In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import sys
import glob
import tqdm
import time
import yaml
import torch
import pickle
import joblib
import random
import sklearn
import logging
import datetime
import torch.fft
import torchvision
import torchvision.models as models

#torch.multiprocessing.set_start_method('spawn')

import numpy as np
import pandas as pd
import xarray as xr
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from collections import defaultdict
from scipy.signal import convolve2d
from torch.optim.lr_scheduler import *
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, TensorDataset
from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple

from unet_losses import *

### Set up the GPU device id, or CPU if no GPU available

In [2]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")

if is_cuda:
    torch.backends.cudnn.benchmark = True
    #torch.backends.cudnn.deterministic = True

In [3]:
# https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch

### Load data from disk

In [4]:
class DataReader(Dataset):
    
    def __init__(self, fn, max_buffer_size = 5000, max_images = 40000, shuffle = True, normalize = True):
        self.fn = fn
        self.buffer = []
        self.max_buffer_size = max_buffer_size
        self.shuffle = shuffle
        self.max_images = max_images
            
        self.fid = open(self.fn, "rb")
        self.loaded = 0 
        self.epoch = 0
        
        self.normalize = normalize
        self.mean = np.mean([0.485, 0.456, 0.406])
        self.std = np.mean([0.229, 0.224, 0.225])
        
    def __getitem__(self, idx):    
        
        self.on_epoch_end()
        
        while True:
        
            try:
                data = joblib.load(self.fid)
                image, label, mask = data
                
                image /= 255.0
                
                if self.normalize:
                    image /= 255.0
                    image = (image - self.mean) / self.std
                
                image = torch.FloatTensor(image)
                #label = torch.LongTensor([label])
                mask = torch.FloatTensor(mask.toarray())
                
                data = (image, mask)

                self.loaded += 1

                if not self.shuffle:
                    return data
                
                self.buffer.append(data)
                random.shuffle(self.buffer)

                if len(self.buffer) > self.max_buffer_size:
                    self.buffer = self.buffer[:self.max_buffer_size]

                if self.epoch > 0:
                    return self.buffer.pop()

                else: # wait until all data has been seen before sampling from the buffer
                    return data


            except EOFError:
                self.fid = open(self.fn, "rb")
                self.loaded = 0
                continue
                    
    def __len__(self):
        return self.max_images
    
    def on_epoch_end(self):
        if self.loaded == self.__len__():
            self.fid = open(self.fn, "rb")
            self.loaded = 0
            self.epoch += 1

### Load the binary model 

In [5]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))

class Encoder(nn.Module):
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs
    
class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs
    
class UNet(nn.Module):
    def __init__(self,
                 enc_chs=(2, 64, 128, 256, 512, 1024),
                 dec_chs=(1024, 512, 256, 128, 64),
                 num_class=1,
                 retain_dim=False, out_sz=(572,572)):
        
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
        self.retain_dim  = retain_dim
        self.out_sz = out_sz
        
    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, self.out_sz)
        out = torch.nn.Sigmoid()(out)
        return out

In [7]:
def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_channels)
    )


class ResNetUNet(nn.Module):
    def __init__(self, n_class, color_dim = 2):
        super().__init__()

        self.base_model = models.resnet18(pretrained=True)
        self.base_model.conv1 = torch.nn.Conv2d(color_dim, 64, (7, 7), (2, 2), (3, 3), bias=False) 
        self.base_layers = list(self.base_model.children())
        
        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(color_dim , 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)
        out = torch.nn.Sigmoid()(out)

        return out

In [8]:
with open("models/unet_double_compare/holo_data.yml") as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [9]:
tile_size = conf["data"]["tile_size"]
step_size = conf["data"]["step_size"]
data_path = conf["data"]["output_path"]

fn_train = f"{data_path}/training_{tile_size}_{step_size}.pkl"
fn_valid = f"{data_path}/validation_{tile_size}_{step_size}.pkl"

epochs = conf["trainer"]["epochs"]
train_batch_size = conf["trainer"]["train_batch_size"]
valid_batch_size = conf["trainer"]["valid_batch_size"]
batches_per_epoch = conf["trainer"]["batches_per_epoch"]
stopping_patience = conf["trainer"]["stopping_patience"]
model_loc = conf["trainer"]["output_path"]

fcl_layers = conf["resnet"]["fcl_layers"]
dropout = conf["resnet"]["dropout"]
output_size = conf["resnet"]["output_size"]
resnet_model = conf["resnet"]["resnet_model"]
pretrained = conf["resnet"]["pretrained"]

learning_rate = conf["optimizer"]["learning_rate"]
weight_decay = conf["optimizer"]["weight_decay"]

In [10]:
train_dataset = DataReader(
    fn_train, 
    max_images = int(0.8 * conf["data"]["total_training"]), 
    max_buffer_size = int(0.1 * conf["data"]["total_training"]), 
    shuffle = True, 
    normalize = False
)

test_dataset = DataReader(
    fn_valid, 
    max_images = int(0.1 * conf["data"]["total_training"]),
    max_buffer_size = int(0.1 * conf["data"]["total_training"]),
    shuffle = False, 
    normalize = False
)

In [11]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=train_batch_size, 
    #num_workers=8, # can increase to number of CPUs you asked for in launch script; usually 8
    pin_memory=True,
    shuffle=False) # let the reader do the shuffling

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=valid_batch_size,
    #num_workers=8,
    pin_memory=True,
    shuffle=False)

In [12]:
retain_dim = True
mask_size = (512,512)

encoder_channels = (2, 16, 32, 64, 128, 256)
decoder_channels = [256, 128, 64, 32, 16]

# unet = UNet(
#     enc_chs = encoder_channels,
#     dec_chs = decoder_channels,
#     retain_dim = retain_dim,
#     out_sz = mask_size,
#     num_class = 1
# ).to(device)

unet = ResNetUNet(n_class = 1).to(device)

In [13]:
total_params = sum(p.numel() for p in unet.parameters())
trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad)

In [14]:
total_params

18313833

In [15]:
optimizer = torch.optim.AdamW(
    unet.parameters(), 
    lr=learning_rate, 
    weight_decay=weight_decay
)

In [16]:
mix = 0.001 #1e-1

def element_weighted_MSELoss(y_hat, y):
    weights = (1-mix)*y + mix
    criterion = torch.nn.MSELoss(reduction='none')
    loss = criterion(y_hat, y)
    loss = loss * weights
    return loss.sum() / weights.sum()


train_criterion = DiceBCELoss() #element_weighted_MSELoss #FocalTverskyLoss()
test_criterion = DiceLoss() #DiceBCELoss()

#train_criterion = torch.nn.SmoothL1Loss() # Huber (MSE, but once converges, MAE)
#test_criterion = torch.nn.L1Loss() # MAE

In [17]:
stopping_patience = 5

lr_scheduler = ReduceLROnPlateau(
    optimizer, 
    patience = 1, 
    min_lr = 1.0e-10,
    verbose = True
)

In [19]:
epoch_test_losses = []
results_dict = defaultdict(list)


for epoch in range(epochs):

    ### Train the model 
    unet.train()

    batch_loss = []
        
    # set up a custom tqdm
    batch_group_generator = tqdm.tqdm(
        enumerate(train_loader), 
        total=batches_per_epoch,
        leave=True
    )
 
    for k, (inputs, y) in batch_group_generator:
        
        # Move data to the GPU, if not there already
        inputs = inputs.to(device).float()
        y = y.to(device).float()
        
        # Clear gradient
        optimizer.zero_grad()

        # get output from the model, given the inputs
        pred_mask = unet(inputs)

        # get loss for the predicted output
        loss = train_criterion(pred_mask, y)
                
        # get gradients w.r.t to parameters
        loss.backward()
        batch_loss.append(loss.item())

        # update parameters
        optimizer.step()

        # update tqdm
        to_print = "Epoch {} train_loss: {:.4f}".format(epoch, np.mean(batch_loss))
        to_print += " lr: {:.12f}".format(optimizer.param_groups[0]['lr'])
        batch_group_generator.set_description(to_print)
        batch_group_generator.update()
                     
        # stop the training epoch when train_batches_per_epoch have been used to update 
        # the weights to the model
        if k >= batches_per_epoch and k > 0:
            break
            
        #lr_scheduler.step(epoch + k / batches_per_epoch)
        
    # Compuate final performance metrics before doing validation
    train_loss = np.mean(batch_loss)
        
    # clear the cached memory from the gpu
    torch.cuda.empty_cache()

    ### Test the model 
    unet.eval()
    with torch.no_grad():

        batch_loss = []
        
        # set up a custom tqdm
        batch_group_generator = tqdm.tqdm(
            enumerate(test_loader),
            leave=True
        )

        for k, (inputs, y) in batch_group_generator:
            # Move data to the GPU, if not there already
            inputs = inputs.to(device).float()
            y = y.to(device).float()
            # get output from the model, given the inputs
            pred_mask = unet(inputs)
            # get loss for the predicted output
            loss = test_criterion(pred_mask, y)
            batch_loss.append(loss.item())
            # update tqdm
            to_print = "Epoch {} test_loss: {:.4f}".format(epoch, np.mean(batch_loss))
            batch_group_generator.set_description(to_print)
            batch_group_generator.update()

    # Use the accuracy as the performance metric to toggle learning rate and early stopping
    test_loss = np.mean(batch_loss)
    epoch_test_losses.append(test_loss)
    
    # Lower the learning rate if we are not improving
    lr_scheduler.step(test_loss)

    # Save the model if its the best so far.
    if test_loss == min(epoch_test_losses):
        state_dict = {
            'epoch': epoch,
            'model_state_dict': unet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': test_loss
        }
        #TODO: add directory
        torch.save(state_dict, f"{model_loc}/best.pt")
        
    # Get the last learning rate
    learning_rate = optimizer.param_groups[0]['lr']
        
    # Put things into a results dictionary -> dataframe
    results_dict['epoch'].append(epoch)
    results_dict['train_loss'].append(train_loss)
    results_dict['valid_loss'].append(np.mean(batch_loss))
    results_dict["learning_rate"].append(learning_rate)
    df = pd.DataFrame.from_dict(results_dict).reset_index()

    # Save the dataframe to disk
    #TODO: add directory
    df.to_csv(f"{model_loc}/training_log.csv", index = False)
        
    # Stop training if we have not improved after X epochs (stopping patience)
    best_epoch = [i for i,j in enumerate(epoch_test_losses) if j == min(epoch_test_losses)][0]
    offset = epoch - best_epoch
    if offset >= stopping_patience:
        break

Epoch 0 train_loss: 1.6077 lr: 0.000100000000: 100%|██████████| 200/200 [01:22<00:00,  2.42it/s]
Epoch 0 test_loss: 0.9984: : 157it [01:37,  1.61it/s]
Epoch 1 train_loss: 1.5367 lr: 0.000100000000: 100%|██████████| 200/200 [01:19<00:00,  2.50it/s]
Epoch 1 test_loss: 0.9982: : 157it [01:36,  1.63it/s]
Epoch 2 train_loss: 1.4405 lr: 0.000100000000: 100%|██████████| 200/200 [01:21<00:00,  2.46it/s]
Epoch 2 test_loss: 0.9978: : 157it [01:32,  1.69it/s]
Epoch 3 train_loss: 1.3414 lr: 0.000100000000: 100%|██████████| 200/200 [01:22<00:00,  2.42it/s]
Epoch 3 test_loss: 0.9971: : 157it [01:31,  1.72it/s]
Epoch 4 train_loss: 1.2503 lr: 0.000100000000: 100%|██████████| 200/200 [01:22<00:00,  2.43it/s]
Epoch 4 test_loss: 0.9962: : 157it [01:32,  1.69it/s]
Epoch 5 train_loss: 1.1847 lr: 0.000100000000: 100%|██████████| 200/200 [01:22<00:00,  2.42it/s]
Epoch 5 test_loss: 0.9948: : 157it [01:30,  1.73it/s]
Epoch 6 train_loss: 1.1267 lr: 0.000100000000: 100%|██████████| 200/200 [01:22<00:00,  2.43it/

Epoch    28: reducing learning rate of group 0 to 1.0000e-05.


Epoch 28 train_loss: 0.0626 lr: 0.000010000000: 100%|██████████| 200/200 [01:21<00:00,  2.44it/s]
Epoch 28 test_loss: 0.0463: : 157it [01:33,  1.67it/s]
Epoch 29 train_loss: 0.0666 lr: 0.000010000000: 100%|██████████| 200/200 [01:22<00:00,  2.44it/s]
Epoch 29 test_loss: 0.0468: : 157it [01:33,  1.67it/s]
Epoch 30 train_loss: 0.0650 lr: 0.000010000000: 100%|██████████| 200/200 [01:22<00:00,  2.43it/s]
Epoch 30 test_loss: 0.0459: : 157it [01:31,  1.72it/s]
Epoch 31 train_loss: 0.0664 lr: 0.000010000000: 100%|██████████| 200/200 [01:22<00:00,  2.44it/s]
Epoch 31 test_loss: 0.0451: : 157it [01:33,  1.68it/s]
Epoch 32 train_loss: 0.0683 lr: 0.000010000000: 100%|██████████| 200/200 [01:22<00:00,  2.43it/s]
Epoch 32 test_loss: 0.0443: : 157it [01:44,  1.51it/s]
Epoch 33 train_loss: 0.0463 lr: 0.000010000000: 100%|██████████| 200/200 [01:22<00:00,  2.43it/s]
Epoch 33 test_loss: 0.0440: : 157it [01:34,  1.66it/s]
Epoch 34 train_loss: 0.0475 lr: 0.000010000000: 100%|██████████| 200/200 [01:22<00

Epoch    39: reducing learning rate of group 0 to 1.0000e-06.


Epoch 39 train_loss: 0.0528 lr: 0.000001000000: 100%|██████████| 200/200 [01:22<00:00,  2.43it/s]
Epoch 39 test_loss: 0.0425: : 157it [01:35,  1.64it/s]
Epoch 40 train_loss: 0.0546 lr: 0.000001000000: 100%|██████████| 200/200 [01:22<00:00,  2.43it/s]
Epoch 40 test_loss: 0.0424: : 157it [01:37,  1.62it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch    41: reducing learning rate of group 0 to 1.0000e-07.


Epoch 41 train_loss: 0.0674 lr: 0.000000100000: 100%|██████████| 200/200 [01:22<00:00,  2.44it/s]
Epoch 41 test_loss: 0.0424: : 157it [01:32,  1.69it/s]
