## Dependencies

In [1]:
from tqdm import tqdm
import statistics

import torch
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score

import torch_geometric.transforms as T
from torch_geometric.datasets import SNAPDataset, DBLP, IMDB
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv, to_hetero
from torch_geometric.utils import negative_sampling, to_networkx

torch.manual_seed(0)

%matplotlib notebook

  from .autonotebook import tqdm as notebook_tqdm


## Data

In [2]:
DBLP(root="../data/DBLP")[0]

HeteroData(
  [1mauthor[0m={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  [1mpaper[0m={ x=[14328, 4231] },
  [1mterm[0m={ x=[7723, 50] },
  [1mconference[0m={ num_nodes=20 },
  [1m(author, to, paper)[0m={ edge_index=[2, 19645] },
  [1m(paper, to, author)[0m={ edge_index=[2, 19645] },
  [1m(paper, to, term)[0m={ edge_index=[2, 85810] },
  [1m(paper, to, conference)[0m={ edge_index=[2, 14328] },
  [1m(term, to, paper)[0m={ edge_index=[2, 85810] },
  [1m(conference, to, paper)[0m={ edge_index=[2, 14328] }
)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([
    T.ToDevice(device),
    T.RemoveIsolatedNodes(),
    T.RandomLinkSplit(
        num_val=0.05, 
        num_test=0.1, 
        is_undirected=True, 
        add_negative_train_samples=False,
        edge_types=[("paper", "to", "author"), ("paper", "to", "conference")]
    ),
    T.ToUndirected(),
])

dataset = DBLP(root="../data/DBLP", transform=transform)

train_data, val_data, test_data = dataset[0]

for data in train_data, val_data, test_data:
    del data["term"]
    del data[("paper", "to", "term")]
    del data[("term", "to", "paper")]
    del data[("author", "to", "paper")]
    del data[("conference", "to", "paper")]
    
    del data[("paper", "rev_to", "author")]
    del data[("term", "rev_to", "paper")]
    del data[("paper", "rev_to", "term")]
    del data[("paper", "rev_to", "conference")]
    
    data["conference"].x = torch.ones((20, 1))

In [9]:
print(train_data)
print(val_data)
print(test_data)

HeteroData(
  [1mauthor[0m={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  [1mpaper[0m={ x=[14328, 4231] },
  [1mconference[0m={
    num_nodes=20,
    x=[20, 1]
  },
  [1m(paper, to, author)[0m={
    edge_index=[2, 16699],
    edge_label=[16699],
    edge_label_index=[2, 16699]
  },
  [1m(paper, to, conference)[0m={
    edge_index=[2, 12180],
    edge_label=[12180],
    edge_label_index=[2, 12180]
  },
  [1m(author, rev_to, paper)[0m={
    edge_index=[2, 16699],
    edge_label=[16699]
  },
  [1m(conference, rev_to, paper)[0m={
    edge_index=[2, 12180],
    edge_label=[12180]
  }
)
HeteroData(
  [1mauthor[0m={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  [1mpaper[0m={ x=[14328, 4231] },
  [1mconference[0m={
    num_nodes=20,
    x=[20, 1]
  },
  [1m(paper, to, author)[0m={
    edge_index=[2, 16699],
    edge_label=[1964],
    edge_label_index=[2, 

## Prediction

In [10]:
from torch import nn
import torch.nn.functional as F


class Encoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    
class SimpleNet(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()
        self.encoder = to_hetero(Encoder(hidden_channels=hidden_channels, out_channels=out_channels), metadata)
    
    def encode(self, x_dict, edge_index_dict):
        return self.encoder(x_dict, edge_index_dict)
    
    def decode(self, z1, z2, edge_label_index):
        x1 = z1[edge_label_index[0]]
        x2 = z2[edge_label_index[1]]
        return (x1 * x2).sum(dim=-1)
    
    
class Net(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()
        self.encoder = to_hetero(Encoder(hidden_channels=hidden_channels, out_channels=out_channels), metadata)
        
        self.W1 = nn.Linear(out_channels * 2, out_channels)
        self.W2 = nn.Linear(out_channels, 1)
        
    def encode(self, x_dict, edge_index_dict):
        return self.encoder(x_dict, edge_index_dict)
    
    def decode(self, z1, z2, edge_label_index):
        z_forward = torch.cat((z1[edge_label_index[0]], z2[edge_label_index[1]]), dim=1)
        out1 = self.W2(F.relu(self.W1(z_forward)).squeeze()).squeeze()
        
        z_reverse = torch.cat((z2[edge_label_index[1]], z1[edge_label_index[0]]), dim=1)
        out2 = self.W2(F.relu(self.W1(z_reverse)).squeeze()).squeeze()
        
        return (out1 + out2) / 2
    
    
simple_model = SimpleNet(hidden_channels=128, out_channels=32, metadata=train_data.metadata()).to(device)
simple_optimizer = torch.optim.Adam(params=simple_model.parameters(), lr=1e-3, weight_decay=1e-1)

model = Net(hidden_channels=128, out_channels=32, metadata=train_data.metadata()).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3, weight_decay=1e-2)

criterion = torch.nn.BCEWithLogitsLoss()

In [11]:
def train(model, optimizer, data, key):
    start, _, end = key
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x_dict, data.edge_index_dict)

    # We perform a new round of negative sampling for every training epoch:
    neg_edge_index = negative_sampling(
        edge_index=data.edge_index_dict[key], 
        num_nodes=(data.x_dict[start].shape[0], data.x_dict[end].shape[0]),
        num_neg_samples=data.edge_label_index_dict[key].shape[1], 
        method='sparse'
    )
    
    edge_label_index = data.edge_label_index_dict[key]
    edge_label_index = torch.cat([edge_label_index, neg_edge_index], dim=-1)
    
    edge_label = data.edge_label_dict[key]
    edge_label = torch.cat([edge_label, edge_label.new_zeros(neg_edge_index.size(1))], dim=0)
    
    out = model.decode(z[start], z[end], edge_label_index)
    loss = criterion(out, edge_label)
    
    loss.backward()
    optimizer.step()
    
    return loss


@torch.no_grad()
def test(model, data, key):
    start, _, end = key
    model.eval()
    z = model.encode(data.x_dict, data.edge_index_dict)
    out = model.decode(z[start], z[end], data.edge_label_index_dict[key]).view(-1).sigmoid()
    a, b = data.edge_label_dict[key].cpu().numpy(), out.cpu().numpy()
    c = (out > 0.5).float().cpu().numpy()
        
    return roc_auc_score(a, b), accuracy_score(a, c)

In [12]:
key = ("paper", "to", "author")
start, _, end = key

best_val_auc = final_test_auc = final_test_acc = 0
for epoch in range(1, 51):
    loss = train(simple_model, simple_optimizer, train_data, key)
    val_auc, val_acc = test(simple_model, val_data, key)
    test_auc, test_acc = test(simple_model, test_data, key)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_acc = test_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f} {val_acc:.4f}, Test: {test_auc:.4f} {test_acc:.4f}')

print(f'Final Test: {final_test_auc:.4f} {final_test_acc:.4f}')

simple_z = simple_model.encode(test_data.x_dict, test_data.edge_index_dict)
simple_final_edge_index = simple_model.decode(simple_z[start], simple_z[end], test_data.edge_label_index_dict[key])

Epoch: 001, Loss: 0.7029, Val: 0.5818 0.5127, Test: 0.6133 0.5143
Epoch: 002, Loss: 0.7652, Val: 0.6377 0.5784, Test: 0.6385 0.5703
Epoch: 003, Loss: 0.6491, Val: 0.6450 0.5071, Test: 0.6332 0.4980
Epoch: 004, Loss: 0.6952, Val: 0.6757 0.5193, Test: 0.6645 0.5064
Epoch: 005, Loss: 0.6655, Val: 0.7262 0.6538, Test: 0.7285 0.6594
Epoch: 006, Loss: 0.6263, Val: 0.7442 0.6619, Test: 0.7602 0.6660
Epoch: 007, Loss: 0.6260, Val: 0.7482 0.6426, Test: 0.7697 0.6403
Epoch: 008, Loss: 0.6420, Val: 0.7570 0.6502, Test: 0.7765 0.6522
Epoch: 009, Loss: 0.6306, Val: 0.7679 0.6696, Test: 0.7810 0.6848
Epoch: 010, Loss: 0.6123, Val: 0.7737 0.6914, Test: 0.7785 0.6894
Epoch: 011, Loss: 0.6004, Val: 0.7732 0.6828, Test: 0.7715 0.6884
Epoch: 012, Loss: 0.6051, Val: 0.7726 0.6894, Test: 0.7673 0.6757
Epoch: 013, Loss: 0.6076, Val: 0.7759 0.6909, Test: 0.7717 0.6843
Epoch: 014, Loss: 0.6047, Val: 0.7812 0.6935, Test: 0.7811 0.6955
Epoch: 015, Loss: 0.5955, Val: 0.7859 0.6955, Test: 0.7898 0.6920
Epoch: 016

In [13]:
key = ("paper", "to", "author")
start, _, end = key

best_val_auc = final_test_auc = final_test_acc = 0
for epoch in range(1, 51):
    loss = train(model, optimizer, train_data, key)
    val_auc, val_acc = test(model, val_data, key)
    test_auc, test_acc = test(model, test_data, key)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_acc = test_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f} {val_acc:.4f}, Test: {test_auc:.4f} {test_acc:.4f}')

print(f'Final Test: {final_test_auc:.4f} {final_test_acc:.4f}')

z = model.encode(test_data.x_dict, test_data.edge_index_dict)
final_edge_index = model.decode(z[start], z[end], test_data.edge_label_index_dict[key])

Epoch: 001, Loss: 0.6978, Val: 0.5112 0.5000, Test: 0.4654 0.5000
Epoch: 002, Loss: 0.6928, Val: 0.5783 0.5148, Test: 0.5449 0.5150
Epoch: 003, Loss: 0.6907, Val: 0.6075 0.5677, Test: 0.5928 0.5616
Epoch: 004, Loss: 0.6898, Val: 0.6185 0.5519, Test: 0.6101 0.5494
Epoch: 005, Loss: 0.6887, Val: 0.6202 0.5672, Test: 0.6092 0.5636
Epoch: 006, Loss: 0.6874, Val: 0.6188 0.5871, Test: 0.6044 0.5942
Epoch: 007, Loss: 0.6858, Val: 0.6203 0.5876, Test: 0.6048 0.6013
Epoch: 008, Loss: 0.6845, Val: 0.6237 0.5789, Test: 0.6072 0.5807
Epoch: 009, Loss: 0.6832, Val: 0.6288 0.5611, Test: 0.6118 0.5435
Epoch: 010, Loss: 0.6820, Val: 0.6352 0.5570, Test: 0.6185 0.5300
Epoch: 011, Loss: 0.6804, Val: 0.6420 0.5596, Test: 0.6265 0.5384
Epoch: 012, Loss: 0.6790, Val: 0.6488 0.5708, Test: 0.6344 0.5670
Epoch: 013, Loss: 0.6772, Val: 0.6546 0.5937, Test: 0.6412 0.5985
Epoch: 014, Loss: 0.6755, Val: 0.6606 0.6166, Test: 0.6475 0.6174
Epoch: 015, Loss: 0.6732, Val: 0.6664 0.6268, Test: 0.6535 0.6293
Epoch: 016

## SubgraphX

In [16]:
from datetime import datetime

In [17]:
test_data.edge_label_index_dict[("paper", "to", "author")][:, 12]

tensor([10576,  1257])

In [18]:
node_1 = 12032
node_2 = 2955

paper_to_author_index = test_data.edge_index_dict[("paper", "to", "author")]
paper_to_conference_index = test_data.edge_index_dict[("paper", "to", "conference")]

print(paper_to_author_index.shape)
print(paper_to_conference_index.shape)

torch.Size([2, 17681])
torch.Size([2, 12896])


In [19]:
node_1_author_neighbors = set(paper_to_author_index[:, paper_to_author_index[0] == node_1][1].cpu().numpy())
node_1_conference_neighbors = set(paper_to_conference_index[:, paper_to_conference_index[0] == node_1][1].cpu().numpy())

print("paper coauthors", node_1_author_neighbors)
print("paper conference", node_1_conference_neighbors)

node_2_paper_neighbors = set(paper_to_author_index[:, paper_to_author_index[1] == node_2][0].cpu().numpy())

print("author papers", node_2_paper_neighbors)

paper coauthors {2955, 1013}
paper conference {16}
author papers {12032, 4480, 13507, 12862, 12031}


In [20]:
T = 5
for neighbor in node_2_paper_neighbors:
    pred_diffs = []
    sub_edge_mask = paper_to_author_index[1] == node_2
    for t in range(T):
        S_filter = torch.zeros(paper_to_author_index.shape[1], dtype=bool)
        S_filter[sub_edge_mask] = True
        S_filter[(sub_edge_mask) & (np.random.random(sub_edge_mask.shape[0]) > 0.5)] = False
        S_filter[(paper_to_author_index[0] == neighbor)] = False
        
        temp_edge_index_dict = {k: v for k, v in test_data.edge_index_dict.items()}
        temp_edge_index_dict[("paper", "to", "author")] = paper_to_author_index[:, S_filter]
        
        old_z = model.encode(test_data.x_dict, temp_edge_index_dict)
        old_pred = model.decode(old_z["paper"], old_z["author"], torch.tensor([[node_1], [node_2]]))
        
        S_filter[(paper_to_author_index[0] == neighbor)] = True
        temp_edge_index_dict = {k: v for k, v in test_data.edge_index_dict.items()}
        temp_edge_index_dict[("paper", "to", "author")] = paper_to_author_index[:, S_filter]
        
        new_z = model.encode(test_data.x_dict, temp_edge_index_dict)
        new_pred = model.decode(new_z["paper"], new_z["author"], torch.tensor([[node_1], [node_2]]))
        
        pred_diff = (new_pred - old_pred)
        pred_diffs.append(pred_diff.item())
    diff_avg, diff_std = sum(pred_diffs) / len(pred_diffs), statistics.stdev(pred_diffs) / np.sqrt(T)
    print(neighbor, "\t", round(diff_avg, 5), "\t", round(diff_std, 5), "\t", round(diff_avg / diff_std, 5))

  S_filter[(sub_edge_mask) & (np.random.random(sub_edge_mask.shape[0]) > 0.5)] = False


12032 	 0.37756 	 0.33862 	 1.11499
4480 	 0.00706 	 0.01427 	 0.49465
13507 	 0.04632 	 0.01014 	 4.56798
12862 	 0.40277 	 0.32232 	 1.24958
12031 	 0.20102 	 0.30044 	 0.66909
