# Imports

In [1]:
from pathlib import Path
from typing import Literal

import numpy as np
import torch
from tqdm.auto import tqdm

from Datasets.base import DatasetMode
from Datasets.batching import BatchManager
from mean_average_precision import MetricBuilder

  from .autonotebook import tqdm as notebook_tqdm


# Parameter Setup

In [2]:
# Choose an investigated dataset and model
dataset_name: Literal["NCaltech101", "Gen1"] = "Gen1"
model_name: Literal["EGSST", "AEGNN"] = "AEGNN"

# Set corresponding paths
dataset_path: Path = Path(r"/home/benio/Documents/Datasets/Gen1")
results_path: Path = Path(r"../Results/Detection_AEGNN_on_Gen1")
model_path: Path = results_path / "TrainedModels" / "300.pth"

# mAP evaluation parameters
batch_size: int = 2
investigated_mode: DatasetMode = "validation"
epoch_count: int = 50

# Graph preprocessing parameters
inference_event_count: int = 10000
beta: float = 0.0001
radius: float = 5.

# EGSST specific parameters
min_nodes_subgraph: int = 1000
detection_head_config: Path = Path(r"../confs/rtdtr_head_gen1.yml")
ecnn_flag: bool = True # Whether enhanced cnn should be used
ti_flag: bool = True # Whether TAC augmentation should be used

# AEGNN specific parameters
kernel_size: int = 8
pooling_outputs: int = 128
max_neighbors: int = 32
sampling: bool = True

# Loading Dataset

In [3]:
print(f"Loading {dataset_name} dataset...")
if dataset_name == "Gen1":
    from Datasets.gen1 import Gen1
    dataset = Gen1(
        root=dataset_path,
    )
elif dataset_name == "NCaltech101":
    from Datasets.ncaltech101 import NCaltech
    dataset = NCaltech(
        root=dataset_path,
    )
else:
    raise ValueError(f"Dataset {dataset_name} not implemented.")

dataset.process()
print(f"Dataset Initialized.")

batch_manager = BatchManager(
    dataset = dataset,
    batch_size = batch_size,
    mode = investigated_mode
)

Loading Gen1 dataset...


training (train_c): 100%|██████████| 250/250 [00:00<00:00, 192752.94it/s]
validation (val_b): 100%|██████████| 179/179 [00:00<00:00, 194805.50it/s]

Dataset Initialized.





# Initializing Model

In [4]:
print(f"Initializing {model_name} model...")
if model_name == "EGSST":
    from Models.EGSST.EGSST import EGSST
    model = EGSST(
        dataset_information = dataset.get_info(),
        detection_head_config = str(detection_head_config)
    )

    def transform_graph(graph):
        graph = model.data_transform(
            graph, beta = beta,
            radius = radius,
            min_nodes_subgraph = min_nodes_subgraph,
            n_samples = inference_event_count,
            sub_sample = sampling
        )

        if graph is None or graph.pos is None:
            return None

        if isinstance(graph.bbox, dict):
            maximum_time = graph.pos[:, 2].max()
            times_of_boxes = torch.tensor(list(graph.bbox.keys()))
            time_diff = (times_of_boxes - maximum_time) ** 2
            graph.bbox = graph.bbox[list(graph.bbox.keys())[time_diff.argmin()]]

        return graph

    graph_transform = transform_graph
elif model_name == "AEGNN":
    from Models.CleanAEGNN.AEGNN_Detection import AEGNN_Detection
    num_classes = len(dataset.get_info().classes)
    if dataset_name == "Gen1":
        num_classes += 1
    model = AEGNN_Detection(
        input_shape = (*dataset.get_info().image_size, 3),
        kernel_size = kernel_size,
        n = [1, 16, 32, 32, 32, 128, 128, 128],
        pooling_outputs = pooling_outputs,
        num_classes = num_classes,
    )

    def transform_graph(graph):
        graph = model.data_transform(
            graph, n_samples = inference_event_count, sampling = sampling,
            beta =  beta, radius = radius,
            max_neighbors = max_neighbors
        )

        if isinstance(graph.bbox, dict):
            maximum_time = graph.pos[:, 2].max()
            times_of_boxes = torch.tensor(list(graph.bbox.keys()))
            time_diff = (times_of_boxes - maximum_time) ** 2
            graph.bbox = graph.bbox[list(graph.bbox.keys())[time_diff.argmin()]]

        return graph

    graph_transform = transform_graph
else:
    raise ValueError(f"Model {model_name} not implemented.")
print(f"Model Initialized.")

print(f"Loading model from: {model_path}")
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
print("Model Loaded.")

dataset.transform = graph_transform

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Using device: {device}")

Initializing AEGNN model...
Model Initialized.
Loading model from: ../Results/Detection_AEGNN_on_Gen1/TrainedModels/300.pth
Model Loaded.
Using device: cuda


# mAP Computation

In [5]:
metric_fn = MetricBuilder.build_evaluation_metric("map_2d", async_mode=False, num_classes=len(dataset.get_info().classes))

model.eval()
with torch.no_grad():
    for epoch in tqdm(range(1, epoch_count + 1), total = epoch_count, desc = f"Assessing mAP"):
        try:
            batch = next(batch_manager)
            batch = batch.to(device)
            output = model(batch)

            for i in range(batch.num_graphs):
                graph_data = batch.get_example(i)
                bbox = graph_data.bbox.detach().cpu().numpy()
                gt_label = bbox[:, 0]
                gt_x_min, gt_x_max = bbox[:, 1] - bbox[:, 3]/2, bbox[:, 1] + bbox[:, 3]/2
                gt_y_min, gt_y_max = bbox[:, 2] - bbox[:, 4]/2, bbox[:, 2] + bbox[:, 4]/2
                gt_bbox = np.concat(
                    [
                        gt_x_min[:, None],
                        gt_y_min[:, None],
                        gt_x_max[:, None],
                        gt_y_max[:, None],
                        gt_label[:, None],
                        np.zeros_like(gt_label[:, None]),
                        np.zeros_like(gt_label[:, None]),
                    ],
                    axis = 1
                )

                pred_boxes = output[i]
                pred_boxes = torch.concat(
                    [
                        pred_boxes["boxes"],
                        pred_boxes["labels"][:, None],
                        pred_boxes["scores"][:, None],
                    ],
                    dim = 1
                ).detach().cpu().numpy()

                metric_fn.add(pred_boxes, gt_bbox)
        except Exception as e:
            pass

with open(results_path / "mAP_results.txt", "w") as f:
    f.write(f"MAP@0.9: {metric_fn.value(iou_thresholds = 0.9)['mAP']}\n")
    f.write(f"MAP@0.5: {metric_fn.value(iou_thresholds = 0.5)['mAP']}\n")
    f.write(f"MAP@0.25: {metric_fn.value(iou_thresholds = 0.25)['mAP']}\n")
    f.write(f"MAP@0.1: {metric_fn.value(iou_thresholds = 0.1)['mAP']}\n")
    f.write(f"MAP@0.05: {metric_fn.value(iou_thresholds = 0.05)['mAP']}\n")

Assessing mAP: 100%|██████████| 50/50 [01:25<00:00,  1.72s/it]
