# Installing Packages and Frameworks:

In [None]:
import torch
import os
os.environ['TORCH'] = torch.__version__
print("PyTorch has version {}".format(torch.__version__))

PyTorch has version 2.0.1+cu118


In [None]:
import os
import random
import pandas as pd
import numpy as np

import pickle

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from torch.nn import Linear
from torch_sparse import SparseTensor

import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.data import NeighborSampler, Data, Dataset
from torch_geometric.utils import negative_sampling, convert, to_dense_adj
from torch_geometric.utils import subgraph, to_networkx, from_networkx
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import to_torch_csc_tensor
from typing import Union, Tuple
from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size

from tqdm import trange


from ogb.linkproppred import PygLinkPropPredDataset, Evaluator
from ogb.io import DatasetSaver

import networkx as nx
import matplotlib.pyplot as plt

# Building Subgraph:

In [None]:
# Load the ogbl-ddi dataset
dataset_name = 'ogbl-ddi'
dataset = PygLinkPropPredDataset(name='ogbl-ddi')
# dataset = PygLinkPropPredDataset(name='ogbl-ddi',
#                                      transform=T.ToSparseTensor())
print(f'The {dataset_name} dataset has {len(dataset)} graph(s).')

Downloading http://snap.stanford.edu/ogb/data/linkproppred/ddi.zip


Downloaded 0.04 GB: 100%|██████████| 46/46 [00:00<00:00, 56.18it/s]


Extracting dataset/ddi.zip


Processing...


Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 36.04it/s]


Converting graphs into PyG objects...


100%|██████████| 1/1 [00:00<00:00, 660.10it/s]

Saving...
The ogbl-ddi dataset has 1 graph(s).



Done!


In [None]:
ddi_graph = dataset[0]

In [None]:
#Finding the right nodes to choose:
df = pd.read_csv("data/dataWithSmiles.csv")

In [1]:
with open('data/node_features.pickle', 'rb') as handle:
    node_features = pickle.load(handle)
with open('data/train_pos.pickle', 'rb') as handle:
    train_pos = pickle.load(handle)
with open('data/valid_pos.pickle', 'rb') as handle:
    valid_pos = pickle.load(handle)
with open('data/valid_neg.pickle', 'rb') as handle:
    valid_neg = pickle.load(handle)
with open('data/test_pos.pickle', 'rb') as handle:
    test_pos = pickle.load(handle)
with open('data/test_neg.pickle', 'rb') as handle:
    test_neg = pickle.load(handle)
with open('data/nodes.pickle', 'rb') as handle:
    nodes = pickle.load(handle)
with open('data/nodes_removed.pickle', 'rb') as handle:
    nodes_removed = pickle.load(handle)

NameError: name 'pickle' is not defined

In [None]:
print(f'Number of training positive edges: {train_pos.shape[0]}')
print(f'Number of validation positive edges: {valid_pos.shape[0]}')
print(f'Number of validation negative edges: {valid_neg.shape[0]}')
print(f'Number of test positive edges: {test_pos.shape[0]}')
print(f'Number of test negative edges: {test_neg.shape[0]}')

Number of training positive edges: 887234
Number of validation positive edges: 111117
Number of validation negative edges: 74182
Number of test positive edges: 114869
Number of test negative edges: 70962


In [None]:
nodes = torch.tensor(nodes)

In [None]:
data = dataset[0].subgraph(nodes)

# Training and Evaluation of Subgraph


In [None]:
def get_spd_matrix(G, S, max_spd=5):
    spd_matrix = np.zeros((G.number_of_nodes(), len(S)), dtype=np.float32)
    for i, node_S in enumerate(S):
        for node, length in nx.shortest_path_length(G, source=node_S).items():
            spd_matrix[node, i] = min(length, max_spd)
    return spd_matrix


class Logger(object):
    def __init__(self, runs, info=None):
        self.info = info
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 2
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax = result[:, 0].argmax().item()
            print(f'Run {run + 1:02d}:')
            print(f'Highest Valid: {result[:, 0].max():.2f}')
            print(f'   Final Test: {result[argmax, 1]:.2f}')
        else:
            result = 100 * torch.tensor(self.results)
            best_results = []
            for r in result:
                valid = r[:, 0].max().item()
                test = r[r[:, 0].argmax(), 1].item()
                best_results.append((valid, test))
            best_result = torch.tensor(best_results)
            print(f'All runs:')
            r = best_result[:, 0]
            print(f'Highest Valid: {r.mean():.4f} ± {r.std():.4f}')
            r = best_result[:, 1]
            print(f'   Final Test: {r.mean():.4f} ± {r.std():.4f}')


class SAGEConv(MessagePassing):
    r"""The GraphSAGE operator from the `"Inductive Representation Learning on
    Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot
        \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j

    Args:
        in_channels (int or tuple): Size of each input sample. A tuple
            corresponds to the sizes of source and target dimensionalities.
        out_channels (int): Size of each output sample.
        normalize (bool, optional): If set to :obj:`True`, output features
            will be :math:`\ell_2`-normalized, *i.e.*,
            :math:`\frac{\mathbf{x}^{\prime}_i}
            {\| \mathbf{x}^{\prime}_i \|_2}`.
            (default: :obj:`False`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, normalize: bool = False,
                 root_weight: bool = True,
                 bias: bool = True, **kwargs):  # yapf: disable
        kwargs.setdefault('aggr', 'mean')
        super(SAGEConv, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
        if self.root_weight:
            self.lin_r = Linear(in_channels[1], out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        if self.root_weight:
            self.lin_r.reset_parameters()


    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                edge_attr: OptTensor = None, size: Size = None) -> Tensor:
        """"""
        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # Node and edge feature dimensionalites need to match.
        if isinstance(edge_index, Tensor):
            assert edge_attr is not None
            assert x[0].size(-1) == edge_attr.size(-1)
        elif isinstance(edge_index, SparseTensor):
            assert x[0].size(-1) == edge_index.size(-1)

        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
        out = self.lin_l(out)

        x_r = x[1]
        if self.root_weight and x_r is not None:
            out += self.lin_r(x_r)

        if self.normalize:
            out = F.normalize(out, p=2., dim=-1)

        return out


    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        return F.relu(x_j + edge_attr)


    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)


class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
        super(GraphSAGE,self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t, edge_attr, emb_ea):
        edge_attr = torch.mm(edge_attr, emb_ea)
        for conv in self.convs[:-1]:
            x = conv(x, adj_t, edge_attr)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t, edge_attr)
        return x


class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)

In [None]:
def train(model, predictor, edge_attr, x, emb_ea, adj_t, train_edge, optimizer, batch_size):
    edge_index = adj_t

    model.train()
    predictor.train()

    pos_train_edge = train_edge.to(x.device)

    total_loss = total_examples = 0
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size, shuffle=True):
        optimizer.zero_grad()

        h = model(x, adj_t, edge_attr, emb_ea)

        edge = pos_train_edge[perm].t()

        pos_out = predictor(h[edge[0]], h[edge[1]])
        pos_loss = -torch.log(pos_out + 1e-15).mean()

        edge = negative_sampling(edge_index, num_nodes=x.size(0),
                                 num_neg_samples=perm.size(0), method='dense')

        neg_out = predictor(h[edge[0]], h[edge[1]])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(x, 1.0)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples

    return total_loss / total_examples


In [None]:
@torch.no_grad()
def test(model, predictor, edge_attr, x, emb_ea, adj_t, pos_valid, neg_valid, pos_test, neg_test, evaluator, batch_size):
    model.eval()
    predictor.eval()

    h = model(x, adj_t, edge_attr, emb_ea)

    pos_valid_edge = pos_valid.to(x.device)
    neg_valid_edge = neg_valid.to(x.device)
    pos_test_edge = pos_test.to(x.device)
    neg_test_edge = neg_test.to(x.device)

    pos_valid_preds = []
    for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
        edge = pos_valid_edge[perm].t()
        pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

    neg_valid_preds = []
    for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
        edge = neg_valid_edge[perm].t()
        neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    pos_test_preds = []
    for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
        edge = pos_test_edge[perm].t()
        pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_test_pred = torch.cat(pos_test_preds, dim=0)

    neg_test_preds = []
    for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
        edge = neg_test_edge[perm].t()
        neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_test_pred = torch.cat(neg_test_preds, dim=0)

    results = {}
    for K in [20, 50, 100]:
        evaluator.K = K
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_valid_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (valid_hits, test_hits)

    return results, pos_valid_pred, neg_valid_pred, pos_test_pred, neg_test_pred

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_layers = 2
num_samples = 5
node_emb = 256
hidden_channels = 256
dropout = 0.3
batch_size = 64 * 1024
lr = 0.003
epochs = 400
log_steps = 1
eval_steps = 1
runs = 2

In [None]:
edge_index = data.edge_index.to(device)

In [None]:
edge_index

tensor([[3408, 2126, 3408,  ...,  254,  742, 2975],
        [2126, 3408, 3303,  ...,  618, 2975,  742]], device='cuda:0')

In [None]:
model = GraphSAGE(node_emb, hidden_channels, hidden_channels,
                      num_layers, dropout).to(device)
emb = torch.nn.Embedding(data.num_nodes, node_emb).to(device)
emb_ea = torch.nn.Embedding(num_samples, node_emb).to(device)
predictor = LinkPredictor(hidden_channels, hidden_channels, 1,
                              num_layers+1, dropout).to(device)

In [None]:
print('Number of parameters:',
          sum(p.numel() for p in list(model.parameters()) +
          list(predictor.parameters()) + list(emb.parameters()) + list(emb_ea.parameters())))

Number of parameters: 1301761


In [None]:
np.random.seed(0)
nx_graph = to_networkx(data, to_undirected=True)
node_mask = []
for _ in range(num_samples):
    node_mask.append(np.random.choice(500, size=200, replace=False))
node_mask = np.array(node_mask)
node_subset = np.random.choice(nx_graph.number_of_nodes(), size=500, replace=False)
spd = get_spd_matrix(G=nx_graph, S=node_subset, max_spd=5)
spd = torch.Tensor(spd).to(device)

In [None]:
edge_attr = spd[edge_index, :].mean(0)[:, node_mask].mean(2)

In [None]:
a_max = torch.max(edge_attr, dim=0, keepdim=True)[0]
a_min = torch.min(edge_attr, dim=0, keepdim=True)[0]
edge_attr = (edge_attr - a_min) / (a_max - a_min + 1e-6)

In [None]:
import argparse

In [None]:
parser = argparse.ArgumentParser(description='Link_Pred_DDI')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--num_samples', type=int, default=5)
parser.add_argument('--node_emb', type=int, default=256)
parser.add_argument('--hidden_channels', type=int, default=256)
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--batch_size', type=int, default=64 * 1024)
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--epochs', type=int, default=400)
parser.add_argument('--log_steps', type=int, default=1)
parser.add_argument('--eval_steps', type=int, default=1)
parser.add_argument('--runs', type=int, default=2)
parser.add_argument("-f", "--file", required=False)
args = parser.parse_args()
print(args)

Namespace(device=0, num_layers=2, num_samples=5, node_emb=256, hidden_channels=256, dropout=0.3, batch_size=65536, lr=0.003, epochs=400, log_steps=1, eval_steps=1, runs=2, file='/root/.local/share/jupyter/runtime/kernel-f5e7f22a-f782-4547-878b-9ddb736657ef.json')


In [None]:
evaluator = Evaluator(name='ogbl-ddi')
loggers = {
    'Hits@20': Logger(runs, args),
    'Hits@50': Logger(args.runs, args),
    'Hits@100': Logger(args.runs, args),
}

In [None]:
for run in range(runs):
    random.seed(run)
    torch.manual_seed(run)
    torch.nn.init.xavier_uniform_(emb.weight)
    torch.nn.init.xavier_uniform_(emb_ea.weight)
    model.reset_parameters()
    predictor.reset_parameters()
    optimizer = torch.optim.Adam(
        list(model.parameters()) + list(emb.parameters()) +
        list(emb_ea.parameters()) + list(predictor.parameters()), lr=lr)

    for epoch in range(1, 1 + epochs):
        loss = train(model, predictor, edge_attr, emb.weight, emb_ea.weight, edge_index, train_pos,
                      optimizer, batch_size)

        if epoch % eval_steps == 0:
            results, pos_valid_pred, neg_valid_pred, pos_test_pred, neg_test_pred = test(model, predictor, edge_attr, emb.weight, emb_ea.weight, edge_index, valid_pos, valid_neg, test_pos, test_neg,
                            evaluator, batch_size)
            for key, result in results.items():
                loggers[key].add_result(run, result)

            if epoch % log_steps == 0:
                for key, result in results.items():
                    valid_hits, test_hits = result
                    print(key)
                    print(f'Run: {run + 1:02d}, '
                          f'Epoch: {epoch:02d}, '
                          f'Loss: {loss:.4f}, '
                          f'Valid: {100 * valid_hits:.2f}%, '
                          f'Test: {100 * test_hits:.2f}%')
                print('---')

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics(run)

for key in loggers.keys():
    print(key)
    loggers[key].print_statistics()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
---
Hits@20
Run: 01, Epoch: 92, Loss: 0.2041, Valid: 61.70%, Test: 49.30%
Hits@50
Run: 01, Epoch: 92, Loss: 0.2041, Valid: 67.34%, Test: 77.16%
Hits@100
Run: 01, Epoch: 92, Loss: 0.2041, Valid: 69.79%, Test: 88.48%
---
Hits@20
Run: 01, Epoch: 93, Loss: 0.2040, Valid: 63.89%, Test: 15.05%
Hits@50
Run: 01, Epoch: 93, Loss: 0.2040, Valid: 68.10%, Test: 57.46%
Hits@100
Run: 01, Epoch: 93, Loss: 0.2040, Valid: 69.89%, Test: 85.68%
---
Hits@20
Run: 01, Epoch: 94, Loss: 0.2047, Valid: 63.13%, Test: 44.88%
Hits@50
Run: 01, Epoch: 94, Loss: 0.2047, Valid: 67.85%, Test: 74.94%
Hits@100
Run: 01, Epoch: 94, Loss: 0.2047, Valid: 70.26%, Test: 89.20%
---
Hits@20
Run: 01, Epoch: 95, Loss: 0.2025, Valid: 65.04%, Test: 24.34%
Hits@50
Run: 01, Epoch: 95, Loss: 0.2025, Valid: 67.88%, Test: 77.71%
Hits@100
Run: 01, Epoch: 95, Loss: 0.2025, Valid: 70.10%, Test: 89.07%
---
Hits@20
Run: 01, Epoch: 96, Loss: 0.2029, Valid: 63.04%, Test: 56.80%
H

In [None]:
pos_valid_pred, neg_valid_pred, pos_test_pred, neg_test_pred

tensor([0.8543, 0.9810, 0.9821,  ..., 0.9696, 0.9829, 0.9816])

In [None]:
with open('pos_valid_pred.pickle', 'wb') as handle:
    pickle.dump(pos_valid_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('neg_valid_pred.pickle', 'wb') as handle:
    pickle.dump(neg_valid_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('pos_test_pred.pickle', 'wb') as handle:
    pickle.dump(pos_test_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('neg_test_pred.pickle', 'wb') as handle:
    pickle.dump(neg_test_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
emb = node_features.to(device)

In [None]:
data = T.ToSparseTensor()(data)

In [None]:
adj_t = data.adj_t.to(device)

In [None]:
adj_t

SparseTensor(row=tensor([   0,    0,    0,  ..., 3538, 3538, 3538], device='cuda:0'),
             col=tensor([  29,   31,   39,  ..., 2976, 3034, 3356], device='cuda:0'),
             size=(3539, 3539), nnz=1774468, density=14.17%)

In [None]:
class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, aggr="add"):
        super(SAGE, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels, normalize=True, aggr=aggr))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels, normalize=True, aggr=aggr))
        self.convs.append(SAGEConv(hidden_channels, out_channels, normalize=True, aggr=aggr))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x

class DotProductLinkPredictor(torch.nn.Module):
    def __init__(self):
        super(DotProductLinkPredictor, self).__init__()

    def forward(self, x_i, x_j):
        out = (x_i*x_j).sum(-1)
        return torch.sigmoid(out)

    def reset_parameters(self):
      pass

In [None]:
# Initialize our model and LinkPredictor
hidden_dimension = 256
model = SAGE(2048, hidden_dimension, hidden_dimension, 7, 0.5).to(device)
predictor = DotProductLinkPredictor().to(device)

# Run our initial "node features" through the GNN to get node embeddings
model.eval()
predictor.eval()
h = model(emb, adj_t)

In [None]:
# Randomly sample some training edges and pass them through our basic predictor
idx = torch.randperm(train_pos.size(0))[:10]
edges = train_pos[idx].t()
predictor(h[edges[0]], h[edges[1]])

tensor([0.7311, 0.7311, 0.7311, 0.7310, 0.7311, 0.7311, 0.7311, 0.7311, 0.7310,
        0.7311], device='cuda:0', grad_fn=<SigmoidBackward0>)

In [None]:
adj_t.t().coo()

(tensor([   0,    0,    0,  ..., 3538, 3538, 3538], device='cuda:0'),
 tensor([  29,   31,   39,  ..., 2976, 3034, 3356], device='cuda:0'),
 None)

In [None]:
def create_train_batch(all_pos_train_edges, perm, edge_index):
    # First, we get our positive edges, reshaping them to the form (2, hidden_dimension)
    pos_edges = all_pos_train_edges[perm].t().to(device)

    # We then sample the negative edges using PyG functionality
    neg_edges = negative_sampling(edge_index, num_nodes=3539,
                                  num_neg_samples=perm.shape[0], method='dense').to(device)

    # Our training batch is just the positive edges concatanted with the negative ones
    train_edge = torch.cat([pos_edges, neg_edges], dim=1)

    # Our labels are all 1 for the positive edges and 0 for the negative ones
    pos_label = torch.ones(pos_edges.shape[1], )
    neg_label = torch.zeros(neg_edges.shape[1], )
    train_label = torch.cat([pos_label, neg_label], dim=0).to(device)

    return train_edge, train_label

def train(model, predictor, x, adj_t, train_edge, loss_fn, optimizer, batch_size, num_epochs, edge_model=False, spd=None):
  # adj_t isn't used everywhere in PyG yet, so we switch back to edge_index for negative sampling
  row, col, edge_attr = adj_t.t().coo()
  edge_index = torch.stack([row, col], dim=0)

  model.train()
  predictor.train()

  model.reset_parameters()
  predictor.reset_parameters()

  all_pos_train_edges = train_edge
  for epoch in range(num_epochs):
    epoch_total_loss = 0
    for perm in DataLoader(range(all_pos_train_edges.shape[0]), batch_size,
                           shuffle=True):
      optimizer.zero_grad()

      train_edge, train_label = create_train_batch(all_pos_train_edges, perm, edge_index)

      # Use the GNN to generate node embeddings
      if edge_model:
        h = model(x, edge_index, spd)
      else:
        h = model(x, adj_t)

      # Get predictions for our batch and compute the loss
      preds = predictor(h[train_edge[0]], h[train_edge[1]])
      loss = loss_fn(preds, train_label)

      epoch_total_loss += loss.item()

      # Update our parameters
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
      torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)
      optimizer.step()
    print(f'Epoch {epoch} has loss {round(epoch_total_loss, 4)}')

In [None]:
def accuracy(pred, label):
  pred_rounded = torch.round(pred)
  accu = torch.eq(pred_rounded, label).sum() / label.shape[0]
  accu = round(accu.item(), 4)
  return accu

@torch.no_grad()
def test(model, predictor, x, adj_t, split_edge_pos, split_edge_neg, evaluator, batch_size, edge_model=False, spd=None):
    model.eval()
    predictor.eval()

    if edge_model:
        # adj_t isn't used everywhere in PyG yet, so we switch back to edge_index
        row, col, edge_attr = adj_t.t().coo()
        edge_index = torch.stack([row, col], dim=0)
        h = model(x, edge_index, spd)
    else:
        h = model(x, adj_t)

    pos_eval_edge = split_edge_pos.to(device)
    neg_eval_edge = split_edge_neg.to(device)

    pos_eval_preds = []
    for perm in DataLoader(range(pos_eval_edge.shape[0]), batch_size):
        edge = pos_eval_edge[perm].t()
        pos_eval_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_eval_pred = torch.cat(pos_eval_preds, dim=0)

    neg_eval_preds = []
    for perm in DataLoader(range(neg_eval_edge.size(0)), batch_size):
        edge = neg_eval_edge[perm].t()
        neg_eval_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_eval_pred = torch.cat(neg_eval_preds, dim=0)

    total_preds = torch.cat((pos_eval_pred, neg_eval_pred), dim=0)
    labels = torch.cat((torch.ones_like(pos_eval_pred), torch.zeros_like(neg_eval_pred)), dim=0)
    acc = accuracy(total_preds, labels)

    results = {}
    for K in [10, 20, 30, 40, 50]:
        evaluator.K = K
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_eval_pred,
            'y_pred_neg': neg_eval_pred,
        })[f'hits@{K}']
        results[f'Hits@{K}'] = (valid_hits)
    results['Accuracy'] = acc

    return results, pos_eval_pred, neg_eval_pred
eval = Evaluator(name='ogbl-ddi')
# ogb Evaluators can be invoked to get their expected format
print(eval.expected_input_format)

==== Expected input format of Evaluator for ogbl-ddi
{'y_pred_pos': y_pred_pos, 'y_pred_neg': y_pred_neg}
- y_pred_pos: numpy ndarray or torch tensor of shape (num_edges, ). Torch tensor on GPU is recommended for efficiency.
- y_pred_neg: numpy ndarray or torch tensor of shape (num_edges, ). Torch tensor on GPU is recommended for efficiency.
y_pred_pos is the predicted scores for positive edges.
y_pred_neg is the predicted scores for negative edges.
Note: As the evaluation metric is ranking-based, the predicted scores need to be different for different edges.


In [None]:
optimizer = torch.optim.Adam(
            list(model.parameters())  +
            list(predictor.parameters()), lr=0.01)
train(model, predictor, emb, adj_t, train_pos, torch.nn.BCELoss(),
      optimizer, 64 * 1024, 30)
test(model, predictor, emb, adj_t, valid_pos, valid_neg, Evaluator(name='ogbl-ddi'), 64*1024)

Epoch 0 has loss 9.783
Epoch 1 has loss 9.195
Epoch 2 has loss 8.5566
Epoch 3 has loss 8.2481
Epoch 4 has loss 8.2215
Epoch 5 has loss 8.1786
Epoch 6 has loss 8.1583
Epoch 7 has loss 8.1546
Epoch 8 has loss 8.1457
Epoch 9 has loss 8.1805
Epoch 10 has loss 8.1378
Epoch 11 has loss 8.1436
Epoch 12 has loss 8.2384
Epoch 13 has loss 8.1381
Epoch 14 has loss 8.1053
Epoch 15 has loss 8.0789
Epoch 16 has loss 8.0385
Epoch 17 has loss 8.0132
Epoch 18 has loss 7.9893
Epoch 19 has loss 7.9629
Epoch 20 has loss 7.9628
Epoch 21 has loss 7.9348
Epoch 22 has loss 7.9473
Epoch 23 has loss 7.9157
Epoch 24 has loss 7.9022
Epoch 25 has loss 7.8883
Epoch 26 has loss 7.9069
Epoch 27 has loss 7.8904
Epoch 28 has loss 7.8803
Epoch 29 has loss 7.8885


#Model Enhancement

In [None]:
class SkipConnSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_dimension, out_channels, num_layers,
                 dropout):
        super(SkipConnSAGE, self).__init__()

        self.convs = torch.nn.ModuleList()

        self.convs.append(SAGEConv(in_channels, hidden_dimension, normalize=True, aggr="add"))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_dimension, hidden_dimension, normalize=True, aggr="add"))
        self.convs.append(SAGEConv(hidden_dimension, out_channels, normalize=True, aggr="add"))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        prev_x = None
        for i in range(len(self.convs) - 1):
          prev_x = x
          x = self.convs[i](x, adj_t)
          # Skip Connection
          if i > 0:
            x = x + prev_x
          x = F.relu(x)
          x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x

In [None]:
class PostProcessSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_dimension, out_channels, num_conv_layers,
                 num_linear_layers, dropout):
        super(PostProcessSAGE, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleList()

        self.convs.append(SAGEConv(in_channels, hidden_dimension, normalize=True, aggr="add"))
        for _ in range(num_conv_layers - 1):
            self.convs.append(SAGEConv(hidden_dimension, hidden_dimension, normalize=True, aggr="add"))

        for _ in range(num_linear_layers - 1):
            self.lins.append(torch.nn.Linear(hidden_dimension, hidden_dimension))
        self.lins.append(torch.nn.Linear(hidden_dimension, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)

        # Post-process
        for lin in self.lins[:-1]:
          x = lin(x)
          x = F.relu(x)
        x = self.lins[-1](x)
        return x

In [None]:
class NeuralLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(NeuralLinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x).squeeze()

In [None]:
model = SAGE(2048 , hidden_dimension, hidden_dimension, 5, 0.3).to(device)
predictor = NeuralLinkPredictor(hidden_dimension, hidden_dimension, 1, 4, 0.3).to(device)
optimizer = torch.optim.Adam(
            list(model.parameters())  +
            list(predictor.parameters()), lr=0.01)
train(model, predictor, emb, adj_t, train_pos, torch.nn.BCELoss(),
      optimizer, 64 * 1024, 50)
test(model, predictor, emb, adj_t, valid_pos, valid_neg, Evaluator(name='ogbl-ddi'), 64*1024)

Epoch 0 has loss 9.3322
Epoch 1 has loss 8.3561
Epoch 2 has loss 7.9216
Epoch 3 has loss 8.1477
Epoch 4 has loss 7.1912
Epoch 5 has loss 6.5159
Epoch 6 has loss 6.1057
Epoch 7 has loss 5.9734
Epoch 8 has loss 5.9502
Epoch 9 has loss 5.8657
Epoch 10 has loss 5.807
Epoch 11 has loss 5.7704
Epoch 12 has loss 5.9188
Epoch 13 has loss 5.862
Epoch 14 has loss 5.7531
Epoch 15 has loss 5.7001
Epoch 16 has loss 5.6568
Epoch 17 has loss 5.7993
Epoch 18 has loss 5.7442
Epoch 19 has loss 5.4495
Epoch 20 has loss 5.3639
Epoch 21 has loss 5.2336
Epoch 22 has loss 5.1209
Epoch 23 has loss 5.0228
Epoch 24 has loss 4.8656
Epoch 25 has loss 4.8159
Epoch 26 has loss 4.6775
Epoch 27 has loss 4.5975
Epoch 28 has loss 4.5301
Epoch 29 has loss 4.7629
Epoch 30 has loss 4.5344
Epoch 31 has loss 4.4697
Epoch 32 has loss 4.4174
Epoch 33 has loss 4.415
Epoch 34 has loss 4.3874
Epoch 35 has loss 4.4274
Epoch 36 has loss 4.3698
Epoch 37 has loss 4.4142
Epoch 38 has loss 4.3327
Epoch 39 has loss 4.4014
Epoch 40 has 

({'Hits@10': 0.13005210723831637,
  'Hits@20': 0.17434775956874285,
  'Hits@30': 0.2434910949719665,
  'Hits@40': 0.2606351863351242,
  'Hits@50': 0.27728430393189163,
  'Accuracy': 0.8741},
 tensor([0.6843, 0.9283, 0.9694,  ..., 0.8228, 0.9717, 0.9760]),
 tensor([0.2969, 0.0221, 0.6009,  ..., 0.0447, 0.1054, 0.1150]))

In [None]:
G = convert.to_networkx(data, to_undirected=True)

NameError: ignored

In [None]:
K = 200
sampled_nodes = sorted(random.sample(G.nodes, K))

spd = torch.ones(3539, K, dtype=torch.float64).to(device)
for k in range(K):
  distance_from_sample_k_to_all_nodes = nx.shortest_path_length(G, source=sampled_nodes[k])
  for node in distance_from_sample_k_to_all_nodes:
    spd[node][k] = distance_from_sample_k_to_all_nodes[node]
spd = spd.float()
spd

In [None]:
class SAGEConvWithEdgesConceptual(MessagePassing):
    def __init__(self, in_channels,
                 out_channels, normalize = False,
                 root_weight = True,
                 bias = True, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super(SAGEConvWithEdges, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_l = torch.nn.Linear(in_channels[0], out_channels, bias=bias)
        self.lin_e = torch.nn.Linear(1, in_channels[0], bias=bias)
        if self.root_weight:
            self.lin_r = torch.nn.Linear(in_channels[1], out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_e.reset_parameters()
        if self.root_weight:
            self.lin_r.reset_parameters()


    def forward(self, x, edge_index, spd, size = None):
        if isinstance(x, Tensor):
            x = (x, x)
        out = self.propagate(edge_index, x=x, spd=spd)
        out = self.lin_l(out)

        x_r = x[1]
        if self.root_weight and x_r is not None:
            out += self.lin_r(x_r)

        if self.normalize:
          out = F.normalize(out, p=2., dim=-1)
        return out


    def message(self, x_j, spd_i, spd_j):
        dist_mean = torch.mean(spd_i + spd_j, 1, True)
        return F.relu(x_j + self.lin_e(dist_mean))


    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

In [None]:
class SAGEConvWithEdges(MessagePassing):

    def __init__(self, in_channels,
                 out_channels, normalize = False,
                 root_weight = True,
                 bias = True, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super(SAGEConvWithEdges, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_l = torch.nn.Linear(in_channels[0], out_channels, bias=bias)
        self.lin_e = torch.nn.Linear(1, in_channels[0])
        if self.root_weight:
            self.lin_r = torch.nn.Linear(in_channels[1], out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_e.reset_parameters()
        if self.root_weight:
            self.lin_r.reset_parameters()


    def forward(self, x, edge_index, spd, size = None):
        if isinstance(x, Tensor):
            x = (x, x)

        spd = torch.sum(spd, dim=1, keepdim=True) / spd.shape[1]
        spd = self.lin_e(spd)

        out = self.propagate(edge_index, x=x, spd=spd)
        out = self.lin_l(out)

        x_r = x[1]
        if self.root_weight and x_r is not None:
            out += self.lin_r(x_r)

        if self.normalize:
            out = F.normalize(out, p=2., dim=-1)

        return out


    def message(self, x_j, spd_i, spd_j):
        dist_mean = F.relu(spd_i + spd_j)
        return x_j + dist_mean


    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)


In [None]:
class EdgeSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout, aggr="mean"):
        super(EdgeSAGE, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConvWithEdges(in_channels, hidden_channels, normalize=True, aggr=aggr))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConvWithEdges(hidden_channels, hidden_channels, normalize=True, aggr=aggr))
        self.convs.append(SAGEConvWithEdges(hidden_channels, out_channels, normalize=True, aggr=aggr))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, edge_index, spd):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index, spd)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index, spd)
        return x

In [None]:
hidden_dimension = 1024
model = EdgeSAGE(2048, hidden_dimension, hidden_dimension, 5, 0.3, aggr="add").to(device)
predictor = NeuralLinkPredictor(hidden_dimension, hidden_dimension, 1, 4, 0.3).to(device)
optimizer = torch.optim.Adam(
            list(model.parameters())  +
            list(predictor.parameters()), lr=0.003)
train(model, predictor, emb, adj_t, train_pos, torch.nn.BCELoss(),
      optimizer, 64 * 1024, 100, edge_model=True, spd=spd)
test(model, predictor, emb, adj_t, valid_pos, valid_neg, Evaluator(name='ogbl-ddi'), 64*1024, edge_model=True, spd=spd)

In [None]:
KEKW, pos, neg = test(model, predictor, emb, adj_t, valid_pos, valid_neg, Evaluator(name='ogbl-ddi'), 64*1024)