# TODO
1. Weights and Biases to log models
2. Build Validation Set
3. Loops for Training, Validation and RunTraining
4. Memory Management Clarity

Inspired heavily from the great work in https://www.kaggle.com/code/awsaf49/uwmgi-unet-train-pytorch#📒-Notebooks

# Config

In [None]:

# Params
train_csv_location = '../input/uw-madison-gi-tract-image-segmentation/train.csv'
training_images_directory = '../input/uw-madison-gi-tract-image-segmentation/train/'

training_prop = 0.8
image_size = 224
batch_size = 64

random_seed = 100


# Load Packages

## Install Required Packages

## Load

In [None]:

# Python Basics
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
pd.options.plotting.backend = "plotly"
import os, glob, cv2, PIL
import matplotlib.pyplot as plt

# Torch
import torch

## Torch Dataset and Loaders
from torch.utils.data import Dataset, DataLoader

## TorchVision
from torchvision.transforms import ToTensor
from torchvision import utils
from torchvision.io import read_image

## Torch Models
import torch.nn as nn
import torch.nn.functional as F

# From Reference Notebook

import random
from glob import glob
import os, shutil
from tqdm import tqdm
tqdm.pandas()
import time
import copy
import joblib
from collections import defaultdict
import gc
from IPython import display as ipd

# visualization
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Sklearn
from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold

# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

#import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2



# Albumentations for augmentations
#import albumentations as A
#from albumentations.pytorch import ToTensorV2

#import rasterio
#from joblib import Parallel, delayed

# For colored terminal text
#from colorama import Fore, Back, Style
#c_  = Fore.GREEN
#sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Prepare Environment

## Reproducibility

In [None]:
import numpy as np
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
    
set_seed(random_seed)

In [None]:
%load_ext autoreload
%autoreload 2

# Define Functions and Classes


## Utils

In [None]:
# Functions
def rle2mask(mask_rle, shape, label=1):
    """
    mask_rle: run-length as string formatted (start length)
    shape: (height,width) of array to return
    Returns numpy array, 1 - mask, 0 - background

    """
    
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    
    if mask_rle != 'nan':
        s = mask_rle.split()
        starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
        starts -= 1
        ends = starts + lengths

        for lo, hi in zip(starts, ends):
            img[lo:hi] = label
            
    return img.reshape(shape)  # Needed to align to RLE direction

def show_img(img, mask=None):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
#     img = clahe.apply(img)
#     plt.figure(figsize=(10,10))
    plt.imshow(img, cmap='bone')
    
    if mask is not None:
        # plt.imshow(np.ma.masked_where(mask!=1, mask), alpha=0.5, cmap='autumn')
        plt.imshow(mask, alpha=0.5)
        handles = [Rectangle((0,0),1,1, color=_c) for _c in [(0.667,0.0,0.0), (0.0,0.667,0.0), (0.0,0.0,0.667)]]
        labels = ["Large Bowel", "Small Bowel", "Stomach"]
        plt.legend(handles,labels)
    plt.axis('off')
    
    
def plot_batch(imgs, msks, size=3):
    plt.figure(figsize=(5*5, 5))
    for idx in range(size):
        plt.subplot(1, 5, idx+1)
        img = imgs[idx,].permute((1, 2, 0)).numpy()*255.0
        img = img.astype('uint8')
        msk = msks[idx,].permute((1, 2, 0)).numpy()*255.0
        show_img(img, msk)
    plt.tight_layout()
    plt.show()
    
    
def load_img(path):
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    img = np.tile(img[...,None], [1, 1, 3]) # gray to rgb
    img = img.astype('float32') # original is uint16
    mx = np.max(img)
    if mx:
        img/=mx # scale image to [0, 1]
    return img

## Prepare Training Data

In [None]:
def parse_case_day_slice_info(row):
    # Row-wise function to parse [Case, Day, Slice] numbers from Train.csv
    id_strings = row['id'].split('_')
    case_number = id_strings[0]
    day_number = id_strings[1]
    slice_number = id_strings[3]
    return case_number, day_number, slice_number

def parse_image_filename(row):
    # Row-wise function to parse [Case, Day, Slice, height, width, h_pixels, w_pixels] numbers from FilaName
    # To be joined with Annotations, also used for scaling
    
    #if file_name is None:
    #    file_name_string = row['file_name']
    #    print('Parsing String {}'.format(filename))
        
    relevant_string = file_name_string.split('slice')[1].split('_')
    
    slice_number = relevant_string[1]
    height = relevant_string[2]
    width = relevant_string[3]
    h_pixels = relevant_string[4]
    w_pixels = relevant_string[5].rstrip('.png')
    
    case_number = file_name_string.split('train/')[1].split('/')[0]
    day_number = file_name_string.split('train/')[1].split('/')[1].split('_')[1]
    day_id = int(day_number.strip('day'))
    
    return case_number, day_number, day_id, slice_number, height, width, h_pixels, w_pixels

def parse_image_info(row):
    # Row-wise function to parse [Case, Day, Slice, height, width, h_pixels, w_pixels] numbers from FilaName
    # To be joined with Annotations, also used for scaling
    
    relevant_string = row['file_name'].split('slice')[1].split('_')
    slice_number = relevant_string[1]
    height = relevant_string[2]
    width = relevant_string[3]
    h_pixels = relevant_string[4]
    w_pixels = relevant_string[5].rstrip('.png')
    
    case_number = row['file_name'].split('train/')[1].split('/')[0]
    day_number = row['file_name'].split('train/')[1].split('/')[1].split('_')[1]
    day_id = int(day_number.strip('day'))
    
    return case_number, day_number, day_id, slice_number, height, width, h_pixels, w_pixels


def get_training_masks(train_csv_location = train_csv_location):
    # Returns MaskRLE for each case_day_slice combination.
    # Each row referenced by case_day_slice_id
    # This portion has been referenced from another Kaggle notebook. Insert reference.
    # Get CSV of annotated mask information
    print('Getting All Available annotations for the training')
    train_csv = pd.read_csv(train_csv_location)

    # Change layout to have each slice Info on one row
    train_csv = pd.pivot_table(
        train_csv, 
        values='segmentation', 
        index='id',
        columns='class',
        aggfunc=np.max
    ).astype(str).fillna('')
    
    train_csv = train_csv.reset_index()
    
    # Transform train_csv to have clear information - parsed columns
    train_csv['id_info'] = train_csv.apply(parse_case_day_slice_info, axis = 1)
    train_csv[['case_number','day_number', 'slice_number']] = pd.DataFrame(train_csv.id_info.tolist(), index= train_csv.index)

    return train_csv


def get_training_image_names(location = training_images_directory):
    
    # Find all files in mentioned location, Make a DF for later use
    print('Getting Filenames of all available training images')
    print('From location {}'.format(location))
    location = training_images_directory
    print(location)
    x = glob(location + '*/*/*/*')
    print(len(x))
    image_info = pd.DataFrame(x, columns = ['file_name'])
    
    image_info['slice_info'] = image_info.apply(parse_image_info, axis = 1)
    image_info[['case_number','day_number', 'day_id', 'slice_number','height', 'width', 'h_pixels', 'w_pixels']] = pd.DataFrame(
        image_info.slice_info.tolist(), 
        index= image_info.index
    )

    return image_info

def get_training_dataset_info():
    # Fetches Image F
    # Output: DF with [case_day_slice_id, file_name_location, target_mask_1, target_mask_2, target_mask_3]
    train_csv = get_training_masks()
    # ow that we know all the masks, get all the case day slice image file name and locations for easy retrieval
    train_image_names = get_training_image_names()
    
    joined_data = train_csv.merge(train_image_names, on = ['case_number', 'day_number', 'slice_number'], how = 'inner')
    print('Joined Data has shape {}'.format(joined_data.shape))
    
    return joined_data

## Custom DataSet Generator

In [None]:


class CustomImageDataset(Dataset):
# Takes in location, outputs an image - 
# Each image is a [case_number, day_number, slice_number]

    def __init__(self, image_mask_associations, img_dir = training_images_directory, transform=None, labels = True):

        #print('Initilaising Data at location {}'.format(img_dir))
        self.image_mask_associations = image_mask_associations
        self.img_dir = img_dir
        self.transforms = transform
        self.labels = labels
        

    def __len__(self):
        return len(self.image_mask_associations)

    def get_masks_for_image(self, idx):
#        print('Getting masks for img at location {}'.format(img_path))
        # Returns the three channels masks for the given Image File Name
        rle_stomach = self.image_mask_associations.loc[idx]['stomach']
        rle_lb = self.image_mask_associations.loc[idx]['large_bowel']
        rle_sb = self.image_mask_associations.loc[idx]['small_bowel']

        img_size_h = int(self.image_mask_associations.loc[idx]['height'])
        img_size_w = int(self.image_mask_associations.loc[idx]['width'])

        mask_stomach = rle2mask(rle_stomach, shape=[img_size_w,img_size_h])
        mask_lb = rle2mask(rle_lb, shape=[img_size_w,img_size_h])
        mask_sb = rle2mask(rle_sb, shape=[img_size_w,img_size_h])
        
        masks = np.stack((mask_stomach, mask_lb, mask_sb), axis = 2).astype('uint8')
        #masks/=255.0
        return masks

        
    def __getitem__(self, idx):
        # Get location of Image based on idx, from image_mask_associations
        img_path = self.image_mask_associations['file_name'].iloc[idx]
        img = []
        img = load_img(img_path)
        
        # In case labels are required, fetch them. Else, transform Images and return
        if self.labels:
            msks = self.get_masks_for_image(idx)
            
            if self.transforms:
                data = self.transforms(image = img, mask = msks)
                img = data['image']
                msks = data['mask']
            
            img = np.transpose(img, (2, 0, 1))
            msks = np.transpose(msks, (2, 0, 1))
            
            return torch.tensor(img), torch.tensor(msks)
        else:
            if self.transforms:
                data = self.transforms(image = img)
                img = data['image']
                
            img = np.transpose(img, (2, 0, 1))
            
            return torch.tensor(img)

In [None]:
#data_transforms = A.Compose([
#  	 A.Resize(*[224,224])
# ]
#)
img_size = [224,224]
data_transforms = {
    "train": A.Compose([
        A.Resize(*img_size, interpolation=cv2.INTER_NEAREST),
        A.HorizontalFlip(p=0.5),
#         A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
        A.OneOf([
            A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
# #             A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=1.0),
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
        ], p=0.25),
        A.CoarseDropout(max_holes=8, max_height=img_size[0]//20, max_width=img_size[1]//20,
                         min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
        ], p=1.0),
    
    "valid": A.Compose([
        A.Resize(*img_size, interpolation=cv2.INTER_NEAREST),
        ], p=1.0)
}

## Train / Validation Split

In [None]:
# Get Image-Mask Associations
image_mask_associations = get_training_dataset_info()

In [None]:
cases = image_mask_associations['case_number'].unique()
training_set_membership = int(training_prop*len(cases))
training_set_cases = list(set(random.sample(list(cases), int(0.8*len(cases)))))
validation_set_cases = [x for x in cases if x not in training_set_cases]

def validation_set_split(row):
    if row['case_number'] in training_set_cases: 
        return 'training' 
    elif row['case_number'] in validation_set_cases:
        return 'validation'
    else: 
        return 'leave_for_now'
    

image_mask_associations['t_v_set'] = image_mask_associations.apply(validation_set_split, axis = 1)
image_mask_associations.groupby('t_v_set')['case_number'].count()

## Data Transforms and Prepare Loaders

In [None]:
image_mask_associations.groupby('case_number').count()

In [None]:
image_mask_associations.groupby('t_v_set').count()

In [None]:
image_mask_associations.head()

In [None]:
def prepare_loaders(batch_size, transforms = data_transforms):

    train_df = image_mask_associations.query("t_v_set == 'training'").reset_index(drop=True)
    valid_df = image_mask_associations.query("t_v_set == 'validation'").reset_index(drop=True)
    #if debug:
    #    train_df = train_df.head(32*5).query("empty==0")
    #    valid_df = valid_df.head(32*3).query("empty==0")
    train_dataset = CustomImageDataset(train_df, transform=transforms['train'])
    valid_dataset = CustomImageDataset(valid_df, transform=transforms['valid'])

    train_loader = DataLoader(train_dataset, 
                              batch_size = batch_size, 
                              num_workers=4, 
                              shuffle=True, 
                              pin_memory=True, 
                             drop_last=False
                              )
    
    valid_loader = DataLoader(valid_dataset,
                              batch_size = batch_size, 
                              num_workers=4, 
                              shuffle=False, 
                              pin_memory=True)
    
    return train_loader, valid_loader

## UNet Model Definitions

In [None]:
class UNET(nn.Module):
    
    def __init__(self, in_channels=3, classes=1):
        super(UNET, self).__init__()
        self.layers = [in_channels, 64, 128, 256, 512, 1024, 2048]
        
        self.double_conv_downs = nn.ModuleList(
            [self.__double_conv(layer, layer_n) for layer, layer_n in zip(self.layers[:-1], self.layers[1:])])
        
        self.up_trans = nn.ModuleList(
            [nn.ConvTranspose2d(layer, layer_n, kernel_size=2, stride=2)
             for layer, layer_n in zip(self.layers[::-1][:-2], self.layers[::-1][1:-1])])
            
        self.double_conv_ups = nn.ModuleList(
        [self.__double_conv(layer, layer//2) for layer in self.layers[::-1][:-2]])
        
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.final_conv = nn.Conv2d(64, classes, kernel_size=1)

        
    def __double_conv(self, in_channels, out_channels):
        conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        return conv
    
    def forward(self, x):
        # down layers
        concat_layers = []
        
        for down in self.double_conv_downs:
            x = down(x)
            if down != self.double_conv_downs[-1]:
                concat_layers.append(x)
                x = self.max_pool_2x2(x)
        
        concat_layers = concat_layers[::-1]
        
        # up layers
        for up_trans, double_conv_up, concat_layer  in zip(self.up_trans, self.double_conv_ups, concat_layers):
            x = up_trans(x)
            if x.shape != concat_layer.shape:
                x = TF.resize(x, concat_layer.shape[2:])
            
            concatenated = torch.cat((concat_layer, x), dim=1)
            x = double_conv_up(concatenated)
            
        x = self.final_conv(x)
        #x = nn.ReLU(x)
        
        return x 

In [None]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
    
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))


class Encoder(nn.Module):
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs


class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs


class UNet(nn.Module):
    def __init__(self, enc_chs=(3,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, retain_dim=False, out_sz=(572,572)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
        self.retain_dim  = retain_dim

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, out_sz)
        return out
    
    

In [None]:
def build_model():
    model = smp.Unet(
        encoder_name='efficientnet-b2',      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        #encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=3,        # model output channels (number of classes in your dataset)
        activation=None,
    )
    model.to(device)
    return model

In [None]:
gc.collect()

## Losses

In [None]:
!pip install segmentation_models_pytorch
import segmentation_models_pytorch as smp

JaccardLoss = smp.losses.JaccardLoss(mode='multilabel')
DiceLoss    = smp.losses.DiceLoss(mode='multilabel')
BCELoss     = smp.losses.SoftBCEWithLogitsLoss()
LovaszLoss  = smp.losses.LovaszLoss(mode='multilabel', per_image=False)
TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)

def dice_coef(y_true, y_pred, thr=0.5, dim=(2,3), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    den = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    dice = ((2*inter+epsilon)/(den+epsilon)).mean(dim=(1,0))
    return dice

def iou_coef(y_true, y_pred, thr=0.5, dim=(2,3), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    union = (y_true + y_pred - y_true*y_pred).sum(dim=dim)
    iou = ((inter+epsilon)/(union+epsilon)).mean(dim=(1,0))
    return iou

def criterion(y_pred, y_true):
    return DiceLoss(y_pred, y_true)#0.5*BCELoss(y_pred, y_true) + 0.5*TverskyLoss(y_pred, y_true)

# Begin Modelling Process

## Train One Epoch

In [None]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.train()
    scaler = amp.GradScaler()
    
    dataset_size = 0
    running_loss = 0.0
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    for step, (images, masks) in pbar:         
        images = images.to(device, dtype=torch.float)
        masks  = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        with amp.autocast(enabled=True):
            y_pred = model(images)
            loss   = dice_coef(y_pred, masks)
            loss   = loss / n_accumulate
            
        scaler.scale(loss).backward()
    
        if (step + 1) % n_accumulate == 0:
            scaler.step(optimizer)
            scaler.update()

            # zero the parameter gradients
            optimizer.zero_grad()

            if scheduler is not None:
                scheduler.step(loss)
                
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(train_loss=f'{epoch_loss:0.4f}',
                        lr=f'{current_lr:0.5f}',
                        gpu_mem=f'{mem:0.2f} GB')
    torch.cuda.empty_cache()
    gc.collect()
    
    return epoch_loss

## Validate One Epoch

In [None]:
@torch.no_grad()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    val_scores = []
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid ')
    for step, (images, masks) in pbar:        
        images  = images.to(device, dtype=torch.float)
        masks   = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        y_pred  = model(images)
        loss    = criterion(y_pred, masks)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        y_pred = nn.Sigmoid()(y_pred)
        val_dice = dice_coef(masks, y_pred).cpu().detach().numpy()
        val_jaccard = iou_coef(masks, y_pred).cpu().detach().numpy()
        val_scores.append([val_dice, val_jaccard])
        
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(valid_loss=f'{epoch_loss:0.4f}',
                        lr=f'{current_lr:0.7f}',
                        gpu_memory=f'{mem:0.2f} GB')
    val_scores  = np.mean(val_scores, axis=0)
    torch.cuda.empty_cache()
    gc.collect()
    
    return epoch_loss, val_scores

## Run Training

In [None]:

def run_training(model, optimizer, scheduler, device, num_epochs):
    # To automatically log gradients
    #wandb.watch(model, log_freq=100)
    
    if torch.cuda.is_available():
        print("cuda: {}\n".format(torch.cuda.get_device_name()))
    
    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_dice      = -np.inf
    best_epoch     = -1
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        print(f'Epoch {epoch}/{num_epochs}', end='')
        train_loss = train_one_epoch(model, optimizer, scheduler, 
                                           dataloader=train_loader, 
                                           device=device, epoch=epoch)
        
        val_loss, val_scores = valid_one_epoch(model, valid_loader, 
                                                 device=device, 
                                                 epoch=epoch)
        val_dice, val_jaccard = val_scores
    
        history['Train Loss'].append(train_loss)
        history['Valid Loss'].append(val_loss)
        history['Valid Dice'].append(val_dice)
        history['Valid Jaccard'].append(val_jaccard)
        
        # Log the metrics
        #wandb.log({"Train Loss": train_loss, 
        #           "Valid Loss": val_loss,
        #           "Valid Dice": val_dice,
        #           "Valid Jaccard": val_jaccard,
        #           "LR":scheduler.get_last_lr()[0]})
        
        print(f'Valid Dice: {val_dice:0.4f} | Valid Jaccard: {val_jaccard:0.4f}')
        
        # deep copy the model
        if val_dice >= best_dice:
            print(f"Valid Score Improved ({best_dice:0.4f} ---> {val_dice:0.4f})")
            best_dice    = val_dice
            best_jaccard = val_jaccard
            best_epoch   = epoch
            #run.summary["Best Dice"]    = best_dice
            #run.summary["Best Jaccard"] = best_jaccard
            #run.summary["Best Epoch"]   = best_epoch
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"best_epoch.bin"
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
            #wandb.save(PATH)
            print(f"Model Saved")
            
        last_model_wts = copy.deepcopy(model.state_dict())
        PATH = f"last_epoch_{epoch}.bin"
        torch.save(model.state_dict(), PATH)
            
        print(); print()
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Score: {:.4f}".format(best_jaccard))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, history

## Optimizer

In [None]:

"""
def fetch_scheduler(optimizer):
    if CFG.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CFG.T_max, 
                                                   eta_min=CFG.min_lr)
    elif CFG.scheduler == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=CFG.T_0, 
                                                             eta_min=CFG.min_lr)
    elif CFG.scheduler == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   mode='min',
                                                   factor=0.1,
                                                   patience=7,
                                                   threshold=0.0001,
                                                   min_lr=CFG.min_lr,)
    elif CFG.scheduer == 'ExponentialLR':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.85)
    elif CFG.scheduler == None:
        return None
        
    return scheduler

"""

In [None]:
import gc
gc.collect()

# Train Model

In [None]:
train_loader, valid_loader = prepare_loaders(batch_size = 128)

In [None]:
n_accumulate  = max(1, 32//batch_size)
epochs        = 15
lr            = 2e-3

min_lr        = 1e-6
T_max         = int(30000/batch_size*epochs)+50
T_0           = 25
warmup_epochs = 0
wd            = 1e-6

model     = build_model()


In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=wd)

scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   mode='min',
                                                   factor=0.005,
                                                   patience=50,
                                                   threshold=0.0001,
                                                   min_lr=1e-6,
                                           verbose = True)

In [None]:
#unet = UNET(in_channels=1, classes=3).to(device)
#optimizer = optim.SGD(unet.parameters(), lr=0.7, momentum=0.9)
#scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.85)
#scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
#                                                   mode='min',
#                                                   factor=0.5,
#                                                   patience=10,
#                                                   threshold=0.001,
#                                                   min_lr = 1e-6, verbose = True)

model, history = run_training(model, 
                              optimizer, 
                              scheduler,
                              device= device,
                              num_epochs= 25)
#run.finish()
#display(ipd.IFrame(run.url, width=1000, height=720))


In [None]:
"""
imgs = a
msks = b
idx = 10
img = imgs[idx,].permute((1, 2, 0)).numpy()*255.0
img = img.astype('uint8')
msk = msks[idx,].permute((1, 2, 0)).numpy()*255.0
output = pred[idx,].cpu().detach().permute((1, 2, 0)).numpy()

plt.subplot(1, 5, 1)
plt.imshow(img)
plt.subplot(1, 5, 2)
plt.imshow(msk)
plt.subplot(1, 5, 3)
plt.imshow(output)"""

In [None]:

"""# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.imshow(np.moveaxis(np.asarray(label), 0, 2), cmap="gray", alpha = 0.2)
plt.show()
#print(f"Label: {label}")"""

In [None]:
"""

print(f"Using {device} device")


for epoch in range(100):  # loop over the dataset multiple times

  running_loss = 0.0
  for i, data in enumerate(train_dataloader, 0):
      # get the inputs; data is a list of [inputs, labels]
      #print(i)
      inputs, labels = data
      inputs, labels = inputs.cuda(), labels.cuda() # add this line
      # zero the parameter gradients
      optimizer.zero_grad()

      # forward + backward + optimize
      outputs = unet(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      #plt.imshow(outputs)
      # print statistics
      running_loss += loss.item()
      
      if i % 20 == 19:    # print every 2000 mini-batches
          
          print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
          print(running_loss)
          running_loss = 0.0
  scheduler.step(loss)
    
print('Finished Training')
"""

## Inspect Model

# Prediction

In [None]:
# TODO : Build Test from Images that do not have masks.
#test_df = image_data_associations.query()
#test_dataset = BuildDataset(df.query("fold==0 & empty==0").sample(frac=1.0), label=False, 
                            #transforms=data_transforms['valid'])
#test_loader  = DataLoader(test_dataset, batch_size=5, 
                        #  num_workers=4, shuffle=False, pin_memory=True)
def load_model(path):
    model = build_model() #UNET(in_channels=1, classes=3).to(device)
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

imgs, msks = next(iter(train_loader))
imgs = imgs.to(device, dtype=torch.float)
msks = msks.to(device, dtype=torch.float)

preds = []

model2 = load_model(f"best_epoch.bin")
with torch.no_grad():
    pred = model2(imgs)
    pred = (nn.Sigmoid()(pred)>0.5).double()
#preds.append(pred)
    
imgs  = imgs.cpu().detach()
preds = pred#torch.mean(pred, dim=0).cpu().detach()
print(pred.min(), pred.max())
plot_batch(imgs.cpu(), preds.cpu(), size=5)

In [None]:
plot_batch(imgs.cpu(), msks.cpu(), size=5)

In [None]:
#pred = (nn.Sigmoid()(pred)>0.5).double()
pred.min(), pred.max()