# Load torch device

In [3]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, AdamW

from Datasets.batching import BatchManager

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


# Model Initialization

In [4]:
from Models.CleanAEGNN.GraphRes import GraphRes as AEGNN
from torch_geometric.data import Data as PyGData
from Datasets.ncaltech101 import NCaltech

image_size: tuple[int, int] = NCaltech.get_info().image_size
input_shape: tuple[int, int, int] = (*image_size, 3)

aegnn = AEGNN(
    input_shape = input_shape,
    kernel_size = 8,
    n = [1, 16, 32, 32, 32, 128, 128, 128],
    pooling_outputs = 128,
    num_outputs = len(NCaltech.get_info().classes),
).to(device)

def transform_graph(graph: PyGData) -> PyGData:
    graph = aegnn.data_transform(
        graph, n_samples = 25000, sampling = True,
        beta =  0.5e-5, radius = 5.0,
        max_neighbors = 32
    ).to(device)
    return graph

# Dataset Initialization and processing (from the parsed dataset from the aegnn issues thread)

In [5]:
#Instantiating the ncaltech dataset
ncaltech = NCaltech(
    root=r""
    #root=r"D:\Uniwersytet\GNNBenchmarking\Datasets\NCaltech",
    transform=transform_graph
)

# Processing the training part of the dataset
ncaltech.process(modes = ["training"])

FileNotFoundError: [WinError 3] Das System kann den angegebenen Pfad nicht finden: 'D:\\Uniwersytet\\GNNBenchmarking\\Datasets\\NCaltech\\Caltech101_annotations'

Display example events data point

In [4]:
training_set = BatchManager(
    dataset=ncaltech,
    batch_size=8,
    mode="training"
)

In [5]:
optimizer = Adam(aegnn.parameters(), lr=5e-5)
loss_fn = CrossEntropyLoss()

classes = ncaltech.get_info().classes

cls_to_idx = dict(zip(classes, range(len(classes))))

In [6]:
aegnn.train()
for i in range(200):
    examples = next(training_set)
    reference = torch.tensor([cls_to_idx[cls] for cls in examples.label], dtype=torch.long).to(device)
    out = aegnn(examples)
    loss = loss_fn(out, reference)
    loss.backward()
    optimizer.step()
    print(f"Iteration {i} loss: {loss.item()}")

    optimizer.zero_grad()

torch.save(aegnn.state_dict(), "aegnn_ncaltech.pth")

Iteration 0 loss: 17.573427200317383
Iteration 1 loss: 17.83142852783203
Iteration 2 loss: 9.329046249389648
Iteration 3 loss: 11.488554954528809
Iteration 4 loss: 9.886789321899414
Iteration 5 loss: 9.645904541015625
Iteration 6 loss: 8.793233871459961
Iteration 7 loss: 7.6187944412231445
Iteration 8 loss: 9.701828002929688
Iteration 9 loss: 8.32140064239502
Iteration 10 loss: 9.481127738952637
Iteration 11 loss: 8.300086975097656
Iteration 12 loss: 8.616653442382812
Iteration 13 loss: 9.078987121582031
Iteration 14 loss: 7.399202346801758
Iteration 15 loss: 4.016829490661621
Iteration 16 loss: 7.69382905960083
Iteration 17 loss: 7.873544692993164
Iteration 18 loss: 5.749243259429932
Iteration 19 loss: 6.941044330596924
Iteration 20 loss: 6.451145648956299
Iteration 21 loss: 8.85360336303711
Iteration 22 loss: 6.640111923217773
Iteration 23 loss: 7.282352924346924
Iteration 24 loss: 4.321983337402344
Iteration 25 loss: 4.7192816734313965
Iteration 26 loss: 5.341638565063477
Iteration 