# RGB Road Scene Material Segmentation

This notebook shows an example of running the RMSNet as described here:

https://github.com/kyotovision-public/RGB-Road-Scene-Material-Segmentation/

In [1]:
# --------------------------------------
import torch as pt

# --------------------------------------
import numpy as np

# --------------------------------------
import matplotlib.pyplot as plt

# --------------------------------------
from PIL import Image

# --------------------------------------
from dataclasses import dataclass

# --------------------------------------
import torchvision as tv

# --------------------------------------
from rsms.modeling.rmsnet.rmsnet import RMSNet
from rsms.dataloaders import make_data_loader
from rsms.dataloaders.utils import decode_segmap
from rsms import conf

Model parameters (should be cleaned up)

In [2]:
@dataclass
class Args:
    positional_encoding = False
    lr = 0
    workers = 1
    epochs = 1
    batch_size = 1
    batch_size_val = 8
    gpu_ids = "0"
    backbone = "mit_b2"  #
    checkname = "new"
    eval_interval = 1
    loss_type = "ce"
    dataset = "kitti_advanced"
    propagation = 0  # int value
    sync_bn = False  # True
    list_folder = "list_folder2"  # split-1: list_folder1; split-2: list_folder2
    lr_scheduler = "cos"  # choices=['poly', 'step', 'cos']
    use_balanced_weights = False
    use_sbd = False
    base_size = 512
    crop_size = 512
n_classes = 20
freeze_bn = False

args = Args()

In [None]:
# Define network
rms = RMSNet(
    num_classes=n_classes,
    backbone="segformer",
    encoder_id=2,
    sync_bn=args.sync_bn,
    freeze_bn=freeze_bn,
)

# Switch the model to inference mode
rms.eval()

# Test images

In [4]:
# A generic building with bricks
# ==================================================
img_path = conf.DATA_DIR / "bricks2.jpeg"

# An image from the KITTI Materials training set
# ==================================================
# img_path = (
#     conf.DATA_DIR
#     / "KITTI_Materials/train/image_2/2011_09_26_drive_0002_sync_0000000025.png"
# )

Open the image and show it so that we know what we are trying to segment

In [None]:
img = tv.io.read_image(img_path)
plt.imshow(img.permute(1, 2, 0) / img.max())

Run the image through the model

In [6]:
with pt.no_grad():
    out = rms(img.float()[None,:,:,:])[0].detach().clone().numpy()

Extract the pixel classes and display the segmentation results

In [None]:
classes = np.argmax(out, axis=0)
seg = decode_segmap(classes, dataset="kitti_advanced")
plt.imshow(seg)