DRKG

Adapted from: https://github.com/gnn4dr/DRKG/blob/master/drkg_with_dgl/loading_drkg_in_dgl.ipynb

In [1]:
import pandas as pd
import numpy as np
import os 
import torch_geometric.transforms as T
import torch
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_geometric
from torch_geometric.nn import SAGEConv, to_hetero
from   torch.utils.data      import Dataset, DataLoader
from   torch_geometric.data  import Data
from   torch_geometric.utils import negative_sampling

from torch_geometric.nn import SAGEConv, to_hetero


import csv

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from SIMP_LLM.dataloader_mappings import load_graph


In [3]:
def load_csv_as_list(file_path):
    data = []
    set_ = set()
    with open(file_path, 'r') as file:
        csv_reader = csv.reader(file)
        next(csv_reader)
        for row in csv_reader:
            if row[3] == 'Atc' or row[3] == "Tax":
                continue
            else:
                data.append((row[1],row[2],row[3]))
    return data
triplets = load_csv_as_list('triplets.csv')
#triplets =  list(set(triplets))

In [4]:
triplets =[('Compound', 'Compound_treats_the_disease', 'Disease')]
data = load_graph(triplets )
#data

data2


In [5]:
dictionaries = torch.load("data2/mapping_dict")
#dictionaries["Compound"]

## GRAPH SAGE

In [12]:




class GNNStack(torch.nn.Module):
    def __init__(self, input_dim:int, hidden_dim:int, output_dim:int, layers:int, dropout:float=0.3, return_embedding=False):
        """
            A stack of GraphSAGE Module 
            input_dim        <int>:   Input dimension
            hidden_dim       <int>:   Hidden dimension
            output_dim       <int>:   Output dimension
            layers           <int>:   Number of layers
            dropout          <float>: Dropout rate
            return_embedding <bool>:  Whether to return the return_embeddingedding of the input graph
        """
        
        super(GNNStack, self).__init__()
        graphSage_conv               = pyg.nn.SAGEConv
        self.dropout                 = dropout
        self.layers                  = layers
        self.return_embedding        = return_embedding

        ### Initalize the layers ###
        self.convs                   = nn.ModuleList()                      # ModuleList to hold the layers
        for l in range(self.layers):
            if l == 0:
                ### First layer  maps from input_dim to hidden_dim ###
                self.convs.append(graphSage_conv(input_dim, hidden_dim))
            else:
                ### All other layers map from hidden_dim to hidden_dim ###
                self.convs.append(graphSage_conv(hidden_dim, hidden_dim))

        # post-message-passing processing MLP
        self.post_mp = nn.Sequential(
                                     nn.Linear(hidden_dim, hidden_dim), 
                                     nn.Dropout(self.dropout),
                                     nn.Linear(hidden_dim, output_dim))

    def forward(self, x, edge_index):
        for i in range(self.layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.post_mp(x)

        # Return final layer of return_embeddingeddings if specified
        if self.return_embedding:
            return x

        # Else return class probabilities
        return F.log_softmax(x, dim=1)

    #def loss(self, pred, label):
    #    return F.nll_loss(pred, label)
    


class LinkPredictorMLP(nn.Module):
    def __init__(self, in_channels:int, hidden_channels:int, out_channels:int, n_layers:int,dropout_probabilty:float=0.3):
        """
        Args:
            in_channels (int):     Number of input features.
            hidden_channels (int): Number of hidden features.
            out_channels (int):    Number of output features.
            n_layers (int):        Number of MLP layers.
            dropout (float):       Dropout probability.
            """
        super(LinkPredictorMLP, self).__init__()
        self.dropout_probabilty    = dropout_probabilty  # dropout probability
        self.mlp_layers            = nn.ModuleList()     # ModuleList: is a list of modules
        self.non_linearity         = F.relu              # non-linearity
        
        for i in range(n_layers - 1):                                 
            if i == 0:
                self.mlp_layers.append(nn.Linear(in_channels, hidden_channels))          # input layer (in_channels, hidden_channels)
            else:
                self.mlp_layers.append(nn.Linear(hidden_channels, hidden_channels))      # hidden layers (hidden_channels, hidden_channels)

        self.mlp_layers.append(nn.Linear(hidden_channels, out_channels))                 # output layer (hidden_channels, out_channels)


    def reset_parameters(self):
        for mlp_layer in self.mlp_layers:
            mlp_layer.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j                                                     # element-wise multiplication
        for mlp_layer in self.mlp_layers[:-1]:                            # iterate over all layers except the last one
            x = mlp_layer(x)                                              # apply linear transformation
            x = self.non_linearity(x)                                     # Apply non linear activation function
            #x = F.dropout(x, p=self.dropout_probabilty,training=self.training)      # Apply dropout
            #x = F.dropout(x, p=self.dropout_probabilty)      # Apply dropout
        x = self.mlp_layers[-1](x)                                        # apply linear transformation to the last layer
        x = torch.sigmoid(x)                                              # apply sigmoid activation function to get the probability
        return x
    
### We will use This function to save our best model during trainnig ###
def save_torch_model(model,epoch,PATH:str,optimizer):
    print(f"Saving Model in Path {PATH}")
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer':optimizer,      
                }, PATH)

In [13]:
data

HeteroData(
  [1mCompound[0m={ x=[15182, 768] },
  [1mDisease[0m={ x=[4098, 768] },
  [1m(Compound, Compound_treats_the_disease, Disease)[0m={ edge_index=[2, 48554] },
  [1m(Disease, rev_Compound_treats_the_disease, Compound)[0m={ edge_index=[2, 48554] }
)

In [14]:
epochs        = 500
hidden_dim    = 524      # 256 
dropout       = 0.7
num_layers    = 3
learning_rate = 1e-4
node_emb_dim  = 768
device        = "cpu"



HomoGNN         = GNNStack(node_emb_dim, hidden_dim, hidden_dim, num_layers, dropout, return_embedding=True).to(device) # the graph neural network that takes all the node embeddings as inputs to message pass and agregate
HeteroGNN       = to_hetero(HomoGNN   , data.metadata(), aggr='sum')
link_predictor  = LinkPredictorMLP(hidden_dim, hidden_dim, 1, num_layers , dropout).to(device) # the MLP that takes embeddings of a pair of nodes and predicts the existence of an edge between them
#optimizer      = torch.optim.AdamW(list(model.parameters()) + list(link_predictor.parameters() ), lr=learning_rate, weight_decay=1e-4)
optimizer       = torch.optim.Adam(list(HeteroGNN.parameters()) + list(link_predictor.parameters() ), lr=learning_rate)

print(HeteroGNN )
print(link_predictor)
print(f"Models Loaded to {device}")
data.to(device)
HeteroGNN.to(device)

GraphModule(
  (convs): ModuleList(
    (0): ModuleDict(
      (Compound__Compound_treats_the_disease__Disease): SAGEConv(768, 524, aggr=mean)
      (Disease__rev_Compound_treats_the_disease__Compound): SAGEConv(768, 524, aggr=mean)
    )
    (1-2): 2 x ModuleDict(
      (Compound__Compound_treats_the_disease__Disease): SAGEConv(524, 524, aggr=mean)
      (Disease__rev_Compound_treats_the_disease__Compound): SAGEConv(524, 524, aggr=mean)
    )
  )
  (post_mp): ModuleList(
    (0): ModuleDict(
      (Compound): Linear(in_features=524, out_features=524, bias=True)
      (Disease): Linear(in_features=524, out_features=524, bias=True)
    )
    (1): ModuleDict(
      (Compound): Dropout(p=0.7, inplace=False)
      (Disease): Dropout(p=0.7, inplace=False)
    )
    (2): ModuleDict(
      (Compound): Linear(in_features=524, out_features=524, bias=True)
      (Disease): Linear(in_features=524, out_features=524, bias=True)
    )
  )
)



def forward(self, x, edge_index):
    x_dict = torch_geo

GraphModule(
  (convs): ModuleList(
    (0): ModuleDict(
      (Compound__Compound_treats_the_disease__Disease): SAGEConv(768, 524, aggr=mean)
      (Disease__rev_Compound_treats_the_disease__Compound): SAGEConv(768, 524, aggr=mean)
    )
    (1-2): 2 x ModuleDict(
      (Compound__Compound_treats_the_disease__Disease): SAGEConv(524, 524, aggr=mean)
      (Disease__rev_Compound_treats_the_disease__Compound): SAGEConv(524, 524, aggr=mean)
    )
  )
  (post_mp): ModuleList(
    (0): ModuleDict(
      (Compound): Linear(in_features=524, out_features=524, bias=True)
      (Disease): Linear(in_features=524, out_features=524, bias=True)
    )
    (1): ModuleDict(
      (Compound): Dropout(p=0.7, inplace=False)
      (Disease): Dropout(p=0.7, inplace=False)
    )
    (2): ModuleDict(
      (Compound): Linear(in_features=524, out_features=524, bias=True)
      (Disease): Linear(in_features=524, out_features=524, bias=True)
    )
  )
)

In [15]:
node_emb   = HeteroGNN(data.x_dict, data.edge_index_dict)
edge_index = data['Compound', 'Compound_treats_the_disease', 'Disease'].edge_index 
pos_pred   = link_predictor(node_emb["Compound"][edge_index[0]], node_emb["Disease"][edge_index[1]])   # (B, )



In [16]:
def forward_pass(model, link_predictor,data,return_node_emb:bool=False,prediction_entites:tuple=("Compound","Disease")):
    ## If model is provided get GNN embeddings ##
    if model !=  None:                                        # If model is provided, embedd
        node_emb   = model(data.x_dict, data.edge_index_dict) # Embed Bert Embeddigns with graphsage (N, d) 
    else:                                                     # else 
        node_emb = x                                          #  use Bert default  Embedddings
   
    edge_index = data['Compound', 'Compound_treats_the_disease', 'Disease'].edge_label_index
    pred       = link_predictor(node_emb[prediction_entites[0]][edge_index[0]], node_emb[prediction_entites[1]][edge_index[1]])   
    if model !=  None and return_node_emb == True:
        return (pred,node_emb)
    else:
        return pred 


def train(model, link_predictor, data, optimizer,triplet:tuple=('Compound', 'Compound_treats_the_disease', 'Disease'),device:str="cuda"):
    """
    Runs offline training for model, link_predictor and node embeddings given the message
    edges and supervision edges.
    :param model: Torch Graph model used for updating node embeddings based on message passing 
        (If None, no embbeding is performed) 
    :param link_predictor: Torch model used for predicting whether edge exists or not
    :param emb: (N, d) Initial node embeddings for all N nodes in graph
    :param edge_index: (2, E) Edge index for all edges in the graph

    :param optimizer: Torch Optimizer to update model parameters
    :return: Average supervision loss over all positive (and correspondingly sampled negative) edges
    """
    if model != None: 
        model.train()
    link_predictor.train()
    train_losses = []

    optimizer.zero_grad()                                  # Reset Gradients
    #edge_index     = torch.tensor(edge_index).T           # Reshape edge index     (2,|E|)
    #x              = x.squeeze(dim=1)                     # Reshape Feature matrix (|N|,D)
    #x , edge_index = x.to(device) , edge_index.to(device) # Move data to devices


    ### Step 1: Get Embeddings:
    # Run message passing on the inital node embeddings to get updated embeddings

    ### This model has the option of only running link predictor without graphsage, for that case the node embedding
    ### is equal to the original embedding (X)
    pred           = forward_pass(model, link_predictor,data,return_node_emb=False)
    ground_truth   = data[triplet].edge_label.to(device)
    ground_truth[ground_truth  == 2] = 1
    #print(ground_truth)
    
    
    loss = F.binary_cross_entropy_with_logits(pred, ground_truth.unsqueeze(1).float())
    loss.backward()    # Backpropagate and update parameters
    optimizer.step()

    train_losses.append(loss.item())
    return sum(train_losses) / len(train_losses)

In [17]:
#data.to("cuda")

In [18]:
data['Compound', 'Compound_treats_the_disease', 'Disease'].edge_label = torch.ones(data['Compound', 'Compound_treats_the_disease', 'Disease'].edge_index.shape[1], dtype=torch.long).to(device)
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.5,
    neg_sampling_ratio=2,
    add_negative_train_samples=True,
    edge_types=("Compound", "Compound_treats_the_disease", "Disease"),
    rev_edge_types=("Compound", "rev_Compound_treats_the_disease", "Disease"), 
   
)

train_data, val_data, test_data = transform(data)


#print(f"Train Data:\n{train_data}")
#print(f"Validation Data:\n{val_data}")
#print(f"Test Data:\n{test_data}")

In [19]:
node_emb   = HeteroGNN(train_data.x_dict, train_data.edge_index_dict)
edge_index = train_data['Compound', 'Compound_treats_the_disease', 'Disease'].edge_label_index
pos_pred   = link_predictor(node_emb["Compound"][edge_index[0]], node_emb["Disease"][edge_index[1]])   # (B, )/'

print(pos_pred.shape)

torch.Size([58266, 1])


In [20]:
train_loss                      = []
train_accuracy                  = []
show_metrics_every              = 20
best_accuracy                   = 0 
best_graphsage_model_path       = ""
best_link_predictor_model_path  = ""
epochs                          = 5000

for epoch in range(1,epochs):
    
    ### TRAIN ####
    loss = train(model =  HeteroGNN, link_predictor = link_predictor, data = train_data, optimizer =  optimizer,device=device)
    train_loss.append(loss)
    print(f"Epoch {epoch}: loss: {round(loss, 5)}")
    




Epoch 1: loss: 0.80813
Epoch 2: loss: 0.80768
Epoch 3: loss: 0.80718
Epoch 4: loss: 0.8066


KeyboardInterrupt: 

In [26]:
threshold = 0.5
with torch.no_grad():
    node_emb         = HeteroGNN(val_data.x_dict, val_data.edge_index_dict)
    edge_index       = test_data['Compound', 'Compound_treats_the_disease', 'Disease'].edge_label_index
    
    pos_pred         = link_predictor(node_emb["Compound"][edge_index[0]], node_emb["Disease"][edge_index[1]])   # (B, )
    ground_truth     = test_data['Compound', 'Compound_treats_the_disease', 'Disease'].edge_label
    ground_truth[ground_truth  == 2] = 1

    acc           = accuracy_score(pos_pred.to("cpu")  > threshold  ,ground_truth.to("cpu") )
    print(acc)




0.6666666666666666


In [90]:
with torch.no_grad():
    predictions  = forward_pass(model =  HeteroGNN, link_predictor = link_predictor, data = train_data)
ground_truth     = test_data['Compound', 'Compound_treats_the_disease', 'Disease'].edge_label
ground_truth[ground_truth  == 2] = 1

In [44]:

print(ground_truth )

print(predictions  > 0.5)

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,