In [1]:
# This notebook should be used inside the repository (https://github.com/VainF/DeepLabV3Plus-Pytorch)
# This repository should be downloaded and all requirements should be installed.

In [2]:
from tqdm import tqdm
import network
import utils
import os
import random
import argparse
import numpy as np
from collections import namedtuple
import pandas as pd

from torch.utils import data
from datasets import VOCSegmentation, Cityscapes
from utils import ext_transforms as et
from metrics import StreamSegMetrics

import torch
import torch.nn as nn
from torchvision import transforms

from utils.visualizer import Visualizer

from PIL import Image
import matplotlib
import matplotlib.pyplot as plt

In [3]:
!pwd

/home/ghadeer/Projects/KAMAZ/DeepLabV3Plus-Pytorch


In [4]:
# Parameters
seed = 0
num_classes = 19
output_stride = 8
batch_size = 4
crop_size = 513

In [5]:
def validate(opts, model, loader, device, metrics, ret_samples_ids=None):
    """Do validation and return specified samples"""
    metrics.reset()
    ret_samples = []
    if opts.save_val_results:
        if not os.path.exists('results'):
            os.mkdir('results')
        denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
        img_id = 0

    with torch.no_grad():
        for i, (images, labels) in tqdm(enumerate(loader)):
            
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            outputs = model(images)
            preds = outputs.detach().max(dim=1)[1].cpu().numpy()
            targets = labels.cpu().numpy()

            metrics.update(targets, preds)
            if ret_samples_ids is not None and i in ret_samples_ids:  # get vis samples
                ret_samples.append(
                    (images[0].detach().cpu().numpy(), targets[0], preds[0]))

            if opts.save_val_results:
                for i in range(len(images)):
                    image = images[i].detach().cpu().numpy()
                    target = targets[i]
                    pred = preds[i]

                    image = (denorm(image) * 255).transpose(1, 2, 0).astype(np.uint8)
                    target = loader.dataset.decode_target(target).astype(np.uint8)
                    pred = loader.dataset.decode_target(pred).astype(np.uint8)

                    Image.fromarray(image).save('results/%d_image.png' % img_id)
                    Image.fromarray(target).save('results/%d_target.png' % img_id)
                    Image.fromarray(pred).save('results/%d_pred.png' % img_id)

                    fig = plt.figure()
                    plt.imshow(image)
                    plt.axis('off')
                    plt.imshow(pred, alpha=0.7)
                    ax = plt.gca()
                    ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    plt.savefig('results/%d_overlay.png' % img_id, bbox_inches='tight', pad_inches=0)
                    plt.close()
                    img_id += 1

        score = metrics.get_results()
    return score, ret_samples

Defining the encoding for the result of the network according to the pretrained model dataset

In [6]:
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
                                                 'has_instances', 'ignore_in_eval', 'color'])
classes = [
    CityscapesClass('unlabeled',            0, 255, 'void', 0, False, True, (0, 0, 0)),
    CityscapesClass('ego vehicle',          1, 255, 'void', 0, False, True, (0, 0, 0)),
    CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
    CityscapesClass('out of roi',           3, 255, 'void', 0, False, True, (0, 0, 0)),
    CityscapesClass('static',               4, 255, 'void', 0, False, True, (0, 0, 0)),
    CityscapesClass('dynamic',              5, 255, 'void', 0, False, True, (111, 74, 0)),
    CityscapesClass('ground',               6, 255, 'void', 0, False, True, (81, 0, 81)),
    CityscapesClass('road',                 7, 0, 'flat', 1, False, False, (128, 64, 128)),
    CityscapesClass('sidewalk',             8, 1, 'flat', 1, False, False, (244, 35, 232)),
    CityscapesClass('parking',              9, 255, 'flat', 1, False, True, (250, 170, 160)),
    CityscapesClass('rail track',           10, 255, 'flat', 1, False, True, (230, 150, 140)),
    CityscapesClass('building',             11, 2, 'construction', 2, False, False, (70, 70, 70)),
    CityscapesClass('wall',                 12, 3, 'construction', 2, False, False, (102, 102, 156)),
    CityscapesClass('fence',                13, 4, 'construction', 2, False, False, (190, 153, 153)),
    CityscapesClass('guard rail',           14, 255, 'construction', 2, False, True, (180, 165, 180)),
    CityscapesClass('bridge',               15, 255, 'construction', 2, False, True, (150, 100, 100)),
    CityscapesClass('tunnel',               16, 255, 'construction', 2, False, True, (150, 120, 90)),
    CityscapesClass('pole',                 17, 5, 'object', 3, False, False, (153, 153, 153)),
    CityscapesClass('polegroup',            18, 255, 'object', 3, False, True, (153, 153, 153)),
    CityscapesClass('traffic light',        19, 6, 'object', 3, False, False, (250, 170, 30)),
    CityscapesClass('traffic sign',         20, 7, 'object', 3, False, False, (220, 220, 0)),
    CityscapesClass('vegetation',           21, 8, 'nature', 4, False, False, (107, 142, 35)),
    CityscapesClass('terrain',              22, 9, 'nature', 4, False, False, (152, 251, 152)),
    CityscapesClass('sky',                  23, 10, 'sky', 5, False, False, (70, 130, 180)),
    CityscapesClass('person',               24, 11, 'human', 6, True, False, (220, 20, 60)),
    CityscapesClass('rider',                25, 12, 'human', 6, True, False, (255, 0, 0)),
    CityscapesClass('car',                  26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
    CityscapesClass('truck',                27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
    CityscapesClass('bus',                  28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
    CityscapesClass('caravan',              29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
    CityscapesClass('trailer',              30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
    CityscapesClass('train',                31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
    CityscapesClass('motorcycle',           32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
    CityscapesClass('bicycle',              33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
    CityscapesClass('license plate',        -1, 255, 'vehicle', 7, False, True, (0, 0, 142)),
]

train_id_to_color = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)]
train_id_to_color.append([0, 0, 0])
train_id_to_color = np.array(train_id_to_color)
id_to_train_id = np.array([c.train_id for c in classes])


def decode_result(res):
    return train_id_to_color[res]

In [7]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: %s" % device)

Device: cuda


In [8]:
# Setup random seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [9]:
# Define the model architecture
model = network.deeplabv3plus_mobilenet(num_classes=num_classes, output_stride=output_stride)

In [10]:
 optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.001},
        {'params': model.classifier.parameters(), 'lr': 0.01},
    ], lr=0.01, momentum=0.9, weight_decay=1e-4)

In [11]:
# Load pretrained model
checkpoint = torch.load("checkpoints/best_deeplabv3plus_mobilenet_cityscapes_os16.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"])
cur_itrs = checkpoint["cur_itrs"]
model.to(device)

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (low_level_features): Sequential(
      (0): ConvBNReLU(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNReLU(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNReLU(
            (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        

In [49]:
model.eval()

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (low_level_features): Sequential(
      (0): ConvBNReLU(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNReLU(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU6(inplace=True)
          )
          (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): InvertedResidual(
        (conv): Sequential(
          (0): ConvBNReLU(
            (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        

In [53]:
# Transform the Image to be the same as the training data
img_transform = transforms.Compose([
    transforms.Resize((crop_size-1, crop_size-1)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])

In [54]:
path_to_data = "data/"
path_to_results = os.path.join(path_to_data, "results")
images = [img_name for img_name in os.listdir(path_to_data+"RGB/")]



In [55]:
# Test the pretrained model on the images
for name in images:
    img = Image.open(os.path.join(path_to_data, "RGB", name)).convert('RGB')
    img = img_transform(img).unsqueeze(0)
    img = img.to(device, dtype=torch.float32)
    outputs = model(img)
    preds = outputs.detach().max(dim=1)[1].cpu().numpy()
    pred = preds[0]
    decoded_preds = decode_result(preds).astype(np.uint8)
    Image.fromarray(decoded_preds[0]).save(os.path.join(path_to_results, name))

In [59]:
img = Image.open("data/RGB/0000051.png").convert('RGB')
# img.show()
img = img_transform(img).unsqueeze(0)
img = img.to(device, dtype=torch.float32)

In [60]:
outputs = model(img)
preds = outputs.detach().max(dim=1)[1].cpu().numpy()

In [61]:
image = img[0].detach().cpu().numpy()
image = (denorm(image) * 255).transpose(1, 2, 0).astype(np.uint8)
Image.fromarray(image).show() 

In [62]:
pred = preds[0]
decoded_preds = decode_result(preds).astype(np.uint8)
Image.fromarray(decoded_preds[0]).show()

In [24]:
print(pred[230][200:210])

[10 10 10 10 10 10 10 10 10 10]


In [25]:
train_transform = et.ExtCompose([
    #et.ExtResize( 512 ),
    et.ExtRandomCrop(size=(crop_size, crop_size)),
    et.ExtColorJitter( brightness=0.5, contrast=0.5, saturation=0.5 ),
    et.ExtRandomHorizontalFlip(),
    et.ExtToTensor(),
    et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
])

In [42]:
class CityscapesLabelEncoder:

    def __init__(self, select_classes=[]):
        # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py  
        self.labels = [
            (                   "name","id", "trainId",         "category",  "catId","hasInstances","ignoreInEval",        "color"),
            (  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
            (  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
            (  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
            (  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
            (  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
            (  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
            (  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
            (  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
            (  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
            (  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
            (  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
            (  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
            (  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
            (  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
            (  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
            (  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
            (  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
            (  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
            (  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
            (  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
            (  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
            (  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
            (  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
            (  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
            (  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
            (  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
            (  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
            (  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
            (  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
            (  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
            (  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
            (  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
            (  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
            (  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
            (  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
        ]

        # create labels dataframe
        self.cityscapes_labels_df = pd.DataFrame(self.labels[1:], columns=self.labels[0])
        self.cityscapes_labels_df.loc[self.cityscapes_labels_df["trainId"].isin([255, -1]), "trainId"] = 19
        self.categories = np.arange(self.cityscapes_labels_df["catId"].nunique())
        if select_classes:
            selected = self.cityscapes_labels_df[
                self.cityscapes_labels_df["name"].isin(select_classes)]["id"].unique()
            self.cityscapes_labels_df.loc[~self.cityscapes_labels_df["id"].isin(selected), "trainId"] = len(selected)
            for i, j in enumerate(selected):
                self.cityscapes_labels_df.loc[self.cityscapes_labels_df["id"] == j, "trainId"] = i
        self.classes = self.cityscapes_labels_df["trainId"].unique()
        self.classes.sort() # in-place labels ascending sort

    def make_ohe(self, labelIds, mode="catId"):
        """
        converts image with labels into the one-hot encoded format
        (img[...,] --> img[..., N_CLASSES])
        mode : `catId` or `trainId`
        """ 
        classes = self.categories
        if mode == "trainId":
            classes = self.classes

        if len(classes) == 2:
            classes = [0]
        
        labelIds2 = labelIds.copy()
#         for unique in np.unique(labelIds):
#             labelIds[labelIds == unique] = self.cityscapes_labels_df[self.cityscapes_labels_df["id"] == unique][mode]
#         labelIds = labelIds.astype(int)
        
        
        for unique in np.unique(labelIds):
            labelIds2[np.where((labelIds2 == unique).all(axis=2))] = self.cityscapes_labels_df[self.cityscapes_labels_df["id"] == unique][mode]
        labelIds2 = labelIds.astype(int)

        return labelIds2
    
#         ohe_labels = np.zeros(labelIds.shape[:2] + (len(classes),))
#         for c in classes:
#             ys, xs = np.where(labelIds[..., 0] == c)
#             ohe_labels[ys, xs, c] = 1
#         return ohe_labels.astype(int)

In [43]:
class Dataset_Collection(data.Dataset):
    # Add the dictionary of colors to id 
    def __init__(self, root, split='train', transform=None, select_classes=[]):
        self.root = os.path.expanduser(root)
        self.images_dir = os.path.join(self.root, 'data', split, 'RGB')
        self.targets_dir = os.path.join(self.root, 'data', split, 'GTDebug')
        self.transform = transform
        self.images = []
        self.targets = []
        self.label_encoder = CityscapesLabelEncoder(select_classes)
        
        if split not in ['train', 'test', 'val']:
            raise ValueError('Invalid split for mode! Please use split="train", split="test"'
                             ' or split="val"')
        for file_name in os.listdir(self.images_dir):
            self.images.append(os.path.join(self.images_dir,file_name))
            self.targets.append(os.path.join(self.targets_dir,file_name))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
            than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
        """
        image = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.targets[index])
        if self.transform:
            image, target = self.transform(image, target)
        target = self.encode_target(target)
        target = torch.from_numpy(np.array(target, dtype='uint8'))
        return image, target
    
    def __len__(self):
        return len(self.images)
    
    def encode_target(self, target):
        target = target.detach().cpu().numpy()
        return self.label_encoder.make_ohe(target)



In [63]:
dataset = Dataset_Collection("", transform=train_transform)
criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
loader = data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [64]:
TL_Label =  (250,170,30, 255)
TL_Label =  (250,170,30, 255)

In [69]:
current_path = os.path.abspath(os.getcwd())

depth_path = os.path.join(current_path, "data/Depth")
depth_deb_path = os.path.join(current_path, "data/DepthDebug")
gt_deb_path = os.path.join(current_path, "data/GTDebug")
rgb_path = os.path.join(current_path, "data/RGB")
res_path = os.path.join(current_path, "data/results")

In [72]:
image_name = "0000013.png"

In [73]:
mask = Image.open(os.path.join(gt_deb_path, image_name)) 
pixels = np.array(mask)
pixels_mask = np.array([int(pixels[i][j][0]==TL_Label[0] and pixels[i][j][1]==TL_Label[1]) for i in range(len(pixels)) for j in range(len(pixels[0]))])

In [74]:
pixels_mask = np.reshape(pixels_mask,(1080,1920))
pixels_mask = [pixels_mask, pixels_mask, pixels_mask, pixels_mask]
pixels_mask = np.swapaxes(pixels_mask, 0,2)
pixels_mask = np.swapaxes(pixels_mask, 1,0)
inv_mask = 1-pixels_mask

depth = Image.open(os.path.join(depth_path, image_name)) 
depth_pixels = np.array(depth)


In [80]:
rgb = Image.open(os.path.join(rgb_path, image_name)) 
rgb_pixels = np.array(rgb)

In [84]:
new_rgb_pixels = inv_mask*rgb_pixels + pixels_mask*TL_Label
new_rgb = Image.fromarray(new_rgb_pixels.astype(np.uint8)).save(current_path+"/test_rgb.png")

In [99]:
# Test the pretrained model on the images
for name in images:
    img = Image.open(os.path.join(path_to_data, "RGB", name)).convert('RGBA')
    mask = Image.open(os.path.join(path_to_data, "results", name)).convert('RGB')
    size_ = img.size
#     print(size_)
    mask = mask.resize(size_, Image.ANTIALIAS)
#     print(mask.size)
    pixels = np.array(mask)
    pixels_mask = np.array([int(pixels[i][j][0]==TL_Label[0] and pixels[i][j][1]==TL_Label[1]) for i in range(len(pixels)) for j in range(len(pixels[0]))])
    pixels_mask = np.reshape(pixels_mask,size_)
    pixels_mask = [pixels_mask, pixels_mask, pixels_mask, pixels_mask]
    pixels_mask = np.swapaxes(pixels_mask, 0,2)
#     pixels_mask = np.swapaxes(pixels_mask, 1,0)
    inv_mask = 1-pixels_mask
    rgb_pixels = np.array(img)
    new_rgb_pixels = pixels_mask*TL_Label + inv_mask*rgb_pixels
    new_rgb = Image.fromarray(new_rgb_pixels.astype(np.uint8)).save(current_path+"/data/results_with_mask/"+name)