# Detection Inference

In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt 
from project_paths import paths 
from ipywidgets import interact, IntSlider

from lane_detection_medium.inference import DetectionInference
from lane_detection_medium.utils.fs import read_image
from lane_detection_medium.utils.viz import render_bbox
from lane_detection_medium.utils.video_processing import VideoReader, VideoWriter
from lane_detection_medium.utils.fs import get_date_string

%matplotlib inline

## Trained Model Loading

In [None]:
EXP_NAME = "train-2023-07-05"

CHECKPOINT_DPATH = paths.yolo_dpath / "LaneMarkingsDetection" / EXP_NAME / "weights"
MODEL_FPATH = CHECKPOINT_DPATH / "best.pt"

inference = DetectionInference.from_file(
    str(MODEL_FPATH), 
    device="cuda:0", 
    img_size=(640, 640)
)

## Test Data Preparation

In [None]:
%matplotlib inline

data_dpath = paths.yolo_dpath / "data" / "2023_07_03" / "test"
img_fpaths = sorted(list((data_dpath / "images").glob("*.PNG")))


TEST_INDEX = 100 
test_image = read_image(img_fpaths[TEST_INDEX])

plt.figure(figsize=(12, 12))
plt.imshow(test_image)
plt.show()

## Single Image Inference 

In [None]:
detection_result = inference.detect([test_image], conf=0.25)[0]

In [None]:
canva = test_image.copy()

color_map = { 
  "solid_white": (255, 0, 0), 
  "break_white": (0, 0, 255), 
  "zebra": (255, 255, 0)
}

plt.figure(figsize=(12, 12))

for det in detection_result:
  label_name = f"{det.label_name}: {det.conf:.2f}"
  render_bbox(canva, det.bbox, label=label_name, color=color_map[det.label_name])

plt.imshow(canva)
plt.show()

## Inference Widget 

In [None]:
color_map = { 
  "solid_white": (255, 0, 0), 
  "break_white": (0, 0, 255), 
  "zebra": (255, 255, 0)
}

@interact 
def show_inference(index=IntSlider(val=0, min=0, max=len(img_fpaths) - 1)):
    test_image = read_image(img_fpaths[index]) 
    detection_result = inference.detect([test_image], conf=0.25)[0]

    canva = test_image.copy()

    plt.figure(figsize=(12, 12))

    for det in detection_result:
      label_name = f"{det.label_name}: {det.conf:.2f}"
      render_bbox(canva, det.bbox, label=label_name, color=color_map[det.label_name])

    plt.imshow(canva)
    plt.show()

## Video Processing

In [None]:
# video_fname = "bad_road_example.mp4" 
# video_fname = "pulkovo.mp4" 
video_fname = "archangel1.mp4" 
video_fpath = paths.data / "videos" / video_fname

cache_dpath = paths.data / "output_videos" / get_date_string()
cache_dpath.mkdir(parents=True, exist_ok=True)
cache_fpath = cache_dpath / f"output_{video_fname}"

color_map = { 
  "solid_white": (255, 0, 0), 
  "break_white": (0, 0, 255), 
  "zebra": (255, 255, 0), 
  "solid_yellow": (255, 0, 255), 
  "break_yellow": (0, 255, 255)
}

with VideoReader(video_fpath, verbose=True) as reader: 
    with VideoWriter(cache_fpath, fps=reader.fps) as writer: 
        for frame_img in reader.get_frames():
            detections = inference.detect([frame_img], conf=0.25)[0]

            canva = frame_img.copy()

            for det in detections:
                label_name = f"{det.label_name}: {det.conf:.2f}"
                render_bbox(canva, det.bbox, label=label_name, color=color_map[det.label_name])
            
            writer.write_frame(canva)