In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

In [None]:
!pip install ogb

In [None]:
from ogb.linkproppred import PygLinkPropPredDataset, Evaluator
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import negative_sampling, add_self_loops
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
dataset = PygLinkPropPredDataset(name="ogbl-ddi", root="dataset")
data = dataset[0]

In [None]:
# Forming edge_index
data = T.ToSparseTensor()(data)
row, col, _ = data.adj_t.coo()
data.edge_index = torch.stack([col, row], dim=0)

In [None]:
# creating train test split
split_edge = dataset.get_edge_split()

In [None]:
data = data.to(device)

In [None]:
class GNNStack(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout):
        super(GNNStack, self).__init__()
        conv_model = SAGEConv

        self.convs = torch.nn.ModuleList()
        self.convs.append(conv_model(input_dim, hidden_dim))
        self.dropout = dropout
        self.num_layers = num_layers
  
        for l in range(self.num_layers - 1):
            self.convs.append(conv_model(hidden_dim, hidden_dim))


    def forward(self, x, edge_index):
        for i in range(self.num_layers-1):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, edge_index)

        return x

In [None]:
# GNNStack(512,512,2,0.3)

In [None]:
class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        # Create linear layers
        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

  
    def forward(self, x_i, x_j):
        # x_i and x_j are both of shape (E, D)
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return x

In [None]:
# LinkPredictor(512, 512, 1, 2, 0.3)

In [None]:
def get_pos_neg_edges(split, split_edge, edge_index = None, num_nodes = None, num_neg = None):
  pos_edge = split_edge[split]['edge']
  new_edge_index, _ = add_self_loops(edge_index)
  if split == 'train':
    neg_edge = negative_sampling(new_edge_index, num_nodes=num_nodes, num_neg_samples=pos_edge.size(0) * num_neg, method='sparse')
    assert neg_edge.size(1) == pos_edge.size(0) * num_neg
    neg_src = neg_edge[0]
    neg_dst = neg_edge[1]
    neg_edge = torch.reshape(torch.stack(
        (neg_src, neg_dst), dim=-1), (-1, num_neg, 2))
  else:
    neg_edge = split_edge[split]['edge_neg']
  
  
  return pos_edge, neg_edge

def evaluate_hits(evaluator, pos_val_pred, neg_val_pred,
                  pos_test_pred, neg_test_pred):
    results = {}
    for K in [20, 50, 100]:
        evaluator.K = K
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_val_pred,
            'y_pred_neg': neg_val_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']

        results[f'Hits@{K}'] = (valid_hits, test_hits)

    return results

In [None]:
class Model():
  def __init__(self, num_nodes, emb_dim, device, num_layers, dropout, lr) -> None:
    self.num_nodes = num_nodes
    self.device = device
    self.emb = torch.nn.Embedding(num_nodes, emb_dim).to(device)
    self.encoder = GNNStack(emb_dim,emb_dim,num_layers,dropout).to(device)
    self.predictor = LinkPredictor(emb_dim, emb_dim, 1, num_layers, dropout).to(device)
    self.optimizer = torch.optim.Adam(list(self.emb.parameters()) + list(self.encoder.parameters()) + list(self.predictor.parameters()),lr=lr)
    self.para_list = list(self.encoder.parameters()) + list(self.predictor.parameters()) + list(self.emb.parameters())
  
  def auc_loss(self, pos_out, neg_out, num_neg):
    pos_out = torch.reshape(pos_out, (-1, 1))
    neg_out = torch.reshape(neg_out, (-1, num_neg))
    return torch.square(1 - (pos_out - neg_out)).sum()

  def train(self, data, split_edge, batch_size, num_neg):
    self.encoder.train()
    self.predictor.train()

    pos_train_edge, neg_train_edge = get_pos_neg_edges('train', split_edge,
                                                           edge_index=data.edge_index,
                                                           num_nodes=self.num_nodes,
                                                           num_neg = num_neg)
    
    pos_train_edge, neg_train_edge = pos_train_edge.to(self.device), neg_train_edge.to(self.device)

    total_loss = total_examples = 0
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size, shuffle=True):
      self.optimizer.zero_grad()

      input_feat = self.emb.weight
      h = self.encoder(input_feat, data.adj_t)
      pos_edge = pos_train_edge[perm].t()
  
      neg_edge = torch.reshape(neg_train_edge[perm], (-1, 2)).t()

      pos_out = self.predictor(h[pos_edge[0]], h[pos_edge[1]])
      neg_out = self.predictor(h[neg_edge[0]], h[neg_edge[1]])

      loss = self.auc_loss(pos_out, neg_out, num_neg)
      loss.backward()

      torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), 2)
      torch.nn.utils.clip_grad_norm_(self.predictor.parameters(), 2)

      self.optimizer.step()

      num_examples = pos_out.size(0)
      total_loss += loss.item() * num_examples
      total_examples += num_examples

    return total_loss / total_examples

  @torch.no_grad()
  def batch_predict(self, h, edges, batch_size):
    self.predictor.eval()
    preds = []
    for perm in DataLoader(range(edges.size(0)), batch_size):
        edge = edges[perm].t()
        preds += [self.predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pred = torch.cat(preds, dim=0)
    return pred
                                                    


  @torch.no_grad()
  def test(self, data, split_edge, batch_size, evaluator):
    self.encoder.eval()
    self.predictor.eval()

    input_feat = self.emb.weight
    h = self.encoder(input_feat, data.adj_t)
    # The default index of unseen nodes is -1,
    # hidden representations of unseen nodes is the average of all seen node representations
    # mean_h = torch.mean(h, dim=0, keepdim=True)
    # h = torch.cat([h, mean_h], dim=0)

    pos_valid_edge, neg_valid_edge = get_pos_neg_edges('valid', split_edge, edge_index=data.edge_index,
                                                           num_nodes=self.num_nodes)
                                                           
    pos_test_edge, neg_test_edge = get_pos_neg_edges('test', split_edge, edge_index=data.edge_index,
                                                           num_nodes=self.num_nodes)
    pos_valid_edge, neg_valid_edge = pos_valid_edge.to(self.device), neg_valid_edge.to(self.device)
    pos_test_edge, neg_test_edge = pos_test_edge.to(self.device), neg_test_edge.to(self.device)

    pos_valid_pred = self.batch_predict(h, pos_valid_edge, batch_size)
    neg_valid_pred = self.batch_predict(h, neg_valid_edge, batch_size)

    h = self.encoder(input_feat, data.adj_t)
    # mean_h = torch.mean(h, dim=0, keepdim=True)
    # h = torch.cat([h, mean_h], dim=0)

    pos_test_pred = self.batch_predict(h, pos_test_edge, batch_size)
    neg_test_pred = self.batch_predict(h, neg_test_edge, batch_size)

    results = evaluate_hits(
            evaluator,
            pos_valid_pred,
            neg_valid_pred,
            pos_test_pred,
            neg_test_pred)


    return results

In [None]:
model = Model(data.num_nodes, 512, device, 2, 0.3, 1e-3)

In [None]:
#3_497_473
total_params = sum(p.numel() for param in model.para_list for p in param)
total_params_print = f'Total number of model parameters is {total_params}'
total_params_print

In [None]:
evaluator = Evaluator(name='ogbl-ddi')
print(evaluator.expected_input_format)

In [None]:
epochs = 1000
batch_size = 65536
num_neg = 3

In [None]:
train_loss = []
val_hits = []
test_hits = []

for epoch in range(1, 1 + epochs):
  loss = model.train(data, split_edge,batch_size=batch_size,num_neg=num_neg)
  print(f"Epoch {epoch}: loss: {round(loss, 5)}")
  train_loss.append(loss)

  if epoch % 10 == 0:
      results = model.test(data, split_edge,
                            batch_size=batch_size,
                            evaluator=evaluator)
      val_hits.append(results['Hits@20'][0])
      test_hits.append(results['Hits@20'][1])
      print(results)


In [None]:
plt.title('Link Prediction on OGB-ddi using GraphSAGE GNN - Loss Curve')
plt.plot(train_loss,label="training loss")
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('loss-curve-ddi.png', dpi = 300)
plt.show()

In [None]:
plt.title('Link Prediction on OGB-ddi using GraphSAGE GNN - Hits@20')
plt.plot(np.arange(9,epochs,10),val_hits,label="Hits@20 on validation")
plt.plot(np.arange(9,epochs,10),test_hits,label="Hits@20 on test")
plt.xlabel('Epochs')
plt.xlabel('Accuracy')
plt.legend()
plt.savefig('Hits@20-curve-ddi.png', dpi = 300)
plt.show()