In [25]:
import torch.nn as nn

class NodeEmbeddings(nn.Module):
    def __init__(self, n_drugs, n_proteins, n_diseases, embedding_dim):
        super(NodeEmbeddings, self).__init__()
        
        # Embedding for Drugs
        self.drug_embeddings = nn.Embedding(n_drugs, embedding_dim)
        
        # Embedding for Proteins
        self.protein_embeddings = nn.Embedding(n_proteins, embedding_dim)

        # Embedding for Diseases
        self.disease_embeddings = nn.Embedding(n_diseases, embedding_dim)


In [26]:
import torch
def project_onto_hyperplane(v, hyperplane_normal):
    """
    Project the vector v onto a hyperplane defined by its normal vector.
    """
    # Calculate the dot product along the last dimension
    dot_product = torch.sum(v * hyperplane_normal, dim=-1, keepdim=True)
    projection = v - dot_product * hyperplane_normal / (torch.norm(hyperplane_normal)**2)
    return projection


In [27]:
import torch

class HPProjection(nn.Module):
    def __init__(self, n_drugs, n_proteins, n_diseases, embedding_dim):
        super(HPProjection, self).__init__()

        self.node_embeddings = NodeEmbeddings(
            n_drugs, n_proteins, n_diseases, embedding_dim)
        
        # Single shared hyperplane normal vector for all types of nodes
        self.hyperplane_embedding = nn.Parameter(torch.randn(embedding_dim))

    def forward(self, drug_index, protein_index, disease_index):
        # Retrieve the embeddings for the given indices
        drug_emb = self.node_embeddings.drug_embeddings(drug_index)
        protein_emb = self.node_embeddings.protein_embeddings(protein_index)
        disease_emb = self.node_embeddings.disease_embeddings(disease_index)


        # print(drug_emb)
        # Project the embeddings onto the shared hyperplane
         # Project the embeddings onto the shared hyperplane
        drug_projected = project_onto_hyperplane(drug_emb, self.hyperplane_embedding)
        protein_projected = project_onto_hyperplane(protein_emb, self.hyperplane_embedding)
        disease_projected = project_onto_hyperplane(disease_emb, self.hyperplane_embedding)

        return drug_projected, protein_projected, disease_projected

        


In [30]:
hyper_plane_projection= HPProjection(
    
    n_drugs = 500, n_proteins=200, n_diseases=400, embedding_dim=128
)

In [32]:
drug_projection  , _,_= hyper_plane_projection(torch.tensor([1 , 2 ,3]) , torch.tensor([1,2,3]) , torch.tensor([1,2,3]) , )

In [33]:
drug_projection.shape

torch.Size([3, 128])