## Dependencies

In [1]:
from tqdm import tqdm
import statistics

import torch
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score

import torch_geometric.transforms as T
from torch_geometric.datasets import SNAPDataset, DBLP, IMDB
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv, to_hetero
from torch_geometric.utils import negative_sampling, to_networkx

torch.manual_seed(0)

%matplotlib notebook

  from .autonotebook import tqdm as notebook_tqdm


## Data

In [2]:
DBLP(root="../data/DBLP")[0]

HeteroData(
  [1mauthor[0m={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  [1mpaper[0m={ x=[14328, 4231] },
  [1mterm[0m={ x=[7723, 50] },
  [1mconference[0m={ num_nodes=20 },
  [1m(author, to, paper)[0m={ edge_index=[2, 19645] },
  [1m(paper, to, author)[0m={ edge_index=[2, 19645] },
  [1m(paper, to, term)[0m={ edge_index=[2, 85810] },
  [1m(paper, to, conference)[0m={ edge_index=[2, 14328] },
  [1m(term, to, paper)[0m={ edge_index=[2, 85810] },
  [1m(conference, to, paper)[0m={ edge_index=[2, 14328] }
)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([
    T.ToDevice(device),
    T.RemoveIsolatedNodes(),
    T.RandomLinkSplit(
        num_val=0.05, 
        num_test=0.1, 
        is_undirected=True, 
        add_negative_train_samples=False,
        edge_types=[("paper", "to", "author")]
    ),
    T.ToUndirected(),
])

dataset = DBLP(root="../data/DBLP", transform=transform)

train_data, val_data, test_data = dataset[0]

for data in train_data, val_data, test_data:
    del data["term"]
    del data[("paper", "to", "term")]
    del data[("term", "to", "paper")]
    del data[("author", "to", "paper")]
    del data[("conference", "to", "paper")]
    
    del data[("paper", "rev_to", "author")]
    del data[("term", "rev_to", "paper")]
    del data[("paper", "rev_to", "term")]
    del data[("paper", "rev_to", "conference")]
    
    data["conference"].x = torch.ones((20, 1))

In [4]:
print(train_data)
print(val_data)
print(test_data)

HeteroData(
  [1mauthor[0m={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  [1mpaper[0m={ x=[14328, 4231] },
  [1mconference[0m={
    num_nodes=20,
    x=[20, 1]
  },
  [1m(paper, to, author)[0m={
    edge_index=[2, 16699],
    edge_label=[16699],
    edge_label_index=[2, 16699]
  },
  [1m(paper, to, conference)[0m={ edge_index=[2, 14328] },
  [1m(author, rev_to, paper)[0m={
    edge_index=[2, 16699],
    edge_label=[16699]
  },
  [1m(conference, rev_to, paper)[0m={ edge_index=[2, 14328] }
)
HeteroData(
  [1mauthor[0m={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  [1mpaper[0m={ x=[14328, 4231] },
  [1mconference[0m={
    num_nodes=20,
    x=[20, 1]
  },
  [1m(paper, to, author)[0m={
    edge_index=[2, 16699],
    edge_label=[1964],
    edge_label_index=[2, 1964]
  },
  [1m(paper, to, conference)[0m={ edge_index=[2, 14328] },
  [1m(author, rev_to

## Prediction

In [5]:
from torch import nn
import torch.nn.functional as F


class Encoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    
class SimpleNet(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()
        self.encoder = to_hetero(Encoder(hidden_channels=hidden_channels, out_channels=out_channels), metadata)
    
    def encode(self, x_dict, edge_index_dict):
        return self.encoder(x_dict, edge_index_dict)
    
    def decode(self, z1, z2, edge_label_index):
        x1 = z1[edge_label_index[0]]
        x2 = z2[edge_label_index[1]]
        return (x1 * x2).sum(dim=-1)
    
    
class Net(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()
        self.encoder = to_hetero(Encoder(hidden_channels=hidden_channels, out_channels=out_channels), metadata)
        
        self.W1 = nn.Linear(out_channels * 2, out_channels)
        self.W2 = nn.Linear(out_channels, 1)
        
    def encode(self, x_dict, edge_index_dict):
        return self.encoder(x_dict, edge_index_dict)
    
    def decode(self, z1, z2, edge_label_index):
        z_forward = torch.cat((z1[edge_label_index[0]], z2[edge_label_index[1]]), dim=1)
        out1 = self.W2(F.relu(self.W1(z_forward)).squeeze()).squeeze()
        
        z_reverse = torch.cat((z2[edge_label_index[1]], z1[edge_label_index[0]]), dim=1)
        out2 = self.W2(F.relu(self.W1(z_reverse)).squeeze()).squeeze()
        
        return (out1 + out2) / 2
    
    
simple_model = SimpleNet(hidden_channels=128, out_channels=32, metadata=train_data.metadata()).to(device)
simple_optimizer = torch.optim.Adam(params=simple_model.parameters(), lr=1e-3, weight_decay=1e-1)

model = Net(hidden_channels=128, out_channels=32, metadata=train_data.metadata()).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3, weight_decay=1e-2)

criterion = torch.nn.BCEWithLogitsLoss()

In [6]:
def train(model, optimizer, data, key):
    start, _, end = key
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x_dict, data.edge_index_dict)

    # We perform a new round of negative sampling for every training epoch:
    neg_edge_index = negative_sampling(
        edge_index=data.edge_index_dict[key], 
        num_nodes=(data.x_dict[start].shape[0], data.x_dict[end].shape[0]),
        num_neg_samples=data.edge_label_index_dict[key].shape[1], 
        method='sparse'
    )
    
    edge_label_index = data.edge_label_index_dict[key]
    edge_label_index = torch.cat([edge_label_index, neg_edge_index], dim=-1)
    
    edge_label = data.edge_label_dict[key]
    edge_label = torch.cat([edge_label, edge_label.new_zeros(neg_edge_index.size(1))], dim=0)
    
    out = model.decode(z[start], z[end], edge_label_index)
    loss = criterion(out, edge_label)
    
    loss.backward()
    optimizer.step()
    
    return loss


@torch.no_grad()
def test(model, data, key):
    start, _, end = key
    model.eval()
    z = model.encode(data.x_dict, data.edge_index_dict)
    out = model.decode(z[start], z[end], data.edge_label_index_dict[key]).view(-1).sigmoid()
    a, b = data.edge_label_dict[key].cpu().numpy(), out.cpu().numpy()
    c = (out > 0.5).float().cpu().numpy()
        
    return roc_auc_score(a, b), accuracy_score(a, c)

In [7]:
key = ("paper", "to", "author")
start, _, end = key

best_val_auc = final_test_auc = final_test_acc = 0
for epoch in range(1, 51):
    loss = train(simple_model, simple_optimizer, train_data, key)
    val_auc, val_acc = test(simple_model, val_data, key)
    test_auc, test_acc = test(simple_model, test_data, key)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_acc = test_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f} {val_acc:.4f}, Test: {test_auc:.4f} {test_acc:.4f}')

print(f'Final Test: {final_test_auc:.4f} {final_test_acc:.4f}')

simple_z = simple_model.encode(test_data.x_dict, test_data.edge_index_dict)
simple_final_edge_index = simple_model.decode(simple_z[start], simple_z[end], test_data.edge_label_index_dict[key])

Epoch: 001, Loss: 0.7300, Val: 0.6557 0.5168, Test: 0.6502 0.5303
Epoch: 002, Loss: 0.8422, Val: 0.6656 0.5545, Test: 0.6435 0.5481
Epoch: 003, Loss: 0.6663, Val: 0.6648 0.5010, Test: 0.6419 0.5013
Epoch: 004, Loss: 0.7530, Val: 0.7079 0.5550, Test: 0.6894 0.5642
Epoch: 005, Loss: 0.6658, Val: 0.7469 0.6762, Test: 0.7395 0.6726
Epoch: 006, Loss: 0.6109, Val: 0.7624 0.6441, Test: 0.7600 0.6451
Epoch: 007, Loss: 0.6471, Val: 0.7701 0.6298, Test: 0.7690 0.6372
Epoch: 008, Loss: 0.6651, Val: 0.7735 0.6558, Test: 0.7726 0.6568
Epoch: 009, Loss: 0.6287, Val: 0.7752 0.6838, Test: 0.7717 0.6861
Epoch: 010, Loss: 0.6006, Val: 0.7728 0.6899, Test: 0.7680 0.6940
Epoch: 011, Loss: 0.6032, Val: 0.7720 0.6767, Test: 0.7651 0.6861
Epoch: 012, Loss: 0.6192, Val: 0.7743 0.6731, Test: 0.7680 0.6917
Epoch: 013, Loss: 0.6183, Val: 0.7801 0.6935, Test: 0.7759 0.6983
Epoch: 014, Loss: 0.6031, Val: 0.7854 0.6991, Test: 0.7838 0.7014
Epoch: 015, Loss: 0.5901, Val: 0.7894 0.6869, Test: 0.7894 0.6879
Epoch: 016

In [8]:
key = ("paper", "to", "author")
start, _, end = key

best_val_auc = final_test_auc = final_test_acc = 0
for epoch in range(1, 51):
    loss = train(model, optimizer, train_data, key)
    val_auc, val_acc = test(model, val_data, key)
    test_auc, test_acc = test(model, test_data, key)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_acc = test_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f} {val_acc:.4f}, Test: {test_auc:.4f} {test_acc:.4f}')

print(f'Final Test: {final_test_auc:.4f} {final_test_acc:.4f}')

z = model.encode(test_data.x_dict, test_data.edge_index_dict)
final_edge_index = model.decode(z[start], z[end], test_data.edge_label_index_dict[key])

Epoch: 001, Loss: 0.6972, Val: 0.4650 0.4990, Test: 0.4589 0.5013
Epoch: 002, Loss: 0.6924, Val: 0.4974 0.4903, Test: 0.4833 0.4661
Epoch: 003, Loss: 0.6907, Val: 0.5351 0.4893, Test: 0.5185 0.4891
Epoch: 004, Loss: 0.6900, Val: 0.5615 0.4980, Test: 0.5444 0.4975
Epoch: 005, Loss: 0.6893, Val: 0.5773 0.4863, Test: 0.5608 0.4779
Epoch: 006, Loss: 0.6883, Val: 0.5890 0.4883, Test: 0.5727 0.4801
Epoch: 007, Loss: 0.6871, Val: 0.5989 0.4908, Test: 0.5827 0.4809
Epoch: 008, Loss: 0.6859, Val: 0.6074 0.5351, Test: 0.5903 0.5232
Epoch: 009, Loss: 0.6846, Val: 0.6132 0.5804, Test: 0.5960 0.5728
Epoch: 010, Loss: 0.6835, Val: 0.6199 0.6029, Test: 0.6022 0.5860
Epoch: 011, Loss: 0.6820, Val: 0.6248 0.6120, Test: 0.6069 0.5952
Epoch: 012, Loss: 0.6805, Val: 0.6295 0.6074, Test: 0.6109 0.5934
Epoch: 013, Loss: 0.6789, Val: 0.6356 0.6085, Test: 0.6171 0.5970
Epoch: 014, Loss: 0.6773, Val: 0.6416 0.6054, Test: 0.6232 0.5983
Epoch: 015, Loss: 0.6750, Val: 0.6485 0.6110, Test: 0.6302 0.6013
Epoch: 016

## SubgraphX

In [9]:
from datetime import datetime

In [10]:
test_data.edge_label_index_dict[("paper", "to", "author")][:, 12]

tensor([12032,  2955])

In [11]:
node_1 = 12032
node_2 = 2955

paper_to_author_index = test_data.edge_index_dict[("paper", "to", "author")]
paper_to_conference_index = test_data.edge_index_dict[("paper", "to", "conference")]

print(paper_to_author_index.shape)
print(paper_to_conference_index.shape)

torch.Size([2, 17681])
torch.Size([2, 14328])


In [12]:
node_1_author_neighbors = set(paper_to_author_index[:, paper_to_author_index[0] == node_1][1].cpu().numpy())
node_1_conference_neighbors = set(paper_to_conference_index[:, paper_to_conference_index[0] == node_1][1].cpu().numpy())

print("paper coauthors", node_1_author_neighbors)
print("paper conference", node_1_conference_neighbors)

node_2_paper_neighbors = set(paper_to_author_index[:, paper_to_author_index[1] == node_2][0].cpu().numpy())

print("author papers", node_2_paper_neighbors)

paper coauthors set()
paper conference {16}
author papers {4480, 13507, 12862, 12031}


In [13]:
T = 5
for neighbor in node_2_paper_neighbors:
    pred_diffs = []
    sub_edge_mask = paper_to_author_index[1] == node_2
    for t in range(T):
        S_filter = torch.zeros(paper_to_author_index.shape[1], dtype=bool)
        S_filter[sub_edge_mask] = True
        S_filter[(sub_edge_mask) & (np.random.random(sub_edge_mask.shape[0]) > 0.5)] = False
        S_filter[(paper_to_author_index[0] == neighbor)] = False
        
        temp_edge_index_dict = {k: v for k, v in test_data.edge_index_dict.items()}
        temp_edge_index_dict[("paper", "to", "author")] = paper_to_author_index[:, S_filter]
        
        old_z = model.encode(test_data.x_dict, temp_edge_index_dict)
        old_pred = model.decode(old_z["paper"], old_z["author"], torch.tensor([[node_1], [node_2]]))
        
        S_filter[(paper_to_author_index[0] == neighbor)] = True
        temp_edge_index_dict = {k: v for k, v in test_data.edge_index_dict.items()}
        temp_edge_index_dict[("paper", "to", "author")] = paper_to_author_index[:, S_filter]
        
        new_z = model.encode(test_data.x_dict, temp_edge_index_dict)
        new_pred = model.decode(new_z["paper"], new_z["author"], torch.tensor([[node_1], [node_2]]))
        
        pred_diff = (new_pred - old_pred)
        pred_diffs.append(pred_diff.item())
    diff_avg, diff_std = sum(pred_diffs) / len(pred_diffs), statistics.stdev(pred_diffs) / np.sqrt(T)
    print(neighbor, "\t", round(diff_avg, 5), "\t", round(diff_std, 5), "\t", round(diff_avg / diff_std, 5))

  S_filter[(sub_edge_mask) & (np.random.random(sub_edge_mask.shape[0]) > 0.5)] = False


4480 	 0.02378 	 0.01707 	 1.39276
13507 	 0.41612 	 0.38254 	 1.08779
12862 	 0.80234 	 0.4506 	 1.7806
12031 	 -0.06213 	 0.01119 	 -5.55069
