In [None]:
import cv2
import numpy as np

In [None]:
%matplotlib inline
import cv2
import numpy as np
import matplotlib.pyplot as plt
import chainercv
from chainercv.datasets import VOCSemanticSegmentationDataset
from chainer.datasets import TransformDataset
from chainercv.evaluations import eval_semantic_segmentation
import os.path
from os import path
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import random
import glob
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import copy
use_cuda = torch.cuda.is_available()
device =  torch.device("cuda" if use_cuda else "cpu")
from PIL import Image
import torch.utils.data as data
from torch.utils.data import DataLoader

In [None]:
from vslam.semantic import Net
print(device)

## Create Dataset

In [None]:
class DataLoaderSegmentation(data.Dataset):
        def __init__(self, split):
            super(DataLoaderSegmentation, self).__init__()
        
            if split not in ['train', 'val', 'test']:
                raise ValueError('please pick split from \'train\', \'test\', or \'val\'')
            
            dir_path = os.getcwd()
            self.split = split
            self.image_paths = os.path.join(dir_path, 'data', 'tas500v1.1', 'tas500v1.1', split)
            if self.split != 'test':
                self.label_paths = os.path.join(dir_path, 'data', 'tas500v1.1', 'tas500v1.1', split + '_labels_ids')
                self.labels = os.listdir(self.label_paths)
            self.files = os.listdir(self.image_paths)
            
        def __len__(self):
            return len(self.files)
        
        def __getitem__(self, idx):
            image_name = self.files[idx]
            
            image = Image.open(os.path.join(self.image_paths, image_name)).convert('RGB')

            if self.split != 'test':
                label_name = self.labels[idx]
                mask  = Image.open(os.path.join(self.label_paths, label_name))
                x, y = self.transformData(image, mask)
                y = np.where((y != 6) & (y != 7) & (y != 20), 2, y)
                y = np.where((y == 6) | (y == 7), 0, y)
                y = np.where(y == 20, 1, y)
                x, y = torch.from_numpy(x).float(), torch.from_numpy(y).long()
                x = torch.permute(x, (2, 0, 1))
                return x, y
            else:
                x, y = self.transformData(image, None)
                x = torch.from_numpy(x).float()
                x = torch.permute(x, (2, 0, 1))
                return x
    
        def transformData(self, image, mask=None):
            # Random crop
            i, j, h, w = transforms.RandomCrop.get_params(
                image, output_size=(512,512))
            image = TF.crop(image, i, j, h, w)
            if mask is not None:
                mask = TF.crop(mask, i, j, h, w)

            # Random horizontal flip
            if random.random() > 0.5:
                image = TF.hflip(image)
                if mask is not None:
                    mask = TF.hflip(mask)

            # Random Vertical Flip
            if random.random() > 0.5:
                image = TF.vflip(image)
                if mask is not None:
                    mask = TF.vflip(mask)

            image = np.array(image)
            if mask is not None:
                mask  = np.array(mask)
            
            image = image[:, :, ::-1].copy()
            
            return image, mask

In [None]:
train_data = DataLoaderSegmentation('train')

In [None]:
def find_image():
    trunk_label = 10
    bush_label = 4
    person_label = 20
    animal_label = 21
    count = 0
    max_count = 3
    for i in range(len(train_data)):
        tmp_count = np.unique(train_data[i][1], return_counts=True)
#         print(tmp_count)
        if person_label in tmp_count[0]:
#             if tmp_count[1][i] > 100
            if count == max_count:
                break
            else:
                img, label = train_data[i]
                fig = plt.figure(figsize=(4,3))
                ax = fig.add_subplot(1,1,1)
                plt.title('Image ' + str(i))
                ax.imshow(np.rollaxis(img.numpy().astype(int), 0, 0))
                count += 1

# find_image()

palette = [0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 128,
           128, 128, 128, 64, 0, 0, 192, 0, 0, 64, 128, 0, 192, 128, 0, 64, 0, 128, 192, 0, 128,
           64, 128, 128, 192, 128, 128, 0, 64, 0, 128, 64, 0, 0, 192, 0, 128, 192, 0, 0, 64, 128]

# Some relevant images
img_113, seg_img_113 = train_data[113]
img_28, seg_img_28 = train_data[28]
img_354, seg_img_354 = train_data[354]

def colorize_mask(mask):
    new_mask = Image.fromarray(mask.numpy().astype(np.uint8)).convert('P')
    new_mask.putpalette(palette)
                      
    return new_mask

def add_img_plot(fig, index, img, title, sub_plot_id):
    ax = fig.add_subplot(sub_plot_id[0],sub_plot_id[1],sub_plot_id[2])
    plt.title(title + str(index))
    ax.imshow(img)

    
fig = plt.figure(figsize=(10,10))

# person image
add_img_plot(fig, 113, np.rollaxis(img_113.numpy().astype(int), 0, 3), 'Image ', [3,2,1])
add_img_plot(fig, 113, colorize_mask(seg_img_113), 'Segmented Image ', [3,2,2])
# car image
add_img_plot(fig, 28, np.rollaxis(img_28.numpy().astype(int), 0, 3), 'Image ', [3,2,3])
add_img_plot(fig, 28, colorize_mask(seg_img_28), 'Segmented Image ', [3,2,4])
# animal image
add_img_plot(fig, 354, np.rollaxis(img_354.numpy().astype(int), 0, 3), 'Image ', [3,2,5])
add_img_plot(fig, 354, colorize_mask(seg_img_354), 'Segmented Image ', [3,2,6])

In [None]:
if device == 'cpu':
    train_loader = DataLoader(train_data, batch_size=5, shuffle=True, num_workers=0)
else:
    train_loader = DataLoader(train_data, batch_size=4, shuffle=True, num_workers=0)

In [None]:
torch.cuda.empty_cache()
from vslam.semantic import Net
untrained_net = Net().to(device)
untrained_net.eval()

sample_img, sample_target = train_data[113]

untrained_output = untrained_net.forward(sample_img[None].to(device))
if device != 'cpu':
    untrained_output = untrained_output.cpu()
untrained_nn_seg_img_113 = torch.argmax(untrained_output.cpu(), dim=1).numpy()[0]

fig = plt.figure(figsize=(8,2))
add_img_plot(fig, 113, transforms.ToPILImage()(img_113), 'Image ', [1,3,1])
add_img_plot(fig, 113, colorize_mask(seg_img_113), 'GT Segmented Image ', [1,3,2])
add_img_plot(fig, 113, colorize_mask_nn_output(untrained_nn_seg_img_113), 'NN Segmented Image', [1,3,3])

In [None]:
# import torch.backends.cudnn as cudnn

checkpoint = None
batch_size = 3
iterations = 170
weight_decay = 5e-4
momentum = 0.5

# cudnn.benchmark = True

def adjust_learning_rate(optimizer, lr_to):
    for g in optimizer.param_groups:
        g['lr'] = g['lr']*lr_to

def save_checkpoint(epoch, model, optimizer, scheduler):
    """
    Save model checkpoint.

    :param epoch: epoch number
    :param model: model
    :param optimizer: optimizer
    :param base_type: The base network type
    """
    state = {'epoch': epoch,
             'model': model,
             'optimizer': optimizer,
             'scheduler': scheduler}
    if scheduler == None:
        filename = 'checkpoint_test_jack.pth.tar'
    else:
        filename = 'checkpoint_mobilenet_v2_scheduler.pth.tar'
    torch.save(state, filename)

def train_mobilenet_v2(lr_type):
    """
    Training.
    """
    global start_epoch, label_map, epoch, checkpoint, delay_lr_at
    
    train_dataset = DataLoaderSegmentation('train')
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    
    if lr_type == 'original_scheduler':
        lr = 1e-1
        decay_lr_at = [45000, 55000] # decay learning rate after this many iterations
        decay_lr_to = 0.1
    elif lr_type == 'pytorch_scheduler':
        lr = 1e-3
    else:
        raise NotImplementedError
        
    epochs = iterations // (len(train_dataset) // batch_size)
    
    if lr_type == 'original_scheduler':
        decay_lr_at = [it // (len(train_dataset) // batch_size) for it in decay_lr_at]
        print("Epochs to decay learning rate:", decay_lr_at)
    
    if checkpoint is None:
        start_epoch = 0
        model = Net().to(device)
        
        optimizer = torch.optim.SGD(model.parameters(),
                               lr=lr,
                               weight_decay=weight_decay,
                               momentum=momentum,
                               nesterov=False)
        if lr_type == 'pytorch_scheduler':
            raise NotImplementedError
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        print('\nLoaded checkpoint from epoch %d.\n' % start_epoch)
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']
        if lr_type == 'pytorch_scheduler':
            raise NotImplementedError
        
    model.criterion = nn.CrossEntropyLoss(ignore_index=-1)
    
    loss_graph = []
    
    fig = plt.figure(figsize=(12,6))
    plt.subplots_adjust(bottom=0.2,right=0.85,top=0.95)
    ax = fig.add_subplot(1,1,1)
    
    for epoch in range(start_epoch, epochs):
        
        if lr_type == 'original_scheduler':
            if epoch in decay_lr_at:
                adjust_learning_rate(optimizer, decay_lr_to)
                for g in optimizer.param_groups:
                    print("Optimizer learning rate changed to {}".format(g['lr']))

        loss = train_model(train_loader=train_loader, model=model,loss_graph=loss_graph, epoch=epoch, device=device, optimizer=optimizer)
        if lr_type == 'pytorch_scheduler':
            raise NotImplementedError
        
        if lr_type == 'original_scheduler':
            save_checkpoint(epoch, model, optimizer, scheduler=None)
        else:
            save_checkpoint(epoch, model, optimizer, scheduler)
            
        ax.clear()
        ax.set_xlabel('iterations')
        ax.set_ylabel('loss value')
        ax.set_title('Training loss curve for trained  net')
        ax.plot(loss_graph, label='training loss')
        ax.legend(loc='upper right')
        fig.canvas.draw()
        print("Epoch: {} Loss: {}".format(epoch, loss))

def train_model(train_loader, model, optimizer, epoch, device, loss_graph):
    model.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        img, gt_seg_img = data.to(device), target.to(device)
        
        main_loss = model(img, gts=gt_seg_img)
        
        loss_graph.append(main_loss.item())
        
        print(main_loss.item())
        
        optimizer.zero_grad()
        main_loss.backward()
        optimizer.step()

    return main_loss

In [None]:
torch.cuda.empty_cache()
train_mobilenet_v2('original_scheduler')

In [None]:
model

In [None]:
val_data = DataLoaderSegmentation('val')
val_loader = DataLoader(val_data, batch_size=1, shuffle=True, num_workers=0)

checkpoint = torch.load('./checkpoint_mobilenet_v2.pth.tar')
model = checkpoint['model']
model = model.to(device)

train_dataset = DataLoaderSegmentation('train')
train_loader = DataLoader(train_dataset, batch_size=3, shuffle=True, num_workers=0)

print("mIoU over the validation dataset:{}".format(validate(val_loader, model)[1]))
print("mIoU over the training dataset:{}".format(validate(train_loader, model)[1]))

In [None]:
def colorize_mask_nn_output(mask):
    new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
    new_mask.putpalette(palette)
                      
    return new_mask

In [None]:
def validate(val_loader, net):
    iou_arr = []
    val_loss = 0
    
    net.eval()
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(val_loader):
            img, gt_seg_img = data.to(device), target.to(device)
            
            output = net(img)
            
            if device != 'cpu':
                output = output.cpu()
                gt_seg_img = gt_seg_img.cpu()
            pred = torch.argmax(output, dim=1).numpy()[0]
            
            gt_np = gt_seg_img.numpy()[0]
            
            conf = eval_semantic_segmentation(pred[None], gt_np[None])
            
            iou_arr.append(conf['miou'])
            
    return val_loss, (sum(iou_arr) / len(iou_arr))

In [None]:
model.eval()

def colorize_mask(mask):
    new_mask = Image.fromarray(mask.numpy().astype(np.uint8)).convert('P')
    new_mask.putpalette(palette)
                      
    return new_mask

def add_img_plot(fig, index, img, title, sub_plot_id):
    ax = fig.add_subplot(sub_plot_id[0],sub_plot_id[1],sub_plot_id[2])
    plt.title(title + str(index))
    ax.imshow(img)


def add_img_txt_plot(fig, index, img, title, sub_plot_id, txt):
    ax = fig.add_subplot(sub_plot_id[0],sub_plot_id[1],sub_plot_id[2])
    plt.title(title + str(index))
    ax.text(10, 25, 'mIoU = {:_>8.6f}'.format(txt), fontsize=20, color='white')
    ax.imshow(img)
    
def get_nn_seg_img(net, img, mask):
    img, get_seg_img = img, mask
    
    nn_seg_output = net.forward(img[None].cuda())
    
    if device != 'cpu':
        nn_seg_img = torch.argmax(nn_seg_output, dim=1).cpu().numpy()[0]
    else:
        nn_seg_img = torch.argmax(nn_seg_output, dim=1).numpy()[0]
        
    gts = get_seg_img.cpu().numpy()
    
    conf = eval_semantic_segmentation(nn_seg_img[None], gts[None])
    
#     print("View count of pixels within each label category:")
#     print(np.unique(get_seg_img, return_counts=True))
#     print(np.unique(nn_seg_img, return_counts=True))
    
    return nn_seg_img, conf['miou']


img_113, seg_img_113 = val_data[25]
img_28, seg_img_28 = train_data[99]
img_354, seg_img_354 = train_data[100]

nn_seg_img_113, miou_113 = get_nn_seg_img(model, img_113, seg_img_113)
nn_seg_img_28, miou_28 = get_nn_seg_img(model, img_28, seg_img_28)
nn_seg_img_354, miou_354 = get_nn_seg_img(model, img_354, seg_img_354)


fig = plt.figure(figsize=(20,15))

add_img_plot(fig, 113, np.rollaxis(img_113.numpy().astype(int), 0, 3), 'Image ', [3,3,1])
add_img_plot(fig, 113, colorize_mask(seg_img_113), 'GT Segmented Image ', [3,3,2])
add_img_txt_plot(fig, 113, colorize_mask_nn_output(nn_seg_img_113), 'NN Segmented Image ', [3,3,3], miou_113)
add_img_plot(fig, 28, np.rollaxis(img_28.numpy().astype(int), 0, 3), 'Image ', [3,3,4])
add_img_plot(fig, 28, colorize_mask(seg_img_28), 'Segmented Image ', [3,3,5])
add_img_txt_plot(fig, 28, colorize_mask_nn_output(nn_seg_img_28), 'NN Segmented Image ', [3,3,6], miou_28)
add_img_plot(fig, 354, np.rollaxis(img_354.numpy().astype(int), 0, 3), 'Image ', [3,3,7])
add_img_plot(fig, 354, colorize_mask(seg_img_354), 'Segmented Image ', [3,3,8])
add_img_txt_plot(fig, 354, colorize_mask_nn_output(nn_seg_img_354), 'NN Segmented Image ', [3,3,9], miou_354)

In [None]:
test_data = DataLoaderSegmentation('test')

In [None]:
def get_nn_seg_img_test(net, data):
    img = data

    nn_seg_output = net.forward(img[None].cuda())

    if device != 'cpu':
        nn_seg_img = torch.argmax(nn_seg_output, dim=1).cpu().numpy()[0]
    else:
        nn_seg_img = torch.argmax(nn_seg_output, dim=1).numpy()[0]

#     print("View count of pixels within each label category:")
#     print(np.unique(nn_seg_img, return_counts=True))

    return nn_seg_img

test_img_9 = test_data[25]

test_nn_seg_img_9 = get_nn_seg_img_test(model, test_data[25])

fig = plt.figure(figsize=(20,15))

add_img_plot(fig, 9, np.rollaxis(test_img_9.numpy().astype(int), 0, 3),'Test Image ', [3,3,1])
add_img_plot(fig, 9, colorize_mask_nn_output(test_nn_seg_img_9), 'Test NN Segmented Image ', [3,3,2])