In [1]:
import torch
import torch.nn as nn
from torch_geometric.data import DataLoader
import argparse
import numpy as np
import random
import ogb
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator

from graph_transformer import GraphTransformerModel
from utils import compute_mutual_shortest_distances

parser = argparse.ArgumentParser(description='PyTorch implementation of relative positional encodings and relation-aware self-attention for graph Transformers')
args = parser.parse_args("")
args.device = 7
args.device = torch.device('cuda:'+ str(args.device) if torch.cuda.is_available() else 'cpu')
# args.device = torch.device('cpu')
print("device:", args.device)
# torch.cuda.set_device(args.device)

torch.manual_seed(0)
np.random.seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed = 0
set_seed(seed)

device: cuda:7


In [2]:
## embed_dim // num_heads should remain constant
args.dataset = 'ogbg-molhiv'
args.n_classes = 1
args.batch_size = 2
args.graph_pooling = 'mean'
args.proj_mode = 'nonlinear'
args.eval_metric = 'rocauc'
args.embed_dim = 320
args.ff_embed_dim = 640
args.num_heads = 8
args.graph_layers = 4
args.dropout = 0.4
args.relation_type = 'shortest_dist'
args.pre_transform = compute_mutual_shortest_distances
args.max_vocab = 12
args.split = 'scaffold'
args.weights_dropout = True
args.saved_model = './models/model_172_ogbg-molhiv_lr5e-05.pth'
args.k_hop_neighbors = 2
args.weight_decay = 0.01

In [3]:
print("Loading data...")
print("dataset: {} ".format(args.dataset))
dataset = PygGraphPropPredDataset(name=args.dataset, pre_transform=args.pre_transform)

split_idx = dataset.get_idx_split()
if args.split == 'scaffold':
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, drop_last=True)
elif args.split == '80-20':
    test_loader = DataLoader(dataset[int(0.8 * len(dataset)):], batch_size=args.batch_size, shuffle=False, drop_last=True)

model = GraphTransformerModel(args)
# Remove `module` artifact from using DistributedDataPrallel
state_dict = torch.load(args.saved_model)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
model = model.to(args.device)

criterion = torch.nn.BCEWithLogitsLoss(reduction = "mean")
evaluator = Evaluator(name=args.dataset)
    
##############
# EVALUATION #
##############

model.eval()

with torch.no_grad():
    loss_epoch = 0
    y_true = []
    y_scores = []
    for idx, batch in enumerate(test_loader):
        z = model(batch.to(args.device))

        y = batch.y.float()
        y_true.append(y)
        y_scores.append(z)
        is_valid = ~torch.isnan(y)

        loss = criterion(z[is_valid], y[is_valid])
        loss_epoch += loss.item()

    y_true = torch.cat(y_true, dim = 0)
    y_scores = torch.cat(y_scores, dim = 0)

input_dict = {"y_true": y_true, "y_pred": y_scores}
result_dict = evaluator.eval(input_dict)
print('Test loss:', loss_epoch / len(test_loader))
print('Test ROC-AUC:', result_dict[args.eval_metric])

Loading data...
dataset: ogbg-molhiv 
Test loss: 0.2049240654027298
Test ROC-AUC: 0.7768689873662249
