# Segmentation Inference

In [None]:
%load_ext autoreload
%autoreload 2

import os

CURRENT_DPATH = os.path.abspath(os.path.dirname("__file__"))
PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DPATH, os.pardir))
DATA_DPATH = os.path.join(PROJECT_ROOT, "external_data", "apolloscape")

import matplotlib.pyplot as plt

from lane_detection_hackathon.datasets import DatasetMode, FileDataset
from lane_detection_hackathon.utils.fs import read_image
from lane_detection_hackathon.inference import SegmentationInference
from lane_detection_hackathon.masks import MaskProcessor


## Data Loading

In [None]:
dataset_name = "examples_preprocessed"
dataset_version = "2023_02_27"

dataset_dpath = os.path.join(DATA_DPATH, dataset_name, dataset_version)
file_dataset = FileDataset(dataset_dpath)

test_df = file_dataset.get_data(mode=DatasetMode.TEST)
test_df.shape

In [None]:
test_img_fpath = os.path.join(
    PROJECT_ROOT, 
    "external_data", 
    "apolloscape", 
    "lane_marking_examples",
    "road02", 
    "ColorImage", 
    "Record001", 
    "Camera 5", 
    "170927_063811892_Camera_5.jpg"
)

test_image = read_image(test_img_fpath)
test_image.shape

In [None]:
plt.imshow(test_image)
plt.show()

## Trained Model Loading

In [None]:
MODEL_ID = "88896693014a41989b3c00645e04c30c"
CHECKPOINT_DPATH = os.path.join(PROJECT_ROOT, "train_checkpoints")
MODEL_FNAME = f"best-valid-iou_{MODEL_ID}.pth"

MODEL_FPATH = os.path.join(CHECKPOINT_DPATH, MODEL_FNAME)

inference = SegmentationInference.from_file(
  MODEL_FPATH, device="cuda", batch_size=2, verbose=False
)

In [None]:
results = inference.predict(test_image)

In [None]:
label_cell = results[0].get_label_mask()
label_cell.shape

In [None]:
heatmap_cell = results[0].get_heatmap_mask(4)

In [None]:
heatmap_cell.shape

In [None]:
import numpy as np 

heatmap_cell = np.clip(heatmap_cell, 0, 1)
heatmap_cell = (heatmap_cell * 255).astype(np.uint8)

plt.imshow(heatmap_cell, cmap="gray")
plt.show()

In [None]:
mask_processor = MaskProcessor()
rgb_cell = mask_processor.label_to_rgb(label_cell, inference.label_map)
rgb_cell.shape

In [None]:
plt.imshow(rgb_cell)
plt.show()