In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
import cv2
import os
from tqdm import tqdm_notebook as tqdm
from PIL import Image
import torchvision.models as models
import torch.optim as optim
from google.colab import files
from torch.optim.lr_scheduler import LambdaLR
import glob
import matplotlib.gridspec as gridspec

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Get the dataset

PASCAL VOC 2012

In [0]:
!wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar -O VOCtrainval.tar
!tar -xf VOCtrainval.tar

CityScapes

In [0]:
#!wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=arunava.chakraborty&password=xxxxxxx&submit=Login' https://www.cityscapes-dataset.com/login/
#!wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1
#!wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3
#!unzip -qq gtFine_trainvaltest.zip
#!unzip -qq leftImg8bit_trainvaltest.zip

## DeepLabv3 Model

In [0]:
class ResNet_50 (nn.Module):
  def __init__(self, in_channels = 3, conv1_out = 64):
    super(ResNet_50,self).__init__()
    
    self.resnet_50 = models.resnet50(pretrained = True)
    
    self.relu = nn.ReLU(inplace=True)
  
  def forward(self,x):
    x = self.relu(self.resnet_50.bn1(self.resnet_50.conv1(x)))
    x = self.resnet_50.maxpool(x)
    x = self.resnet_50.layer1(x)
    x = self.resnet_50.layer2(x)
    x = self.resnet_50.layer3(x)
    
    return x

In [0]:
class ASSP(nn.Module):
  def __init__(self,in_channels,out_channels = 256):
    super(ASSP,self).__init__()
    
    
    self.relu = nn.ReLU(inplace=True)
    
    self.conv1 = nn.Conv2d(in_channels = in_channels, 
                          out_channels = out_channels,
                          kernel_size = 1,
                          padding = 0,
                          dilation=1,
                          bias=False)
    
    self.bn1 = nn.BatchNorm2d(out_channels)
    
    self.conv2 = nn.Conv2d(in_channels = in_channels, 
                          out_channels = out_channels,
                          kernel_size = 3,
                          stride=1,
                          padding = 6,
                          dilation = 6,
                          bias=False)
    
    self.bn2 = nn.BatchNorm2d(out_channels)
    
    self.conv3 = nn.Conv2d(in_channels = in_channels, 
                          out_channels = out_channels,
                          kernel_size = 3,
                          stride=1,
                          padding = 12,
                          dilation = 12,
                          bias=False)
    
    self.bn3 = nn.BatchNorm2d(out_channels)
    
    self.conv4 = nn.Conv2d(in_channels = in_channels, 
                          out_channels = out_channels,
                          kernel_size = 3,
                          stride=1,
                          padding = 18,
                          dilation = 18,
                          bias=False)
    
    self.bn4 = nn.BatchNorm2d(out_channels)
    
    self.conv5 = nn.Conv2d(in_channels = in_channels, 
                          out_channels = out_channels,
                          kernel_size = 1,
                          stride=1,
                          padding = 0,
                          dilation=1,
                          bias=False)
    
    self.bn5 = nn.BatchNorm2d(out_channels)
    
    self.convf = nn.Conv2d(in_channels = out_channels * 5, 
                          out_channels = out_channels,
                          kernel_size = 1,
                          stride=1,
                          padding = 0,
                          dilation=1,
                          bias=False)
    
    self.bnf = nn.BatchNorm2d(out_channels)
    
    self.adapool = nn.AdaptiveAvgPool2d(1)  
   
  
  def forward(self,x):
    
    x1 = self.conv1(x)
    x1 = self.bn1(x1)
    x1 = self.relu(x1)
    
    x2 = self.conv2(x)
    x2 = self.bn2(x2)
    x2 = self.relu(x2)
    
    x3 = self.conv3(x)
    x3 = self.bn3(x3)
    x3 = self.relu(x3)
    
    x4 = self.conv4(x)
    x4 = self.bn4(x4)
    x4 = self.relu(x4)
    
    x5 = self.adapool(x)
    x5 = self.conv5(x5)
    x5 = self.bn5(x5)
    x5 = self.relu(x5)
    x5 = F.interpolate(x5, size = tuple(x4.shape[-2:]), mode='bilinear')
    
    #print (x1.shape, x2.shape, x3.shape, x4.shape, x5.shape)
    x = torch.cat((x1,x2,x3,x4,x5), dim = 1) #channels first
    x = self.convf(x)
    x = self.bnf(x)
    x = self.relu(x)
    
    return x

In [0]:
class DeepLabv3(nn.Module):
  
  def __init__(self, nc):
    
    super(DeepLabv3, self).__init__()
    
    self.nc = nc
    
    self.resnet = ResNet_50()
    
    self.assp = ASSP(in_channels = 1024)
    
    self.conv = nn.Conv2d(in_channels = 256, out_channels = self.nc,
                          kernel_size = 1, stride=1, padding=0)
        
  def forward(self,x):
    _, _, h, w = x.shape
    x = self.resnet(x)
    x = self.assp(x)
    x = self.conv(x)
    x = F.interpolate(x, size=(h, w), mode='bilinear') #scale_factor = 16, mode='bilinear')
    return x

## Define the loader

In [0]:
def loader(input_path, segmented_path, batch_size, h=1024, w=2048):
    """
    Creating data loader for the training
    
    Args:
    
    input_path - path to images folder
    segmented_path - path to labels (segmented images)
    batch_size - amount of images in each mini batch
    h - image height
    w - image weight
    
    """    
    filenames = os.listdir(segmented_path)
    total_files = len(filenames)
    
    inp_files = list(map(lambda x : x.split('.')[0] + '.jpg', filenames))
    
    if str(batch_size).lower() == 'all':
        batch_size = total_files
    
    idx = 0
    while(1):
      # Choosing random indexes of images and labels
        batch_idxs = np.random.randint(0, total_files, batch_size)
        
        inputs = []
        labels = []
        
        for jj in batch_idxs:
            # Reading normalized photo
            img = np.array(Image.open(input_path + inp_files[jj]))
            # Resizing using nearest neighbor method
            img = cv2.resize(img, (w, h), cv2.INTER_NEAREST)
            inputs.append(img)
          
            # Reading semantic image
            img = Image.open(segmented_path + filenames[jj])
            img = np.array(img)
            idx255 = img == np.ones_like(img) * 255
            img[idx255] = 0
            # Resizing using nearest neighbor method
            img = cv2.resize(img, (w, h), cv2.INTER_NEAREST)
            labels.append(img)
         
        inputs = np.stack(inputs, axis=2)
        # Changing image format to C x H x W
        inputs = torch.tensor(inputs).transpose(0, 2).transpose(1, 3)
        
        labels = torch.tensor(labels)
        
        yield inputs, labels

## Define the Global Variables and helper functions

In [0]:
# Defining Global variables used for training
###############################

training_path = '/content/VOCdevkit/VOC2012/JPEGImages/'
segmented_path = '/content/VOCdevkit/VOC2012/SegmentationClass/'
eval_path = '/content/VOCdevkit/VOC2012/JPEGImages/'
train_samples = len(os.listdir(segmented_path))
eval_samples = len(os.listdir(segmented_path))
nc = 21
H = 360
W = 500
batch_size = 16

lr = 0.001
epochs = 10
save_every = 1
print_every = 1
eval_every = 1

In [0]:
def get_class_weights(loader, num_classes, c=1.02):
    '''
    This class return the class weights for each class
    
    Arguments:
    - loader : The generator object which return all the labels at one iteration
               Do Note: That this class expects all the labels to be returned in
               one iteration
    - num_classes : The number of classes
    Return:
    - class_weights : An array equal in length to the number of classes
                      containing the class weights for each class
    '''

    _, labels = next(loader)
    all_labels = labels.flatten()
    each_class = np.bincount(all_labels, minlength=num_classes)
    prospensity_score = each_class / len(all_labels)
    class_weights = 1 / (np.log(c + prospensity_score))
    return class_weights

In [0]:
def show_pascal(model, path, fname, ext='.jpg'):
  tmg_ = np.array(Image.open(path + fname.split('.')[0] + ext))
  tmg_ = cv2.resize(tmg_, (W, H), cv2.INTER_NEAREST)
  tmg = torch.tensor(tmg_).unsqueeze(0).float()
  tmg = tmg.transpose(2, 3).transpose(1, 2).to(device)
  model.eval()
  with torch.no_grad():
      out1 = model(tmg.float()).squeeze(0)
  model.train()
  
  b_ = out1.data.max(0)[1].cpu().numpy()
  
  out2 = out1.cpu().detach().numpy()
  
  plt.title('Input Image')
  plt.axis('off')
  plt.imshow(tmg_)
  
  plt.figure(figsize=(10, 10))
  #fig, ax = plt.subplots(nrows=5, ncols=4)
  gs = gridspec.GridSpec(5, 4)
  gs.update(wspace=0.025, hspace=0.0)
  
  label = 0
  for ii in range(20):
    plt.subplot(gs[ii])
    plt.axis('off')
    plt.imshow(out2[label, :, :])
    label += 1
  plt.show()

In [0]:
###############################
# Create here two loaders

pipe = loader(training_path, segmented_path, batch_size, h = H, w = W)
#eval_pipe = loader(eval_path, segmented_eval_path, batch_size, h = H, w = W)

In [0]:
###############################

model = DeepLabv3(nc).to(device)
#model = DeepLabV3(nc).to(device)

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.torch/models/resnet50-19c8e357.pth
102502400it [00:01, 92771258.75it/s]


In [0]:
class_weights = get_class_weights(pipe, nc)
criterion = nn.CrossEntropyLoss(weight = torch.FloatTensor(class_weights).to(device), ignore_index=255)

In [0]:
# Define Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)

In [0]:
# Loading model and optimizer

model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['opt_state_dict'])

In [0]:
epochs = 5

In [0]:
# Define learning rate scheduler

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: (1 - (epoch / epochs)) ** 0.9)

In [0]:
all_tests = os.listdir(segmented_path)

In [0]:
# Training loop

epochs = 50

train_losses = [] # holds the running train loss
eval_losses = []  # holds the running validation loss 

# Calculate the number of mini batches
bc_train = train_samples // batch_size
bc_eval = eval_samples // batch_size

for e in range(1, epochs+1):

    train_loss = 0
    print ('-'*15,'Epoch %d' % e, '-'*15)
    
    scheduler.step()
    
    model.train()

    for _ in tqdm(range(bc_train)):
      
        # generate batch
        X_batch, mask_batch = next(pipe)
        
        # assign to GPU/CPU
        X_batch, mask_batch = X_batch.to(device), mask_batch.to(device)

        optimizer.zero_grad()

        out = model(X_batch.float())
        
        loss = criterion(out, mask_batch.long())
        loss.backward()
        optimizer.step()

        train_loss += loss.item()


    print ()
    train_losses.append(train_loss)

    if (e+1) % print_every == 0:
        print ('Epoch {}/{}...'.format(e, epochs),
                'Loss {:6f}'.format(train_loss))
    '''
    if e % eval_every == 0:
        with torch.no_grad():
            model.eval()

            eval_loss = 0

            for _ in tqdm(range(bc_eval)):
                inputs, labels = next(eval_pipe)

                inputs, labels = inputs.to(device), labels.to(device)
                out = model(inputs.float())

                loss = criterion(out, labels.long())

                eval_loss += loss.item()

            print ()
            print ('Loss {:6f}'.format(eval_loss))

            eval_losses.append(eval_loss)
    
    scheduler.step(eval_loss)
    '''
    
    if e % save_every == 0:
        checkpoint = {
            'epochs' : e,
            'state_dict' : model.state_dict()
        }
        torch.save(checkpoint, './ckpt-enet-{}-{:2f}.pth'.format(e, train_loss))
        print ('Model saved!')
    
    show_pascal(model, training_path, all_tests[np.random.randint(0, len(all_tests))])

## Save the model or images to drive

In [0]:
auth.authenticate_user()

In [0]:
from googleapiclient.discovery import build

drive_service = build('drive', 'v3')

def save_to_drive(fcname, fname):
    file_metadata = {
        'name' : '{}'.format(fname)
    }

    media = MediaFileUpload('/content/{}'.format(fcname), resumable=True)
    
    created = drive_service.files().create(body=file_metadata,
                                           media_body=media,
                                           fields='id').execute()
    
    print ('[INFO] File created with id = {}'.format(created['id']))

In [0]:
colab_file_name = ''
drive_file_name = ''
save_to_drive(colab_file_name, drive_file_name)