## Dependencies

In [1]:
from tqdm import tqdm
import statistics

import torch
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score

import torch_geometric.transforms as T
from torch_geometric.datasets import SNAPDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv
from torch_geometric.utils import negative_sampling, to_networkx

torch.manual_seed(0)

%matplotlib notebook

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data

In [3]:
transform = T.Compose(
    [
        T.ToDevice(device),
        T.RemoveIsolatedNodes(),
        T.RandomLinkSplit(
            num_val=0.05,
            num_test=0.1,
            is_undirected=True,
            add_negative_train_samples=False,
        ),
    ]
)

dataset = SNAPDataset(
    root="./data/SNAPDataset", name="ego-facebook", transform=transform
)

train_data = next(iter(DataLoader([x[0] for x in dataset], batch_size=10)))
val_data = next(iter(DataLoader([x[1] for x in dataset], batch_size=10)))
test_data = next(iter(DataLoader([x[2] for x in dataset], batch_size=10)))

In [4]:
print(train_data)
print(val_data)
print(test_data)

DataBatch(x=[4167, 1406], edge_index=[2, 151430], circle=[4233], circle_batch=[4233], edge_label=[75715], edge_label_index=[2, 75715], batch=[4167], ptr=[11])
DataBatch(x=[4167, 1406], edge_index=[2, 151430], circle=[4233], circle_batch=[4233], edge_label=[8894], edge_label_index=[2, 8894], batch=[4167], ptr=[11])
DataBatch(x=[4167, 1406], edge_index=[2, 160324], circle=[4233], circle_batch=[4233], edge_label=[17800], edge_label_index=[2, 17800], batch=[4167], ptr=[11])


## Prediction

In [5]:
from torch import nn
import torch.nn.functional as F


class SimpleNet(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
    
    def forward(self, x, edge_index, data=None):
        z = self.encode(x, edge_index)
        out = model.decode(z, edge_index)
        return torch.hstack((-out, out)).T


class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        # TODO: look into SAGEConv, GATConv, GINConv, comparison between
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        
        self.W1 = nn.Linear(out_channels * 2, out_channels)
        self.W2 = nn.Linear(out_channels, 1)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        z1 = torch.cat((z[edge_label_index[0]], z[edge_label_index[1]]), dim=1)
        out1 = self.W2(F.relu(self.W1(z1)).squeeze())
        
        z2 = torch.cat((z[edge_label_index[1]], z[edge_label_index[0]]), dim=1)
        out2 = self.W2(F.relu(self.W1(z2)).squeeze())
        
        return (out1 + out2) / 2
    
    def forward(self, x, edge_index, edge_label_index, data=None):
        z = self.encode(x, edge_index)
        out = model.decode(z, edge_label_index)
        return torch.hstack((-out, out)).T

simple_model = SimpleNet(dataset.num_features, 128, 32).to(device)
simple_optimizer = torch.optim.Adam(params=simple_model.parameters(), lr=3e-3, weight_decay=2e-3)
    
model = Net(dataset.num_features, 128, 32).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-3, weight_decay=2e-3)
criterion = torch.nn.BCEWithLogitsLoss()

# TODO: These methods simultaneously use node feature and graph structure properties.
#       Is it possible to train models that look at each aspect separately
#       Can look at only node features by just passing original layer to MLP
#       Unsure if can look at only graph by passing random vector into GCNConv
#       Should also read up on Node2Vec and other methods of generating node embeddings (talk to Rex)

In [6]:
def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x, data.edge_index)

    # We perform a new round of negative sampling for every training epoch:
    neg_edge_index = negative_sampling(
        edge_index=data.edge_index, 
        num_nodes=data.num_nodes,
        num_neg_samples=data.edge_label_index.shape[1], 
        method='sparse'
    )
    
    edge_label_index = torch.cat([data.edge_label_index, neg_edge_index], dim=-1)
    edge_label = torch.cat([data.edge_label, data.edge_label.new_zeros(neg_edge_index.size(1))], dim=0)

    out = model.decode(z, edge_label_index).view(-1)
    loss = criterion(out, edge_label)
    loss.backward()
    optimizer.step()
    return loss


@torch.no_grad()
def test(model, data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
    a, b = data.edge_label.cpu().numpy(), out.cpu().numpy()
    c = (out > 0.5).float().cpu().numpy()
    return roc_auc_score(a, b), accuracy_score(a, c)

In [7]:
best_val_auc = final_test_auc = final_test_acc = 0
for epoch in range(1, 301):
    loss = train(simple_model, simple_optimizer, train_data)
    val_auc, val_acc = test(simple_model, val_data)
    test_auc, test_acc = test(simple_model, test_data)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_acc = test_acc
    
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f} {val_acc:.4f}, Test: {test_auc:.4f} {test_acc:.4f}')

print(f'Final Test: {final_test_auc:.4f} {final_test_acc:.4f}')

simple_z = simple_model.encode(test_data.x, test_data.edge_index)
simple_final_edge_index = simple_model.decode(simple_z, test_data.edge_label_index)

Epoch: 010, Loss: 0.4942, Val: 0.8451 0.5345, Test: 0.8513 0.5370
Epoch: 020, Loss: 0.4327, Val: 0.8577 0.5291, Test: 0.8629 0.5319
Epoch: 030, Loss: 0.4278, Val: 0.8807 0.5426, Test: 0.8848 0.5438
Epoch: 040, Loss: 0.4221, Val: 0.8815 0.5442, Test: 0.8849 0.5476
Epoch: 050, Loss: 0.4210, Val: 0.8808 0.5386, Test: 0.8843 0.5407
Epoch: 060, Loss: 0.4208, Val: 0.8873 0.5421, Test: 0.8907 0.5435
Epoch: 070, Loss: 0.4188, Val: 0.8896 0.5407, Test: 0.8930 0.5412
Epoch: 080, Loss: 0.4174, Val: 0.8955 0.5427, Test: 0.8993 0.5434
Epoch: 090, Loss: 0.4168, Val: 0.8999 0.5543, Test: 0.9037 0.5535
Epoch: 100, Loss: 0.4165, Val: 0.9014 0.5641, Test: 0.9055 0.5644
Epoch: 110, Loss: 0.4180, Val: 0.9041 0.5625, Test: 0.9083 0.5620
Epoch: 120, Loss: 0.4176, Val: 0.9026 0.5625, Test: 0.9068 0.5620
Epoch: 130, Loss: 0.4185, Val: 0.9028 0.5609, Test: 0.9068 0.5596
Epoch: 140, Loss: 0.4164, Val: 0.9046 0.5625, Test: 0.9088 0.5636
Epoch: 150, Loss: 0.4171, Val: 0.9031 0.5631, Test: 0.9072 0.5638
Epoch: 160

In [8]:
best_val_auc = final_test_auc = final_test_acc = 0
for epoch in range(1, 301):
    loss = train(model, optimizer, train_data)
    val_auc, val_acc = test(model, val_data)
    test_auc, test_acc = test(model, test_data)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_acc = test_acc

    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f} {val_acc:.4f}, Test: {test_auc:.4f} {test_acc:.4f}')

print(f'Final Test: {final_test_auc:.4f} {final_test_acc:.4f}')

z = model.encode(test_data.x, test_data.edge_index)
final_edge_index = model.decode(z, test_data.edge_label_index)

Epoch: 010, Loss: 0.6516, Val: 0.6647 0.5146, Test: 0.6696 0.5175
Epoch: 020, Loss: 0.5874, Val: 0.6566 0.5806, Test: 0.6639 0.5908
Epoch: 030, Loss: 0.5259, Val: 0.7309 0.6317, Test: 0.7386 0.6443
Epoch: 040, Loss: 0.4354, Val: 0.7835 0.6977, Test: 0.7933 0.7119
Epoch: 050, Loss: 0.3812, Val: 0.8182 0.7082, Test: 0.8237 0.7188
Epoch: 060, Loss: 0.3397, Val: 0.8617 0.7684, Test: 0.8674 0.7763
Epoch: 070, Loss: 0.3074, Val: 0.8780 0.7876, Test: 0.8841 0.7991
Epoch: 080, Loss: 0.2834, Val: 0.8856 0.7958, Test: 0.8912 0.8078
Epoch: 090, Loss: 0.2755, Val: 0.8886 0.8058, Test: 0.8938 0.8142
Epoch: 100, Loss: 0.2710, Val: 0.8851 0.8080, Test: 0.8900 0.8121
Epoch: 110, Loss: 0.2684, Val: 0.8835 0.8038, Test: 0.8882 0.8101
Epoch: 120, Loss: 0.2666, Val: 0.8805 0.7959, Test: 0.8851 0.8020
Epoch: 130, Loss: 0.2638, Val: 0.8799 0.7938, Test: 0.8847 0.8005
Epoch: 140, Loss: 0.2644, Val: 0.8782 0.7919, Test: 0.8832 0.7990
Epoch: 150, Loss: 0.2609, Val: 0.8760 0.7888, Test: 0.8811 0.7980
Epoch: 160