In [None]:
%matplotlib inline
from utils import *
from sklearn.model_selection import train_test_split
from deeplabv3p import preprocess_input
PATH = '/workspace/datasets/OpenSourceDatasets/VOCdevkit/VOC2012/'
image_size = (512, 512)

In [None]:
# build model
deeplab_seg = SegModel(PATH, image_size, use_coords = True)
deeplab_seg.set_batch_size(4)
deeplab_seg.set_num_epochs(10)

opt = SGD(lr=0.01, momentum = 0.8) # Adam()
deeplab_model = deeplab_seg.create_seg_model(opt, net='deeplabv3', load_weights = False, w=30, multi_gpu = False)

#deeplab_model.summary()

#### Create Generators

In [None]:
data_trn_gen_args_mask = dict(preprocessing_function = preprocess_mask,
                              horizontal_flip=True, rotation_range=45,
                              width_shift_range=0.01, 
                              height_shift_range=0.01,
                              zoom_range=0.2, 
                              validation_split = .2,
                              )

data_trn_gen_args_image = dict(preprocessing_function = preprocess_input,
                               horizontal_flip=True, rotation_range=45,
                               width_shift_range=0.01, 
                               height_shift_range=0.01,
                               zoom_range=0.2, 
                               validation_split = .2,
                               channel_shift_range = .2
                               )

data_val_gen_args_image = dict(preprocessing_function = preprocess_input, validation_split = .2)
data_val_gen_args_mask = dict(validation_split = .2)

train_generator = deeplab_seg.create_generators(data_trn_gen_args_image, data_trn_gen_args_mask, subset = 'training')
valid_generator = deeplab_seg.create_generators(data_val_gen_args_image, data_val_gen_args_mask, subset = 'validation')

In [None]:
data_trn_gen_args_image['preprocessing_function'] = None

show_aug_data(train_generator, data_trn_gen_args_image, data_trn_gen_args_mask)

## Train Model

In [None]:
h = deeplab_seg.train_generator(deeplab_model, train_generator=train_generator, 
                                valid_generator = valid_generator, 
                                tf_board = True, mp = True)

### EVALUATION

In [None]:
classes = [c for c in get_VOC2012_classes().values()]
y_preds = deeplab_model.predict(X_test, batch_size = deeplab_seg.batch_size, verbose=1)
df_test, conf_test, mean_acc_test = evaluate_model(y_preds, y_test, data_from = 'test')

y_preds = deeplab_model.predict(X_valid, batch_size = deeplab_seg.batch_size, verbose=1)
df_valid, conf_valid, mean_acc_valid = evaluate_model(y_preds, y_valid, data_from = 'valid')
df_test['IoU valid'] = df_valid['IoU valid']

if deeplab_seg.coords:
    model_name = 'deeplabv3+ with added coordconv channels'
else:
    model_name = 'deeplabv3+'

plt.figure(figsize=(17,7))
plt.subplot(121)
plot_confusion_matrix(conf_test, classes, normalize=True, 
                      title='Pixel-wise confusion matrix of Test set. Model: '+model_name)
plt.subplot(122)
plot_confusion_matrix(conf_valid, classes, normalize=True, 
                      title='Pixel-wise confusion matrix of Validation set. Model: '+model_name)

print('test mean acc: ', mean_acc_test)
print('validation mean acc: ', mean_acc_valid)
df_test