## Setup

In [126]:
import os
import importlib
from tqdm import tqdm
import numpy as np

# Torch modules
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torch import amp

# Custom modules
import segmentation.utils
importlib.reload(segmentation.utils)
from segmentation.utils import preprocessing, model_utils

import segmentation.s3_utils
importlib.reload(segmentation.s3_utils)
from segmentation.s3_utils import *

import segmentation.dataset
importlib.reload(segmentation.dataset)
from segmentation.dataset import PointDataset

from models.unet_model import UNET

import importlib.readers
from segmentation.show import *
import segmentation.constants
importlib.reload(segmentation.constants)
from segmentation.constants import VisualisationConstants

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
from segmentation.utils import traininglog


In [16]:
# Defining some constants that will be used throughout the notebook
DATA_DIR = 'Dataset/'
x_test_dir = os.path.join(DATA_DIR, 'Test/color')
y_test_dir = os.path.join(DATA_DIR, 'Test/label')
x_trainVal_dir = os.path.join(DATA_DIR, 'TrainVal/color')
y_trainVal_dir = os.path.join(DATA_DIR, 'TrainVal/label')

# Splitting into training and validation
# Getting a list of relative paths to the images (x) and the masks/labels (y)
x_test_fps, y_test_fps = preprocessing.get_testing_paths(x_test_dir, y_test_dir)

# Splitting relative path names into into training and validation
x_train_fps, x_val_fps, y_train_fps, y_val_fps = preprocessing.train_val_split(x_trainVal_dir, y_trainVal_dir, 0.2)

In [17]:
# Creating the dataset
train_augmentation = preprocessing.get_training_augmentation()
preprocessing_fn = preprocessing.get_preprocessing()
val_augmentation = preprocessing.get_validation_augmentation()

train_dataset = PointDataset(x_train_fps, y_train_fps, augmentation=train_augmentation, preprocessing=preprocessing_fn)
val_dataset  = PointDataset(x_val_fps, y_val_fps, augmentation=val_augmentation, preprocessing=preprocessing_fn)

train_loader = DataLoader(train_dataset,batch_size= 16,num_workers=4,pin_memory=True,shuffle=True)
val_loader = DataLoader(val_dataset,batch_size= 16,num_workers=4,pin_memory=True,shuffle=False)


In [128]:
expirement_nunber = 0
experiment_name = f'unet_point_experiment_{expirement_nunber}'
checkpoint_name = 'best_checkpoint.pth'
training_log_name = 'training_log.csv'

In [129]:
dowload_expirement_files(experiment_name, training_log_name, checkpoint_name)

No files downloaded from s3
No files downloaded from s3


In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    '''
    One epoch of training
    '''
    model.train()
    loop = tqdm(loader)
    running_loss = 0.0

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE, dtype=torch.float32)
        targets = targets.long().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.autocast(device_type=DEVICE, dtype=torch.float16):
            targets = targets.squeeze(1)  # (N, H, W)
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

    avg_loss = running_loss / len(loader)
    return avg_loss

def validate_fn(loader, model, loss_fn):
    model.eval()  # set the model to evaluation mode
    total_loss = 0.0
    with torch.no_grad():
        for data, targets in loader:
            data = data.to(device=DEVICE, dtype=torch.float32)
            targets = targets.long().unsqueeze(1).to(device=DEVICE)
            with torch.autocast(device_type=DEVICE, dtype=torch.float16):
                # squeeze targets to get shape (N, H, W)
                predictions = model(data)
                loss = loss_fn(predictions, targets.squeeze(1))
            total_loss += loss.item()
    avg_loss = total_loss / len(loader)
    model.train()  # set model back to training mode
    return avg_loss


def train_and_evaluate(model, optimizer, train_loader, valid_loader, loss_fn, num_epochs, checkpoint_name):
    '''
    Trains the model for a given set of hyperparameters and evaluates it.
    '''

    start_epoch = 0
    # Loading the checkpoint if there is one
    if os.path.exists(checkpoint_name):
        # Loading the checkpoint
        checkpoint = model_utils.return_checkpoint_from(checkpoint_name)
        start_epoch = start_epoch['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print(f'Continuing to train from {start_epoch}')
    else:
        print('Training new model')


    scaler = amp.GradScaler()
    best_val_loss = 100

    for epoch in range(start_epoch, num_epochs):
        # ... training loop ....
        train_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler)
        

        val_loss = validate_fn(valid_loader, model, loss_fn)
        traininglog.log_training(log_filename=training_log_name, epoch = epoch, train_loss = train_loss, val_loss = val_loss)
        print(f'validation loss: {val_loss}')

        # Saving if the validation loss is better
        if val_loss < best_val_loss:
            checkpoint = {'state_dict': model.state_dict(),
                          'epoch': epoch}
            
            model_utils.save_checkpoint(checkpoint, checkpoint_name)
            best_val_loss = val_loss

        # Upload files incase colab crashes
        if epoch % 10 == 0:
            upload_experiment_files(experiment_name, 'training_log_csv', checkpoint_name)

In [None]:
model = UNET(in_channels=4, out_channels=2).to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
#dice_loss = smp.losses.DiceLoss(mode='binary', smooth=1.0)
optimizer = optim.Adam(model.parameters(), lr=1e-4)


train_and_evaluate(model=model,
                   optimizer=optimizer,
                   train_loader=train_loader,
                   valid_loader=val_loader,
                   loss_fn = loss_fn,
                   num_epochs=100
                   )

In [26]:
type(nn.CrossEntropyLoss())

torch.nn.modules.loss.CrossEntropyLoss

In [None]:
s3_utils