In [1]:
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install -q torch-geometric
!pip install ogb

[K     |████████████████████████████████| 2.6MB 3.8MB/s 
[K     |████████████████████████████████| 1.5MB 8.5MB/s 
[K     |████████████████████████████████| 194kB 8.7MB/s 
[K     |████████████████████████████████| 235kB 43.4MB/s 
[K     |████████████████████████████████| 2.2MB 47.7MB/s 
[K     |████████████████████████████████| 51kB 8.5MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone
Collecting ogb
[?25l  Downloading https://files.pythonhosted.org/packages/34/47/16573587124ee85c8255cebd30c55981fa78c815eaff966ff111fb11c32c/ogb-1.3.0-py3-none-any.whl (67kB)
[K     |████████████████████████████████| 71kB 5.0MB/s 
Collecting outdated>=0.2.0
  Downloading https://files.pythonhosted.org/packages/86/70/2f166266438a30e94140f00c99c0eac1c45807981052a1d4c123660e1323/outdated-0.2.0.tar.gz
Collecting littleutils
  Downloading https://files.pythonhosted.org/packages/4e/b1/bb4e06f010947d67349f863b6a2ad71577f85590180a935f60543f622652/littleutils-0.2.2.tar.gz
Buil

In [2]:
import torch
import copy
import numpy as np
import networkx as nx
import random
import pandas as pd

import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch_geometric.utils import negative_sampling, to_networkx
from torch_geometric.nn import GCNConv, SAGEConv, TAGConv, JumpingKnowledge

import torch_sparse

In [3]:
!nvidia-smi

Sun Mar 21 20:50:36 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.56       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   43C    P8    10W /  70W |      3MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Load Dataset

In [4]:
from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device: {}'.format(device))

dataset = PygLinkPropPredDataset(name='ogbl-ddi')
print('Task type: {}'.format(dataset.task_type))
graph = dataset[0]
# print(graph)

# Train-Val-Test split
split_idx = dataset.get_edge_split()

evaluator = Evaluator(name = 'ogbl-ddi')
print(evaluator.expected_input_format) 
print(evaluator.expected_output_format) 

# 2 x (Number of edges*2)
edge_index = graph.edge_index
print('edge_index:')
print(edge_index)
print(edge_index.shape)

# Number of nodes x Number of nodes
adj_t = torch_sparse.SparseTensor.from_edge_index(edge_index)
print('adj_t:')
print(adj_t)

  0%|          | 0/46 [00:00<?, ?it/s]

Device: cuda
Downloading http://snap.stanford.edu/ogb/data/linkproppred/ddi.zip


Downloaded 0.04 GB: 100%|██████████| 46/46 [00:15<00:00,  3.02it/s]


Extracting dataset/ddi.zip


  0%|          | 0/1 [00:00<?, ?it/s]

Processing...
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 35.84it/s]
100%|██████████| 1/1 [00:00<00:00, 632.05it/s]


Converting graphs into PyG objects...
Saving...
Done!
Task type: link prediction
==== Expected input format of Evaluator for ogbl-ddi
{'y_pred_pos': y_pred_pos, 'y_pred_neg': y_pred_neg}
- y_pred_pos: numpy ndarray or torch tensor of shape (num_edge, ). Torch tensor on GPU is recommended for efficiency.
- y_pred_neg: numpy ndarray or torch tensor of shape (num_edge, ). Torch tensor on GPU is recommended for efficiency.
y_pred_pos is the predicted scores for positive edges.
y_pred_neg is the predicted scores for negative edges.
Note: As the evaluation metric is ranking-based, the predicted scores need to be different for different edges.
==== Expected output format of Evaluator for ogbl-ddi
{hits@20': hits@20}
- hits@20 (float): Hits@20 score

edge_index:
tensor([[4039, 2424, 4039,  ...,  338,  835, 3554],
        [2424, 4039,  225,  ...,  708, 3554,  835]])
torch.Size([2, 2135822])
adj_t:
SparseTensor(row=tensor([   0,    0,    0,  ..., 4266, 4266, 4266]),
             col=tensor([   4

In [5]:
print('Train edges:', split_idx['train']['edge'].shape)
# print('Train negative edges:', split_idx['train']['edge_neg'].shape)

print('Val edges:', split_idx['valid']['edge'].shape)
print('Val negative edges:', split_idx['valid']['edge_neg'].shape)

print('Test edges:', split_idx['test']['edge'].shape)
print('Test negative edges:', split_idx['test']['edge_neg'].shape)

Train edges: torch.Size([1067911, 2])
Val edges: torch.Size([133489, 2])
Val negative edges: torch.Size([101882, 2])
Test edges: torch.Size([133489, 2])
Test negative edges: torch.Size([95599, 2])


# Preprocess Data for DE-GNN

In [6]:
nx_graph = to_networkx(graph, to_undirected=True)
nx_degree = nx.degree(nx_graph)
nx_pagerank = nx.pagerank(nx_graph)
# slow
nx_clustering = nx.clustering(nx_graph)
nx_centrality = nx.closeness_centrality(nx_graph)

In [7]:
# Shortest Path Distance 
def get_spd_matrix(G, S, max_spd=5):
    spd_matrix = np.zeros((G.number_of_nodes(), max_spd + 1), dtype=np.int32)
    for i, node_S in enumerate(S):
        for node, length in nx.shortest_path_length(G, source=node_S).items():
            spd_matrix[node, min(length, max_spd)] += 1
    return spd_matrix /(len(S))

# Random Walk Landing Probablity (S = all nodes)
def get_lp_matrix(A, max_steps=5):
    W = A / A.sum(1, keepdims=True)
    W_list = [np.identity(A.shape[0])]
    for i in range(max_steps):
        W_list.append(np.matmul(W_list[-1], W))

    W_stack = np.stack(W_list, axis=2)  
    return W_stack.mean(axis=0)

In [8]:
# S = 200 random nodes, slow
np.random.seed(0)
node_subset = np.random.choice(nx_graph.number_of_nodes(), size=200, replace=False)
spd_feature = get_spd_matrix(G=nx_graph, S=node_subset, max_spd=5)
print(spd_feature)
print(spd_feature.shape)

# S = all nodes
lp_feature = get_lp_matrix(adj_t.to_dense(), max_steps=5)
print(lp_feature)
print(lp_feature.shape)

[[0.    0.12  0.74  0.135 0.005 0.   ]
 [0.    0.1   0.565 0.33  0.005 0.   ]
 [0.    0.    0.225 0.725 0.045 0.005]
 ...
 [0.    0.005 0.04  0.74  0.21  0.005]
 [0.    0.01  0.705 0.28  0.005 0.   ]
 [0.    0.    0.575 0.42  0.005 0.   ]]
(4267, 6)
[[2.34356691e-04 2.93526852e-04 2.23200240e-04 2.45474974e-04
  2.20286597e-04 2.23163665e-04]
 [2.34356691e-04 1.80790480e-04 2.31190380e-04 1.76067754e-04
  1.92013031e-04 1.67547330e-04]
 [2.34356691e-04 1.87849777e-06 8.47383599e-05 2.61183075e-06
  3.26818208e-05 2.56206488e-06]
 ...
 [2.34356691e-04 3.70326652e-05 1.97831345e-05 1.53761192e-05
  1.03435513e-05 9.72830004e-06]
 [2.34356691e-04 5.87933213e-05 5.66102474e-05 2.75958480e-05
  2.91710197e-05 2.42242912e-05]
 [2.34356691e-04 4.25467285e-06 3.23737135e-05 8.08685402e-06
  1.12771452e-05 8.52324894e-06]]
(4267, 6)


In [9]:
# Convert to tensor
tensor_degree = torch.Tensor([t[1] for t in nx_degree]).unsqueeze(1)
tensor_pagerank = torch.Tensor([t[1] for t in nx_pagerank.items()]).unsqueeze(1)
tensor_clustering = torch.Tensor([t[1] for t in nx_clustering.items()]).unsqueeze(1)
tensor_centrality = torch.Tensor([t[1] for t in nx_centrality.items()]).unsqueeze(1)

tensor_spd = torch.Tensor(spd_feature)
tensor_lp = torch.Tensor(lp_feature)

# Concat
feature_tensor_list = [tensor_degree, tensor_pagerank, tensor_clustering,
                       tensor_centrality, tensor_spd, tensor_lp]
x_feature = torch.cat(feature_tensor_list, dim=1)
print(x_feature.shape)

torch.Size([4267, 16])


In [None]:
# x_feature_numpy = x_feature.numpy()
# x_df = pd.DataFrame(x_feature_numpy)
# x_df.to_csv('x_feature.csv', index=False)

In [None]:
# Use this chunk to skip preprecessing

# x_df = pd.read_csv('x_feature.csv')
# x_feature_numpy = x_df.to_numpy()
# x_feature = torch.Tensor(x_feature_numpy)
# print(x_feature.shape)

In [10]:
# Normalize to 0-1
x_max = torch.max(x_feature, dim=0, keepdim=True)[0]
x_min = torch.min(x_feature, dim=0, keepdim=True)[0]
x_feature = (x_feature - x_min)/(x_max - x_min + 1e-6)

In [11]:
edge_index = edge_index.to(device)
adj_t = adj_t.to(device)
x_feature = x_feature.to(device)

In [12]:
def train(model, optimizer, evaluator, graph, x_feature, edge_index, adj_t, split_idx,
          batch_size=1024*64, num_epochs=200, save_model=False):
    best_val_score = 0
    best_epoch = 0
    best_test_score = 0
    best_model = model

    all_pos_edges = split_idx['train']['edge'].transpose(0,1).to(device)

    for epoch in range(1, num_epochs+1):
        sum_loss = 0
        count = 0 
        for batch in DataLoader(list(range(all_pos_edges.shape[1])), batch_size=batch_size, shuffle=True):
            model.train()
            batch_pos_edges = all_pos_edges[:, batch]
            batch_neg_edges = negative_sampling(edge_index=edge_index, 
                                            num_nodes=graph.num_nodes,
                                            num_neg_samples=batch_pos_edges.shape[1], 
                                            method='dense').to(device)
            edge_label_index = torch.cat([batch_pos_edges, batch_neg_edges], dim=1).to(device)
          
            pos_label = torch.ones(batch_pos_edges.shape[1], )
            neg_label = torch.zeros(batch_neg_edges.shape[1], )
            edge_label = torch.cat([pos_label, neg_label], dim=0).to(device)

            optimizer.zero_grad()  
            pred = model(x_feature, adj_t, edge_label_index)
            loss = model.loss(pred, edge_label.type_as(pred))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
            optimizer.step()

            sum_loss += loss.item() * edge_label.shape[0]
            count += edge_label.shape[0]

        val_score, test_score = evaluate(model, x_feature, adj_t, split_idx, evaluator)
        if best_val_score < val_score:
            best_val_score = val_score
            best_epoch = epoch
            best_test_score = test_score
            if save_model:
                best_model = copy.deepcopy(model)

        log = 'Epoch: {:03d}, Loss: {:.4f}, Val Hits: {:.2f}%, Test Hits: {:.2f}%'
        print(log.format(epoch, sum_loss/count, 100*val_score, 100*test_score))

    print('Final model:')
    log = 'Epoch: {:03d}, Val Hits: {:.2f}%, Test Hits: {:.2f}%'
    print(log.format(best_epoch, 100*best_val_score, 100*best_test_score))
    return best_model, best_val_score, best_test_score

@torch.no_grad()
def evaluate(model, x_feature, adj_t, split_idx, evaluator):
    model.eval()

    pos_edge_label_index = split_idx['valid']['edge'].transpose(0,1)
    neg_edge_label_index = split_idx['valid']['edge_neg'].transpose(0,1)
    
    y_pred_pos = model(x_feature, adj_t, pos_edge_label_index)
    y_pred_neg = model(x_feature, adj_t, neg_edge_label_index)
    
    score_val = evaluator.eval({'y_pred_pos': y_pred_pos, 'y_pred_neg': y_pred_neg})['hits@20']

    pos_edge_label_index = split_idx['test']['edge'].transpose(0,1)
    neg_edge_label_index = split_idx['test']['edge_neg'].transpose(0,1)
    
    y_pred_pos = model(x_feature, adj_t, pos_edge_label_index)
    y_pred_neg = model(x_feature, adj_t, neg_edge_label_index)
    
    score_test = evaluator.eval({'y_pred_pos': y_pred_pos, 'y_pred_neg': y_pred_neg})['hits@20']

    return (score_val, score_test)

# DEA-GCN-JK 


## Model

In [13]:
class DEA_GNN_JK(torch.nn.Module):
    def __init__(self, num_nodes, embed_dim, 
                 gnn_in_dim, gnn_hidden_dim, gnn_out_dim, gnn_num_layers, 
                 mlp_in_dim, mlp_hidden_dim, mlp_out_dim=1, mlp_num_layers=2, 
                 dropout=0.5, gnn_batchnorm=False, mlp_batchnorm=False, K=2, jk_mode='max'):
        super(DEA_GNN_JK, self).__init__()
        
        assert jk_mode in ['max','sum','mean','lstm','cat']
        # Embedding
        self.emb = torch.nn.Embedding(num_nodes, embedding_dim=embed_dim)

        # GNN 
        convs_list = [TAGConv(gnn_in_dim, gnn_hidden_dim, K)]
        for i in range(gnn_num_layers-2):
            convs_list.append(TAGConv(gnn_hidden_dim, gnn_hidden_dim, K))
        convs_list.append(TAGConv(gnn_hidden_dim, gnn_out_dim, K))
        self.convs = torch.nn.ModuleList(convs_list)

        # MLP
        lins_list = [torch.nn.Linear(mlp_in_dim, mlp_hidden_dim)]
        for i in range(mlp_num_layers-2):
            lins_list.append(torch.nn.Linear(mlp_hidden_dim, mlp_hidden_dim))
        lins_list.append(torch.nn.Linear(mlp_hidden_dim, mlp_out_dim))
        self.lins = torch.nn.ModuleList(lins_list)

        # Batchnorm
        self.gnn_batchnorm = gnn_batchnorm
        self.mlp_batchnorm = mlp_batchnorm
        if self.gnn_batchnorm:
            self.gnn_bns = torch.nn.ModuleList([torch.nn.BatchNorm1d(gnn_hidden_dim) for i in range(gnn_num_layers)])
        
        if self.mlp_batchnorm:
            self.mlp_bns = torch.nn.ModuleList([torch.nn.BatchNorm1d(mlp_hidden_dim) for i in range(mlp_num_layers-1)])

        self.jk_mode = jk_mode
        if self.jk_mode in ['max', 'lstm', 'cat']:
            self.jk = JumpingKnowledge(mode=self.jk_mode, channels=gnn_hidden_dim, num_layers=gnn_num_layers)

        self.dropout = dropout
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.emb.weight)  
        for conv in self.convs:
            conv.reset_parameters()
        for lin in self.lins:
            lin.reset_parameters()
        if self.gnn_batchnorm:
            for bn in self.gnn_bns:
                bn.reset_parameters()
        if self.mlp_batchnorm:
            for bn in self.mlp_bns:
                bn.reset_parameters()
        if self.jk_mode in ['max', 'lstm', 'cat']:
            self.jk.reset_parameters()

    def forward(self, x_feature, adj_t, edge_label_index):
        
        if x_feature is not None:
            out = torch.cat([self.emb.weight, x_feature], dim=1)
        else:
            out = self.emb.weight

        out_list = []
        for i in range(len(self.convs)):
            out = self.convs[i](out, adj_t)
            if self.gnn_batchnorm:
                out = self.gnn_bns[i](out)
            out = F.relu(out)
            out = F.dropout(out, p=self.dropout, training=self.training)
            out_list += [out]

        if self.jk_mode in ['max', 'lstm', 'cat']:
            out = self.jk(out_list)
        elif self.jk_mode == 'mean':
            out_stack = torch.stack(out_list, dim=0)
            out = torch.mean(out_stack, dim=0)
        elif self.jk_mode == 'sum':
            out_stack = torch.stack(out_list, dim=0)
            out = torch.sum(out_stack, dim=0)

        gnn_embed = out[edge_label_index,:]
        embed_product = gnn_embed[0, :, :] * gnn_embed[1, :, :]
        out = embed_product

        for i in range(len(self.lins)-1):
            out = self.lins[i](out)
            if self.mlp_batchnorm:
                out = self.mlp_bns[i](out)
            out = F.relu(out)
            out = F.dropout(out, p=self.dropout, training=self.training)
        out = self.lins[-1](out).squeeze(1)

        return out
    
    def loss(self, y_pred, y_true):
        return self.loss_fn(y_pred, y_true)

## Train

In [14]:
# mlp_out_dim = 1
# gnn_num_layers >= 2
# mlp_num_layers >= 2
# jk_mode = cat, max, lstm，mean, sum 

USE_DE = True
gnn_in_dim = 0

if USE_DE: 
    gnn_in_dim = 256 + x_feature.shape[1]
else:
    gnn_in_dim = 256 

model = DEA_GNN_JK(num_nodes=graph.num_nodes, embed_dim=256, 
               gnn_in_dim=gnn_in_dim, gnn_hidden_dim=256, gnn_out_dim=256, gnn_num_layers=3, 
               mlp_in_dim=256, mlp_hidden_dim=256, mlp_out_dim=1, mlp_num_layers=2, 
               dropout=0.5, gnn_batchnorm=True, mlp_batchnorm=True, K=2, jk_mode='max').to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
print(model)
print('Number of parameters:', sum(p.numel() for p in model.parameters()))

DEA_GNN_JK(
  (emb): Embedding(4267, 256)
  (convs): ModuleList(
    (0): TAGConv(272, 256, K=2)
    (1): TAGConv(256, 256, K=2)
    (2): TAGConv(256, 256, K=2)
  )
  (lins): ModuleList(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=1, bias=True)
  )
  (gnn_bns): ModuleList(
    (0): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (mlp_bns): ModuleList(
    (0): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (jk): JumpingKnowledge(max)
  (loss_fn): BCEWithLogitsLoss()
)
Number of parameters: 1763329


In [15]:
# Multiple runs
RUNS = 10
best_val_scores = np.zeros((RUNS,))
best_test_scores = np.zeros((RUNS,))

for i in range(RUNS):
    random.seed(i+1)
    torch.manual_seed(i+1)
    model.reset_parameters()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

    if USE_DE:
        result = train(model, optimizer, evaluator, graph, x_feature, edge_index, adj_t, split_idx,
                       batch_size=1024*64, num_epochs=400, save_model=False)
    else:
        result = train(model, optimizer, evaluator, graph, None, edge_index, adj_t, split_idx,
                       batch_size=1024*64, num_epochs=400, save_model=False) 
    
    best_val_scores[i] = result[1]
    best_test_scores[i] = result[2]

    print('Run', i+1, 'done.')
    
log = 'Mean Val Hits: {:.4f}, SD Val Hits: {:.4f}'
print(log.format(np.mean(best_val_scores), np.std(best_val_scores, ddof=1)))
log = 'Mean Test Hits: {:.4f}, SD Test Hits: {:.4f}'
print(log.format(np.mean(best_test_scores), np.std(best_test_scores, ddof=1)))

Epoch: 001, Loss: 0.4247, Val Hits: 2.11%, Test Hits: 1.16%
Epoch: 002, Loss: 0.3419, Val Hits: 4.48%, Test Hits: 1.95%
Epoch: 003, Loss: 0.2976, Val Hits: 12.52%, Test Hits: 13.36%
Epoch: 004, Loss: 0.2611, Val Hits: 8.69%, Test Hits: 5.32%
Epoch: 005, Loss: 0.2355, Val Hits: 5.96%, Test Hits: 2.88%
Epoch: 006, Loss: 0.2150, Val Hits: 7.92%, Test Hits: 3.47%
Epoch: 007, Loss: 0.1982, Val Hits: 9.39%, Test Hits: 4.07%
Epoch: 008, Loss: 0.1862, Val Hits: 12.23%, Test Hits: 6.10%
Epoch: 009, Loss: 0.1742, Val Hits: 13.21%, Test Hits: 6.35%
Epoch: 010, Loss: 0.1659, Val Hits: 14.43%, Test Hits: 7.62%
Epoch: 011, Loss: 0.1590, Val Hits: 14.68%, Test Hits: 10.69%
Epoch: 012, Loss: 0.1527, Val Hits: 18.33%, Test Hits: 14.27%
Epoch: 013, Loss: 0.1477, Val Hits: 19.44%, Test Hits: 18.43%
Epoch: 014, Loss: 0.1434, Val Hits: 19.16%, Test Hits: 12.18%
Epoch: 015, Loss: 0.1387, Val Hits: 19.38%, Test Hits: 20.96%
Epoch: 016, Loss: 0.1359, Val Hits: 22.76%, Test Hits: 22.73%
Epoch: 017, Loss: 0.132