In [8]:
import warnings
warnings.filterwarnings("ignore")

import os
import sys
import glob
import tqdm
import time
import yaml
import torch
import pickle
import joblib
import random
import sklearn
import logging
import datetime
import torch.fft
import torchvision
import torchvision.models as models

#torch.multiprocessing.set_start_method('spawn')

import numpy as np
import pandas as pd
import xarray as xr
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from collections import defaultdict
from scipy.signal import convolve2d
from torch.optim.lr_scheduler import *
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, TensorDataset
from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple

### Set up the GPU device id, or CPU if no GPU available

In [16]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")

if is_cuda:
    torch.backends.cudnn.benchmark = True
    #torch.backends.cudnn.deterministic = True

### Load data from disk

In [1]:
# # Load all of the data into memory
# images = []
# labels = []
# masks = []

# loaded = 0
# # when training, load 10k to 20k images at once
# max_images = 10

# start_time = time.time()
# with open("/glade/work/schreck/repos/HOLO/holodec-ml/holodecml/reader/training_512x512_128_50000.pkl", "rb") as fid:
#     while True:
        
#         try:
#             image, label, u_net_mask, image_tile_idx, image_tile_coors = pickle.load(fid)
#             images.append(np.expand_dims(image, 0))
#             labels.append(label)
#             masks.append(np.expand_dims(u_net_mask, 0))
            
#             loaded += 1
            
#             if len(images) == max_images:
#                 break
            
#         except Exception as E:
#             break
            
# images = np.vstack(images)
# labels = np.vstack(labels)
# masks = np.vstack(masks)

# end_time = time.time()

In [2]:
#print(f"It took {end_time - start_time} s to load {loaded} (x,y) points")

In [3]:
#images.shape, labels.shape, masks.shape

In [4]:
# train_indices = random.sample(list(range(images.shape[0])), 4 * images.shape[0] // 5)
# test_indices = list(set(range(images.shape[0])) - set(train_indices))

# X_train = images[train_indices]
# X_test = images[test_indices]
# y_train = masks[train_indices]
# y_test = masks[test_indices]

#X_train, X_test, y_train, y_test = train_test_split(images, masks, test_size=0.20, random_state=42)

In [9]:
class HologramLoader(Dataset):
    
    def __init__(self, fn, max_buffer_size = 5000, max_images = 40000, shuffle = True):
        self.fn = fn
        self.buffer = []
        self.max_buffer_size = max_buffer_size
        self.shuffle = shuffle
        self.max_images = max_images
            
        self.fid = open(self.fn, "rb")
        self.loaded = 0 
        self.epoch = 0
        
    def __getitem__(self, idx):    
        
        self.on_epoch_end()
        
        try:
            data = joblib.load(self.fid)
            image, label, mask = data
            image = torch.FloatTensor(image.squeeze(0))
            #label = torch.LongTensor([label])
            mask = torch.FloatTensor(mask.squeeze(0))
            data = (image, mask)
            
            self.loaded += 1

            if not self.shuffle:
                return data
            self.buffer.append(data)
            random.shuffle(self.buffer)

            if len(self.buffer) > self.max_buffer_size:
                self.buffer = self.buffer[:self.max_buffer_size]
                
            if self.epoch > 0:
                return self.buffer.pop()
            
            else: # wait until all data has been seen before sampling from the buffer
                return data
            

        except EOFError:
            self.fid = open(self.fn, "rb")
            self.loaded = 0
            return #raise StopIteration

                    
    def __len__(self):
        return self.max_images
    
    def on_epoch_end(self):
        if self.loaded == self.__len__():
            self.fid = open(self.fn, "rb")
            self.loaded = 0
            self.epoch += 1

### Load the binary model 

In [10]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))
class Encoder(nn.Module):
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs
    
class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs
    
class UNet(nn.Module):
    def __init__(self,
                 enc_chs=(2, 64, 128, 256, 512, 1024),
                 dec_chs=(1024, 512, 256, 128, 64),
                 num_class=1,
                 retain_dim=False, out_sz=(572,572)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
        self.retain_dim  = retain_dim
        self.out_sz = out_sz
    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, self.out_sz)
        return out


In [11]:
fn_train = "/glade/p/cisl/aiml/ai4ess_hackathon/holodec/tiled_synthetic/training_512x512_128_40000.pkl"
fn_valid = "/glade/p/cisl/aiml/ai4ess_hackathon/holodec/tiled_synthetic/validation_512x512_128_10000.pkl"

epochs = 200
train_batch_size = 32
valid_batch_size = 32
batches_per_epoch = 100

stopping_patience = 5

In [17]:
# train_dataset = torch.utils.data.TensorDataset(
#     torch.from_numpy(X_train), 
#     torch.from_numpy(y_train)
# )

# test_dataset = torch.utils.data.TensorDataset(
#     torch.from_numpy(X_test), 
#     torch.from_numpy(y_test)
# )

train_dataset = HologramLoader(
    fn_train, 
    max_images = 40000, 
    max_buffer_size = 5000, 
    shuffle = True
)

test_dataset = HologramLoader(
    fn_valid, 
    max_images = 10000, 
    shuffle = False
)

In [18]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=train_batch_size, 
    num_workers=8, # can increase to number of CPUs you asked for in launch script; usually 8
    pin_memory=True,
    shuffle=False) # let the reader do the shuffling

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=valid_batch_size,
    num_workers=8,
    pin_memory=True,
    shuffle=False)

In [19]:
retain_dim = True
mask_size = (512,512)

unet = UNet(
    retain_dim = retain_dim,
    out_sz = mask_size
).to(device)

In [20]:
learning_rate = 1e-04
weight_decay = 0.0

optimizer = torch.optim.Adam(
    unet.parameters(), 
    lr=learning_rate, 
    weight_decay=weight_decay
)

In [21]:
train_criterion = torch.nn.SmoothL1Loss() # Huber (MSE, but once converges, MAE)
test_criterion = torch.nn.L1Loss() # MAE

In [22]:
lr_scheduler = ReduceLROnPlateau(
    optimizer, 
    patience = 1, 
    min_lr = 1.0e-10,
    verbose = True
)


In [23]:
epoch_test_losses = []
results_dict = defaultdict(list)


for epoch in range(epochs):

    ### Train the model 
    unet.train()

    batch_loss = []
        
    # set up a custom tqdm
    batch_group_generator = tqdm.tqdm(
        enumerate(train_loader), 
        total=batches_per_epoch,
        leave=True
    )
 
    for k, (inputs, y) in batch_group_generator:
        
        # Move data to the GPU, if not there already
        inputs = inputs.to(device).float()
        y = y.to(device).float()
        
        # Clear gradient
        optimizer.zero_grad()

        # get output from the model, given the inputs
        pred_mask = unet(inputs)

        # get loss for the predicted output
        loss = train_criterion(pred_mask, y)
                
        # get gradients w.r.t to parameters
        loss.backward()
        batch_loss.append(loss.item())

        # update parameters
        optimizer.step()

        # update tqdm
        to_print = "Epoch {} train_loss: {:.4f}".format(epoch, np.mean(batch_loss))
        to_print += " lr: {:.12f}".format(optimizer.param_groups[0]['lr'])
        batch_group_generator.set_description(to_print)
        batch_group_generator.update()
                     
        # stop the training epoch when train_batches_per_epoch have been used to update 
        # the weights to the model
        if k >= batches_per_epoch and k > 0:
            break
            
        #lr_scheduler.step(epoch + k / batches_per_epoch)
        
    # Compuate final performance metrics before doing validation
    train_loss = np.mean(batch_loss)
        
    # clear the cached memory from the gpu
    torch.cuda.empty_cache()

    ### Test the model 
    unet.eval()
    with torch.no_grad():

        batch_loss = []
        
        # set up a custom tqdm
        batch_group_generator = tqdm.tqdm(
            enumerate(train_loader),
            leave=True
        )

        for k, (inputs, y) in batch_group_generator:
            # Move data to the GPU, if not there already
            inputs = inputs.to(device).float()
            y = y.to(device).long()
            # get output from the model, given the inputs
            pred_mask = unet(inputs)
            # get loss for the predicted output
            loss = test_criterion(pred_mask, y)
            batch_loss.append(loss.item())
            # update tqdm
            to_print = "Epoch {} test_loss: {:.4f}".format(epoch, np.mean(batch_loss))
            batch_group_generator.set_description(to_print)
            batch_group_generator.update()

    # Use the accuracy as the performance metric to toggle learning rate and early stopping
    test_loss = np.mean(batch_loss)
    epoch_test_losses.append(test_loss)
    
    # Lower the learning rate if we are not improving
    lr_scheduler.step(test_loss)

    # Save the model if its the best so far.
    if test_loss == min(epoch_test_losses):
        state_dict = {
            'epoch': epoch,
            'model_state_dict': unet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': test_loss
        }
        #TODO: add directory
        torch.save(state_dict, "best_unet.pt")
        
    # Get the last learning rate
    learning_rate = optimizer.param_groups[0]['lr']
        
    # Put things into a results dictionary -> dataframe
    results_dict['epoch'].append(epoch)
    results_dict['train_loss'].append(train_loss)
    results_dict['valid_loss'].append(np.mean(batch_loss))
    results_dict["learning_rate"].append(learning_rate)
    df = pd.DataFrame.from_dict(results_dict).reset_index()

    # Save the dataframe to disk
    #TODO: add directory
    df.to_csv("training_log_unet.csv", index = False)
        
    # Stop training if we have not improved after X epochs (stopping patience)
    best_epoch = [i for i,j in enumerate(epoch_test_losses) if j == min(epoch_test_losses)][0]
    offset = epoch - best_epoch
    if offset >= stopping_patience:
        break
        

Epoch 0 train_loss: 0.5359 lr: 0.000100000000:  10%|█         | 10/100 [00:26<03:58,  2.65s/it]


KeyboardInterrupt: 