<a href="https://colab.research.google.com/github/arnavvats/pytorch-cnns/blob/master/unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
os.environ['KAGGLE_USERNAME'] = "your-kaggle-username"
os.environ['KAGGLE_KEY'] = "your-kaggle-key"
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!kaggle datasets download -d kumaresanmanickavelu/lyft-udacity-challenge

lyft-udacity-challenge.zip: Skipping, found more recently modified local copy (use --force to force download)


In [0]:
!unzip lyft-udacity-challenge.zip
!tar -xvzf dataA.tar.gz
!ls

In [0]:
imdir = './dataA'

In [6]:
import matplotlib.pyplot as plt
import numpy as np
import os
import imageio
from torchsummary import summary
import time
imageio.imread(imdir + '/CameraSeg/07_00_141.png').shape

(600, 800, 3)

In [0]:
import torch.nn.functional as F
import torch.nn as nn
from torch import optim
import torch

In [0]:
class UNet(nn.Module):
    def __init__(self, n_classes):
        super(UNet, self).__init__()
        self.conv_1 = nn.Conv2d(3,8,3, padding = 1)
        self.bn_1 = nn.BatchNorm2d(8)
        self.conv_2 = nn.Conv2d(8,8,3, padding = 1)
        self.bn_2 = nn.BatchNorm2d(8)
        self.conv_3 = nn.Conv2d(8,16,3, padding = 1)
        self.bn_3 = nn.BatchNorm2d(16)
        self.conv_4 = nn.Conv2d(16,16,3, padding = 1)
        self.bn_4 = nn.BatchNorm2d(16)
        self.conv_5 = nn.Conv2d(16,32,3, padding = 1)
        self.bn_5 = nn.BatchNorm2d(32)
        self.conv_6 = nn.Conv2d(32,32,3, padding = 1)
        self.bn_6 = nn.BatchNorm2d(32)
        self.conv_7 = nn.Conv2d(32,64,3, padding = 1)
        self.bn_7 = nn.BatchNorm2d(64)
        self.conv_8 = nn.Conv2d(64,64,3, padding = 1)
        self.bn_8 = nn.BatchNorm2d(64)
        self.conv_9 = nn.Conv2d(64,128, 3, padding = 1)
        self.bn_9 = nn.BatchNorm2d(128)
        self.conv_10 = nn.Conv2d(128,128,3, padding = 1)
        self.bn_10 = nn.BatchNorm2d(128)
        self.up_conv_11 = nn.ConvTranspose2d(128,64,2,stride = 2)
        self.conv_12 = nn.Conv2d(128, 64, 3, padding = 1)
        self.bn_12 = nn.BatchNorm2d(64)
        self.conv_13 = nn.Conv2d(64, 64, 3, padding = 1)
        self.bn_13 = nn.BatchNorm2d(64)
        self.up_conv_14 = nn.ConvTranspose2d(64,32,2, stride = 2)
        self.conv_15 = nn.Conv2d(64, 32, 3, padding = 1)
        self.bn_15 = nn.BatchNorm2d(32)
        self.conv_16 = nn.Conv2d(32,32,3, padding= 1)
        self.bn_16 = nn.BatchNorm2d(32)
        self.up_conv_17 = nn.ConvTranspose2d(32, 16, 2, stride = 2)
        self.conv_18 = nn.Conv2d(32,16,3, padding = 1)
        self.bn_18 = nn.BatchNorm2d(16)
        self.conv_19 = nn.Conv2d(16,16,3, padding = 1)
        self.bn_19 = nn.BatchNorm2d(16)
        self.up_conv_20 = nn.ConvTranspose2d(16, 8, 2, stride = 2)
        self.conv_21 = nn.Conv2d(16, 8, 3, padding = 1)
        self.bn_21 = nn.BatchNorm2d(8)
        self.conv_22 = nn.Conv2d(8,8,3, padding = 1)
        self.bn_22 = nn.BatchNorm2d(8)
        self.conv_23 = nn.Conv2d(8,n_classes,1)
        
    def forward(self, x):
        x =F.relu(self.bn_1(self.conv_1(x)))
        d_1 = F.relu(self.bn_2(self.conv_2(x)))
        x = F.max_pool2d(d_1, (2,2),2)
        x = F.relu(self.bn_3(self.conv_3(x)))
        d_2 = F.relu(self.bn_4(self.conv_4(x)))
        x = F.max_pool2d(d_2, (2,2),2)
        x = F.relu(self.bn_5(self.conv_5(x)))
        d_3 = F.relu(self.bn_6(self.conv_6(x)))
        x = F.max_pool2d(d_3, (2,2),2)
        x = F.relu(self.bn_7(self.conv_7(x)))
        d_4 = F.relu(self.bn_8(self.conv_8(x)))
        x = F.max_pool2d(d_4, (2,2),2)
        x = F.relu(self.bn_9(self.conv_9(x)))
        x = F.relu(self.bn_10(self.conv_10(x)))
        x = self.up_conv_11(x)
        dx = d_4.size()[2] - x.size()[2]
        dy = d_4.size()[3] - x.size()[3]
        x = F.pad(x, (dy//2, dy - dy//2, dx//2, dx - dx//2))
        x = torch.cat([d_4, x], dim = 1)
        x = F.relu(self.bn_12(self.conv_12(x)))
        x = F.relu(self.bn_13(self.conv_13(x)))
        x = self.up_conv_14(x)
        dx = d_3.size()[2] - x.size()[2]
        dy = d_3.size()[3] - x.size()[3]
        x = F.pad(x, (dy//2, dy - dy//2, dx//2, dx - dx // 2))
        x = torch.cat([d_3, x], dim = 1)
        x = F.relu(self.bn_15(self.conv_15(x)))
        x = F.relu(self.bn_16(self.conv_16(x)))
        x = self.up_conv_17(x)
        dx = d_2.size()[2] - x.size()[2]
        dy = d_2.size()[3] - x.size()[3]
        x = F.pad(x, (dy // 2, dy - dy // 2, dx // 2, dx - dx // 2))
        x = torch.cat([d_2, x], dim = 1)
        x = F.relu(self.bn_18(self.conv_18(x)))
        x = F.relu(self.bn_19(self.conv_19(x)))
        x = self.up_conv_20(x)
        dx = d_1.size()[2] - x.size()[2]
        dy = d_1.size()[3] - x.size()[3]
        x = F.pad(x, (dy // 2, dy - dy // 2, dx // 2, dx - dx // 2))
        x = torch.cat([d_1, x], dim = 1)
        x = F.relu(self.bn_21(self.conv_21(x)))
        x = F.relu(self.bn_22(self.conv_22(x)))
        x = torch.sigmoid(self.conv_23(x))
        return x

In [0]:
mask_dir = imdir + '/CameraSeg'
input_dir = imdir + '/CameraRGB'
input_list = np.array(os.listdir(input_dir))

In [0]:
def generate_io(im_list):
    im_x_list = np.empty(shape=(0,3,600,800))
    im_y_list = np.empty(shape=(0,1,600,800))
    for im in im_list:
        im_x = np.reshape(imageio.imread(input_dir + '/' + im), (1,3,600,800))
        mask = imageio.imread(mask_dir + '/' + im)
        mask = np.amax(mask, axis = 2)
        im_y = np.zeros((600, 800))
        im_y[np.where(mask==7)[0], np.where(mask==7)[1]]=1
        im_y = im_y.reshape((1,1,600,800))
        im_x_list = np.concatenate((im_x_list,im_x), axis = 0)
        im_y_list = np.concatenate((im_y_list, im_y), axis = 0)
    return torch.tensor(im_x_list, dtype = torch.float32).cuda(), torch.tensor(im_y_list, dtype = torch.float32).cuda()

In [0]:
unet = UNet(1).cuda()
optimizer = optim.Adam(unet.parameters(),lr =0.5)
criterion = nn.BCELoss()
epochs = 200
steps = 21
batch_size = 24
threshold = 0.04
#summary(unet, (3, 600, 800))

In [12]:
for epoch in range(epochs):
  unet.train()
  np.random.seed(epoch)
  print('-----Starting Epoch {}---- '.format(epoch + 1))
  epoch_loss = 0
  np.random.shuffle(input_list)
  for step in range(steps):
    optimizer.zero_grad()
    input_batch = input_list[step * batch_size: (step + 1) * batch_size]
    input_im, masks_im = generate_io(input_batch)
    masks_pred = unet(input_im)
    masks_pred_flat = masks_pred.view(-1)
    masks_im_flat = masks_im.view(-1)
    loss = criterion(masks_pred_flat, masks_im_flat)
    epoch_loss += loss.item()
    loss.backward()
    optimizer.step()
  avg_epoch_loss = epoch_loss / steps
  print('Average epoch loss {}'.format(avg_epoch_loss))
  if avg_epoch_loss <= threshold:
    break


-----Starting Epoch 1---- 
Average epoch loss 0.40872676031930105
-----Starting Epoch 2---- 
Average epoch loss 0.25302205483118695
-----Starting Epoch 3---- 
Average epoch loss 0.21664260824521384
-----Starting Epoch 4---- 
Average epoch loss 0.2060413999216897
-----Starting Epoch 5---- 
Average epoch loss 0.19395494744891212
-----Starting Epoch 6---- 
Average epoch loss 0.19061319459052312
-----Starting Epoch 7---- 
Average epoch loss 0.19044679474262965
-----Starting Epoch 8---- 
Average epoch loss 0.18772398006348384
-----Starting Epoch 9---- 
Average epoch loss 0.18004640582061948
-----Starting Epoch 10---- 
Average epoch loss 0.17469733031023116
-----Starting Epoch 11---- 
Average epoch loss 0.17056523405370258
-----Starting Epoch 12---- 
Average epoch loss 0.16626455031690143
-----Starting Epoch 13---- 
Average epoch loss 0.1624888898361297
-----Starting Epoch 14---- 
Average epoch loss 0.16752744359629496
-----Starting Epoch 15---- 
Average epoch loss 0.16693033632777987
-----S

In [0]:
torch.save(unet.state_dict(), '/content/drive/My Drive/unet_model.h5')
