# Combined Notebook

# Imports and Global Variables

In [None]:
# initial imports

import time
import datetime
import math
import os
import re
import cv2
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import random
from PIL import Image
import gc
import json
import re
import shutil

The below configuration cell can be set accordingly to reproduce any of our experiments.

In [None]:
# some constants
CONFIGURATION_NAME = 'unet3plusRS_bcedicemixcustom_nopre_allval'
SEED = 123
PATCH_SIZE = 16  # pixels per side of square patches
CUTOFF = 0.25
DATASETS = ['kaggle_data', 'new_data']
BALANCE_DATASETS = False # if true each dataset will contribute a sample of equal size to the training set
MAX_SAMPLE_SIZE = 0 # refers number of files to sample for each training iteration (0 => all files)
VAL_SIZE = 200  # size of the validation set (number of images)
VALIDATE_ONLY_KAGGLE = False # choose wether to only validate on kaggle data (mainly for early stopping)
BATCH_SIZE = 8
N_EPOCHS = 60
USE_EARLY_STOPPING = True

# filtering of the data, using min / max thresholds for percent of 1-labeled groundtruth pixels,
# as well as a similarity threshold and which similarity network (alex or VGG) to compare against
FILTER_DATA = True
MIN_ROADS_PCT = 0.07
MAX_ROADS_PCT = 0.7
PREPROCESS_SIMILARITY_SCORE = 0.85
SIMILARITIY_NETWORK = "alex"
SIMILARITIY_IMAGE_TYPE = "groundtruth"

# should never set both augmentation options to true!
AUGMENT_DATA = False # creates a file for every 90 degree rotation for all images
AUGMENT_ONLY_KAGGLE = True # does the augmentation, but only for the included kaggle data
PREPROCESS_CONTRAST = False # normalize contrast
PREPROCESS_RANDOMIZE = False # randomize contrast and brightness
PREPROCESS_GAUSSIAN_BLUR_KERNEL_SIZE = 0 # Blurs images to increase edge detection
POSTPROCESSING_MAJORITY = 4 # if over 0, rotates and predicts multiple times for each test image, taking the majority consensus
POSTPROCESSING_MORPHOLOGICAL_ITERATIONS = 0 # if over 0, morph. postprocessing is used

# a pre-processing feature extraction filter used for detecting tubular structures
PREPROCESS_FRANGI = False

# global constants used to implement full-scale skip connections in the U-Net3+
GLOBAL_RES = 384
GLOBAL_FIRST_CHANNELS = 64

# our models and loss functions are defined throughout the course of this notebook,
# so this dictionary is filled in as the notebook executes.
MODEL = 'unet'  # model to be used
LOSS_FN = 'bce' # loss function to be used 
# Dictionaries to store models and loss functions to make them accessible
MODELS = {
    'unet': None,             # basic UNet()
    'sdunet': None,           # SDUNet
    'attention_sdunet': None, # Attention UNet using SDC blocks
    'unet3plusRS': None       # Our custom U-Net3+RS
    }
LOSS_FNS = {
    'bce': None,                  # vanilla nn.BCELoss()
    'soft_dice': None,            # soft dice loss
    'custom_bce': None,           # a custom softening idea for the focal BCE, not used in final experiments due to poor performance
    'focal_bce': None,            # vanilla focal BCE
    'bce_dice_mix': None,         # mix of dice and focal BCE
    'custom_bce_dice_mix': None,  # a custom softening idea for the focal BCE, not used in final experiments
    'focal_dice_chi': None        # our final custom loss function, detailed in the report
    }

# training can be resumed from a loaded pytorch model checkpoint by setting this parameter
LOAD_CHECKPOINT = None
# this dictionary is used to create logging output for reproducability of any experiment
HYPERPARAMETERS = {
    'datasets': DATASETS,
    'balance_datasets': BALANCE_DATASETS, 
    'train_set_size': 0, 
    'val_set_size': VAL_SIZE, 
    'n_epochs': N_EPOCHS,
    'seed': SEED,
    'model': MODEL,
    'loss_fn': LOSS_FN,
    'frangi': PREPROCESS_FRANGI
    }
STATE = {
    'total_iterations': 0,
    'current_iteration': 0,
    'epoch': 0,
    'time_trained': 0
    }

# enable some debugging prints
DEBUG = True

In [None]:
# globally set seeds for reproducibility
random.seed(SEED)
np.random.seed(SEED)

# Data

## Functions
Here we define various util functions to enable the loading, filtering and augmentation of data.

In [None]:
def load_from_path(path, isGroundtruth, sample = []):
    if sample == []:
        files = sorted(glob(path + '/*.png'))
    else:
        files = list(map(lambda basename: os.path.join(path, basename), sample))

    if isGroundtruth:
        img_stack = np.stack([np.array(Image.open(f).convert('1')) for f in files]).astype(np.float32)
    else:
        img_stack = np.stack([np.array(Image.open(f).convert('RGB')) for f in files]).astype(np.float32) / 255.
        
    print(f"{len(files)} files loaded. Shape = {img_stack.shape}. Max Value = {img_stack.max()}")
    return img_stack

def generate_sample(path, sample_size, already_sampled = []):
    all_files = list(map(os.path.basename, sorted(glob(path + '/*.png'))))
    sample = [f for f in all_files if f not in already_sampled]
    if sample_size > 0 and len(sample) > sample_size:
        sample = sorted(random.sample(sample, sample_size))
    print(f"Generated sample consisting of {len(sample)} files.")
    return sample

def show_first_n(imgs, masks, n=5):
    # visualizes the first n elements of a series of images and segmentation masks
    imgs_to_draw = min(5, len(imgs))
    fig, axs = plt.subplots(2, imgs_to_draw, figsize=(18.5, 6))
    for i in range(imgs_to_draw):
        axs[0, i].imshow(imgs[i])
        axs[1, i].imshow(masks[i])
        axs[0, i].set_title(f'Image {i}')
        axs[1, i].set_title(f'Mask {i}')
        axs[0, i].set_axis_off()
        axs[1, i].set_axis_off()
    plt.show()

In [None]:
# returns filter which includes indices of images included in the dataset
def filter_roads(masks):
    # Get labels not reshaped (else everything same as image_to_patches)
    n_images = masks.shape[0]  # number of images
    h, w = masks.shape[1:3]  # shape of images
    assert (h % PATCH_SIZE) + (w % PATCH_SIZE) == 0  # make sure images can be patched exactly

    h_patches = h // PATCH_SIZE
    w_patches = w // PATCH_SIZE

    masks = masks.reshape((n_images, h_patches, PATCH_SIZE, w_patches, PATCH_SIZE, -1))
    masks = np.moveaxis(masks, 2, 3)
    labels = np.mean(masks, (-1, -2, -3)) > CUTOFF  # compute labels
    labels = labels.astype(np.float32)

    filter = []

    for i, d in enumerate(labels):
        dt = d.reshape(-1)
        if (sum(dt) / len(dt) <= MIN_ROADS_PCT or sum(dt) / len(dt) >= MAX_ROADS_PCT):
            filter.append(i)

    return filter

# Load data with this: e.g. tTI = genfromtxt('cil/results/similarities/similaritiesAlexTrainTrainImages.csv', delimiter=',')
def filter_similarities(data):
    filter = []
    for i, d in enumerate(data):
        if (d.mean() >= PREPROCESS_SIMILARITY_SCORE):
            filter.append(i)
    return filter

In [None]:
# go through all the images and rotate them 3 times, producing an augmented dataset
def produce_rotations(sourcePath, onlykaggle=False):
    lst = os.listdir(sourcePath)
    # can also choose to only augment kaggle data, giving this more weight
    if onlykaggle:
        lst = list(filter(lambda x: 'satimage' in x, lst))
    for img_file in lst:
        # initial unrotated
        img_read = Image.open(sourcePath + '/' + img_file)
        # filename ending indicates rotation amount
        for i in range(1, 4):
            img_read_rot = img_read.rotate(i * 90)
            img_read_rot.save(sourcePath + '/' + img_file.replace('.png', '') + '_' + str(i * 90) + '.png')

## Import data from github

In [None]:
start_download_data = time.time()

In [None]:
!git clone https://github.com/Rasilu/cil

In [None]:
stop_download_data = time.time()
print(f"Time to load data: {round(stop_download_data - start_download_data, 2)}")

## Create training and validation sets
These cells implement our data filtering as described in our report. Importantly, we never filter the included kaggle data. At the start of every run, we also delete any images in the training folder that are also in the validation folder. We do this in order to enable faster testing, as this way the data does not have to be re-downloaded when restarting and re-running the notebook.

In [None]:
# prepare datasets, by loading and filtering them
print("Prepare datasets ...")
!mkdir -p datasets
!rm -r datasets/*
!cp -r cil/datasets/kaggle_data .
for d in DATASETS:
    !cp -r cil/datasets/{d} datasets
    # never filter kaggle data!
    if (not FILTER_DATA) or (d == 'kaggle_data'):
        continue;
    
    # load the filenames of the given dataset
    data_path = 'datasets/' + d
    data_filtered_path = 'datasets/filtered/' + d
    data_train_img_path = data_path + '/training/images'
    data_train_msk_path = data_path + '/training/groundtruth'
    dataset_img_files = sorted(glob(data_train_img_path + '/*.png'))
    dataset_msk_files = sorted(glob(data_train_msk_path + '/*.png'))
    data_basenames = list(map(os.path.basename, dataset_img_files))

    # load the masks, check them against the similarity threshold and filter them
    dataset_masks = load_from_path(data_train_msk_path, isGroundtruth=True)
    filter_roads_exclude = filter_roads(dataset_masks)

    file_name = "similarities_" + d + "_kaggle_data_train_" + SIMILARITIY_IMAGE_TYPE + "_" + SIMILARITIY_NETWORK + ".csv"
    csv_np = np.genfromtxt("cil/results/similaritiesForPipeline/" + file_name, delimiter=',')
    filter_similar_exclude = filter_similarities(csv_np)

    filter_combined_exclude = [idx for idx in range(len(dataset_img_files)) if idx in filter_similar_exclude or idx in filter_roads_exclude]
    print(f"{len(filter_combined_exclude)} files deleted from {d}")
    filter_dataset_exclude = np.array(data_basenames)[filter_combined_exclude]
    for basename in filter_dataset_exclude:
        os.remove(f"{data_train_img_path}/{basename}")
        os.remove(f"{data_train_msk_path}/{basename}")

In [None]:
def replace_name(s):
    img_named = s.replace('.png', '')
    return img_named

# create training set
print("Creating training set ...")
train_path = 'training'
train_images_path = 'training/images'
train_masks_path = 'training/groundtruth'
!mkdir -p {train_path}
!rm -r {train_path}/*
!mkdir {train_path}/images
!mkdir {train_path}/groundtruth
# provides the option of cutting each dataset down to the same amount of samples as the smallest dataset
if not BALANCE_DATASETS:
    for d in DATASETS:
        !cp -r datasets/{d}/training .
else:
    smallest_dataset = 1000000
    for d in DATASETS:
        dataset_size = len(glob(f"datasets/{d}/training/images/*.png"))
        print(f"{dataset_size} files in {d}")
        if dataset_size < smallest_dataset:
            smallest_dataset = dataset_size
    print(f"Size of smallest dataset = {smallest_dataset} -> sample {smallest_dataset} files from each dataset")
    for d in DATASETS:
        dataset_files = sorted(glob(f"datasets/{d}/training/images/*.png"))
        dataset_sample = sorted(random.sample(dataset_files, smallest_dataset))
        for img in dataset_sample:
            mask = img.replace('images', 'groundtruth')
            os.rename(img, img.replace(f"datasets/{d}/training", 'training'))
            os.rename(mask, mask.replace(f"datasets/{d}/training", 'training'))

# create validation set
print("\nCreating validation set ...")
!cp -r cil/datasets/kaggle_data .
val_path = 'validation'
!mkdir -p {val_path}
!rm -r {val_path}/*
!mkdir {val_path}/images
!mkdir {val_path}/groundtruth
kdata = ""
# provides the option to only validate on kaggle data
if VALIDATE_ONLY_KAGGLE:
    kdata = "kaggle_data/"
val_files = glob(kdata + "training/images/*.png")
if VAL_SIZE >= 0 and VAL_SIZE <= len(val_files):
    val_files = sorted(random.sample(val_files, VAL_SIZE))
for img in val_files:
    mask = img.replace('images', 'groundtruth')
    os.rename(img, img.replace(kdata + 'training', 'validation'))
    os.rename(mask, mask.replace(kdata + 'training', 'validation'))

# remove training images (and variations) that are already contained in validation set
# useful in the case that the notebook runtime is restarted without files being reset
print("\nRemoving files from validation set...")
train_files = sorted(glob(train_images_path +  '/*.png'))   # ex. ['validation/images/satimage_0.png', ...]
train_basenames = list(map(os.path.basename, train_files))  # ex. ['satimage_0.png', 'satimage_1.png', ...]
val_basenames = list(map(os.path.basename, val_files))      # ex. ['validation/images/satimage_2.png', ...]
val_base = [os.path.splitext(split)[0] for split in val_basenames]  # ex. ['satimage_2.png', ...]
val_base_re = '|'.join([vb for vb in val_base])             # ex. 'satimage_2|satimage_15|...|satimage_140'
pat = re.compile(f'({val_base_re})\\D') # regex pattern: (val_base_re)[^0-9]
count = 0
for i in range(len(train_files)):
    for j in range(len(val_files)):
        if pat.match(train_basenames[i]): # matching regex to start of file basenames in training set
            os.remove(train_files[i])
            os.remove(train_files[i].replace('/images', '/groundtruth'))
            count += 1
            break
print(str(count) + " images removed.")

In [None]:
# exclusive augmentation of the kaggle data
if AUGMENT_ONLY_KAGGLE:
    produce_rotations(train_images_path, onlykaggle=True)
    produce_rotations(train_masks_path, onlykaggle=True)
# create augmented data if necessery
if AUGMENT_DATA:
    produce_rotations(train_images_path)
    produce_rotations(train_masks_path)


In [None]:
# sanity check: check size of training set and make sure the mask set is the same size
train_img_set_size = len(glob(train_images_path + '/*.png'))
train_msk_set_size = len(glob(train_masks_path + '/*.png'))
assert(train_img_set_size == train_msk_set_size)
HYPERPARAMETERS['train_set_size'] = train_img_set_size
print("Training set size: " + str(train_img_set_size))

In [None]:
def image_to_patches(images, masks=None):
    # takes in a 4D np.array containing images and (optionally) a 4D np.array containing the segmentation masks
    # returns a 4D np.array with an ordered sequence of patches extracted from the image and (optionally) a np.array containing labels
    n_images = images.shape[0]  # number of images
    h, w = images.shape[1:3]  # shape of images
    assert (h % PATCH_SIZE) + (w % PATCH_SIZE) == 0  # make sure images can be patched exactly

    images = images[:,:,:,:3]
    
    h_patches = h // PATCH_SIZE
    w_patches = w // PATCH_SIZE
    
    patches = images.reshape((n_images, h_patches, PATCH_SIZE, w_patches, PATCH_SIZE, -1))
    patches = np.moveaxis(patches, 2, 3)
    patches = patches.reshape(-1, PATCH_SIZE, PATCH_SIZE, 3)
    if masks is None:
        return patches

    masks = masks.reshape((n_images, h_patches, PATCH_SIZE, w_patches, PATCH_SIZE, -1))
    masks = np.moveaxis(masks, 2, 3)
    labels = np.mean(masks, (-1, -2, -3)) > CUTOFF  # compute labels
    labels = labels.reshape(-1).astype(np.float32)
    return patches, labels


def show_patched_image(patches, labels, h_patches=25, w_patches=25):
    # reorders a set of patches in their original 2D shape and visualizes them
    fig, axs = plt.subplots(h_patches, w_patches, figsize=(18.5, 18.5))
    for i, (p, l) in enumerate(zip(patches, labels)):
        # the np.maximum operation paints patches labeled as road red
        axs[i // w_patches, i % w_patches].imshow(np.maximum(p, np.array([l.item(), 0., 0.])))
        axs[i // w_patches, i % w_patches].set_axis_off()
    plt.show()

In [None]:
def create_submission(labels, test_filenames, submission_filename, test_path='test/images'):
    
    with open(submission_filename, 'w') as f:
        f.write('id,prediction\n')
        for fn, patch_array in zip(sorted(test_filenames), labels):
            img_number = int(re.search(r"\d+", fn).group(0))
            for i in range(patch_array.shape[0]):
                for j in range(patch_array.shape[1]):
                    f.write("{:03d}_{}_{},{}\n".format(img_number, j*PATCH_SIZE, i*PATCH_SIZE, int(patch_array[i, j])))

## Dataloader

In [None]:
import torch
from torch import nn
from tqdm.notebook import tqdm
from torchvision import transforms
import torchvision
from skimage.filters import frangi

torch.manual_seed(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: " + device)
!nvidia-smi

def np_to_tensor(x, device):
    # allocates tensors from np.arrays
    if device == 'cpu':
        return torch.from_numpy(x).cpu()
    else:
        return torch.from_numpy(x).contiguous().pin_memory().to(device=device, non_blocking=True)


class ImageDataset(torch.utils.data.Dataset):
    # dataset class that deals with loading the data and making it available by index.

    def __init__(self, path, device, use_patches=True, resize_to=(400, 400), sample=[]):
        self.path = path
        self.device = device
        self.use_patches = use_patches
        self.resize_to=resize_to
        self.x, self.y, self.n_samples = None, None, None
        self._load_data(sample)
        if PREPROCESS_CONTRAST:
          self.deviation = np.std(self.x)

    def _load_data(self, sample=[]):
        self.x = load_from_path(os.path.join(self.path, 'images'), isGroundtruth=False, sample=sample)[:,:,:,:3]
        self.y = np.ceil(load_from_path(os.path.join(self.path, 'groundtruth'), isGroundtruth=True, sample=sample))
        if self.use_patches:  # split each image into patches
            self.x, self.y = image_to_patches(self.x, self.y)
        elif self.resize_to != (self.x.shape[1], self.x.shape[2]):  # resize images
            self.x = np.stack([cv2.resize(img, dsize=self.resize_to) for img in self.x], 0)
            self.y = np.stack([cv2.resize(mask, dsize=self.resize_to) for mask in self.y], 0)
        self.x = np.moveaxis(self.x, -1, 1)  # pytorch works with CHW format instead of HWC
        self.n_samples = len(self.x)

    def _preprocess(self, x, y):
        # various options for preprocessing can be configured at the start of the notebook
        if PREPROCESS_CONTRAST:
            x = (x - torch.mean(x)) / self.deviation
        if PREPROCESS_RANDOMIZE:
            jitter = torchvision.transforms.ColorJitter(brightness=0.1, contrast=0.3)
            x = jitter.forward(x)
        if PREPROCESS_GAUSSIAN_BLUR_KERNEL_SIZE > 0:
            transform = transforms.Compose([transforms.GaussianBlur(kernel_size=PREPROCESS_GAUSSIAN_BLUR_KERNEL_SIZE)])
            x = transform(x)
        if PREPROCESS_FRANGI:
            # filters only work on grayscale image, so it must first be converted
            with torch.no_grad():
                grayx = np.squeeze(transforms.Grayscale().forward(x).cpu().numpy(), axis=0)
                stackx = frangi(grayx)
                # now stack this information (feature extraction) onto the channel dimension of x
                x = torch.cat([x, torch.unsqueeze(np_to_tensor(stackx.astype(np.float32), self.device), 0)], dim=0)
        return x, y

    def __getitem__(self, item):
        return self._preprocess(np_to_tensor(self.x[item], self.device), np_to_tensor(self.y[[item]], self.device))
    
    def __len__(self):
        return self.n_samples


def show_val_samples(x, y, y_hat, segmentation=False):
    # training callback to show predictions on validation set
    imgs_to_draw = min(5, len(x))
    if x.shape[-2:] == y.shape[-2:]:  # segmentation
        fig, axs = plt.subplots(3, imgs_to_draw, figsize=(18.5, 12))
        for i in range(imgs_to_draw):
            axs[0, i].imshow(np.moveaxis(x[i], 0, -1))
            axs[1, i].imshow(np.concatenate([np.moveaxis(y_hat[i], 0, -1)] * 3, -1))
            axs[2, i].imshow(np.concatenate([np.moveaxis(y[i], 0, -1)]*3, -1))
            axs[0, i].set_title(f'Sample {i}')
            axs[1, i].set_title(f'Predicted {i}')
            axs[2, i].set_title(f'True {i}')
            axs[0, i].set_axis_off()
            axs[1, i].set_axis_off()
            axs[2, i].set_axis_off()
    else:  # classification
        fig, axs = plt.subplots(1, imgs_to_draw, figsize=(18.5, 6))
        for i in range(imgs_to_draw):
            axs[i].imshow(np.moveaxis(x[i], 0, -1))
            axs[i].set_title(f'True: {np.round(y[i]).item()}; Predicted: {np.round(y_hat[i]).item()}')
            axs[i].set_axis_off()
    plt.show()

# Models
All Models outlined in our report are implemented here.


In [None]:
class Block(nn.Module):
    # a repeating structure composed of two convolutional layers with batch normalization and ReLU activations
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1),
                                   nn.ReLU(),
                                   nn.BatchNorm2d(out_ch),
                                   nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1),
                                   nn.ReLU())

    def forward(self, x):
        return self.block(x)

        
class UNet(nn.Module):
    # UNet-like architecture for single class semantic segmentation.
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        enc_chs = chs  # number of channels in the encoder
        dec_chs = chs[::-1][:-1]  # number of channels in the decoder
        self.enc_blocks = nn.ModuleList([Block(in_ch, out_ch) for in_ch, out_ch in zip(enc_chs[:-1], enc_chs[1:])])  # encoder blocks
        self.pool = nn.MaxPool2d(2)  # pooling layer (can be reused as it will not be trained)
        self.upconvs = nn.ModuleList([nn.ConvTranspose2d(in_ch, out_ch, 2, 2) for in_ch, out_ch in zip(dec_chs[:-1], dec_chs[1:])])  # deconvolution
        self.dec_blocks = nn.ModuleList([Block(in_ch, out_ch) for in_ch, out_ch in zip(dec_chs[:-1], dec_chs[1:])])  # decoder blocks
        self.head = nn.Sequential(nn.Conv2d(dec_chs[-1], 1, 1), nn.Sigmoid()) # 1x1 convolution for producing the output

    def forward(self, x):
        # encode
        enc_features = []
        for block in self.enc_blocks[:-1]:
            x = block(x)  # pass through the block
            enc_features.append(x)  # save features for skip connections
            x = self.pool(x)  # decrease resolution
        x = self.enc_blocks[-1](x)
        # decode
        for block, upconv, feature in zip(self.dec_blocks, self.upconvs, enc_features[::-1]):
            x = upconv(x)  # increase resolution
            x = torch.cat([x, feature], dim=1)  # concatenate skip features
            x = block(x)  # pass through the block
        return self.head(x)  # reduce to 1 channel
        

In [None]:
class StackedDilConv(nn.Module):
    # an implementation of the SDC block, providing improved spatial resolution at different 
    # granularities by using dilations of various sizes accross its filter channels.
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.hlf = nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=(out_ch//2), kernel_size=3, padding='same'), nn.ReLU())
        self.quarter = nn.Sequential(nn.Conv2d(in_channels=(out_ch//2), out_channels=(out_ch//4), kernel_size=3, padding='same', dilation=3), nn.ReLU())
        self.eigth = nn.Sequential(nn.Conv2d(in_channels=(out_ch//4), out_channels=(out_ch//8), kernel_size=3, padding='same', dilation=6), nn.ReLU())
        self.sixt1 = nn.Sequential(nn.Conv2d(in_channels=(out_ch//8), out_channels=(out_ch//16), kernel_size=3, padding='same', dilation=9), nn.ReLU())
        self.sixt2 = nn.Sequential(nn.Conv2d(in_channels=(out_ch//16), out_channels=(out_ch//16), kernel_size=3, padding='same', dilation=12), nn.ReLU())

    def forward(self, x):
        n2 = self.hlf(x)
        n4 = self.quarter(n2)
        n8 = self.eigth(n4)
        n16_1 = self.sixt1(n8)
        n16_2 = self.sixt2(n16_1)
        return torch.cat([n2, n4, n8, n16_1, n16_2], dim=1)
            

class StackedDilBloc(nn.Module):
    # similar to the U-Net approach of two convolutions per layer,
    # it is also possible to use two SDC-blocks.

    def __init__(self, in_ch, out_ch):
        super().__init__()
        b1 = StackedDilConv(in_ch, out_ch)
        b2 = StackedDilConv(out_ch, out_ch)
        self.block = nn.Sequential(
            b1,
            nn.BatchNorm2d(out_ch),
            b2
        )
    def forward(self, x):
        return self.block(x)

class UNetStackedDilations(nn.Module):
    # an implementation of the SDU-Net.
    # the architecture is identical to the U-Net, except that the encoder and decoder blocks are replaced by SDC-Blocks.

    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        enc_chs = chs  # number of channels in the encoder
        dec_chs = chs[::-1][:-1]  # number of channels in the decoder
        self.enc_blocks = nn.ModuleList([StackedDilBloc(in_ch, out_ch) for in_ch, out_ch in zip(enc_chs[:-1], enc_chs[1:])])  # encoder blocks
        self.pool = nn.MaxPool2d(2)  # pooling layer (can be reused as it will not be trained)
        self.upconvs = nn.ModuleList([nn.ConvTranspose2d(in_ch, out_ch, 2, 2) for in_ch, out_ch in zip(dec_chs[:-1], dec_chs[1:])])  # deconvolution
        self.dec_blocks = nn.ModuleList([StackedDilBloc(in_ch, out_ch) for in_ch, out_ch in zip(dec_chs[:-1], dec_chs[1:])])  # decoder blocks
        self.head = nn.Sequential(nn.Conv2d(dec_chs[-1], 1, 1), nn.Sigmoid()) # 1x1 convolution for producing the output

    def forward(self, x):
        # encode
        enc_features = []
        for block in self.enc_blocks[:-1]:
            x = block(x)  # pass through the block
            enc_features.append(x)  # save features for skip connections
            x = self.pool(x)  # decrease resolution
        x = self.enc_blocks[-1](x)
        # decode
        for block, upconv, feature in zip(self.dec_blocks, self.upconvs, enc_features[::-1]):
            x = upconv(x)  # increase resolution
            x = torch.cat([x, feature], dim=1)  # concatenate skip features
            x = block(x)  # pass through the block
        return self.head(x)  # reduce to 1 channel

In [None]:
class AttentionBlock(nn.Module):
    # an implementation of attention for U-Nets, used in the skip connection.
    # based on https://arxiv.org/abs/1804.03999
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(in_ch*2, out_ch, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(out_ch)
            )
        
        self.W_x = nn.Sequential(
            # we need a stride of 2 here: since g is taken from one level lower
            # than the feature vector, we need to halve spatial resolution
            # but use the same amount of channels
            nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=2, padding=0, bias=True),
            nn.BatchNorm2d(out_ch)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(out_ch, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid(),
            # now need to upsample the psi to get same dim as x
            nn.ConvTranspose2d(1, 1, 2, 2)
        )
        
        self.relu = nn.ReLU()
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi


class UNetStackedDilationsAttention(nn.Module):
    # An Attention U-Net implementation using SDC-Blocks.
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        enc_chs = chs  # number of channels in the encoder
        dec_chs = chs[::-1][:-1]  # number of channels in the decoder
        self.enc_blocks = nn.ModuleList([StackedDilBloc(in_ch, out_ch) for in_ch, out_ch in zip(enc_chs[:-1], enc_chs[1:])])  # encoder blocks
        self.pool = nn.MaxPool2d(2)  # pooling layer (can be reused as it will not be trained)
        self.upconvs = nn.ModuleList([nn.ConvTranspose2d(in_ch, out_ch, 2, 2) for in_ch, out_ch in zip(dec_chs[:-1], dec_chs[1:])])  # deconvolution
        self.dec_blocks = nn.ModuleList([StackedDilBloc(in_ch, out_ch) for in_ch, out_ch in zip(dec_chs[:-1], dec_chs[1:])])  # decoder blocks
        # attention blocks always go out_ch to out_ch/2 based on decoder blocks
        self.att_blocks = nn.ModuleList([AttentionBlock(out_ch, out_ch//2) for out_ch in dec_chs[1:]]) # attention blocks
        self.head = nn.Sequential(nn.Conv2d(dec_chs[-1], 1, 1), nn.Sigmoid()) # 1x1 convolution for producing the output

    def forward(self, x):
        # encode
        enc_features = []
        for block in self.enc_blocks[:-1]:
            x = block(x)  # pass through the block
            enc_features.append(x)  # save features for skip connections
            x = self.pool(x)  # decrease resolution
        x = self.enc_blocks[-1](x)
        # decode
        for block, upconv, feature, attention in zip(self.dec_blocks, self.upconvs, enc_features[::-1], self.att_blocks):
            # what is referred to as x in the paper is what we get from the skip connections (more spatial, less feature info)
            # what is referred to as g comes from "prev. layer", so what we are actually calling x in our code (after upconv)
            # at this point data-x is not yet upscaled --> use stride 2 (like in original paper)
            feature = attention(g=x, x=feature)
            x = upconv(x)  # increase resolution
            x = torch.cat([x, feature], dim=1)  # concatenate skip features
            x = block(x)  # pass through the block
        return self.head(x)  # reduce to 1 channel

In [None]:
class SkipBlock(nn.Module):
    # implements full-scale skip connections
    # this block is used as the decoder-block of UNet3+
    def __init__(self, out_ch, sharedups, shareddowns):
        super().__init__()
        # depending on the amount of out-channels of the corresponding enc block:
        # a different configuration of max-pools, convolutions, and up-convolutions are needed.
        # down and up convolutions are shared, as they are not trainable.
        gfc = GLOBAL_FIRST_CHANNELS
        self.expectedRes = (gfc * GLOBAL_RES) // out_ch

        if self.expectedRes == GLOBAL_RES:
            # note the inclusion of a batch norm and ReLU activation after every individual convolution.
            self.b1 = nn.Sequential(nn.Conv2d(gfc, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b2 = nn.Sequential(sharedups[0], nn.Conv2d(gfc*5, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b3 = nn.Sequential(sharedups[1], nn.Conv2d(gfc*4, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b4 = nn.Sequential(sharedups[2], nn.Conv2d(gfc*8, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b5 = nn.Sequential(sharedups[3], nn.Conv2d(gfc*16, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))

        if self.expectedRes == GLOBAL_RES // 2:
            self.b1 = nn.Sequential(shareddowns[0], nn.Conv2d(gfc, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b2 = nn.Sequential(nn.Conv2d(gfc*2, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b3 = nn.Sequential(sharedups[0], nn.Conv2d(gfc*5, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b4 = nn.Sequential(sharedups[1], nn.Conv2d(gfc*8, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b5 = nn.Sequential(sharedups[2], nn.Conv2d(gfc*16, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
        
        if self.expectedRes == GLOBAL_RES // 4:
            self.b1 = nn.Sequential(shareddowns[1], nn.Conv2d(gfc, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b2 = nn.Sequential(shareddowns[0], nn.Conv2d(gfc*2, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b3 = nn.Sequential(nn.Conv2d(gfc*4, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b4 = nn.Sequential(sharedups[0], nn.Conv2d(gfc*5, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b5 = nn.Sequential(sharedups[1], nn.Conv2d(gfc*16, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))

        if self.expectedRes == GLOBAL_RES // 8:
            self.b1 = nn.Sequential(shareddowns[2], nn.Conv2d(gfc, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b2 = nn.Sequential(shareddowns[1], nn.Conv2d(gfc*2, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b3 = nn.Sequential(shareddowns[0], nn.Conv2d(gfc*4, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b4 = nn.Sequential(nn.Conv2d(gfc*8, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))
            self.b5 = nn.Sequential(sharedups[0], nn.Conv2d(gfc*16, gfc, kernel_size=3, padding=1), nn.BatchNorm2d(gfc), nn.ReLU(inplace=True))

        self.finalConv = nn.Sequential(nn.Conv2d(gfc*5, gfc*5, kernel_size=3, padding=1), nn.BatchNorm2d(gfc*5), nn.ReLU(inplace=True))

    # full scale --> uses all skip connections
    def forward(self, encList, prevX):
        # we will always expect 5 entries in the encList
        # here, need to choose which previous dec-block to forward to the current dec-block
        if self.expectedRes == GLOBAL_RES:
            return self.finalConv(torch.cat([self.b1(encList[0]), self.b2(prevX), self.b3(encList[2]), self.b4(encList[3]), self.b5(encList[4])], dim=1))

        if self.expectedRes == GLOBAL_RES // 2:
            return self.finalConv(torch.cat([self.b1(encList[0]), self.b2(encList[1]), self.b3(prevX), self.b4(encList[3]), self.b5(encList[4])], dim=1))

        if self.expectedRes == GLOBAL_RES // 4:
            return self.finalConv(torch.cat([self.b1(encList[0]), self.b2(encList[1]), self.b3(encList[2]), self.b4(prevX), self.b5(encList[4])], dim=1))

        if self.expectedRes == GLOBAL_RES // 8:
            return self.finalConv(torch.cat([self.b1(encList[0]), self.b2(encList[1]), self.b3(encList[2]), self.b4(encList[3]), self.b5(prevX)], dim=1))

class UNet3plusRS(nn.Module):
    # An implementation of a combined U-Net approach we refer to as U-Net3+RS
    # SDC-Blocks are used in the encoding phase, while the above full-scale SkipBlocks are used in the decoding phase
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        enc_chs = chs  # number of channels in the encoder
        dec_chs = chs[::-1][:-1]  # number of channels in the decoder
        self.enc_blocks = nn.ModuleList([StackedDilBloc(in_ch, out_ch) for in_ch, out_ch in zip(enc_chs[:-1], enc_chs[1:])])  # encoder blocks
        self.pool = nn.MaxPool2d(2)  # pooling layer (can be reused as it will not be trained)

        # the up / downsampling layers can be shared, as they are not trained
        self.upconvs = nn.ModuleList([nn.Upsample(scale_factor = 2**i, mode='bilinear') for i in range(1, 5)])
        self.downconvs = nn.ModuleList([nn.MaxPool2d(2**i) for i in range(1, 4)])

        # we pass the shared down / upconv layers to the skip_blocks on creation and they select which ones to use in construction
        self.skip_blocks = nn.ModuleList([SkipBlock(out_ch, self.upconvs, self.downconvs) for out_ch in dec_chs[1:]]) # skip blocks for full-scale
        self.head = nn.Sequential(nn.Conv2d(GLOBAL_FIRST_CHANNELS*5, 1, 1), nn.Sigmoid()) # 1x1 convolution for producing the output

    def forward(self, x):
        # encode
        enc_features = []
        for block in self.enc_blocks[:-1]:
            x = block(x)  # pass through the block
            enc_features.append(x)  # save features for skip connections
            x = self.pool(x)  # decrease resolution
        x = self.enc_blocks[-1](x)
        enc_features.append(x) # need to also save last enc / dec level for full scale skips
        # decode
        for skipblock in self.skip_blocks:
            # the entire decoder block is implemented within the SkipBlock
            x = skipblock(enc_features, x)
        return self.head(x)  # reduce to 1 channel

## Checkpoints
These util-functions are used to enable us to store and load checkpoints of any pytorch model during training.

In [None]:
# define path for saving checkpoints
current_time = datetime.datetime.now().strftime("%H%M%S")
checkpoint_path = f"checkpoints/{CONFIGURATION_NAME}_{current_time}"
!mkdir checkpoints
!mkdir {checkpoint_path}

In [None]:
def create_checkpoint(model, path, name = 'checkpoint', metadata = {}):
    print(f"Creating checkpoint {name}")
    torch.save(model.state_dict(), os.path.join(path, name + '.pt'))
    with open(os.path.join(path, name + '.json'), 'w') as fp:
        json.dump(metadata, fp,  indent=4)

def load_checkpoint(model, path):
    root = os.path.splitext(path)[0]
    basename = os.path.basename(path)
    baseroot = os.path.splitext(basename)[0]
    dirname = os.path.dirname(path)
    print("Loading checkpoint " + baseroot)
    # load model
    model.load_state_dict(torch.load(root + '.pt'))
    model.eval()
    # load metadata
    with open(root + '.json', 'r') as fp:
        metadata = json.load(fp)
    # load list of already sampled files
    with open(dirname + '/already_sampled.list', 'r') as fp:
        already_sampled = [f.rstrip() for f in fp.readlines()]
    if metadata['hyperparameters'] != HYPERPARAMETERS:
        raise Warning("Hyperparameters do not match!")
    return model, metadata, already_sampled

## Loss Functions
In these cells we define various loss functions and their compounds. Note that we went through multiple iterations and ideas to arrive at the final loss functions we utilize in our report: These loss functions that are not mentioned in the report are still provided here for possible future use.

In [None]:
def accuracy_fn(y_hat, y):
    # computes classification accuracy
    return (y_hat.round() == y.round()).float().mean()


def patch_accuracy_fn(y_hat, y):
    # computes accuracy weighted by patches (metric used on Kaggle for evaluation)
    h_patches = y.shape[-2] // PATCH_SIZE
    w_patches = y.shape[-1] // PATCH_SIZE
    patches_hat = y_hat.reshape(-1, 1, h_patches, PATCH_SIZE, w_patches, PATCH_SIZE).mean((-1, -3)) > CUTOFF
    patches = y.reshape(-1, 1, h_patches, PATCH_SIZE, w_patches, PATCH_SIZE).mean((-1, -3)) > CUTOFF
    return (patches == patches_hat).float().mean()


def jaccard_index(y_hat, y, smooth=1):
    # jaccard performance metric
    intersection = (y_hat * y).sum()
    union = y_hat.sum() + y.sum() - intersection
    return ((intersection + smooth) / (union + smooth)).mean()


def dice_coef(y_hat, y, smooth=1):
    intersection = (y_hat * y).sum()
    return (2. * intersection + smooth) / (y_hat.sum() + y.sum() + smooth)


def soft_dice_loss(y_hat, y):
    return 1-dice_coef(y_hat, y)

In [None]:
# we define our own weighted BCE loss function to cope with class imbalance
# careful: the weight parameter in the included nn.BCE just applies weighting to batch elements!
def BCELoss_class_weighted(weights):

    def loss(input, target):
        input = torch.clamp(input,min=1e-7,max=1-1e-7)
        bce = - weights[1] * target * torch.log(input) - (1 - target) * weights[0] * torch.log(1 - input)
        return torch.mean(bce)

    return loss


# note: this weighting is based on ratio of 1-pixels accross all target (kaggle_data) training ground-truths
loss_fn_weighted = BCELoss_class_weighted([1, 0.75])

In [None]:
KERNEL_SIZE = 15
PADDING = int((KERNEL_SIZE - 1) / 2)

# sharpening and connectivity loss amplification ideas
# these are not covered in our report, but included for future use here.
def create_sharpening_matrix(input, target):
    thresh1 = nn.Threshold(CUTOFF, 0.0, inplace=False)
    input_t = thresh1(input)
    
    # also remove correct 1-predictions using thresholding
    thresh2 = nn.Threshold(0.0, 0.0, inplace=False)
    pos_if_missedzero = thresh2(-1.0 * (target - input_t))

    # we want to more heavily penalize false ones that are not close to many true ones.
    # do this by lessening the numeric value of false ones that are close to many true ones.
    # We can use an average pooling convolution over the target to approximate the "true oneness" of an area.
    # NOTE: this must produce the same dims. as the target! Using (384x384) images, by the output dim.
    # formula we then select kernel size 15 and padding 7.
    avg_pool = torch.nn.AvgPool2d(KERNEL_SIZE, stride=1, padding=PADDING, count_include_pad=False)
    avg_neighbors = avg_pool(target)
    # we can now soften penalties for false ones that are in high-oneness regions
    penalties = pos_if_missedzero * (1 - avg_neighbors)

    return torch.clamp(penalties,min=1e-7,max=1-1e-7)

def create_connectivity_matrix(input, target):
    # again, thresholding for true zeroes
    thresh1 = nn.Threshold(CUTOFF, 0.0, inplace=False)
    input_t = thresh1(input)

    # we want to more heavily penalize false zeroes that are close to many true ones.
    thresh2 = nn.Threshold(0.0, 0.0, inplace=False)
    # penalties are now proportional to how far we were from having a correct one-prediction
    # thresh removes false ones, these are handled by sharpening
    pos_if_missedone = thresh2(target - input_t)

    avg_pool = torch.nn.AvgPool2d(KERNEL_SIZE, stride=1, padding=PADDING, count_include_pad=False)
    avg_neighbors = avg_pool(target)
    # harden penalties for false zeroes that are in high-oneness regions
    penalties = pos_if_missedone * avg_neighbors

    return torch.clamp(penalties,min=1e-7,max=1-1e-7)

def Sharpening_Loss(beta=1):
    def sharpening(input, target):
        return beta * torch.mean(create_sharpening_matrix(input, target)) + (1-beta) * torch.mean(create_connectivity_matrix(input, target))
    return sharpening

sharpening_loss = Sharpening_Loss(1)
connectivity_loss = Sharpening_Loss(0)


In [None]:
# custom focal weighted BCE loss idea that was also experimented with
# but ultimately did not improve performance.
# the idea here is to soften penalties corresponding to how badly a pixel is
# misclassified relative to its neighbors: e.g a missed 1 surrounded by ground-truth
# ones should be penalized more heavily than a missed 1 surrounded by 0s.
def Focal_BCE(gamma):
    def focal(input, target):
        input = torch.clamp(input,min=1e-7,max=1-1e-7)

        avg_pool = torch.nn.AvgPool2d(KERNEL_SIZE, stride=1, padding=PADDING, count_include_pad=False)
        avg_neighbors = avg_pool(target)

        focal = - torch.pow(1-input, gamma) * target * torch.log(input) * (avg_neighbors) - torch.pow(input, gamma) * (1-target) * torch.log(1-input) * (1 - avg_neighbors)
        return torch.mean(focal)
    return focal
loss_fn_focalbce = Focal_BCE(2)

# vanilla implementation of the Focal BCE loss
def Focal_BCE_vanilla(gamma):
    def focalvanilla(input, target):
        input = torch.clamp(input,min=1e-7,max=1-1e-7)
        focal = - torch.pow(1-input, gamma) * target * torch.log(input) - torch.pow(input, gamma) * (1-target) * torch.log(1-input)
        return torch.mean(focal)
    return focalvanilla

# compound of the dice loss with a custom Focal BCE idea
def compound_focal_dice(beta):
    def compound(input, target):
        focal_bce = Focal_BCE(2)
        return beta * focal_bce(input, target) + (1-beta) * soft_dice_loss(input, target)
    return compound
loss_fn_focalbce_dice = compound_focal_dice(0.5)

# compound of the dice loss with the vanilla Focal BCE
def compound_focal_dice_vanilla(beta):
    def compoundvanilla(input, target):
        focal_bce = Focal_BCE_vanilla(2)
        return beta * focal_bce(input, target) + (1-beta) * soft_dice_loss(input, target)
    return compoundvanilla
loss_fn_focalbce_dice_vanilla = compound_focal_dice_vanilla(0.5)
        

In [None]:
def chi_square_dist(y_hat, y):
    # patched implementation of the chi squared distance used as a component of loss functions
    h_patches = y.shape[-2] // PATCH_SIZE
    w_patches = y.shape[-1] // PATCH_SIZE
    patches_hat_mean = y_hat.reshape(-1, 1, h_patches, PATCH_SIZE, w_patches, PATCH_SIZE).mean((-1, -3))
    patches_mean = y.reshape(-1, 1, h_patches, PATCH_SIZE, w_patches, PATCH_SIZE).mean((-1, -3))
    return torch.mean(torch.div(torch.pow(patches_hat_mean - patches_mean, 2), patches_hat_mean + patches_mean))

def compound_focal_dice_chi():
    # a compound of dice, focal BCE and chi distance used as our final loss function.
    def compound(input, target):
        focal_bce = Focal_BCE_vanilla(2)
        return focal_bce(input, target) + soft_dice_loss(input, target) + chi_square_dist(input, target)
    return compound
loss_fn_focal_dice_chi = compound_focal_dice_chi()


# Training Loop

## Function

In [None]:
def train(train_dataloader, eval_dataloader, model, loss_fn, metric_fns, optimizer, n_epochs, 
          train_iter=(1, 1), useEarlyStopping=USE_EARLY_STOPPING):
    # training loop
    history = {}  # collects metrics at the end of each epoch
    STATE['total_iterations'] = train_iter[1]
    STATE['current_iteration'] = train_iter[0]
    STATE['time_trained'] = time.time() - start_model_train
    early_stop_counter = 0
    lowest_val_loss = 1.

    for epoch in range(n_epochs):  # loop over the dataset multiple times

        # initialize metric list
        STATE['epoch'] = epoch+1
        metrics = {'loss': [], 'val_loss': []}
        for k, _ in metric_fns.items():
            metrics[k] = []
            metrics['val_'+k] = []

        pbar = tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{n_epochs}')
        # training
        model.train()
        for (x, y) in pbar:
            optimizer.zero_grad()  # zero out gradients
            y_hat = model(x)  # forward pass
            loss = loss_fn(y_hat, y)
            loss.backward()  # backward pass
            optimizer.step()  # optimize weights

            # log partial metrics
            metrics['loss'].append(loss.item())
            for k, fn in metric_fns.items():
                metrics[k].append(fn(y_hat, y).item())
            pbar.set_postfix({k: sum(v)/len(v) for k, v in metrics.items() if len(v) > 0})

        # validation
        model.eval()
        with torch.no_grad():  # do not keep track of gradients
            for (x, y) in eval_dataloader:
                y_hat = model(x)  # forward pass
                loss = loss_fn(y_hat, y)
                
                # log partial metrics
                metrics['val_loss'].append(loss.item())
                for k, fn in metric_fns.items():
                    metrics['val_'+k].append(fn(y_hat, y).item())

        # summarize metrics and hyperparameters
        history[epoch] = {k: sum(v) / len(v) for k, v in metrics.items()}
        metadata = {'hyperparameters': HYPERPARAMETERS,
                    'state': STATE,
                    'metrics': history[epoch],
                    }
        print(' '.join(['\t- '+str(k)+' = '+str(round(v, 3))+'\n ' for (k, v) in history[epoch].items()]))

        if useEarlyStopping:
            if history[epoch]['val_loss'] < lowest_val_loss:
                create_checkpoint(model, checkpoint_path, metadata=metadata)
                lowest_val_loss = history[epoch]['val_loss']
                early_stop_counter = 0
            elif lowest_val_loss < 1.1 * history[epoch]['val_loss']:
                early_stop_counter += 1 
            if early_stop_counter > 10:
                print("early stopping")
                break;

        # display validation results
        # there may be feature augmentation in x!
        print(x.detach().cpu().numpy().shape)
        show_val_samples(x.detach().cpu().numpy()[:,:3,:,:], y.detach().cpu().numpy(), y_hat.detach().cpu().numpy())

    # end of train iteration
    print(f'Finished training iteration: {train_iter[0]}/{train_iter[1]}')
    print(f'Elapsed time since start of training: {round(time.time() - start_model_train, 2)} seconds')
    model, metadata, _ = load_checkpoint(model, checkpoint_path + '/checkpoint')
    # plot loss curves
    plt.plot([v['loss'] for k, v in history.items()], label='Training Loss')
    plt.plot([v['val_loss'] for k, v in history.items()], label='Validation Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epochs')
    plt.legend()
    plt.show()
    plt.plot([v['patch_acc'] for k, v in history.items()], label='Patch Acc.')
    plt.plot([v['val_patch_acc'] for k, v in history.items()], label='Validation Patch Acc.')
    plt.ylabel('Patch Acc.')
    plt.xlabel('Epochs')
    plt.legend()
    print('Datasets: ' + '; '.join(f"{d} " for d in DATASETS))
    print(f'Total train set size: {train_img_set_size}')
    print(f'BATCH_SIZE: {BATCH_SIZE}')
    print(f'LOSS_FN: {LOSS_FN}')
    print(' '.join(['\t- '+str(k)+' = '+str(round(v, 3))+'\n ' for (k, v) in history[epoch].items()]))
    

## Execution

In [None]:
# if feature extraction is used, there is one additional input channel
channels = (3,64,128,256,512,1024)
if PREPROCESS_FRANGI:
    channels = (4,64,128,256,512,1024)

# now that all definitions have been made, we can fill in the config dictionary
MODELS['unet'] = UNet(chs=channels)
MODELS['sdunet'] = UNetStackedDilations(chs=channels)
MODELS['attention_sdunet'] = UNetStackedDilationsAttention(chs=channels)
MODELS['unet3plusRS'] = UNet3plusRS(chs=channels)
LOSS_FNS['bce'] = nn.BCELoss()
LOSS_FNS['soft_dice'] = soft_dice_loss
LOSS_FNS['focal_bce'] = Focal_BCE(2)
LOSS_FNS['custom_bce_dice_mix'] = loss_fn_focalbce_dice
LOSS_FNS['bce_dice_mix'] = loss_fn_focalbce_dice_vanilla
LOSS_FNS['focal_dice_chi'] = loss_fn_focal_dice_chi

In [None]:
!pip install torchmetrics
import torchmetrics

f1 = torchmetrics.F1Score(num_classes=1, threshold=0.5, average = 'weighted').to(device)
def f1withintconv(input, target):
    # the f1 metric needs flat input and target, and the target must be of type int.
    return f1(input.view(-1), target.int().view(-1))

model = MODELS[MODEL].to(device)
loss_fn = LOSS_FNS[LOSS_FN]
metric_fns = {'acc': accuracy_fn, 'patch_acc': patch_accuracy_fn, 'f1': f1withintconv}
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

start_train_iter, n_train_iter = 1, 1
if MAX_SAMPLE_SIZE > 0:
    n_train_iter = len(os.listdir(train_images_path)) // MAX_SAMPLE_SIZE + 1

# provides the ability to resume training from a previously saved pytorch checkpoint
if LOAD_CHECKPOINT is None:
    already_sampled = []
else:
    model, metadata, already_sampled = load_checkpoint(model, LOAD_CHECKPOINT)
    start_train_iter = metadata['state']['current_iteration'] + 1
    start_model_train += metadata['state']['time_trained']

In [None]:
start_model_train = time.time()

In [None]:
# begin one training iteration.
# in the case of large datasets, it is possible to run multiple training iterations
for i in range(start_train_iter-1, n_train_iter):
    print(f"Start of training iteration {i+1}/{n_train_iter}")
    current_sample = generate_sample(train_images_path, sample_size=MAX_SAMPLE_SIZE, already_sampled=already_sampled)
    already_sampled = sorted(already_sampled + current_sample)
    with open(checkpoint_path + f'/already_sampled.list', "w") as outfile:
        outfile.write("\n".join(already_sampled))
    train_dataset = ImageDataset(train_path, device, use_patches=False, resize_to=(384, 384), sample=current_sample)
    val_dataset = ImageDataset(val_path, device, use_patches=False, resize_to=(384, 384))
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
    train(train_dataloader, val_dataloader, model, loss_fn, metric_fns, optimizer, N_EPOCHS, train_iter=(i+1, n_train_iter))
    # in the case of multiple training iterations, free up the GPU by garbage collecting
    train_dataloader = None
    val_dataloader = None
    gc.collect()

In [None]:
stop_model_train = time.time()
print(f"Time to train model: {round(stop_model_train - start_model_train, 2)}")

In [None]:
model, metadata, _ = load_checkpoint(model, checkpoint_path + '/checkpoint')
create_checkpoint(model, checkpoint_path, CONFIGURATION_NAME + '_final_model', metadata=metadata)

# Prediction

In [None]:
def morphological_postprocessing(imgs):
    # an implementation of morphological postprocessing to erode solitary ones
    result = []
    for img in imgs:
        kernel = np.ones((3,3), np.uint8)
        img = img * 255
        _, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
        img = img / 255
        img = cv2.erode(img, kernel, iterations=POSTPROCESSING_MORPHOLOGICAL_ITERATIONS)
        img = cv2.dilate(img, kernel, iterations=POSTPROCESSING_MORPHOLOGICAL_ITERATIONS)
        result.append(img)
    result = np.stack(result)
    return result
    
def prediction(imagesOrg, flip, rot):
  # perform majority-voting based on multiple flipped images, as described in report
  size = imagesOrg.shape[1:3]
  images = np.stack([cv2.resize(img, dsize=(384, 384)) for img in imagesOrg], 0)
  images = images[:, :, :, :3]

  images_aug = []
  for img in images:
    
    img_aug = img
    for i in range(rot):
      img_aug = np.rot90(img_aug)
    if flip:
      img_aug = np.flip(img_aug, axis=1)
    images_aug.append(img_aug)
  images_aug = np.array(images_aug)
  
  test_images = np_to_tensor(np.moveaxis(images_aug, -1, 1), device)
  # the test images need to be subject to idential pre-processing
  if PREPROCESS_FRANGI:
    hess_img = []
    for t in test_images:
      with torch.no_grad():
        grayx = np.squeeze(transforms.Grayscale().forward(t).cpu().numpy(), axis=0)
        stackx = frangi(grayx)
        # now stack this information (feature extraction) onto the channel dimension of x
        tnew = torch.cat([t, torch.unsqueeze(np_to_tensor(stackx.astype(np.float32), device), 0)], dim=0)
        hess_img.append(tnew)
    hess_img_tensor = torch.cat([x.unsqueeze(0) for x in hess_img], dim=0)
    test_images = hess_img_tensor

  test_pred = [model(t).detach().cpu().numpy() for t in test_images.unsqueeze(1)]
  test_pred = np.concatenate(test_pred, 0)
  test_pred = np.moveaxis(test_pred, 1, -1)  # CHW to HWC
  test_pred = np.stack([cv2.resize(img, dsize=size) for img in test_pred], 0)  # resize to original shape
  if POSTPROCESSING_MORPHOLOGICAL_ITERATIONS:
    test_pred = morphological_postprocessing(test_pred)
  # now compute labels
  test_pred = test_pred.reshape((-1, size[0] // PATCH_SIZE, PATCH_SIZE, size[0] // PATCH_SIZE, PATCH_SIZE))
  test_pred = np.moveaxis(test_pred, 2, 3)
  test_pred = np.round(np.mean(test_pred, (-1, -2)) > CUTOFF)
  test_pred_back = []
  for tp in test_pred:
    tp_back = tp
    if flip:
      tp_back = np.flip(tp_back, axis=1)
    for i in range(4 - rot):
      tp_back = np.rot90(tp_back)
    test_pred_back.append(tp_back)

  test_pred_back = np.array(test_pred_back)

  return test_pred_back

def postprocessing_prediction(path):
  test_filenames = (glob(path + '/*.png'))
  test_images = load_from_path(test_path, isGroundtruth=False)
  
  test_preds = prediction(test_images, 0, 0)
  if POSTPROCESSING_MAJORITY > 0:
    for i, j in [(0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3)]:
      test_preds += prediction(test_images, i, j)
    test_preds = test_preds / 8 >= 0.5
  return test_preds

In [None]:
start_model_predict = time.time()

In [None]:
# load best checkpoint before predicting
model, metadata = load_checkpoint(model, f"{checkpoint_path}/{CONFIGURATION_NAME}_final_model")[0:2]
print(' '.join([f'\t- {k} = {v}\n ' for (k, v) in metadata['hyperparameters'].items()]))
print(' '.join([f'\t- {k} = {v}\n ' for (k, v) in metadata['state'].items()]))
print(' '.join([f'\t- {k} = {v}\n ' for (k, v) in metadata['metrics'].items()]))

In [None]:
test_path = 'kaggle_data/test/images'
test_filenames = sorted(glob(test_path + '/*.png'))
result = postprocessing_prediction(test_path)
create_submission(result, test_filenames, submission_filename=f"{checkpoint_path}/{CONFIGURATION_NAME}_submission.csv")

In [None]:
# finally, download the prediction and the checkpoint for reproducibility
# note that this was only used for our convenience in running the experiments on
# google colab, and so this cell is only compatible with the colab platform.
from google.colab import files
import os
files.download(f"{checkpoint_path}/{CONFIGURATION_NAME}_submission.csv")
os.system( "zip -r {} {}".format( f"{CONFIGURATION_NAME}_checkpoints.zip" , f"{checkpoint_path}" ) )
files.download(f"{CONFIGURATION_NAME}_checkpoints.zip")

In [None]:
# finally, check the parameters of the model
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)