# Load torch device

In [1]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, AdamW

from Datasets.batching import BatchManager

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


# Model Initialization

In [2]:
from Models.CleanEvGNN.recognition import RecognitionModel as EvGNN
from torch_geometric.data import Data as PyGData
from Datasets.ncaltech101 import NCaltech
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_size: tuple[int, int] = NCaltech.get_info().image_size  # (height=240, width=180)
input_shape: tuple[int, int, int] = (*image_size, 3)

# Vertausche Dimensionen: pos ist [x, y] = [width, height], also Ã¼bergeben wir (width, height)
img_shape_for_model = (image_size[1], image_size[0])  # (180, 240) statt (240, 180)

evgnn = EvGNN(
    network="graph_res",
    dataset="ncaltech101",
    num_classes = len(NCaltech.get_info().classes),
    img_shape=img_shape_for_model,  # (width=180, height=240)
    dim=3,
    conv_type="fuse",
    distill=False,        # <â€“ no KD, just normal training
).to(device)


def transform_graph(graph: PyGData) -> PyGData:
    return evgnn.data_transform(
        graph, n_samples=25000, sampling=True,
        beta=0.5e-5, radius=5.0,
        max_neighbors=32
    ).to(device)

# Dataset Initialization and processing (from the parsed dataset from the aegnn issues thread)

In [3]:
#Instantiating the ncaltech dataset
ncaltech = NCaltech(
    root=r"C:\Users\hanne\Documents\Hannes\Uni\Maastricht\Project\Datasets",
    transform=transform_graph
)

# Processing the training part of the dataset
ncaltech.process(modes = ["training"])

x

ðŸ“‚ Processing folder: accordion


accordion:   0%|          | 0/55 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: airplanes


airplanes:   0%|          | 0/800 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: anchor


anchor:   0%|          | 0/42 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: ant


ant:   0%|          | 0/42 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: barrel


barrel:   0%|          | 0/47 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: bass


bass:   0%|          | 0/54 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: beaver


beaver:   0%|          | 0/46 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: binocular


binocular:   0%|          | 0/33 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: bonsai


bonsai:   0%|          | 0/128 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: brain


brain:   0%|          | 0/98 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: brontosaurus


brontosaurus:   0%|          | 0/43 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: buddha


buddha:   0%|          | 0/85 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: butterfly


butterfly:   0%|          | 0/91 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: camera


camera:   0%|          | 0/50 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: cannon


cannon:   0%|          | 0/43 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: car_side


car_side:   0%|          | 0/123 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: ceiling_fan


ceiling_fan:   0%|          | 0/47 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: cellphone


cellphone:   0%|          | 0/59 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: chair


chair:   0%|          | 0/62 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: chandelier


chandelier:   0%|          | 0/107 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: cougar_body


cougar_body:   0%|          | 0/47 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: cougar_face


cougar_face:   0%|          | 0/69 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: crab


crab:   0%|          | 0/73 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: crayfish


crayfish:   0%|          | 0/70 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: crocodile


crocodile:   0%|          | 0/50 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: crocodile_head


crocodile_head:   0%|          | 0/51 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: cup


cup:   0%|          | 0/57 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: dalmatian


dalmatian:   0%|          | 0/67 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: dollar_bill


dollar_bill:   0%|          | 0/52 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: dolphin


dolphin:   0%|          | 0/65 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: dragonfly


dragonfly:   0%|          | 0/68 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: electric_guitar


electric_guitar:   0%|          | 0/75 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: elephant


elephant:   0%|          | 0/64 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: emu


emu:   0%|          | 0/53 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: euphonium


euphonium:   0%|          | 0/64 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: ewer


ewer:   0%|          | 0/85 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: Faces_easy


Faces_easy:   0%|          | 0/435 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: ferry


ferry:   0%|          | 0/67 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: flamingo


flamingo:   0%|          | 0/67 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: flamingo_head


flamingo_head:   0%|          | 0/45 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: garfield


garfield:   0%|          | 0/34 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: gerenuk


gerenuk:   0%|          | 0/34 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: gramophone


gramophone:   0%|          | 0/51 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: grand_piano


grand_piano:   0%|          | 0/99 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: hawksbill


hawksbill:   0%|          | 0/100 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: headphone


headphone:   0%|          | 0/42 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: hedgehog


hedgehog:   0%|          | 0/54 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: helicopter


helicopter:   0%|          | 0/88 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: ibis


ibis:   0%|          | 0/80 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: inline_skate


inline_skate:   0%|          | 0/31 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: joshua_tree


joshua_tree:   0%|          | 0/64 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: kangaroo


kangaroo:   0%|          | 0/86 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: ketch


ketch:   0%|          | 0/114 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: lamp


lamp:   0%|          | 0/61 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: laptop


laptop:   0%|          | 0/81 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: Leopards


Leopards:   0%|          | 0/200 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: llama


llama:   0%|          | 0/78 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: lobster


lobster:   0%|          | 0/41 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: lotus


lotus:   0%|          | 0/66 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: mandolin


mandolin:   0%|          | 0/43 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: mayfly


mayfly:   0%|          | 0/40 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: menorah


menorah:   0%|          | 0/87 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: metronome


metronome:   0%|          | 0/32 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: minaret


minaret:   0%|          | 0/76 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: Motorbikes


Motorbikes:   0%|          | 0/798 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: nautilus


nautilus:   0%|          | 0/55 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: octopus


octopus:   0%|          | 0/35 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: okapi


okapi:   0%|          | 0/39 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: pagoda


pagoda:   0%|          | 0/47 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: panda


panda:   0%|          | 0/38 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: pigeon


pigeon:   0%|          | 0/45 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: pizza


pizza:   0%|          | 0/53 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: platypus


platypus:   0%|          | 0/34 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: pyramid


pyramid:   0%|          | 0/57 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: revolver


revolver:   0%|          | 0/82 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: rhino


rhino:   0%|          | 0/59 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: rooster


rooster:   0%|          | 0/49 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: saxophone


saxophone:   0%|          | 0/40 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: schooner


schooner:   0%|          | 0/63 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: scissors


scissors:   0%|          | 0/39 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: scorpion


scorpion:   0%|          | 0/84 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: sea_horse


sea_horse:   0%|          | 0/57 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: snoopy


snoopy:   0%|          | 0/35 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: soccer_ball


soccer_ball:   0%|          | 0/64 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: stapler


stapler:   0%|          | 0/45 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: starfish


starfish:   0%|          | 0/86 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: stegosaurus


stegosaurus:   0%|          | 0/59 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: stop_sign


stop_sign:   0%|          | 0/64 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: strawberry


strawberry:   0%|          | 0/35 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: sunflower


sunflower:   0%|          | 0/85 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: tick


tick:   0%|          | 0/49 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: trilobite


trilobite:   0%|          | 0/86 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: umbrella


umbrella:   0%|          | 0/75 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: watch


watch:   0%|          | 0/239 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: water_lilly


water_lilly:   0%|          | 0/37 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: wheelchair


wheelchair:   0%|          | 0/59 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: wild_cat


wild_cat:   0%|          | 0/34 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: windsor_chair


windsor_chair:   0%|          | 0/56 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: wrench


wrench:   0%|          | 0/39 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: yin_yang


yin_yang:   0%|          | 0/60 [00:00<?, ?it/s]

Display example events data point

In [4]:
training_set = BatchManager(
    dataset=ncaltech,
    batch_size=8,
    mode="training"
)

In [5]:
optimizer = Adam(evgnn.parameters(), lr=5e-5)
loss_fn = CrossEntropyLoss()

classes = ncaltech.get_info().classes

cls_to_idx = dict(zip(classes, range(len(classes))))

In [None]:
evgnn.train()
for i in range(200):
    examples = next(training_set)
    reference = torch.tensor([cls_to_idx[cls] for cls in examples.label], dtype=torch.long).to(device)
    out = evgnn(examples)
    loss = loss_fn(out, reference)
    loss.backward()
    optimizer.step()
    print(f"Iteration {i} loss: {loss.item()}")
    optimizer.zero_grad()

torch.save(evgnn.state_dict(), "evgnn_ncaltech_test.pth")

Iteration 0 loss: 4.840540885925293
Iteration 1 loss: 4.454733371734619
Iteration 2 loss: 4.309764385223389
Iteration 3 loss: 5.130025863647461
Iteration 4 loss: 4.360721111297607
Iteration 5 loss: 3.81257963180542
Iteration 6 loss: 3.5215659141540527
Iteration 7 loss: 4.5277838706970215
Iteration 8 loss: 4.685369968414307
Iteration 9 loss: 4.981554985046387
Iteration 10 loss: 5.860132217407227
Iteration 11 loss: 5.223621845245361
Iteration 12 loss: 5.345272541046143
Iteration 13 loss: 4.210827827453613
Iteration 14 loss: 3.5913376808166504
Iteration 15 loss: 2.6900486946105957
Iteration 16 loss: 4.999479293823242
Iteration 17 loss: 5.074362277984619
Iteration 18 loss: 3.7000162601470947
Iteration 19 loss: 3.8278160095214844
