
# Pneumothorax segmentation


In [None]:

#! git remote add origin https://github.com/andrew-johnson-melb/kaggle-pneumothorax-segmentation.git


In [None]:

#!pip install albumentations
SRC_FILES = '/home/ec2-user/SageMaker/seg_project/src'


In [None]:

import os
import sys
import torch
import gc
from tqdm import tqdm
import cv2
from glob import glob
from collections import namedtuple
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import WeightedRandomSampler

sys.path.append(SRC_FILES)

from utils import *
from unet import UNet
from transforms import get_transforms
from trainer import train_one_epoch, evaluate
from dataset import PneumothoraxDataset
from vis import show_batch, compare_masks
from loss import MixedLoss, dice_loss, dice_metric, MetricCollector

%load_ext autoreload
%autoreload 2
%matplotlib inline


In [None]:

LABELLED_DATA = '/home/ec2-user/SageMaker/seg_project/data/preprocessed_data/train/size-512/'
LABELS_MASKS = '/home/ec2-user/SageMaker/seg_project/data/preprocessed_data/train-masks/size-512/'
TEST_DATA = '/home/ec2-user/SageMaker/seg_project/data/preprocessed_data/test/size-512/'
META_DATA_DIR = '/home/ec2-user/SageMaker/seg_project/data/meta_data_siim.csv'



### Constructing training and testing sets



Load the meta data which contains file names, labels, and patient info
then split the dataframe into the labelled data (for development) and the test data used for Kaggles validation.


In [None]:

meta_data_df = pd.read_csv(META_DATA_DIR, index_col=0)
labelled_df_set = meta_data_df[meta_data_df.train_set]
test_df_set = meta_data_df[~meta_data_df.train_set]



Add the full paths to the masks and images, these paths 
will be used by the dataloader to read the file from disk


In [None]:

test_df_set['images'] = TEST_DATA + test_df_set.file_name + '.png'
labelled_df_set['images'] = LABELLED_DATA + labelled_df_set.file_name + '.png'
labelled_df_set['masks'] = LABELS_MASKS + labelled_df_set.file_name + '.png'
labelled_df_set.head(2)



Split the data randomly into a train and test set.
Cross validation could be used here to get a more
accurate measure of the generalisation of the model but for speed of development we will use a simple random split. Ideally the 
train/val split would be strafied by the positive class to ensure the validation set has a reasonble number of positive classes.


In [None]:

train_set_df, val_set_df = train_dev_split(labelled_df_set)


### Create dataloaders


In [None]:

# Create dataset for the training and validation data
aug_training, aug_validation = get_transforms()
train_dataset = PneumothoraxDataset(files_df=train_set_df, labelled=True, transform=aug_training)
val_dataset = PneumothoraxDataset(files_df=val_set_df, labelled=True, transform=aug_validation)


### Set training parameters



In [None]:

Configs = namedtuple('TrainingConfigs', ['batch_size', 'ratio_pos_neg_sample', 'lr', 'num_epochs'])
configs = Configs(batch_size=8, ratio_pos_neg_sample=5, lr=0.00005, num_epochs=40)


### Dealing with Class imbalance

Only 20% of the samples contain the positive class label (indicating Pneumothorax)
This imbalance will cause issues when training the model: the model is only 
predicting zero for the entire region. To counter this we can increase the 
frequency at which the positive samples are drawn. We do this using 
the pytorch WeightedRandomSampler. We will use the labels contained 
in the meta data to construct a vector of weights which the WeightedRandomSampler
will use to sample the data.



In [None]:

train_set_df = gen_upsampling_weights(train_set_df, ratio_pos_neg=configs.ratio_pos_neg_sample)
weighted_sampler = WeightedRandomSampler(weights=train_set_df.weights.values, num_samples=train_set_df.shape[0])


In [None]:

train_dataloader = DataLoader(dataset=train_dataset, batch_size=configs.batch_size, num_workers=6, sampler=weighted_sampler)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=8, num_workers=6, shuffle=False)



### Visualize a batch of training data


In [None]:
# It can be very helpful to inspect the transformed data and (normalization aside) exactly the data
# that is going into training the model. 

_ = show_batch(train_dataloader)


### Tesing the model outputs and shape


In [None]:
# Get some sample data
input_, target = get_sample(train_dataloader)

# Create model
unet = UNet()
output = unet(input_)
output = output.squeeze()

print(f'Input Shape  = {input_.shape}')
print(f'Target (y) shape = {target.shape}')
print(f"Output shape = {output.shape}")


In [None]:

# Lets have a quick look at the model prediction and the mask values.
for i in range(1):
    im, t = output[i], target[i]
    print(f'Best Dice = {dice_loss(im.float(), t.float())}')
    print(f'Best Dice = {dice_metric(im.float(), t.float())}')
    print(f'Best BCE = {bce_loss(im.float(), t.float())}')
    compare_masks(im,t, label=i)


In [None]:

# Check the sizes 
print_model_sizes(unet, input_)


### Train the segmentation model

In [None]:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_epochs = configs.num_epochs

model = UNet(freeze_encoder=True)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(),lr=configs.lr)

# Create a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


In [None]:

loss_fn = MixedLoss(10.0,2.0)
validation_losses = {'dice_loss' : dice_loss, 'mixed_loss': loss_fn, 'dice_metric': dice_metric} 
val_metric_collector = MetricCollector(validation_losses, set_label='validation')
train_metric_collector = MetricCollector(validation_losses, set_label='train')


In [None]:

print('Random model loss validation')
evaluate(model, val_dataloader, device, val_metric_collector, epoch=-1)


In [None]:

for epoch_idx in range(5):
    train_one_epoch(model, optimizer, loss_fn, train_dataloader, device , epoch_idx, train_metric_collector)
    lr_scheduler.step()
    evaluate(model, val_dataloader, device, val_metric_collector, epoch_idx)
