In [2]:
import torch
import torchvision
import numpy as np
import cv2
import misc.segm.lookup_table as lut
from collections import namedtuple
import os

base_dir = "/mnt/drive-d/anurag/roadwork/"
scene_data_dir = os.path.join(base_dir, "scene")
sem_seg_data_dir = os.path.join(scene_data_dir, "sem_seg")
device = "cuda:0"

In [3]:
class DatasetWorkzoneSemantic(torchvision.datasets.Cityscapes):

    CityscapesClass = namedtuple(
        "CityscapesClass",
        ["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"],
    )

    classes = [
        CityscapesClass(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),

        CityscapesClass(  'Road'                 ,  1 ,        1 , 'flat'            , 1       , False        , False        , ( 70, 70, 70) ),
        CityscapesClass(  'Sidewalk'             ,  2 ,        2 , 'flat'            , 1       , False        , False        , (102,102,156) ),
        CityscapesClass(  'Bike Lane'            ,  3 ,        3 , 'flat'            , 1       , False        , False        , (190,153,153) ),
        CityscapesClass(  'Off-Road'             ,  4 ,        4 , 'flat'            , 1       , False        , False        , (180,165,180) ),
        CityscapesClass(  'Roadside'             ,  5 ,        5 , 'flat'            , 1       , False        , False        , (150,100,100) ),

        CityscapesClass(  'Barrier'              ,  6 ,        6 , 'barrier'         , 2       , False        , False        , (246, 116, 185) ),
        CityscapesClass(  'Barricade'            ,  7 ,        7 , 'barrier'         , 2       , False        , False        , (248, 135, 182) ),
        CityscapesClass(  'Fence'                ,  8 ,        8 , 'barrier'         , 2       , False        , False        , (251, 172, 187) ),

        CityscapesClass(  'Police Vehicle'       ,  9 ,        9 , 'vehicle'         , 3       , True         , False        , (255, 68, 51) ),
        CityscapesClass(  'Work Vehicle'         ,  10,        10, 'vehicle'         , 3       , True         , False        , (255,104, 66) ),

        CityscapesClass(  'Police Officer'       ,  11,        11, 'human'           , 4       , True         , False        , (184, 107, 35) ),
        CityscapesClass(  'Worker'               ,  12,        12, 'human'           , 4       , True         , False        , (205, 135, 29) ),

        CityscapesClass(  'Cone'                 ,  13,        13, 'object'          , 5       , True         , False        , (30, 119, 179) ),
        CityscapesClass(  'Drum'                 ,  14,        14, 'object'          , 5       , True         , False        , (44, 79, 206) ),
        CityscapesClass(  'Vertical Panel'       ,  15,        15, 'object'          , 5       , True         , False        , (102, 81, 210) ),
        CityscapesClass(  'Tubular Marker'       ,  16,        16, 'object'          , 5       , True         , False        , (170, 118, 213) ),
        CityscapesClass(  'Work Equipment'       ,  17,        17, 'object'          , 5       , True         , False        , (214, 154, 219) ),

        CityscapesClass(  'Arrow Board'          ,  18,        18, 'guidance'        , 6       , True         , False        , (241, 71, 14) ),
        CityscapesClass(  'TTC Sign'             ,  19,        19, 'guidance'        , 6       , True         , False        , (254, 139, 32) ),
    ]

    def __init__(
            self,
            root,
            device,
            split: str = "train",
            mode: str = "fine",
            target_type = "semantic",
            transform = None,
            target_transform = None,
            transforms = None,
            small_size = False
        ) -> None:
        ## don't want to call Cityscapes.__init__
        ## instead want to call VisionDataset.__init__
        super(DatasetWorkzoneSemantic.__bases__[0], self).__init__(root, transforms, transform, target_transform)
        self.mode = "gtFine" if mode == "fine" else "gtCoarse"
        self.images_dir = os.path.join(self.root, "images", split)
        if mode == "fine":
            self.targets_dir = os.path.join(self.root, "gtFine", split)
        else:
            self.targets_dir = os.path.join(self.root, "gtCoarse", split)
        self.target_type = [ target_type ]
        self.split = split
        self.images = []
        self.targets = []

        # verify_str_arg(mode, "mode", ("fine", "coarse"))
        if mode == "fine":
            valid_modes = ("train", "val")
        else:
            valid_modes = ("train", "val")


        for file_name in os.listdir(self.images_dir):
            target_types = []
            for t in self.target_type:
                target_name = "{}{}".format(
                    os.path.splitext(file_name)[0], self._get_target_suffix(self.mode, t)
                )
                target_types.append(os.path.join(self.targets_dir, target_name))

            self.images.append(os.path.join(self.images_dir, file_name))
            self.targets.append(target_types)

        self.device = device
        # setup lookup tables for class/color conversions
        l_key_id, l_key_trainid, l_key_color = self._get_class_properties()
        ar_u_key_id = np.asarray(l_key_id, dtype = np.uint8)
        ar_u_key_trainid = np.asarray(l_key_trainid, dtype = np.uint8)
        ar_u_key_color = np.asarray(l_key_color, dtype = np.uint8)
        _, self.th_i_lut_id2trainid = lut.get_lookup_table(
            ar_u_key = ar_u_key_id,
            ar_u_val = ar_u_key_trainid,
            v_val_default = 0,  # default class is 0 - unlabeled
            device = self.device,
        )
        _, self.th_i_lut_trainid2id = lut.get_lookup_table(
            ar_u_key = ar_u_key_trainid,
            ar_u_val = ar_u_key_id,
            v_val_default = 0,  # default class is 0 - unlabeled
            device = self.device,
        )
        _, self.th_i_lut_trainid2color = lut.get_lookup_table(
            ar_u_key = ar_u_key_trainid,
            ar_u_val = ar_u_key_color,
            v_val_default = 0,  # default color is black
            device = self.device,
        )
    
    def _get_target_suffix(self, mode: str, target_type: str) -> str:
        if target_type == "instance":
            return f"_instanceIds.png"
        elif target_type == "semantic":
            return f"_labelIds.png"
        elif target_type == "color":
            return f"_color.png"
        else:
            raise ValueError(f"Unknown value '{target_type}' for argument target_type.")

    def _get_class_properties(self):
        # iterate over named tuples (nt)
        l_key_id = list()
        l_key_trainid = list()
        l_key_color = list()
        # append classes
        for nt_class in self.classes:
            if nt_class.train_id in [-1, 255]:
                continue
            l_key_id.append([nt_class.id])
            l_key_trainid.append([nt_class.train_id])
            l_key_color.append(nt_class.color)
        # append class background
        l_key_id.append([0])
        l_key_trainid.append([0])
        l_key_color.append([0, 0, 0])
        return l_key_id, l_key_trainid, l_key_color

    def __getitem__(self, index):
        # read images
        p_image = self.images[index]
        p_target = self.targets[index][0]  # 0 is index of semantic target type
        image = cv2.imread(p_image)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        target = cv2.imread(p_target, cv2.IMREAD_UNCHANGED)
        if self.transform is not None:
            transformed = self.transform(image = image, mask = target)
            image = transformed["image"]
            target = transformed["mask"]
        return image, target, p_image, p_target

In [4]:
## create symlink for images in sem_seg_data_dir
import os

def im_names(root_dir, split):
    train_ims_name = os.listdir(os.path.join(root_dir, split))
    train_ims_name = list(filter(lambda x: x.endswith("labelIds.png"), train_ims_name))
    train_ims_name = list(map(lambda x: x.replace("_labelIds.png", ".jpg"), train_ims_name))
    return train_ims_name

sem_seg_images_dir = os.path.join(sem_seg_data_dir, "images")
os.makedirs(sem_seg_images_dir, exist_ok = True)
os.makedirs(os.path.join(sem_seg_images_dir, "train"), exist_ok = True)
os.makedirs(os.path.join(sem_seg_images_dir, "val"), exist_ok = True)

sem_seg_targets_dir = os.path.join(sem_seg_data_dir, "gtFine")
train_ims_name = im_names(sem_seg_targets_dir, "train")
val_ims_name = im_names(sem_seg_targets_dir, "val")

imgs_dir = os.path.join(scene_data_dir, "images")

for im_name in train_ims_name:
    im_path = os.path.join(imgs_dir, im_name)
    out_path = os.path.join(sem_seg_images_dir, "train", im_name)
    if os.path.exists(im_path) and not os.path.exists(out_path):
       os.symlink(im_path, out_path)

for im_name in val_ims_name:
    im_path = os.path.join(imgs_dir, im_name)
    out_path = os.path.join(sem_seg_images_dir, "val", im_name)
    if os.path.exists(im_path) and not os.path.exists(out_path):
       os.symlink(im_path, out_path)

In [5]:
import albumentations as A
from segmentation_models_pytorch.encoders import get_preprocessing_fn
import yaml

model_base_path = "./models/sem_segm_gps_split/"
model_path = os.path.join(model_base_path, "DeeplabV3Plus_EfficientNetB4_best_model_epoch_0060_workzone.pth")
config_path =  os.path.join(model_base_path, "DeeplabV3Plus_EfficientNetB4_workzone.yaml")
config = yaml.safe_load(open(config_path, "r"))

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype("float32")

preprocess_input = get_preprocessing_fn(
    encoder_name = config["segmentation_model_backbone"],
    pretrained = config["segmentation_pretrained_dataset"],
)

transform_full = A.Compose([
    A.Resize(width=1280, height=720),
    A.Lambda(name = "image_preprocessing", image = preprocess_input),
    A.PadIfNeeded(736, 1280),
    A.Lambda(name = "to_tensor", image = to_tensor),
])

dataset_test = DatasetWorkzoneSemantic(
    root = sem_seg_data_dir,
    split = "val",
    mode = "fine",
    transform = transform_full,
    device= device,
)

model = torch.load(model_path, map_location=device)
model.eval()

  from .autonotebook import tqdm as notebook_tqdm


DeepLabV3Plus(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 48, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d((0, 1, 0, 1))
    )
    (_bn0): BatchNorm2d(48, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          48, 48, kernel_size=(3, 3), stride=[1, 1], groups=48, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(48, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          48, 12, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          12, 48, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStatic

In [6]:
img, target, img_path, target_path = dataset_test.__getitem__(0)

## non-tensor version of the image, but padded
image = cv2.imread(img_path)
transform_padded = A.Compose([
    A.Resize(width=1280, height=720),
    A.PadIfNeeded(736, 1280),
])
img_padded = transform_padded(image = image)["image"]

with torch.inference_mode():
    model_input = torch.from_numpy(img).unsqueeze(0)
    model_input = model_input.to(device)
    logits = model(model_input)
    prediction = logits.argmax(axis = 1)

prediction_color = lut.lookup_chw(
    td_u_input = prediction.byte(),
    td_i_lut = dataset_test.th_i_lut_trainid2color,
).permute((1, 2, 0)).cpu().numpy()
blend = cv2.addWeighted(img_padded, 0.4, prediction_color, 0.6, 0.0)

gt_color = lut.lookup_chw(
    td_u_input = torch.from_numpy(target).unsqueeze(0).byte().to(device),
    td_i_lut = dataset_test.th_i_lut_trainid2color,
).permute((1, 2, 0)).cpu().numpy()
gt_blend = cv2.addWeighted(img_padded, 0.4, gt_color, 0.6, 0.0)

os.makedirs("./output/segm_gps_split/gt", exist_ok=True)
cv2.imwrite("./output/segm_gps_split/gt/{img}".format(img=os.path.basename(img_path)), gt_blend)

os.makedirs("./output/segm_gps_split/vis", exist_ok=True)
cv2.imwrite("./output/segm_gps_split/vis/{img}".format(img=os.path.basename(img_path)), blend)

True