In [1]:
%load_ext autoreload
%autoreload 2
import networkx as nx

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="6"

import torch
import torch.nn as nn
from torch_geometric.data import DataLoader
import argparse
import numpy as np
import random
import ogb
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from graph_transformer import GT
from utils import pre_process, pre_process_with_summary, get_n_params, get_optimizer
import datetime
from tqdm import tqdm
from tensorboardX import SummaryWriter
import pytz


Data(adamic_edge_attr=[308, 1], alloc_edge_attr=[308, 1], cn_edge_attr=[308, 1], comm_edge_attr=[308, 1], edge_attr=[308, 3], edge_index=[2, 308], hier_label=[32, 4], hsd_edge_attr=[308, 4], jaccard_edge_attr=[308, 1], lap_x=[32, 10], orig_edge_attr=[70, 3], orig_edge_index=[2, 70], sd_edge_attr=[308, 1], x=[32, 9], y=[1, 1])
torch.Size([70, 3])
torch.Size([308, 3])
tensor([[0.0000],
        [0.0000],
        [0.5000],
        [0.0000],
        [0.2500],
        [0.0000],
        [0.2500],
        [0.2500],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.5000],
        [0.0000],
        [0.0000],
        [0.5000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.3333],
        [0.0000],
        [0.3333],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.5000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.5000],
 

In [2]:
parser = argparse.ArgumentParser(description='PyTorch implementation of relative positional encodings and relation-aware self-attention for graph Transformers')
args = parser.parse_args("")

args.dataset = 'ogbg-molhiv'
args.n_classes = 1
args.lr = 2e-4
# args.lr = 1e-3
args.n_hid = 512
args.n_heads = 8
args.n_layer = 4
args.dropout = 0.2
args.num_epochs = 60
args.k_hop_neighbors = 3
args.weight_decay = 1e-2
args.bsz      = 128
args.strategies = ['ea', 'sd']
args.summary_node = True
args.hier_levels = 3
args.lap_k = 10
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.metric = 'rocauc'
print("device:", args.device)

device: cuda


In [3]:
print("Loading data...")
print("dataset: {} ".format(args.dataset))
tz = pytz.timezone('US/Pacific')
time_now = datetime.datetime.now(tz).strftime('%m-%d_%H:%M:%S')

if args.summary_node:
    pre_transform = lambda d : pre_process_with_summary(d, args)
    root_path= f'dataset/{args.dataset}/with_summary_{args.k_hop_neighbors}'
    args.writer = SummaryWriter(log_dir=f'runs_new/{args.dataset}/with_summary_k={args.k_hop_neighbors}/strats={"-".join(args.strategies)}/{time_now}')

else:
    pre_transform = lambda d : pre_process(d, args)
    root_path= f'dataset/{args.dataset}/{args.k_hop_neighbors}'
    args.writer = SummaryWriter(log_dir=f'runs_new/{args.dataset}/k={args.k_hop_neighbors}/strats={"-".join(args.strategies)}/{time_now}')
    
    
dataset = PygGraphPropPredDataset(name=args.dataset, pre_transform=pre_transform, root = root_path)
evaluator = Evaluator(name=args.dataset)
split_idx = dataset.get_idx_split()
edge_dim_dict = {'ea': dataset.data.edge_attr.max(dim=0)[0].int().view(-1) + 1, \
                 'disc': {
                     'sd': (dataset.data.sd_edge_attr.max(dim=0)[0].int().view(-1) + 1).tolist(), \
#                      'cn': (dataset.data.cn_edge_attr.max(dim=0)[0].int().view(-1) + 1).tolist(), \
#                      'hsd': (dataset.data.hsd_edge_attr.max(dim=0)[0].int().view(-1) + 1).tolist(), \
                    },
                 'cont': {
#                     'ja': dataset.data.jaccard_edge_attr.max(dim=0)[0].int().view(-1) + 1, \
#                     'ad': dataset.data.adamic_edge_attr.max(dim=0)[0].int().view(-1) + 1, \
                 }
                }
model = GT(args.n_hid, args.n_classes, args.n_heads, args.n_layer, edge_dim_dict, args.dropout, args.summary_node, args.lap_k).to(args.device)

Loading data...
dataset: ogbg-molhiv 


In [4]:
# Regular loader
# train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.bsz, shuffle=True)
# valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.bsz, shuffle=False)
# test_loader  = DataLoader(dataset[split_idx["test"]],  batch_size=args.bsz, shuffle=False)

In [5]:
# Loader with weighted sampler, for unbalanced data

from torch.utils.data import WeightedRandomSampler
weight = [1.0, np.sqrt((dataset.data.y == 0).sum().item() / (dataset.data.y == 1).sum().item())]
samples_weight = np.array([weight[yi] for yi in dataset.data.y.view(-1)[split_idx["train"]]])

samples_weight = torch.from_numpy(samples_weight)
samples_weigth = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.bsz, sampler = sampler)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.bsz, shuffle = False)
test_loader  = DataLoader(dataset[split_idx["test"]],  batch_size=args.bsz, shuffle = False)

In [6]:
print('Model #Params: %d' % get_n_params(model))

criterion = torch.nn.BCEWithLogitsLoss(reduction = "mean")

optimizer = get_optimizer(model, weight_decay = args.weight_decay)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, pct_start = 0.05,\
        steps_per_epoch=len(train_loader), epochs = args.num_epochs, anneal_strategy = 'linear')
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000, eta_min=1e-6)

Model #Params: 7462401


In [None]:
stats = []
for epoch in range(args.num_epochs):
    model.train()
    train_loss = []
    y_true = []
    y_scores = []
    for num_iters, data in enumerate(tqdm(train_loader)):
        data.to(args.device)
        strats = {'ea': data.edge_attr, 'sd': data.sd_edge_attr} # 'lap_x': data.lap_x
        out = model(data.x, data.batch, data.edge_index, strats)
        loss = criterion(out, data.y.float())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
        
        train_loss += [loss.item()]
        
        y_true += [data.y]
        y_scores += [out]
        
    args.writer.add_scalar("LR/epoch", optimizer.param_groups[0]['lr'], epoch + 1)
    args.writer.add_scalar("Loss/train", np.average(train_loss), epoch + 1)
    input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
    train_metric = evaluator.eval(input_dict)[args.metric]
    args.writer.add_scalar(args.metric + "/train", train_metric, epoch + 1)

    model.eval()
    with torch.no_grad():
        valid_loss = []
        y_true = []
        y_scores = []
        for num_iters, data in enumerate(tqdm(valid_loader)):
            data.to(args.device)
            strats = {'ea': data.edge_attr, 'sd': data.sd_edge_attr}
            out = model(data.x, data.batch, data.edge_index, strats)
        
            loss = criterion(out, data.y.float())
            valid_loss += [loss.item()]

            y_true += [data.y]
            y_scores += [out]

        input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
        valid_metric = evaluator.eval(input_dict)[args.metric]
        args.writer.add_scalar("Loss/valid", np.average(valid_loss), epoch + 1)
        args.writer.add_scalar(args.metric + "/valid", valid_metric, epoch + 1)
        
        test_loss = []
        y_true = []
        y_scores = []
        for data in test_loader:
            data.to(args.device)
            strats = {'ea': data.edge_attr, 'sd': data.sd_edge_attr}
            out = model(data.x, data.batch, data.edge_index, strats)
        
            loss = criterion(out, data.y.float())
            test_loss += [loss.item()]

            y_true += [data.y]
            y_scores += [out]

        input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
        test_metric = evaluator.eval(input_dict)[args.metric]
        args.writer.add_scalar("Loss/test", np.average(test_loss), epoch + 1)
        args.writer.add_scalar(args.metric + "/test", test_metric, epoch + 1)
    
    print('Epoch %d: LR: %.5f, Train loss: %.3f Train %s: %.3f Valid loss: %.3f  Valid %s: %.3f Test loss: %.3f  Test %s: %.3f' \
          % (epoch + 1, optimizer.param_groups[0]['lr'], np.average(train_loss), args.metric, train_metric, \
             np.average(valid_loss), args.metric, valid_metric, \
             np.average(test_loss), args.metric, test_metric))
    stats += [[epoch, np.average(train_loss), train_metric, np.average(valid_loss), valid_metric, np.average(test_loss), test_metric]]

args.writer.close()

100%|██████████| 258/258 [01:20<00:00,  3.20it/s]
100%|██████████| 33/33 [00:05<00:00,  5.81it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 1: LR: 0.00007, Train loss: 0.432 Train rocauc: 0.633 Valid loss: 0.231  Valid rocauc: 0.707 Test loss: 0.259  Test rocauc: 0.691


100%|██████████| 258/258 [01:09<00:00,  3.73it/s]
100%|██████████| 33/33 [00:04<00:00,  7.21it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 2: LR: 0.00014, Train loss: 0.378 Train rocauc: 0.753 Valid loss: 0.132  Valid rocauc: 0.752 Test loss: 0.154  Test rocauc: 0.750


100%|██████████| 258/258 [01:09<00:00,  3.71it/s]
100%|██████████| 33/33 [00:04<00:00,  7.25it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 3: LR: 0.00020, Train loss: 0.354 Train rocauc: 0.794 Valid loss: 0.344  Valid rocauc: 0.763 Test loss: 0.361  Test rocauc: 0.763


100%|██████████| 258/258 [01:10<00:00,  3.65it/s]
100%|██████████| 33/33 [00:04<00:00,  6.85it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 4: LR: 0.00020, Train loss: 0.326 Train rocauc: 0.824 Valid loss: 0.113  Valid rocauc: 0.757 Test loss: 0.148  Test rocauc: 0.700


100%|██████████| 258/258 [01:10<00:00,  3.68it/s]
100%|██████████| 33/33 [00:04<00:00,  7.20it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 5: LR: 0.00019, Train loss: 0.311 Train rocauc: 0.843 Valid loss: 0.112  Valid rocauc: 0.810 Test loss: 0.145  Test rocauc: 0.728


100%|██████████| 258/258 [01:12<00:00,  3.54it/s]
100%|██████████| 33/33 [00:04<00:00,  7.26it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 6: LR: 0.00019, Train loss: 0.301 Train rocauc: 0.859 Valid loss: 0.149  Valid rocauc: 0.751 Test loss: 0.190  Test rocauc: 0.715


100%|██████████| 258/258 [01:10<00:00,  3.64it/s]
100%|██████████| 33/33 [00:04<00:00,  7.24it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 7: LR: 0.00019, Train loss: 0.287 Train rocauc: 0.870 Valid loss: 0.154  Valid rocauc: 0.780 Test loss: 0.184  Test rocauc: 0.704


100%|██████████| 258/258 [01:09<00:00,  3.71it/s]
100%|██████████| 33/33 [00:04<00:00,  7.27it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 8: LR: 0.00018, Train loss: 0.278 Train rocauc: 0.884 Valid loss: 0.132  Valid rocauc: 0.780 Test loss: 0.185  Test rocauc: 0.674


100%|██████████| 258/258 [01:10<00:00,  3.67it/s]
100%|██████████| 33/33 [00:04<00:00,  6.64it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 9: LR: 0.00018, Train loss: 0.265 Train rocauc: 0.898 Valid loss: 0.240  Valid rocauc: 0.811 Test loss: 0.271  Test rocauc: 0.732


100%|██████████| 258/258 [01:10<00:00,  3.66it/s]
100%|██████████| 33/33 [00:04<00:00,  7.10it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 10: LR: 0.00018, Train loss: 0.249 Train rocauc: 0.912 Valid loss: 0.230  Valid rocauc: 0.798 Test loss: 0.262  Test rocauc: 0.722


100%|██████████| 258/258 [01:09<00:00,  3.74it/s]
100%|██████████| 33/33 [00:04<00:00,  7.30it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 11: LR: 0.00017, Train loss: 0.238 Train rocauc: 0.921 Valid loss: 0.132  Valid rocauc: 0.794 Test loss: 0.162  Test rocauc: 0.745


100%|██████████| 258/258 [01:11<00:00,  3.63it/s]
100%|██████████| 33/33 [00:04<00:00,  7.22it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 12: LR: 0.00017, Train loss: 0.223 Train rocauc: 0.930 Valid loss: 0.181  Valid rocauc: 0.807 Test loss: 0.220  Test rocauc: 0.736


100%|██████████| 258/258 [01:09<00:00,  3.69it/s]
100%|██████████| 33/33 [00:04<00:00,  7.29it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 13: LR: 0.00016, Train loss: 0.213 Train rocauc: 0.941 Valid loss: 0.219  Valid rocauc: 0.774 Test loss: 0.246  Test rocauc: 0.736


100%|██████████| 258/258 [01:09<00:00,  3.71it/s]
100%|██████████| 33/33 [00:04<00:00,  7.32it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 14: LR: 0.00016, Train loss: 0.205 Train rocauc: 0.946 Valid loss: 0.128  Valid rocauc: 0.760 Test loss: 0.178  Test rocauc: 0.696


100%|██████████| 258/258 [01:09<00:00,  3.73it/s]
100%|██████████| 33/33 [00:04<00:00,  7.38it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 15: LR: 0.00016, Train loss: 0.196 Train rocauc: 0.952 Valid loss: 0.176  Valid rocauc: 0.805 Test loss: 0.238  Test rocauc: 0.687


100%|██████████| 258/258 [01:08<00:00,  3.74it/s]
100%|██████████| 33/33 [00:04<00:00,  7.08it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 16: LR: 0.00015, Train loss: 0.181 Train rocauc: 0.960 Valid loss: 0.162  Valid rocauc: 0.799 Test loss: 0.219  Test rocauc: 0.716


100%|██████████| 258/258 [01:08<00:00,  3.76it/s]
100%|██████████| 33/33 [00:04<00:00,  7.28it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 17: LR: 0.00015, Train loss: 0.179 Train rocauc: 0.962 Valid loss: 0.145  Valid rocauc: 0.803 Test loss: 0.205  Test rocauc: 0.728


100%|██████████| 258/258 [01:09<00:00,  3.71it/s]
100%|██████████| 33/33 [00:04<00:00,  7.33it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 18: LR: 0.00015, Train loss: 0.168 Train rocauc: 0.967 Valid loss: 0.188  Valid rocauc: 0.797 Test loss: 0.241  Test rocauc: 0.719


100%|██████████| 258/258 [01:09<00:00,  3.70it/s]
100%|██████████| 33/33 [00:04<00:00,  7.30it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 19: LR: 0.00014, Train loss: 0.155 Train rocauc: 0.973 Valid loss: 0.142  Valid rocauc: 0.770 Test loss: 0.198  Test rocauc: 0.739


100%|██████████| 258/258 [01:10<00:00,  3.68it/s]
100%|██████████| 33/33 [00:04<00:00,  7.31it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 20: LR: 0.00014, Train loss: 0.150 Train rocauc: 0.974 Valid loss: 0.131  Valid rocauc: 0.774 Test loss: 0.192  Test rocauc: 0.733


100%|██████████| 258/258 [01:10<00:00,  3.68it/s]
100%|██████████| 33/33 [00:04<00:00,  7.07it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 21: LR: 0.00014, Train loss: 0.143 Train rocauc: 0.976 Valid loss: 0.145  Valid rocauc: 0.812 Test loss: 0.228  Test rocauc: 0.712


100%|██████████| 258/258 [01:12<00:00,  3.56it/s]
100%|██████████| 33/33 [00:04<00:00,  7.18it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 22: LR: 0.00013, Train loss: 0.134 Train rocauc: 0.980 Valid loss: 0.131  Valid rocauc: 0.795 Test loss: 0.202  Test rocauc: 0.742


100%|██████████| 258/258 [01:19<00:00,  3.23it/s]
100%|██████████| 33/33 [00:05<00:00,  6.50it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 23: LR: 0.00013, Train loss: 0.129 Train rocauc: 0.981 Valid loss: 0.226  Valid rocauc: 0.801 Test loss: 0.286  Test rocauc: 0.736


100%|██████████| 258/258 [01:14<00:00,  3.44it/s]
100%|██████████| 33/33 [00:05<00:00,  5.79it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 24: LR: 0.00013, Train loss: 0.124 Train rocauc: 0.983 Valid loss: 0.142  Valid rocauc: 0.800 Test loss: 0.221  Test rocauc: 0.720


100%|██████████| 258/258 [01:15<00:00,  3.40it/s]
100%|██████████| 33/33 [00:05<00:00,  6.44it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 25: LR: 0.00012, Train loss: 0.118 Train rocauc: 0.984 Valid loss: 0.151  Valid rocauc: 0.823 Test loss: 0.246  Test rocauc: 0.713


100%|██████████| 258/258 [01:17<00:00,  3.31it/s]
100%|██████████| 33/33 [00:05<00:00,  6.51it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 26: LR: 0.00012, Train loss: 0.111 Train rocauc: 0.987 Valid loss: 0.127  Valid rocauc: 0.812 Test loss: 0.223  Test rocauc: 0.690


100%|██████████| 258/258 [01:16<00:00,  3.39it/s]
100%|██████████| 33/33 [00:04<00:00,  6.65it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 27: LR: 0.00012, Train loss: 0.108 Train rocauc: 0.987 Valid loss: 0.145  Valid rocauc: 0.816 Test loss: 0.246  Test rocauc: 0.727


100%|██████████| 258/258 [01:13<00:00,  3.53it/s]
100%|██████████| 33/33 [00:04<00:00,  7.26it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 28: LR: 0.00011, Train loss: 0.102 Train rocauc: 0.988 Valid loss: 0.155  Valid rocauc: 0.797 Test loss: 0.254  Test rocauc: 0.706


100%|██████████| 258/258 [01:08<00:00,  3.75it/s]
100%|██████████| 33/33 [00:04<00:00,  7.33it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 29: LR: 0.00011, Train loss: 0.092 Train rocauc: 0.990 Valid loss: 0.148  Valid rocauc: 0.786 Test loss: 0.242  Test rocauc: 0.701


100%|██████████| 258/258 [01:08<00:00,  3.74it/s]
100%|██████████| 33/33 [00:04<00:00,  7.33it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 30: LR: 0.00011, Train loss: 0.093 Train rocauc: 0.991 Valid loss: 0.157  Valid rocauc: 0.793 Test loss: 0.256  Test rocauc: 0.711


100%|██████████| 258/258 [01:09<00:00,  3.72it/s]
100%|██████████| 33/33 [00:04<00:00,  7.20it/s]
  0%|          | 0/258 [00:00<?, ?it/s]

Epoch 31: LR: 0.00010, Train loss: 0.093 Train rocauc: 0.990 Valid loss: 0.144  Valid rocauc: 0.793 Test loss: 0.243  Test rocauc: 0.705


 52%|█████▏    | 133/258 [00:36<00:33,  3.68it/s]

In [None]:
import matplotlib.pyplot as plt
labels = ['epoch', 'train_loss', 'train_metric', 'valid_loss', 'valid_metric', 'test_loss', 'test_metric']
fig = plt.figure(figsize=(15, 10))
stats_np = np.array(stats)
best_valid = stats_np[stats_np[:, 4].argmax()]
print(best_valid)
for i in range(1, stats_np.shape[-1]):
    ax = fig.add_subplot(2, 3, i)
    ax.plot(stats_np[:, i], label=labels[i])
    ax.scatter(x=best_valid[0], y=best_valid[i], color='red')
    ax.annotate(best_valid[i].round(3), xy=(best_valid[0]+5, best_valid[i]), color='red')
    ax.legend()
