# Import

In [None]:
from functools import partial
from pathlib import Path
from typing import Literal

import torch
from tqdm.auto import tqdm

from Datasets.base import DatasetMode
from Datasets.batching import BatchManager

# Parameter Setup

In [None]:
# Choose a trained dataset and model
dataset_name: Literal["ncaltech", "ncars"] = "ncars"
model_name: Literal["AEGNN", "AEGNN-EVGNN", "EVGNN", "EGSST"] = "EVGNN"

# Set corresponding paths
dataset_path: Path = Path(r"/home/benio/Documents/Datasets/NCars")
pretrained_model_path: Path | None = Path(rf"/home/benio/Documents/GNNBenchmark/Results/Recognition_{model_name}_on_{dataset_name}/TrainedModels/500.pth")
results_path: Path = Path(r"../Results") / f"Recognition_{model_name}_on_{dataset_name}"

# Accuracy parameters
trained_mode: DatasetMode = "training"
epoch_count: int = 100
batch_size: int = 8

# Graph preprocessing parameters
inference_event_count: int = 10000
beta: float = 0.0001
radius: float = 5.

# EGSST specific parameters
min_nodes_subgraph: int = 1000
ecnn_flag: bool = True # Whether enhanced cnn should be used
ti_flag: bool = True # Whether TAC augmentation should be used

# AEGNN specific parameters
kernel_size: int = 8
pooling_outputs: int = 128
max_neighbors: int = 32
sampling: bool = True

# Dataset Initialization

In [None]:
if dataset_name == "ncars":
    from Datasets.ncars import NCars
    dataset = NCars(
        root = dataset_path
    )
elif dataset_name == "ncaltech":
    from Datasets.ncaltech101 import NCaltech
    dataset = NCaltech(
        root = dataset_path
    )
else:
    raise ValueError(f"Dataset {dataset_name} not implemented.")

dataset.process()
print(f"Dataset Initialized.")

batch_manager = BatchManager(
    dataset = dataset,
    batch_size = batch_size,
    mode = trained_mode
)

# Model Initialized

In [None]:
if model_name == "AEGNN":
    from src.Models.CleanAEGNN.GraphRes import GraphRes as AEGNN
    model: AEGNN = AEGNN(
        input_shape = (*dataset.get_info().image_size, 3),
        kernel_size = kernel_size,
        n = [1, 16, 32, 32, 32, 128, 128, 128],
        pooling_outputs = pooling_outputs,
        num_outputs = len(dataset.get_info().classes),
    )

    graph_transform = partial(
        model.data_transform,
        n_samples = inference_event_count,
        sampling = sampling,
        beta =  beta,
        radius = radius,
        max_neighbors = max_neighbors
    )
elif model_name == "EVGNN":
    from src.Models.CleanEvGNN.GraphRes_Base import GraphRes as EvGNN
    model: EvGNN = EvGNN(
        input_shape = torch.tensor([*dataset.get_info().image_size, 3]),
        dataset = dataset_name,
        num_outputs = len(dataset.get_info().classes),
        conv_type = "fuse",
        distill = False,  # <– no KD, just normal training
    )

    graph_transform = partial(
        model.data_transform,
        n_samples=inference_event_count,
        sampling=sampling,
        beta=beta,
        radius=radius,
        max_neighbors=max_neighbors
    )
elif model_name == "AEGNN-EVGNN":
    from src.Models.CleanEvGNN.GraphRes_Base import GraphRes as EvGNN
    model: EvGNN = EvGNN(
        input_shape = torch.tensor([*dataset.get_info().image_size, 3]),
        dataset = dataset_name,
        num_outputs = len(dataset.get_info().classes),
        conv_type = 'ori_aegnn',
        distill = False,  # <– no KD, just normal training
    )

    graph_transform = partial(
        model.data_transform,
        n_samples=inference_event_count,
        sampling=sampling,
        beta=beta,
        radius=radius,
        max_neighbors=max_neighbors
    )
elif model_name == "EGSST":
    from Models.EGSST.EGSST import EGSST

    model: EGSST = EGSST(
        dataset_information = dataset.get_info(),
        detection_head_config = "",
        task = "cls"
    )

    def transform_graph(graph):
        graph = model.data_transform(
            graph, beta = beta,
            radius = radius,
            min_nodes_subgraph = min_nodes_subgraph,
            n_samples = inference_event_count,
            sub_sample = sampling
        )

        if graph is None or graph.pos is None:
            return None

        return graph
else:
    raise ValueError(f"Model {model_name} not implemented.")

if pretrained_model_path is not None:
    print(f"Loading model from: {pretrained_model_path}")
    state_dict = torch.load(pretrained_model_path, weights_only = True)
    model.load_state_dict(state_dict)
    print("Model Loaded.")

dataset.transform = graph_transform

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Using device: {device}")

# Assessing Accuracy

In [None]:
model.eval()
predictions_made = 0
correct = 0

for i in tqdm(range(epoch_count)):
    batch = next(batch_manager)
    batch = batch.to(device)
    reference = batch.y.long()

    out = model(batch)
    prediction = out.argmax(dim = -1)
    is_correct = (prediction - reference) == 0
    correct += is_correct.sum().cpu().item()
    predictions_made += is_correct.shape[0]

print(f"Accuracy: {correct / predictions_made * 100:.2f}%")
with open(results_path / "accuracy_results.txt", "w") as f:
    f.write(f"Analyzed model: {pretrained_model_path.name}\n")
    f.write(f"Accuracy: {correct / predictions_made * 100:.2f}%\n")
    f.write(f"   Predictions made: {predictions_made}\n")
    f.write(f"   Correctly Predicted: {correct}\n")