# Imports

In [None]:
import json
import os
from pathlib import Path
from typing import Literal

import torch
from torch.optim import AdamW

from Benchmarks.ModelTester import ModelTester
from Datasets.base import DatasetMode
from Datasets.batching import BatchManager

# Parameter Setup

In [None]:
# Choose a trained dataset and model
dataset_name: Literal["NCaltech101", "Gen1"] = "Gen1"
model_name: Literal["EGSST", "AEGNN"] = "EGSST"

# Set corresponding paths
dataset_path: Path = Path(r"D:\Uniwersytet\GNNBenchmarking\Datasets\GEN1-DS")
pretrained_model_path: Path | None = None
results_path: Path = Path(r"../Results") / f"Detection_{model_name}_on_{dataset_name}"

# Training parameters
trained_mode: DatasetMode = "validation"
epoch_count: int = 1000
batch_size: int = 8

learning_rate: float = 4e-4
scheduler_patience: int = 100
scheduler_factor: float = 0.5

saving_frequency: int = 100

# Graph preprocessing parameters
inference_event_count: int = 10000
beta: float = 0.0001
radius: float = 5.

# Performance Analysis parameters
sampled_graphs: int = 50
batch_sizes: list[int] = [1, 2, 4, 8]
test_sizes: list[int] = [100, 100, 100, 100]
detail_model_parameters: bool = False

# EGSST specific parameters
min_nodes_subgraph: int = 1000
detection_head_config: Path = Path(r"../confs/rtdtr_head_gen1.yml")
ecnn_flag: bool = True # Whether enhanced cnn should be used
ti_flag: bool = True # Whether TAC augmentation should be used

# AEGNN specific parameters
kernel_size: int = 8
pooling_outputs: int = 128
max_neighbors: int = 32
sampling: bool = True

# Loading Dataset

In [None]:
print(f"Loading {dataset_name} dataset...")
if dataset_name == "Gen1":
    from Datasets.gen1 import Gen1
    dataset = Gen1(
        root=dataset_path,
    )
elif dataset_name == "NCaltech101":
    from Datasets.ncaltech101 import NCaltech
    dataset = NCaltech(
        root=dataset_path,
    )
else:
    raise ValueError(f"Dataset {dataset_name} not implemented.")

dataset.process()
print(f"Dataset Initialized.")

batch_manager = BatchManager(
    dataset = dataset,
    batch_size = batch_size,
    mode = trained_mode
)

# Initializing Model

In [None]:
print(f"Initializing {model_name} model...")
if model_name == "EGSST":
    from Models.EGSST.EGSST import EGSST
    model = EGSST(
        dataset_information = dataset.get_info(),
        detection_head_config = str(detection_head_config)
    )

    def transform_graph(graph):
        graph.x = graph.x[:inference_event_count, :]
        graph.pos = graph.pos[:inference_event_count, :]
        graph = model.data_transform(graph, beta = beta, radius = radius, min_nodes_subgraph = min_nodes_subgraph)

        if graph is None or graph.pos is None:
            return None

        if isinstance(graph.bbox, dict):
            maximum_time = graph.pos[:, 2].max()
            times_of_boxes = torch.tensor(list(graph.bbox.keys()))
            time_diff = (times_of_boxes - maximum_time) ** 2
            graph.bbox = graph.bbox[list(graph.bbox.keys())[time_diff.argmin()]]

        return graph

    graph_transform = transform_graph
elif model_name == "AEGNN":
    from Models.CleanAEGNN.AEGNN_Detection import AEGNN_Detection
    model = AEGNN_Detection(
        input_shape = (*dataset.get_info().image_size, 3),
        kernel_size = kernel_size,
        n = [1, 16, 32, 32, 32, 128, 128, 128],
        pooling_outputs = pooling_outputs,
        num_classes = len(dataset.get_info().classes),
    )

    def transform_graph(graph):
        graph = model.data_transform(
            graph, n_samples = inference_event_count, sampling = sampling,
            beta =  beta, radius = radius,
            max_neighbors = max_neighbors
        )

        if isinstance(graph.bbox, dict):
            maximum_time = graph.pos[:, 2].max()
            times_of_boxes = torch.tensor(list(graph.bbox.keys()))
            time_diff = (times_of_boxes - maximum_time) ** 2
            graph.bbox = graph.bbox[list(graph.bbox.keys())[time_diff.argmin()]]

        return graph

    graph_transform = transform_graph
else:
    raise ValueError(f"Model {model_name} not implemented.")
print(f"Model Initialized.")

if pretrained_model_path is not None:
    print(f"Loading model from: {pretrained_model_path}")
    state_dict = torch.load(pretrained_model_path, map_location=torch.device("cpu"))
    model = model.load_state_dict(state_dict)
    print("Model Loaded.")

dataset.transform = graph_transform

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Using device: {device}")

# Performance Check

In [None]:
model_tester = ModelTester(
    results_path = results_path,
    model = model
)

model_hyperparameters = {
    "Data Preprocessing" : {
        "inference_event_count": inference_event_count,
        "beta": beta,
        "radius": radius
    }
}

if model_name == "EGSST":
    model_hyperparameters["Data Preprocessing"]["min_nodes_subgraph"] = min_nodes_subgraph
    model_hyperparameters.update({
        "EGSST Internal": {
            "detection_head_config": detection_head_config,
            "ecnn_flag": ecnn_flag,
            "ti_flag": ti_flag,
        }
    })
elif model_name == "AEGNN":
    model_hyperparameters["Data Preprocessing"]["sampling"] = sampling
    model_hyperparameters["Data Preprocessing"]["max_neighbors"] = max_neighbors
    model_hyperparameters.update({
        "AEGNN Internal": {
            "kernel_size": kernel_size,
            "pooling_outputs": pooling_outputs,
        }
    })

print("Recording Model's Hyperparameters...")
model_tester.record_model_hyperparameters(model_hyperparameters)

if model_name != "AEGNN":
    print("Assessing Model's performance Metrics...")
    model_tester.test_model_performance(
        dataset = dataset,
        mode = trained_mode,
        sampled_count = sampled_graphs,
        batch_sizes = batch_sizes,
        test_sizes = test_sizes,
        device = device
    )

if detail_model_parameters:
    print("Detailing Model's Parameters...")
    model_tester.detail_model_parameters()

# Training Setup

In [None]:
optimizer = AdamW(model.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=scheduler_factor, patience=scheduler_patience
)

os.makedirs(results_path / "TrainedModels", exist_ok = True)

# Training

In [None]:
loss_lists = {}
best_loss = float("inf")
epoch_of_best_loss = 0

with model_tester:
    model.train()

    for epoch in range(1, epoch_count + 1):
        batch = next(batch_manager)
        batch = batch.to(device)

        # ---- forward + backward ----
        optimizer.zero_grad()
        total_loss, loss_dict = model(batch)
        total_loss.backward()
        optimizer.step()
        scheduler.step(total_loss)

        epoch_loss = total_loss.item()
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            epoch_of_best_loss = epoch

        print(f"Epoch: {epoch} | Learning Rate: {optimizer.param_groups[0]['lr']}")
        print(f"Total loss: {epoch_loss:.4f} | Epoch of best loss: {epoch_of_best_loss}")
        for loss in ["loss_bbox", "loss_giou", "loss_ce"]:
            if loss in loss_dict:
                print(f"  - {loss}: {loss_dict[loss].item():.4f}")

            if loss not in loss_lists:
                loss_lists[loss] = []

            loss_lists[loss].append(loss_dict[loss].item())

        if epoch % saving_frequency == 0 or epoch == epoch_of_best_loss or epoch == epoch_count:
            torch.save(model.state_dict(), results_path / "TrainedModels" / f"{epoch}.pth")

with open(results_path / "training_loss_log.json", 'w') as f:
    json.dump(loss_lists, f, indent=4)

# Power Consumption

In [None]:
model_tester.print_power_consumption()