In [None]:
import os
import cv2
import numpy as np
from os.path import join
import tensorflow as tf

from modules.datasets import ImageTargetDataset, RandomConcatDataset, ConcatDataset
from modules.segm_transforms import train_transforms, test_transforms, ToTensorColor
from modules.metrics import FbSegm
from train.train import Model

import matplotlib.pyplot as plt
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [None]:
tf.config.threading.set_intra_op_parallelism_threads(16)
tf.config.threading.set_inter_op_parallelism_threads(16)

In [None]:
train_batch_size = 32
val_batch_size = 32
INPUT_SIZE = (224, 224)
AUG_PARAMS = [0.75, 1.25, 0.75, 1.25, 0.6, 1.4]
ANG_RANGE = 15

device = 'GPU:0'

In [None]:
train_trns = train_transforms(dataset='picsart', scale_size=INPUT_SIZE, ang_range=ANG_RANGE,
                                      augment_params=AUG_PARAMS, add_background=False,
                                      crop_scale=0.02)
val_trns = test_transforms(dataset='picsart', scale_size=INPUT_SIZE)

In [None]:
data_dirs_hq = [
    '/workdir/data/datasets/picsart/',
    '/workdir/data/datasets/supervisely_person/',
]

data_dirs_coco = [
    '/workdir/data/datasets/coco_person/'
#   '/workdir/data/datasets/cityscapes_person/',
]


In [None]:
train_dirs_hq = [join(d, 'train') for d in data_dirs_hq]
val_dirs_hq = [join(d, 'val') for d in data_dirs_hq]
train_dirs_coco = [join(d, 'train') for d in data_dirs_coco]
val_dirs_coco = [join(d, 'val') for d in data_dirs_coco]

In [None]:
train_dataset_hq = ImageTargetDataset(train_dirs_hq,
                                           train_batch_size,
                                           shuffle=True,
                                           device=device,
                                           **train_trns,
                                           IMG_EXTN='.jpg',
                                           TRGT_EXTN='.png')
val_dataset_hq = ImageTargetDataset(val_dirs_hq,
                                           val_batch_size,
                                           shuffle=False,
                                           device=device,
                                           **val_trns,
                                           IMG_EXTN='.jpg',
                                           TRGT_EXTN='.png')

In [None]:
train_dataset_coco = ImageTargetDataset(train_dirs_coco,
                                           train_batch_size,
                                           shuffle=True,
                                           device=device,
                                           **train_trns,
                                           IMG_EXTN='.jpg',
                                           TRGT_EXTN='.png')
val_dataset_coco = ImageTargetDataset(val_dirs_coco,
                                           val_batch_size,
                                           shuffle=False,
                                           device=device,
                                           **val_trns,
                                           IMG_EXTN='.jpg',
                                           TRGT_EXTN='.png')

In [None]:
train_dataset = RandomConcatDataset([train_dataset_hq, train_dataset_coco],
                                    [0.95, 0.05], size=300)

In [None]:
print("Train dataset len:", len(train_dataset))
print("Val dataset len:", len(val_dataset_hq))

### Visualize datasets

In [None]:
def vis_dataset(dataset, num_samples=train_batch_size):
    for x in dataset:
        img, target = x[0], x[1]
        for i in range(num_samples):
            print("Image shape: {}, target shape: {}".format(img[i].shape, target[i].shape))
            plt.imshow(img[i])
            plt.imshow(np.squeeze(target[i]), alpha=0.4)
            plt.show()
        break

In [None]:
vis_dataset(train_dataset, 8)

In [None]:
vis_dataset(val_dataset_hq, 8)

## Build a model

In [None]:
# Initialize model params
model_name = 'mobilenet_small'
n_class=1
old_model_path = None  # Or path to the previous saved model

In [None]:
# Train params
n_train = len(train_dataset)
n_val = len(val_dataset_hq)

loss_name = 'fb_combined'
optimizer = 'Adam'
lr = 0.00005
batch_size = train_batch_size
max_epoches = 1000
save_directory = '/workdir/data/experiments/mobilenetv3_test'
reduce_factor = 0.75
epoches_limit = 5
early_stoping = 100
metrics = [FbSegm(channel_axis=-1)]

In [None]:
mobilenet_model = Model(device=device,
                        model_name=model_name,
                        n_class=n_class,
                        input_shape=(train_batch_size, INPUT_SIZE[0],INPUT_SIZE[1],3),
                        old_model_path=old_model_path, shape=INPUT_SIZE)

In [None]:
mobilenet_model.prepare_train(train_loader=train_dataset,
                              val_loader=val_dataset_hq,
                              n_train=n_train,
                              n_val=n_val,
                              loss_name=loss_name,
                              optimizer=optimizer,
                              lr = lr,
                              batch_size = batch_size,
                              max_epoches = max_epoches,
                              save_directory = save_directory,
                              reduce_factor=reduce_factor,
                              epoches_limit=epoches_limit,
                              early_stoping=early_stoping,
                              metrics=metrics)

In [None]:
mobilenet_model.fit()

In [None]:
mobilenet_model.validate(val_dataset_hq, n_val)

# Test on some images

In [None]:
test_imgs = os.listdir('/workdir/data/test_examples/')

In [None]:
for img_path in test_imgs:
    img_path = os.path.join('/workdir/data/test_examples/', img_path)
    test_img = cv2.imread(img_path)
    test_img = test_img[:,:,::-1]
    test_img = cv2.resize(test_img, INPUT_SIZE)
    test_tensor = ToTensorColor()(test_img)
    test_tensor = tf.expand_dims(test_tensor, 0)
    out = mobilenet_model.predict(test_tensor)
    out_img = np.squeeze(out)
    print("Prediction shape:", out_img.shape)
    plt.imshow(test_img)
    plt.imshow((out_img>0.5)*255, alpha=0.4)
    plt.show()