In [15]:
from unet_model import *
from glob import glob
import random
import matplotlib.pyplot as plt
from pathlib import Path
from torchvision import transforms
from datetime import datetime
from collections import namedtuple
import numpy as np
from PIL import Image


In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_path = glob('data/leftImg8bit/train/*/*leftImg8bit.png')
val_path = glob('data/leftImg8bit/val/*/*leftImg8bit.png')


In [17]:
GroupedLabel = namedtuple("GroupedLabel", [
    "id", # group id
    "name", # group name
    "ids", # list of ids
    "color", # color of the group
])
grouped_labels = [
    GroupedLabel(0, "motor vehicles" , [26, 27, 28, 28, 29, 30, 31, 32],           (  0,   0, 142)),
    GroupedLabel(1, "pedestrians"    , [24, 25, 33],                               (220,  20,  60)),
    GroupedLabel(2, "road"           , [6, 7, 8, 9, 10],                           (128,  64, 128)),
    GroupedLabel(3, "traffic objects", [17, 18, 19, 20],                           (250, 170,  30)),
    GroupedLabel(4, "background"     , [4, 5, 11, 12, 13, 14, 15, 16, 21, 22, 23], ( 70,  70,  70)),
    GroupedLabel(5, "void"           , [0, 2, 3],                                  (  0,   0,   0)),
    GroupedLabel(6, "ego vehicle"    , [1],                                        (  0,   0,   0)),
]

def classes_to_rgb(output: torch.Tensor) -> torch.Tensor:
    # Input: (num_classes, H, W)
    # Output: (3, H, W)
    rgb = torch.zeros((3, *output.size()[-2:]))
    output_max = torch.argmax(output.squeeze(), dim=0)
    for label in grouped_labels:
        for c in range(3):
            rgb[c][output_max == label.id] = label.color[c] / 255
    return rgb

def gt_to_classes(gt: np.ndarray) -> torch.Tensor:
    # Input: (H, W), values: 0-num_classes
    # Output: (num_classes, H, W)
    output = torch.zeros((len(grouped_labels), *gt.shape))
    for group in grouped_labels:
        output[group.id] = torch.Tensor(np.isin(gt, group.ids))
    return output

def infer(unet_model, img_path, out_path):
    label_path = img_path.replace("leftImg8bit", "gtFine").replace(".png", "_labelIds.png")
    img = np.array(Image.open(img_path))
    label = np.array(Image.open(label_path))
    h, w = img.shape[:2]

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((h//4, w//4), antialias=True),
        #transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    ])
    transform1 = transforms.Compose([
        transforms.Resize((h//4, w//4), antialias=True),
    ])

    img_in = transform(img).to(device)[None]
    label_in = transform1(gt_to_classes(label)).to(device)[None]
    output = unet_model(img_in)
    output = classes_to_rgb(output)
    label_in = classes_to_rgb(label_in)
    output = output.cpu().detach().numpy().transpose(1, 2, 0).clip(0, 1)
    img_in = img_in.cpu().detach()[0].numpy().transpose(1, 2, 0)
    label_in = label_in.cpu().detach().numpy().transpose(1, 2, 0)
    
    plt.imsave(out_path, output)
    return img_in, output, label_in
    # fig, ax = plt.subplots(1, 3, figsize=(20, 5))
    # ax[0].imshow(img_in)
    # ax[1].imshow(output)
    # ax[2].imshow(label)
    # plt.show()

In [18]:
models = Path("weights").glob("unetsegment_231126_171755_checkpoint_*.pt")
models = sorted(models, key=lambda x: int(x.name.split("_")[-1].split(".")[0]))
img = random.choice(val_path)
for model in models:
    print(f"Saving {model} output")
    unet_model = UNet(3, 7).float().to(device)
    checkpoint = torch.load(model)
    unet_model.load_state_dict(checkpoint['model_state_dict'])
    _, _, label = infer(unet_model, img, f"output/{model.stem}.png")
plt.imsave(f"output/{model.stem}_gt.png", label)

Saving weights\unetsegment_231126_171755_checkpoint_0.pt output
Saving weights\unetsegment_231126_171755_checkpoint_1.pt output
Saving weights\unetsegment_231126_171755_checkpoint_2.pt output
Saving weights\unetsegment_231126_171755_checkpoint_3.pt output
Saving weights\unetsegment_231126_171755_checkpoint_4.pt output
Saving weights\unetsegment_231126_171755_checkpoint_5.pt output
Saving weights\unetsegment_231126_171755_checkpoint_6.pt output
Saving weights\unetsegment_231126_171755_checkpoint_7.pt output
Saving weights\unetsegment_231126_171755_checkpoint_8.pt output
Saving weights\unetsegment_231126_171755_checkpoint_9.pt output
Saving weights\unetsegment_231126_171755_checkpoint_10.pt output
Saving weights\unetsegment_231126_171755_checkpoint_11.pt output
Saving weights\unetsegment_231126_171755_checkpoint_12.pt output
Saving weights\unetsegment_231126_171755_checkpoint_13.pt output
Saving weights\unetsegment_231126_171755_checkpoint_14.pt output
Saving weights\unetsegment_231126_1

In [19]:
! ffmpeg -y -f image2 -framerate 5 -i "output/unetsegment_231126_171755_checkpoint_%d.png" -c:v libx264 -vf hqx=4 -pix_fmt yuv420p -crf 13 -x264-params psy-rd=0 output/video_{datetime.now().strftime(r"%y%m%d_%H%M%S")}.mp4

ffmpeg version 6.1-full_build-www.gyan.dev Copyright (c) 2000-2023 the FFmpeg developers
  built with gcc 12.2.0 (Rev10, Built by MSYS2 project)
  configuration: --enable-gpl --enable-version3 --enable-static --pkg-config=pkgconf --disable-w32threads --disable-autodetect --enable-fontconfig --enable-iconv --enable-gnutls --enable-libxml2 --enable-gmp --enable-bzlib --enable-lzma --enable-libsnappy --enable-zlib --enable-librist --enable-libsrt --enable-libssh --enable-libzmq --enable-avisynth --enable-libbluray --enable-libcaca --enable-sdl2 --enable-libaribb24 --enable-libaribcaption --enable-libdav1d --enable-libdavs2 --enable-libuavs3d --enable-libzvbi --enable-librav1e --enable-libsvtav1 --enable-libwebp --enable-libx264 --enable-libx265 --enable-libxavs2 --enable-libxvid --enable-libaom --enable-libjxl --enable-libopenjpeg --enable-libvpx --enable-mediafoundation --enable-libass --enable-frei0r --enable-libfreetype --enable-libfribidi --enable-libharfbuzz --enable-liblensfun --ena