In [None]:
import os
import sys
import datetime

In [None]:
data_root = "../dataset"
images_root = os.path.join(data_root, "images_all")
masks_root = os.path.join(data_root, "masks_all")

train_data_dist = os.path.join(data_root, "train_data.csv")

In [None]:
scripts_path = "../scripts"

In [None]:
sys.path.append(scripts_path)

In [None]:
import segmentation_models_pytorch as smp
import torch
import constants as const
import numpy as np

from train_utils import *
from metrics import get_iou_metric
from pytorch_toolbelt import losses as L

## Pre-load transformation functions

## Data loading and splitting

In [None]:
train_dataset, val_dataset = get_train_val_dataset_segmentation(train_data_dist, images_root, masks_root)

## Modelling

In [None]:
model = smp.FPN(
        encoder_name = const.ENCODER, 
        encoder_weights = const.ENCODER_WEIGHTS, 
        classes = len(const.CLASSES), 
        activation = const.ACTIVATION)

In [None]:
# Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
# IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index

loss = smp.utils.losses.DiceLoss()

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

## Data pipelines

In [None]:
train_loader = get_data_loader(train_dataset, batch_size = const.batch_size_train, num_workers = const.num_workers_train)
valid_loader = get_data_loader(val_dataset, batch_size = const.batch_size_val, shuffle=False, num_workers = const.num_workers_val)

In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=const.DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=const.DEVICE,
    verbose=True,
)

## Training

In [None]:
def generate_model_name(architecture: str = "unet", encoder: str = "resnet_34"):
    return "{}_backbone_{}_{}".format(architecture, encoder, datetime.datetime.now())

In [None]:
max_score = 0
epochs = 20

for i in range(0, epochs):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, '../models/{}.pth'.format(generate_model_name(encoder = const.ENCODER)))
        print('Model saved!')
        
    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

In [None]:
train_image, train_mask = train_dataset.__getitem__(random.randint(0, 2000))

In [None]:
print(np.unique(train_image))
print(np.unique(train_mask))

In [None]:
f, axarr = plt.subplots(1,2, figsize = (16, 8))
axarr[0].imshow(train_image.transpose(2, 1, 0), cmap = 'gray')
axarr[1].imshow(train_mask.transpose(2, 1, 0), cmap = 'gray')
    
axarr[0].title.set_text("Image")
axarr[1].title.set_text("Mask")