# Imports & Standard stuff

In [2]:
import torch
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
import torch.nn.functional as F
from torch_geometric.loader import LinkNeighborLoader
from tqdm.notebook import  trange, tqdm
from ogb.linkproppred import Evaluator
import wandb
import numpy as np
from torch_geometric.sampler import NegativeSampling

In [3]:
# For printing precise tensor values rather than round-off (For Qualitative Analysis)
torch.set_printoptions(precision=10)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model Definitions

In [5]:
class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(SAGE, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x

In [6]:
class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels*2, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        # x = x_i * x_j
        x = torch.concat([x_i,x_j], dim=1)
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)

# Loading Dataset

In [7]:
EMBEDDINGS = "ERNIE"

In [8]:
data = torch.load("./Data/Wiki_en_small_MHRWS_1M_ERNIE.pt")
data.num_nodes = data.x.size(0)
# data.x = data.x.float()

In [9]:
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=False, add_negative_train_samples=False)
train_data, val_data, test_data = transform(data)

In [10]:
train_data

Data(x=[1000000, 768], edge_index=[2, 9516737], num_nodes=1000000, edge_label=[9516737], edge_label_index=[2, 9516737])

# Negative Sampling & Loader

In [11]:
train_negative_sampling = NegativeSampling(mode='binary', amount=1)
val_negative_sampling = NegativeSampling(mode='triplet', amount=100)

In [12]:
3072/2

1536.0

In [14]:
NUM_NBRS_PER_HOP = [20,20] #20 Nodes will be sampled from the first hop and 20 from second hop
TRAIN_BATCH = 1356
VAL_BATCH = 2048

train_loader = LinkNeighborLoader(train_data, 
                                  num_neighbors=NUM_NBRS_PER_HOP, # FIX THIS WRT TO NUM_LAYERS
                                  batch_size=TRAIN_BATCH,
                                  neg_sampling=train_negative_sampling,
                                  edge_label_index=train_data.edge_label_index,
                                  edge_label=train_data.edge_label,
                                  shuffle=True)

val_loader = LinkNeighborLoader(val_data, 
                                  num_neighbors=NUM_NBRS_PER_HOP, #FIX THIS WRT TO NUM_LAYERS
                                  batch_size=VAL_BATCH,
                                  neg_sampling=val_negative_sampling,
                                  edge_label_index=val_data.edge_label_index,
                                  shuffle=True)

# Model Object 

In [15]:
HIDDEN_CHANNELS = 512
NUM_LAYERS = 2
DROPOUT = 0.1

In [16]:
model = SAGE(data.num_features, HIDDEN_CHANNELS,
             HIDDEN_CHANNELS, NUM_LAYERS,
             DROPOUT).to(device)

In [17]:
model

SAGE(
  (convs): ModuleList(
    (0): SAGEConv(768, 512, aggr=mean)
    (1): SAGEConv(512, 512, aggr=mean)
  )
)

In [18]:
link_predictor = LinkPredictor(HIDDEN_CHANNELS, 1024, 1,
                          4,DROPOUT).to(device)

# Train Function

In [19]:
def train(model, link_predictor, train_loader, optimizer, cached_embeddings, GAMMA=2, ALPHA=0.5):
    model.train()
    link_predictor.train()
    total_loss = total_examples = 0

    pos_means = []
    neg_means = []
    for idx,batch in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()

        #Sending the batch to device
        x, edge_index = batch.x.to(device), batch.edge_index.to(device)
        
        #Forward Pass through the GNN
        embed = model(x, edge_index)
        
        #Saving the Embeddings to the Cache 
        # These are only the nodes who's parameterized random k-hop neighbourhood is present in the batch
        complete_embeddings = embed[:batch.num_sampled_nodes[0]].cpu()  
        # These are the corresponding positions of the embeddings in the global cache
        global_indices = batch.n_id[:batch.num_sampled_nodes[0]].cpu()  
        # Here we update the cache
        cached_embeddings[global_indices] = complete_embeddings 

        #Extracting Edge_Embeddings for Positive Edges
        pos_edges = batch.edge_label_index[:,batch.edge_label.bool()].to(device)
        start_node_embeddings, end_node_embeddings = embed[pos_edges[0]], embed[pos_edges[1]]
        pos_op = link_predictor(start_node_embeddings, end_node_embeddings)
        pos_loss = -(ALPHA * torch.pow(1-pos_op, GAMMA) * torch.log(pos_op + 1e-15)).mean()

        #Extracting Edge_Embedding for Negative Edges
        neg_edges = batch.edge_label_index[:,~batch.edge_label.bool()].to(device)
        start_node_embeddings, end_node_embeddings = embed[neg_edges[0]], embed[neg_edges[1]]
        neg_op = link_predictor(start_node_embeddings, end_node_embeddings)
        neg_loss = -((1-ALPHA)*torch.pow(neg_op, GAMMA) * torch.log(1 - neg_op + 1e-15)).mean()

        loss = pos_loss + neg_loss
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(link_predictor.parameters(), 1.0)
        optimizer.step()
        num_examples = pos_op.size(0)
        batch_loss = loss.item() * num_examples
        total_loss += batch_loss

        total_examples += num_examples
        wandb.log({"train_loss":loss.item(), 
                   "pos_mean":pos_op.mean().item(), 
                   "neg_mean":neg_op.mean().item(),
                   "pos_min":pos_op.min().item(),
                   "neg_max":neg_op.max().item()
                  })
    return total_loss/total_examples, cached_embeddings

# Evaluation Function

In [20]:
def evaluate(link_predictor, val_loader, cached_embedding, evaluator, log_flag=True):
    link_predictor.eval()
    pos_preds = []
    neg_preds = []
    total_loss = total_examples = 0
    for idx,batch in enumerate(tqdm(val_loader)):
        with torch.no_grad():
            #First convert the local node_ids to global node_ids
            #For positive edges
            global_src_idx = batch.n_id[batch.src_index]
            global_dest_idx = batch.n_id[batch.dst_pos_index]
            # Here we are reading via cache as we dont need the gradients to flow
            start_node_embeddings= cached_embedding[global_src_idx].to(device) 
            end_node_embeddings  = cached_embedding[global_dest_idx].to(device)
            
            pos_op = link_predictor(start_node_embeddings, end_node_embeddings)
            pos_loss = -torch.log(pos_op + 1e-15).mean()
            pos_scores = pos_op.squeeze().cpu()
            pos_preds += [pos_scores]

            #For negative Edges 
            all_global_src_idx = batch.n_id[batch.src_index]
            
            # repeating src for c corruptions per src_index c = batch.dst_neg_index.size(1)
            global_src_idx = all_global_src_idx.repeat(batch.dst_neg_index.size(1),1).T.flatten()
            global_dest_idx = batch.n_id[batch.dst_neg_index.flatten()]
            
            # Retrieve node embeddings from cache
            start_node_embeddings = cached_embedding[global_src_idx].to(device)
            end_node_embeddings = cached_embedding[global_dest_idx].to(device)
            
            # Pass embeddings through the link predictor
            neg_op = link_predictor(start_node_embeddings, end_node_embeddings)
            neg_loss = -torch.log(1 - neg_op + 1e-15).mean()
            
            # Reshape the output predictions to match the original format
            neg_scores = neg_op.squeeze().cpu()
            neg_preds += [neg_scores]

            # Updating Loss  
            loss = pos_loss + neg_loss
            num_examples = (pos_op.size(0) + neg_op.size(0))/2
            batch_loss = loss.item() * num_examples
            total_loss += batch_loss
            total_examples += num_examples
                
    pos_pred = torch.cat(pos_preds, dim=0)
    neg_pred = torch.cat(neg_preds, dim=0).reshape(pos_pred.size(0),-1)

    mrr = evaluator.eval({
                'y_pred_pos': pos_pred,
                'y_pred_neg': neg_pred,
            })['mrr_list'].mean().item()
    
    wandb.log({"val_loss":total_loss, 
               "mrr":mrr
              })
    return total_loss, mrr

# Training & Validation Loop

In [21]:
evaluator = Evaluator(name='ogbl-citation2')

In [22]:
RUNS = 1
EPOCHS = 10
EVAL_STEP = 1
LR = 0.01
OUT_CHANNELS = HIDDEN_CHANNELS
num_nodes = data.num_nodes

In [23]:
cached_embeddings = torch.zeros([num_nodes,HIDDEN_CHANNELS], dtype=torch.float32)

In [None]:
wandb.init(
    project="SBERT+SAGE",
    config={"RUNS":RUNS,
            "EPOCHS":EPOCHS,
            "HIDDEN_CHANNELS":HIDDEN_CHANNELS,
            "EMBEDDINGS":EMBEDDINGS,
            "NUM_LAYERS":NUM_LAYERS,
            }
) 
for run in trange(RUNS):
    #Re-setting the parameters
    model.reset_parameters()
    link_predictor.reset_parameters()

    #Optimizer
    optimizer = torch.optim.Adam(list(model.parameters()) + list(link_predictor.parameters()))
    
    for epoch in trange(EPOCHS):
        loss, cached_embedding = train(model, link_predictor, train_loader, optimizer, cached_embeddings)
        print(f"For {run=} and {epoch=} training {loss=}")

        
        if (epoch%EVAL_STEP == 0):
            val_loss, mrr  = evaluate(link_predictor, val_loader, cached_embedding, evaluator, log_flag=True)
            print(f"For {run=} and for {epoch=} the {val_loss=} with {mrr=}")


[34m[1mwandb[0m: Currently logged in as: [33madityakadam[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7019 [00:00<?, ?it/s]

For run=0 and epoch=0 training loss=0.03231949127207814


  0%|          | 0/1328 [00:00<?, ?it/s]

For run=0 and for epoch=0 the val_loss=375979941.66914654 with mrr=0.38483351469039917


  0%|          | 0/7019 [00:00<?, ?it/s]

For run=0 and epoch=1 training loss=0.023707562000722313


  0%|          | 0/1328 [00:00<?, ?it/s]

For run=0 and for epoch=1 the val_loss=404230995.0028391 with mrr=0.39071738719940186


  0%|          | 0/7019 [00:00<?, ?it/s]