## Installation

In [None]:
# install required libs
#!pip install -U segmentation-models-pytorch albumentations --user 

## Loading data

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt

from os import path
from matplotlib import gridspec

In [None]:
# root dir
PATH_ROOT = 'root_dir'

# dataset dirs
DATASET = 'dataset_name'
PATH_DATASET = path.join(PATH_ROOT, 'datasets/enriched/', DATASET)
PATH_TRAIN_IMAGES = path.join(PATH_DATASET, 'train/images/')
PATH_TRAIN_LABELS = path.join(PATH_DATASET, 'train/labels_thresh128/')
PATH_VAL_IMAGES = PATH_TRAIN_IMAGES
PATH_VAL_LABELS = PATH_TRAIN_LABELS
print(PATH_DATASET)

n_ti, n_tl = len(os.listdir(PATH_TRAIN_IMAGES)), len(os.listdir(PATH_TRAIN_LABELS))
print('train ', n_ti, '/', n_tl, sep='')

In [None]:
# keyworded visualization: title_name_1=image1, title_name_2=image2, ...
def visualize(**images):
    n = len(images)
    plt.figure(figsize=(16, 10))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i+1)
        plt.title(' '.join(name.split('_')).title()) # keyword as title
        plt.imshow(image)
    plt.show()

# non-keyworded images: image1, image2, ...
def visualize_grid(*images):
    n, cols = len(images), 4
    rows = int(np.ceil(n / cols))
    gs = gridspec.GridSpec(rows, cols)
    fig = plt.figure(figsize=(16, 4*rows))
    fig.tight_layout()
    for i in range(n):
        ax = fig.add_subplot(gs[i])
        ax.imshow(images[i])
        ax.axis('off')    
    plt.show()

## Dataloader

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

In [None]:
# classes/intensities
CLASSES = [
    'background',
    'wound'
]

for i in range(len(CLASSES)):
    print('value ', i, ':\t', CLASSES[i], sep='')

In [None]:
class Dataset(BaseDataset):

    CLASSES = ['background', 'wound']
    
    def __init__(self, images_dir, masks_dir, classes=None, augmentation=None, preprocessing=None):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv.imread(self.images_fps[i])
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
        mask = cv.imread(self.masks_fps[i], 0)
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.ids)

In [None]:
# inspect training data
train_dataset = Dataset(PATH_TRAIN_IMAGES, PATH_TRAIN_LABELS, classes=['wound'])
image, mask = train_dataset[1] # get some sample
visualize(image=image, wound_mask=mask.squeeze())

# save image height/width
IMAGE_HEIGHT, IMAGE_WIDTH, _ = image.shape
assert (IMAGE_HEIGHT % 32) + (IMAGE_WIDTH % 32) == 0, 'image height/width must be divisible by 32'

### Augmentations

In [None]:
import albumentations as albu

In [None]:
AUGMENTATION = 'medium'

In [None]:
def get_training_augmentation():
    train_transform = [
        
        # random crop (basic)
        albu.RandomCrop(
            height=352,
            width=352,
            always_apply=True
        ),
        
        # geometry (basic)
        albu.Flip(p=0.5),
        albu.ShiftScaleRotate(p=0.5),
        
        # distortion, shuffling (heavy)
        albu.OneOf([
            albu.GridDistortion(p=1),            
            albu.ElasticTransform(p=1),
        ], p=0.5),      
        
        # contrast, brightness (basic)
        albu.OneOf([
            albu.CLAHE(p=1),
            albu.RandomGamma(p=1),
            albu.RandomBrightnessContrast(p=1),
        ], p=0.5),

        # blurring, sharpening (basic)
        albu.OneOf([
            albu.Sharpen(p=1),
            albu.Blur(blur_limit=8, p=1),
            albu.MotionBlur(blur_limit=8, p=1),
        ], p=0.5),      
        
        # noise (basic)
        albu.OneOf([
            albu.GaussNoise(p=1),
        ], p=0.5),
               
    ]
    
    return albu.Compose(train_transform)


def get_validation_augmentation():
    val_transform = [albu.PadIfNeeded(IMAGE_HEIGHT, IMAGE_WIDTH)]    
    return albu.Compose(val_transform)


def to_tensor(x, **kwargs):    
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):    
    _transform = [albu.Lambda(image=preprocessing_fn), albu.Lambda(image=to_tensor, mask=to_tensor)]
    return albu.Compose(_transform)

In [None]:
# create dataset with augmentation pipeline for training
dataset = Dataset(
    PATH_TRAIN_IMAGES, PATH_TRAIN_LABELS, 
    classes=CLASSES, augmentation=get_training_augmentation(),
)

# visualize exemplary sample augmentations
visualize_grid(*[dataset[3103][0] for _ in range(16)])

## Create model and train

In [None]:
import torch
import numpy as np
import segmentation_models_pytorch as smp
import time
import os

from torch.utils.tensorboard import SummaryWriter

In [None]:
# model parameters
MODEL = 'fpn' # TODO: currently hard-coded, fix
ENCODER = 'se_resnext101_32x4d'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'
CLASSES = ['wound'] # ['background', 'wound']

# training parameters
OPTIMIZER = 'adam' # TODO: currently hard-coded, fix
LR_INIT = 0.0001
LR_DROP_1 = 0.00001
#LR_DROP_2 = 0.00001

# batch size, epochs 
BATCH_SIZE = 24
EPOCHS = 150
EPOCHS_DROP_1 = 120 # LR_DROP after stated epoch (e.g., 25 -> reduced LR from epoch 26)
#EPOCHS_DROP_2 = 225 # LR_DROP after stated epoch (e.g., 25 -> reduced LR from epoch 26)

In [None]:
# run dir
PATH_RUNS = path.join(PATH_ROOT, 'models/')

# run name
if LR_DROP_1 == None and EPOCHS_DROP_1 == None:
    # no lr adjustments
    RUN_NAME = 'training___{}___{}_{}_{}_{}___{}_lr{:.0e}___{}___bs{}_e{}'.format(
        DATASET, MODEL, ENCODER, ENCODER_WEIGHTS, ACTIVATION, 
        OPTIMIZER, LR_INIT, AUGMENTATION, BATCH_SIZE, EPOCHS,
    )
else:
    # lr adjustments
    #RUN_NAME = 'training___{}___{}_{}_{}_{}___{}_lr{:.0e}_lrd1{:.0e}_lrd2{:.0e}___{}___bs{}_e{}_ed1{}_ed2{}___sameconf1'.format(
    RUN_NAME = 'training___{}___{}_{}_{}_{}___{}_lr{:.0e}_lrd1{:.0e}___{}___bs{}_e{}_ed1{}___sameconf1'.format(
        DATASET, MODEL, ENCODER, ENCODER_WEIGHTS, ACTIVATION, 
        #OPTIMIZER, LR_INIT, LR_DROP_1, LR_DROP_2, AUGMENTATION, BATCH_SIZE, EPOCHS, EPOCHS_DROP_1, EPOCHS_DROP_2
        OPTIMIZER, LR_INIT, LR_DROP_1, AUGMENTATION, BATCH_SIZE, EPOCHS, EPOCHS_DROP_1
    )
print('dir:\t', PATH_RUNS)
print('run:\t', RUN_NAME)
print('path:\t', path.join(PATH_RUNS, RUN_NAME))

In [None]:
# create dirs
PATH_RUN = path.join(PATH_RUNS, RUN_NAME)
PATH_LOG = path.join(PATH_RUN, 'log')
os.makedirs(PATH_LOG)

# perform training
print("Training: Started")
print('\n', '-'*80, sep='')
training_start = time.time()

# tensorboard and text log
tb_log = SummaryWriter(log_dir=PATH_LOG)
txt_log = list()


### MODEL
# create segmentation model with pretrained encoder
model = smp.FPN(
    encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS,
    classes=len(CLASSES), activation=ACTIVATION
)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)


### DATASETS, LOADERS
# create training and validation datasets and loaders
train_dataset = Dataset(
    PATH_TRAIN_IMAGES, PATH_TRAIN_LABELS,
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)
val_dataset = Dataset(
    PATH_VAL_IMAGES, PATH_VAL_LABELS,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=24, prefetch_factor=4)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
    
    
### LOSS, SCORE, OPTIMIZER
# create loss, score, and optimizer
loss = smp.utils.losses.DiceLoss() # Dice/F1 loss
metrics = [
    smp.utils.metrics.Fscore(threshold=0.5), # Dice/F1 Score
    smp.utils.metrics.IoU(threshold=0.5), # Jaccard/IoU score
] 
optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=LR_INIT)])


### RUNNERS
# create epoch runners (loop of iterating over dataloader's samples)
train_epoch = smp.utils.train.TrainEpoch(
    model, loss=loss, metrics=metrics, 
    optimizer=optimizer, device=DEVICE, verbose=True
)
val_epoch = smp.utils.train.ValidEpoch(
    model, loss=loss, metrics=metrics, 
    device=DEVICE, verbose=True
)

### TRAINING
# train model
max_score_dice = 0
max_score_jaccard = 0
best_epoch_dice = 0
best_epoch_jaccard = 0
for epoch in range(1, EPOCHS+1):
    
    print('\nEpoch: {}'.format(epoch))
    remaining = EPOCHS - epoch # remaining epochs
    
    # run epoch (only evaluate every n-th epoch, and last epochs)
    train_logs = train_epoch.run(train_loader)    
    if (epoch == 1) or (epoch % 50 == 0) or (remaining < 1):
        val_logs = val_epoch.run(val_loader)


    ### LOGGING
    # dice/f1 loss
    tb_log.add_scalar('train/Dice loss (F1)', train_logs['dice_loss'], epoch)
    tb_log.add_scalar('val/Dice loss (F1)', val_logs['dice_loss'], epoch)

    # dice/f1 score
    tb_log.add_scalar("train/Dice score (F1)", train_logs['fscore'], epoch)
    tb_log.add_scalar("val/Dice score (F1)", val_logs['fscore'], epoch)

    # jaccard/iou score
    tb_log.add_scalar('train/Jaccard index (IoU)', train_logs['iou_score'], epoch)
    tb_log.add_scalar('val/Jaccard index (IoU)', val_logs['iou_score'], epoch)

    txt_log.append({
        'epoch': epoch, 
        'dice_loss': val_logs['dice_loss'],
        'fscore': val_logs['fscore'],
        'iou_score': val_logs['iou_score'],
    })


    ### EVALUATION
    # save epoch if better dice/f1
    if max_score_dice < val_logs['fscore']:
        max_score_dice = val_logs['fscore']
        best_epoch_dice = epoch
        torch.save(model, path.join(PATH_RUN, './best_model_dice.pth'))
        with open(os.path.join(PATH_LOG, 'best_model_dice.txt'), 'w') as file:
            file.write('Best epoch: {}, Dice/F1 score (val): {:.4f}'.format(epoch, max_score_dice))
        print('Model saved (best Dice/F1 score)')

    # save epoch if better jaccard/iou
    if max_score_jaccard < val_logs['iou_score']:
        max_score_jaccard = val_logs['iou_score']
        best_epoch_jaccard = epoch
        torch.save(model, path.join(PATH_RUN, './best_model_jaccard.pth'))
        with open(os.path.join(PATH_LOG, 'best_model_jaccard.txt'), 'w') as file:
            file.write('Best epoch: {}, Jaccard/IoU score (val): {:.4f}'.format(epoch, max_score_jaccard))
        print('Model saved (best Jaccard/IoU score)')
    
    # save last five epochs
    if remaining < 5:
        
        # save range of last epochs
        torch.save(model, path.join(PATH_RUN, 'last_model-{}.pth'.format(remaining)))
        with open(path.join(PATH_LOG, 'last_model-{}.txt'.format(remaining)), 'w') as file:
            file.write('Last epoch: {}, IoU (val): {:.4f}'.format(epoch, val_logs['iou_score']))
        print('One of the last epochs saved!')
        
        # save last epoch under different name
        if epoch == EPOCHS:
            torch.save(model, path.join(PATH_RUN, './last_model.pth'))
            with open(path.join(PATH_LOG, 'last_model.txt'), 'w') as file:
                file.write('Last epoch: {}, IoU (val): {:.4f}'.format(epoch, val_logs['iou_score']))
            print('Last model saved!')
        
        
        
    ### LEARNING RATE MODIFICATION
    # drop learning rate
    if epoch == EPOCHS_DROP_1:
        optimizer.param_groups[0]['lr'] = LR_DROP_1
        print('\nLearning rate dropped: {}'.format(LR_DROP_1))

    #if epoch == EPOCHS_DROP_2:
    #    optimizer.param_groups[0]['lr'] = LR_DROP_2
    #    print('\nLearning rate dropped: {}'.format(LR_DROP_2))
    

training_end = time.time()
print('\n', '-'*80, sep='')
print("\nTraining: Finished")
print("Duration: ~{:.0f} min".format((training_end - training_start)/60))


# write simple log
with open(path.join(PATH_LOG, 'log.txt'), 'w') as file:
    file.write('{}'.format(str(txt_log)))
    print('Text log written to:', PATH_LOG)

## Visualize validation log

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
colors = ['r']

In [None]:
fig, ax = plt.subplots()
fig.set_size_inches(20, 5)
fig.set_dpi(80)

ax.set(xlabel='Epoch', ylabel='Dice/F1 loss', title='Dice/F1 loss over epochs')
ax.grid()

epochs = np.arange(0, EPOCHS, 1)
losses = np.array([d["dice_loss"] for d in txt_log])
losses = [float(l) for l in losses] # float conversion
ax.plot(epochs+1, losses, colors[0], label = 'Training')
legend = ax.legend(loc='upper right', shadow=True, fontsize='x-large')

plt.show()
fig.savefig(os.path.join(PATH_RUN, "dice_loss.png"))
fig.savefig(os.path.join(PATH_RUN, "dice_loss.pdf"))

In [None]:
fig, ax = plt.subplots()
fig.set_size_inches(20, 5)
fig.set_dpi(80)

ax.set(xlabel='Epoch', ylabel='IoU/Jaccard score', title='IoU/Jaccard scores over epochs')
ax.grid()

epochs = np.arange(0, EPOCHS, 1)
losses = np.array([d["iou_score"] for d in txt_log])
losses = [float(l) for l in losses] # float conversion
ax.plot(epochs+1, losses, colors[0], label = 'Training')
legend = ax.legend(loc='lower right', shadow=True, fontsize='x-large')

plt.show()
fig.savefig(os.path.join(PATH_RUN, "iou_score.png"))
fig.savefig(os.path.join(PATH_RUN, "iou_score.pdf"))

In [None]:
# text output for sanity checks
#train_log

## Test best saved model

In [None]:
print("Last model test: Started")
print("\n--------------------------------------------------------------------------------\n")

test_log = list()

# load saved best fold model
best_model = torch.load(os.path.join(WORKING_DIR, RUN_NAME, 'last_model.pth'))

# create test dataset and loader
test_dataset = Dataset(
    test_image_dir, test_label_dir, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)
test_dataloader = DataLoader(test_dataset)

# evaluate model on test set
test_epoch = smp.utils.train.ValidEpoch(model=best_model, loss=loss, metrics=metrics, device=DEVICE)
logs = test_epoch.run(test_dataloader)

# log test
test_log.append({
    'dice_loss': '{:.4f}'.format(logs['dice_loss']), 
    'iou_score': '{:.4f}'.format(logs['iou_score'])
})

print("\n--------------------------------------------------------------------------------")
print("Last model test: Finished")


with open('./test_log.txt', 'w') as f:
    for item in test_log:
        f.write("%s\n" % item)

In [None]:
# text output for sanity checks
#test_log

## Visualize predictions

In [None]:
# test dataset without transformations for image visualization
test_dataset_vis = Dataset(
    test_image_dir, test_label_dir, 
    classes=CLASSES,
)

In [None]:
for i in range(5):
    n = np.random.choice(len(test_dataset))
    
    image_vis = test_dataset_vis[n][0].astype('uint8')
    image, gt_mask = test_dataset[n]
    
    gt_mask = gt_mask.squeeze()
    
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
        
    visualize(
        image=image_vis, 
        ground_truth_mask=gt_mask, 
        predicted_mask=pr_mask
    )