# Segmentation Inference

In [None]:
%load_ext autoreload
%autoreload 2

import os

CURRENT_DPATH = os.path.abspath(os.path.dirname("__file__"))
PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DPATH, os.pardir))
DATA_DPATH = os.path.join(PROJECT_ROOT, "external_data")

import matplotlib.pyplot as plt
from tqdm import tqdm

from lane_detection_hackathon.datasets import DatasetMode, FileDataset
from lane_detection_hackathon.utils.fs import read_image, VideoWriter
from lane_detection_hackathon.inference import SegmentationInference
from lane_detection_hackathon.baseparser import BaseParser
from lane_detection_hackathon.utils.image import overlay

## Trained Model Loading

In [None]:
MODEL_ID = "40446a2b3b7543c292301a3b2da1ed67"
CHECKPOINT_DPATH = os.path.join(PROJECT_ROOT, "train_checkpoints")
MODEL_FNAME = f"best-valid-iou_{MODEL_ID}.pth"

MODEL_FPATH = os.path.join(CHECKPOINT_DPATH, MODEL_FNAME)

inference = SegmentationInference.from_file(
  MODEL_FPATH, device="cuda", batch_size=32, verbose=False
)

## Data Loading

In [None]:
dataset_name = "check2"
dataset_version = "2023_03_04"

dataset_dpath = os.path.join(DATA_DPATH, dataset_name, dataset_version)
file_dataset = FileDataset(dataset_dpath)

test_df = file_dataset.get_data(mode=DatasetMode.TEST)
test_df.shape

In [None]:
TEST_IDX = 0 

test_row = test_df.iloc[TEST_IDX]

src_fpath, trg_fpath = test_row[[BaseParser.src_key, BaseParser.target_key]]

src_fpath = os.path.join(DATA_DPATH, src_fpath)
trg_fpath = os.path.join(DATA_DPATH, trg_fpath)

src_image = read_image(src_fpath)
trg_image = read_image(trg_fpath)

src_image.shape, trg_image.shape

## Inference

In [None]:
result = inference.predict(src_image)

pred_mask = result.get_rgb_mask(src_image.shape[:2])
pred_mask.shape

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(30, 10))

ax[0].imshow(src_image)
ax[0].set_title("Source Image")

ax[1].imshow(trg_image)
ax[1].set_title("Target Mask")

ax[2].imshow(pred_mask)
ax[2].set_title("Pred Mask")

plt.show()

In [None]:
plt.figure(figsize=(8, 8))

heatmap_img = result.get_heatmap(5, src_image.shape)
plt.imshow(heatmap_img)
plt.show()

In [None]:
plt.figure(figsize=(8, 8))

overlayed_img = overlay(src_image, pred_mask)
plt.imshow(overlayed_img)
plt.show()

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 10))

overlayed_img = overlay(src_image, pred_mask)

ax[0].imshow(overlayed_img)
ax[0].set_title("Source Image with Predicted Mask")

ax[1].imshow(trg_image)
ax[1].set_title("Target Mask")

plt.savefig("infer.jpg")

plt.show()

## Video Processing

In [None]:
test_df["src_camera"] = test_df["src"].str.split(os.sep).str[-2]

test_df.head()

In [None]:
output_dpath = os.path.join(DATA_DPATH, "inference_video")
os.makedirs(output_dpath, exist_ok=True)

output_fpath = os.path.join(output_dpath, f"test_camera_6.mp4")
output_fpath

In [None]:
test_fpaths = test_df[test_df["src_camera"] == "Camera 6"]["src"].values
test_fpaths.shape

In [None]:
with VideoWriter(output_fpath, fps=15) as video_writer:
    stream = tqdm(test_fpaths)
    for fpath in stream: 
        src_fpath = os.path.join(DATA_DPATH, fpath)
        src_image = read_image(src_fpath)
        
        result = inference.predict(src_image)
        pred_mask = result.get_rgb_mask(src_image.shape[:2])
        
        overlayed_img = overlay(src_image, pred_mask)
        
        video_writer.write_frame(overlayed_img)
        