# Load torch device

In [None]:
from Datasets.ncaltech101 import NCaltech
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from Datasets.batching import BatchManager

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model Initialization

In [None]:
from Models.CleanAEGNN.GraphRes import GraphRes as AEGNN
from torch_geometric.data import Data as PyGData

image_size: tuple[int, int] = NCaltech.get_info().image_size
input_shape: tuple[int, int, int] = (*image_size, 3)

model = AEGNN(
    input_shape = input_shape,
    kernel_size = 8,
    n = [1, 16, 32, 32, 32, 128, 128, 128],
    pooling_outputs = 128,
    num_outputs = len(NCaltech.get_info().classes),
).to(device)

def transform_graph(graph: PyGData) -> PyGData:
    graph = model.data_transform(
        graph, n_samples = 25000, sampling = True,
        beta =  0.5e-5, radius = 5.0,
        max_neighbors = 32
    ).to(device)
    return graph

# Dataset Initialization and processing (from the parsed dataset from the aegnn issues thread)

In [None]:
#Instantiating the ncaltech dataset
ncaltech = NCaltech(
    root=r"D:\Uniwersytet\GNNBenchmarking\Datasets\NCaltech",
    transform=transform_graph
)

# Processing the training part of the dataset
ncaltech.process(modes = ["training"])

Display example events data point

In [None]:
training_set = BatchManager(
    dataset=ncaltech,
    batch_size=8,
    mode="training"
)

In [None]:
optimizer = Adam(model.parameters(), lr=5e-5)
loss_fn = CrossEntropyLoss()

classes = ncaltech.get_info().classes

cls_to_idx = dict(zip(classes, range(len(classes))))

In [None]:
model.train()
for i in range(50):
    examples = next(training_set)
    reference = torch.tensor([cls_to_idx[cls] for cls in examples.label], dtype=torch.long).to(device)
    out = model(examples)
    loss = loss_fn(out, reference)
    loss.backward()
    optimizer.step()
    print(f"Iteration {i} loss: {loss.item()}")

    optimizer.zero_grad()

# torch.save(model.state_dict(), "aegnn_ncaltech.pth")