In [None]:
!pip install picsellia

In [None]:
import tensorflow as tf
from picsellia import Client

In [None]:
api_token = "your_token" # API Token from the picsell-IA platform
project_token = "your_project_token" # project token dounf in project -> settings

In [None]:
model_name = "your_model_name" # Name your soon-to-be trained model

In [None]:
clt = Client(api_token=api_token)
clt.init_project(project_token=project_token)
clt.init_model(model_name)
clt.dl_annotations()
clt.generate_labelmap()
clt.local_pic_save()

In [None]:
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [None]:
 def create_record_files(label_map, record_dir, tfExample_generator, annotation_type):
    ensembles = ["train", "eval"]    
    for ensemble in ensembles:
        output_path = record_dir+ensemble+".record"
        writer = tf.io.TFRecordWriter(output_path)
        for variables in tfExample_generator(label_map, ensemble=ensemble, annotation_type=annotation_type):
            (width, height, filename, encoded_jpg, image_format, 
                classes_text, classes) = variables

            tf_example = tf.train.Example(features=tf.train.Features(feature={
                'image/encoded': _bytes_feature(encoded_jpg),
                'image/object/class/label': _int64_feature(classes[0]-1)
                }))
            writer.write(tf_example.SerializeToString())
    
        writer.close()
        print('Successfully created the TFRecords: {}'.format(output_path))

annotation_type = "classification"                
create_record_files(label_map=clt.label_map, record_dir=clt.record_dir, 
                    tfExample_generator=clt.tf_vars_generator)

In [None]:
feature_description = {
      'image/encoded': tf.io.FixedLenFeature([], tf.string),
      'image/object/class/label': tf.io.FixedLenFeature([], tf.int64, default_value=0),
}

def _parse_function(example_proto):
  # Parse the input `tf.Example` proto using the dictionary above.
    return tf.io.parse_single_example(example_proto, feature_description)

In [None]:
raw_dataset = tf.data.TFRecordDataset(clt.record_dir+"train.record")
train_dataset = raw_dataset.map(_parse_function)

raw_dataset = tf.data.TFRecordDataset(clt.record_dir+"eval.record")
eval_dataset = raw_dataset.map(_parse_function)

In [None]:
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input

def map_img_label(example_proto):
    img = tf.io.decode_jpeg(example_proto["image/encoded"], channels=3)
    img = tf.image.resize(img, (224,224))
    img = tf.keras.applications.mobilenet_v2.preprocess_input(img)
    label = example_proto["image/object/class/label"]
    label = tf.one_hot(label, depth=2)
    return (img,label)
    
train_set = train_dataset.map(map_img_label)
eval_set = eval_dataset.map(map_img_label)

In [None]:
BATCH_SIZE = 16
SHUFFLE_BUFFER_SIZE = 50

train_set = train_set.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
eval_set = eval_set.batch(BATCH_SIZE)

In [None]:
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import AveragePooling2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model

In [None]:
baseModel = MobileNetV2(weights="imagenet", include_top=False,
    input_tensor=Input(shape=(224, 224, 3)))

headModel = baseModel.output
headModel = AveragePooling2D(pool_size=(7, 7))(headModel)
headModel = Flatten(name="flatten")(headModel)
headModel = Dense(128, activation="relu")(headModel)
headModel = Dropout(0.5)(headModel)
headModel = Dense(2, activation="softmax")(headModel)
model = Model(inputs = baseModel.input, outputs = headModel)
for layer in baseModel.layers:
    layer.trainable = False

In [None]:
model.summary()

In [None]:
History = model.fit(train_set,
    validation_data=eval_set,
    callbacks=[tensorboard_callback],
    epochs=EPOCHS)

In [None]:
logs = {k:{"step": History.epoch, "value":v} for k,v in History.history.items()}
clt.send_logs(logs)

checkpoint = tf.train.Checkpoint(optimizer=opt, model=model)
checkpoint.save(clt.checkpoint_dir+"model.ckpt" )
clt.send_checkpoints()

model.save(clt.exported_model)
clt.send_model()