## Dependencies

In [1]:
from tqdm import tqdm
import statistics

import torch
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score

import torch_geometric.transforms as T
from torch_geometric.datasets import SNAPDataset, DBLP, IMDB
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv, to_hetero
from torch_geometric.utils import negative_sampling, to_networkx

torch.manual_seed(0)

%matplotlib notebook

  from .autonotebook import tqdm as notebook_tqdm


## Data

In [2]:
IMDB(root="../data/IMDB")[0]

HeteroData(
  [1mmovie[0m={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278]
  },
  [1mdirector[0m={ x=[2081, 3066] },
  [1mactor[0m={ x=[5257, 3066] },
  [1m(movie, to, director)[0m={ edge_index=[2, 4278] },
  [1m(movie, to, actor)[0m={ edge_index=[2, 12828] },
  [1m(director, to, movie)[0m={ edge_index=[2, 4278] },
  [1m(actor, to, movie)[0m={ edge_index=[2, 12828] }
)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([
    T.ToDevice(device),
    T.RemoveIsolatedNodes(),
    T.RandomLinkSplit(
        num_val=0.05, 
        num_test=0.1, 
        is_undirected=True, 
        add_negative_train_samples=False,
        edge_types=[("movie", "to", "actor")]
    ),
    T.ToUndirected(),
])

dataset = IMDB(root="../data/IMDB", transform=transform)

train_data, val_data, test_data = dataset[0]

for data in train_data, val_data, test_data:
    del data[("director", "to", "movie")]
    del data[("actor", "to", "movie")]
    
    del data[("movie", "rev_to", "director")]
    del data[("movie", "rev_to", "actor")]

In [4]:
print(train_data)
print(val_data)
print(test_data)

HeteroData(
  [1mmovie[0m={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278]
  },
  [1mdirector[0m={ x=[2081, 3066] },
  [1mactor[0m={ x=[5257, 3066] },
  [1m(movie, to, director)[0m={ edge_index=[2, 4278] },
  [1m(movie, to, actor)[0m={
    edge_index=[2, 10905],
    edge_label=[10905],
    edge_label_index=[2, 10905]
  },
  [1m(director, rev_to, movie)[0m={ edge_index=[2, 4278] },
  [1m(actor, rev_to, movie)[0m={
    edge_index=[2, 10905],
    edge_label=[10905]
  }
)
HeteroData(
  [1mmovie[0m={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278]
  },
  [1mdirector[0m={ x=[2081, 3066] },
  [1mactor[0m={ x=[5257, 3066] },
  [1m(movie, to, director)[0m={ edge_index=[2, 4278] },
  [1m(movie, to, actor)[0m={
    edge_index=[2, 10905],
    edge_label=[1282],
    edge_label_index=[2, 1282]
  },
  [1m(director, rev_to, movie)[0m={ edge_index=[2, 4278] },
  [1m(actor, r

## Prediction

In [5]:
from torch import nn
import torch.nn.functional as F


class Encoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    
class SimpleNet(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()
        self.encoder = to_hetero(Encoder(hidden_channels=hidden_channels, out_channels=out_channels), metadata)
    
    def encode(self, x_dict, edge_index_dict):
        return self.encoder(x_dict, edge_index_dict)
    
    def decode(self, z1, z2, edge_label_index):
        x1 = z1[edge_label_index[0]]
        x2 = z2[edge_label_index[1]]
        return (x1 * x2).sum(dim=-1)
    
    
class Net(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()
        self.encoder = to_hetero(Encoder(hidden_channels=hidden_channels, out_channels=out_channels), metadata)
        
        self.W1 = nn.Linear(out_channels * 2, out_channels)
        self.W2 = nn.Linear(out_channels, 1)
        
    def encode(self, x_dict, edge_index_dict):
        return self.encoder(x_dict, edge_index_dict)
    
    def decode(self, z1, z2, edge_label_index):
        z_forward = torch.cat((z1[edge_label_index[0]], z2[edge_label_index[1]]), dim=1)
        out1 = self.W2(F.relu(self.W1(z_forward)).squeeze()).squeeze()
        
        z_reverse = torch.cat((z2[edge_label_index[1]], z1[edge_label_index[0]]), dim=1)
        out2 = self.W2(F.relu(self.W1(z_reverse)).squeeze()).squeeze()
        
        return (out1 + out2) / 2
    
    
simple_model = SimpleNet(hidden_channels=128, out_channels=32, metadata=train_data.metadata()).to(device)
simple_optimizer = torch.optim.Adam(params=simple_model.parameters(), lr=5e-3, weight_decay=1e-4)

model = Net(hidden_channels=128, out_channels=32, metadata=train_data.metadata()).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=5e-3, weight_decay=1e-4)

criterion = torch.nn.BCEWithLogitsLoss()

In [6]:
def train(model, optimizer, data, key):
    start, _, end = key
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x_dict, data.edge_index_dict)

    # We perform a new round of negative sampling for every training epoch:
    neg_edge_index = negative_sampling(
        edge_index=data.edge_index_dict[key], 
        num_nodes=(data.x_dict[start].shape[0], data.x_dict[end].shape[0]),
        num_neg_samples=data.edge_label_index_dict[key].shape[1], 
        method='sparse'
    )
    
    edge_label_index = data.edge_label_index_dict[key]
    edge_label_index = torch.cat([edge_label_index, neg_edge_index], dim=-1)
    
    edge_label = data.edge_label_dict[key]
    edge_label = torch.cat([edge_label, edge_label.new_zeros(neg_edge_index.size(1))], dim=0)
    
    out = model.decode(z[start], z[end], edge_label_index)
    loss = criterion(out, edge_label)
    
    loss.backward()
    optimizer.step()
    
    return loss


@torch.no_grad()
def test(model, data, key):
    start, _, end = key
    model.eval()
    z = model.encode(data.x_dict, data.edge_index_dict)
    out = model.decode(z[start], z[end], data.edge_label_index_dict[key]).view(-1).sigmoid()
    a, b = data.edge_label_dict[key].cpu().numpy(), out.cpu().numpy()
    c = (out > 0.5).float().cpu().numpy()
        
    return roc_auc_score(a, b), accuracy_score(a, c)

In [7]:
key = ("movie", "to", "actor")
start, _, end = key

best_val_auc = final_test_auc = final_test_acc = 0
for epoch in range(1, 51):
    loss = train(simple_model, simple_optimizer, train_data, key)
    val_auc, val_acc = test(simple_model, val_data, key)
    test_auc, test_acc = test(simple_model, test_data, key)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_acc = test_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f} {val_acc:.4f}, Test: {test_auc:.4f} {test_acc:.4f}')

print(f'Final Test: {final_test_auc:.4f} {final_test_acc:.4f}')

simple_z = simple_model.encode(test_data.x_dict, test_data.edge_index_dict)
simple_final_edge_index = simple_model.decode(simple_z[start], simple_z[end], test_data.edge_label_index_dict[key])

Epoch: 001, Loss: 0.6932, Val: 0.5290 0.4836, Test: 0.5204 0.4614
Epoch: 002, Loss: 0.6810, Val: 0.6054 0.5491, Test: 0.5938 0.5484
Epoch: 003, Loss: 0.6366, Val: 0.6695 0.6225, Test: 0.6490 0.6014
Epoch: 004, Loss: 0.5971, Val: 0.6283 0.5398, Test: 0.6098 0.5218
Epoch: 005, Loss: 0.7375, Val: 0.6905 0.6373, Test: 0.6647 0.6170
Epoch: 006, Loss: 0.5302, Val: 0.6588 0.5554, Test: 0.6395 0.5394
Epoch: 007, Loss: 0.5959, Val: 0.6670 0.5421, Test: 0.6496 0.5277
Epoch: 008, Loss: 0.5899, Val: 0.7122 0.5913, Test: 0.6886 0.5768
Epoch: 009, Loss: 0.5526, Val: 0.7354 0.6482, Test: 0.7191 0.6342
Epoch: 010, Loss: 0.5196, Val: 0.7291 0.6591, Test: 0.7193 0.6580
Epoch: 011, Loss: 0.4894, Val: 0.7150 0.6490, Test: 0.7063 0.6447
Epoch: 012, Loss: 0.4675, Val: 0.7115 0.6435, Test: 0.7035 0.6400
Epoch: 013, Loss: 0.4623, Val: 0.7269 0.6630, Test: 0.7201 0.6587
Epoch: 014, Loss: 0.4163, Val: 0.7426 0.6724, Test: 0.7368 0.6673
Epoch: 015, Loss: 0.3772, Val: 0.7472 0.6599, Test: 0.7413 0.6541
Epoch: 016

In [8]:
key = ("movie", "to", "actor")
start, _, end = key

best_val_auc = final_test_auc = final_test_acc = 0
for epoch in range(1, 51):
    loss = train(model, optimizer, train_data, key)
    val_auc, val_acc = test(model, val_data, key)
    test_auc, test_acc = test(model, test_data, key)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_acc = test_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f} {val_acc:.4f}, Test: {test_auc:.4f} {test_acc:.4f}')

print(f'Final Test: {final_test_auc:.4f} {final_test_acc:.4f}')

z = model.encode(test_data.x_dict, test_data.edge_index_dict)
final_edge_index = model.decode(z[start], z[end], test_data.edge_label_index_dict[key])

Epoch: 001, Loss: 0.6972, Val: 0.4625 0.4930, Test: 0.4548 0.4992
Epoch: 002, Loss: 0.6920, Val: 0.4758 0.4548, Test: 0.4666 0.4555
Epoch: 003, Loss: 0.6895, Val: 0.5091 0.4540, Test: 0.4985 0.4497
Epoch: 004, Loss: 0.6848, Val: 0.5497 0.5351, Test: 0.5333 0.5257
Epoch: 005, Loss: 0.6765, Val: 0.5705 0.5538, Test: 0.5491 0.5488
Epoch: 006, Loss: 0.6664, Val: 0.5851 0.5569, Test: 0.5611 0.5519
Epoch: 007, Loss: 0.6522, Val: 0.6035 0.5858, Test: 0.5759 0.5651
Epoch: 008, Loss: 0.6329, Val: 0.6191 0.5944, Test: 0.5905 0.5722
Epoch: 009, Loss: 0.6122, Val: 0.6342 0.5991, Test: 0.6063 0.5905
Epoch: 010, Loss: 0.5985, Val: 0.6378 0.6014, Test: 0.6128 0.5909
Epoch: 011, Loss: 0.5781, Val: 0.6570 0.6334, Test: 0.6350 0.6123
Epoch: 012, Loss: 0.5651, Val: 0.6578 0.6209, Test: 0.6367 0.6197
Epoch: 013, Loss: 0.5410, Val: 0.6582 0.6225, Test: 0.6359 0.6178
Epoch: 014, Loss: 0.5319, Val: 0.6709 0.6342, Test: 0.6495 0.6193
Epoch: 015, Loss: 0.5067, Val: 0.6743 0.6279, Test: 0.6524 0.6131
Epoch: 016

## SubgraphX

In [9]:
from datetime import datetime

In [10]:
test_data.edge_label_index_dict[("movie", "to", "actor")][:, 12]

tensor([3057, 3242])

In [11]:
node_1 = 2344
node_2 = 4044

movie_to_actor_index = test_data.edge_index_dict[("movie", "to", "actor")]
movie_to_director_index = test_data.edge_index_dict[("movie", "to", "director")]

print(movie_to_actor_index.shape)
print(movie_to_director_index.shape)

torch.Size([2, 11546])
torch.Size([2, 4278])


In [12]:
node_1_actor_neighbors = set(movie_to_actor_index[:, movie_to_actor_index[0] == node_1][1].cpu().numpy())
node_1_director_neighbors = set(movie_to_director_index[:, movie_to_director_index[0] == node_1][1].cpu().numpy())

print("movie actors", node_1_actor_neighbors)
print("movie director", node_1_director_neighbors)

node_2_actor_neighbors = set(movie_to_actor_index[:, movie_to_actor_index[1] == node_2][0].cpu().numpy())

print("actor movies", node_2_actor_neighbors)

movie actors {4044, 476, 2126}
movie director {719}
actor movies {2208, 2818, 551, 2344, 280}


In [14]:
T = 5
for neighbor in node_2_actor_neighbors:
    pred_diffs = []
    sub_edge_mask = movie_to_actor_index[1] == node_2
    for t in range(T):
        S_filter = torch.zeros(movie_to_actor_index.shape[1], dtype=bool)
        S_filter[sub_edge_mask] = True
        S_filter[(sub_edge_mask) & (np.random.random(sub_edge_mask.shape[0]) > 0.5)] = False
        S_filter[(movie_to_actor_index[0] == neighbor)] = False
        
        temp_edge_index_dict = {k: v for k, v in test_data.edge_index_dict.items()}
        temp_edge_index_dict[("movie", "to", "actor")] = movie_to_actor_index[:, S_filter]
        
        old_z = model.encode(test_data.x_dict, temp_edge_index_dict)
        old_pred = model.decode(old_z["movie"], old_z["actor"], torch.tensor([[node_1], [node_2]]))
        
        S_filter[(movie_to_actor_index[0] == neighbor)] = True
        temp_edge_index_dict = {k: v for k, v in test_data.edge_index_dict.items()}
        temp_edge_index_dict[("movie", "to", "actor")] = movie_to_actor_index[:, S_filter]
        
        new_z = model.encode(test_data.x_dict, temp_edge_index_dict)
        new_pred = model.decode(new_z["movie"], new_z["actor"], torch.tensor([[node_1], [node_2]]))
        
        pred_diff = (new_pred - old_pred)
        pred_diffs.append(pred_diff.item())
    diff_avg, diff_std = sum(pred_diffs) / len(pred_diffs), statistics.stdev(pred_diffs) / np.sqrt(T)
    print(neighbor, "\t", round(diff_avg, 5), "\t", round(diff_std, 5), "\t", round(diff_avg / diff_std, 5))

  S_filter[(sub_edge_mask) & (np.random.random(sub_edge_mask.shape[0]) > 0.5)] = False


2208 	 0.22202 	 0.44816 	 0.49541
2818 	 3.01895 	 2.8313 	 1.06627
551 	 1.21486 	 0.4297 	 2.82721
2344 	 2.62995 	 3.02644 	 0.86899
280 	 -0.56216 	 0.11493 	 -4.89135
