# Preprocessing
Imports. Loading data. Definining dataloaders

In [56]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch
import PIL.Image
import os
import seaborn as sbn
from albumentations import GaussNoise, RandomGridShuffle, Normalize, PixelDropout
import elasticdeform
from skimage.io import imread, imsave

sbn.set()

datadir = "EM_ISBI_CHALLENGE"

In [57]:
# Random image subset
np.random.seed(42)
validation_subset_idx = [int(x) for x in np.random.rand(5) * 31]
validation_subset_idx.sort()

train_subset_idx = [x for x in range(1, 31) if x not in validation_subset_idx]

In [58]:
# Make dataset class.
class ChallengeData(torch.utils.data.Dataset):
    '''  Dataset which loads all images for training or testing'''
    def __init__(self, data_dir, type, margin_size=20):
        self.images = []
        self.labels = []

        if type == 'train':
            self.train_set = True
            self.test_set = False
            self.validation_set = False
        elif type == 'test':
            self.train_set = False
            self.test_set = True
            self.validation_set = False
        elif type == 'validation':
            self.train_set = False
            self.test_set = False
            self.validation_set = True

        self.data_dir = data_dir
        self.img_prefix = 'train' if (self.train_set or self.validation_set) else 'test'
        self.img_folder = 'train_images' if (self.train_set or self.validation_set) else 'test_images'
        self.labels_folder = 'train_labels'
        self.labels_prefix = 'labels'
        

        if self.train_set:
            self.load_train_set()
        elif self.test_set:
            self.load_test_set()
        elif self.validation_set:
            self.load_validation_set()


    def load_train_set(self):
        for idx in train_subset_idx:
            im = np.array(PIL.Image.open(f'{self.data_dir}/{self.img_folder}/{self.img_prefix}_{idx:02d}.png'))
            im = im/255
            self.images.append(torch.tensor(im, dtype=torch.float32))

            label_im = np.array(PIL.Image.open(f'{self.data_dir}/{self.labels_folder}/{self.labels_prefix}_{idx:02d}.png'))
            label_im = label_im/255
            # label_im = label_im[margin_size:-margin_size, margin_size:-margin_size]/255
            self.labels.append(torch.tensor(label_im, dtype=torch.float32))

    def load_validation_set(self):
        for idx in validation_subset_idx:
            im = np.array(PIL.Image.open(f'{self.data_dir}/{self.img_folder}/{self.img_prefix}_{idx:02d}.png'))
            im = im/255
            self.images.append(torch.tensor(im, dtype=torch.float32))

            label_im = np.array(PIL.Image.open(f'{self.data_dir}/{self.labels_folder}/{self.labels_prefix}_{idx:02d}.png'))
            label_im = label_im/255
            # label_im = label_im[margin_size:-margin_size, margin_size:-margin_size]/255
            self.labels.append(torch.tensor(label_im, dtype=torch.float32))


    def load_test_set(self):
        for idx in range(1, 31):
            im = np.array(PIL.Image.open(f'{self.data_dir}/{self.img_folder}/{self.img_prefix}_{idx:02d}.png'))
            im = im/255
            self.images.append(torch.tensor(im, dtype=torch.float32))

    def __getitem__(self, idx):
        if self.train_set or self.validation_set:
            return self.images[idx], self.labels[idx]
        
        return self.images[idx]
    
    def __len__(self):
        return len(self.images)

In [50]:
# Make training and validation set.
# (This involves loading images and may take some seconds.)
challengeTrainData = ChallengeData(datadir, type='train')
challengeValidationData = ChallengeData(datadir, type='validation')
challengeTestData = ChallengeData(datadir, type='test')


In [51]:
trainloader = torch.utils.data.DataLoader(challengeTrainData,
                                          batch_size=10,
                                          shuffle=True,
                                          drop_last=True)
validationloader = torch.utils.data.DataLoader(challengeValidationData,
                                            batch_size=5)
testloader = torch.utils.data.DataLoader(challengeTestData,
                                          batch_size=10)

# Model Definition
Model classes

In [61]:
#%% Make model class.
class UNet128(torch.nn.Module):
    """Takes in patches of 128^2 RGB, returns 88^2"""
    
    def __init__(self, out_channels=2):
        super().__init__()

        # Learnable
        self.conv1A = torch.nn.Conv2d(1, 8, 3)  
        self.conv1B = torch.nn.Conv2d(8, 8, 3)  
        self.conv2A = torch.nn.Conv2d(8, 16, 3)  
        self.conv2B = torch.nn.Conv2d(16, 16, 3)  
        self.conv3A = torch.nn.Conv2d(16, 32, 3)  
        self.conv3B = torch.nn.Conv2d(32, 32, 3)  
        self.conv4A = torch.nn.Conv2d(32, 16, 3)  
        self.conv4B = torch.nn.Conv2d(16, 16, 3)  
        self.conv5A = torch.nn.Conv2d(16, 8, 3)  
        self.conv5B = torch.nn.Conv2d(8, 8, 3)  
        self.convfinal = torch.nn.Conv2d(8, out_channels, 1)         
        self.convtrans34 = torch.nn.ConvTranspose2d(32, 16, 2, stride=2) 
        self.convtrans45 = torch.nn.ConvTranspose2d(16, 8, 2, stride=2)
        
        # Convenience
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(2, 2)        
       
    def forward(self, x):
 
        # Down, keeping layer outputs we'll need later.
        l1 = self.relu(self.conv1B(self.relu(self.conv1A(x))))
        l2 = self.relu(self.conv2B(self.relu(self.conv2A(self.pool(l1)))))
        out = self.relu(self.conv3B(self.relu(self.conv3A(self.pool(l2))))) 
        
        # Up, now we overwritte out in each step.
        out = torch.cat([self.convtrans34(out), l2[:,:,4:-4,4:-4]], dim=1)
        out = self.relu(self.conv4B(self.relu(self.conv4A(out))))
        out = torch.cat([self.convtrans45(out), l1[:,:,16:-16,16:-16]], dim=1)      
        out = self.relu(self.conv5B(self.relu(self.conv5A(out))))
   
         # Finishing
        out = self.convfinal(out)
  
        return out

# Training parameters
Include here what types of data augmentation we're using. Using weight map or not. Define models and epochs. 

In [62]:
net = UNet128()

net(torch.zeros(1, 1, 512, 512)).shape

torch.Size([1, 2, 472, 472])

# Training Loop
Loop for training the model with the parameters

# Results
Visualizations of the results for the report