[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Picsell-ia/training/blob/master/Object_Detection_TF1_easy.ipynb)

In this notebook we define a simple wrapper function of what has been described in the "Train a custom object detection model with Tensorflow 1" HOW TO.

In [None]:
import sys
sys.path.append("slim")
from picsellia import Client
import picsell_utils
import tensorflow as tf

def wrapper_function(api_token, project_token, model_name, batch_size, nb_steps, 
                     learning_rate=None, annotation_type="rectangle"):
    
    clt = Client(api_token)
    clt.checkout_project(project_token=project_token)
    clt.checkout_network(model_name)
    
    clt.train_test_split()
    clt.dl_pictures()
    

    picsell_utils.create_record_files(label_path=clt.label_path, record_dir=clt.record_dir, 
                                     tfExample_generator=clt.tf_vars_generator, 
                                      annotation_type=annotation_type)

    picsell_utils.edit_config(model_selected=clt.model_selected, 
                                config_output_dir=clt.config_dir,
                                record_dir=clt.record_dir, 
                                label_map_path=clt.label_path, 
                                num_steps=nb_steps,
                                batch_size=batch_size, 
                                learning_rate=learning_rate,
                                annotation_type=annotation_type,
                                eval_number=len(clt.eval_list))
    
    picsell_utils.train(ckpt_dir=clt.checkpoint_dir, 
                     conf_dir=clt.config_dir)
    
    dict_log = picsell_utils.tfevents_to_dict(path=clt.checkpoint_dir)
    
    metrics = picsell_utils.evaluate(clt.metrics_dir, clt.config_dir, clt.checkpoint_dir)
    
    picsell_utils.export_infer_graph(ckpt_dir=clt.checkpoint_dir, 
                       exported_model_dir=clt.exported_model_dir, 
                       pipeline_config_path=clt.config_dir)
    
    picsell_utils.infer(clt.record_dir, 
                        exported_model_dir=clt.exported_model_dir, 
                        label_map_path=clt.label_path, 
                        results_dir=clt.results_dir)

    clt.send_everything(dict_log, metrics)

In [None]:
api_token = "your_api_token"
project_token = "your_project_token" 
model_name = "your_model_name"

wrapper_function(api_token, project_token, model_name,
                 batch_size=10, nb_steps=1000)