In [1]:
import os
import sys

print(os.getcwd())
os.chdir("..")
os.chdir("..")
print(os.getcwd())

/Users/mariayuffa/synqronix/src/synqronix/quantum_models
/Users/mariayuffa/synqronix/src


# Node level training

In [None]:
from synqronix.trainer import GNNTrainer
from synqronix.evaluation import plot_training_curves
from synqronix.models.qgcn import QGCN
from torch.nn import LeakyReLU
import torch
from synqronix.dataproc.dataloader import ColumnarNeuralGraphDataLoader, NeuralGraphDataLoader
data_dir = "./Auditory cortex data"
dataloader = NeuralGraphDataLoader(
    data_dir=data_dir,
    k=20,
    connectivity_threshold=0.5,
    batch_size=32
)

train_loader, val_loader, test_loader = dataloader.get_dataloaders()

model_kwargs = {
    'num_features': dataloader.get_num_features(),
    'num_classes': dataloader.get_num_classes(),
    'hidden_dim': 64,
    'num_layers': 3,
    'dropout_rate': 0.5
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint_dir = "./results/checkpoints"
checkpoint_freq = 5
model = QGCN(input_dims=model_kwargs['num_features'], q_depths=[2,2], output_dims=model_kwargs['num_classes'], activ_fn=LeakyReLU(0.2), classifier=None, readout=False)
lr = 0.001
epochs = 160

trainer = GNNTrainer(
        model=model,
        device=device,
        save_dir=checkpoint_dir,
        checkpoint_freq=checkpoint_freq
    )
    
trainer.setup_optimizer(
        optimizer_type='Adam',
        lr=lr,
        weight_decay=1e-4
    )
    
print("Starting training...")
train_losses, val_losses, train_accs, val_accs = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=epochs,
    resume_from=None
    )
plot_training_curves(
    train_losses, val_losses, train_accs, val_accs, trainer.val_f1_scores,
    save_path=os.path.join(checkpoint_dir, 'training_curves.png')
)
