In [1]:
from Examples.test_async_metrics import convType

#Set hyperparameters

epochs = 1000
batch_size = 32 #optimal: 32
lr = 5e-4
dataset = 'ncaltech'  # 'ncars' or 'ncaltech'
convType="fuse" #fuse or ori_aegnn
ncars_path = r'/Users/hannes/Documents/University/Datasets/raw_ncars/Prophesee_Dataset_n_cars'
ncaltech_path =r'/Users/hannes/Documents/University/Datasets/raw_ncaltec'


# EvGNN Training Pipeline

In [2]:


import sys, os
# Set MPS fallback BEFORE importing torch
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

notebook_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebook_dir, '../src', '..'))
src_path = os.path.join(project_root, 'src')
if src_path not in sys.path:
    sys.path.insert(0, src_path)
if project_root not in sys.path:
    sys.path.insert(0, project_root)
# comment out if youre on windows and remove 'src.' prefixes from imports #

from src.Datasets.batching import BatchManager

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW



# Configuration: Device Selection

In [3]:
# Device options: 'auto', 'mps', 'cuda', 'cpu'
USE_DEVICE = 'cpu'  # 'auto' = try MPS/CUDA first, fallback to CPU

CPU_THREADS = 8
# Apply settings
torch.set_num_threads(CPU_THREADS)

if USE_DEVICE == 'cpu':
    device = torch.device("cpu")
    print(f"Using device: cpu with {CPU_THREADS} threads (forced)")
elif USE_DEVICE == 'mps':
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print(f"Using device: mps (Apple Silicon GPU with CPU fallback for unsupported ops)")
        print(f"CPU operations will use {CPU_THREADS} threads")
    else:
        device = torch.device("cpu")
        print(f"MPS not available, falling back to cpu with {CPU_THREADS} threads")
elif USE_DEVICE == 'cuda':
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using device: cuda")
    else:
        device = torch.device("cpu")
        print(f"CUDA not available, falling back to cpu with {CPU_THREADS} threads")
else:  # 'auto'
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print(f"Using device: mps (Apple Silicon GPU with CPU fallback)")
        print(f"CPU operations will use {CPU_THREADS} threads")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using device: cuda")
    else:
        device = torch.device("cpu")
        print(f"Using device: cpu with {CPU_THREADS} threads")

Using device: cpu with 8 threads (forced)


# Dataset Selection

In [4]:
from src.Datasets.ncaltech101 import NCaltech
from src.Datasets.ncars import NCars
if dataset == 'ncars':

    num_classes = len(NCars.get_info().classes)
    image_size = NCars.get_info().image_size
elif dataset == 'ncaltech':

    num_classes = len(NCaltech.get_info().classes)
    image_size = NCaltech.get_info().image_size

print(f"Dataset: {dataset}")
print(f"Number of classes: {num_classes}")

Dataset: ncaltech
Number of classes: 101


# Model Setup

In [5]:
from src.Models.CleanEvGNN.recognition import RecognitionModel as EvGNN
from torch_geometric.data import Data as PyGData

img_shape_for_model = (image_size[1], image_size[0])  # Swap to (width, height)

evgnn = EvGNN(
    network="graph_res",
    dataset=dataset,
    num_classes=num_classes,
    img_shape=img_shape_for_model,
    dim=3,
    conv_type=convType,
    distill=False
).to(device)

def transform_graph(graph: PyGData) -> PyGData:
    return evgnn.data_transform(
        graph, n_samples=10000, sampling=True,
        beta=0.5e-5, radius=3.0, max_neighbors=16
    ).to(device)

# Dataset Loading

In [6]:
if dataset == 'ncaltech':
    dataset_obj = NCaltech(
        root=ncaltech_path,
        transform=transform_graph
    )
elif dataset == 'ncars':
    dataset_obj = NCars(
        root=ncars_path,
        transform=transform_graph
    )

dataset_obj.process(modes=["training"])
num_training_samples = dataset_obj.get_mode_length("training")
print(f"Training samples: {num_training_samples}")

training_set = BatchManager(
    dataset=dataset_obj,
    batch_size=batch_size,
    mode="training"
)

x

ðŸ“‚ Processing folder: gerenuk


gerenuk:   0%|          | 0/27 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: hawksbill


Error processing line 1 of /opt/anaconda3/envs/GNNBenchmark/lib/python3.11/site-packages/distutils-precedence.pth:



hawksbill:   0%|          | 0/80 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: headphone


  Traceback (most recent call last):
    File "<frozen site>", line 195, in addpackage
    File "<string>", line 1, in <module>
  ModuleNotFoundError: No module named '_distutils_hack'

Remainder of file ignored


headphone:   0%|          | 0/33 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: ant


ant:   0%|          | 0/33 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: butterfly


butterfly:   0%|          | 0/72 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: lamp


lamp:   0%|          | 0/48 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: strawberry


strawberry:   0%|          | 0/28 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: water_lilly


water_lilly:   0%|          | 0/29 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: chandelier


chandelier:   0%|          | 0/85 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: dragonfly


dragonfly:   0%|          | 0/54 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: crab


crab:   0%|          | 0/58 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: pagoda


pagoda:   0%|          | 0/37 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: dollar_bill


dollar_bill:   0%|          | 0/41 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: emu


emu:   0%|          | 0/42 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: inline_skate


inline_skate:   0%|          | 0/24 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: platypus


platypus:   0%|          | 0/27 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: dalmatian


dalmatian:   0%|          | 0/53 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: cup


cup:   0%|          | 0/45 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: airplanes


airplanes:   0%|          | 0/640 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: joshua_tree


joshua_tree:   0%|          | 0/51 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: cougar_body


cougar_body:   0%|          | 0/37 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: grand_piano


grand_piano:   0%|          | 0/79 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: trilobite


trilobite:   0%|          | 0/68 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: brontosaurus


brontosaurus:   0%|          | 0/34 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: wild_cat


wild_cat:   0%|          | 0/27 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: pigeon


pigeon:   0%|          | 0/36 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: dolphin


dolphin:   0%|          | 0/52 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: soccer_ball


soccer_ball:   0%|          | 0/51 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: wrench


wrench:   0%|          | 0/31 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: scorpion


scorpion:   0%|          | 0/67 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: flamingo_head


flamingo_head:   0%|          | 0/36 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: nautilus


nautilus:   0%|          | 0/44 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: accordion


accordion:   0%|          | 0/44 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: cougar_face


cougar_face:   0%|          | 0/55 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: pyramid


pyramid:   0%|          | 0/45 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: camera


camera:   0%|          | 0/40 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: barrel


barrel:   0%|          | 0/37 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: schooner


schooner:   0%|          | 0/50 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: cellphone


cellphone:   0%|          | 0/47 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: panda


panda:   0%|          | 0/30 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: revolver


revolver:   0%|          | 0/65 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: lobster


lobster:   0%|          | 0/32 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: menorah


menorah:   0%|          | 0/69 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: lotus


lotus:   0%|          | 0/52 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: stapler


stapler:   0%|          | 0/36 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: crocodile


crocodile:   0%|          | 0/40 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: chair


chair:   0%|          | 0/49 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: helicopter


helicopter:   0%|          | 0/70 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: minaret


minaret:   0%|          | 0/60 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: starfish


starfish:   0%|          | 0/68 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: ceiling_fan


ceiling_fan:   0%|          | 0/37 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: ketch


ketch:   0%|          | 0/91 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: mayfly


mayfly:   0%|          | 0/32 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: wheelchair


wheelchair:   0%|          | 0/47 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: bass


bass:   0%|          | 0/43 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: yin_yang


yin_yang:   0%|          | 0/48 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: crocodile_head


crocodile_head:   0%|          | 0/40 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: saxophone


saxophone:   0%|          | 0/32 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: beaver


beaver:   0%|          | 0/36 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: mandolin


mandolin:   0%|          | 0/34 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: bonsai


bonsai:   0%|          | 0/102 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: Leopards


Leopards:   0%|          | 0/160 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: car_side


car_side:   0%|          | 0/98 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: ibis


ibis:   0%|          | 0/64 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: electric_guitar


electric_guitar:   0%|          | 0/60 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: kangaroo


kangaroo:   0%|          | 0/68 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: stegosaurus


stegosaurus:   0%|          | 0/47 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: ferry


ferry:   0%|          | 0/53 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: snoopy


snoopy:   0%|          | 0/28 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: umbrella


umbrella:   0%|          | 0/60 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: rhino


rhino:   0%|          | 0/47 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: okapi


okapi:   0%|          | 0/31 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: watch


watch:   0%|          | 0/191 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: brain


brain:   0%|          | 0/78 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: gramophone


gramophone:   0%|          | 0/40 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: scissors


scissors:   0%|          | 0/31 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: rooster


rooster:   0%|          | 0/39 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: cannon


cannon:   0%|          | 0/34 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: binocular


binocular:   0%|          | 0/26 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: anchor


anchor:   0%|          | 0/33 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: octopus


octopus:   0%|          | 0/28 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: buddha


buddha:   0%|          | 0/68 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: laptop


laptop:   0%|          | 0/64 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: windsor_chair


windsor_chair:   0%|          | 0/44 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: hedgehog


hedgehog:   0%|          | 0/43 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: pizza


pizza:   0%|          | 0/42 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: euphonium


euphonium:   0%|          | 0/51 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: stop_sign


stop_sign:   0%|          | 0/51 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: Motorbikes


Motorbikes:   0%|          | 0/638 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: sea_horse


sea_horse:   0%|          | 0/45 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: flamingo


flamingo:   0%|          | 0/53 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: ewer


ewer:   0%|          | 0/68 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: garfield


garfield:   0%|          | 0/27 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: crayfish


crayfish:   0%|          | 0/56 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: Faces_easy


Faces_easy:   0%|          | 0/348 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: sunflower


sunflower:   0%|          | 0/68 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: llama


llama:   0%|          | 0/62 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: elephant


elephant:   0%|          | 0/51 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: tick


tick:   0%|          | 0/39 [00:00<?, ?it/s]


ðŸ“‚ Processing folder: metronome


metronome:   0%|          | 0/25 [00:00<?, ?it/s]

Training samples: 6559


# Training Loop

In [7]:



optimizer = AdamW(evgnn.parameters(), lr=lr, weight_decay=1e-7)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=250, cooldown = 25)
loss_fn = CrossEntropyLoss()

print("Starting training...")

evgnn.train()
losses = []

for i in range(epochs):
    optimizer.zero_grad()
    examples = next(training_set)
    reference = examples.y.to(device)

    out = evgnn(examples)
    loss = loss_fn(out, reference)

    if torch.isnan(loss) or torch.isinf(loss):
        print(f"Loss is NaN/Inf at iteration {i}!")
        break

    loss.backward()
    torch.nn.utils.clip_grad_norm_(evgnn.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step(loss.item())

    losses.append(loss.item())

    if i % 10 == 0:
        with torch.no_grad():
            predictions = out.argmax(dim=-1)
            accuracy = (predictions == reference).float().mean().item()
        print(f"Iteration {i:4d} | Loss: {loss.item():.4f} | Acc: {accuracy:.3f}")

print(f"\nâœ… Training complete!")


Starting training...
Iteration    0 | Loss: 4.6521 | Acc: 0.000
Iteration   10 | Loss: 4.5185 | Acc: 0.219
Iteration   20 | Loss: 4.4660 | Acc: 0.344
Iteration   30 | Loss: 3.5567 | Acc: 0.312
Iteration   40 | Loss: 3.0597 | Acc: 0.250
Iteration   50 | Loss: 3.4164 | Acc: 0.406
Iteration   60 | Loss: 3.4108 | Acc: 0.344
Iteration   70 | Loss: 2.8955 | Acc: 0.344
Iteration   80 | Loss: 2.9931 | Acc: 0.406
Iteration   90 | Loss: 3.1978 | Acc: 0.375
Iteration  100 | Loss: 3.4000 | Acc: 0.344
Iteration  110 | Loss: 2.5911 | Acc: 0.375


KeyboardInterrupt: 

In [8]:
# Save model in results directory
os.makedirs('../results/TrainedModels', exist_ok=True)
model_path = f'../results/TrainedModels/evgnn_{dataset}_{convType}.pth'
torch.save(evgnn.state_dict(), model_path)
print(f"Model saved to: {model_path}")

Model saved to: ../results/TrainedModels/evgnn_ncaltech_fuse.pth


# Loss Visualization

In [None]:
import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(losses, alpha=0.3, color='blue', label='Raw loss')
window = 50
moving_mean = np.convolve(losses, np.ones(window)/window, mode='valid')
plt.plot(range(window-1, len(losses)), moving_mean, color='orange', linewidth=2, label=f'Moving avg')
plt.title(f'{dataset.upper()} Training Loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

# Test Evaluation

In [None]:
dataset_obj.process(modes=["test"])
num_test_samples = dataset_obj.get_mode_length("test")
print(f"Test samples: {num_test_samples}")

test_set = BatchManager(
    dataset=dataset_obj,
    batch_size=32,
    mode="test"
)

evgnn.eval()
correct = 0
total = 0

print("Evaluating on test set...")
num_test_batches = (num_test_samples + 31) // 32

with torch.no_grad():
    for i in range(num_test_batches):
        examples = next(test_set)
        reference = examples.y.to(device)
        out = evgnn(examples)
        predictions = out.argmax(dim=-1)

        correct += (predictions == reference).sum().item()
        total += reference.size(0)

test_accuracy = correct / total
print(f"\n{'='*50}")
print(f"TEST ACCURACY: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
print(f"{'='*50}")

# Save results
import json
results = {
    'dataset': dataset,
    'architecture': 'fuse',
    'test_accuracy': float(test_accuracy),
    'num_test_samples': total
}

os.makedirs('results', exist_ok=True)
with open(f'results/test_results_{dataset}orig_aegnn.json', 'w') as f:
    json.dump(results, f, indent=2)

print(f"Results saved to: results/test_results_{dataset}_fuse.json")
