In [1]:
import json
import os
import socket
import sys
from datetime import datetime
from pathlib import Path

import numpy as np
import rasterio.shutil
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm

from models.UNet import UNet
from utils.dataset.SegmentationDataset import SegmentationDataset
from utils.dataset.transforms import transforms as T
from utils.eval import predict_tiff, eval_model
from utils.loss import assymetric_tversky_loss
from utils.loss import iou

In [2]:
DISABLE_CUDA = False

# For running script locally without Docker use these for e.g
checkpoint_dir = Path('unet/seagrass/checkpoints')
weights_dir = Path('unet/seagrass/model_weights')
train_data_dir = Path('unet/seagrass/train_input/data/train')
eval_data_dir = Path('unet/seagrass/train_input/data/eval')
seg_in_dir = Path("unet/seagrass/train_input/data/segmentation")
seg_out_dir = Path("unet/seagrass/train_output/segmentation")
restart_training = False

# Load hyper-parameters dictionary
num_classes = 2  # "2",
num_epochs = 200  # "200",
batch_size = 1  # "8",
lr = 0.001
weight_decay = 0.0

# Make results reproducible
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)

if not DISABLE_CUDA and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print("Using device:", device)

Using device: cuda


# Check ratio of classes

In [3]:
ds_train = SegmentationDataset(train_data_dir, transform=T.test_transforms,
                               target_transform=T.test_target_transforms)
ds_val = SegmentationDataset(eval_data_dir, transform=T.test_transforms,
                             target_transform=T.test_target_transforms)

dataloader_opts = {
    "batch_size": batch_size,
    "pin_memory": True,
    "drop_last": True,
    "num_workers": os.cpu_count()
}
data_loaders = {
    'train': DataLoader(ds_train, shuffle=False, **dataloader_opts),
    'eval': DataLoader(ds_val, shuffle=False, **dataloader_opts),
}

In [4]:
# totals = {
#     'train': np.zeros((2,)),
#     'eval': np.zeros((2,))
# }

# for phase in ['train', 'eval']:
#     for _, y in tqdm(data_loaders[phase]):
#         y = y.to(device)
#         values, counts = torch.unique(y, return_counts=True)

#         for i, c in zip(values, counts):
#             totals[phase][i] += c

#     print(totals[phase])

In [5]:
# import matplotlib.pyplot as plt

# plt.bar(["not seagrass", "seagrass"], totals['train'])
# plt.title("Train")
# plt.show()

# plt.bar(["not seagrass", "seagrass"], totals['eval'])
# plt.title("Eval")
# plt.show()

# Create vegetation indices for images

In [6]:
from torchvision import transforms

def add_veg_indices(x):
    """Add a slew of RGB vegetation indices to the image.
    x: torch.Tensor RGB image of shape (n, c, h, w)

    returns modified image tensor with bands in order
        n, [rgbvi, vari, gli, ngrdi, r, g, b], h, w
    """
    x = add_ngrdi_band(x)
    x = add_gli_band(x)
    x = add_vari_band(x)
    x = add_rgbvi_band(x)
    return x
    
def add_rgbvi_band(img):
    """Add RGBVI as band 0.
    img: a Torch tensor of shape (n, c, h, w)
     It is assumed there are at least 3 channels, 
        with RGB located at index -3, -2, -1, respectively

    returns img of shape (n, c+1, h, w)
    """
    r, g, b = img[:,-3,:,:], img[:,-2,:,:], img[:,-1,:,:]

    rgbvi = (torch.mul(r,r) - torch.mul(r,b)) / (torch.mul(g,g) + torch.mul(r,b))
    return torch.cat((rgbvi.unsqueeze(1), img), dim=1)

def add_gli_band(img):
    """Add GLI as band 0.
    img: a Torch tensor of shape (n, c, h, w)

    returns img of shape (n, c+1, h, w)
    """
    r, g, b = img[:,-3,:,:], img[:,-2,:,:], img[:,-1,:,:]
    gli = (2*g - r - b) / (2*g + r + b)
    return torch.cat((gli.unsqueeze(1), img), dim=1)

def add_vari_band(img):
    """Add VARI as band 0.
    img: a Torch tensor of shape (n, c, h, w)

    returns img of shape (n, c+1, h, w)
    """
    r, g, b = img[:,-3,:,:], img[:,-2,:,:], img[:,-1,:,:]
    vari = (g - r) / (g + r - b)
    return torch.cat((vari.unsqueeze(1), img), dim=1)

def add_ngrdi_band(img):
    """Add NGRDI as band 0.
    img: a Torch tensor of shape (n, c, h, w)

    returns img of shape (n, c+1, h, w)
    """
    r, g, b = img[:,-3,:,:], img[:,-2,:,:], img[:,-1,:,:]
    ngrdi = (g - r) / (g + r)
    return torch.cat((ngrdi.unsqueeze(1), img), dim=1)

out = add_veg_indices(next(iter(data_loaders['train']))[0])

print(out.shape)

torch.Size([1, 7, 513, 513])


# Create UNet instance

In [7]:
n_channels = 7  # [rgbvi, vari, gli, ngrdi, r, g, b]

# Model
os.environ['TORCH_HOME'] = str(checkpoint_dir.parents[0])
model = UNet(n_channels=n_channels, n_classes=num_classes, bilinear=True)
model = model.to(device)
model = nn.DataParallel(model)

## Train loop

In [8]:
def assymetric_tversky_loss(p, g, beta=1.):
    """Loss function from the paper S. R. Hashemi, et al, 2018. "Asymmetric loss functions and deep densely-connected
    networks for highly-imbalanced medical image segmentation: application to multiple sclerosis lesion detection"
    https://ieeexplore.ieee.org/abstract/document/8573779.
    Electronic ISSN: 2169-3536. DOI: 10.1109/ACCESS.2018.2886371.

    p: predicted output from a sigmoid-like activation. (i.e. range is 0-1)
    g: ground truth label of pixel (0 or 1)
    beta: parameter that adjusts weight between FP and FN error importance. beta=1. simplifies to the Dice loss function
    (F1 score) and weights both FP and FNs equally. B=0 is precicion, B=2 is the F_2 score

    >>> np.around(assymetric_tversky_loss(torch.Tensor([0.9, 0.5, 0.2]), torch.Tensor([1., 0., 1.]), beta=1.).numpy(), 6)
    0.611111
    """
    p = p.flatten().float()
    g = g.flatten().float()
    bsq = beta * beta
    pg = torch.sum(torch.mul(p, g))
    similarity_coeff = ((1 + bsq) * pg) / (
        ((1 + bsq) * pg) + (bsq * torch.sum(torch.mul((1 - p), g))) + (torch.sum(torch.mul(p, (1 - g))))
    )
    
    return 1 - similarity_coeff

In [9]:
# %load ./unet/unet.py
#!/usr/bin/env python

def alpha_blend(bg, fg, alpha=0.5):
    return fg * alpha + bg * (1 - alpha)


def train_model(model, device, dataloaders, num_classes, optimizer, criterion, num_epochs, checkpoint_dir, output_dir,
                lr_scheduler=None, start_epoch=0):
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    log_dir = os.path.join(checkpoint_dir.joinpath('runs', current_time + '_' + socket.gethostname()))

    writers = {
        'train': SummaryWriter(log_dir=log_dir + '_train'),
        'eval': SummaryWriter(log_dir=log_dir + '_eval')
    }

    best_val_loss = None
    best_val_iou_seagrass = None
    best_val_miou = None

    for epoch in range(start_epoch, num_epochs):
        for phase in ['train']:
            sum_loss = 0.
            sum_iou = np.zeros(num_classes)

            for x, y in tqdm(iter(dataloaders[phase]), desc=f"{phase} epoch {epoch}", file=sys.stdout):
                y = y.to(device)
                x = x.to(device)
                
                # Add vegetation indices
                x = add_veg_indices(x)

                optimizer.zero_grad()

                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                pred = model(x)
#                 loss = criterion(pred.float(), y)
                sig = torch.sigmoid(pred[:,1])
                loss = assymetric_tversky_loss(sig, y, beta=1.)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # Compute metrics
                sum_loss += loss.detach().cpu().item()
                sum_iou += iou(y, pred.float()).detach().cpu().numpy()

            mloss = sum_loss / len(dataloaders[phase])
            ious = np.around(sum_iou / len(dataloaders[phase]), 4)
            miou = np.mean(ious)
            iou_bg = sum_iou[0] / len(dataloaders[phase])
            iou_seagrass = sum_iou[1] / len(dataloaders[phase])

            print(f'{phase}-loss={mloss}; {phase}-miou={miou}; {phase}-iou-bg={iou_bg}; {phase}-iou-seagrass={iou_seagrass};')

            writers[phase].add_scalar('Loss', mloss, epoch)
            writers[phase].add_scalar('IoU/Mean', miou, epoch)
            writers[phase].add_scalar('IoU/BG', iou_bg, epoch)
            writers[phase].add_scalar('IoU/Kelp', iou_seagrass, epoch)

            # Show images
            x = x[:,-3:,:,:]
            img_grid = torchvision.utils.make_grid(x, nrow=8)
            img_grid = T.inv_normalize(img_grid)

            y = y.unsqueeze(dim=1)
            label_grid = torchvision.utils.make_grid(y, nrow=8).cuda()
            label_grid = alpha_blend(img_grid, label_grid)
            writers[phase].add_image('Labels/True', label_grid.detach().cpu(), epoch)

            pred = pred.max(dim=1)[1].unsqueeze(dim=1)
            pred_grid = torchvision.utils.make_grid(pred, nrow=8).cuda()
            pred_grid = alpha_blend(img_grid, pred_grid)
            writers[phase].add_image('Labels/Pred', pred_grid.detach().cpu(), epoch)

            # Save model checkpoints
            if phase == 'train':
                # Model checkpoint after every train phase every epoch
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'mean_eval_loss': mloss,
                }, Path(checkpoint_dir).joinpath('unet.pt'))
            else:
                # Save best models for eval set
                if best_val_loss is None or mloss < best_val_loss:
                    best_val_loss = mloss
                    torch.save(model.state_dict(), Path(output_dir).joinpath('unet_best_val_loss.pt'))
                if best_val_iou_seagrass is None or iou_seagrass < best_val_iou_seagrass:
                    best_val_iou_seagrass = iou_seagrass
                    torch.save(model.state_dict(), Path(output_dir).joinpath('unet_best_val_seagrass_iou.pt'))
                if best_val_miou is None or miou < best_val_miou:
                    best_val_miou = miou
                    torch.save(model.state_dict(), Path(output_dir).joinpath('unet_best_val_miou.pt'))

        if lr_scheduler is not None:
            lr_scheduler.step()

    writers['train'].flush()
    writers['train'].close()
    writers['eval'].flush()
    writers['eval'].close()

In [10]:
indices = {'train': [], 'eval': []}

for phase in ['train', 'eval']:
    for i, (_, y) in enumerate(tqdm(data_loaders[phase])):
        y = y.to(device)
        values, counts = torch.unique(y, return_counts=True)
        if len(counts) > 1:
            indices[phase].append(i)
            
        if len(indices[phase]) == batch_size:
            break

HBox(children=(FloatProgress(value=0.0, max=33320.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=8470.0), HTML(value='')))

In [11]:
indices

{'train': [3], 'eval': [43]}

In [12]:
# Datasets
num_gpus = torch.cuda.device_count()
print(num_gpus, "GPU(s) detected")
if num_gpus > 1:
    batch_size *= num_gpus

ds_train = torch.utils.data.Subset(ds_train, indices['train'])
ds_val = torch.utils.data.Subset(ds_val, indices['eval'])
data_loaders = {
    'train': DataLoader(ds_train, shuffle=False, **dataloader_opts),
    'eval': DataLoader(ds_val, shuffle=False, **dataloader_opts),
}

# Optimizer, Loss
print(lr)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()

# Train the model
checkpoint_path = checkpoint_dir.joinpath('unet.pt')

train_model(model, device, data_loaders, num_classes, optimizer, criterion, num_epochs, checkpoint_dir,
            weights_dir, start_epoch=0)

1 GPU(s) detected
0.001


HBox(children=(FloatProgress(value=0.0, description='train epoch 0', max=1.0, style=ProgressStyle(description_…


train-loss=0.5541805624961853; train-miou=0.33785; train-iou-bg=0.38191282749176025; train-iou-seagrass=0.2937708795070648;


HBox(children=(FloatProgress(value=0.0, description='train epoch 1', max=1.0, style=ProgressStyle(description_…


train-loss=0.519179105758667; train-miou=0.3489; train-iou-bg=0.37440454959869385; train-iou-seagrass=0.3234141170978546;



HBox(children=(FloatProgress(value=0.0, description='train epoch 2', max=1.0, style=ProgressStyle(description_…


train-loss=0.4683103561401367; train-miou=0.38039999999999996; train-iou-bg=0.39330166578292847; train-iou-seagrass=0.3674554228782654;


HBox(children=(FloatProgress(value=0.0, description='train epoch 3', max=1.0, style=ProgressStyle(description_…

KeyboardInterrupt: 