## Dependencies

In [1]:
from tqdm import tqdm
import statistics

import torch
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score

import torch_geometric.transforms as T
from torch_geometric.datasets import SNAPDataset, DBLP, IMDB
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv, to_hetero
from torch_geometric.utils import negative_sampling, to_networkx

torch.manual_seed(0)

%matplotlib notebook

  from .autonotebook import tqdm as notebook_tqdm


## Data

In [2]:
IMDB(root="../data/IMDB")[0]

HeteroData(
  [1mmovie[0m={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278]
  },
  [1mdirector[0m={ x=[2081, 3066] },
  [1mactor[0m={ x=[5257, 3066] },
  [1m(movie, to, director)[0m={ edge_index=[2, 4278] },
  [1m(movie, to, actor)[0m={ edge_index=[2, 12828] },
  [1m(director, to, movie)[0m={ edge_index=[2, 4278] },
  [1m(actor, to, movie)[0m={ edge_index=[2, 12828] }
)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([
    T.ToDevice(device),
    T.RemoveIsolatedNodes(),
    T.RandomLinkSplit(
        num_val=0.05, 
        num_test=0.1, 
        is_undirected=True, 
        add_negative_train_samples=False,
        edge_types=[("movie", "to", "actor")]
    ),
    T.ToUndirected(),
])

dataset = IMDB(root="../data/IMDB", transform=transform)

train_data, val_data, test_data = dataset[0]

for data in train_data, val_data, test_data:
    del data[("director", "to", "movie")]
    del data[("actor", "to", "movie")]
    
    del data[("movie", "rev_to", "director")]
    del data[("movie", "rev_to", "actor")]

In [4]:
print(train_data)
print(val_data)
print(test_data)

HeteroData(
  [1mmovie[0m={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278]
  },
  [1mdirector[0m={ x=[2081, 3066] },
  [1mactor[0m={ x=[5257, 3066] },
  [1m(movie, to, director)[0m={ edge_index=[2, 4278] },
  [1m(movie, to, actor)[0m={
    edge_index=[2, 10905],
    edge_label=[10905],
    edge_label_index=[2, 10905]
  },
  [1m(director, rev_to, movie)[0m={ edge_index=[2, 4278] },
  [1m(actor, rev_to, movie)[0m={
    edge_index=[2, 10905],
    edge_label=[10905]
  }
)
HeteroData(
  [1mmovie[0m={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278]
  },
  [1mdirector[0m={ x=[2081, 3066] },
  [1mactor[0m={ x=[5257, 3066] },
  [1m(movie, to, director)[0m={ edge_index=[2, 4278] },
  [1m(movie, to, actor)[0m={
    edge_index=[2, 10905],
    edge_label=[1282],
    edge_label_index=[2, 1282]
  },
  [1m(director, rev_to, movie)[0m={ edge_index=[2, 4278] },
  [1m(actor, r

## Prediction

In [5]:
from torch import nn
import torch.nn.functional as F


class Encoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    
class SimpleNet(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()
        self.encoder = to_hetero(Encoder(hidden_channels=hidden_channels, out_channels=out_channels), metadata)
    
    def encode(self, x_dict, edge_index_dict):
        return self.encoder(x_dict, edge_index_dict)
    
    def decode(self, z1, z2, edge_label_index):
        x1 = z1[edge_label_index[0]]
        x2 = z2[edge_label_index[1]]
        return (x1 * x2).sum(dim=-1)
    
    
class Net(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()
        self.encoder = to_hetero(Encoder(hidden_channels=hidden_channels, out_channels=out_channels), metadata)
        
        self.W1 = nn.Linear(out_channels * 2, out_channels)
        self.W2 = nn.Linear(out_channels, 1)
        
    def encode(self, x_dict, edge_index_dict):
        return self.encoder(x_dict, edge_index_dict)
    
    def decode(self, z1, z2, edge_label_index):
        z_forward = torch.cat((z1[edge_label_index[0]], z2[edge_label_index[1]]), dim=1)
        out1 = self.W2(F.relu(self.W1(z_forward)).squeeze()).squeeze()
        
        z_reverse = torch.cat((z2[edge_label_index[1]], z1[edge_label_index[0]]), dim=1)
        out2 = self.W2(F.relu(self.W1(z_reverse)).squeeze()).squeeze()
        
        return (out1 + out2) / 2
    
    
simple_model = SimpleNet(hidden_channels=128, out_channels=32, metadata=train_data.metadata()).to(device)
simple_optimizer = torch.optim.Adam(params=simple_model.parameters(), lr=1e-3, weight_decay=1e-5)

model = Net(hidden_channels=128, out_channels=32, metadata=train_data.metadata()).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3, weight_decay=1e-5)

criterion = torch.nn.BCEWithLogitsLoss()

In [6]:
def train(model, optimizer, data, key):
    start, _, end = key
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x_dict, data.edge_index_dict)

    # We perform a new round of negative sampling for every training epoch:
    neg_edge_index = negative_sampling(
        edge_index=data.edge_index_dict[key], 
        num_nodes=(data.x_dict[start].shape[0], data.x_dict[end].shape[0]),
        num_neg_samples=data.edge_label_index_dict[key].shape[1], 
        method='sparse'
    )
    
    edge_label_index = data.edge_label_index_dict[key]
    edge_label_index = torch.cat([edge_label_index, neg_edge_index], dim=-1)
    
    edge_label = data.edge_label_dict[key]
    edge_label = torch.cat([edge_label, edge_label.new_zeros(neg_edge_index.size(1))], dim=0)
    
    out = model.decode(z[start], z[end], edge_label_index)
    loss = criterion(out, edge_label)
    
    loss.backward()
    optimizer.step()
    
    return loss


@torch.no_grad()
def test(model, data, key):
    start, _, end = key
    model.eval()
    z = model.encode(data.x_dict, data.edge_index_dict)
    out = model.decode(z[start], z[end], data.edge_label_index_dict[key]).view(-1).sigmoid()
    a, b = data.edge_label_dict[key].cpu().numpy(), out.cpu().numpy()
    c = (out > 0.5).float().cpu().numpy()
        
    return roc_auc_score(a, b), accuracy_score(a, c)

In [7]:
key = ("movie", "to", "actor")
start, _, end = key

best_val_auc = final_test_auc = final_test_acc = 0
for epoch in range(1, 51):
    loss = train(simple_model, simple_optimizer, train_data, key)
    val_auc, val_acc = test(simple_model, val_data, key)
    test_auc, test_acc = test(simple_model, test_data, key)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_acc = test_acc
        
    if epoch % 5 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f} {val_acc:.4f}, Test: {test_auc:.4f} {test_acc:.4f}')

print(f'Final Test: {final_test_auc:.4f} {final_test_acc:.4f}')

simple_z = simple_model.encode(test_data.x_dict, test_data.edge_index_dict)
simple_final_edge_index = simple_model.decode(simple_z[start], simple_z[end], test_data.edge_label_index_dict[key])

Epoch: 005, Loss: 0.6721, Val: 0.5443 0.5172, Test: 0.5713 0.5137
Epoch: 010, Loss: 0.5973, Val: 0.6252 0.5866, Test: 0.6354 0.6057
Epoch: 015, Loss: 0.5109, Val: 0.6510 0.6123, Test: 0.6547 0.6236
Epoch: 020, Loss: 0.4368, Val: 0.6671 0.6232, Test: 0.6714 0.6318
Epoch: 025, Loss: 0.3655, Val: 0.6795 0.6240, Test: 0.6813 0.6396
Epoch: 030, Loss: 0.2945, Val: 0.6876 0.6295, Test: 0.6886 0.6459
Epoch: 035, Loss: 0.2397, Val: 0.6950 0.6123, Test: 0.6992 0.6408
Epoch: 040, Loss: 0.1958, Val: 0.7044 0.6061, Test: 0.7057 0.6330
Epoch: 045, Loss: 0.1683, Val: 0.7115 0.6193, Test: 0.7099 0.6388
Epoch: 050, Loss: 0.1456, Val: 0.7207 0.6069, Test: 0.7161 0.6221
Final Test: 0.7161 0.6221


In [8]:
key = ("movie", "to", "actor")
start, _, end = key

best_val_auc = final_test_auc = final_test_acc = 0
for epoch in range(1, 51):
    loss = train(model, optimizer, train_data, key)
    val_auc, val_acc = test(model, val_data, key)
    test_auc, test_acc = test(model, test_data, key)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_acc = test_acc
        
    if epoch % 5 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f} {val_acc:.4f}, Test: {test_auc:.4f} {test_acc:.4f}')

print(f'Final Test: {final_test_auc:.4f} {final_test_acc:.4f}')

z = model.encode(test_data.x_dict, test_data.edge_index_dict)
final_edge_index = model.decode(z[start], z[end], test_data.edge_label_index_dict[key])

Epoch: 005, Loss: 0.6909, Val: 0.4724 0.5125, Test: 0.4906 0.5308
Epoch: 010, Loss: 0.6840, Val: 0.4989 0.4626, Test: 0.5143 0.4746
Epoch: 015, Loss: 0.6692, Val: 0.5415 0.5398, Test: 0.5552 0.5437
Epoch: 020, Loss: 0.6440, Val: 0.5759 0.5686, Test: 0.5860 0.5718
Epoch: 025, Loss: 0.6117, Val: 0.6035 0.5764, Test: 0.6092 0.5885
Epoch: 030, Loss: 0.5736, Val: 0.6218 0.5889, Test: 0.6241 0.5971
Epoch: 035, Loss: 0.5368, Val: 0.6378 0.5944, Test: 0.6375 0.5991
Epoch: 040, Loss: 0.4966, Val: 0.6494 0.6045, Test: 0.6462 0.6010
Epoch: 045, Loss: 0.4624, Val: 0.6564 0.6131, Test: 0.6501 0.6170
Epoch: 050, Loss: 0.4297, Val: 0.6593 0.6170, Test: 0.6494 0.6158
Final Test: 0.6494 0.6158
