AUTHORS: Alejandro

https://medium.com/@pytorch_geometric/link-prediction-on-heterogeneous-graphs-with-pyg-6d5c29677c70

In [1]:
import pandas as pd
import numpy as np
import os 

### Build HeteroData Object

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_geometric
from torch_geometric.nn import SAGEConv, to_hetero
from   torch.utils.data      import Dataset, DataLoader
from   torch_geometric.data  import Data
from   torch_geometric.utils import negative_sampling

from torch_geometric.nn import SAGEConv, to_hetero

import torch_geometric.transforms as T

import torch 
import os 
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric.loader import NeighborLoader

DATA_PATH = os.path.join("data","graph_obj")
GRAPH     = T.ToUndirected()(torch.load(DATA_PATH, map_location="cpu"))




  from .autonotebook import tqdm as notebook_tqdm


In [4]:
data = HeteroData()
data['Compound'].x  =  GRAPH["Compound"]["x"].float().to("cpu")
data['Disease'].x   = GRAPH["Disease"]["x"].float().to("cpu")
ctd                 = GRAPH[("Compound", "Compound treats the disease", "Disease")]["edge_index"].to_sparse().to("cpu")
data['Compound', 'treats', 'Disease'].edge_index = ctd
target_label                                     = data['Compound', 'treats', 'Disease'].edge_index
#data['Compound', 'treats', 'Disease'].edge_label =  torch.ones(target_label.shape[1],)
data = T.ToUndirected()(data)





print(target_label.shape)

print(data)



torch.Size([2, 48185])
HeteroData(
  [1mCompound[0m={ x=[15182, 768] },
  [1mDisease[0m={ x=[3750, 768] },
  [1m(Compound, treats, Disease)[0m={ edge_index=[2, 48185] },
  [1m(Disease, rev_treats, Compound)[0m={ edge_index=[2, 48185] }
)


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_geometric
from torch_geometric.nn import SAGEConv, to_hetero
from   torch.utils.data      import Dataset, DataLoader
from   torch_geometric.data  import Data
from   torch_geometric.utils import negative_sampling

from torch_geometric.nn import SAGEConv, to_hetero

import torch_geometric.transforms as T


class GNNStack(torch.nn.Module):
    def __init__(self, input_dim:int, hidden_dim:int, output_dim:int, layers:int, dropout:float=0.3, return_embedding=False):
        """
            A stack of GraphSAGE Module 
            input_dim        <int>:   Input dimension
            hidden_dim       <int>:   Hidden dimension
            output_dim       <int>:   Output dimension
            layers           <int>:   Number of layers
            dropout          <float>: Dropout rate
            return_embedding <bool>:  Whether to return the return_embeddingedding of the input graph
        """
        
        super(GNNStack, self).__init__()
        graphSage_conv               = pyg.nn.SAGEConv
        self.dropout                 = dropout
        self.layers                  = layers
        self.return_embedding        = return_embedding

        ### Initalize the layers ###
        self.convs                   = nn.ModuleList()                      # ModuleList to hold the layers
        for l in range(self.layers):
            if l == 0:
                ### First layer  maps from input_dim to hidden_dim ###
                self.convs.append(graphSage_conv(input_dim, hidden_dim))
            else:
                ### All other layers map from hidden_dim to hidden_dim ###
                self.convs.append(graphSage_conv(hidden_dim, hidden_dim))

        # post-message-passing processing MLP
        self.post_mp = nn.Sequential(
                                     nn.Linear(hidden_dim, hidden_dim), 
                                     nn.Dropout(self.dropout),
                                     nn.Linear(hidden_dim, output_dim))

    def forward(self, x, edge_index):
        for i in range(self.layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.post_mp(x)

        # Return final layer of return_embeddingeddings if specified
        if self.return_embedding:
            return x

        # Else return class probabilities
        return F.log_softmax(x, dim=1)

    def loss(self, pred, label):
        return F.nll_loss(pred, label)
    


class LinkPredictorMLP(nn.Module):
    def __init__(self, in_channels:int, hidden_channels:int, out_channels:int, n_layers:int,dropout_probabilty:float=0.3):
        """
        Args:
            in_channels (int):     Number of input features.
            hidden_channels (int): Number of hidden features.
            out_channels (int):    Number of output features.
            n_layers (int):        Number of MLP layers.
            dropout (float):       Dropout probability.
            """
        super(LinkPredictorMLP, self).__init__()
        self.dropout_probabilty    = dropout_probabilty  # dropout probability
        self.mlp_layers            = nn.ModuleList()     # ModuleList: is a list of modules
        self.non_linearity         = F.relu              # non-linearity
        
        for i in range(n_layers - 1):                                 
            if i == 0:
                self.mlp_layers.append(nn.Linear(in_channels, hidden_channels))          # input layer (in_channels, hidden_channels)
            else:
                self.mlp_layers.append(nn.Linear(hidden_channels, hidden_channels))      # hidden layers (hidden_channels, hidden_channels)

        self.mlp_layers.append(nn.Linear(hidden_channels, out_channels))                 # output layer (hidden_channels, out_channels)


    def reset_parameters(self):
        for mlp_layer in self.mlp_layers:
            mlp_layer.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j                                                     # element-wise multiplication
        for mlp_layer in self.mlp_layers[:-1]:                            # iterate over all layers except the last one
            x = mlp_layer(x)                                              # apply linear transformation
            x = self.non_linearity(x)                                     # Apply non linear activation function
            x = F.dropout(x, p=self.dropout_probabilty,training=self.training)      # Apply dropout
        x = self.mlp_layers[-1](x)                                        # apply linear transformation to the last layer
        x = torch.sigmoid(x)                                              # apply sigmoid activation function to get the probability
        return x
    
### We will use This function to save our best model during trainnig ###
def save_torch_model(model,epoch,PATH:str,optimizer):
    print(f"Saving Model in Path {PATH}")
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer':optimizer,      
                }, PATH)

In [6]:
def train(model, link_predictor, dataset, optimizer,device:str="cuda"):
    """
    Runs offline training for model, link_predictor and node embeddings given the message
    edges and supervision edges.
    :param model: Torch Graph model used for updating node embeddings based on message passing 
        (If None, no embbeding is performed) 
    :param link_predictor: Torch model used for predicting whether edge exists or not
    :param emb: (N, d) Initial node embeddings for all N nodes in graph
    :param edge_index: (2, E) Edge index for all edges in the graph

    :param optimizer: Torch Optimizer to update model parameters
    :return: Average supervision loss over all positive (and correspondingly sampled negative) edges
    """
    if model != None:
        model.train()
    link_predictor.train()
    train_losses = []
    for x, edge_index in tqdm(dataset):                       # Get X and Index from Dataste
        optimizer.zero_grad()                                 # Reset Gradients
        edge_index     = torch.tensor(edge_index).T           # Reshape edge index     (2,|E|)
        x              = x.squeeze(dim=1)                     # Reshape Feature matrix (|N|,D)
        x , edge_index = x.to(device) , edge_index.to(device) # Move data to devices
        
        
        ### Step 1: Get Embeddings:
        # Run message passing on the inital node embeddings to get updated embeddings
        
        ### This model has the option of only running link predictor without graphsage, for that case the node embedding
        ### is equal to the original embedding (X)
        if model !=  None:
            node_emb   = model(data.x_dict, data.edge_index_dict) # Embed Bert Embeddigns with graphsage (N, d) 
            
        else:
            node_emb = x                     # Else (None) use Bert Embedddings
        # Predict the class probabilities on the batch of positive edges using link_predictor
        #print(node_emb[edge_index[0]].shape)
        edge_index = data['compounds', 'treats', 'disease'].edge_index   
        pred       = link_predictor(node_emb["compounds"][edge_index[0]], node_emb["disease"][edge_index[0]])   # (B, )

        loss = F.binary_cross_entropy_with_logits(pred, ground_truth)
        # Backpropagate and update parameters
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
    return sum(train_losses) / len(train_losses)


def evaluate(model, predictor,dataset,device="cpu",threshold=0.5,ppi_:int=0.9,verbose:bool=False,best_accuracy=0,show_extra_metrics:bool=False):
    if model != None:
        model.eval()
    possitive_acc  = 0 
    negative_acc   = 0
    batches        = 0
    if show_extra_metrics:
        yhat_total     = []
        y_total        = []
        
    for x, edge_index in dataset:                              # Get X and Index from Dataste

        edge_index      = torch.tensor(edge_index).T           # Reshape edge index     (2,|E|)
        number_of_edges =  edge_index.shape[1]                 # Retrive number of edges 
        permutations    =  torch.randperm(number_of_edges)     # Create Permutations for edge index
        edge_index      = edge_index[:,permutations]           # Run permutation
        limit           = int(ppi_*number_of_edges)            # get limit  (based on ppis to embed)
        ppi_index_embed = edge_index[:,0:limit]                # PPI to embed with GraphSage 
        ppi_index_infer = edge_index[:,limit:]                 # PPI to make inference
        
        x                    = x.squeeze(dim=1)                          # Reshape Feature matrix (|N|,D)
        x ,ppi_index_embed   = x.to(device) , ppi_index_embed.to(device) # Move data to devices
        if model !=  None:
            node_emb             = model(x,ppi_index_embed)              # Get all node embeddings
        else:
            node_emb = x                                                 # Else (None) use Bert Embedddings
            
        if verbose:
            print(f" {limit} Positive Protein Interactions were used to Embed a graph with {number_of_edges} ppi's")
        
        del ppi_index_embed 
        with torch.no_grad():
            ### Positive PPI ###
            positive_pairs_embeddings = node_emb[ppi_index_infer[0]].to(device), node_emb[ppi_index_infer[1]].to(device)
            predictions               = predictor(positive_pairs_embeddings[0], positive_pairs_embeddings[1]) 
            y                         = torch.ones_like(input=predictions)
            predictions,y             = predictions.cpu(),y.cpu()
            possitive_acc            += accuracy_score(predictions > threshold  ,y)
            if show_extra_metrics:
                yhat_total.extend(predictions.tolist())
                y_total.extend(y.tolist())
                
            else:
                del y, predictions , positive_pairs_embeddings,ppi_index_infer

            ### Negative PPI ##
            edge_index  =  edge_index.to(device)
            neg_edge    = negative_sampling(edge_index       = edge_index,        # Possitve PPI's
                                         num_nodes        = x.shape[0],           # Total number of nodes in graph
                                         num_neg_samples  = edge_index.shape[1],  # Same Number of edges as in positive example
                                         method           = 'dense',              # Method for edge generation
                                         force_undirected = True)                 # Our graph is undirected

            negative_pairs_embeddings = node_emb[neg_edge[0]].to(device), node_emb[neg_edge[1]].to(device)
            predictions               = predictor(negative_pairs_embeddings[0], negative_pairs_embeddings[1])   
            y                         = torch.zeros_like(input=predictions)
            predictions,y             = predictions.cpu(),y.cpu()
            negative_acc             += accuracy_score(predictions > threshold,y)
            if show_extra_metrics:
                yhat_total.extend(predictions.tolist())
                y_total.extend(y.tolist())
                
            else:
                del y,  predictions  ,negative_pairs_embeddings 
            batches +=1

    negative_acc  = negative_acc/batches
    possitive_acc = possitive_acc/batches
    total_acc     = 0.5*possitive_acc  + 0.5*negative_acc
    if show_extra_metrics == False:
        print(f"Sensitivity (poss_acc):{possitive_acc:.4f} Specificity (negative_acc):{negative_acc:.4f} accuracy:{total_acc:.4f}")
    
    elif show_extra_metrics == True:
        
        fig, ax = plt.subplots(1, 2,figsize=(10,2))
        fpr, tpr, thresholds = metrics.roc_curve( y_total, yhat_total)
        
        sens      =  tpr
        spec      =  1 - fpr
        j         = sens + spec -1
        opt_index = np.where(j == np.max(j))[0][0]
        op_point  = thresholds[opt_index]
        
        print(f"Youdens  index: {op_point:.4f} Sensitivity: {round(sens[opt_index],4)} Specificity: {round(spec[opt_index],4)}")
       
        ax[0].set_title("ROC Curve")
        ax[1].set_title("Confussion Matrix")
        if model == None:
            ax[0].plot(fpr,tpr,label="MLP") 
        else:
            ax[0].plot(fpr,tpr,label="GraphSage+MLP") 
        ax[0].plot([0, 1], [0, 1], 'k--')
        ax[0].set_ylabel('True Positive Rate')
        ax[0].set_xlabel('False Positive Rate')
        ax[0].legend()
       
    
        cfm = metrics.confusion_matrix(y_total, np.array(yhat_total)> op_point)
        
        cmn = cfm.astype('float') / cfm.sum(axis=1)[:, np.newaxis] # Normalise
        disp = ConfusionMatrixDisplay(cmn)
        disp.plot(ax=ax[1])
        
        plt.show()
     
    return total_acc 

In [7]:
data.metadata

<bound method HeteroData.metadata of HeteroData(
  [1mCompound[0m={ x=[15182, 768] },
  [1mDisease[0m={ x=[3750, 768] },
  [1m(Compound, treats, Disease)[0m={ edge_index=[2, 48185] },
  [1m(Disease, rev_treats, Compound)[0m={ edge_index=[2, 48185] }
)>

In [8]:
epochs        = 500
hidden_dim    = 524      # 256 
dropout       = 0.7
num_layers    = 3
learning_rate = 1e-4
node_emb_dim  = 768
device        = "cpu"

HomoGNN         = GNNStack(node_emb_dim, hidden_dim, hidden_dim, num_layers, dropout, return_embedding=True).to(device) # the graph neural network that takes all the node embeddings as inputs to message pass and agregate
HeteroGNN       = to_hetero(HomoGNN   , data.metadata(), aggr='sum')
link_predictor  = LinkPredictorMLP(hidden_dim, hidden_dim, 1, num_layers , dropout).to(device) # the MLP that takes embeddings of a pair of nodes and predicts the existence of an edge between them
#optimizer      = torch.optim.AdamW(list(model.parameters()) + list(link_predictor.parameters() ), lr=learning_rate, weight_decay=1e-4)
optimizer       = torch.optim.Adam(list(HeteroGNN.parameters()) + list(link_predictor.parameters() ), lr=learning_rate)

print(HeteroGNN )
print(link_predictor)
print(f"Models Loaded to {device}")

GraphModule(
  (convs): ModuleList(
    (0): ModuleDict(
      (Compound__treats__Disease): SAGEConv(768, 524, aggr=mean)
      (Disease__rev_treats__Compound): SAGEConv(768, 524, aggr=mean)
    )
    (1-2): 2 x ModuleDict(
      (Compound__treats__Disease): SAGEConv(524, 524, aggr=mean)
      (Disease__rev_treats__Compound): SAGEConv(524, 524, aggr=mean)
    )
  )
  (post_mp): ModuleList(
    (0): ModuleDict(
      (Compound): Linear(in_features=524, out_features=524, bias=True)
      (Disease): Linear(in_features=524, out_features=524, bias=True)
    )
    (1): ModuleDict(
      (Compound): Dropout(p=0.7, inplace=False)
      (Disease): Dropout(p=0.7, inplace=False)
    )
    (2): ModuleDict(
      (Compound): Linear(in_features=524, out_features=524, bias=True)
      (Disease): Linear(in_features=524, out_features=524, bias=True)
    )
  )
)



def forward(self, x, edge_index):
    x_dict = torch_geometric_nn_to_hetero_transformer_get_dict(x);  x = None
    x__Compound = x_dict.get(

In [9]:
node_emb   = HeteroGNN(data.x_dict, data.edge_index_dict)
edge_index = data['compounds', 'treats', 'disease'].edge_index 
pos_pred    = link_predictor(node_emb["compounds"][edge_index[0]], node_emb["disease"][edge_index[0]])   # (B, )



  src = src.to_sparse_csr()


RuntimeError: Expected tensor for argument #1 'other' to have scalar type Long; but got torch.FloatTensor instead (while checking arguments for sparse_mm_reduce)

torch.Size([2, 5])

In [101]:
pos_pred 

tensor([[0.5184],
        [0.4856],
        [0.4953],
        [0.5092],
        [0.4947],
        [0.5142],
        [0.5107],
        [0.5027],
        [0.5004],
        [0.4941],
        [0.4897],
        [0.4992],
        [0.4904],
        [0.5110],
        [0.4954],
        [0.5033],
        [0.5056],
        [0.4836],
        [0.5020],
        [0.5049],
        [0.4978],
        [0.5036],
        [0.4993],
        [0.5033],
        [0.5132],
        [0.5004],
        [0.5130],
        [0.4967],
        [0.5031],
        [0.5005],
        [0.5025],
        [0.4947],
        [0.5097],
        [0.5148],
        [0.5067],
        [0.4963],
        [0.4975],
        [0.5010],
        [0.4936],
        [0.5053],
        [0.5058],
        [0.5101],
        [0.4967],
        [0.4891],
        [0.4929],
        [0.4987],
        [0.4818],
        [0.4959],
        [0.4999],
        [0.4986],
        [0.4968],
        [0.5041],
        [0.4915],
        [0.4974],
        [0.4964],
        [0

In [None]:
train_loss                      = []
train_accuracy                  = []
show_metrics_every              = 20
best_accuracy                   = 0 
best_graphsage_model_path       = ""
best_link_predictor_model_path  = ""

for epoch in range(1,epochs):
    
    ### TRAIN ####
    loss = train(model, link_predictor,train_dataset, optimizer,device) # Get Loss
    train_loss.append(loss)
    print(f"Epoch {epoch}: loss: {round(loss, 5)}")
    
    ### EVALUATE ###
    if (epoch % 20 == 0) or (epoch ==1):
        accuracy = evaluate(model, link_predictor ,test_dataset,device=device,best_accuracy=best_accuracy,show_extra_metrics=True)
        
    else:
        accuracy = evaluate(model, link_predictor ,test_dataset,device=device,best_accuracy=best_accuracy)
    
    train_accuracy.append(accuracy)
    ### SAVE ###
    if best_accuracy < accuracy:
        if os.path.exists(best_graphsage_model_path):
            
            os.remove(best_graphsage_model_path)
            
        if os.path.exists(best_link_predictor_model_path):
            os.remove(best_link_predictor_model_path)
        print(f"Replacing models: {best_graphsage_model_path }  {best_link_predictor_model_path}")
            
        best_accuracy  = accuracy
        best_graphsage_model_path      = f"GraphSage_epoch_{epoch}.pt"
        best_link_predictor_model_path =  f"link_predictor_epoch_{epoch}.pt"
        print(f"with: Best models at {best_graphsage_model_path }  {best_link_predictor_model_path}")
        save_torch_model(model,         epoch=epoch,PATH=best_graphsage_model_path ,     optimizer=optimizer)
        save_torch_model(link_predictor,epoch=epoch,PATH=best_link_predictor_model_path, optimizer=optimizer)

        
#### Load Best Models ####

print(f"Loading best models:  {best_graphsage_model_path }  {best_link_predictor_model_path}")
checkpoint = torch.load(best_graphsage_model_path)
model.load_state_dict(checkpoint['model_state_dict'])

checkpoint = torch.load(best_link_predictor_model_path)
link_predictor.load_state_dict(checkpoint['model_state_dict'])

del checkpoint
