In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
import time
import os
import copy
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
from PIL import Image
from FCN8 import FCN8s
from segnet import segnet
from createDataset import MyDataset
from utils import *
from torchvision.utils import make_grid, save_image
import datetime
from sklearn.model_selection import train_test_split
import json
# from tensorboardX import SummaryWriter
import cv2

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu')
print(device)
dataroot = 'data/'
batch = 16
num_classes=39
img_size = (128,128)

cpu


In [3]:
########### Transforms ###########
mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
input_transforms = transforms.Compose([
        transforms.Resize(img_size, interpolation = 1),
        transforms.ToTensor(),
])
to_tensor =  transforms.Compose([transforms.ToTensor()])


In [4]:
########### Dataloader ###########
seg_path = 'gt_labels/'
img_path = 'leftImg8bit_orig/'

colpath = os.path.join(dataroot, img_path)
segpath = os.path.join(dataroot, seg_path)

# colimg = os.listdir(colpath)
# segimg = os.listdir(segpath)

X_train = os.listdir(os.path.join(colpath,'train'))
Y_train = os.listdir(os.path.join(segpath,'train'))
X_val = os.listdir(os.path.join(colpath,'val'))
Y_val = os.listdir(os.path.join(segpath,'val'))
X_test = os.listdir(os.path.join(colpath,'test'))
Y_test = os.listdir(os.path.join(segpath,'test'))


                
# X_train, X_test, Y_train, Y_test = train_test_split(colimg, segimg, random_state=123)

train_dataset = MyDataset(X_train, Y_train, dataroot, in_transforms = input_transforms, size = img_size,
	phase = 'train')
test_dataset = MyDataset(X_test, Y_test, dataroot, in_transforms = input_transforms, size = img_size,
	phase = 'test')
val_dataset = MyDataset(X_val, Y_val, dataroot, in_transforms = input_transforms, size = img_size,
	phase = 'val')


train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch, shuffle = True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch, shuffle=False)
val_dataloader = torch.utils.data.DataLoader(val_dataset,batch_size=batch,shuffle=False)

In [5]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [6]:
with open('id_to_color.txt', 'r') as f:
    id_to_color_map = json.load(f)

id_to_color_map = {int(key): value for key, value in id_to_color_map.items()}

In [7]:
def convertImgToSegColMap(img):
    new_img = np.zeros((img.shape[0],img.shape[1],3))
#     print(new_img.shape,img.shape)
    valid_keys = np.unique(img)
    
#     new_img[:,:,0],new_img[:,:,1],new_img[:,:,2] = img[:,:,0],img[:,:,0],img[:,:,0]
    for key in valid_keys:
#         print(key)
        x,y = np.where(img==key)
#         print(len(x),len(y))
        
        coords = [list(coord) for coord in zip(x,y)]
#         print(len(coords))
#         print(max(x),max(y))
#         print(id_to_color_map[key])
        for coord in coords:

#             print(coord)
            
            new_img[coord[0],coord[1]] = id_to_color_map[key]
    return new_img
    

In [8]:
def calculatePixelAcc(predictedImg,gt_img):
    pred_arr = predictedImg.reshape(-1)
    gt_img = gt_img.reshape(-1)
    corr_arr = np.zeros_like(pred_arr)
    corr_arr[pred_arr==gt_img] = 1
    return sum(corr_arr)/pred_arr.shape[0]

In [9]:
def checkDataImbalance(root,img_dir):
    dataPath = os.path.join(root,img_dir)
    count_labels = np.zeros((num_classes))
    all_labels = os.listdir(dataPath)
    for idx,label in enumerate(all_labels):
        
        gt_lab = cv2.imread(os.path.join(root,img_dir,label),0)
        un_labs = np.unique(gt_lab)
        count_labels[un_labs]+=1

#     print(un_labs)
    return count_labels

In [10]:
def tic():
    # Homemade version of matlab tic and toc functions
    import time
    global startTime_for_tictoc
    startTime_for_tictoc = time.time()

def toc():
    import time
    if 'startTime_for_tictoc' in globals():
        print ("Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds.")
    else:
        print ("Toc: start time not set")

In [14]:
count_labels = checkDataImbalance(segpath,'train')

In [15]:
print(count_labels.astype(int))

[4966  475 3082 4998 2007 3811    0 2931 1582    0  483 4993 4975 4317
 3447  421 2742 4981    2  339   12 5022   77    1  279 2776    7    2
 4190 1616 4438  464 2976 4262 4435  197 2230  434    0]


In [17]:
count_labels = count_labels.astype(int)

In [22]:
with open('class_id_to_label_map.txt', 'r') as f:
    id_to_label_map = json.load(f)
    
id_to_label_map = {int(key): value for key, value in id_to_label_map.items()}
print(id_to_label_map)

{0: 'out of roi', 1: 'sky', 2: 'animal', 3: 'car', 4: 'bus', 5: 'wall', 6: 'sidewalk', 7: 'tunnel', 8: 'curb', 9: 'train', 10: 'fallback background', 11: 'caravan', 12: 'obs-str-bar-fallback', 13: 'person', 14: 'non-drivable fallback', 15: 'road', 16: 'rectification border', 17: 'billboard', 18: 'drivable fallback', 19: 'license plate', 20: 'traffic light', 21: 'bridge', 22: 'rider', 23: 'bicycle', 24: 'truck', 25: 'building', 26: 'vehicle fallback', 27: 'trailer', 28: 'vegetation', 29: 'autorickshaw', 30: 'fence', 31: 'polegroup', 32: 'rail track', 33: 'motorcycle', 34: 'guard rail', 35: 'traffic sign', 36: 'pole', 37: 'parking', 38: 'ego vehicle'}


In [24]:
occ_dict = {}
for idx, occ in enumerate(count_labels):
    occ_dict[id_to_label_map[idx]]=occ
    
print(occ_dict)

{'guard rail': 4435, 'caravan': 4993, 'vehicle fallback': 7, 'road': 421, 'pole': 2230, 'train': 0, 'autorickshaw': 1616, 'building': 2776, 'person': 4317, 'vegetation': 4190, 'tunnel': 2931, 'sky': 475, 'rectification border': 2742, 'curb': 1582, 'non-drivable fallback': 3447, 'traffic sign': 197, 'parking': 434, 'car': 4998, 'billboard': 4981, 'fence': 4438, 'truck': 279, 'out of roi': 4966, 'obs-str-bar-fallback': 4975, 'bridge': 5022, 'bus': 2007, 'sidewalk': 0, 'animal': 3082, 'rail track': 2976, 'ego vehicle': 0, 'polegroup': 464, 'bicycle': 1, 'wall': 3811, 'license plate': 339, 'fallback background': 483, 'drivable fallback': 2, 'rider': 77, 'trailer': 2, 'motorcycle': 4262, 'traffic light': 12}


In [11]:
saved_models =['weights_fcn.pth','weighted_fcn.pth','weights_segnet_ep8.pth','weights_wt_segnet_sameOP.pth']
op_names = ['unweighted_fcn','weighted_fcn','unweighted_segnet','weighted_segnet']
root = 'demoSet'
gtDir = 'gt_labels'
col_Dir = 'gt_labels_colored'
orig_Dir = 'leftImg8bit_orig'

saved_output_dir = 'saved_outputs'
all_imgs = os.listdir(os.path.join(root,orig_Dir))
all_test_imgs = os.listdir(os.path.join(colpath,'test'))

total_images = len(all_test_imgs)
total_iou = 0
total_pix_acc = 0
# print(id_to_color_map)
#### test model accuracy #######
for idxx,saved_model in enumerate(saved_models):
    if idxx==0 or idx==1:
        model = FCN8s(num_classes)
    else:  
        model = segnet(num_classes)
    model.load_state_dict(torch.load(saved_model))
    model.eval()
    model.to(device)
    tic()
    for idx,img_ in enumerate(all_test_imgs):
         if idx==6:
            img1 = Image.open(os.path.join(colpath, 'test' ,img_))
            img = input_transforms(img1)
            img.unsqueeze_(0)
            img = img.to(device)
    #         print(img.size())

    #             optimizer.zero_grad()
            output = model(img)

    #             print(output.size())
            output.squeeze_(0)
            output_labels = torch.argmax(output,dim=0)
    #             print(output_labels.size())
    #             print(output_labels)

            np_img = output_labels.detach().cpu().numpy()
            np_img = np_img.reshape(np_img.shape[0],np_img.shape[1],1)
    #             print(np_img.shape)

    #             np.save('ac.npy',np_img)
            img1.save(os.path.join('saved_demo_outputs/','orig.png'))
            img_col = Image.open(os.path.join(dataroot,col_Dir,'test',img_))
            img_col.save(os.path.join('saved_demo_outputs/','col.png'))
                         
            np_img_res = cv2.resize(np_img,img1.size,interpolation = cv2.INTER_NEAREST)
            cv2.imwrite(os.path.join('saved_demo_outputs/','out.png'),np_img_res)

            gt_img_name = img_[:img_.find('_')]+str('_id_gt.png')

            gt_img = cv2.imread(os.path.join(segpath,'test',gt_img_name),0)

            gt_tsor_img = Image.open(os.path.join(segpath,'test',gt_img_name))
    #             print(gt_tsor_img.size)
    #             gt_tsor_img = gt_tsor_img[:,:,0]
    #             gt_tsor_img = Image.fromarray(gt_img)

    #             cv2.imwrite('testing.png',gt_img)

    #             gt_tsor = input_transforms(gt_tsor_img)
    #             gt_tsor = gt_tsor[0,:,:]
    #             print('Tsor',gt_tsor.shape)
    #             gt_tsor = gt_tsor.long()
    #             gt_tsor.unsqueeze_(0)
    #             gt_tsor.unsqueeze_(0)
    #             gt_tsor = gt_tsor.to(device)
    #             print('this',gt_img.shape)
    #             print(gt_img[:,:,0]-gt_img[:,:,1])


            col_gt_img = convertImgToSegColMap(gt_img)

            my_gt_name = str(op_names[idxx])+str('_col_gt.png')
            cv2.imwrite(os.path.join('saved_demo_outputs/',my_gt_name),col_gt_img)

            col_seg_img = convertImgToSegColMap(np_img_res)
            my_seg_name = str(op_names[idxx])+str('_col_seg.png')
            cv2.imwrite(os.path.join('saved_demo_outputs/',my_seg_name),col_seg_img)


    #             print(np_img_res.shape)
    #             print(np.array(gt_img).shape)
    #             gt_img = np.array(gt_img)[:,:,0]
    #             print(gt_img)
    #             print(gt_img.shape)
    #             print(gt_img.reshape(-1).shape)
            pix_acc = calculatePixelAcc(np_img_res,gt_img)
            output = output.unsqueeze_(0)
    #         print(output.size())
    #             test = torch.max(output.data, 1)[1]
    #             test = test.long()
    #             test = test.to(device)

            out_labels = torch.max(output.data, 1)[1]
    #             out_labels.unsqueeze_(0)
            out_labels_np = out_labels.cpu().numpy().transpose((1,2,0))

    #             print('outlabels',out_labels_np.shape)
    #             print('gt_tsor',gt_tsor.size())
    #             print(gt_tsor)

            gt_res_img = cv2.resize(gt_img,(128,128),interpolation = cv2.INTER_NEAREST)
            gt_res_img = gt_res_img.reshape(gt_res_img.shape[0],gt_res_img.shape[1],1)
    #             print('gt_img_res',gt_res_img.shape)
            intersection = np.bitwise_and(out_labels_np,gt_res_img)
            union = np.bitwise_or(out_labels_np,gt_res_img)
    #             print(intersection,union)
            iou = np.mean(np.sum(intersection)/np.sum(union))
            print('iou',iou)
    #             iou = torch.mean((torch.sum(intersection).float()/torch.sum(union).float()).float())
    #             print('iou = ',checkiou(out_labels,gt_tsor,1).item())
            total_iou+=iou
            total_pix_acc+=pix_acc
    toc()
#     print("Mean IOU = ",total_iou/total_images)
#     print("Total Pix Acc = ",total_pix_acc/total_images)

KeyboardInterrupt: 

In [19]:
378/total_images

0.2380352644836272