In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.optim
import torch.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision

import torchvision.transforms as transforms
import torch.optim as optim

import torch.nn as nn
from torch.autograd import Variable

from include.data import ObjectSegmentationDataset
from include.hybrid_net_batchnorm import SegmentationModel

from include.Utility_Functions import Validate_IOU
from torch.nn.modules.upsampling import Upsample

from torch.utils.data.sampler import SubsetRandomSampler

from include import crf
import time


In [None]:
# Loaders for Training and Validation

# It is important to normalise the Input images acording to
# https://pytorch.org/docs/stable/torchvision/models.html

norm_transform = transforms.Normalize(mean=[0.4441137029691302,0.4153741837809971,0.38574530437912646],std=[0.2804474648193802,0.2737859518590525,0.2825218670683551])
transform = transforms.Compose([transforms.ToTensor(),norm_transform])
number_of_classes = 7
train_dataset = ObjectSegmentationDataset(
    src_image_dir="/home/snajder/AML/data/VOCdevkit/VOC2012/JPEGImages",
    seg_image_dir="/home/snajder/AML/data/VOCdevkit/VOC2010/person_trainval/Annotations_Part_images",
    num_classes=number_of_classes,
    transform=transform,
    gt_one_hot=False,
    rescale=False)

val_dataset = ObjectSegmentationDataset(
    src_image_dir="/home/snajder/AML/data/VOCdevkit/VOC2012/JPEGImages",
    seg_image_dir="/home/snajder/AML/data/VOCdevkit/VOC2010/person_trainval/Annotations_Part_images",
    num_classes=number_of_classes,
    transform=transform,
    gt_one_hot=False,
    augment=False,
    rescale=False)    

# Must match augmentation options above. If augmentation=False, this is 1, 
# if augmentation=True and rescale=True then this is 4
# if augmentation=True and rescale=False then this is 8
augmentation_factor = 8

In [None]:
net = SegmentationModel(num_classes=7)
net_file="Segnet_Best_params_VOC_Person_7class_09-16.pth"
checkpoint = torch.load(net_file)
net.load_state_dict(checkpoint['state_dict'])
net = net.cuda()

In [None]:
bs=1

best_score=0

net.train()
running_loss=0.0
losses=[]
Val_Score=[]

# Create training and validation loader for cross validation
# Now this is a bit complicated, because we don't want to separate images
# from their augmented versions. For example, if image 0 is in the validation
# set, then image 1 (which would be image 0 horizontally flipped) should not
# be in the training set. Hence, we first pick indices from the number of 
# images without augmentations, then  multiply all indices by the number of
# augmentations and add the augmented indices back for the training set

num_train = len(train_dataset)
# Indices of images (NOT considering the augmentations)
indices = list(range(num_train//augmentation_factor))
num_validation = 500
num_test = 250

np.random.seed(43)
validation_idx = np.random.choice(indices, size=num_validation, replace=False)
train_idx = list(set(indices) - set(validation_idx))
test_idx = np.random.choice(train_idx, size=num_test, replace=False)
train_idx = list(set(train_idx) - set(validation_idx))
np.random.seed(int(time.time()))

# add indices of augmentations
train_idx = [idx * augmentation_factor + aug_offset for idx in train_idx for aug_offset in range(augmentation_factor)]


Train_Loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=False,
                                         num_workers=4, sampler=SubsetRandomSampler(train_idx))
VAL_Loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, 
                                         num_workers=4, sampler=SubsetRandomSampler(validation_idx))
Test_Loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, 
                                         num_workers=4, sampler=SubsetRandomSampler(test_idx))

In [None]:
# This code has been used to train batches of downscaled images in the first stages of training

learning_rate = 1e-3


optimizer = optim.SGD([
                {'params': net.vgg.parameters(), 'lr': learning_rate*0.1},
                {'params': net.atr.parameters(), 'lr': learning_rate}
            ], momentum=0.9, weight_decay=0.0005)

for epoch in range(2000):
    iterations = 0
    net.train()
    for i, data in enumerate(Train_Loader, 0):
        src_img, seg_img, seg_img_ds, src_img_raw = data

        Input = Variable(src_img).float().cuda()
        Target = Variable(seg_img_ds.long(), requires_grad=False).cuda()
        
        optimizer.zero_grad()
        Output = net(Input)

        # use weighted cross entropy
        ps = nn.functional.softmax(Output, dim=1).sum(dim=(0,2,3))
        weights = (Target.view(-1).size()[0] - ps)/ps
        loss = nn.functional.cross_entropy(Output, Target, weight=weights.detach())
        
        loss.backward()
        optimizer.step()
        
        running_loss=(loss.cpu().data)/Output.shape[0]
        
        iterations+=bs
            

        if iterations % 50 == 0: 
            print('Epoch = %d, Iteration = %d, loss= %.3f' % (epoch + 1, iterations, running_loss))    
            out=nn.functional.softmax(Output[-1,:,:,:], dim=0)
            out=out.cpu()
            out=out.detach().numpy()

            I=np.argmax(out,axis=0)
            fig = plt.figure()
            fig.add_subplot(1,3,1)
            plt.imshow(src_img_raw[-1].numpy())
            fig.add_subplot(1,3,2)
            plt.imshow(seg_img_ds[-1].numpy())
            plt.axis('off')
            fig.add_subplot(1,3,3)
            plt.imshow(I[:,:])
            plt.axis('off')
            plt.show()
            
    net.eval() 
    # Checkpoint Network and Optimiser after each epoch
    best_score,score=Validate_IOU(net,optimizer,epoch,losses,Val_Score,bs,learning_rate,VAL_Loader,best_score,net_file)
    print("Current validation score: ", score)
    losses.append(running_loss)
    Val_Score.append(score)

In [None]:
# This code has been used in later stages of the training. Instead of batchwise training, 
# it assumes batches of size 1, but runs multiple batches in sequence before taking a step
# in the optimizer. This reduces the noise level in the gradient descent, while allowing 
# for images of variable size


learning_rate = 5e-4

v_bs = 3
optimizer = optim.SGD([
                {'params': net.vgg.parameters(), 'lr': learning_rate*0.1},
                {'params': net.atr.parameters(), 'lr': learning_rate}
            ], momentum=0.91, weight_decay=0.0006)

for epoch in range(2000):
    iterations = 0
    net.train()
    
    batch_loss = 0
    for i, data in enumerate(Train_Loader, 0):
        src_img, seg_img, seg_img_ds, src_img_raw = data

        Input = Variable(src_img).float().cuda()
        Target = Variable(seg_img_ds.long(), requires_grad=False).cuda()
        
        optimizer.zero_grad()
        Output = net(Input)

        # use weighted cross entropy
        ps = nn.functional.softmax(Output, dim=1).sum(dim=(0,2,3))
        weights = (Target.view(-1).size()[0] - ps)/ps
        loss = nn.functional.cross_entropy(Output, Target, weight=weights.detach())
        batch_loss = batch_loss + loss
        iterations+=bs
        
        if(iterations % v_bs == v_bs - 1):
            batch_loss.backward()
            optimizer.step()

            running_loss=(batch_loss.cpu().data)/v_bs


            if i % 50 == 0: 
                print('Epoch = %d, Iteration = %d, loss= %.3f' % (epoch + 1, iterations, running_loss))    
                out=nn.functional.softmax(Output[-1,:,:,:], dim=0)
                out=out.cpu()
                out=out.detach().numpy()

                I=np.argmax(out,axis=0)
                fig = plt.figure()
                fig.add_subplot(1,3,1)
                plt.imshow(src_img_raw[-1].numpy())
                fig.add_subplot(1,3,2)
                plt.imshow(seg_img_ds[-1].numpy())
                plt.axis('off')
                fig.add_subplot(1,3,3)
                plt.imshow(I[:,:])
                plt.axis('off')
                plt.show()
            batch_loss = 0
            
    net.eval() 
    # Checkpoint Network and Optimiser after each epoch
    best_score,score=Validate_IOU(net,optimizer,epoch,losses,Val_Score,bs,learning_rate,VAL_Loader,best_score,net_file)
    print("Current validation score: ", score)
    losses.append(running_loss)
    Val_Score.append(score)

In [None]:
# Visualize how well it works on the validation set
net.eval()
for i, data in enumerate(VAL_Loader, 0):
        src_img, seg_img, seg_img_ds, src_img_raw = data

        Input = Variable(src_img).float().cuda()
        Target = Variable(seg_img_ds.long(), requires_grad=False).cuda()
        
        Output = net(Input)
        out=nn.functional.softmax(Output[-1,:,:,:], dim=0)
        out=out.cpu()
        out=out.detach().numpy()

        I=np.argmax(out,axis=0)
        fig = plt.figure()
        fig.add_subplot(1,3,1)
        plt.imshow(np.transpose(src_img[-1].numpy(), axes=(1,2,0)))
        fig.add_subplot(1,3,2)
        plt.imshow(seg_img_ds[-1].numpy())
        #plt.imshow(np.argmax(seg_img_ds[-1].numpy(), axis=0))
        plt.axis('off')
        fig.add_subplot(1,3,3)
        plt.imshow(I[:,:])
        plt.axis('off')
        plt.show()

In [None]:
# Code to test some visualization of the CRF
en = enumerate(Train_Loader, 0)

test_input, test_gt, test_gt_ds, test_input_raw = next(en)[1]
test_input_raw = test_input_raw[0].clone()
test_gt = test_gt[0]
test_gt_ds = test_gt_ds[0]
test_input = test_input.cuda()
m = nn.Upsample(size=(test_input.shape[2],test_input.shape[3]), mode='bilinear', align_corners=True)
test_out = net(test_input)[0]
test_out = nn.functional.softmax(test_out[:,:,:], dim=0)
test_out = test_out[:,:,:]
test_out_upsampled = m(test_out.unsqueeze(0)).cpu()
_, test_out_labels = test_out_upsampled[0].max(dim=0)

unary_ori = test_out_upsampled[0].cpu().detach().numpy()
unary = unary_ori.reshape((7,-1))
unary = - np.log(unary)
image = (test_input_raw*255).byte().cpu().detach().numpy()
#image = np.transpose(image, axes=(1,2,0)).copy()
after_crf = crf.applyDenseCRF(unary, image,70,5,3,5,3)
#after_crf = np.log(after_crf)
after_crf = np.array(after_crf).reshape(7,image.shape[0],image.shape[1])
unary = unary.reshape(7,image.shape[0],image.shape[1])

for part in range(7):
    print("Part ",part)
    plt.imshow(unary_ori[part])
    plt.show()
    plt.imshow(after_crf[part])
    plt.show()

plt.imshow(test_gt.numpy())
plt.show()
plt.imshow(np.argmax(unary_ori,axis=0))
plt.show()
plt.imshow(np.argmax(after_crf,axis=0))
plt.show()