In [1]:
%load_ext autoreload
%autoreload 2
import networkx as nx

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="4"

import torch
import torch.nn as nn
from torch_geometric.data import Data, Batch
from torch_geometric.data import DataLoader
import argparse
import numpy as np
import random
import ogb
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from graph_transformer import GT
from utils import pre_process, pre_process_with_summary, concat_pre_process_with_summary, inf_sum_pre_process_with_summary, fin_sum_pre_process_with_summary, get_n_params, get_optimizer
import datetime
from tqdm import tqdm
from tensorboardX import SummaryWriter
import torch.nn.functional as F
import pytz
from torch_geometric.utils import to_dense_adj, to_networkx, dense_to_sparse, remove_self_loops, to_undirected

In [2]:
parser = argparse.ArgumentParser(description='PyTorch implementation of relative positional encodings and relation-aware self-attention for graph Transformers')
args = parser.parse_args("")

args.dataset = 'ogbg-molhiv'
args.n_classes = 1
args.lr = 3e-4
args.n_hid = 512
args.n_heads = 8
args.n_layer = 4
args.dropout = 0.3
args.num_epochs = 50
# args.num_epochs = 1
args.k_hop_neighbors = 3
args.k_hop = True
args.weight_decay = 1e-2
# args.bsz      = 512
args.bsz      = 448
args.strategies = ['ea', 'rw_concat']
args.summary_node = True
args.hier_levels = 3
args.lap_k = None
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.temp = torch.nn.Linear(1, 1).to(args.device)
args.metric = 'rocauc'
print("device:", args.device)

device: cuda


In [3]:
print("Loading data...")
print("dataset: {} ".format(args.dataset))
tz = pytz.timezone('US/Pacific')
time_now = datetime.datetime.now(tz).strftime('%m-%d_%H:%M:%S')

if args.summary_node:
    pre_transform = lambda d : concat_pre_process_with_summary(d, args)
    root_path= f'dataset/{args.dataset}/concat_with_summary_{args.k_hop_neighbors}'
    # args.writer = SummaryWriter(log_dir=f'runs_new/{args.dataset}/concat_with_summary_{args.k_hop_neighbors}/strats={"-".join(args.strategies)}/{time_now}')

else:
    pre_transform = lambda d : pre_process(d, args)
    root_path= f'dataset/{args.dataset}/{args.k_hop_neighbors}'
    # args.writer = SummaryWriter(log_dir=f'runs_new/{args.dataset}/k={args.k_hop_neighbors}/strats={"-".join(args.strategies)}/{time_now}')
    
    
dataset = PygGraphPropPredDataset(name=args.dataset, pre_transform=pre_transform, root = root_path)
orig_dataset = PygGraphPropPredDataset(name=args.dataset)
evaluator = Evaluator(name=args.dataset)
split_idx = dataset.get_idx_split()
edge_dim_dict = {'ea': None, \
                 'disc': {
#                      'sd': (dataset.data.sd_edge_attr.max(dim=0)[0].int().view(-1) + 1).tolist(), \
#                      'cn': (dataset.data.cn_edge_attr.max(dim=0)[0].int().view(-1) + 1).tolist(), \
#                      'hsd': (dataset.data.hsd_edge_attr.max(dim=0)[0].int().view(-1) + 1).tolist(), \
                    },
                 'cont': {
                     **{('rw_' + str(k)): args.n_hid for k in range(1, args.k_hop_neighbors + 1)}
#                      'rw': args.n_hid
                 }
                }
model = GT(args.n_hid, args.n_classes, args.n_heads, args.n_layer, edge_dim_dict, args.dropout, args.summary_node, args.lap_k).to(args.device)

Loading data...
dataset: ogbg-molhiv 


In [4]:
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.bsz, shuffle = False)
test_loader  = DataLoader(dataset[split_idx["test"]],  batch_size=args.bsz, shuffle = False)

orig_valid_loader = DataLoader(orig_dataset[split_idx["valid"]], batch_size=args.bsz, shuffle = False)
orig_test_loader  = DataLoader(orig_dataset[split_idx["test"]],  batch_size=args.bsz, shuffle = False)

In [5]:
print('Model #Params: %d' % get_n_params(model))

criterion = torch.nn.BCEWithLogitsLoss(reduction = "sum")

optimizer = get_optimizer(model, weight_decay = args.weight_decay, learning_rate = args.lr)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6)
scheduler.step(-500)

Model #Params: 10338817




In [6]:
import seaborn as sb
def mat_visualize(node_size, edge_index, edge_attr):
    mat = np.zeros((node_size, node_size))
    for e, v in zip(edge_index, edge_attr):
        mat[e[0]][e[1]] = v
    sb.heatmap(mat)

  import pandas.util.testing as tm


In [7]:
def turn_prob(inp):
    prob = torch.sigmoid(inp)
    prob = torch.cat([prob, 1-prob], dim=1)
    return prob

In [8]:
train_mask = torch.zeros(len(dataset), dtype=bool)
valid_mask = torch.zeros(len(dataset), dtype=bool)
test_mask = torch.zeros(len(dataset), dtype=bool)

train_mask[split_idx["train"]] = True
valid_mask[split_idx["valid"]] = True
test_mask[split_idx["test"]] = True
def entropy_loss(pred, label):
    return torch.mean(torch.sum(-label * pred, dim=1))

In [9]:
stats = []
for epoch in range(args.num_epochs):
    model.train()
    train_loss = []
    train_adv  = []
    y_true = []
    y_scores = []
    all_idx = torch.randperm(len(dataset))
    for batch_idx in tqdm(range(len(all_idx) // args.bsz)):
        batch = all_idx[batch_idx * args.bsz : (batch_idx + 1) * args.bsz]
        train_msk = train_mask[batch]    
        data = Batch.from_data_list(dataset[batch])
        data.to(args.device)
        
        strats = {'ea': data.edge_attr, \
                  **{('rw_' + str(k)): data['rw_edge_attr_' + str(k)] for k in range(1, args.k_hop_neighbors + 1)}}
#         strats = {'ea': data.edge_attr, 'rw': data.rw_edge_attr}
        out, reps = model(data.x, data.batch, data.edge_index, strats)
        with torch.no_grad():
            strats = {'ea': data.edge_attr, \
                  **{('rw_' + str(k)): data['rw_edge_attr_' + str(k)] for k in range(1, args.k_hop_neighbors + 1)}}
            # strats = {'ea': data.edge_attr, 'rw': data.rw_edge_attr}
            adv_out, reps = model(data.x, data.batch, data.edge_index, strats)
        
        # loss = criterion(out[train_msk], data.y[train_msk].float())
        loss = 0.0
        numel = 0
        for idx, orig_data in enumerate(orig_dataset[batch]):
            if not train_msk[idx]:
                continue
            orig_adj = to_dense_adj(edge_index=to_undirected(orig_data.edge_index), max_num_nodes=orig_data.x.size(0))[0].float().to(args.device)
            normalized = F.normalize(reps[data.batch == idx][:-1])
            pred_adj = args.temp(torch.mm(normalized, normalized.t()).unsqueeze(-1)).squeeze(-1)
            loss += criterion(pred_adj, orig_adj + torch.eye(orig_adj.size(0), device=orig_adj.device))
            numel += pred_adj.numel()
        print(pred_adj.sigmoid())
        loss = loss / numel
        
        adv_loss = entropy_loss(turn_prob(out[train_msk]).log(), turn_prob(adv_out[train_msk]))
        adv_loss = adv_loss * 0
        (loss + 0.5 * adv_loss).backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
        
        train_loss += [loss.item()]
        train_adv  += [adv_loss.item()]
        
        y_true += [data.y]
        y_scores += [out]

    input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
    train_metric = evaluator.eval(input_dict)[args.metric]
    

    model.eval()
    with torch.no_grad():
        valid_loss = []
        y_true = []
        y_scores = []
        for num_iters, (data, orig_data_batch) in enumerate(tqdm(zip(valid_loader, orig_valid_loader))):
            data.to(args.device)
            strats = {'ea': data.edge_attr, \
                  **{('rw_' + str(k)): data['rw_edge_attr_' + str(k)] for k in range(1, args.k_hop_neighbors + 1)}}
            # strats = {'ea': data.edge_attr, 'rw': data.rw_edge_attr}
            out, reps = model(data.x, data.batch, data.edge_index, strats)
        
            # loss = criterion(out, data.y.float())
            loss = 0.0
            numel = 0
            for idx, orig_data in enumerate(orig_data_batch.to_data_list()):
                orig_adj = to_dense_adj(edge_index=to_undirected(orig_data.edge_index), max_num_nodes=orig_data.x.size(0))[0].float().to(args.device)
                normalized = F.normalize(reps[data.batch == idx][:-1])
                pred_adj = args.temp(torch.mm(normalized, normalized.t()).unsqueeze(-1)).squeeze(-1)
                loss += criterion(pred_adj, orig_adj + torch.eye(orig_adj.size(0), device=orig_adj.device))
                numel += pred_adj.numel()
            loss = loss / numel
            
            valid_loss += [loss.item()]

            y_true += [data.y]
            y_scores += [out]

        input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
        valid_metric = evaluator.eval(input_dict)[args.metric]
        
        test_loss = []
        y_true = []
        y_scores = []
        for (data, orig_data_batch) in zip(test_loader, orig_test_loader):
            data.to(args.device)
            strats = {'ea': data.edge_attr, \
                  **{('rw_' + str(k)): data['rw_edge_attr_' + str(k)] for k in range(1, args.k_hop_neighbors + 1)}}
            # strats = {'ea': data.edge_attr, 'rw': data.rw_edge_attr}
            out, reps = model(data.x, data.batch, data.edge_index, strats)
        
            # loss = criterion(out, data.y.float())
            loss = 0.0
            numel = 0
            for idx, orig_data in enumerate(orig_data_batch.to_data_list()):
                orig_adj = to_dense_adj(edge_index=to_undirected(orig_data.edge_index), max_num_nodes=orig_data.x.size(0))[0].float().to(args.device)
                normalized = F.normalize(reps[data.batch == idx][:-1])
                pred_adj = args.temp(torch.mm(normalized, normalized.t()).unsqueeze(-1)).squeeze(-1)
                loss += criterion(pred_adj, orig_adj + torch.eye(orig_adj.size(0), device=orig_adj.device))
                numel += pred_adj.numel()
            loss = loss / numel
        
            test_loss += [loss.item()]

            y_true += [data.y]
            y_scores += [out]

        input_dict = {"y_true": torch.cat(y_true), "y_pred": torch.cat(y_scores)}
        test_metric = evaluator.eval(input_dict)[args.metric]

    print('Epoch %d: LR: %.5f, Train loss: %.3f Train %s: %.3f Train Adv: %.3f Valid loss: %.3f  Valid %s: %.3f \
        Test loss: %.3f  Test %s: %.3f' \
          % (epoch + 1, optimizer.param_groups[0]['lr'], np.average(train_loss), args.metric, train_metric, \
             np.average(train_adv), np.average(valid_loss), args.metric, valid_metric, \
             np.average(test_loss), args.metric, test_metric))
    stats += [[epoch, np.average(train_loss), train_metric, np.average(valid_loss), valid_metric, np.average(test_loss), test_metric]]

  0%|          | 0/91 [00:00<?, ?it/s]

tensor([[0.5376, 0.4789, 0.4766, 0.4733, 0.4757, 0.4758, 0.4808, 0.4808, 0.4797,
         0.4824, 0.4846, 0.4738, 0.4726, 0.4787, 0.4766, 0.4824, 0.4808, 0.4846,
         0.4772],
        [0.4789, 0.5376, 0.4904, 0.4914, 0.4952, 0.4951, 0.4971, 0.4949, 0.5017,
         0.4968, 0.4964, 0.5003, 0.4901, 0.4967, 0.4927, 0.4944, 0.4974, 0.4958,
         0.5028],
        [0.4766, 0.4904, 0.5376, 0.4989, 0.4999, 0.4990, 0.4969, 0.5046, 0.4967,
         0.4991, 0.4948, 0.4955, 0.5020, 0.4973, 0.4950, 0.5010, 0.5001, 0.5037,
         0.4931],
        [0.4733, 0.4914, 0.4989, 0.5376, 0.5056, 0.5004, 0.5013, 0.5015, 0.4973,
         0.5030, 0.4950, 0.4894, 0.4991, 0.5020, 0.5051, 0.5033, 0.5008, 0.5021,
         0.4930],
        [0.4757, 0.4952, 0.4999, 0.5056, 0.5376, 0.5078, 0.5052, 0.5060, 0.4979,
         0.5030, 0.5000, 0.4950, 0.4966, 0.5034, 0.5042, 0.5054, 0.5041, 0.5011,
         0.4962],
        [0.4758, 0.4951, 0.4990, 0.5004, 0.5078, 0.5376, 0.5055, 0.4988, 0.4961,
         0.5022, 0.

  1%|          | 1/91 [00:02<03:11,  2.13s/it]

tensor([[0.5376, 0.4968, 0.4975, 0.5005, 0.5021, 0.5020, 0.4930, 0.4897, 0.4863,
         0.4917, 0.4986, 0.4855, 0.4887, 0.4874, 0.4888, 0.5049, 0.4893, 0.4992,
         0.4840, 0.4934, 0.4913, 0.4890, 0.5030, 0.4878, 0.5013, 0.4879],
        [0.4968, 0.5376, 0.5035, 0.5076, 0.5019, 0.5048, 0.4925, 0.4858, 0.4832,
         0.4937, 0.5025, 0.4847, 0.4935, 0.4877, 0.4894, 0.4988, 0.4858, 0.4991,
         0.4876, 0.4963, 0.4898, 0.4872, 0.4971, 0.4877, 0.4975, 0.4830],
        [0.4975, 0.5035, 0.5376, 0.5055, 0.5051, 0.5011, 0.4961, 0.4913, 0.4859,
         0.4915, 0.5001, 0.4913, 0.4944, 0.4887, 0.4916, 0.4945, 0.4875, 0.5025,
         0.4894, 0.4940, 0.4878, 0.4886, 0.4960, 0.4884, 0.4940, 0.4821],
        [0.5005, 0.5076, 0.5055, 0.5376, 0.5056, 0.5028, 0.4954, 0.4890, 0.4859,
         0.4958, 0.5029, 0.4920, 0.4920, 0.4876, 0.4926, 0.4960, 0.4866, 0.5018,
         0.4890, 0.4958, 0.4915, 0.4899, 0.5022, 0.4876, 0.4954, 0.4839],
        [0.5021, 0.5019, 0.5051, 0.5056, 0.5376, 0.5048,

  2%|▏         | 2/91 [00:04<03:03,  2.06s/it]

tensor([[0.5376, 0.4920, 0.4927, 0.4947, 0.4972, 0.5006, 0.5005, 0.4918, 0.4946,
         0.4943, 0.4952, 0.4939, 0.4947, 0.4980, 0.4980],
        [0.4920, 0.5376, 0.4941, 0.4939, 0.4956, 0.4924, 0.4942, 0.4898, 0.4915,
         0.4885, 0.4874, 0.4924, 0.4888, 0.4895, 0.4908],
        [0.4927, 0.4941, 0.5376, 0.4955, 0.4996, 0.4959, 0.4887, 0.4853, 0.4889,
         0.4852, 0.4871, 0.4847, 0.4827, 0.4897, 0.4876],
        [0.4947, 0.4939, 0.4955, 0.5376, 0.5038, 0.4967, 0.4959, 0.4915, 0.4925,
         0.4909, 0.4921, 0.4944, 0.4914, 0.4915, 0.4942],
        [0.4972, 0.4956, 0.4996, 0.5038, 0.5376, 0.4960, 0.4938, 0.4915, 0.4939,
         0.4964, 0.4922, 0.4928, 0.4910, 0.4928, 0.4925],
        [0.5006, 0.4924, 0.4959, 0.4967, 0.4960, 0.5376, 0.5030, 0.5007, 0.5032,
         0.5034, 0.5002, 0.5038, 0.5013, 0.5037, 0.5003],
        [0.5005, 0.4942, 0.4887, 0.4959, 0.4938, 0.5030, 0.5376, 0.4992, 0.5027,
         0.4998, 0.5039, 0.5018, 0.5046, 0.5027, 0.5033],
        [0.4918, 0.4898, 0.

  3%|▎         | 3/91 [00:05<02:50,  1.94s/it]

tensor([[0.5376, 0.4996, 0.4868,  ..., 0.4844, 0.5001, 0.4955],
        [0.4996, 0.5376, 0.4927,  ..., 0.4889, 0.5044, 0.4964],
        [0.4868, 0.4927, 0.5376,  ..., 0.4975, 0.4872, 0.4848],
        ...,
        [0.4844, 0.4889, 0.4975,  ..., 0.5376, 0.4859, 0.4885],
        [0.5001, 0.5044, 0.4872,  ..., 0.4859, 0.5376, 0.4963],
        [0.4955, 0.4964, 0.4848,  ..., 0.4885, 0.4963, 0.5376]],
       device='cuda:0', grad_fn=<SigmoidBackward>)


  4%|▍         | 4/91 [00:07<02:42,  1.87s/it]

tensor([[0.5376, 0.4910, 0.4885,  ..., 0.4916, 0.5038, 0.4893],
        [0.4910, 0.5376, 0.4969,  ..., 0.4981, 0.4894, 0.4959],
        [0.4885, 0.4969, 0.5376,  ..., 0.4979, 0.4937, 0.5022],
        ...,
        [0.4916, 0.4981, 0.4979,  ..., 0.5376, 0.4938, 0.4929],
        [0.5038, 0.4894, 0.4937,  ..., 0.4938, 0.5376, 0.4908],
        [0.4893, 0.4959, 0.5022,  ..., 0.4929, 0.4908, 0.5376]],
       device='cuda:0', grad_fn=<SigmoidBackward>)


  5%|▌         | 5/91 [00:10<02:52,  2.01s/it]


KeyboardInterrupt: 

In [None]:
print(args.temp.weight)

In [None]:
import matplotlib.pyplot as plt
labels = ['epoch', 'train_loss', 'train_metric', 'valid_loss', 'valid_metric', 'test_loss', 'test_metric']
fig = plt.figure(figsize=(15, 10))
stats_np = np.array(stats)
best_valid = stats_np[stats_np[:50, 4].argmax()]
print(best_valid)
for i in range(1, stats_np.shape[-1]):
    ax = fig.add_subplot(2, 3, i)
    ax.plot(stats_np[:, i], label=labels[i])
    ax.scatter(x=best_valid[0], y=best_valid[i], color='red')
    ax.annotate(best_valid[i].round(3), xy=(best_valid[0]+5, best_valid[i]), color='red')
    ax.legend()


In [None]:
import random
from torch_geometric.utils import degree
from torch.distributions.multinomial import Multinomial

def generateSequence(startIndex, transitionMatrix, path_length, alpha):
    result = [startIndex]
    current = startIndex

    for i in range(0, path_length):
        if random.random() < alpha:
            nextIndex = startIndex
        else:
            probs = transitionMatrix[current]
            assert np.sum(probs) != 0, print(probs)
            nextIndex = np.random.choice(len(probs), 1, p=probs)[0]

        result.append(nextIndex)
        current = nextIndex

    return result

def weighted_random_walk(data, transitionMatrix, path_length, alpha, degree_weighted_start=True, num_samples=3):
    if degree_weighted_start:
        # Exclude degree 1 nodes, soft max over remaining degrees
        p = degree(data.edge_index[0])
        p[p == 1] = 0
    else:
        p = torch.ones(data.num_nodes)    
    m = Multinomial(num_samples, probs=p.exp()-1)
    start_node = m.sample().long().tolist()
    start = torch.Tensor(sum([[i] * start_node[i] for i in range(len(start_node))], [])).long()
    
    sentenceList = []
    nodes = list(range(data.num_nodes))
    
    for j in range(0, num_samples):
        indexList = generateSequence(start[j].item(), transitionMatrix, path_length, alpha)
        sentence = [nodes[tmp] for tmp in indexList]
        sentence = torch.LongTensor(sentence).unique()
        sentenceList.append(sentence)

    return sentenceList

In [None]:
from draw_mols_demo import pyg_to_mol, mol_to_svg, HorizontalDisplay
from torch_geometric.utils import sort_edge_index

model.eval()
# threshold = 0.1

for idx in range(20, 41):
    orig_data = orig_dataset[idx]

    data = dataset[idx]
    data.to(args.device)
    strats = {'ea': data.edge_attr,  **{('rw_' + str(k)): data['rw_edge_attr_' + str(k)] for k in range(1, args.k_hop_neighbors + 1)}}
    out, _ = model(data.x, 0, data.edge_index, strats)
    
    imgs = []
    for layer_idx, gc in enumerate(model.gcs):

        adj = to_dense_adj(edge_index=data.edge_index, edge_attr=gc.att)[0]
        adj_mean = adj.mean(dim=-1).detach().cpu()
        adj_mean = adj_mean[:-1, :-1] # remove virtual node
        
        # only include edges that were in the original graph
#         orig_adj = to_dense_adj(edge_index=orig_data.edge_index)[0].bool()
#         adj_mean[~orig_adj] = 0
        adj_mean /= adj_mean.sum(dim=1, keepdim=True)
        
#         adj_mean_sorted = adj_mean.flatten().sort()[0]
#         adj_mean_sorted = adj_mean_sorted[adj_mean_sorted != 0]
#         adj_mean_threshold = adj_mean_sorted[int(threshold * len(adj_mean_sorted))]
#         adj_mean[adj_mean < adj_mean_threshold] = 0
#         adj_mean[adj_mean >= adj_mean_threshold] = 1
        subsetList = weighted_random_walk(orig_data, adj_mean.cpu().numpy(), 10, 0)
        edgeIndexList = []
        for s in subsetList:
            edgeIndexList.append(torch.LongTensor(s).repeat_interleave(2)[1:-1].reshape(-1, 2))
        mean_edge_index = sort_edge_index(torch.cat(edgeIndexList).t().contiguous())[0]
    
        # mean_edge_index = dense_to_sparse(adj_mean.long())[0]
        mean_data = Data(x=data.x, edge_index=mean_edge_index)

        mol = pyg_to_mol(mean_data)  
        svg = mol_to_svg(mol, molSize=(150, 150))
        imgs += [svg]

    mol = pyg_to_mol(orig_data)
    svg = mol_to_svg(mol, molSize=(150, 150))
    imgs += [svg]
    row = HorizontalDisplay(*imgs)
    display(row)

In [None]:
from draw_mols_demo import pyg_to_mol, mol_to_svg, HorizontalDisplay

model.eval()
# threshold = 0.0
orig_dataset = PygGraphPropPredDataset(name='ogbg-molhiv')

for idx in range(20, 41):
    orig_data = orig_dataset[idx]

    data = dataset[idx]
    data.to(args.device)
    strats = {'ea': data.edge_attr,  **{('rw_' + str(k)): data['rw_edge_attr_' + str(k)] for k in range(1, args.k_hop_neighbors + 1)}}
    out, reps = model(data.x, 0, data.edge_index, strats)
    
    imgs = []
    normalized = F.normalize(reps[:-1])
    adj_mean = args.temp(torch.mm(normalized, normalized.t()))
    adj_mean = adj_mean.sigmoid()
    adj_mean[adj_mean >= 0.5] = 1
    adj_mean[adj_mean < 0.5] = 0
    adj_mean -= torch.eye(adj_mean.size(0), device=adj_mean.device)

    mean_edge_index = dense_to_sparse(adj_mean.long())[0]
    mean_data = Data(x=data.x, edge_index=mean_edge_index)

    mol = pyg_to_mol(mean_data)  
    svg = mol_to_svg(mol, molSize=(150, 150))
    imgs += [svg]

    mol = pyg_to_mol(orig_data)
    svg = mol_to_svg(mol, molSize=(150, 150))
    imgs += [svg]
    row = HorizontalDisplay(*imgs)
    display(row)