## Import Packages

In [1]:
import zipfile as zf
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
import cv2
import os
from tqdm.notebook import tqdm
from PIL import Image

import future
from torch.utils.tensorboard import SummaryWriter

%load_ext autoreload
%autoreload 2

## Define the ENet model

We decided to model following residual blocks as separate class to model ENET encoder and decoder:
    - Initial block
    - RDDNeck - class for regular, downsampling and dilated bottlenecks
    - ASNeck - class for asymetric bottlenecks
    - UBNeck - class for upsampling bottlenecks

ENET architecture is autoencoder based model and is divided into 5 sub-blocks. Pleas refer [ENET paper](https://arxiv.org/pdf/1606.02147.pdf) for details of each sub-block. ENET building blocks code is taken from [here](https://github.com/iArunava/ENet-Real-Time-Semantic-Segmentation).

Fast scene understanding uses first 2 sub-blocks as encoder and remaining 3 as decoder. In this implemantation, there is 1 shared encoder and 3 separate decoder for 3 tasks(instance segementation, semantic segmentation, Depth estimation )

In [2]:
import os, sys
nb_dir = os.getcwd()
if nb_dir not in sys.path:
    sys.path.append(nb_dir)

In [3]:
from models.ENetDecoder import ENetDecoder
from models.ENetEncoder import ENetEncoder

class BranchedENet(nn.Module):
    def __init__(self, C):
        super().__init__()
        
        # Define class variables
        # C - number of classes
        self.C = C
        
        self.enc = ENetEncoder(C)
        
        self.dec1 = ENetDecoder(C)
        self.dec2 = ENetDecoder(1)
        #self.dec3 = ENetDecoder(1)
        
        
    def forward(self, x):
        # Output of Encoder
        x, i1, i2 = self.enc(x)
        # output of all 3 decoder in tuple
        #x = (self.dec1(x, i1, i2), self.dec2(x, i1, i2), self.dec3(x, i1, i2))
        x = (self.dec1(x, i1, i2), self.dec2(x, i1, i2))
        return x

## Instantiate the ENet model

In [4]:
enet = BranchedENet(20)

In [5]:
# Checking if there is any gpu available and pass the model to gpu or cpu
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
enet = enet.to(device)

## Define Dataloader

In [6]:
from data.cityscapes import Cityscapes as dataset

height = 512
width = 1024
dataset_dir = 'data/cityscape'
image_transform = transforms.Compose(
        [transforms.Resize((height,width)),transforms.ToTensor()])
train_set = dataset(dataset_dir,transform=image_transform)
test_set = dataset(dataset_dir,mode='test',transform=image_transform)

In [7]:
# get class weights
"""
train_loader = data.DataLoader(train_set,batch_size=1,num_workers=4)
trainiter = iter(train_loader)

all_labels_sum = 0
each_class = np.zeros(20)
for _ in tqdm(range(len(train_set))):
    _, labels, _, _ = trainiter.next()
    all_labels = labels.flatten()
    each_class += np.bincount(all_labels, minlength=20)
    all_labels_sum += len(all_labels)
    
propensity_score = each_class / all_labels_sum
class_weights = 1 / (np.log(1.02 + propensity_score))
"""
# manually define class weights from previous attempt
class_weights = [1.42227783, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979,
50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979,
50.49834979, 50.49834979]

In [8]:
# show the class weights
print(class_weights)

[1.42227783, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979, 50.49834979]


In [9]:
batch_size=1
train_loader = data.DataLoader(train_set,batch_size=batch_size,shuffle=True,
        num_workers=4)

test_loader = data.DataLoader(test_set,batch_size=batch_size,shuffle=True,
        num_workers=4)

trainiter = iter(train_loader)
testiter = iter(test_loader)

## 3 - Losses
(1) Semantic Segmentation Loss

(2) Instantance Segmentation Loss

(3) Depth Estimation Loss

In [10]:
def inverse_huber_loss(out, target):
    absdiff = torch.abs(out-target)
    C = 0.2*torch.max(absdiff)
    return torch.mean(torch.where(absdiff<C, absdiff, (absdiff*absdiff+C*C)/(2*C)))

In [11]:
def instance_loss(out, target):
    return

## 7 - Define the Hyperparameters(todo)

In [12]:
from data.utils import enet_weighing
lr = 5e-4
weight_decay = 2e-4

# figure out enet_weighing issue
#criterion_label = nn.CrossEntropyLoss()
criterion_label = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device))
#criterion_inst = criterion_label
#criterion_inst = instance_loss
criterion_dpth = inverse_huber_loss
optimizer = torch.optim.Adam(enet.parameters(), 
                             lr=lr,
                             weight_decay=weight_decay)

## Name experiment to differentiate different runs for tensorboard
## eg. for hyperparameter tuning

experiment = 'experiment_lr-' + str(lr) + '_bc-' + str(batch_size) + '_wd-' + str(weight_decay)

print_every = 2
eval_every = 2

## 8 - Training loop(todo)

In [13]:
train_losses = []
eval_losses = []

bc_train = 367 // batch_size # 367-mini_batch train
bc_eval = 101 // batch_size  # 101-mini_batch validation

epochs = 100

train_writer = SummaryWriter(nb_dir+'/runs/' + experiment + '/train')
val_writer = SummaryWriter(nb_dir+'/runs/' + experiment + '/val')

In [None]:
# Train loop

for e in range(1, epochs+1):
    
    train_loss = 0
    print ('-'*15,'Epoch %d' % e, '-'*15)
    
    enet.train()
    
    for _ in tqdm(range(bc_train)):
        # get new data/recycle if done
        try:
            img, label, inst, dpth = trainiter.next()
            
        except:
            trainiter = iter(train_loader)
            img, label, inst, dpth = trainiter.next()
            
        # assign data to cpu/gpu
        img, label, inst, dpth = img.to(device), label.to(device), inst.to(device), dpth.to(device)
        label = label.squeeze(1)
        inst = inst.squeeze(1)
        
        optimizer.zero_grad()
        out = enet(img.float())

        # split output into three predictions
        #label_out, inst_out, dpth_out = out[0], out[1], out[2]
        label_out, dpth_out = out[0], out[1]

        # loss calculation for class segmentation
        loss = criterion_label(label_out, label.long()).float()

        # loss calculation for class instance
        #loss += criterion_inst(inst_out, inst.long()).float()

        # loss calculation for depth
        loss += criterion_dpth(dpth_out, dpth.float())
        loss.backward()
        
        # update weights
        optimizer.step()

        train_loss += loss.item()
        
    train_writer.add_scalar('Loss', train_loss/bc_train, e)
    
    if e % eval_every == 0:
        with torch.no_grad():
            enet.eval()
            
            eval_loss = 0

            # Validation loop
            for _ in tqdm(range(bc_eval)):
                # get new data/recycle if done
                try:
                    img, label, inst, dpth = testiter.next()

                except:
                    testiter = iter(test_loader)
                    img, label, inst, dpth = testiter.next()
                
                # assign data to cpu/gpu
                img, label, inst, dpth = img.to(device), label.to(device), inst.to(device), dpth.to(device)
                label = label.squeeze(1)
                inst = inst.squeeze(1)
        
                out = enet(img.float())
                
                # split output into three predictions
                #label_out, inst_out, dpth_out = out[0], out[1], out[2]
                label_out, dpth_out = out[0], out[1]

                # loss calculation for class segmentation
                eval_loss += criterion_label(label_out, label.long()).float().item()

                # loss calculation for class instance
                #eval_loss += criterion_inst(inst_out, inst.long()).float().item()

                # loss calculation for depth
                eval_loss += criterion_dpth(dpth_out, dpth.float()).item()
                
            
            val_writer.add_scalar('Loss', eval_loss/bc_eval, e)
        
    if e % print_every == 0:
        checkpoint = {
            'epochs' : e,
            'state_dict' : enet.state_dict()
        }
        torch.save(checkpoint, nb_dir+'/content/ckpt-enet-{}-{}.pth'.format(e, train_loss))
        print ('Model saved!')
    train_writer.flush()
    val_writer.flush()

print ('Epoch {}/{}...'.format(e, epochs),
       'Total Mean Loss: {:6f}'.format(sum(train_loss) / epochs))


train_writer.close()
val_writer.close()

--------------- Epoch 1 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 2 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 3 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 4 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 5 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 6 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 7 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 8 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 9 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 10 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 11 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 12 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 13 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 14 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 15 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 16 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 17 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 18 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 19 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 20 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 21 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 22 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 23 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 24 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 25 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 26 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 27 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 28 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 29 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 30 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 31 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 32 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 33 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 34 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 35 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 36 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 37 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 38 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 39 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 40 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 41 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 42 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 43 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 44 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=101.0), HTML(value='')))


Model saved!
--------------- Epoch 45 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))


--------------- Epoch 46 ---------------


HBox(children=(FloatProgress(value=0.0, max=367.0), HTML(value='')))