In [1]:
%load_ext autoreload
%autoreload 2
import networkx as nx

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="7"

import torch
import torch.nn as nn
from torch_geometric.data import Data, Batch
from torch_geometric.data import DataLoader
import argparse
import numpy as np
import random
import ogb
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from graph_transformer import GT
from utils import pre_process, pre_process_with_summary, concat_pre_process_with_summary, inf_sum_pre_process_with_summary, fin_sum_pre_process_with_summary, get_n_params, get_optimizer
import datetime
from tqdm import tqdm
from tensorboardX import SummaryWriter
import pytz

In [2]:
parser = argparse.ArgumentParser(description='PyTorch implementation of relative positional encodings and relation-aware self-attention for graph Transformers')
args = parser.parse_args("")

args.dataset = 'ogbg-molhiv'
args.n_classes = 1
args.lr = 3e-4
args.n_hid = 512
args.n_heads = 8
args.n_layer = 4
args.dropout = 0.3
args.num_epochs = 50
args.k_hop_neighbors = 3
args.k_hop = True
args.weight_decay = 1e-2
# args.bsz      = 512
args.bsz      = 128
args.strategies = ['ea', 'rw_inf_sum']
args.summary_node = True
args.hier_levels = 3
args.lap_k = None
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.metric = 'rocauc'
print("device:", args.device)

device: cuda


In [3]:
print("Loading data...")
print("dataset: {} ".format(args.dataset))
tz = pytz.timezone('US/Pacific')
time_now = datetime.datetime.now(tz).strftime('%m-%d_%H:%M:%S')

if args.summary_node:
    pre_transform = lambda d : inf_sum_pre_process_with_summary(d, args)
    root_path= f'dataset/{args.dataset}/inf_sum_with_summary_{args.k_hop_neighbors}'
    # args.writer = SummaryWriter(log_dir=f'runs_new/{args.dataset}/inf_sum_with_summary_{args.k_hop_neighbors}/strats={"-".join(args.strategies)}/{time_now}')

else:
    pre_transform = lambda d : pre_process(d, args)
    root_path= f'dataset/{args.dataset}/{args.k_hop_neighbors}'
    # args.writer = SummaryWriter(log_dir=f'runs_new/{args.dataset}/k={args.k_hop_neighbors}/strats={"-".join(args.strategies)}/{time_now}')
    
    
dataset = PygGraphPropPredDataset(name=args.dataset, pre_transform=pre_transform, root = root_path)
evaluator = Evaluator(name=args.dataset)
split_idx = dataset.get_idx_split()
edge_dim_dict = {'ea': None, \
                 'disc': {
#                      'sd': (dataset.data.sd_edge_attr.max(dim=0)[0].int().view(-1) + 1).tolist(), \
#                      'cn': (dataset.data.cn_edge_attr.max(dim=0)[0].int().view(-1) + 1).tolist(), \
#                      'hsd': (dataset.data.hsd_edge_attr.max(dim=0)[0].int().view(-1) + 1).tolist(), \
                    },
                 'cont': {
#                      **{('rw_' + str(k)): args.n_hid for k in range(1, args.k_hop_neighbors + 1)}
                     'rw': args.n_hid
                 }
                }
model = GT(args.n_hid, args.n_classes, args.n_heads, args.n_layer, edge_dim_dict, args.dropout, args.summary_node, args.lap_k).to(args.device)

Loading data...
dataset: ogbg-molhiv 


In [4]:
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.bsz, shuffle = False)
test_loader  = DataLoader(dataset[split_idx["test"]],  batch_size=args.bsz, shuffle = False)

In [5]:
print('Model #Params: %d' % get_n_params(model))

criterion = torch.nn.BCEWithLogitsLoss(reduction = "mean")

optimizer = get_optimizer(model, weight_decay = args.weight_decay, learning_rate = args.lr)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6)
scheduler.step(-500)

Model #Params: 7717377




In [6]:
import seaborn as sb
def mat_visualize(node_size, edge_index, edge_attr):
    mat = np.zeros((node_size, node_size))
    for e, v in zip(edge_index, edge_attr):
        mat[e[0]][e[1]] = v
    sb.heatmap(mat)

  import pandas.util.testing as tm


In [7]:
def turn_prob(inp):
    prob = torch.sigmoid(inp)
    prob = torch.cat([prob, 1-prob], dim=1)
    return prob

In [8]:
train_mask = torch.zeros(len(dataset), dtype=bool)
valid_mask = torch.zeros(len(dataset), dtype=bool)
test_mask = torch.zeros(len(dataset), dtype=bool)

train_mask[split_idx["train"]] = True
valid_mask[split_idx["valid"]] = True
test_mask[split_idx["test"]] = True
def entropy_loss(pred, label):
    return torch.mean(torch.sum(-label * pred, dim=1))

In [None]:
stats = []
for epoch in range(args.num_epochs):
    model.train()
    train_loss = []
    train_adv  = []
    y_true = []
    y_scores = []
    all_idx = torch.randperm(len(dataset))
    for batch_idx in tqdm(range(len(all_idx) // args.bsz)):
        batch = all_idx[batch_idx * args.bsz : (batch_idx + 1) * args.bsz]
        train_msk = train_mask[batch]    
        data = Batch.from_data_list(dataset[batch])
        data.to(args.device)
        
#         strats = {'ea': data.edge_attr, \
#                   **{('rw_' + str(k)): data['rw_edge_attr_' + str(k)] for k in range(1, args.k_hop_neighbors + 1)}}
        strats = {'ea': data.edge_attr, 'rw': data.rw_edge_attr}
        out = model(data.x, data.batch, data.edge_index, strats)
        with torch.no_grad():
#             strats = {'ea': data.edge_attr, \
#                   **{('rw_' + str(k)): data['rw_edge_attr_' + str(k)] for k in range(1, args.k_hop_neighbors + 1)}}
            strats = {'ea': data.edge_attr, 'rw': data.rw_edge_attr}
            adv_out = model(data.x, data.batch, data.edge_index, strats)
        
        loss = criterion(out[train_msk], data.y[train_msk].float())
        adv_loss = entropy_loss(turn_prob(out).log(), turn_prob(adv_out))
        (loss + 0.5 * adv_loss).backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
        
        train_loss += [loss.item()]
        train_adv  += [adv_loss.item()]
        
        y_true += [data.y]
        y_scores += [out]

    input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
    train_metric = evaluator.eval(input_dict)[args.metric]
    

    model.eval()
    with torch.no_grad():
        valid_loss = []
        y_true = []
        y_scores = []
        for num_iters, data in enumerate(tqdm(valid_loader)):
            data.to(args.device)
#             strats = {'ea': data.edge_attr, \
#                   **{('rw_' + str(k)): data['rw_edge_attr_' + str(k)] for k in range(1, args.k_hop_neighbors + 1)}}
            strats = {'ea': data.edge_attr, 'rw': data.rw_edge_attr}
            out = model(data.x, data.batch, data.edge_index, strats)
        
            loss = criterion(out, data.y.float())
            valid_loss += [loss.item()]

            y_true += [data.y]
            y_scores += [out]

        input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
        valid_metric = evaluator.eval(input_dict)[args.metric]
        
        test_loss = []
        y_true = []
        y_scores = []
        for data in test_loader:
            data.to(args.device)
#             strats = {'ea': data.edge_attr, \
#                   **{('rw_' + str(k)): data['rw_edge_attr_' + str(k)] for k in range(1, args.k_hop_neighbors + 1)}}
            strats = {'ea': data.edge_attr, 'rw': data.rw_edge_attr}
            out = model(data.x, data.batch, data.edge_index, strats)
        
            loss = criterion(out, data.y.float())
            test_loss += [loss.item()]

            y_true += [data.y]
            y_scores += [out]

        input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
        test_metric = evaluator.eval(input_dict)[args.metric]

    print('Epoch %d: LR: %.5f, Train loss: %.3f Train %s: %.3f Train Adv: %.3f Valid loss: %.3f  Valid %s: %.3f \
        Test loss: %.3f  Test %s: %.3f' \
          % (epoch + 1, optimizer.param_groups[0]['lr'], np.average(train_loss), args.metric, train_metric, \
             np.average(train_adv), np.average(valid_loss), args.metric, valid_metric, \
             np.average(test_loss), args.metric, test_metric))
    stats += [[epoch, np.average(train_loss), train_metric, np.average(valid_loss), valid_metric, np.average(test_loss), test_metric]]

100%|██████████| 321/321 [03:10<00:00,  1.68it/s]
100%|██████████| 33/33 [00:06<00:00,  4.74it/s]
  0%|          | 0/321 [00:00<?, ?it/s]

Epoch 1: LR: 0.00022, Train loss: 0.182 Train rocauc: 0.519 Train Adv: 0.205 Valid loss: 0.093  Valid rocauc: 0.700         Test loss: 0.138  Test rocauc: 0.681


100%|██████████| 321/321 [03:08<00:00,  1.71it/s]
100%|██████████| 33/33 [00:06<00:00,  5.34it/s]
  0%|          | 0/321 [00:00<?, ?it/s]

Epoch 2: LR: 0.00024, Train loss: 0.157 Train rocauc: 0.610 Train Adv: 0.154 Valid loss: 0.091  Valid rocauc: 0.710         Test loss: 0.144  Test rocauc: 0.680


100%|██████████| 321/321 [03:09<00:00,  1.70it/s]
100%|██████████| 33/33 [00:06<00:00,  5.16it/s]
  0%|          | 0/321 [00:00<?, ?it/s]

Epoch 3: LR: 0.00001, Train loss: 0.148 Train rocauc: 0.690 Train Adv: 0.149 Valid loss: 0.086  Valid rocauc: 0.738         Test loss: 0.122  Test rocauc: 0.701


100%|██████████| 321/321 [03:09<00:00,  1.69it/s]
100%|██████████| 33/33 [00:06<00:00,  4.99it/s]
  0%|          | 0/321 [00:00<?, ?it/s]

Epoch 4: LR: 0.00018, Train loss: 0.142 Train rocauc: 0.719 Train Adv: 0.146 Valid loss: 0.088  Valid rocauc: 0.738         Test loss: 0.121  Test rocauc: 0.695


100%|██████████| 321/321 [03:09<00:00,  1.69it/s]
100%|██████████| 33/33 [00:06<00:00,  5.15it/s]
  0%|          | 0/321 [00:00<?, ?it/s]

Epoch 5: LR: 0.00027, Train loss: 0.147 Train rocauc: 0.681 Train Adv: 0.146 Valid loss: 0.099  Valid rocauc: 0.672         Test loss: 0.129  Test rocauc: 0.730


100%|██████████| 321/321 [03:07<00:00,  1.71it/s]
100%|██████████| 33/33 [00:06<00:00,  5.13it/s]
  0%|          | 0/321 [00:00<?, ?it/s]

Epoch 6: LR: 0.00002, Train loss: 0.139 Train rocauc: 0.724 Train Adv: 0.144 Valid loss: 0.081  Valid rocauc: 0.766         Test loss: 0.123  Test rocauc: 0.732


100%|██████████| 321/321 [03:08<00:00,  1.70it/s]
100%|██████████| 33/33 [00:06<00:00,  4.87it/s]
  0%|          | 0/321 [00:00<?, ?it/s]

Epoch 7: LR: 0.00015, Train loss: 0.136 Train rocauc: 0.745 Train Adv: 0.140 Valid loss: 0.088  Valid rocauc: 0.744         Test loss: 0.119  Test rocauc: 0.728


 91%|█████████ | 292/321 [02:52<00:17,  1.68it/s]

In [None]:
import matplotlib.pyplot as plt
labels = ['epoch', 'train_loss', 'train_metric', 'valid_loss', 'valid_metric', 'test_loss', 'test_metric']
fig = plt.figure(figsize=(15, 10))
stats_np = np.array(stats)
best_valid = stats_np[stats_np[:50, 4].argmax()]
print(best_valid)
for i in range(1, stats_np.shape[-1]):
    ax = fig.add_subplot(2, 3, i)
    ax.plot(stats_np[:, i], label=labels[i])
    ax.scatter(x=best_valid[0], y=best_valid[i], color='red')
    ax.annotate(best_valid[i].round(3), xy=(best_valid[0]+5, best_valid[i]), color='red')
    ax.legend()


In [None]:
from torch_geometric.utils import to_dense_adj, to_networkx, dense_to_sparse, remove_self_loops, to_undirected
from draw_mols_demo import pyg_to_mol, mol_to_svg, HorizontalDisplay

model.eval()
data = dataset[3]    
data.to(args.device)
strats = {'ea': data.edge_attr,  'rw': data.rw_edge_attr}
out = model(data.x, 0, data.edge_index, strats)

threshold = 0.975

# fig, axes = plt.subplots(nrows=len(model.gcs), ncols=2, figsize=(15, 10 * len(model.gcs)))
for layer_idx, gc in enumerate(model.gcs):
    imgs = []
    
    adj = to_dense_adj(edge_index=data.edge_index, edge_attr=gc.att)[0]
    adj_mean = adj.mean(dim=-1).detach().cpu()
    adj_max = adj.max(dim=-1)[0].detach().cpu()
    
    adj_mean_sorted = adj_mean.flatten().sort()[0]
    adj_mean_threshold = adj_mean_sorted[int(threshold * len(adj_mean_sorted))]
    adj_mean[adj_mean < adj_mean_threshold] = 0
    adj_mean[adj_mean >= adj_mean_threshold] = 1
    
#     adj_max_sorted = adj_max.flatten().sort()[0]
#     adj_max_threshold = adj_max_sorted[int(threshold * len(adj_max_sorted))]
#     adj_max[adj_max < max_threshold] = 0
#     adj_max[adj_max >= max_threshold] = 1
    
    mean_edge_index = dense_to_sparse(adj_mean.long())[0]
    mean_edge_index = remove_self_loops(mean_edge_index)[0]
    mean_data = Data(x=data.x, edge_index=mean_edge_index)
    
#     max_edge_index = dense_to_sparse(adj_max.long())[0]
#     max_edge_index = remove_self_loops(max_edge_index)[0]
#     max_data = Data(x=data.x, edge_index=max_edge_index)
    
    # ax = axes[layer_idx][0]
    # ax.set_title(f'Layer {layer_idx + 1}, Mean')
    # im = ax.matshow(adj_mean)
    # fig.colorbar(im, ax=ax)
    # molecule_draw_with_color(to_networkx(mean_data, node_attrs=['x']), ax=ax, labels='node_id')
    mol = pyg_to_mol(mean_data)  
    # mc = mol_to_svg(mol, molSize=(150, 150))
    svg = mol_to_svg(mol, molSize=(150, 150))
    imgs += [svg]
    
    # ax = axes[layer_idx][1]
    # ax.set_title(f'Layer {layer_idx + 1}, Max')
    # im = ax.matshow(adj_max)
    # fig.colorbar(im, ax=ax)
    # molecule_draw_with_color(to_networkx(max_data, node_attrs=['x']), ax=ax, labels='node_id')
    
    # mol = pyg_to_mol(max_data)  
    # mc = mol_to_svg(mol, molSize=(150, 150))
    svg = mol_to_svg(mol, molSize=(150, 150))
    # imgs += [svg]
    row = HorizontalDisplay(*imgs)
    display(row)