In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
import os
from PIL import Image

In [2]:
from cityscapedataset import Dataset, CityScapeClasses
##import the sdn network here
from sdn_architecture import SDN, Supervised_CE_loss

In [3]:
classes = CityScapeClasses().classes
classes

['road',
 'sidewalk',
 'building',
 'fence',
 'pedestrian-railing',
 'pole',
 'traffic-light',
 'traffic-sign',
 'tree',
 'vegetation',
 'sky',
 'person',
 'rider',
 'car',
 'truck',
 'bus',
 'train',
 'motorcycle',
 'bicycle',
 'misc']

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
##test overfitting
tds = Dataset('train', im_size=(512,1024))
val_ds = Dataset('val', im_size=(512,1024), length=10)
#for final training#######################################################################
#tds = Dataset('train', im_size=(512,1024))
#val_ds = Dataset('val', im_size=(512,1024), length=40)

In [5]:
class GPUDL():
    def __init__(self, dl):
        self.dl = dl
    def __iter__(self):
        for xb, yb in self.dl:
            yield xb.to(device, non_blocking=True),yb.to(device, non_blocking=True)
    def __len__(self):
        return len(self.dl)

In [6]:
import time

In [7]:
from unet_utils import miou, pixelwiseacc

In [8]:
def train(model, lr, batch_size, epochs, weight_decay, tds, val_ds):
    optim = torch.optim.SGD(model.parameters(), lr, weight_decay=weight_decay, momentum=0.99)
    ##import special loss function here#######################################################
    loss_fn = Supervised_CE_loss()
    #loss_fn = nn.CrossEntropyLoss()
    
    dl = GPUDL(torch.utils.data.DataLoader(tds, batch_size, shuffle=True, num_workers=2))
    val_dl = GPUDL(torch.utils.data.DataLoader(val_ds, batch_size,num_workers=2))
    
    model.to(device)
    
    sched = torch.optim.lr_scheduler.OneCycleLR(optim, lr, epochs=epochs, steps_per_epoch=len(dl))
    
    losses = []
    val_losses = []
    mious = []
    val_mious = []
    paccs = []
    val_paccs = []
    for epoch in range(epochs):
        begin = time.time()
        losses.append([])
        val_losses.append([])
        mious.append([]) 
        val_mious.append([])
        paccs.append([])
        val_paccs.append([])
        for xb, yb in dl:
            final, preds1, preds2, preds3 = model(xb)
            loss = loss_fn(final,preds1,preds2,preds3, yb)
            #loss = loss_fn(preds3[2], yb)
            optim.zero_grad()
            loss.backward()
            
            nn.utils.clip_grad_value_(model.parameters(), 0.1)
            optim.step()
            sched.step()
            
            losses[epoch].append(loss.item())
            mious[epoch].append(miou(preds3[2],yb))
            paccs[epoch].append(pixelwiseacc(preds3[2],yb))
            del(xb);del(yb);del(loss);
            del(preds1);del(preds2);del(preds3);
            del(final)
            torch.cuda.empty_cache()
        with torch.no_grad():
            model.eval()
            for val_xb, val_yb in val_dl:
                val_final, val_preds1, val_preds2, val_preds3 = model(val_xb)
                val_loss = loss_fn(val_final, val_preds1, val_preds2, val_preds3, val_yb)
                #val_loss = loss_fn(val_preds3[2], val_yb)
                val_losses[epoch].append(val_loss.item())
                val_mious[epoch].append(miou(val_preds3[2],val_yb))
                val_paccs[epoch].append(pixelwiseacc(val_preds3[2],val_yb))
                del(val_xb);del(val_yb);del(val_loss);
                del(val_preds1);del(val_preds2);del(val_preds3)
                del(val_final)
                torch.cuda.empty_cache()
            model.train()
            with open(f'../working/SDN_{epoch+1}cepochsCityScapes.pt', 'wb') as f:
                pass
            torch.save(model.state_dict(), f'../working/SDN_{epoch+1}cepochsCityScapes.pt')
        print('Epoch:', epoch + 1, 'TrainLoss:', f'{np.mean(losses[epoch]):.4f}', 'TMiou',f'{np.mean(mious[epoch]):.4f}', 'TPacc', f'{np.mean(paccs[epoch]):.4f}') 
        print('ValLoss', f'{np.mean(val_losses[epoch]):.4f},','VMiou',f'{np.mean(val_mious[epoch]):.4f}', 'VPacc', f'{np.mean(val_paccs[epoch]):.4f}',
              f'{((time.time() - begin) / 60):.2f}', 'minutes')
    return losses, val_losses

In [9]:
##make sdn model object
model = SDN(no_deconv=True, pretrained=False)
model.load_state_dict(torch.load('/kaggle/input/38vmiousdn/SDN_3epochsCityScapes37VMIOU93PACC.pt'))
model.train()
pass

In [10]:
##hyperparams
lr = 0.0002
batch_size = 1
epochs = 10
weight_decay = 0.000005

In [None]:
losses, val_losses = train(model, lr, batch_size, epochs, weight_decay, tds, val_ds)

Epoch: 1 TrainLoss: 1.5251 TMiou 0.3662 TPacc 0.9187
ValLoss 1.3795, VMiou 0.3697 VPacc 0.9240 80.90 minutes
Epoch: 2 TrainLoss: 1.5542 TMiou 0.3648 TPacc 0.9162
ValLoss 1.3929, VMiou 0.3689 VPacc 0.9273 81.33 minutes
Epoch: 3 TrainLoss: 1.4766 TMiou 0.3693 TPacc 0.9202
ValLoss 1.3269, VMiou 0.3771 VPacc 0.9314 81.51 minutes
Epoch: 4 TrainLoss: 1.3389 TMiou 0.3772 TPacc 0.9278
ValLoss 1.3010, VMiou 0.3736 VPacc 0.9308 80.71 minutes
Epoch: 5 TrainLoss: 1.2030 TMiou 0.3862 TPacc 0.9353
ValLoss 1.3858, VMiou 0.3833 VPacc 0.9319 80.85 minutes
Epoch: 6 TrainLoss: 1.0964 TMiou 0.3943 TPacc 0.9408
ValLoss 1.2783, VMiou 0.3844 VPacc 0.9339 80.87 minutes
Epoch: 7 TrainLoss: 1.0077 TMiou 0.4017 TPacc 0.9455
ValLoss 1.2753, VMiou 0.3863 VPacc 0.9338 80.76 minutes
Epoch: 8 TrainLoss: 0.9389 TMiou 0.4082 TPacc 0.9491
ValLoss 1.2405, VMiou 0.3951 VPacc 0.9367 81.07 minutes


In [None]:
plt.plot(losses[0], val_losses[0])

In [None]:
with open('../working/SDN_FINALepochsCityScapes.pt', 'wb') as f:
    pass
torch.save(model.state_dict(), '../working/SDN_1epochsCityScapes.pt')

# Prediction on Dataset

In [None]:
##to test overfitting
img, label = tds[0]

##to test actual accuracy on validation
#img, label = val_ds[12]

In [None]:
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
ax.imshow(img.permute(1,2,0))

## Ground Truth Value

In [None]:
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
ax.imshow(label, cmap='gray')

## Prediction indices visualized

In [None]:
final, pred1, pred2, pred3 = model(img.to(device).reshape(-1,3,512,1024))#[0][2].reshape(20,512,1024)

In [None]:
asdf, idxs = torch.max(pred1[0].reshape(20,512,1024), dim=0)
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
ax.imshow(idxs.cpu(), cmap='gray')

In [None]:
asdf, idxs = torch.max(pred1[1].reshape(20,512,1024), dim=0)
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
ax.imshow(idxs.cpu(), cmap='gray')

In [None]:
asdf, idxs = torch.max(pred1[2].reshape(20,512,1024), dim=0)
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
ax.imshow(idxs.cpu(), cmap='gray')

In [None]:
asdf, idxs = torch.max(pred2[0].reshape(20,512,1024), dim=0)
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
ax.imshow(idxs.cpu(), cmap='gray')

In [None]:
asdf, idxs = torch.max(pred2[1].reshape(20,512,1024), dim=0)
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
ax.imshow(idxs.cpu(), cmap='gray')

In [None]:
asdf, idxs = torch.max(pred2[2].reshape(20,512,1024), dim=0)
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
ax.imshow(idxs.cpu(), cmap='gray')

In [None]:
asdf, idxs = torch.max(pred3[0].reshape(20,512,1024), dim=0)
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
ax.imshow(idxs.cpu(), cmap='gray')

In [None]:
asdf, idxs = torch.max(pred3[1].reshape(20,512,1024), dim=0)
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
ax.imshow(idxs.cpu(), cmap='gray')

In [None]:
asdf, idxs = torch.max(pred3[2].reshape(20,512,1024), dim=0)
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
ax.imshow(idxs.cpu(), cmap='gray')

In [None]:
asdf, idxs = torch.max(final.reshape(20,512,1024), dim=0)
fig = plt.figure(figsize=(20,20))
ax = fig.add_subplot()
ax.imshow(idxs.cpu(), cmap='gray')

### Prediction on Valildation Data

In [None]:
colorify(idxs.cpu(), (512,1024))

In [None]:
def colorify(idxs, img_shape, figsize=(20,20)):
    idxs = np.array(idxs.to('cpu'))
    assert img_shape[0] == idxs.shape[0] and img_shape[1] == idxs.shape[1]
                          ##road         ##sidelwalk    building        fence       p-railing       pole      traffic-light
    colors = np.array([[162, 75, 175], [244,31,181], [113,103,112], [100,75,44], [147,98,39], [203,202,190], [219,169,42],
                 [255,246,69], [21,149,18], [155,243,154], [39,201,200], [234,22,47], [255,0,30], [39,62,163]])
                   #traffic-sign    ##tree       vegetation       sky         person       rider         car
    colors = np.append(colors, ([[78,105,221], [128,147,224], [82,96,159], [94,19,36], [123,48,66], [0,0,0]]))
                                   #truck           bus           train       motorcycle   bicycle     misc
    colors = colors.reshape((20,3))
    
    
    cmap = torch.zeros((3,img_shape[0], img_shape[1])).long()
    
    for i in range(len(colors)):
        cmap[0,idxs==i] = colors[i][0]
        cmap[1,idxs==i] = colors[i][1]
        cmap[2,idxs==i] = colors[i][2]
    
    fig = plt.figure(figsize=(figsize[0], figsize[1]))
    ax = fig.add_subplot()
    ax.imshow(cmap.permute(1,2,0))

In [None]:
model