[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/BouleJaune/picselliaT/blob/master/notebookOBDTC.ipynb)



In [0]:
import sys
sys.path.append("slim")

# Object-detection made easy
We will learn how to easily train an object detection model from a list of pre-trained models with the dataset you created on the picsell-IA platform.

## Imports

In [0]:
from picsellia import Client
import main
from util.infer import infer
import tensorflow as tf





## List of possible models : 

- Mask RCNN, a segmentation model. String variable: 'mask_rcnn'
- Faster RCNN an accurate but slow object detection model. String variable: 'faster_rcnn'
- SSD Inception, a fast but less accurate object detection model. String variable : 'ssd_inception'

## Setup

We need to start setting up some variables before everything else.

In [0]:
token = "f2a5daec-e0cc-4ea8-a5eb-10d04fd1e153" # Token from the picsell-IA platform
model_picked = "faster_rcnn" # Choose your base model here from the list of possible models
model_name = "faster_rcnn" # Name your to-be trained model
annotation_type = "rectangle" # Chose the type of annotation used

batch_size = 1
learning_rate = None #You can let this value to None
nb_steps = 20000
mask_type = None #Set this to 'PNG_MASKS' if you want to train a mask segmentation model.

## Client initialisation and data pre-processing

We communicate with the platform to create a new model and get the images and annotations.
With this we can generate the label map, smartly split our data then create the TFRecord files which will be used as input for the model. 

In [0]:
path_models = "models/"
model_selected = path_models + model_picked + "/"

clt = Client(token=token, host="https://backstage.picsellia.com/sdk/")
clt.init_model(model_name)

clt.dl_annotations()
clt.generate_labelmap()
clt.local_pic_save()

main.create_record_files(label_path=clt.label_path, record_dir=clt.record_dir, 
                         tfExample_generator=clt.tf_vars_generator, annotation_type=annotation_type)

Here we edit the base protobuf configuration of our model with our parameters. 
We check if it the first training on this model to see if we should train from a previous checkpoint or not.

In [0]:
if clt.training_id!=0:
    previous_path = clt.base_dir.split("/")[:-1]
    previous_path[-1] = clt.training_id - 1
    model_selected = "{}/{}/{}/".format(*previous_path)+"checkpoint/"
    
main.edit_config(model_selected=model_selected, config_output_dir=clt.config_dir,
            record_dir=clt.record_dir, 
            label_map_path=clt.label_path, 
            masks=mask_type, 
            num_steps=nb_steps,
            batch_size=batch_size, 
            learning_rate=learning_rate,
            training_id=clt.training_id)

## Training

We can launch the training... and it's as easy as just telling the fonction where is the configuration file and where we want the checkpoints and records to be saved !

In [0]:
main.legacy_train(ckpt_dir=clt.checkpoint_dir, 
                     conf_dir=clt.config_dir)

Now that the training has ended we want to send the logs to our dashboard so we can neatly see the sweet decrease of the loss. 

In [0]:
dict_log = main.tfevents_to_dict(path=clt.checkpoint_dir)
clt.send_logs(dict_log)

## Exporting and infering

The model is trained but we still need to export it to a Tensorflow graph proto to use it.

In [0]:
tf.reset_default_graph()
main.export_infer_graph(ckpt_dir=clt.checkpoint_dir, 
                       exported_model_dir=clt.exported_model, 
                       pipeline_config_path=clt.config_dir,
                       write_inference_graph=True, input_type="image_tensor", input_shape=None)

Now we will use the exported model to do some inference on our evaluation set, then send the results to the dashboard.
You can set the minimum confidence treshold at which we keep the bounding boxes to the value you like the most.

In [0]:
min_score_thresh = 0.6 
infer(clt.eval_list, exported_model_dir=clt.exported_model, 
          label_map_path=clt.label_path, results_dir=clt.results_dir, min_score_thresh=min_score_thresh)
clt.send_examples()