In [None]:
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
!pip install torch-geometric

In [None]:
import torch ;print(torch.__version__)

In [None]:
!nvidia-smi


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
import torch_geometric
from torch_geometric.nn import SAGEConv, to_hetero
from   torch.utils.data      import Dataset, DataLoader
from   torch_geometric.data  import Data
from   torch_geometric.utils import negative_sampling

from torch_geometric.nn import SAGEConv, to_hetero

import torch_geometric.transforms as T


class GNNStack(torch.nn.Module):
    def __init__(self, input_dim:int, hidden_dim:int, output_dim:int, layers:int, dropout:float=0.3, return_embedding=False):
        """
            A stack of GraphSAGE Module 
            input_dim        <int>:   Input dimension
            hidden_dim       <int>:   Hidden dimension
            output_dim       <int>:   Output dimension
            layers           <int>:   Number of layers
            dropout          <float>: Dropout rate
            return_embedding <bool>:  Whether to return the return_embeddingedding of the input graph
        """
        
        super(GNNStack, self).__init__()
        graphSage_conv               = pyg.nn.SAGEConv
        self.dropout                 = dropout
        self.layers                  = layers
        self.return_embedding        = return_embedding

        ### Initalize the layers ###
        self.convs                   = nn.ModuleList()                      # ModuleList to hold the layers
        for l in range(self.layers):
            if l == 0:
                ### First layer  maps from input_dim to hidden_dim ###
                self.convs.append(graphSage_conv(input_dim, hidden_dim))
            else:
                ### All other layers map from hidden_dim to hidden_dim ###
                self.convs.append(graphSage_conv(hidden_dim, hidden_dim))

        # post-message-passing processing MLP
        self.post_mp = nn.Sequential(
                                     nn.Linear(hidden_dim, hidden_dim), 
                                     nn.Dropout(self.dropout),
                                     nn.Linear(hidden_dim, output_dim))

    def forward(self, x, edge_index):
        for i in range(self.layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.post_mp(x)

        # Return final layer of return_embeddingeddings if specified
        if self.return_embedding:
            return x

        # Else return class probabilities
        return F.log_softmax(x, dim=1)

    def loss(self, pred, label):
        return F.nll_loss(pred, label)
    


class LinkPredictorMLP(nn.Module):
    def __init__(self, in_channels:int, hidden_channels:int, out_channels:int, n_layers:int,dropout_probabilty:float=0.3):
        """
        Args:
            in_channels (int):     Number of input features.
            hidden_channels (int): Number of hidden features.
            out_channels (int):    Number of output features.
            n_layers (int):        Number of MLP layers.
            dropout (float):       Dropout probability.
            """
        super(LinkPredictorMLP, self).__init__()
        self.dropout_probabilty    = dropout_probabilty  # dropout probability
        self.mlp_layers            = nn.ModuleList()     # ModuleList: is a list of modules
        self.non_linearity         = F.relu              # non-linearity
        
        for i in range(n_layers - 1):                                 
            if i == 0:
                self.mlp_layers.append(nn.Linear(in_channels, hidden_channels))          # input layer (in_channels, hidden_channels)
            else:
                self.mlp_layers.append(nn.Linear(hidden_channels, hidden_channels))      # hidden layers (hidden_channels, hidden_channels)

        self.mlp_layers.append(nn.Linear(hidden_channels, out_channels))                 # output layer (hidden_channels, out_channels)


    def reset_parameters(self):
        for mlp_layer in self.mlp_layers:
            mlp_layer.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j                                                     # element-wise multiplication
        for mlp_layer in self.mlp_layers[:-1]:                            # iterate over all layers except the last one
            x = mlp_layer(x)                                              # apply linear transformation
            x = self.non_linearity(x)                                     # Apply non linear activation function
            x = F.dropout(x, p=self.dropout_probabilty,training=self.training)      # Apply dropout
        x = self.mlp_layers[-1](x)                                        # apply linear transformation to the last layer
        x = torch.sigmoid(x)                                              # apply sigmoid activation function to get the probability
        return x
    
### We will use This function to save our best model during trainnig ###
def save_torch_model(model,epoch,PATH:str,optimizer):
    print(f"Saving Model in Path {PATH}")
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer':optimizer,      
                }, PATH)
    
    
def train(model, link_predictor, dataset, optimizer,device:str="cuda"):
    """
    Runs offline training for model, link_predictor and node embeddings given the message
    edges and supervision edges.
    :param model: Torch Graph model used for updating node embeddings based on message passing 
        (If None, no embbeding is performed) 
    :param link_predictor: Torch model used for predicting whether edge exists or not
    :param emb: (N, d) Initial node embeddings for all N nodes in graph
    :param edge_index: (2, E) Edge index for all edges in the graph

    :param optimizer: Torch Optimizer to update model parameters
    :return: Average supervision loss over all positive (and correspondingly sampled negative) edges
    """
    if model != None:
        model.train()
    link_predictor.train()
    train_losses = []
    for x, edge_index in tqdm(dataset):                       # Get X and Index from Dataste
        optimizer.zero_grad()                                 # Reset Gradients
        edge_index     = torch.tensor(edge_index).T           # Reshape edge index     (2,|E|)
        x              = x.squeeze(dim=1)                     # Reshape Feature matrix (|N|,D)
        x , edge_index = x.to(device) , edge_index.to(device) # Move data to devices
        
        
        ### Step 1: Get Embeddings:
        # Run message passing on the inital node embeddings to get updated embeddings
        
        ### This model has the option of only running link predictor without graphsage, for that case the node embedding
        ### is equal to the original embedding (X)
        if model !=  None:
            node_emb   = model(data.x_dict, data.edge_index_dict) # Embed Bert Embeddigns with graphsage (N, d) 
            
        else:
            node_emb = x                     # Else (None) use Bert Embedddings
        # Predict the class probabilities on the batch of positive edges using link_predictor
        #print(node_emb[edge_index[0]].shape)
        edge_index = data['compounds', 'treats', 'disease'].edge_index   
        pred       = link_predictor(node_emb["compounds"][edge_index[0]], node_emb["disease"][edge_index[0]])   # (B, )

        loss = F.binary_cross_entropy_with_logits(pred, ground_truth)
        # Backpropagate and update parameters
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
    return sum(train_losses) / len(train_losses)


def evaluate(model, predictor,dataset,device="cpu",threshold=0.5,ppi_:int=0.9,verbose:bool=False,best_accuracy=0,show_extra_metrics:bool=False):
    if model != None:
        model.eval()
    possitive_acc  = 0 
    negative_acc   = 0
    batches        = 0
    if show_extra_metrics:
        yhat_total     = []
        y_total        = []
        
    for x, edge_index in dataset:                              # Get X and Index from Dataste

        edge_index      = torch.tensor(edge_index).T           # Reshape edge index     (2,|E|)
        number_of_edges =  edge_index.shape[1]                 # Retrive number of edges 
        permutations    =  torch.randperm(number_of_edges)     # Create Permutations for edge index
        edge_index      = edge_index[:,permutations]           # Run permutation
        limit           = int(ppi_*number_of_edges)            # get limit  (based on ppis to embed)
        ppi_index_embed = edge_index[:,0:limit]                # PPI to embed with GraphSage 
        ppi_index_infer = edge_index[:,limit:]                 # PPI to make inference
        
        x                    = x.squeeze(dim=1)                          # Reshape Feature matrix (|N|,D)
        x ,ppi_index_embed   = x.to(device) , ppi_index_embed.to(device) # Move data to devices
        if model !=  None:
            node_emb             = model(x,ppi_index_embed)              # Get all node embeddings
        else:
            node_emb = x                                                 # Else (None) use Bert Embedddings
            
        if verbose:
            print(f" {limit} Positive Protein Interactions were used to Embed a graph with {number_of_edges} ppi's")
        
        del ppi_index_embed 
        with torch.no_grad():
            ### Positive PPI ###
            positive_pairs_embeddings = node_emb[ppi_index_infer[0]].to(device), node_emb[ppi_index_infer[1]].to(device)
            predictions               = predictor(positive_pairs_embeddings[0], positive_pairs_embeddings[1]) 
            y                         = torch.ones_like(input=predictions)
            predictions,y             = predictions.cpu(),y.cpu()
            possitive_acc            += accuracy_score(predictions > threshold  ,y)
            if show_extra_metrics:
                yhat_total.extend(predictions.tolist())
                y_total.extend(y.tolist())
                
            else:
                del y, predictions , positive_pairs_embeddings,ppi_index_infer

            ### Negative PPI ##
            edge_index  =  edge_index.to(device)
            neg_edge    = negative_sampling(edge_index       = edge_index,        # Possitve PPI's
                                         num_nodes        = x.shape[0],           # Total number of nodes in graph
                                         num_neg_samples  = edge_index.shape[1],  # Same Number of edges as in positive example
                                         method           = 'dense',              # Method for edge generation
                                         force_undirected = True)                 # Our graph is undirected

            negative_pairs_embeddings = node_emb[neg_edge[0]].to(device), node_emb[neg_edge[1]].to(device)
            predictions               = predictor(negative_pairs_embeddings[0], negative_pairs_embeddings[1])   
            y                         = torch.zeros_like(input=predictions)
            predictions,y             = predictions.cpu(),y.cpu()
            negative_acc             += accuracy_score(predictions > threshold,y)
            if show_extra_metrics:
                yhat_total.extend(predictions.tolist())
                y_total.extend(y.tolist())
                
            else:
                del y,  predictions  ,negative_pairs_embeddings 
            batches +=1

    negative_acc  = negative_acc/batches
    possitive_acc = possitive_acc/batches
    total_acc     = 0.5*possitive_acc  + 0.5*negative_acc
    if show_extra_metrics == False:
        print(f"Sensitivity (poss_acc):{possitive_acc:.4f} Specificity (negative_acc):{negative_acc:.4f} accuracy:{total_acc:.4f}")
    
    elif show_extra_metrics == True:
        
        fig, ax = plt.subplots(1, 2,figsize=(10,2))
        fpr, tpr, thresholds = metrics.roc_curve( y_total, yhat_total)
        
        sens      =  tpr
        spec      =  1 - fpr
        j         = sens + spec -1
        opt_index = np.where(j == np.max(j))[0][0]
        op_point  = thresholds[opt_index]
        
        print(f"Youdens  index: {op_point:.4f} Sensitivity: {round(sens[opt_index],4)} Specificity: {round(spec[opt_index],4)}")
       
        ax[0].set_title("ROC Curve")
        ax[1].set_title("Confussion Matrix")
        if model == None:
            ax[0].plot(fpr,tpr,label="MLP") 
        else:
            ax[0].plot(fpr,tpr,label="GraphSage+MLP") 
        ax[0].plot([0, 1], [0, 1], 'k--')
        ax[0].set_ylabel('True Positive Rate')
        ax[0].set_xlabel('False Positive Rate')
        ax[0].legend()
       
    
        cfm = metrics.confusion_matrix(y_total, np.array(yhat_total)> op_point)
        
        cmn = cfm.astype('float') / cfm.sum(axis=1)[:, np.newaxis] # Normalise
        disp = ConfusionMatrixDisplay(cmn)
        disp.plot(ax=ax[1])
        
        plt.show()
     
    return total_acc 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch 
import os 
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric.loader import NeighborLoader


In [3]:
DATA_PATH = os.path.join("data","graph_obj")
GRAPH     = T.ToUndirected()(torch.load(DATA_PATH, map_location="cpu"))

In [64]:
## Embedding needs to remove . 
del GRAPH ['inhibits cell growth (esp. cancers)']
del GRAPH['rev_inhibits cell growth (esp. cancers)']
del GRAPH ['binding, ligand (esp. receptors)']
del GRAPH['rev_binding, ligand (esp. receptors)']


# saved as none 
del GRAPH['increases expression/production']
del GRAPH['rev_increases expression/production']
del GRAPH[("Compound", 'increases expression/production', "Gene")]
del GRAPH[("Compound", 'decreases expression/production', "Gene")]
del GRAPH[("Compound", 'affects expression/production (neutral)', "Gene")]
del GRAPH[("Compound", 'metabolism, pharmacokinetics', "Gene")]
del GRAPH[("Compound", 'transport, channels', "Gene")]
del GRAPH[("Compound", 'biomarkers (of_disease_progression)', "Disease")]
del GRAPH[("Compound", 'alleviates, reduces', "Disease")]
del GRAPH[("Compound", 'prevents, suppresses', "Disease")]





In [97]:
target_label = GRAPH[("Compound", "Compound treats the disease", "Disease")].edge_index
print(target_label.shape)
GRAPH[("Compound", "Compound treats the disease", "Disease")].edge_label =  torch.ones(target_label.shape[1],).long()

torch.Size([2, 48185])


In [116]:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=3.0,
    add_negative_train_samples=True,
    edge_types=("Compound", "Compound treats the disease", "Disease"),
    rev_edge_types=("Compound", "rev_Compound treats the disease", "Disease"), 
)

In [121]:
train_data, val_data, test_data = transform(GRAPH)


In [122]:
edge_label_index = train_data["Compound", "Compound treats the disease", "Disease"].edge_label_index
edge_label = train_data["Compound", "Compound treats the disease", "Disease"].edge_label

In [123]:
from torch_geometric.loader import LinkNeighborLoader
# Define seed edges:
edge_label_index = train_data["Compound", "Compound treats the disease", "Disease"].edge_label_index
edge_label       = train_data["Compound", "Compound treats the disease", "Disease"].edge_label


train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20, 10],
    neg_sampling_ratio=2.0,
    edge_label_index=(("Compound", "Compound treats the disease", "Disease"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

AttributeError: 'EdgeStorage' object has no attribute 'edge_index'

In [67]:
#print(GRAPH)
train_loader = NeighborLoader(
    GRAPH,
   
    num_neighbors=[15] * 2,   # Sample 15 neighbors for each node and each edge type for 2 iterations:
    batch_size=128,           # Use a batch size of 128 for sampling training nodes of type "paper":
    input_nodes=('Compound'),
)

batch = next(iter(train_loader))



AttributeError: 'EdgeStorage' object has no attribute 'edge_index'

In [None]:
subgraph_loader = NeighborLoader(
    GRAPH,
    input_nodes=None,
    num_neighbors=[-1],
    batch_size=4096,
    num_workers=12,
    persistent_workers=True,
)

In [None]:
epochs        = 500
hidden_dim    = 524      # 256 
dropout       = 0.7
num_layers    = 3
learning_rate = 1e-4
node_emb_dim  = 768
device        = "cpu"

HomoGNN         = GNNStack(node_emb_dim, hidden_dim, hidden_dim, num_layers, dropout, return_embedding=True).to(device) # the graph neural network that takes all the node embeddings as inputs to message pass and agregate
HeteroGNN       = to_hetero(HomoGNN   , GRAPH.metadata(), aggr='sum')
link_predictor  = LinkPredictorMLP(hidden_dim, hidden_dim, 1, num_layers , dropout).to(device) # the MLP that takes embeddings of a pair of nodes and predicts the existence of an edge between them
#optimizer      = torch.optim.AdamW(list(model.parameters()) + list(link_predictor.parameters() ), lr=learning_rate, weight_decay=1e-4)
optimizer       = torch.optim.Adam(list(HeteroGNN.parameters()) + list(link_predictor.parameters() ), lr=learning_rate)

#print(HeteroGNN )
#print(link_predictor)
print(f"Models Loaded to {device}")

In [68]:
#GRAPH.x_dict
#GRAPH.edge_index_dict

In [69]:
HeteroGNN.convs

ModuleList(
  (0): ModuleDict(
    (Compound__activation__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__allosteric_modulation__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__antagonism__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__antibody__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__binding__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__blocking__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__inhibition__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__modulation__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__other__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__partial_agonism__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__positive_allosteric_modulation__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__carrier__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__enzyme__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__target__Gene): SAGEConv(768, 524, aggr=mean)
    (Compound__association__Gene): SAGEConv(768, 524, aggr=m

In [124]:
with torch.no_grad():  # Initialize lazy modules.
    node_emb   = HeteroGNN(test_data.x_dict, test_data .edge_index_dict)

ValueError: `MessagePassing.propagate` only supports integer tensors of shape `[2, num_messages]`, `torch_sparse.SparseTensor` or `torch.sparse.Tensor` for argument `edge_index`.

In [None]:
node_emb   = HeteroGNN(GRAPH.x_dict, GRAPH.edge_index_dict)
edge_index = GRAPH['compounds', 'treats', 'disease'].edge_index 
pos_pred    = link_predictor(node_emb["compounds"][edge_index[0]], node_emb["disease"][edge_index[0]])   # (B, )