# Lab-06 Deeplab v3+

Please run the code with "VScode-devcontainer".

> You can find the tutorial provided by Visual Studio Code here :   
> [https://code.visualstudio.com/docs/devcontainers/containers](https://code.visualstudio.com/docs/devcontainers/containers)

## Import Required Libraries

In [43]:
import os
import time
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from glob import glob

import torch
import torch.nn as nn
from torchvision import transforms
from tqdm import tqdm

import network
from datasets import Cityscapes

## utils

In [44]:
def time_synchronized():
    # pytorch-accurate time
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    return time.time()


def set_bn_momentum(model, momentum=0.1):
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.momentum = momentum

## Define Predict Function

In [45]:
def predict_images(inputs, result, model_name, ckpt):
    os.makedirs(result, exist_ok=True)
    decode_fn = Cityscapes.decode_target

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device: %s" % device)

    # Setup dataloader
    image_files = []
    if os.path.isdir(inputs):
        for ext in ["png", "jpeg", "jpg", "JPEG"]:
            files = glob(os.path.join(inputs, "**/*.%s" % (ext)), recursive=True)
            if len(files) > 0:
                image_files.extend(files)
    elif os.path.isfile(inputs):
        image_files.append(inputs)

    # Set up model (all models are 'constructed at network.modeling) / output_stride : 8 or 16
    model = network.modeling.__dict__[model_name](num_classes=19, output_stride=16)

    set_bn_momentum(model.backbone, momentum=0.01)

    checkpoint = torch.load(ckpt, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint["model_state"])
    model = nn.DataParallel(model)
    model.to(device)
    print("Resume model from %s" % ckpt)
    del checkpoint

    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    with torch.no_grad():
        model = model.eval()
        for img_path in tqdm(image_files):
            ext = os.path.basename(img_path).split(".")[-1]
            img_name = os.path.basename(img_path)[: -len(ext) - 1]
            origin_img = Image.open(img_path).convert("RGB")
            img = transform(origin_img).unsqueeze(0)  # To tensor of NCHW
            img = img.to(device)

            t1 = time_synchronized()
            pred = model(img).max(1)[1].cpu().numpy()[0]  # HW
            t2 = time_synchronized()

            print(f"Done. ({(1E3 * (t2 - t1)):.1f}ms) Inference.")

            colorized_preds = decode_fn(pred).astype("uint8")
            colorized_preds = Image.fromarray(colorized_preds)
            colorized_preds = Image.blend(colorized_preds, origin_img, alpha=0.4)

            if result:
                colorized_preds.save(os.path.join(result, img_name + ".png"))

            plt.axis("off")
            plt.imshow(colorized_preds)
            plt.show()

In [46]:
def predict_video(path, result, model_name, ckpt):
    decode_fn = Cityscapes.decode_target

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device: %s" % device)

    cap = cv2.VideoCapture(path)
    video_writer = cv2.VideoWriter(result, cv2.VideoWriter_fourcc(*"mp4v"), 30, (int(cap.get(3)), int(cap.get(4))))

    torch.backends.cudnn.benchmark = True  # set True to speed up constant image size inference

    # Set up model (all models are 'constructed at network.modeling) / output_stride : 8 or 16
    model = network.modeling.__dict__[model_name](num_classes=19, output_stride=16)

    set_bn_momentum(model.backbone, momentum=0.01)

    checkpoint = torch.load(ckpt, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint["model_state"])
    model = nn.DataParallel(model)
    model.to(device)
    print("Resume model from %s" % ckpt)
    del checkpoint

    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    
    print("Start to predict ...")

    with torch.no_grad():
        model = model.eval()
        total_img = 0

        t1 = time_synchronized()
        while True:
            success, origin_img = cap.read()

            if not success:
                break
            else:
                total_img += 1

            img = transform(origin_img).unsqueeze(0)  # To tensor of NCHW
            img = img.to(device)

            pred = model(img).max(1)[1].cpu().numpy()[0]  # HW

            colorized_preds = decode_fn(pred).astype("uint8")

            colorized_preds = cv2.addWeighted(colorized_preds, 0.5, origin_img, 0.5, 0)
            video_writer.write(colorized_preds)

        t2 = time_synchronized()

        print(f"Done. ({(1E3 * (t2 - t1)):.1f}ms) Inference, Total frame : {total_img}.")

    cap.release()
    video_writer.release()
    cv2.destroyAllWindows()

## Predict

In [47]:
def PredictImage():
    inputs = "test_inputs/test1.png"
    result = "test_results"

    model_name = "deeplabv3plus_mobilenet"
    ckpt = "weights/best_deeplabv3plus_mobilenet_cityscapes_os16.pth"

    # model_name = "deeplabv3plus_resnet101"
    # ckpt = "weights/best_deeplabv3plus_resnet101_cityscapes_os16.pth"

    predict_images(inputs, result, model_name, ckpt)


# PredictImage()

In [48]:
def PredictVideo():
    inputs = "./test_inputs/test.mp4"
    result = "./test_results/test.mp4"

    # model_name = "deeplabv3plus_mobilenet"
    # ckpt = "weights/best_deeplabv3plus_mobilenet_cityscapes_os16.pth"

    model_name = "deeplabv3plus_resnet101"
    ckpt = "weights/best_deeplabv3plus_resnet101_cityscapes_os16.pth"

    predict_video(inputs, result, model_name, ckpt)


PredictVideo()

Device: cuda
Resume model from weights/best_deeplabv3plus_resnet101_cityscapes_os16.pth
Start to predict ...
Done. (54687.9ms) Inference, Total frame : 300.
