In [1]:
import torch
import torch.nn as nn
from torch_geometric.data import DataLoader
import argparse
import numpy as np
import random
import ogb
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator

from graph_transformer import GraphTransformerModel
from utils import compute_mutual_shortest_distances

import os
os.environ['CUDA_VISIBLE_DEVICES']='0,2,3,6'

parser = argparse.ArgumentParser(description='PyTorch implementation of relative positional encodings and relation-aware self-attention for graph Transformers')
args = parser.parse_args("")
args.device = 0
args.device = torch.device('cuda:'+ str(args.device) if torch.cuda.is_available() else 'cpu')
# args.device = torch.device('cpu')
print("device:", args.device)
# torch.cuda.set_device(args.device)

torch.manual_seed(0)
np.random.seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed = 0
set_seed(seed)

device: cuda:0


In [2]:
## embed_dim // num_heads should remain constant
args.dataset = 'ogbg-moltox21'
args.n_classes = 12
args.batch_size = 2
args.lr = 0.001
args.graph_pooling = 'mean'
args.proj_mode = 'nonlinear'
args.eval_metric = 'rocauc'
args.embed_dim = 512
args.ff_embed_dim = 1024
args.num_heads = 8
args.graph_layers = 4
args.dropout = 0.2
args.relation_type = 'shortest_dist'
args.pre_transform = compute_mutual_shortest_distances
args.max_vocab = 12
args.split = 'scaffold'
args.num_epochs = 200
args.weights_dropout = True

In [3]:
print("Loading data...")
print("dataset: {} ".format(args.dataset))
dataset = PygGraphPropPredDataset(name=args.dataset, pre_transform=args.pre_transform).shuffle()

split_idx = dataset.get_idx_split()
if args.split == 'scaffold':
    train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=False, drop_last=True)
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False, drop_last=True)
elif args.split == '80-20':
    train_loader = DataLoader(dataset[:int(0.8 * len(dataset))], batch_size=args.batch_size, shuffle=False, drop_last=True)
    test_loader = DataLoader(dataset[int(0.8 * len(dataset)):], batch_size=args.batch_size, shuffle=False, drop_last=True)

model = GraphTransformerModel(args).to(args.device)
# model = nn.DataParallel(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000, eta_min=1e-6)
criterion = torch.nn.BCEWithLogitsLoss(reduction = "mean")
evaluator = Evaluator(name=args.dataset)

for epoch in range(args.num_epochs):
    ############
    # TRAINING #
    ############
    
    model.train()
    
    loss_epoch = 0
    for idx, batch in enumerate(train_loader):
        z = model(batch.to(args.device))

        y = batch.y.float()
        is_valid = ~torch.isnan(y)

        optimizer.zero_grad()
        loss = criterion(z[is_valid], y[is_valid])
        loss.backward()
        optimizer.step()
        scheduler.step()

        loss_epoch += loss.item()

    print('Train loss:', loss_epoch / len(train_loader))
    
    ##############
    # EVALUATION #
    ##############
    
    model.eval()
    
    with torch.no_grad():
        loss_epoch = 0
        y_true = []
        y_scores = []
        for idx, batch in enumerate(test_loader):
            z = model(batch.to(args.device))

            y = batch.y.float()
            y_true.append(y)
            y_scores.append(z)
            is_valid = ~torch.isnan(y)

            loss = criterion(z[is_valid], y[is_valid])
            loss_epoch += loss.item()
        
        y_true = torch.cat(y_true, dim = 0)
        y_scores = torch.cat(y_scores, dim = 0)

    input_dict = {"y_true": y_true, "y_pred": y_scores}
    result_dict = evaluator.eval(input_dict)
    print('Test loss:', loss_epoch / len(test_loader))
    print('Test ROC-AUC:', result_dict[args.eval_metric])

Loading data...
dataset: ogbg-moltox21 
Train loss: 0.2815435029569856
Test loss: 0.2702550831611972
Test ROC-AUC: 0.4949803764641379


KeyboardInterrupt: 

In [None]:
## TODOs:
# ogbg-molhiv, molpcba, try all OGB leaderboard datasets
# explore network feature extraction using Networkx, survey (pick 5-10)
# try features independently, then try different compositions -- analyze thoroughly
# continuous relation features: https://arxiv.org/abs/2003.09229
# link prediction is the simplest structural task to implement
# TransformerXL has different way to do relative positional encoding
# https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L92
# Blockwise Self-Attention for Long Document Understanding
# Longformer <-- standard way