In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.optim
import torch.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision

import torchvision.transforms as transforms
import torch.optim as optim

import torch.nn as nn
from torch.autograd import Variable

from include.data import ObjectSegmentationDataset
from include.hybrid_net_batchnorm import SegmentationModel

from include.Utility_Functions import Validate_IOU
from torch.nn.modules.upsampling import Upsample

from torch.utils.data.sampler import SubsetRandomSampler


from include.Utility_Functions import IOU
from include import crf
import time


In [None]:
# Loaders for Training and Validation

# It is important to normalise the Input images acording to
# https://pytorch.org/docs/stable/torchvision/models.html

norm_transform = transforms.Normalize(mean=[0.45752926384356873,0.4377081874543038,0.40432555437277296],std=[0.2664644516691583,0.2634024345463397,0.2775109722016356])
transform = transforms.Compose([transforms.ToTensor(),norm_transform])
number_of_classes = 7
crf_dataset = ObjectSegmentationDataset(
    src_image_dir="/home/snajder/AML/data/VOCdevkit/VOC2012/JPEGImages",
    seg_image_dir="/home/snajder/AML/data/VOCdevkit/VOC2010/person_trainval/Annotations_Part_images",
    num_classes=number_of_classes,
    transform=transform,
    gt_one_hot=False,
    augment=False,
    rescale=False)    


In [None]:
net = SegmentationModel(num_classes=7)
net_file="Segnet_Best_params_VOC_Person_7class_09-16.pth"
checkpoint = torch.load(net_file)
net.load_state_dict(checkpoint['state_dict'])
net = net.cuda()
net.eval()
pass

In [None]:
bs=1

dataset_len = len(crf_dataset)
indices = list(range(dataset_len))
num_crf_train = 100
num_validation = 500
num_test = 250

np.random.seed(43)
validation_idx = np.random.choice(indices, size=num_validation, replace=False)
train_idx = list(set(indices) - set(validation_idx))
test_idx = np.random.choice(train_idx, size=num_test, replace=False)
train_idx = list(set(train_idx) - set(validation_idx))
crf_train_idx = np.random.choice(validation_idx, size=num_crf_train, replace=False)
np.random.seed(int(time.time()))

Train_Loader = torch.utils.data.DataLoader(crf_dataset, batch_size=1, num_workers=4, sampler=SubsetRandomSampler(crf_train_idx))
Test_Loader = torch.utils.data.DataLoader(crf_dataset, batch_size=1, num_workers=4, sampler=SubsetRandomSampler(test_idx))
print(dataset_len)

In [None]:
total_improvement = crf.gridSearchCRFParameters(net, Train_Loader, norm_transform, 7)
total_improvement

In [None]:
best_params = np.unravel_index(np.argmax(total_improvement), total_improvement.shape)
print(best_params)
print("Best improvement", total_improvement[best_params])
print("Improvement with published parameters", total_improvement[2,5,2])

In [None]:
val_set_size = 0
TP = np.zeros(7)
FP = np.zeros(7)
FN = np.zeros(7)
TP_afterCRF= np.zeros(7)
FP_afterCRF = np.zeros(7)
FN_afterCRF = np.zeros(7)
for data in Test_Loader:
        src_img, seg_img, seg_img_ds, src_img_raw = data
        val_set_size += src_img.shape[0]
        
        Input = Variable(src_img, requires_grad=False).float().cuda()
        Target = Variable(seg_img.long(), requires_grad=False).cuda()

        Output = net(Input)
        upsampler = nn.Upsample(size=(src_img.shape[2],src_img.shape[3]), mode='bilinear', align_corners=True)
        Output_upsampled = upsampler(Output)
        out=nn.functional.softmax(Output_upsampled[:,:,:,:], dim=1).cpu().detach().numpy()
        Result=np.argmax(out, axis=1)
        for i in range(Output.shape[0]):
            GT=Target[i,:,:].cpu().numpy()
            for n in range(7):
                A = GT == n
                B = Result[i] == n
                TP[n] += np.sum(B & A)
                FP[n] += np.sum(B & ~A)
                FN[n] += np.sum(~B & A)
            # now apply CRF
            unary = out[i].reshape((7,-1))
            unary = - np.log(unary)
            image = (src_img_raw[i]*255).byte().cpu().detach().numpy()
            after_crf = crf.applyDenseCRF(unary, image,70,5,3,3,3)
            after_crf = np.array(after_crf).reshape(7,image.shape[0],image.shape[1])
            Result_after_crf = np.argmax(after_crf, axis=0)
            for n in range(7):
                A = GT == n
                B = Result_after_crf == n
                TP_afterCRF[n] += np.sum(B & A)
                FP_afterCRF[n] += np.sum(B & ~A)
                FN_afterCRF[n] += np.sum(~B & A)
            IOU = TP / ( TP + FP + FN)
            IOU_afterCRF = TP_afterCRF / ( TP_afterCRF + FP_afterCRF + FN_afterCRF)
            
            print("Image %d, mean IOU w/o CRF: %f, mean IOU with CRF: %f" % (val_set_size, np.mean(IOU), np.mean(IOU_afterCRF) ))
            print("DCNN IOUS: ", IOU)

In [None]:
en = enumerate(Test_Loader, 0)

In [None]:

test_input, test_gt, test_gt_ds, test_input_raw = next(en)[1]
test_input_raw = test_input_raw[0].clone()
test_gt = test_gt[0]
test_gt_ds = test_gt_ds[0]
test_input = test_input.cuda()
m = nn.Upsample(size=(test_input.shape[2],test_input.shape[3]), mode='bilinear', align_corners=True)
test_out = net(test_input)[0]
test_out = nn.functional.softmax(test_out[:,:,:], dim=0)
test_out = test_out[:,:,:]
test_out_upsampled = m(test_out.unsqueeze(0)).cpu()
_, test_out_labels = test_out_upsampled[0].max(dim=0)

In [None]:
plt.imshow(test_input_raw)

In [None]:
unary_ori = test_out_upsampled[0].cpu().detach().numpy()
unary = unary_ori.reshape((number_of_classes,-1))
unary = - np.log(unary)
image = (test_input_raw*255).byte().cpu().detach().numpy()
#image = np.transpose(image, axes=(1,2,0)).copy()
after_crf = crf.applyDenseCRF(unary, image,70,5,3,5,3)
#after_crf = np.log(after_crf)
after_crf = np.array(after_crf).reshape(number_of_classes,image.shape[0],image.shape[1])
unary = unary.reshape(number_of_classes,image.shape[0],image.shape[1])



In [None]:
for part in range(7):
    print("Part ",part)
    plt.imshow(unary_ori[part])
    plt.show()
    plt.imshow(after_crf[part])
    plt.show()

plt.imshow(test_gt.numpy())
plt.show()
plt.imshow(np.argmax(unary_ori,axis=0))
plt.show()
plt.imshow(np.argmax(after_crf,axis=0))
plt.show()