In [1]:
import warnings
warnings.filterwarnings("ignore")

import sys,os

import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import convolve2d

import glob
import xarray as xr
import datetime

# import yaml
import tqdm
import time
import torch
import torchvision

import pickle
import joblib
import logging
import random

from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple
# from multiprocessing import cpu_count

import torch.fft
from torch import nn

import torch.nn.functional as F

# from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision.models as models
from torch.optim.lr_scheduler import *

from sklearn.model_selection import train_test_split

from collections import defaultdict
import pandas as pd

In [2]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")

if is_cuda:
    torch.backends.cudnn.benchmark = True
    #torch.backends.cudnn.deterministic = True

### Load data from disk

In [6]:
# Load all of the data into memory
images = []
labels = []
masks = []

loaded = 0
max_images = 10

start_time = time.time()
with open("training_512x512_128_50000.pkl", "rb") as fid:
    while True:
        
        try:
            image, label, u_net_mask, image_tile_idx, image_tile_coors = pickle.load(fid)
            images.append(np.expand_dims(image, 0))
            labels.append(label)
            masks.append(np.expand_dims(u_net_mask, 0))
            
            loaded += 1
            
            if len(images) == max_images:
                break
            
        except Exception as E:
            break
            
images = np.vstack(images)
labels = np.vstack(labels)
masks = np.vstack(masks)

end_time = time.time()

In [7]:
print(f"It took {end_time - start_time} s to load {loaded} (x,y) points")

It took 0.11396145820617676 s to load 10 (x,y) points


In [8]:
images.shape, labels.shape, masks.shape

((10, 2, 512, 512), (10, 1), (10, 512, 512))

In [5]:
X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.20, random_state=42)

### Load the binary model 

In [6]:
class ResNet(nn.Module):
    def __init__(self, fcl_layers = [], dr = 0.0, output_size = 1, resnet_model = 18, pretrained = True):
        super(ResNet, self).__init__()
        self.pretrained = pretrained
        self.resnet_model = resnet_model 
        if self.resnet_model == 18:
            resnet = models.resnet18(pretrained=self.pretrained)
        elif self.resnet_model == 34:
            resnet = models.resnet34(pretrained=self.pretrained)
        elif self.resnet_model == 50:
            resnet = models.resnet50(pretrained=self.pretrained)
        elif self.resnet_model == 101:
            resnet = models.resnet101(pretrained=self.pretrained)
        elif self.resnet_model == 152:
            resnet = models.resnet152(pretrained=self.pretrained)
        resnet.conv1 = torch.nn.Conv1d(2, 64, (7, 7), (2, 2), (3, 3), bias=False) # Manually change color dim to match our data
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet_output_dim = resnet.fc.in_features
        self.resnet = nn.Sequential(*modules)
        self.fcn = self.make_fcn(self.resnet_output_dim, output_size, fcl_layers, dr)
        
    def make_fcn(self, input_size, output_size, fcl_layers, dr):
        if len(fcl_layers) > 0:
            fcn = [
                nn.Dropout(dr),
                nn.Linear(input_size, fcl_layers[0]),
                nn.BatchNorm1d(fcl_layers[0]),
                torch.nn.LeakyReLU()
            ]
            if len(fcl_layers) == 1:
                fcn.append(nn.Linear(fcl_layers[0], output_size))
            else:
                for i in range(len(fcl_layers)-1):
                    fcn += [
                        nn.Linear(fcl_layers[i], fcl_layers[i+1]),
                        nn.BatchNorm1d(fcl_layers[i+1]),
                        torch.nn.LeakyReLU(),
                        nn.Dropout(dr)
                    ]
                fcn.append(nn.Linear(fcl_layers[i+1], output_size))
        else:
            fcn = [
                nn.Dropout(dr),
                nn.Linear(input_size, output_size)
            ]
        if output_size > 1:
            fcn.append(torch.nn.LogSoftmax(dim=1))
        return nn.Sequential(*fcn)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.resnet(x)
        x = x.view(x.size(0), -1)  # flatten
        x = self.fcn(x)
        return x

In [7]:
epochs = 200
train_batch_size = 32
valid_batch_size = 32
batches_per_epoch = 500

stopping_patience = 5

In [8]:
train_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(X_train), 
    torch.from_numpy(y_train)
)

test_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(X_test), 
    torch.from_numpy(y_test)
)

In [9]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=train_batch_size, 
    #num_workers=0,
    pin_memory=True,
    shuffle=True)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=valid_batch_size,
    #num_workers=0,
    pin_memory=True,
    shuffle=False)

In [10]:
fcl_layers = []
dropout = 0.2
output_size = 2
resnet_model = 50
pretrained = True

model = ResNet(fcl_layers, 
               dr = dropout, 
               output_size = output_size, 
               resnet_model=resnet_model, 
               pretrained = pretrained).to(device)

In [11]:
learning_rate = 1e-04
weight_decay = 0.0

optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=learning_rate, 
    weight_decay=weight_decay
)

In [12]:
train_criterion = torch.nn.CrossEntropyLoss()
test_criterion = torch.nn.CrossEntropyLoss() 

In [13]:
lr_scheduler = ReduceLROnPlateau(
    optimizer, 
    patience = 1, 
    min_lr = 1.0e-10,
    verbose = True
)

In [14]:
epoch_test_losses = []
results_dict = defaultdict(list)


for epoch in range(epochs):

    ### Train the model 
    model.train()

    batch_loss = []
    accuracy = [] 
        
    # set up a custom tqdm
    batch_group_generator = tqdm.tqdm(
        enumerate(train_loader), 
        total=batches_per_epoch,
        leave=True
    )
 
    for k, (inputs, y) in batch_group_generator:
        
        # Move data to the GPU, if not there already
        inputs = inputs.to(device).float()
        y = y.to(device).long()
        
        # Clear gradient
        optimizer.zero_grad()

        # get output from the model, given the inputs
        pred_z_logits = model(inputs)

        # get loss for the predicted output
        loss = train_criterion(pred_z_logits, y.squeeze(-1))
        
        # compute the top-1 accuracy
        pred_z_labels = torch.argmax(pred_z_logits, 1)
        accuracy += list((pred_z_labels == y.squeeze(1)).float().detach().cpu().numpy())
        
        # get gradients w.r.t to parameters
        loss.backward()
        batch_loss.append(loss.item())

        # update parameters
        optimizer.step()

        # update tqdm
        to_print = "Epoch {} train_loss: {:.4f}".format(epoch, np.mean(batch_loss))
        to_print += " train_acc: {:.4f}".format(np.mean(accuracy))
        to_print += " lr: {:.12f}".format(optimizer.param_groups[0]['lr'])
        batch_group_generator.set_description(to_print)
        batch_group_generator.update()
                     
        # stop the training epoch when train_batches_per_epoch have been used to update 
        # the weights to the model
        if k >= batches_per_epoch and k > 0:
            break
            
        #lr_scheduler.step(epoch + k / batches_per_epoch)
        
    # Compuate final performance metrics before doing validation
    train_loss = np.mean(batch_loss)
    train_acc = np.mean(accuracy)
        
    # clear the cached memory from the gpu
    torch.cuda.empty_cache()

    ### Test the model 
    model.eval()
    with torch.no_grad():

        batch_loss = []
        accuracy = []
        
        # set up a custom tqdm
        batch_group_generator = tqdm.tqdm(
            enumerate(train_loader),
            leave=True
        )

        for k, (inputs, y) in batch_group_generator:
            # Move data to the GPU, if not there already
            inputs = inputs.to(device).float()
            y = y.to(device).long()
            # get output from the model, given the inputs
            pred_z_logits = model(inputs)
            # get loss for the predicted output
            loss = test_criterion(pred_z_logits, y.squeeze(-1))
            batch_loss.append(loss.item())
            # compute the accuracy
            pred_z_labels = torch.argmax(pred_z_logits, 1)
            accuracy += list((pred_z_labels == y.squeeze(1)).float().detach().cpu().numpy())
            # update tqdm
            to_print = "Epoch {} test_loss: {:.4f}".format(epoch, np.mean(batch_loss))
            to_print += " test_acc: {:.4f}".format(np.mean(accuracy))
            batch_group_generator.set_description(to_print)
            batch_group_generator.update()

    # Use the accuracy as the performance metric to toggle learning rate and early stopping
    test_loss = 1 - np.mean(accuracy)
    epoch_test_losses.append(test_loss)
    
    # Lower the learning rate if we are not improving
    lr_scheduler.step(test_loss)

    # Save the model if its the best so far.
    if test_loss == min(epoch_test_losses):
        state_dict = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': test_loss
        }
        torch.save(state_dict, "best_resnet.pt")
        
    # Get the last learning rate
    learning_rate = optimizer.param_groups[0]['lr']
        
    # Put things into a results dictionary -> dataframe
    results_dict['epoch'].append(epoch)
    results_dict['train_loss'].append(train_loss)
    results_dict['valid_loss'].append(np.mean(batch_loss))
    results_dict['train_accuracy'].append(train_acc)
    results_dict['valid_accuracy'].append(np.mean(accuracy))
    results_dict["learning_rate"].append(learning_rate)
    df = pd.DataFrame.from_dict(results_dict).reset_index()

    # Save the dataframe to disk
    df.to_csv("training_log_resnet.csv", index = False)
        
    # Stop training if we have not improved after X epochs
    best_epoch = [i for i,j in enumerate(epoch_test_losses) if j == min(epoch_test_losses)][0]
    offset = epoch - best_epoch
    if offset >= stopping_patience:
        break
        
    if results_dict['valid_accuracy'][-1] == 1.0:
        break

Epoch 0 train_loss: 0.6744 train_acc: 0.5534 lr: 0.000100000000:   1%|          | 4/500 [00:04<10:05,  1.22s/it]
Epoch 0 test_loss: 0.6814 test_acc: 0.5146: : 4it [00:00,  5.47it/s]
Epoch 1 train_loss: 0.4664 train_acc: 0.8155 lr: 0.000100000000:   1%|          | 4/500 [00:01<03:37,  2.28it/s]
Epoch 1 test_loss: 0.9785 test_acc: 0.5146: : 4it [00:00,  5.33it/s]
Epoch 2 train_loss: 0.2500 train_acc: 0.9709 lr: 0.000100000000:   1%|          | 4/500 [00:01<03:31,  2.35it/s]
Epoch 2 test_loss: 0.4034 test_acc: 0.8155: : 4it [00:00,  5.38it/s]
Epoch 3 train_loss: 0.2367 train_acc: 0.9806 lr: 0.000100000000:   1%|          | 4/500 [00:01<03:31,  2.35it/s]
Epoch 3 test_loss: 0.6478 test_acc: 0.7573: : 4it [00:00,  5.71it/s]
Epoch 4 train_loss: 0.0837 train_acc: 0.9903 lr: 0.000100000000:   1%|          | 4/500 [00:01<03:31,  2.34it/s]
Epoch 4 test_loss: 0.4390 test_acc: 0.7476: : 4it [00:00,  5.35it/s]
  0%|          | 0/500 [00:00<?, ?it/s]

Epoch     5: reducing learning rate of group 0 to 1.0000e-05.


Epoch 5 train_loss: 0.0517 train_acc: 1.0000 lr: 0.000010000000:   1%|          | 4/500 [00:01<03:37,  2.28it/s]
Epoch 5 test_loss: 0.3030 test_acc: 0.9029: : 4it [00:00,  5.35it/s]
Epoch 6 train_loss: 0.1068 train_acc: 1.0000 lr: 0.000010000000:   1%|          | 4/500 [00:01<03:32,  2.34it/s]
Epoch 6 test_loss: 0.1472 test_acc: 1.0000: : 4it [00:00,  5.70it/s]
