# Task 3

In [None]:
import os
import random
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms.functional as TF
from torchvision import transforms, utils
from torchvision.io import read_image

from skimage.transform import resize
from skimage import exposure
from skimage.morphology import erosion, dilation, opening, closing, disk, square, square

from PIL import Image
from kornia.filters import MedianBlur

from albumentations.augmentations.transforms import HueSaturationValue, RandomBrightness, RandomContrast, Flip, Normalize
from albumentations.augmentations import Resize, CLAHE
from albumentations.augmentations import CenterCrop
from albumentations.core.composition import Compose, OneOf
from albumentations.augmentations import Rotate
from albumentations.augmentations.geometric.transforms import Affine

import pickle
import gzip
from tqdm import tqdm
from beepy import beep


In [None]:
# setting seeds
seed_value = 1508

os.environ['PYTHONHASHSEED']=str(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.manual_seed(seed_value)
torch.backends.cudnn.deterministic = True

In [None]:
# parameters
TRAINING = True
MAX_EPOCH = 40
IMG_SIZE = 128
BATCH_SIZE = 10

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = "cpu"
print(device)

## Load data

In [None]:
# load data
def load_zipped_pickle(filename):
    with gzip.open(filename, 'rb') as f:
        loaded_object = pickle.load(f)
        return loaded_object

train_data = load_zipped_pickle("train.pkl")
test_data = load_zipped_pickle("test.pkl")
samples = load_zipped_pickle("sample.pkl")

split_index = 60

# split data
train = train_data[0:split_index] # 56 # 46 is first expert
validation = train_data[split_index:] # collect only expert labels for val

# prepare data for dataloader
X_train = []
y_train = []
name_train = []

for idx, img in enumerate(train):
    frames = img['frames']
    for frame in frames:
        X_train.append(img['video'][:, :, frame])
        y_train.append(img['label'][:, :, frame])
        name_train.append(img['name'])

X_val = []
y_val = []
name_val = []

for idx, img in enumerate(validation):
    frames = img['frames']
    for frame in frames:
        X_val.append(img['video'][:, :, frame])
        y_val.append(img['label'][:, :, frame])
        name_val.append(img['name'])
        
X_test = []
y_test = []
name_test = []

for idx, img in enumerate(test_data):
    for frame in range(img['video'].shape[2]):
        X_test.append(img['video'][:, :, frame])
        y_test.append(np.zeros_like(img['video'][:, :, frame]))
        name_test.append(img['name'])

val_targets = []
for itm in validation:
    val_targets.append({'name': itm['name'], 'label': itm['label'][:, :, itm['frames']]})
    
train_targets = []
for itm in train:
    train_targets.append({'name': itm['name'], 'label': itm['label'][:, :, itm['frames']]})
    
sub_targets = []
for itm in test_data:
    sub_targets.append({'name': itm['name'], 'label': np.zeros_like(itm['video'])})

In [None]:
class RotationTransform:
    """Rotate by one of the given angles."""

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = self.angles
        return TF.rotate(x, angle)

rotation_transform = RotationTransform(angles=15)


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.abs((torch.randn(tensor.size()) * self.std + self.mean)/10.)
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(seed_value)

class CustomImageDataset(Dataset):
    def __init__(self, inputs, labels, name, augmentation=False):
        self.img_labels = labels
        self.img = inputs
        self.name = name
        self.augmentation = augmentation

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        image = self.img[idx]
        label = self.img_labels[idx]
        
        if self.augmentation:
            transf_img = transforms.Compose([AddGaussianNoise()])
            
            transf_img_alb = Compose([
                                    OneOf([
                                               RandomBrightness(limit=0.3),
                                               RandomContrast(limit=0.4)
                                           ], p=1)])
            
            transf_img_alb_mask = Compose([
                                        OneOf([
                                                Affine(shear=1)
                                               ], p=1)])
            
            transf = transforms.Compose([transforms.ToTensor(),
                                        transforms.CenterCrop(image.shape[0]),
                                        transforms.Resize(IMG_SIZE, interpolation=Image.NEAREST),
                                           transforms.RandomChoice([
                                               RotationTransform(angles=0),
                                               RotationTransform(angles=5),
                                               RotationTransform(angles=-5),
                                               RotationTransform(angles=10),
                                               RotationTransform(angles=-10),
                                               transforms.RandomAffine(degrees=(0, 0), shear=(10, 10)),
                                               transforms.RandomAffine(degrees=(0, 0), shear=(-10, -10))
                                           ])
                                         ])

        else:
            transf = transforms.Compose([transforms.ToTensor(),
                                           transforms.CenterCrop(image.shape[0]),
                                           transforms.Resize(IMG_SIZE, interpolation=Image.NEAREST)
                                         ])
        
        
        stacked = np.dstack([image, label])
        stacked = transf(stacked)
        img, label = torch.chunk(stacked, chunks=2, dim=0)
            
        if self.augmentation:
            
            img = img.cpu().detach().numpy()
            img = transf_img_alb(image=img)
            img = torch.tensor(img['image'])
            
            img = transf_img(img)
            
            #print("-"*20)
            
            img[img > 1.] = 1
            img[img < 0.] = 0
            
        
        video_name = self.name[idx]
        return img, label.bool().float(), video_name

In [None]:
training_data = CustomImageDataset(X_train, y_train, name_train, augmentation=True)
training_data_for_eval = CustomImageDataset(X_train, y_train, name_train, augmentation=False)
validation_data = CustomImageDataset(X_val, y_val, name_val, augmentation=False)
testing_data = CustomImageDataset(X_test, y_test, name_test, augmentation=False)

train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, worker_init_fn=seed_worker, generator=g)
val_dataloader = DataLoader(validation_data, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker, generator=g)

# no augmentation
train_dataloader_for_eval = DataLoader(training_data_for_eval, batch_size=1, shuffle=False, drop_last=False, worker_init_fn=seed_worker, generator=g)
val_dataloader_for_eval = DataLoader(validation_data, batch_size=1, shuffle=False, drop_last=False, worker_init_fn=seed_worker, generator=g)

test_dataloader_for_submission = DataLoader(testing_data, batch_size=1, shuffle=False, drop_last=False, worker_init_fn=seed_worker, generator=g)

In [None]:
train_features, train_labels, _ = next(iter(train_dataloader))
print(f"Feature train batch shape: {train_features.size()}")
print(f"Labels train batch shape: {train_labels.size()}")
print("-"*50)
val_features, val_labels, _ = next(iter(val_dataloader))
print(f"Feature val batch shape: {val_features.size()}")
print(f"Labels val batch shape: {val_labels.size()}")

test_idx = 1

img = train_features[test_idx].squeeze()
print(img.min(), img.max())

img_m = train_labels[test_idx].squeeze()
print(img_m.min(), img_m.max())

img_merge = img.clone()
img_merge[img_m.bool()] = 1
print(img_merge.min(), img_merge.max())

f, axarr = plt.subplots(1, 3, figsize=(15, 3))
axarr[0].imshow(img, cmap='gray')
axarr[1].imshow(img_m, cmap='gray')
axarr[2].imshow(img_merge, cmap='gray')
plt.show()

## Network

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
    
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

## Training

In [None]:
def iou_score(output, target):
    smooth = 1e-5

    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output_ = output > 0.5
    target_ = target > 0.5
    intersection = (output_ & target_).sum()
    union = (output_ | target_).sum()

    return (intersection + smooth) / (union + smooth)

In [None]:
if TRAINING:
    # BW
    model = UNet(1, 1).to(device)
    criterion = nn.BCEWithLogitsLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    min_valid_loss = np.inf
    train_losses = []
    val_losses = []

    for epoch in range(MAX_EPOCH):
        train_loss = 0.0
        valid_loss = 0.0

        for data in train_dataloader:

            # get inputs
            inputs, labels, names = data
            
            inputs = Variable(inputs)
            labels = Variable(labels)

            # send to device
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the params gradients
            optimizer.zero_grad()

            # forward, backward, optimize
            outputs = model(inputs)

            loss_bce = criterion(outputs, labels)
            loss_iou = iou_score(outputs, labels)
            loss = loss_bce * 0.5 + (iou_score * -0.5 + 0.5)
            loss.backward()

            optimizer.step()

            # print loss
            train_loss += loss.item() * inputs.size(0)


        for data_val in val_dataloader:

            # get inputs
            inputs_val, labels_val, names_val = data

            # send to device
            inputs_val, labels_val = inputs_val.to(device), labels_val.to(device)


            target_val = model(inputs_val)
            
            loss_val = criterion(target_val, labels_val) 
            loss_iou_val = iou_score(target_val, labels_val)
            loss_val = loss_val * 0.5 + loss_iou_val * 0.5
            valid_loss += loss_val.item() * inputs_val.size(0)

        train_losses.append(train_loss / len(train_dataloader))
        val_losses.append(valid_loss / len(val_dataloader))

        print(f'Epoch {epoch+1} \t\t Training loss: {train_loss / len(train_dataloader)} \t\t Validation loss: {valid_loss / len(val_dataloader)}')

        if min_valid_loss > (valid_loss/len(val_dataloader)):

            print(f'Validation Loss Decreased({min_valid_loss: .6f}----->{valid_loss/len(val_dataloader):.6f}) \t Saving the model..')
            min_valid_loss = (valid_loss/len(val_dataloader))

            # save model
            torch.save(model.state_dict(), 'unet_min.pth')
            
        if epoch % 5 == 0 and epoch > 0:
            torch.save(model.state_dict(), 'unet_' + str(epoch) + '.pth')

    
    print('Finished Training')
    torch.save(model.state_dict(), 'unet.pth')

    plt.plot(train_losses)
    plt.plot(val_losses)
    plt.show()
    
    beep(4)

else:
    print("Training skipped...")

In [None]:
# testing
def test_net(test_dataloader, net_name):
    mask = []
    mask_aux = []
    
    for data in tqdm(test_dataloader):

        # get inputs
        inputs_test, labels_test, names_test = data

        # send to device
        inputs_test, labels_test = inputs_test.to('cpu'), labels_test.to('cpu')

        # load best model
        unet = UNet(1, 1).to('cpu')
        unet.load_state_dict(torch.load(net_name))
        unet.eval()

        # predict test data
        target_test = unet(inputs_test)

        # apply sigmoid to predictions
        preds_test = torch.sigmoid(target_test)

        mask.append({
                'name': names_test[0],
                'prediction': (preds_test.cpu() > 0.5).float()
        })

        mask_aux.append({
                'name': names_test[0],
                'input': inputs_test.cpu(),
                'label': labels_test.cpu()
        })

    return mask, mask_aux

def show_masks(mask, mask_aux):
    for test_ind in range(0, 15):
        f, axarr = plt.subplots(1, 3, figsize=(10, 3))
        axarr[2].imshow(mask[test_ind]['prediction'].detach().numpy().squeeze(0).squeeze(0), cmap='gray')
        axarr[0].imshow(mask_aux[test_ind]['input'].detach().numpy().squeeze(0).squeeze(0), cmap='gray')
        axarr[1].imshow(mask_aux[test_ind]['label'].detach().numpy().squeeze(0).squeeze(0), cmap='gray')
        plt.show()

net_name_to_test = 'unet_25.pth'
val_mask, val_mask_aux = test_net(val_dataloader_for_eval, net_name_to_test)
train_mask, train_mask_aux = test_net(train_dataloader_for_eval, net_name_to_test)
sub_mask, sub_mask_aux = test_net(test_dataloader_for_submission, net_name_to_test)

In [None]:
# postprocessing
def postprocess(prd, label, image, show=False):
    pred = prd.numpy().squeeze(0).squeeze(0)
    label = label.numpy().squeeze(0).squeeze(0)
    image = image.numpy().squeeze(0).squeeze(0)
    
    # threshold on original image
    image_th = (image*255 > 30).astype(float)
    # erode mask
    erosion_square = square(2) # slightly better with just erosion 2
    pred = erosion(pred, erosion_square)
    # return to tensor
    pred = torch.tensor(pred).unsqueeze(0).unsqueeze(0)
    
    if show:
        f, axarr = plt.subplots(1, 5, figsize=(15, 3))
        axarr[0].imshow(image, cmap='gray')
        axarr[1].imshow(image_th, cmap='gray')
        axarr[2].imshow(label, cmap='gray')
        axarr[3].imshow(prd.numpy().squeeze(0).squeeze(0), cmap='gray')
        axarr[4].imshow(pred, cmap='gray')
        plt.show()
    
    return pred


def postprocess_dataset(dataset, dataset_aux):
    for index in tqdm(range(0, len(dataset))):
        dataset[index]['prediction'] = postprocess(dataset[index]['prediction'],
                                                   dataset_aux[index]['label'],
                                                   dataset_aux[index]['input'])
    return dataset
        
val_mask = postprocess_dataset(val_mask, val_mask_aux)
train_mask = postprocess_dataset(train_mask, train_mask_aux)
sub_mask = postprocess_dataset(sub_mask, sub_mask_aux)

In [None]:
def evaluate(predictions, targets):
    ious = []
    for p, t in zip(predictions, targets):
        assert p['name'] == t['name']
        prediction = np.array(p['prediction'], dtype=bool)
        target = np.array(t['label'], dtype=bool)

        assert target.shape == prediction.shape
        overlap = prediction * target
        union = prediction + target

        ious.append(overlap.sum()/float(union.sum()))
    
    print("\nMedian IOU:", round(np.median(ious), 3))
    print("\nMean IOU:  ", round(np.mean(ious), 3))
    print("\nMin IOU:   ", round(np.min(ious), 3))
    print("Q10 IOU:   ", round(np.quantile(ious, 0.10), 3))
    print("Q15 IOU:   ", round(np.quantile(ious, 0.15), 3))
    print("Q25 IOU:   ", round(np.quantile(ious, 0.25), 3))
    print("Q40 IOU:   ", round(np.quantile(ious, 0.40), 3))
    print("Q50 IOU:   ", round(np.quantile(ious, 0.50), 3))
    print("Q75 IOU:   ", round(np.quantile(ious, 0.75), 3))
    print("Max IOU:   ", round(np.max(ious), 3))


def search_by_name(name, dataset):
    for idx, item in enumerate(dataset):
        if item['name'] == name:
            return (idx, item)
    return None

def revert_changes(img, size):
    
    img = np.array(img.squeeze(0).squeeze(0))
    img = resize(img, (size[0], size[0]))
    
    size_diff = size[1] - size[0]
    left_pad = int(size_diff / 2)
    if size_diff % 2 == 0:
        right_pad = int(size_diff / 2)
    else:
        right_pad = int(size_diff / 2) + 1
    
    img = np.pad(img, ((0, 0), (left_pad, right_pad)))
    
    return img

def revert_mask(mask, target):
    mask_reverted = []
    for mask_item in mask:
        target_size = search_by_name(mask_item['name'], target)[1]['label'].shape[:2]
        mask_reverted.append({'name': mask_item['name'],
                            'prediction': revert_changes(mask_item['prediction'], target_size)})
    return mask_reverted

def get_predictions(mask, target):
    predictions = []
    for rev_item in tqdm(revert_mask(mask, target)):
        search = search_by_name(rev_item['name'], predictions)
        if search:
            pred = predictions[search[0]]
            pred['prediction'] = np.dstack([pred['prediction'], rev_item['prediction'].astype(bool)])
            predictions[search[0]] = pred
        else:
            predictions.append({'name': rev_item['name'], 'prediction': rev_item['prediction'].astype(bool)})
    
    return predictions

val_predictions = get_predictions(val_mask, val_targets)
train_predictions = get_predictions(train_mask, train_targets)
sub_predictions = get_predictions(sub_mask, sub_targets)
        
print('VALIDATION SCORE:')
evaluate(val_predictions, val_targets)
print('\nTRAINING SCORE:')
evaluate(train_predictions, train_targets)
print('\nSUBMISSION MOCK SCORE:')
evaluate(sub_predictions, sub_targets)

In [None]:
def save_zipped_pickle(obj, filename):
    with gzip.open(filename, 'wb') as f:
        pickle.dump(obj, f, 2)
        
save_zipped_pickle(sub_predictions, 'submission.gzip')

### Post-processing

In [None]:
plt.imshow(val_mask_aux[0]['input'].squeeze(), cmap='gray')
plt.show()

In [None]:
idx = 2
f, axarr = plt.subplots(1, 3, figsize=(10, 3))
axarr[0].imshow(val_mask_aux[idx]['input'].squeeze(), cmap='gray')
axarr[1].imshow(val_mask_aux[idx]['label'].squeeze(), cmap='gray')
axarr[2].imshow(val_mask[idx]['prediction'].squeeze(), cmap='gray')
plt.show()