# Train a DL Semantic Segmentation Model from Scratch

Preliminaries: 

1. This notebook requires a GPU.
From the toolbar select "Runtime" → "Change runtime type" → "GPU" to enable a GPU accelerator.

2. To maintain a high priority Colab user status such that sufficient GPU resources are available in the future, ensure to free the runtime when finished running this notebook. This can be done using 'Runtime > Manage Sessions' and click 'Terminate'.

3. Ensure you have sufficient Google Drive storage available for checkpointing 
models. Saving the weights of a single DeepLab_v3 instance takes ~152MB, so allow 650 free MB for saving four checkpoints over 60 epochs of training.

4. Do not expect to see results immediately, training takes several hours with 
the top of the line Nvidia P100 GPU available in Colab (GPU RAM Free will be 
near 15079MB for a 16GB card). One pass over the validation set takes 1.5 minutes with the P100.

In [None]:
# Check if this notebook is running in Colab or local workstation
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    !ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
    !pip install gputil
    !pip install psutil
    !pip install humanize

import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()

try:
    # XXX: only one GPU on Colab and isn’t guaranteed
    gpu = GPUs[0]
    def printm():
        process = psutil.Process(os.getpid())
        print("Gen RAM Free: " + humanize.naturalsize(
            psutil.virtual_memory().available ), 
            " | Proc size: " + humanize.naturalsize(process.memory_info().rss))
        print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(
            gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
    printm() 

    # Check if GPU capacity is sufficient to proceed
    if gpu.memoryFree < 10000:
        print("\nInsufficient memory! Some cells may fail. Please try restarting the runtime using 'Runtime → Restart Runtime...' from the menu bar. If that doesn't work, terminate this session and try again later.")
    else:
        print('\nGPU memory is sufficient to proceeed.')
except:
    print('Select the Runtime → "Change runtime type" menu to enable a GPU accelerator, ')
    print('and then re-execute this cell.')

In [None]:
if IN_COLAB:

    from google.colab import drive
    drive.mount('/content/drive')
    DATA_PATH = r'/content/drive/My Drive/Data'
    
    # cd into git repo so python can find utils
    %cd '/content/drive/My Drive/cciw-zebra-mussel/predict'

    sys.path.append('/content/drive/My Drive')

In [None]:
import os
import os.path as osp
import time

import glob

# for manually reading high resolution images
import cv2
import csv
import numpy as np

# for comparing predictions to lab analysis data frames
import pandas as pd

# for plotting
import matplotlib.pyplot as plt

# pytorch core library
import torch  # tested with torch.__version__==1.4.0
# pytorch neural network functions
from torch import nn

# import pytorch computer vision utils
import torchvision

# for DeepLab_v3 segmentation model
from torchvision.models import segmentation as models

# pytorch dataloader
from torch.utils.data import DataLoader

# for post-processing model predictions by conditional random field 
#import pydensecrf.densecrf as dcrf
#import pydensecrf.utils as utils

# using tqdm.notebook for friendly progress bar
import tqdm 

# evaluation metrics
from sklearn.metrics import r2_score
from sklearn.metrics import jaccard_score as jsc

# load custom pre-processeing transforms for use with VOCSegmentation loader
import transforms as T

# various helper functions, metrics that can be evaluated on the GPU
from task_3_utils import (save_checkpoint,
                          save_amp_checkpoint,
                          evaluate,
                          eval_binary_iou,
                          evaluate_loss_and_iou_torchvision,
                          adjust_learning_rate)

# Custom dataloader for rapidly loading images from a single LMDB file
from folder2lmdb import VOCSegmentationLMDB

In [None]:
"""Confim that this cell prints "Found GPU, cuda". If not, select "GPU" as 
"Hardware Accelerator" under the "Runtime" tab of the main menu. """
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Found GPU,', device)

# 1. Basic setup and meta-parameters

Preliminary setup including path to dataset, data version, split type, and
logdir folder for saving intermediate results and trained model weights.

In [None]:
""" type=str, path to dataset in LMDB format for efficiently 
loading data from Google Drive into Colab"""
dataroot = osp.join(DATA_PATH, 'ADIG_Labelled_Dataset/LMDB')

""" type=str, dataset version according to https://semver.org/ convention', 
choices=['v100', 'v101', 'v110', 'v111', 'v112', 'v120']"""
data_version = 'v120' 

""" type=str, training split, choices=['train', 'trainval']
it is recommended to use train for validating generalization to 2019, and 
using trainval before final deployment"""
split = 'train'

# top level directory to store checkpoints; if None, nothing will be saved'
#logdir = osp.join(DATA_PATH, 'Checkpoints/logs')
logdir = 'logs'

# type=str, path to latest checkpoint (default: none)
#resume = None
resume = osp.join(
    DATA_PATH, 'Checkpoints/logs/train_v120/deeplabv3_resnet50_lr1e-01_wd5e-04_bs50_ep60_seed2/checkpoint/deeplabv3_resnet50_lr1e-01_wd5e-04_bs50_ep60_seed2_epoch20.ckpt')

# type=bool, print training and validation statistics during training
do_print = True

# type=int, random seed, makes training deterministic 
seed = 1

torch.manual_seed(seed)

In [None]:
model_names = sorted(
    name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name])
    )
print('Available model choices:\n', model_names)

Select model architecture from `model_names` and meta-parameters: number of training epochs, batch size, learning rate schedule, etc.

In [None]:
# type=str, model architecture to use, choices=model_names
arch = 'deeplabv3_resnet50'

# type=int, number of training epochs, i.e., number of passes over the full 
# dataset and number of times each sample is shown to the model.
epochs = 40

# type=int, epoch to drop the initial learning rate by a factor of ten
drop = 30

# type=int, mini-batch size for stochastic gradient descent (SGD) training
bs = 50

# type=float, the initial learning rate value for SGD
lr = 1e-1

# type=float, L2 weight decay regularization constant
wd = 5e-4

# use apex to train with 16-bit float parameters (not currently enabled)
fp16 = False

In [None]:
ckpt_name = arch + '_lr%.e_wd%.e_bs%d_ep%d_seed%d' % (lr, wd, bs, epochs, seed)

if logdir is not None:
    save_path = osp.join(logdir, split + '_' + data_version, ckpt_name)
    print('Saving training statistics to ', save_path)
    ckpt_path = osp.join(osp.join(DATA_PATH, 'Checkpoints/'), save_path)
    print('Saving model weights to ', ckpt_path)

    # Logging stats
    result_folder = osp.join(save_path, 'results/')

    if not osp.exists(result_folder):
        os.makedirs(result_folder)
        print("Created folder", result_folder)

    logname = osp.join(result_folder, ckpt_name + '.csv')

    if not osp.exists(logname):
        with open(logname, 'w') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow(['epoch', 'lr', 'train loss', 'val loss'])
else:
    print('No checkpoints will be saved')

# 2. Define Image Pre-Processing Transforms and Data Augmentation

Here, we define transforms to be applied to input images (`inputs`) and segmentation masks (`targets`)
on the fly as we draw mini-batches iteratively like:

```
for inputs, targets in dataloader:
    pass
```

These transforms are documented here: https://pytorch.org/docs/stable/torchvision/transforms.html

We may wish to experiment with additional ones in the future, e.g., `ColorJitter` to perturb the image colours, 
or `Grayscale` to convert the dataset to Greyscale and quantify the marginal impact of colour information on model performance.

In [None]:
training_tforms = []

# Randomly crop images to square 224x224
training_tforms.append(T.RandomCrop(224))

# With probability 0.5, flip the images and masks horizontally.
# This increases the effective size of our training set, as 
# mussels are rotation invariant.
training_tforms.append(T.RandomHorizontalFlip(0.5)) 

# Similarly, flip the images and masks vertically with probability 0.5.
training_tforms.append(T.RandomVerticalFlip(0.5))

# Convert images from Python Imaging Library (PIL aka Pillow) format to PyTorch Tensor.
training_tforms.append(T.ToTensor())

"""
T.Normalize performs: image = (image - mean) / std

The first argument (a triple) to T.Normalize are the global 
RGB pixel mean values, and the second argument is their standard deviation. 

For a mini-batch 'inputs' comprised of N samples, 
C channels, e.g. 3 for RGB images, height H, width W, and 
inputs.shape = torch.Size([N, C, H, W]), this can be obtained using:

inputs.mean(dim=(0, 2, 3)), which will output a tensor, e.g., 
tensor([0.2613, 0.2528, 0.2255]). 

The standard deviation can be obtained similarly with:
inputs.std(dim=(0, 2, 3))

The global values can simply be obtained by averaging over all 
mini-batches in the dataset.

For the natural mussel dataset (i.e. not the Lab images), 
these global pixel values are somewhat meaningless to due 
significant changes in lighting and hue, so we simply
pass the triple (0.5, 0.5, 0.5) for both mean and std to
normalize the input image pixels from [0, 1] to [-1, 1]. 
This centers the images and resulting feedforward activations 
around zero and allows training to proceed more smoothly.
"""
training_tforms.append(
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))

# Finally, Compose several transforms together.
training_tforms = T.Compose(training_tforms)

For validation and testing, we often want these transforms to be deterministic to be sure the model is making progress with respect to the natural image distribution. We will evaluate on fixed 250x250 patches rather than randomly cropping.

For evaluating robustness, we could add `ColorJitter` and do scaling or shearing with various Affine transforms here...

In [None]:
test_tform = T.Compose([
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 3. Create Efficient Data Loaders

Specify the mini-batch size (`batch_size`) for validation, and path to serialized LMDB dataset `dataset_root`. 

The `batch_size` is arbitrary at test time since we aren't using `nn.BatchNorm()`, the main consideration here 
is to use the largest `batch_size` the GPU memory allows to maximize throughput. The default setting should be fine.

The `VOCSegmentationLMDB` class was adapted from https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.VOCSegmentation
to enable reading data from a single `*.lmdb` database which is much more efficient on conventional hard drives than randomly reading images.

Note that transforms provided to the `transforms` argument apply to both input images and masks. 
The label values will be rotated accordingly as the input images, but the labels are unaffected by the normalization due to being limited to values 0/1.

In [None]:
training_set = VOCSegmentationLMDB(
    root=osp.join(dataroot, split + '_' + data_version + '.lmdb'), 
    transforms=test_tform)

train_loader = DataLoader(training_set, batch_size=bs, shuffle=True)

In [None]:
valset = VOCSegmentationLMDB(
        root=osp.join(dataroot, 'val_v101.lmdb'), transforms=test_tform)
val_loader = DataLoader(valset, batch_size=bs, shuffle=False)

In [None]:
if logdir is not None:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(save_path, flush_secs=60)

In [None]:
"""Prepare model NB even though there are two classes 
(i.e. mussel and background),
num_classes=1 is used such that nn.Sigmoid(pred) = 0 is bkg, and 1 is mussel.

Could instead use num_classes=2 and nn.CrossEntropyLoss() such that the
*channel* rather than the *value* encodes the class, but this would
require a one-hot label format.

n_channels=3 for RGB images
n_classes is the number of probabilities you want to get per pixel
  - For 1 class and background, use n_classes=1
  - For 2 classes, use n_classes=1
  - For N > 2 classes, use n_classes=N
"""
print("=> creating model '{}'".format(arch))
net = models.__dict__[arch](num_classes=1).to(device)

In [None]:
# Prepare training procedure
optimizer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)

To compute the `pos_weight` from the dataset, uncomment the following cell.

In [None]:
# Note: this cell is optional !
'''
total_mussel = 0.
total_pixels = 0.
for idx, data in enumerate(valloader):
    total_mussel += (data[1] == 1).sum().float().item()
    total_pixels += (data[1] == 0).sum().float().item()
    print('Batch %d of %d, pos_weight=%.4f' % (idx, len(valloader), total_mussel / total_pixels))
print('pos_weight={:.4f}'.format(total_pixels / total_mussel))
'''

In [None]:
# pos_weight by inverse frequency of `mussel` pixels
if data_version == 'v101':
    pos_weight = 3.6891

elif data_version == 'v111':
    if split == 'train':
        pos_weight = 3.4270 # train
    else:
        pos_weight = 3.6633 # trainval

elif data_version == 'v112':
    if split == 'train':
        pos_weight =  3.6021 # train

elif data_version == 'v120':
    if split == 'train':
        pos_weight = 3.1849  # train
    else:
        pos_weight = 3.4297 # trainval

print('Will weight the loss by inverse mussel class frequency', pos_weight)
train_pos_weight = torch.FloatTensor([pos_weight]).to(device)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=train_pos_weight)

# 4.2838 for val_v101
val_pos_weight = torch.FloatTensor([4.2838]).to(device)
val_loss_fn = nn.BCEWithLogitsLoss(pos_weight=val_pos_weight)

sig = nn.Sigmoid()  # initializes a sigmoid function

In [None]:
def evaluate_loss(net, data_loader, loss_fn, device):
    """Evaluates the cross entropy loss of DL model given by `net` on data from
    `data_loader`
    """
    running_loss = 0

    for inputs, targets in data_loader:
        break

    with torch.no_grad():
        for inputs, targets in tqdm.notebook.tqdm(data_loader, unit=' batches'):
            inputs, targets = inputs.to(device), targets.to(device)
            pred = net(inputs)['out']
            # dataloader outputs targets with shape NHW, but we need NCHW
            batch_loss = loss_fn(pred, targets.unsqueeze(dim=1).float())
            running_loss += batch_loss.item()
    return running_loss / len(data_loader)

# 4. Optionally load a pre-trained checkpoint, 

Can be used to resume training if it was previously interrupted.

In [None]:
if resume is not None:
    #if osp.isfile(resume):
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    checkpoint = torch.load(resume)
    net = checkpoint['net']
    # only if using 16-bit floating point with apex amp
    #net.load_state_dict(checkpoint['net'])
    #optimizer.load_state_dict(checkpoint['optimizer'])
    #amp.load_state_dict(checkpoint['amp'])
    start_epoch = checkpoint['epoch'] + 1
    val_loss = checkpoint['val_loss']
    torch.set_rng_state(checkpoint['rng_state'])
    ckpt_bs = int(resume.split('/')[-1].split('_')[-4][2:])
    global_step = start_epoch * (len(training_set) // ckpt_bs)

    """Compute training and validation cross-entropy losses to ensure model 
    was loaded correctly and that the data were pre-processed in consistent 
    manner w.r.t. the training script."""
    calculate_validation_loss = evaluate_loss(
        net, val_loader, val_loss_fn, device)
    assert np.allclose(calculate_validation_loss, val_loss, atol=1e-3)
    print('\n Validation loss of {:.4f} matches checkpoint'.format(
        calculate_validation_loss))
else:
    start_epoch = 0
    global_step = 0

Define a function for efficiently evaluating the cross entropy loss and IoU

In [None]:
def evaluate_loss_and_iou_torchvision(net, data_loader, loss_fn, device):
    """Evaluates the cross entropy loss and IoU of DL model given by `net` on
    data from `data_loader`
    """
    sig = nn.Sigmoid()
    batch = 0
    running_iou = 0
    running_loss = 0

    for inputs, targets in data_loader:
        break

    with torch.no_grad():
        """Note that the unit here is batches, so multiply by the batch size to 
        get the number of images processed per second."""
        for inputs, targets in tqdm.notebook.tqdm(data_loader, unit=' batches'):
            inputs, targets = inputs.to(device), targets.to(device)
            
            """Does a feedforward pass through the model. Remove the ['out'] 
            part if using a model not provided through the torchvision repo, 
            e.g., FCN8slim or U-Net."""
            pred = net(inputs)['out']
            # dataloader outputs targets with shape NHW, but we need NCHW
            batch_loss = loss_fn(pred, targets.unsqueeze(dim=1).float())

            # efficiently computes 
            bin_iou = eval_binary_iou(sig(pred).round(), targets)
            if (bin_iou > 0).sum() > 1:
                iou = bin_iou[bin_iou > 0].mean().item()
                running_iou += iou
                batch += 1

            running_loss += batch_loss.item()

    return running_loss / len(data_loader), running_iou / float(batch + 1e-6)

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

After launching tensorboard using the next cell, you will initially
see the message: 
- __No dashboards are active for the current data set__. 

This dashboard will be populated automatically once training is launched. 


Once the model starts training, you will see two (2) tabs in TensorBoard: "SCALARS" and "IMAGES". 

1. In the __SCALARS__ pane, you will see three drop down menus:

*   The __IoU__ menu tracks the mean intersection-over-union for the validation set over time, which should generally increase.

*   The __L2norm__ menu tracks the L2 norm of various parameters, and serves as a basic sanity check that the parameters are changing from their initial values during learning. The L2 norm can go up or down, but will generally go down in proportion to the magnitude of the weight decay `wd` regularization term.

*   The __Loss__ menu tracks the loss for each training mini-batch and for the 
validation set. The validation loss should generally go down but it's possible for IoU to continue to increase while the validation loss increases as the 
model confidence increases but the number of pixelwise errors decreases.

2. In the __IMAGES__ pane, there are three drop down menus:

* *images* show 224 pixel input images (training set)

* *labels* show the corresponding ground truth

* *predictions* show the model output as a greyscale image

The statistics plotted here will only be available for the duration of 
the Colab session due to the prohibitive latency of frequently writing data 
into Google Drive. Data for each of the scalars panes can be downloaded as a 
`csv` file, however, by checking the box __Show data download links__, then hovering over the grey text __run to download__ in the bottom right hand corner of any pane, select a run, then click the __CSV__ button to trigger the download.

On the other hand the model checkpoints are saved into the `ckpt_path` which is persistent on Google Drive. Once checkpoints are saved there this notebook can be halted and training resumed at any time using the `resume` argument in Section 1.

In [None]:
%tensorboard --logdir logs

# 5. Train the model

In [None]:
for epoch in range(start_epoch, epochs):

    eval_start_time = time.time()
    
    """Put the model into evaluation mode. This is relevant when using layers 
    like dropout or batch normalization that have different behaviour at 
    training and test time. Batch normalization normalizes the activations of a 
    layer using mini-batch statistics at training time, versus acculumated 
    population statistics at test time."""
    net.eval()
    
    # validate the model starting from first initialization
    val_loss, val_iou = evaluate_loss_and_iou_torchvision(
        net, val_loader, val_loss_fn, device)
    
    if do_print:
        print('Epoch [%d/%d], val IoU %.4f, val loss %.4f, took %.2f sec, ' % (
            (epoch, epochs, val_iou, val_loss, time.time() - eval_start_time)))
    if logdir is not None:
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('IoU/val', val_iou, epoch)

    epoch_start_time = time.time()
    
    train_loss = 0
    
    net.train()

    for batch, (inputs, targets) in enumerate(train_loader):
        
        # implements the learning rate schedule, drops lr if necessary
        lr = adjust_learning_rate(optimizer, epoch, drop, lr)

        optimizer.zero_grad() # reset gradients to zero

        """inputs are in NCHW format: N=nb. samples, C=channels, H=height,
        W=width. Do inputs.permute(0, 2, 3, 1) to viz in RGB format."""
        inputs, targets = inputs.to(device), targets.to(device)
        pred = net(inputs)['out'] # fprop for torchvision.models.segmentation

        # dataloader outputs targets with shape NHW, but we need NCHW
        batch_loss = loss_fn(pred, targets.unsqueeze(dim=1).float())
        train_loss += batch_loss.item()

        if fp16:
            with amp.scale_loss(batch_loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            batch_loss.backward() # bprop

        optimizer.step() # update parameters

        if batch % 10 == 0:
            print('Batch [{}/{}], train loss: {:.4f}'
                  .format(batch, len(train_loader), batch_loss.item()))  #, train IoU: {:.4f}'

            if logdir is not None:
                writer.add_scalar('Loss/train mini-batch', batch_loss.item(), global_step)

                with torch.no_grad():
                    for n, p in net.named_parameters():
                        if 'weight' in n.split('.'):
                            writer.add_scalar('L2norm/' + n, p.norm(2), global_step)
        global_step += 1

    epoch_time = time.time() - epoch_start_time
    train_loss /= len(train_loader)

    if logdir is not None:
        writer.add_scalar('Loss/train', train_loss, epoch + 1)
        img_grid = torchvision.utils.make_grid(inputs[:16])
        sig_grid = torchvision.utils.make_grid(sig(pred[:16]))
        lab_grid = torchvision.utils.make_grid(
            targets[:16].unsqueeze(dim=1).float())
        writer.add_image('images', img_grid, epoch)
        writer.add_image('predictions', sig_grid, epoch)
        writer.add_image('labels', lab_grid, epoch)

    """Saves a checkpoint every 20 epochs by default. Can alternatively
    save a checkpoint iff it reduces the validation loss"""
    if epoch % 20 == 0:
        print('Saving checkpoint ', epoch)
        #if val_loss < best_val_loss:
        #    best_val_loss = val_loss
        if fp16:
            save_amp_checkpoint(
                net, amp, optimizer, val_loss, train_loss, epoch, ckpt_path, 
                ckpt_name)
        else:
            save_checkpoint(
                net, val_loss, train_loss, epoch, ckpt_path, ckpt_name)

    print('Epoch [{}/{}], train loss: {:.4f}, val loss: {:.4f}, took {:.2f} s'
          .format(epoch + 1, epochs, train_loss, val_loss, epoch_time))

    if logdir is not None:
        with open(logname, 'a') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow(
                [epoch, lr, np.round(train_loss, 4), np.round(val_loss, 4)])
if fp16:
    save_amp_checkpoint(
        net, amp, optimizer, val_loss, train_loss, epoch, ckpt_path, ckpt_name)
else:
    save_checkpoint(net, val_loss, train_loss, epoch, ckpt_path, ckpt_name)

writer.close()