# Imports

In [None]:
from pathlib import Path
from typing import Literal

import numpy as np
import torch
from tqdm.auto import tqdm

from Datasets.base import DatasetMode
from Datasets.batching import BatchManager
from mean_average_precision import MetricBuilder

# Parameter Setup

In [None]:
# Choose an investigated dataset and model
dataset_name: Literal["NCaltech101", "Gen1"] = "Gen1"
model_name: Literal["EGSST"] = "EGSST"

# Set corresponding paths
dataset_path: Path = Path(r"D:\Uniwersytet\GNNBenchmarking\Datasets\GEN1-DS")
model_path: Path = Path(r"C:\Users\benia\PycharmProjects\GNNBenchmark\Results\Detection_EGSST_on_Gen1\TrainedModels\1000.pth")

# mAP evaluation parameters
batch_size: int = 8
investigated_mode: DatasetMode = "validation"
epoch_count: int = 10

# Graph preprocessing parameters
inference_event_count: int = 10000
beta: float = 0.0001
radius: float = 5.

# EGSST specific parameters
min_nodes_subgraph: int = 1000
detection_head_config: Path = Path(r"../confs/rtdtr_head_gen1.yml")
ecnn_flag: bool = True # Whether enhanced cnn should be used
ti_flag: bool = True # Whether TAC augmentation should be used

# Loading Dataset

In [None]:
print(f"Loading {dataset_name} dataset...")
if dataset_name == "Gen1":
    from Datasets.gen1 import Gen1
    dataset = Gen1(
        root=dataset_path,
    )
elif dataset_name == "NCaltech101":
    from Datasets.ncaltech101 import NCaltech
    dataset = NCaltech(
        root=dataset_path,
    )
else:
    raise ValueError(f"Dataset {dataset_name} not implemented.")

dataset.process()
print(f"Dataset Initialized.")

batch_manager = BatchManager(
    dataset = dataset,
    batch_size = batch_size,
    mode = investigated_mode
)

# Initializing Model

In [None]:
print(f"Initializing {model_name} model...")
if model_name == "EGSST":
    from Models.EGSST.EGSST import EGSST
    model = EGSST(
        dataset_information = dataset.get_info(),
        detection_head_config = str(detection_head_config)
    )

    def transform_graph(graph):
        graph.x = graph.x[:inference_event_count, :]
        graph.pos = graph.pos[:inference_event_count, :]
        graph = model.data_transform(graph, beta = beta, radius = radius, min_nodes_subgraph = min_nodes_subgraph)

        if graph is None or graph.pos is None:
            return None

        if isinstance(graph.bbox, dict):
            maximum_time = graph.pos[:, 2].max()
            times_of_boxes = torch.tensor(list(graph.bbox.keys()))
            time_diff = (times_of_boxes - maximum_time) ** 2
            graph.bbox = graph.bbox[list(graph.bbox.keys())[time_diff.argmin()]]

        return graph

    graph_transform = transform_graph
else:
    raise ValueError(f"Model {model_name} not implemented.")
print(f"Model Initialized.")

print(f"Loading model from: {model_path}")
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
print("Model Loaded.")

dataset.transform = graph_transform

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Using device: {device}")

# mAP Computation

In [None]:
metric_fn = MetricBuilder.build_evaluation_metric("map_2d", async_mode=False, num_classes=len(dataset.get_info().classes))

model.eval()
with torch.no_grad():
    for epoch in tqdm(range(1, epoch_count + 1), total = epoch_count, desc = f"Assessing mAP"):
        batch = next(batch_manager)
        batch = batch.to(device)
        output = model(batch)

        for i in range(batch.num_graphs):
            graph_data = batch.get_example(i)
            bbox = graph_data.bbox.detach().cpu().numpy()
            gt_label = bbox[:, 0]
            gt_x_min, gt_x_max = bbox[:, 1] - bbox[:, 3]/2, bbox[:, 1] + bbox[:, 3]/2
            gt_y_min, gt_y_max = bbox[:, 2] - bbox[:, 4]/2, bbox[:, 2] + bbox[:, 4]/2
            gt_bbox = np.concat(
                [
                    gt_x_min[:, None],
                    gt_y_min[:, None],
                    gt_x_max[:, None],
                    gt_y_max[:, None],
                    gt_label[:, None],
                    np.zeros_like(gt_label[:, None]),
                    np.zeros_like(gt_label[:, None]),
                ],
                axis = 1
            )

            pred_boxes = output[i]
            pred_boxes = torch.concat(
                [
                    pred_boxes["boxes"],
                    pred_boxes["labels"][:, None],
                    pred_boxes["scores"][:, None],
                ],
                dim = 1
            ).detach().cpu().numpy()

            metric_fn.add(pred_boxes, gt_bbox)

print(f"MAP@0.9: {metric_fn.value(iou_thresholds = 0.9)['mAP']}")
print(f"MAP@0.5: {metric_fn.value(iou_thresholds = 0.5)['mAP']}")
print(f"MAP@0.25: {metric_fn.value(iou_thresholds = 0.25)['mAP']}")
print(f"MAP@0.1: {metric_fn.value(iou_thresholds = 0.1)['mAP']}")
print(f"MAP@0.05: {metric_fn.value(iou_thresholds = 0.05)['mAP']}")