In [1]:
import argparse, sys, os, time
import cv2
import numpy as np
import pandas as pd
import torch
from torch import optim
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm

from dataloader import CityscapesDataset

DEPTH_CORRECTION = 2.1116e-09

def compute_loss(batch_X, batch_y_segmt, batch_y_depth, batch_mask_segmt, batch_mask_depth, model,
                 criterion=None, optimizer=None, is_train=True):

    model.train(is_train)

    batch_X = batch_X.to(device, non_blocking=True)
    batch_y_segmt = batch_y_segmt.to(device, non_blocking=True)
    batch_y_depth = batch_y_depth.to(device, non_blocking=True)
    batch_mask_segmt = batch_mask_segmt.to(device, non_blocking=True)
    batch_mask_depth = batch_mask_depth.to(device, non_blocking=True)

    output = model(batch_X)
    image_loss, label_loss = criterion(output, batch_y_segmt, batch_y_depth, batch_mask_segmt, batch_mask_depth)

    if is_train:
        optimizer.zero_grad()
        image_loss.backward(retain_graph=True)
        label_loss.backward()
        optimizer.step()

    return image_loss.item() + label_loss.item()

torch.manual_seed(0)
device = torch.device("cuda")
print("device: {}".format(device))

print("Loading dataset...")
train_data = CityscapesDataset(root_path='../data/cityscapes', height=256, width=512,
                               split='train', transform=["random_flip"], ignore_index=20)
valid_data = CityscapesDataset(root_path='../data/cityscapes', height=256, width=512, 
                               split='val', transform=None, ignore_index=20)
# test_data = CityscapesDataset('./data/cityscapes', split='train', transform=transform)
train = DataLoader(train_data, batch_size=8, shuffle=True, num_workers=2)
valid = DataLoader(valid_data, batch_size=8, shuffle=True, num_workers=2)

  (fname, cnt))
  (fname, cnt))


device: cuda
Loading dataset...


In [2]:
class TTDown(nn.Module):
    def __init__(self, in_features, out_features, mid_features=None):
        super().__init__()
        if not mid_features:
            mid_features = out_features
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_features, mid_features, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_features, out_features, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        return self.pool(self.double_conv(x))

class TTUp(nn.Module):
    def __init__(self, in_features, out_features, mid_features=None):
        super().__init__()
        if not mid_features:
            mid_features = out_features
        self.double_conv = nn.Sequential(
                    nn.Conv2d(in_features, mid_features, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(mid_features, out_features, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True)
                )
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        return self.double_conv(self.up(x))

class TSNet1(nn.Module):
    def __init__(self, in_features, out_features):
        super(TSNet1, self).__init__()
        features = [32, 64, 128]
        self.enc1 = TTDown(in_features=in_features, out_features=features[0])
        self.enc2 = TTDown(in_features=features[0], out_features=features[1])
        self.enc3 = TTDown(in_features=features[1], out_features=features[2])
        self.dec1 = TTUp(in_features=features[2], out_features=features[1])
        self.dec2 = TTUp(in_features=features[1], out_features=features[0])
        self.dec3 = TTUp(in_features=features[0], out_features=features[0])
        self.final_conv = nn.Conv2d(features[0], out_features, kernel_size=1, stride=1, bias=False)

    def forward(self, x):
        out = self.enc1(x)
        out = self.enc2(out)
        out = self.enc3(out)
        out = self.dec1(out)
        out = self.dec2(out)
        out = self.dec3(out)
        out = self.final_conv(out)
        return out
    
class TSNet2(nn.Module):
    def __init__(self, in_features, out_features):
        super(TSNet2, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_features, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )
        self.conv_out = nn.Conv2d(32, out_features, kernel_size=1, stride=1, bias=False)
        
    def forward(self, x):
        out = self.conv(x)
        out = self.conv_out(out)
        return out

In [3]:
model = TSNet1(in_features=19, out_features=1).to(device)

In [14]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [16]:
for e in range(5):
    tr_loss = 0.
    for batch in tqdm(train):
        _, _, seg, dep, _, _, = batch
        seg = F.one_hot(seg, num_classes=19).permute(0,3,1,2)[:, :19, :, :].type(torch.FloatTensor)
        seg = seg.to(device, non_blocking=True)
        dep = dep.to(device, non_blocking=True)
        pred = model(seg)
        loss = criterion(pred, dep)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        tr_loss += loss.item()
    print(tr_loss)

HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.7532201349386014


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.49119944410631433


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.4556331892381422


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.4362934945966117


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.41987054637866095


In [3]:
model = TSNet2(in_features=19, out_features=1).to(device)

In [4]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [5]:
for e in range(10):
    tr_loss = 0.
    for batch in tqdm(train):
        _, _, seg, dep, _, _ = batch
        seg = F.one_hot(seg, num_classes=20).permute(0,3,1,2)[:, :19, :, :].type(torch.FloatTensor)
        seg = seg.to(device, non_blocking=True)
        dep = dep.to(device, non_blocking=True)
        pred = model(seg)
        loss = criterion(pred, dep)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        tr_loss += loss.item()
    print(tr_loss)

HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


2.290089809568599


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.9548620418645442


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.902598551590927


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.8788690216606483


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.8575124908238649


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.8509416352026165


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.8469498814083636


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.8305354914627969


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.8193271509371698


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


0.8203162858262658


In [3]:
model = TSNet1(in_features=1, out_features=19).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [4]:
for e in range(5):
    tr_loss = 0.
    for batch in tqdm(train):
        _, _, seg, dep, _, _, = batch
        seg = seg.to(device, non_blocking=True)
        dep = dep.to(device, non_blocking=True)
        pred = model(dep)
        loss = criterion(pred, seg)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        tr_loss += loss.item()
    print(tr_loss)

HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


770.8169963359833


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


567.2385686635971


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


494.38499319553375


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


456.75744158029556


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


439.2776182293892


In [3]:
model = TSNet2(in_features=1, out_features=19).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
for e in range(5):
    tr_loss = 0.
    for batch in tqdm(train):
        _, _, seg, dep, _, _, = batch
        seg = seg.to(device, non_blocking=True)
        dep = dep.to(device, non_blocking=True)
        pred = model(dep)
        loss = criterion(pred, seg)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        tr_loss += loss.item()
    print(tr_loss)

HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


848.749220252037


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))


684.063108086586


HBox(children=(FloatProgress(value=0.0, max=372.0), HTML(value='')))