In [1]:
import os
import mmcv
import numpy as np
from mmseg.models import BaseSegmentor
from mmengine.structures import PixelData
from mmseg.structures import SegDataSample
from mmseg.apis import inference_model, init_model
import warnings


def fxn():
    warnings.warn("deprecated", DeprecationWarning)


with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fxn()

config_file: str = "pretrained_models/segformer_mit-b4_8xb2-160k_ade20k-512x512.py"
checkpoint_file: str = "pretrained_models/segformer_mit-b4_512x512_160k_ade20k_20210728_183055-7f509d7d.pth"

img = "demo_img/0000007.jpg"
save_dir = "outputs/test"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [2]:
model = init_model(config_file, checkpoint_file, device="cuda:0")
result = inference_model(model, img)
# append img name to save_dir
out_file = os.path.join(save_dir, img.split("/")[-1])



Loads checkpoint by local backend from path: pretrained_models/segformer_mit-b4_512x512_160k_ade20k_20210728_183055-7f509d7d.pth


In [3]:
from segmentation.inference import Segmentor

image = mmcv.imread(img, channel_order='rgb')
classes = Segmentor.ade_classes()
dynamic_classes = Segmentor.dynamic_classes()
num_classes = len(classes)
sem_seg = result.pred_sem_seg.cpu().data
print(sem_seg)

tensor([[[43, 43, 43,  ...,  2,  2,  2],
         [43, 43, 43,  ...,  2,  2,  2],
         [43, 43, 43,  ...,  2,  2,  2],
         ...,
         [ 6,  6,  6,  ...,  6,  6,  6],
         [ 6,  6,  6,  ...,  6,  6,  6],
         [ 6,  6,  6,  ...,  6,  6,  6]]])


In [4]:
print(f"shape of image: {image.shape}")
print(f"shape of sem_seg: {sem_seg.shape}")

shape of image: (240, 352, 3)
shape of sem_seg: torch.Size([1, 240, 352])


In [5]:
ids = np.unique(sem_seg)[::-1]
legal_indices = ids < num_classes
ids = ids[legal_indices]
print(f"ids: {ids}")

ids: [83 43 32 29  6  4  2  0]


In [6]:
# select those not in dynamic classes
ids = np.array([id for id in ids if id not in dynamic_classes])
print(f"ids: {ids}")

ids: [32 29  6  4  0]


In [7]:
def _get_center_loc(mask: np.ndarray) -> np.ndarray:
    """Get semantic seg center coordinate.

    Args:
        mask: np.ndarray: get from sem_seg
    """
    loc = np.argwhere(mask == 1)

    loc_sort = np.array(
        sorted(loc.tolist(), key=lambda row: (row[0], row[1])))
    y_list = loc_sort[:, 0]
    unique, indices, counts = np.unique(
        y_list, return_index=True, return_counts=True)
    y_loc = unique[counts.argmax()]
    y_most_freq_loc = loc[loc_sort[:, 0] == y_loc]
    center_num = len(y_most_freq_loc) // 2
    x = y_most_freq_loc[center_num][1]
    y = y_most_freq_loc[center_num][0]
    return np.array([x, y])

In [8]:
import cv2
import torch

alpha = 0.5
labels = np.array(ids, dtype=np.int64)
palette = Segmentor.ade_palette()
colors = [palette[label] for label in labels]

mask = np.zeros_like(image, dtype=np.uint8)
for label, color in zip(labels, colors):
    mask[sem_seg[0] == label, :] = color

font = cv2.FONT_HERSHEY_SIMPLEX
# (0,1] to change the size of the text relative to the image
scale = 0.05
fontScale = min(image.shape[0], image.shape[1]) / (25 / scale)
fontColor = (255, 255, 255)
if image.shape[0] < 300 or image.shape[1] < 300:
    thickness = 1
    rectangleThickness = 1
else:
    thickness = 2
    rectangleThickness = 2
lineType = 2

if isinstance(sem_seg[0], torch.Tensor):
    masks = sem_seg[0].numpy() == labels[:, None, None]
else:
    masks = sem_seg[0] == labels[:, None, None]
masks = masks.astype(np.uint8)
for mask_num in range(len(labels)):
    classes_id = labels[mask_num]
    classes_color = colors[mask_num]
    loc = _get_center_loc(masks[mask_num])
    text = classes[classes_id]
    (label_width, label_height), baseline = cv2.getTextSize(
        text, font, fontScale, thickness)
    mask = cv2.rectangle(mask, loc,
                            (loc[0] + label_width + baseline,
                            loc[1] + label_height + baseline),
                            classes_color, -1)
    mask = cv2.rectangle(mask, loc,
                            (loc[0] + label_width + baseline,
                            loc[1] + label_height + baseline),
                            (0, 0, 0), rectangleThickness)
    mask = cv2.putText(mask, text, (loc[0], loc[1] + label_height),
                        font, fontScale, fontColor, thickness,
                        lineType)
color_seg = (image * (1 - alpha) + mask * alpha).astype(np.uint8)

In [9]:
mmcv.imwrite(mmcv.rgb2bgr(color_seg), out_file)

True