## Dependencies

In [None]:
!nvidia-smi
!jupyter notebook list
%env CUDA_VISIBLE_DEVICES=3

%matplotlib inline
%load_ext autoreload
%autoreload 2

import time
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from models import tiramisu
from models import tiramisu_bilinear
from models import tiramisu_m1
from datasets import deepglobe
from datasets import maroads
from datasets import joint_transforms
import utils.imgs
import utils.training as train_utils

# tensorboard
from torch.utils.tensorboard import SummaryWriter

## CamVid

Clone this repository which holds the CamVid dataset
```
git clone https://github.com/alexgkendall/SegNet-Tutorial
```
No. Place deepglobe dataset in datasets/deepglobe/dataset/train,test,valid

In [None]:
run = "expM.1.drop2.1.dicebce"
out_run = ""
DEEPGLOBE_PATH = Path('datasets/', 'deepglobe/dataset')
MAROADS_PATH = Path('datasets/', 'maroads/dataset')
RESULTS_PATH = Path('.results/')
WEIGHTS_PATH = Path('.weights/') / run
RUNS_PATH    = Path('.runs/')
RESULTS_PATH.mkdir(exist_ok=True)
WEIGHTS_PATH.mkdir(exist_ok=True)
RUNS_PATH.mkdir(exist_ok=True)

batch_size = 1 # TODO: Should be `MAX_BATCH_PER_CARD * torch.cuda.device_count()` (which in this case is 1 assuming max of 1 batch per card)

In [None]:
# resize = joint_transforms.JointRandomCrop((300, 300))

normalize = transforms.Normalize(mean=deepglobe.mean, std=deepglobe.std)
train_joint_transformer = transforms.Compose([
#     resize,
    joint_transforms.JointRandomHorizontalFlip(),
    joint_transforms.JointRandomVerticalFlip(),
    joint_transforms.JointRandomRotate()
    ])

train_slice = slice(None,4000)
test_slice = slice(4000,None)

train_dset = deepglobe.DeepGlobe(DEEPGLOBE_PATH, 'train', slc = train_slice,
    joint_transform=train_joint_transformer,
    transform=transforms.Compose([
        transforms.ColorJitter(brightness=.4,contrast=.4,saturation=.4),
        transforms.ToTensor(),
        normalize,
    ]))

train_dset_ma = maroads.MARoads(MAROADS_PATH, 
    joint_transform=train_joint_transformer,
    transform=transforms.Compose([
        transforms.ColorJitter(brightness=.4,contrast=.4,saturation=.4),
        transforms.ToTensor(),
        normalize,
    ]))

# print(len(train_dset_ma.imgs))
# print(len(train_dset_ma.msks))
train_dset_combine = torch.utils.data.ConcatDataset((train_dset, train_dset_ma))

# train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=True)
# train_loader = torch.utils.data.DataLoader(train_dset_ma, batch_size=batch_size, shuffle=True)
train_loader = torch.utils.data.DataLoader(
    train_dset_combine, batch_size=batch_size, shuffle=True)





# resize_joint_transformer = transforms.Compose([
#     resize
#     ])
resize_joint_transformer = None
val_dset = deepglobe.DeepGlobe(
    DEEPGLOBE_PATH, 'valid', joint_transform=resize_joint_transformer,
    transform=transforms.Compose([
        transforms.ToTensor(),
        normalize
    ]))
val_loader = torch.utils.data.DataLoader(
    val_dset, batch_size=batch_size, shuffle=False)

test_dset = deepglobe.DeepGlobe(
    DEEPGLOBE_PATH, 'train', joint_transform=resize_joint_transformer, slc = test_slice,
    transform=transforms.Compose([
        transforms.ToTensor(),
        normalize
    ]))
test_loader = torch.utils.data.DataLoader(
    test_dset, batch_size=batch_size, shuffle=False)

In [None]:
print("Train: %d" %len(train_loader.dataset))
print("Val: %d" %len(val_loader.dataset.imgs))
print("Test: %d" %len(test_loader.dataset.imgs))
# print("Classes: %d" % len(train_loader.dataset.classes))

print((iter(train_loader)))

inputs, targets = next(iter(train_loader))
print("Inputs: ", inputs.size())
print("Targets: ", targets.size())

utils.imgs.view_image(inputs[0])
# utils.imgs.view_image(targets[0])
utils.imgs.view_annotated(targets[0])

print(targets[0])

## Set up Model

In [None]:
LR = 1e-4
LR_DECAY = 0.995
DECAY_EVERY_N_EPOCHS = 1
N_EPOCHS = 1000
torch.cuda.manual_seed(0)

In [None]:
from utils.bceloss import dice_bce_loss
from loss.BCESSIM import BCESSIM

model = tiramisu_m1.FCDenseNetSmall(n_classes=1, dropout_rate=0.2).cuda()

optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-4)

criterion = dice_bce_loss()
# criterion = BCESSIM()

# summary(model, input_size=inputs[0].shape)

In [None]:
start_epoch = 0
# start_epoch = train_utils.load_weights(model, (WEIGHTS_PATH/'latest.th'))
start_epoch = train_utils.load_weights(model, (WEIGHTS_PATH/'weights-368-1.008-0.462.pth'))
print(start_epoch)

In [None]:
# # Writer will output to ./runs/ directory by default
# run = "26"
# writer = SummaryWriter(log_dir=("./.runs/run" + str(run) + "/"))

In [None]:
#test_loader.dataset[0]

In [None]:
# Save images
from utils import imgs as img_utils


OUT_PATH = Path('out/') / (run + out_run)
OUT_PATH.mkdir(exist_ok=True)



from PIL import Image
import os

predictions = train_utils.predict_validation(model, val_loader)


for pred in predictions:

    im = transforms.ToPILImage()(pred[1]*255).convert("RGB")
    print(pred[2][0])
    path = os.path.join(OUT_PATH,os.path.split(pred[2][0])[1])
    print(path)
    im.save(path,"PNG")
    
_pred = predictions[0]

In [None]:
break
from torch.autograd import Variable

# break # errors. Used to stop "run all"
for epoch in range(start_epoch, N_EPOCHS+1):
    since = time.time()

    ### Train ###
    trn_loss, trn_err = train_utils.train(
        model, train_loader, optimizer, criterion, epoch)
    print('Epoch {:d}\nTrain - Loss: {:.4f}, Acc: {:.4f}'.format(
        epoch, trn_loss, 1-trn_err))    
    time_elapsed = time.time() - since  
    print('Train Time {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    
#     ### Test ###
#     val_loss, val_err = train_utils.test(model, val_loader, criterion, epoch)    
#     print('Val - Loss: {:.4f} | Acc: {:.4f}'.format(val_loss, 1-val_err))
#     time_elapsed = time.time() - since  
#     print('Total Time {:.0f}m {:.0f}s\n'.format(
#         time_elapsed // 60, time_elapsed % 60))
#     val_loss = trn_loss
#     val_err = trn_err
    ### Test ###
    tes_loss, tes_err, tes_iou = train_utils.test(model, test_loader, criterion, epoch)    
    print('Tes - Loss: {:.4f} | Acc: {:.4f}'.format(tes_loss, 1-tes_err))
    time_elapsed = time.time() - since  
    print('Total Time {:.0f}m {:.0f}s\n'.format(
        time_elapsed // 60, time_elapsed % 60))
    
#     val_loss = trn_loss
#     val_err = trn_err
    
    ### Checkpoint ###    
    train_utils.save_weights(model, epoch, tes_loss, tes_err)

    ### Adjust Lr ###
#     train_utils.adjust_learning_rate(LR, LR_DECAY, optimizer, 
#                                      epoch, DECAY_EVERY_N_EPOCHS)
    
    # Log on tensorboard
    writer.add_scalar('Loss/train', trn_loss, epoch)
    writer.add_scalar('Loss/test', tes_loss, epoch)
    
    writer.add_scalar('Error/train', trn_err, epoch)
    writer.add_scalar('Error/test', tes_err, epoch)
    
#     writer.add_scalar('Accuracy/train', trn_iou, epoch)
    writer.add_scalar('Accuracy/test', tes_iou, epoch)
    
#     writer.add_scalar('Accuracy/train', epoch_acc, epoch)
#     writer.add_scalar('Accuracy/test/noaug', do_valid(False), epoch)
#     writer.add_scalar('Accuracy/test/tta', do_valid(True), epoch)
    for param_group in optimizer.param_groups:
        writer.add_scalar('Params/learning_rage', param_group['lr'], epoch)
#     writer.add_scalar('params/learning_rate', optimizer.lr, epoch)
#     writer.add_scalar('Params/no_optim', no_optim, epoch)

    # show a sample image
    for i in range(3):
        inputs, targets, pred, loss, err, iou = train_utils.get_sample_predictions(model, test_loader, n=1, criterion=criterion)
#         print(inputs.shape)
        raw = model(inputs.cuda()).cpu()
#         print(raw.shape)
#         print(pred.shape)

# #         img = pred
        
# #         img = torchvision.utils.make_grid([inputs, targets, pred], nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
        
# #         img = torchvision.utils.make_grid(torch.stack([inputs[0], targets[0], pred[0].float().float()]))
# #         print(inputs.shape)
# #         print(targets.shape)
# #         print(pred.shape)
        
#         # print stats on raw
#         print("max", raw.max())
#         print("min", raw.min())
        
        
        img = torchvision.utils.make_grid(torch.stack([
            inputs[0],
            targets[0].unsqueeze(0).expand(3,-1,-1).float(), 
            pred[0].unsqueeze(0).expand(3,-1,-1).float(),
            raw[0].expand(3,-1,-1).float()
        ]), normalize=True)
        

        writer.add_image('test/sample_pred', img, epoch)
        break
    
    start_epoch = epoch
    


## Test

In [None]:
# returns test_loss, test_error, jaccard
train_utils.test(model, test_loader, criterion, epoch=1)  

In [None]:
train_utils.test(model, test_loader, criterion, epoch=1, use_tta=True)  # flip tta

In [None]:
train_utils.test(model, test_loader, criterion, epoch=1, use_tta=True)  # rot tta

In [None]:
stats = train_utils.view_sample_predictions(model, test_loader, n=1, criterion=criterion)
print("loss", "error", "jaccard")
print(stats)
