In [5]:
from torch_geometric.loader import DataLoader,NodeLoader
from torch_geometric.utils import to_scipy_sparse_matrix
from torch import nn
from torch.nn import functional as F
import torch as th
from tqdm import tqdm 

In [2]:
class graphattention_layer(nn.Module):
    def __init__(self,input_size,output_size,adjM):
        '''WkH_{i-1} is of dimension : CurrentNodeShape x N'''  
        self.inpshape = input_size
        self.opshape = output_size
        self.A = adjM
        super(graphattention_layer,self).__init__()
        self.vks = nn.Linear(in_features=output_size,out_features=  1)
        self.vkr = nn.Linear(in_features=output_size,out_features= 1)
        self.W =  nn.Linear(in_features=input_size,out_features=output_size) 
    def forward(self, H_k,A):
        '''H_k represents the previous layer's graph representation'''
        '''So i have to account for subgraph forward passes,?'''
        if(A is None):
            M_s = self.A * self.vks(F.relu(self.W(H_k))).T
            M_r = (self.A * self.vkr(F.relu(self.W(H_k))).T).T
            Attention = F.softmax(F.sigmoid(M_s+M_r))
            H_new = Attention@F.relu(self.W(H_k))
            return H_new
        else:
            ''' To allow you to pass through subgraphs, in my attempt for batching'''
            M_s = A * self.vks(F.relu(self.W(H_k))).T
            M_r = (A * self.vkr(F.relu(self.W(H_k))).T).T
            Attention = F.softmax(F.sigmoid(M_s+M_r))
            H_new = Attention@F.relu(self.W(H_k))
            return H_new  

In [3]:
class encoder(nn.Module):
    def __init__(self,adjM,input_embeddings):
        super(encoder,self).__init__( )
        ''' 
        remember that in pytorch, your input_size is the last dimension of your input
        So when my input is F*N, input_size = F
        also a row in my matrix corresponds to a cell's representation
        '''
        ''' For scDEGA, the encoder block remains the same, but the decoder block changes'''
        ''' Input embeddings is the embedding vector size for one cell.'''
        self.layer1 = graphattention_layer(input_size=input_embeddings
                                           ,output_size=512
                                           ,adjM=adjM)
        self.layer2 = graphattention_layer(input_size=512
                                           ,output_size=256
                                           ,adjM=adjM)
        self.layer3 = graphattention_layer(input_size=256
                                           ,output_size=64
                                           ,adjM=adjM)
    def forward(self, X,A):
        '''
        X here is the node embeddings, its of shape (N*embedding_size)
        I'm gonna tranpose it once in the start, and then at the end.
        H3 is of size N*64
        '''
        H1 = self.layer1(X,A)
        H2 = self.layer2(H1,A)
        H3 = self.layer3(H2,A)
        return H3

In [4]:
    
class gene_decoder(nn.Module):
    def __init__(self,adjM,reconstruction_embedding,gene_embeddings):
        ''' Reconstruction embedding is the size that your decoded embedding should be.
        Simply said, it's the original embedding size for our genes.'''
        super(gene_decoder,self).__init__()
        self.gene_embeddings = gene_embeddings
        gene_embedding_size = gene_embeddings.shape[1] # embedding of a single gene.
        ''' Assuming gene_embeddings is 64 * 647,
        (647 being the number of genes there) so embedding size is shape[1])'''
        self.cell_layer1 = graphattention_layer(input_size=64,
                                           output_size=256,adjM=adjM)
        self.cell_layer2 = graphattention_layer(input_size=256,
                                           output_size=512,adjM=adjM)
        self.cell_layer3 = graphattention_layer(input_size=512,
                                           output_size=reconstruction_embedding,adjM=adjM)
        self.Wcr = nn.Parameter(
            th.randn(reconstruction_embedding,reconstruction_embedding+gene_embedding_size))
        self.Wgr = nn.Parameter(
            th.randn(gene_embedding_size,gene_embedding_size+reconstruction_embedding))
    def forward(self, H,A):
        '''
        H here is the encoder's output
        I'm gonna stack the gene embeddings to the H matrix
        Encoder should have returned a 64*N matrix
        Gene embeddings should be of dimension num_nodes*64 , which was 647 for the first run.
        So we're concatenating a 647*64 matrix to a N*64 matrix
        '''
        # Hpass is a 64*(N+647) matrix
        Hpass = th.stack(H,self.gene_embeddings) 
        # We're passing in the (N+647)*64 matrix 
        cell_H1 = self.layer1(Hpass,A)
        cell_H2 = self.layer2(cell_H1,A)
        cell_H3= self.layer3(cell_H2,A)
        X_cr = self.Wcr@cell_H3 # cell reconstructed matrix.
        X_gr = self.Wgr@Hpass
        ''' H3 would be of size N*N'''
        return (X_cr,X_gr)


In [None]:
from sklearn.cluster import KMeans
class clustering_layer(nn.Module):
    def __init__(self,cell_embeddings,n_clusters,initial_cluster_assignments="k_means"):
        ''' This is the clustering optimization layer for my model.
        The input will be my cell embeddings, and the output will be the cluster assignments.
        Initial cluster assignments will be based on the kmeans algorithm.
        Each step, I'm going to update the cluster assignments based on the cell embeddings.
        Basically, I'm going to update the Q matrix, which is the cluster assignments matrix.
        '''
        super(clustering_layer,self).__init__()
        n_cells = cell_embeddings.shape[0]
        embedding_size = cell_embeddings.shape[1]
        self.n_clusters = n_clusters
        if(initial_cluster_assignments == "k_means"):
            self.clustering_method = KMeans(n_clusters=n_clusters)
        self.center_embeddings = nn.Parameter(th.randn(n_clusters,embedding_size))

    def calculate_membership_matrix(cell_embeddings, center_embeddings):
        ''' 
        This function calculates the membership matrix for the cell embeddings.
        The membership matrix is of size N*C, where N is the number of nodes, and C is the number of clusters.
        The membership matrix is calculated using the t-distribution kernel, with dof 1.
        '''
        
        # Calculate the pairwise distances between cell embeddings and center embeddings
        # Expand dimensions to enable broadcasting
        cell_embeddings_expanded = cell_embeddings.unsqueeze(1)  # Shape: (N, 1, D)
        center_embeddings_expanded = center_embeddings.unsqueeze(0)  # Shape: (1, C, D)
        
        # Compute distances using broadcasting
        distances = th.norm(cell_embeddings_expanded - center_embeddings_expanded, dim=2)  # Shape: (N, C)
        # Calculate membership values
        membership_matrix = 1 / (1 + distances)  # Shape: (N, C)
        return membership_matrix

    def calculate_auxilliary_matrix(Q):
        """
        Compute matrix P from matrix Q as per the given formula:
        P_ij = (q_ij^2 * f_j) / sum_j (q_ij^2 * f_j),
        where f_j is the soft cluster frequency for column j of Q.
        
        Parameters:
        Q (torch.Tensor): Input matrix Q of shape (n, m).
        
        Returns:
        torch.Tensor: Output matrix P of shape (n, m).
        """
        Q_squared = Q ** 2
        f_j = Q.sum(dim=0)  # Shape (m,)
        numerator = Q_squared/f_j
        denominator = numerator.sum(dim=1, keepdim=True)  # Shape (n, 1)
        P = numerator / denominator
        return P

    def forward(self, H):
        ''' The size of the Q matrix,
        will be N*C, where N is the number of nodes, and C is the number of clusters.'''
        # With DEC, here the paper scGAC uses a t-distribution kernel of dof 1.
        ## I'm going to use the same.
        self.clustering_method.fit(H)
        center_embeddings = self.clustering_method.cluster_centers_
        Q = self.calculate_membership_matrix(H,center_embeddings)
        P = self.calculate_auxilliary_matrix(Q)
        self.kl_loss = F.kl_div(P,Q)
        cluster_assignments = th.argmax(Q,dim=1)
        return (cluster_assignments,Q)

In [None]:
class scdEGA(nn.Module):
    def __init__(self,hidden_size,cellGraph,adjM,GeneGraph):
        '''GeneGraph is the gene embeddings from the node2vec model
           run on a PPI graph constructed from the gene interactions.
           This will be the gene matrix (so gene loss) we wish to reconstruct.
                      
           Cellmatrix_pca is the PCA reduced cell matrix.
           This will be the cell matrix (so cell loss) we wish to reconstruct.

           adjM is the adjacency matrix of the cell graph.
        '''
        super(scdEGA,self).__init__()
        self.gc = cellGraph
        self.gg = GeneGraph
        cell_embeddingsize = cellGraph.shape[1]
        gene_embeddingsize = GeneGraph.shape[1]
        self.encoder = encoder(adjM,cell_embeddingsize)
        self.decoder = gene_decoder(adjM,gene_embeddingsize,GeneGraph)
        self.clustering_layer = clustering_layer(hidden_size)
    def forward(self, H,A=None):
        ## Self-supervised optmization part
        cell_embeddings = self.encoder(H,A)
        reconstructed_cell_matrix,reconstructed_gene_matrix = self.decoder(cell_embeddings,A)
        self.reconstruction_cell_loss = F.cosine_similarity(reconstructed_cell_matrix,self.gc)
        self.reconstruction_gene_loss = F.mean_absolute_error(reconstructed_gene_matrix,self.gg)
        self.selfsupervised_loss = self.reconstruction_cell_loss+self.reconstruction_gene_loss
        ## Clustering optimization part
        cluster_assignments,Q = self.clustering_layer(cell_embeddings)
        self.kl_loss = self.clustering_layer.kl_loss
        self.total_loss = self.selfsupervised_loss+self.kl_loss
        return (reconstructed_cell_matrix,reconstructed_gene_matrix,cluster_assignments,Q)



In [10]:
import torch as th

def calculate_membership_matrix(cell_embeddings, center_embeddings):
    ''' 
    This function calculates the membership matrix for the cell embeddings.
    The membership matrix is of size N*C, where N is the number of nodes, and C is the number of clusters.
    The membership matrix is calculated using the t-distribution kernel, with dof 1.
    '''
    
    # Calculate the pairwise distances between cell embeddings and center embeddings
    # Expand dimensions to enable broadcasting
    cell_embeddings_expanded = cell_embeddings.unsqueeze(1)  # Shape: (N, 1, D)
    center_embeddings_expanded = center_embeddings.unsqueeze(0)  # Shape: (1, C, D)
    
    # Compute distances using broadcasting
    distances = th.norm(cell_embeddings_expanded - center_embeddings_expanded, dim=2)  # Shape: (N, C)
    
    # Calculate membership values
    membership_matrix = 1 / (1 + distances)  # Shape: (N, C)

    return membership_matrix
cell_embeddings = th.randn(10, 64)
center_embeddings = th.randn(5, 64)
print(calculate_membership_matrix(cell_embeddings, center_embeddings).shape)

torch.Size([10, 5])


In [None]:
class CellGraphDataset(NodeLoader):
    def __init__(self,cell_embeddings,adjM):
        self.cell_embeddings = cell_embeddings
        self.adjM = adjM
    def __getitem__(self,idx):
        return (self.cell_embeddings,self.adjM)
    def __len__(self):
        return 1