# Semantic Segmentation Demo

This is a notebook for running the benchmark semantic segmentation network from the the [ADE20K MIT Scene Parsing Benchchmark](http://sceneparsing.csail.mit.edu/).

The code for this notebook is available here
https://github.com/CSAILVision/semantic-segmentation-pytorch/tree/master/notebooks

It can be run on Colab at this URL https://colab.research.google.com/github/CSAILVision/semantic-segmentation-pytorch/blob/master/notebooks/DemoSegmenter.ipynb

## Imports and utility functions

We need pytorch, numpy, and the code for the segmentation model.  And some utilities for visualizing the data.

In [1]:
# System libs
import os, csv, torch, numpy, scipy.io, PIL.Image, torchvision.transforms
from matplotlib import pyplot as plt
import cv2

# Our libs
from mit_semseg.models import ModelBuilder, SegmentationModule
from mit_semseg.utils import colorEncode

colors = scipy.io.loadmat('data/color150.mat')['colors']
names = {}
with open('data/object150_info.csv') as f:
    reader = csv.reader(f)
    next(reader)
    for row in reader:
        names[int(row[0])] = row[5].split(";")[0]

def visualize_result(img, pred, index=None):
    # filter prediction class if requested
    if index is not None:
        pred = pred.copy()
        pred[pred != index] = -1
        print(f'{names[index+1]}:')

    # colorize prediction
    pred_color = colorEncode(pred, colors).astype(numpy.uint8)

    # aggregate images and save
    im_vis = numpy.concatenate((img, pred_color), axis=1)
    display(PIL.Image.fromarray(im_vis))

In [2]:
# new index: 0=others, 1=light_source, 2=reflective, 3=geometry
light_source = {
    83: 'light',
    37: 'lamp',
    9: 'windowpane',
}

reflective = {
    28: 'mirror',
}

geometry = {
    1: 'wall',
    4: 'floor',
    6: 'ceiling',
    15: 'door',
    19: 'curtain',
}

print(names)
print(light_source)
print(reflective)
print(geometry)
important_keys = list(light_source.keys()) + list(reflective.keys()) + list(geometry.keys())
print(important_keys)

{1: 'wall', 2: 'building', 3: 'sky', 4: 'floor', 5: 'tree', 6: 'ceiling', 7: 'road', 8: 'bed', 9: 'windowpane', 10: 'grass', 11: 'cabinet', 12: 'sidewalk', 13: 'person', 14: 'earth', 15: 'door', 16: 'table', 17: 'mountain', 18: 'plant', 19: 'curtain', 20: 'chair', 21: 'car', 22: 'water', 23: 'painting', 24: 'sofa', 25: 'shelf', 26: 'house', 27: 'sea', 28: 'mirror', 29: 'rug', 30: 'field', 31: 'armchair', 32: 'seat', 33: 'fence', 34: 'desk', 35: 'rock', 36: 'wardrobe', 37: 'lamp', 38: 'bathtub', 39: 'railing', 40: 'cushion', 41: 'base', 42: 'box', 43: 'column', 44: 'signboard', 45: 'chest', 46: 'counter', 47: 'sand', 48: 'sink', 49: 'skyscraper', 50: 'fireplace', 51: 'refrigerator', 52: 'grandstand', 53: 'path', 54: 'stairs', 55: 'runway', 56: 'case', 57: 'pool', 58: 'pillow', 59: 'screen', 60: 'stairway', 61: 'river', 62: 'bridge', 63: 'bookcase', 64: 'blind', 65: 'coffee', 66: 'toilet', 67: 'flower', 68: 'book', 69: 'hill', 70: 'bench', 71: 'countertop', 72: 'stove', 73: 'palm', 74: '

## Loading the segmentation model

Here we load a pretrained segmentation model.  Like any pytorch model, we can call it like a function, or examine the parameters in all the layers.

After loading, we put it on the GPU.  And since we are doing inference, not training, we put the model in eval mode.

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Network Builders
net_encoder = ModelBuilder.build_encoder(
    arch='resnet50dilated',
    fc_dim=2048,
    weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth')
net_decoder = ModelBuilder.build_decoder(
    arch='ppm_deepsup',
    fc_dim=2048,
    num_class=150,
    weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth',
    use_softmax=True)

crit = torch.nn.NLLLoss(ignore_index=-1)
segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
segmentation_module.eval()
segmentation_module.to(device)

Loading weights for net_encoder
Loading weights for net_decoder


SegmentationModule(
  (encoder): ResnetDilated(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): SynchronizedBatchNorm2d(64, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): SynchronizedBatchNorm2d(128, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): SynchronizedBatchNorm2d(64, eps=1

## Load test data

Now we load and normalize a single test image.  Here we use the commonplace convention of normalizing the image to a scale for which the RGB values of a large photo dataset would have zero mean and unit standard deviation.  (These numbers come from the imagenet dataset.)  With this normalization, the limiiting ranges of RGB values are within about (-2.2 to +2.7).

In [7]:
import sys
sys.path.insert(1, '../../../')
import util
import numpy as np

input_dir = "../../../../Dataset/LavalIndoor/crop/"
output_dir = "../../../../Dataset/LavalIndoor/semantics/"
nms = os.listdir(input_dir)
handle = util.PanoramaHandler()
tone = util.TonemapHDR()

# Load and normalize one image as a singleton tensor batch
pil_to_tensor = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406], # These are RGB mean+std values
        std=[0.229, 0.224, 0.225])  # across a large photo dataset.
])


i = 0
for nm in nms:
    if nm.endswith('.exr'):
        path = input_dir + nm
        exr = handle.read_hdr(path)
        original_img = tone(exr, True)[0]
        original_img = 255 * original_img
        original_img = original_img.astype(np.uint8)

        img_data = pil_to_tensor(original_img)
        singleton_batch = {'img_data': img_data[None].to(device)}
        output_size = img_data.shape[1:]
        
        # Run the segmentation at the highest resolution.
        with torch.no_grad():
            scores = segmentation_module(singleton_batch, segSize=output_size)

        # Get the predicted scores for each pixel
        _, pred = torch.max(scores, dim=1)

        mask_important = torch.zeros_like(pred) # 0 for unimportant, 1 for important
        for key in important_keys:
            mask_important[pred == (key - 1)] = 1

        mask_light_source = torch.zeros_like(pred)
        for key in light_source.keys():
            mask_light_source[pred == (key - 1)] = 1

        mask_reflective = torch.zeros_like(pred)
        for key in reflective.keys():
            mask_reflective[pred == (key - 1)] = 1

        mask_geometry = torch.zeros_like(pred)
        for key in geometry.keys():
            mask_geometry[pred == (key - 1)] = 1

        pred = pred * mask_important
        pred[mask_light_source > 0] = 1
        pred[mask_reflective > 0] = 2
        pred[mask_geometry > 0] = 3

        pred = pred.cpu()[0].numpy()
        # visualize_result(original_img, pred)

        output_path = output_dir + nm.replace('exr', 'png')
        cv2.imwrite(output_path, pred)

        i = i + 1
        print(i, len(nms))

1 19575
2 19575
3 19575
4 19575
5 19575
6 19575
7 19575
8 19575
9 19575
10 19575
11 19575
12 19575
13 19575
14 19575
15 19575
16 19575
17 19575
18 19575
19 19575
20 19575
21 19575
22 19575
23 19575
24 19575
25 19575
26 19575
27 19575
28 19575
29 19575
30 19575
31 19575
32 19575
33 19575
34 19575
35 19575
36 19575
37 19575
38 19575
39 19575
40 19575
41 19575
42 19575
43 19575
44 19575
45 19575
46 19575
47 19575
48 19575
49 19575
50 19575
51 19575
52 19575
53 19575
54 19575
55 19575
56 19575
57 19575
58 19575
59 19575
60 19575
61 19575
62 19575
63 19575
64 19575
65 19575
66 19575
67 19575
68 19575
69 19575
70 19575
71 19575
72 19575
73 19575
74 19575
75 19575
76 19575
77 19575
78 19575
79 19575
80 19575
81 19575
82 19575
83 19575
84 19575
85 19575
86 19575
87 19575
88 19575
89 19575
90 19575
91 19575
92 19575
93 19575
94 19575
95 19575
96 19575
97 19575
98 19575
99 19575
100 19575
101 19575
102 19575
103 19575
104 19575
105 19575
106 19575
107 19575
108 19575
109 19575
110 19575
111 1957