# Train notebook pipeline

### Set up imports

In [2]:
# Imports

# Python imports
import os
import cv2
import sys
import glob
import json
import random
import pathlib
import numpy as np
import seaborn as sns
from PIL import Image
from barbar import Bar
from natsort import natsorted
import time

from models import UNet, init_net
from dataloader import EndoMaskDataset
from tensorboardX import SummaryWriter

# For experimentation purpose
import torch
import torchvision
import albumentations as alb
from torchsummary import summary
from torchvision.utils import save_image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
import albumentations.augmentations.transforms as alb_tr

# Project imports
import utils as ut

os.environ['CUDA_VISIBLE_DEVICES'] = str(1)
DEVICE = torch.device('cuda:0')

# Setup interact widget
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

# Load the TensorBoard notebook extension
%load_ext tensorboard

# Auto-reload magic function setup
%load_ext autoreload
%autoreload 2

# Matplotlib magic function setup
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20,10)

### Data setup

In [4]:
# Load path to text file
# You could automate this with an argument possibly ,if needed like just specifying a fold number
split_file_path = "../aicm_sim_dataset/fold_1/{}_files.txt"
train_filenames = ut.read_lines_from_text_file(split_file_path.format("train"))

# If you want to turn off augmentation then just set this False, 
# so you don't have to instantiate aug variables all the time
aug = True 


HEIGHT = 448  # This is just a default, change this as per needed
WIDTH = 448
AUG_PROB = 0.5  # Randomness with which the random aug has to be implemented


# Define aug on image
# and the aug to be done on both image and mask
if aug:
    image_aug = alb.Compose([alb.Resize(height=HEIGHT, width=WIDTH),
                             alb_tr.ColorJitter(brightness=0.2,
                                                contrast=(0.3, 1.5),
                                                saturation=(0.5, 2),
                                                hue=0.1,
                                                p=AUG_PROB)])
                             
    image_mask_aug = alb.Compose([alb.Rotate(limit=(-60, 60), p=AUG_PROB),
                                  alb.IAAAffine(translate_percent=10, shear=0.1, p=AUG_PROB),
                                  alb.HorizontalFlip(p=AUG_PROB),
                                  alb.VerticalFlip(p=AUG_PROB)])

else:
    image_aug = None
    image_mask_aug = None

    
DATAROOT = "/mnt/sds-stud/guest/data_preprocessed/data_coco_final_v3"  # An example, you can change this later
MASK_PATH = "mask"  # Whatever this path is 

# Instantiate PyTorch dataloader
train_dataset = EndoMaskDataset(data_root_folder=DATAROOT,
                                     filenames=train_filenames,
                                     height=HEIGHT,
                                     width=WIDTH,
                                     image_aug=image_aug,
                                     image_mask_aug=image_mask_aug)


BATCH_SIZE = 32

# Package this into a pytorch dataloader
train_dataloader = DataLoader(train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              drop_last=False)

VAL = True  # If you don't want val just toggle this flag

if VAL:
    val_filenames = ut.read_lines_from_text_file(split_file_path.format("val"))
    
    val_dataset = EndoMaskDataset(data_root_folder=DATAROOT,
                                  filenames=val_filenames,
                                  height=HEIGHT,
                                  width=WIDTH,
                                  image_aug=None,
                                  image_mask_aug=None)


    val_dataloader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=False,
                                drop_last=False)

# Save image augmentations to config file
aug_dict = {"image_aug": alb.to_dict(image_aug) if image_aug else None,
            "image_mask_aug": alb.to_dict(image_mask_aug) if image_mask_aug else None}

### Model setup

In [None]:
# Seed everything
seed = 10
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Create model, optim and scheduler
model = UNet(n_channels=3,  # Input channels
             n_classes=1)  # Output channels

LR = 0.001 # Learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                               patience=10,
                                                               min_lr=1e-10,
                                                               factor=0.1)

# Init model
model = init_net(model, type="kaiming", mode="fan_in",
                 activation_mode="relu",
                 distribution="normal")

"""
Loss functions
"""
mse = torch.nn.MSELoss()

### Trainer setup

In [None]:
def compute_epoch(dataloader, train=True):
    running_loss = 0
    running_metric = 0
    pred_mask = None

    if train:
        model.train()

    for i, batch in enumerate(Bar(dataloader), 0):
        image, mask, filename = batch
        image, mask = image.to(DEVICE), mask.to(DEVICE)

        optimizer.zero_grad()  # set the gradients to zero
        pred_mask = model(image)
        loss = mse(y_pred=pred_mask, y_true=mask)

        if train:
            loss.backward()  # backward pass
            optimizer.step()  # Update parameters

        metric = metric_fn(pred=pred_mask, target=mask)
        running_metric += metric.item() * BATCH_SIZE
        running_loss += loss.item() * BATCH_SIZE  # Mean of one batch times the batch size

    epoch_loss = running_loss / len(dataloader.dataset)  # Sum of all samples over number of samples in dataset
    epoch_metric = (running_metric * 100) / len(dataloader.dataset)
    return epoch_loss, epoch_metric, pred_mask[0]

### Train loop

In [None]:
LOG_PATH = "" # Set a log path here
writer = SummaryWriter(LOG_PATH)
MODEL_NAME = "unet_baseline"

def log_losses(name, loss, epoch):
    """Write an event to the tensorboard events file"""
    if isinstance(loss, dict): writer.add_scalars(name, loss, epoch)
    else: writer.add_scalar(name, loss, epoch)

def log_images(name, loss, epoch):
    """Write an image to the tensorboard events file"""
    writer.add_image(name, loss, epoch)

def save_model(epoch):
    """Save model weights to disk
    """
    save_folder = os.path.join(LOG_PATH, "model_weights", "weights_{}".format(epoch))
    os.makedirs(save_folder)
    save_path = os.path.join(save_folder, "{}.pth".format(MODEL_NAME))
    to_save = model.state_dict()
    torch.save(to_save, save_path)

def save_checkpoint(self, epoch, loss):
    """ Save model weights and optim state to disk
    """
    save_folder = os.path.join(LOG_PATH, "model_weights", "weights_{}".format(epoch))
    os.makedirs(save_folder)
    save_path = os.path.join(save_folder, "{}.pt".format(MODEL_NAME))
    checkpoint = {'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': loss}
    torch.save(checkpoint, save_path)

In [None]:
NUM_EPOCHS = 60
SAVE_FREQ = 10  # Save every n epochs

for epoch in range(NUM_EPOCHS):
    print("Epoch {}".format(epoch + 1))

    model.train()
    time_before_epoch_train = time.time()
    # *** Train loop ***
    # *** Train loop ***
    train_loss, train_metric, train_pred = compute_epoch(dataloader=train_dataloader, train=True)
    lr_scheduler.step(train_loss)
    # *** End train loop ***
    epoch_train_duration = time.time() - time_before_epoch_trai
    
    log_losses('loss', {"train_loss": train_loss}, epoch+1)
    log_losses('train_metric', train_metric, epoch+1)
    log_images('train', train_pred, epoch+1)
    print('Epoch {} mean batch train loss: {:0.5f} | train metric: {:0.4f} | epoch train time: {:0.2f}s'.
          format(epoch+1, train_loss, train_metric, epoch_train_duration))

    if VAL:
        model.eval()
        with torch.no_grad():
            time_before_epoch_val = time.time()
            # *** Val loop ***
            val_loss, val_metric, val_pred = self.compute_epoch(dataloader=self.val_dataloader, train=False)
            # *** End val loop ***
            epoch_val_duration = time.time() - time_before_epoch_val

        self.log_losses('loss', {"val_loss": val_loss}, epoch + 1)
        print('Epoch {} mean batch val loss: {:0.5f} | val metric: {:0.4f} | val train time: {:0.2f}s'.
              format(epoch+1, val_loss, val_metric, epoch_val_duration))

    # save model checkpoint every save_freq epochs
    if (epoch + 1) % SAVE_FREQ == 0: save_checkpoint(epoch=epoch+1, loss=train_loss)

* End of program