In [1]:
import warnings
warnings.filterwarnings("ignore")

import sys 
sys.path.append("/glade/work/schreck/repos/HOLO/clean/holodec-ml")
from holodecml.data import *
from holodecml.losses import *
from holodecml.models import *
from holodecml.metrics import *
from holodecml.transforms import *

import os
import glob
import tqdm
import time
import yaml
import torch
import shutil
import pickle
import joblib
import random
import sklearn
import logging
import datetime

import torch.fft
import torchvision
import torchvision.models as models

import numpy as np
import pandas as pd
import xarray as xr
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from collections import defaultdict
from scipy.signal import convolve2d
from torch.optim.lr_scheduler import *
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, TensorDataset
from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple

### Set up logger to print stuff

In [2]:
root = logging.getLogger()
root.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')

# Stream output to stdout
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
root.addHandler(ch)

### Set up the GPU device id, or CPU if no GPU available

In [3]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")

### Set seeds for reproducibility

In [4]:
def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

In [5]:
seed_everything(1000)

### Load the configuration and get the relevant variables

In [6]:
#config = "/glade/work/schreck/repos/HOLO/clean/holodec-ml/results/unet_const_noisy/model.yml"
config = "../config/unet_propagation.yml"

In [7]:
with open(config) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [8]:
tile_size = conf["data"]["tile_size"]
step_size = conf["data"]["step_size"]
data_path = conf["data"]["output_path"]

fn_train = f"{data_path}/training_{tile_size}_{step_size}.pkl"
fn_valid = f"{data_path}/validation_{tile_size}_{step_size}.pkl"

epochs = conf["trainer"]["epochs"]
train_batch_size = conf["trainer"]["train_batch_size"]
valid_batch_size = conf["trainer"]["valid_batch_size"]
batches_per_epoch = conf["trainer"]["batches_per_epoch"]
stopping_patience = conf["trainer"]["stopping_patience"]
model_loc = conf["trainer"]["output_path"]

model_name = conf["model"]["name"]
color_dim = conf["model"]["color_dim"]
inference_mode = conf["model"]["mode"]

learning_rate = conf["optimizer"]["learning_rate"]
weight_decay = conf["optimizer"]["weight_decay"]

In [9]:
os.makedirs(model_loc, exist_ok = True)
shutil.copy(config, os.path.join(model_loc, "model.yml"))

'/glade/work/schreck/repos/HOLO/clean/holodec-ml/results/baseline/model.yml'

### Load the preprocessing transforms

In [10]:
train_transforms = LoadTransformations(conf["transforms"]["training"])
valid_transforms = LoadTransformations(conf["transforms"]["validation"])

INFO:holodecml.transforms:Loaded RandomVerticalFlip transformation with probability 0.5
INFO:holodecml.transforms:Loaded RandomHorizontalFlip transformation with probability 0.5
INFO:holodecml.transforms:Loaded Normalize transformation that normalizes data color channel by dividing by 255.0 and phase pi
INFO:holodecml.transforms:Loaded ToTensor transformation
INFO:holodecml.transforms:Loaded Normalize transformation that normalizes data color channel by dividing by 255.0 and phase pi
INFO:holodecml.transforms:Loaded ToTensor transformation


### Load the data class for reading and preparing the data as needed to train the u-net

In [11]:
# train_dataset = PickleReader(
#     fn_train, 
#     transform = train_transforms,
#     max_images = int(0.8 * conf["data"]["total_training"]), 
#     max_buffer_size = int(0.1 * conf["data"]["total_training"]), 
#     color_dim = color_dim,
#     shuffle = True
# )

train_dataset = UpsamplingReader(
    conf,
    transform = train_transforms,
    max_size = 1000
)

test_dataset = PickleReader(
    fn_valid,
    transform = valid_transforms,
    max_images = int(0.1 * conf["data"]["total_training"]),
    max_buffer_size = int(0.1 * conf["data"]["total_training"]),
    color_dim = color_dim,
    shuffle = False
)

In [12]:
#x, y = train_dataset.__getitem__(8)
#plt.imshow(x[0])

In [13]:
#plt.imshow(y)

### Load the iterators for batching the data

In [14]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=train_batch_size, 
    num_workers=16, # can increase to number of CPUs you asked for in launch script; usually 8
    pin_memory=True,
    shuffle=True) # let the reader do the shuffling

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=valid_batch_size,
    #num_workers=8,
    pin_memory=True,
    shuffle=False)

### Load a u-net model (resnet based on https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch)

In [15]:
unet = ResNetUNet(n_class = 1, color_dim = color_dim).to(device)

In [16]:
total_params = sum(p.numel() for p in unet.parameters())
trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad)

In [17]:
total_params

18310121

### Load an optimizer

In [18]:
optimizer = torch.optim.AdamW(
    unet.parameters(), 
    lr=learning_rate, 
    weight_decay=weight_decay
)

### Specify the training and validation losses

In [19]:
train_criterion = DiceBCELoss()
test_criterion = DiceLoss()

### Load a learning rate scheduler

In [20]:
lr_scheduler = ReduceLROnPlateau(
    optimizer, 
    patience = 1, 
    min_lr = 1.0e-13,
    verbose = True
)

### Train a U-net model

In [None]:
epoch_test_losses = []
results_dict = defaultdict(list)

for epoch in range(epochs):

    ### Train the model 
    unet.train()

    batch_loss = []
        
    # set up a custom tqdm
    batch_group_generator = tqdm.tqdm(
        enumerate(train_loader), 
        total=batches_per_epoch,
        leave=True
    )
 
    for k, (inputs, y) in batch_group_generator:
        
        # Move data to the GPU, if not there already
        inputs = inputs.to(device).float()
        y = y.to(device).float()
        
        # Clear gradient
        optimizer.zero_grad()

        # get output from the model, given the inputs
        pred_mask = unet(inputs)

        # get loss for the predicted output
        loss = train_criterion(pred_mask, y)
                
        # get gradients w.r.t to parameters
        loss.backward()
        batch_loss.append(loss.item())

        # update parameters
        optimizer.step()

        # update tqdm
        to_print = "Epoch {} train_loss: {:.4f}".format(epoch, np.mean(batch_loss))
        to_print += " lr: {:.12f}".format(optimizer.param_groups[0]['lr'])
        batch_group_generator.set_description(to_print)
        batch_group_generator.update()
                     
        # stop the training epoch when train_batches_per_epoch have been used to update 
        # the weights to the model
        if k >= batches_per_epoch and k > 0:
            break
            
        #lr_scheduler.step(epoch + k / batches_per_epoch)
        
    # Shutdown the progbar
    batch_group_generator.close()
        
    # Compuate final performance metrics before doing validation
    train_loss = np.mean(batch_loss)
        
    # clear the cached memory from the gpu
    torch.cuda.empty_cache()

    ### Test the model 
    unet.eval()
    with torch.no_grad():

        batch_loss = []
        
        # set up a custom tqdm
        batch_group_generator = tqdm.tqdm(
            enumerate(test_loader),
            leave=True
        )

        for k, (inputs, y) in batch_group_generator:
            # Move data to the GPU, if not there already
            inputs = inputs.to(device).float()
            y = y.to(device).float()
            # get output from the model, given the inputs
            pred_mask = unet(inputs)
            # get loss for the predicted output
            loss = test_criterion(pred_mask, y)
            batch_loss.append(loss.item())
            # update tqdm
            to_print = "Epoch {} test_loss: {:.4f}".format(epoch, np.mean(batch_loss))
            batch_group_generator.set_description(to_print)
            batch_group_generator.update()
            
        # Shutdown the progbar
        batch_group_generator.close()

    # Use the accuracy as the performance metric to toggle learning rate and early stopping
    test_loss = np.mean(batch_loss)
    epoch_test_losses.append(test_loss)
    
    # Lower the learning rate if we are not improving
    lr_scheduler.step(test_loss)

    # Save the model if its the best so far.
    if test_loss == min(epoch_test_losses):
        state_dict = {
            'epoch': epoch,
            'model_state_dict': unet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': test_loss
        }
        #TODO: add directory
        torch.save(state_dict, f"{model_loc}/best.pt")
        
    # Get the last learning rate
    learning_rate = optimizer.param_groups[0]['lr']
        
    # Put things into a results dictionary -> dataframe
    results_dict['epoch'].append(epoch)
    results_dict['train_loss'].append(train_loss)
    results_dict['valid_loss'].append(np.mean(batch_loss))
    results_dict["learning_rate"].append(learning_rate)
    df = pd.DataFrame.from_dict(results_dict).reset_index()

    # Save the dataframe to disk
    df.to_csv(f"{model_loc}/training_log.csv", index = False)
        
    # Stop training if we have not improved after X epochs (stopping patience)
    best_epoch = [i for i,j in enumerate(epoch_test_losses) if j == min(epoch_test_losses)][0]
    offset = epoch - best_epoch
    if offset >= stopping_patience:
        break

Epoch 0 train_loss: 1.6053 lr: 0.000100000000: 100%|██████████| 500/500 [03:23<00:00,  2.46it/s]
Epoch 0 test_loss: 0.9982: : 157it [01:35,  1.65it/s]
Epoch 1 train_loss: 1.3663 lr: 0.000100000000: 100%|██████████| 500/500 [03:26<00:00,  2.43it/s]
Epoch 1 test_loss: 0.9963: : 157it [01:26,  1.82it/s]
Epoch 2 train_loss: 1.1680 lr: 0.000100000000: 100%|██████████| 500/500 [03:26<00:00,  2.42it/s]
Epoch 2 test_loss: 0.9925: : 157it [01:26,  1.81it/s]
Epoch 3 train_loss: 1.0694 lr: 0.000100000000: 100%|██████████| 500/500 [03:26<00:00,  2.42it/s]
Epoch 3 test_loss: 0.9844: : 157it [01:28,  1.78it/s]
Epoch 4 train_loss: 1.0161 lr: 0.000100000000: 100%|██████████| 500/500 [03:26<00:00,  2.42it/s]
Epoch 4 test_loss: 0.9651: : 157it [01:27,  1.80it/s]
Epoch 5 train_loss: 0.9579 lr: 0.000100000000: 100%|██████████| 500/500 [03:26<00:00,  2.42it/s]
Epoch 5 test_loss: 0.8399: : 157it [01:26,  1.81it/s]
Epoch 6 train_loss: 0.5760 lr: 0.000100000000: 100%|██████████| 500/500 [03:26<00:00,  2.42it/