In [None]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import cv2
import matplotlib.pyplot as plt
from segmentation_models_pytorch import utils
import pandas as pd

In [None]:
# git pull here
root='/Users/amograo/Desktop/test/Dataset'

In [None]:

images_dir='png_files'
masks_dir='mask_files'
train_csv='csv/train_upsampled.csv'
val_csv='csv/val_final.csv'


In [None]:

def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

# Dataset Class

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

In [None]:
class Dataset(BaseDataset):


    def __init__(
            self, 
            root,
            images_dir, 
            masks_dir, 
            csv,
            aug_fn=None,
            id_col='DICOM',
            aug_col='Augmentation',
            preprocessing_fn=None,
    ):
        images_dir=os.path.join(root,images_dir)
        masks_dir=os.path.join(root,masks_dir)
        df=pd.read_csv(os.path.join(root,csv))
        
        self.ids=[(r[id_col],r[aug_col]) for i,r in df.iterrows()]
        self.images=[os.path.join(images_dir,item[0]+".png") for item in self.ids]
        self.masks=[os.path.join(masks_dir,item[0]+"_mask.png") for item in self.ids]
        self.aug_fn=aug_fn
        self.preprocessing_fn=preprocessing_fn

    
    def __getitem__(self, i):

        image = cv2.imread(self.images[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask=(cv2.imread(self.masks[i], 0) == 255).astype('float')
        mask = np.expand_dims(mask, axis=-1)

        aug=self.ids[i][1]
        if aug:
            augmented=self.aug_fn(aug)(image=image,mask=mask)
            image,mask=augmented['image'],augmented['mask']
        
        if self.preprocessing_fn:
            sample = self.preprocessing_fn(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.ids)

# Augmentation and Preprocessing Functions

In [None]:
from albumentations import (HorizontalFlip, RandomBrightnessContrast, RandomGamma, CLAHE, ElasticTransform, GridDistortion, OpticalDistortion, ShiftScaleRotate, Normalize, GaussNoise, Compose, Lambda)

def augmentation_fn(value):
    augmentation_options = {
        1 : [HorizontalFlip(p = 1)],
        2 : [RandomBrightnessContrast(brightness_limit = 0.2, contrast_limit = 0.2, p = 1)],
        3 : [RandomGamma(p = 1)],
        4 : [CLAHE(clip_limit = 4.0, tile_grid_size = (4, 4), p = 1)],
        5 : [OpticalDistortion(p = 1)],
        6 : [ShiftScaleRotate(shift_limit = 0.2, scale_limit = 0.2, rotate_limit = 15, p = 1)],
        7 : [GaussNoise(p = 1)],
        8 : [HorizontalFlip(p = 1), RandomBrightnessContrast(brightness_limit = 0.2, contrast_limit = 0.2, p = 1), RandomGamma(p = 1)],
        9 : [HorizontalFlip(p = 1), RandomBrightnessContrast(brightness_limit = 0.2, contrast_limit = 0.2, p = 1), CLAHE(clip_limit = 4.0, tile_grid_size = (4, 4), p = 1)],
        10 : [HorizontalFlip(p = 1), RandomBrightnessContrast(brightness_limit = 0.2, contrast_limit = 0.2, p = 1), OpticalDistortion(p = 1)],
        11 : [HorizontalFlip(p = 1), RandomBrightnessContrast(brightness_limit = 0.2, contrast_limit = 0.2, p = 1), GaussNoise(p = 1)],
        12 : [ShiftScaleRotate(shift_limit = 0.2, scale_limit = 0.2, rotate_limit = 15, p = 1), GaussNoise(p = 1)],
        13 : [CLAHE(clip_limit = 4.0, tile_grid_size = (4, 4), p = 1), GaussNoise(p = 1)],
        14 : [CLAHE(clip_limit = 4.0, tile_grid_size = (4, 4), p = 1), OpticalDistortion(p = 1)],
        15 : [CLAHE(clip_limit = 4.0, tile_grid_size = (4, 4), p = 1), RandomGamma(p = 1)],
        16 : [RandomGamma(p = 1), OpticalDistortion(p = 1)],
        17 : [RandomBrightnessContrast(brightness_limit = 0.2, contrast_limit = 0.2, p = 1), GaussNoise(p = 1)],
        18 : [ShiftScaleRotate(shift_limit = 0.2, scale_limit = 0.2, rotate_limit = 15, p = 1), RandomGamma(p = 1)],
        19 : [ShiftScaleRotate(shift_limit = 0.2, scale_limit = 0.2, rotate_limit = 15, p = 1), HorizontalFlip(p = 1)],
        20 : [ShiftScaleRotate(shift_limit = 0.2, scale_limit = 0.2, rotate_limit = 15, p = 1), GaussNoise(p = 1), OpticalDistortion(p = 1)]
    }

    return Compose(augmentation_options[value])


In [None]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        Lambda(image=preprocessing_fn),
        Lambda(image=to_tensor, mask=to_tensor),
    ]
    return Compose(_transform)

# Create Model

In [None]:
import torch
import numpy as np
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils

In [None]:
ENCODER = 'efficientnet-b5'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'mps'

# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
train_dataset=Dataset(
    root=root,
    images_dir=images_dir,
    masks_dir=masks_dir,
    csv=train_csv,
    aug_fn=augmentation_fn,
    preprocessing_fn=get_preprocessing(preprocessing_fn)
)
print(len(train_dataset))

val_dataset=Dataset(
    root=root,
    images_dir=images_dir,
    masks_dir=masks_dir,
    csv=val_csv,
    aug_fn=augmentation_fn,
    preprocessing_fn=get_preprocessing(preprocessing_fn)
)

print(len(val_dataset))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
valid_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)

In [None]:
# Dice/F1 score - https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
# IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index

loss = utils.losses.DiceLoss()
metrics = [
    utils.metrics.IoU(threshold=0.5)
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

In [None]:
# create epoch runners 
# it is a simple loop of iterating over dataloader`s samples
train_epoch = utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
# train model for 40 epochs

max_score = 0
iou_scores=[]
dice_loss=[]
for i in range(0, 40):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    iou_scores.append(valid_logs['iou_score'])
    dice_loss.append(valid_logs['dice_loss'])

    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')
        
    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

torch.save(model,'./latest_model.pth')
print('Latest model saved!')