# Imports

In [None]:
import json
import os
from functools import partial
from pathlib import Path
from typing import Literal

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW

from Benchmarks.ModelTester import ModelTester
from Datasets.base import DatasetMode
from Datasets.batching import BatchManager

# Parameter Setup

In [None]:
# Choose a trained dataset and model
dataset_name: Literal["ncaltech", "ncars"] = "ncars"
model_name: Literal["AEGNN", "AEGNN-EVGNN", "EVGNN", "EGSST"] = "EVGNN"

# Set corresponding paths
dataset_path: Path = Path(r"D:\Uniwersytet\GNNBenchmarking\Datasets\NCars")
pretrained_model_path: Path | None = None
results_path: Path = Path(r"../Results") / f"Recognition_{model_name}_on_{dataset_name}"

# Training parameters
trained_mode: DatasetMode = "training"
epoch_count: int = 500
batch_size: int = 8

learning_rate: float = 4e-3
scheduler_patience: int = 100
scheduler_factor: float = 0.5

saving_frequency: int = 100

# Graph preprocessing parameters
inference_event_count: int = 10000
beta: float = 0.0001
radius: float = 3.

# Performance Analysis parameters
sampled_graphs: int = 50
batch_sizes: list[int] = [1, 2, 4, 8]
test_sizes: list[int] = [100, 100, 100, 100]
detail_model_parameters: bool = False

# EGSST specific parameters
min_nodes_subgraph: int = 1000
ecnn_flag: bool = True # Whether enhanced cnn should be used
ti_flag: bool = True # Whether TAC augmentation should be used

# AEGNN specific parameters
kernel_size: int = 8
pooling_outputs: int = 128
max_neighbors: int = 32
sampling: bool = True

# Dataset Initialization

In [None]:
if dataset_name == "ncars":
    from Datasets.ncars import NCars
    dataset = NCars(
        root = dataset_path
    )
elif dataset_name == "ncaltech":
    from Datasets.ncaltech101 import NCaltech
    dataset = NCaltech(
        root = dataset_path
    )
else:
    raise ValueError(f"Dataset {dataset_name} not implemented.")

dataset.process()
print(f"Dataset Initialized.")

batch_manager = BatchManager(
    dataset = dataset,
    batch_size = batch_size,
    mode = trained_mode
)

# Model Initialized

In [None]:
if model_name == "AEGNN":
    from src.Models.CleanAEGNN.GraphRes import GraphRes as AEGNN
    model: AEGNN = AEGNN(
        input_shape = (*dataset.get_info().image_size, 3),
        kernel_size = kernel_size,
        n = [1, 16, 32, 32, 32, 128, 128, 128],
        pooling_outputs = pooling_outputs,
        num_outputs = len(dataset.get_info().classes),
    )

    graph_transform = partial(
        model.data_transform,
        n_samples = inference_event_count,
        sampling = sampling,
        beta =  beta,
        radius = radius,
        max_neighbors = max_neighbors
    )
elif model_name == "EVGNN":
    from src.Models.CleanEvGNN.GraphRes_Base import GraphRes as EvGNN
    model: EvGNN = EvGNN(
        input_shape = torch.tensor([*dataset.get_info().image_size, 3]),
        dataset = dataset_name,
        num_outputs = len(dataset.get_info().classes),
        conv_type = "fuse",
        distill = False,  # <– no KD, just normal training
    )

    graph_transform = partial(
        model.data_transform,
        n_samples=inference_event_count,
        sampling=sampling,
        beta=beta,
        radius=radius,
        max_neighbors=max_neighbors
    )
elif model_name == "AEGNN-EVGNN":
    from src.Models.CleanEvGNN.GraphRes_Base import GraphRes as EvGNN
    model: EvGNN = EvGNN(
        input_shape = torch.tensor([*dataset.get_info().image_size, 3]),
        dataset = dataset_name,
        num_outputs = len(dataset.get_info().classes),
        conv_type = 'ori_aegnn',
        distill = False,  # <– no KD, just normal training
    )

    graph_transform = partial(
        model.data_transform,
        n_samples=inference_event_count,
        sampling=sampling,
        beta=beta,
        radius=radius,
        max_neighbors=max_neighbors
    )
elif model_name == "EGSST":
    from Models.EGSST.EGSST import EGSST

    model: EGSST = EGSST(
        dataset_information = dataset.get_info(),
        detection_head_config = "",
        task = "cls"
    )

    def transform_graph(graph):
        graph.x = graph.x[:inference_event_count, :]
        graph.pos = graph.pos[:inference_event_count, :]
        graph = model.data_transform(graph, beta = beta, radius = radius, min_nodes_subgraph = min_nodes_subgraph)

        if graph is None or graph.pos is None:
            return None

        return graph

    graph_transform = transform_graph
else:
    raise ValueError(f"Model {model_name} not implemented.")

if pretrained_model_path is not None:
    print(f"Loading model from: {pretrained_model_path}")
    state_dict = torch.load(pretrained_model_path, weights_only = True)
    model.load_state_dict(state_dict)
    print("Model Loaded.")

dataset.transform = graph_transform

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Using device: {device}")

# Performance Check

In [None]:
model_tester = ModelTester(
    results_path = results_path,
    model = model
)

model_hyperparameters = {
    "Data Preprocessing" : {
        "inference_event_count": inference_event_count,
        "beta": beta,
        "radius": radius
    }
}

if model_name == "EGSST":
    model_hyperparameters["Data Preprocessing"]["min_nodes_subgraph"] = min_nodes_subgraph
    model_hyperparameters.update({
        "EGSST Internal": {
            "ecnn_flag": ecnn_flag,
            "ti_flag": ti_flag,
        }
    })

if model_name == "AEGNN" or model_name == "AEGNN-EVGNN":
    model_hyperparameters["Data Preprocessing"]["sampling"] = sampling
    model_hyperparameters["Data Preprocessing"]["max_neighbors"] = max_neighbors

if model_name == "AEGNN":
    model_hyperparameters.update({
        "AEGNN Internal": {
            "kernel_size": kernel_size,
            "pooling_outputs": pooling_outputs,
        }
    })

print("Recording Model's Hyperparameters...")
model_tester.record_model_hyperparameters(model_hyperparameters)

if model_name != "AEGNN":
    print("Assessing Model's performance Metrics...")
    model_tester.test_model_performance(
        dataset = dataset,
        mode = trained_mode,
        sampled_count = sampled_graphs,
        batch_sizes = batch_sizes,
        test_sizes = test_sizes,
        device = device
    )

if detail_model_parameters:
    print("Detailing Model's Parameters...")
    model_tester.detail_model_parameters()

# Training Setup

In [None]:
optimizer = AdamW(model.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=scheduler_factor, patience=scheduler_patience
)
loss_fn = CrossEntropyLoss()

os.makedirs(results_path / "TrainedModels", exist_ok = True)

# Training

In [None]:
model.train()

losses = []
best_loss = float("inf")
epoch_of_best_loss = 0
with model_tester:
    for i in range(1, epoch_count + 1):
        batch = next(batch_manager)
        batch = batch.to(device)
        reference = batch.y.long()

        out = model(batch)
        loss = loss_fn(out, reference)
        loss.backward()
        optimizer.step()
        scheduler.step(loss.item())

        if loss.item() < best_loss:
            best_loss = loss.item()
            epoch_of_best_loss = i

        print(f"Epoch {i} | Learning Rate: {optimizer.param_groups[0]['lr']:.2e} | Epoch of Best Loss: {epoch_of_best_loss} | Loss: {loss.item():.2e}")
        losses.append(loss.item())

        optimizer.zero_grad()

        if i % saving_frequency == 0 or i == epoch_of_best_loss or i == epoch_count:
            torch.save(model.state_dict(), results_path / f"TrainedModels" / f"{i}.pth")

with open(results_path / "training_loss_log.json", 'w') as f:
    json.dump(losses, f, indent=4)

# Power Consumption

In [None]:
model_tester.print_power_consumption()