In [5]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
import numpy as np
import torch
import torch.nn as nn

from torch.utils.data import DataLoader

from stitt import Stitt, StittGraphClassifier
from utils import create_spectral_dataset, Trainer, collate_spectral_dataset_no_eigenvects

from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm

In [7]:
dataset = PygGraphPropPredDataset(name='ogbg-molhiv')
idx = dataset.get_idx_split()
max_graph = max([graph.num_nodes for graph in dataset])

In [8]:
train_spect_ds = create_spectral_dataset(dataset[idx["train"]], upsample=[15, 40])

Creating Spectral Dataset:   0%|          | 0/32901 [00:00<?, ?it/s]

Creating Spectral Dataset: 100%|██████████| 32901/32901 [06:25<00:00, 85.37it/s] 


In [9]:
batch_size = 8
n_heads = 8
n_layers = 4
d_input = 256
d_attn = 256
d_ffn = 128
n_epochs = 50
r_warmup = 0.1
n_classes = 2
lr = 5e-5
device = torch.device("cuda")

In [10]:
model = StittGraphClassifier(d_input=d_input, d_attn=d_attn, d_ffn=d_ffn,
                             max_graph=max_graph, n_heads=n_heads, n_layers=n_layers, n_classes=n_classes,
                             device=device)

In [11]:
train_loader = DataLoader(train_spect_ds, batch_size=batch_size, shuffle=True,collate_fn=collate_spectral_dataset_no_eigenvects)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0)

training_steps = n_epochs * len(train_spect_ds)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps= r_warmup * training_steps, num_training_steps=training_steps)

In [12]:
trainer = Trainer(model.to(device), optimizer=optimizer, scheduler=scheduler, criterion=criterion, device=device)

In [15]:
trainer.train(train_loader, n_epochs)

Epoch 1 completed
Epoch 2 completed
Epoch 3 completed
Epoch 4 completed
Epoch 5 completed
Epoch 6 completed
Epoch 7 completed
Epoch 8 completed
Epoch 9 completed


In [None]:
val_spect_ds = create_spectral_dataset(dataset[idx["valid"]])
val_loader = train_loader = DataLoader(
    train_spect_ds,
    batch_size=batch_size * 4,
    collate_fn=collate_spectral_dataset_no_eigenvects,
)

Creating Spectral Dataset:   0%|          | 0/4113 [00:00<?, ?it/s]

Creating Spectral Dataset: 100%|██████████| 4113/4113 [00:02<00:00, 1537.62it/s]


In [None]:
evaluator = Evaluator(name='ogbg-molhiv')
y_true = []
y_pred = []

with torch.no_grad():
    for batch in tqdm(val_loader):
        features, eigvects, mask, labels = batch
        features = features.to(device)
        eigvects = eigvects.to(device)
        mask = mask.to(device)
        labels = labels.to(device)

        outputs = model(features, eigvects, mask)
        predicted = torch.argmax(outputs, dim=1)

        y_true.append(labels)
        y_pred.append(predicted)

y_true = torch.concat(y_true).unsqueeze(1)
y_pred = torch.concat(y_pred).unsqueeze(1)

  0%|          | 0/4097 [00:00<?, ?it/s]

100%|██████████| 4097/4097 [05:26<00:00, 12.56it/s]


In [None]:

roc_auc = evaluator.eval({
    'y_true': y_true,
    'y_pred': y_pred
})['rocauc']

print(f"ROC-AUC: {roc_auc:.3f}")

ROC-AUC: 0.891


In [None]:
torch.save(model, "stitt_molhiv_1.pt")