In [1]:
import torch as th
import torch_geometric as tg
from torch import nn
from torch_geometric.nn import GATConv
from torch_geometric.data import Data, DataLoader
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import torch.optim as optim
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
import csv
from node2vec import Node2Vec


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class DataProcessor:
    def __init__(self, 
                 dataset, 
                 data_path, 
                 k=5, 
                 gene_embedding_dim=64, 
                 string_file="string_interactions.tsv", 
                 string_query=False, 
                 output_hvg="highly_variable_genes.txt"):
        """
        Initialize the data processor with the given parameters.

        Parameters
        ----------
        dataset : str
            Name of the dataset (e.g., 'Biase', 'Darmanis', etc.)
        data_path : str
            Path to the input data file.
        k : int
            Number of nearest neighbors for adjacency construction.
        gene_embedding_dim : int
            Dimension of the gene embeddings (e.g., 64).
        string_file : str
            Path to the STRING interactions file (TSV).
        string_query : bool
            If True, attempt to query the STRING database for interactions.
        output_hvg : str
            File to write HVGs.
        """
        self.dataset = dataset
        self.data_path = data_path
        self.k = k
        self.gene_embedding_dim = gene_embedding_dim
        self.string_file = string_file
        self.string_query = string_query
        self.output_hvg = output_hvg

    def load_dataset(self, dataset_name, data_path):
        """
        Load the dataset specified by dataset_name.
        Implement your own logic here depending on how your datasets are stored.
        """
        # Placeholder: For now assume a text file containing genes x cells
        print(f"this is  {data_path}")
        adata = sc.read_text(data_path)
        
        return adata

    def preprocess_data(self, adata):
        """
        Preprocess the data:
        - Transpose if necessary
        - PCA
        - Filter genes/cells
        - Compute QC metrics, etc.
        """
        adata = adata.T
        adata.var["mt"] = adata.var_names.str.startswith("MT-")
        adata.var["ribo"] = adata.var_names.str.startswith(("RPS","RPL"))
        adata.var["hb"] = adata.var_names.str.contains(("^HB[^(P)]"))

        sc.pp.calculate_qc_metrics(
            adata, qc_vars=["mt", "ribo", "hb"], inplace=True, percent_top=[20], log1p=True
        )

        # Basic filtering
        sc.pp.filter_cells(adata, min_genes=3)

        # PCA and neighbors
        sc.pp.pca(adata)
        sc.pp.neighbors(adata)

        # Normalize and log transform
        sc.pp.normalize_total(adata, target_sum=1e4)
        sc.pp.log1p(adata)

        # Find HVGs
        sc.pp.highly_variable_genes(adata, n_top_genes=2000)
        return adata

    def build_adjacency(self, adata, k):
        """
        Build an adjacency (KNN) graph for cells.
        """
        distances = adata.obsp["distances"]
        distances_csr = distances.tocsr()
        N = adata.shape[0]
        adj_list = []
        for i in range(N):
            row_indices = distances_csr[i].indices
            row_data = distances_csr[i].data
            # Pick top k neighbors
            k_neighbors = row_indices[row_data.argsort()[:k]]
            adj_list.append(k_neighbors)       
        A = th.zeros(size=(N, N))
        adj_list = np.array(adj_list)
        for i in range(adj_list.shape[0]):
            for j in range(k):
                A[i, adj_list[i][j]] = 1

        edge_list = th.zeros(size=(2, N*k))
        for i in range(N):
            for j in range(k):
                edge_list[0, i*k+j] = i
                edge_list[1, i*k+j] = adj_list[i][j]        
        return A, adj_list,edge_list

    def write_hvg_list(self, adata, output_file):
        hvgs = adata.var_names[adata.var["highly_variable"]]
        with open(output_file, "w") as f:
            f.write(",".join(hvgs))
        return hvgs

    def query_string_db(self, hvgs, output_file):
        """
        Query the STRING database for interactions among the given list of HVGs.
        This is pseudo-code. Implement it based on the STRING API.
        """
        # For now, assume string_interactions.tsv is already available.
        # In practice, you'd implement an API call here.
        pass

    def build_gene_graph_and_embed(self, string_file, emb_dim):
        """
        Build the gene graph from string_interactions.tsv and run node2vec.
        """
        edge_list = pd.DataFrame(csv.reader(open(string_file), delimiter="\t"))
        edge_list.columns = edge_list.iloc[0]
        edge_list = edge_list.drop(0)

        gene_graph = nx.Graph()
        gene_graph.add_edges_from(edge_list[["node1_string_id","node2_string_id"]].values)

        node2vec_model = Node2Vec(gene_graph, dimensions=emb_dim, walk_length=30, num_walks=200, workers=4)
        model = node2vec_model.fit(window=10, min_count=1, batch_words=4)
        node_embeddings = model.wv.vectors
        return node_embeddings

    def get_data(self):
        """
        Run the entire pipeline:
        - Load dataset
        - Preprocess
        - Build adjacency
        - Write HVG list
        - Optionally query STRING
        - Build gene embeddings
        Return adata, adjacency matrix A, and node_embeddings.
        """
        # 1. Load dataset
        adata = self.load_dataset(self.dataset, self.data_path)

        # 2. Preprocess
        adata = self.preprocess_data(adata)

        # 3. Build adjacency
        A, adj_list,edge_list = self.build_adjacency(adata, self.k)

        # 4. Write HVG list
        hvgs = self.write_hvg_list(adata, self.output_hvg)

        # 5. Optionally query STRING DB
        if self.string_query:
            self.query_string_db(hvgs, self.string_file)

        # 6. Build gene graph and node embeddings
        node_embeddings = self.build_gene_graph_and_embed(self.string_file, self.gene_embedding_dim)
        return adata, A, node_embeddings,edge_list


processor = DataProcessor(
    dataset="Biase",
    data_path="./tests/datastuff/GSE57249_fpkm.txt",
    k=5,
    gene_embedding_dim=64,
    string_file="./tests/datastuff/string_interactions.tsv",
    string_query=False,  # Set to True if you implement the query function
    output_hvg="./tests/datastuff/highly_variable_genes.txt"
)

In [16]:

class DeepEmbeddedClustering(nn.Module):
    def __init__(self, input_dim, hidden_dims, n_clusters,gene_numwalkers, alpha=1.0):
        """
        Deep Embedded Clustering Model
        
        Args:
        - input_dim (int): Dimension of input features
        - hidden_dims (list): List of hidden layer dimensions
        - n_clusters (int): Number of clusters
        - alpha (float): Hyperparameter for soft assignment
        """
        super(DeepEmbeddedClustering, self).__init__()
        
        # Encoder layers
        encoder_layers = []
        decoder_layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            encoder_layers.append(GATConv(prev_dim, hidden_dim))
            decoder_layers.append(nn.Linear(hidden_dim, prev_dim))
            prev_dim = hidden_dim
        
        decoder_layers = decoder_layers[::-1]
        self.encoder = encoder_layers
        self.decoder = decoder_layers

        self.genedecoder = nn.Linear(input_dim, gene_numwalkers)
        

        # Clustering layer
        self.clustering_layer = nn.Linear(hidden_dims[-1], n_clusters, bias=False)
        
        # Hyperparameters
        self.alpha = alpha
        self.n_clusters = n_clusters

    def encode(self, x, graphedge_list):
        """Forward pass through the encoder"""
        for layer in self.encoder:
            x = layer(x, graphedge_list)
        return x
    def decode(self, x, gene_embeddings):
        """Forward pass through the decoder
        gene embedding is of size (n_genes, hidden_dim) , while x is of size (n_cells, hidden_dim), so we append it.
        And the final x size is (n_cells+n_genes, hidden_dim)"""
        n_cells = x.size(0)
        x = th.cat((x,gene_embeddings),0)
        for layer in self.decoder:
            x = layer(x)
        ''' Our output will be of size (n_cells+n_genes, input_dim), so we need to split it into two parts, one for cells and one for genes'''
        x_cells = x[:n_cells]
        x_genes = x[n_cells:]
        x_genes = self.genedecoder(x_genes)
        return x_cells, x_genes
     
    def forward(self, x,graph_edge_list):
        """Forward pass through the network"""
        z = self.encode(x,graph_edge_list)
        q = self._soft_assignment(z)
        return z, q
    def _soft_assignment(self, z):
        """Compute soft assignment probabilities"""
        # Compute similarity between embedded points and cluster centers
        weights = self.clustering_layer.weight
        q = 1.0 / ((1.0 + th.sum((z.unsqueeze(1) - weights)**2, dim=2) / self.alpha) + 1e-8) ** ((self.alpha + 1.0) / 2.0)
        q = q / th.sum(q, dim=1, keepdim=True)
        return q
    
    def target_distribution(self, q):
        """
        Compute target distribution (sharpened version of q)
        
        Args:
        - q (th.Tensor): Soft assignment probabilities
        
        Returns:
        - p (th.Tensor): Sharpened target distribution
        """
        p = q**2 / th.sum(q, dim=0)
        p = p / th.sum(p, dim=1, keepdim=True)
        return p.detach()

In [34]:

def pretrain_autoencoder(model, data,gene_embeddings,edge_list, epochs=50, lr=1e-3):
    """
    Pretrain autoencoder for initial feature extraction
    
    Args:
    - model (DeepEmbeddedClustering): DEC model
    - data (th.Tensor): Input data
    - epochs (int): Number of pretraining epochs
    - lr (float): Learning rate
    
    Returns:
    - Pretrained model
    """
    cosine_loss = nn.CosineSimilarity()
    mae_loss = nn.L1Loss()
    total_loss = 0

    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        z = model.encode(data,edge_list)
        x_cells,x_genes = model.decode(z,gene_embeddings)
        loss = -1*th.sum(cosine_loss(data,x_cells)) + mae_loss(x_genes,gene_embeddings)
        #print(loss)
        loss.backward()
        optimizer.step()
        if epoch % 10 == 0:
            print(f'Pretraining Epoch [{epoch}/{epochs}], Loss: {loss.item():.4f}')
    
    return model

def dec_loss(p, q):
    """
    Compute DEC loss (KL divergence between p and q)
    
    Args:
    - p (th.Tensor): Target distribution
    - q (th.Tensor): Soft assignment probabilities
    
    Returns:
    - loss (th.Tensor): KL divergence loss
    """
    return th.mean(th.sum(p * th.log(p / q), dim=1))

def train_dec(model,edge_list, data, epochs=100, lr=1e-3):
    """
    Train Deep Embedded Clustering model
    
    Args:
    - model (DeepEmbeddedClustering): DEC model
    - data (th.Tensor): Input data
    - epochs (int): Number of training epochs
    - lr (float): Learning rate
    
    Returns:
    - Trained model and cluster assignments
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Initialize cluster centers using KMeans
    with th.no_grad():
        z = model.encode(data,edge_list)
        kmeans = KMeans(n_clusters=model.n_clusters, n_init=20)
        cluster_labels = kmeans.fit_predict(z.numpy())
        
        # Set initial cluster centers
        model.clustering_layer.weight.copy_(
            th.from_numpy(kmeans.cluster_centers_).float()
        )
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        
        # Forward pass
        z, q = model(data,edge_list)
        
        # Compute target distribution
        p = model.target_distribution(q)
        
        # Compute loss
        loss = dec_loss(p, q)
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            print(f'DEC Epoch [{epoch}/{epochs}], Loss: {loss.item():.4f}')
    
    # Get final cluster assignments
    _, q = model(data,edge_list)
    cluster_assignments = th.argmax(q, dim=1)
    
    return model, cluster_assignments



In [5]:
# Simulated data
processor = DataProcessor(
    dataset="Biase",
    data_path="./datasets/Biase/GSE57249_fpkm.txt",
    k=5,
    gene_embedding_dim=64,
    string_file="./datasets/Biase/string_interactions.tsv",
    string_query=False,  # Set to True if you implement the query function
    output_hvg="./datasets/Biase/highly_variable_genes.txt"
)
adata, A, node_embeddings,edge_list = processor.get_data()
# so our input is the sc matrix of only the hvg genes
hvg = adata.var['highly_variable']
adata_hvg = adata[:, hvg].X
data = th.tensor(adata_hvg)


this is  ./datasets/Biase/GSE57249_fpkm.txt


  dispersion = np.log(dispersion)
Computing transition probabilities: 100%|██████████| 647/647 [00:00<00:00, 5343.35it/s]


In [37]:
input_dim = data.shape[1]
hidden_dims = [256, 128, 64]
n_clusters = 3
gene_numwalkers = node_embeddings.shape[1]
# Initialize model

In [47]:
model = DeepEmbeddedClustering(input_dim, hidden_dims, n_clusters,gene_numwalkers)
edge_list = edge_list.int()
pretrain_autoencoder(model,data,th.tensor(node_embeddings),edge_list,epochs=500,lr=2e-3)

# Train model
trained_model, cluster_labels = train_dec(model,edge_list ,data,epochs=2000,lr=1e-3)

print("Clustering complete. Cluster labels:", cluster_labels)

# Compare with sklearn clustering
scaler = StandardScaler()
data_scaled = scaler.fit_transform(data.numpy())

# Fit KMeans from sklearn
kmeans_sklearn = KMeans(n_clusters=n_clusters, n_init=20)
cluster_labels_sklearn = kmeans_sklearn.fit_predict(data_scaled)

# Print comparison
print("Cluster labels from DEC model:", cluster_labels.numpy())
print("Cluster labels from sklearn KMeans:", cluster_labels_sklearn)

# Compute ARI and NMI


Pretraining Epoch [0/500], Loss: 0.8628
Pretraining Epoch [10/500], Loss: 0.6943
Pretraining Epoch [20/500], Loss: 0.6389
Pretraining Epoch [30/500], Loss: 0.6056
Pretraining Epoch [40/500], Loss: 0.5822
Pretraining Epoch [50/500], Loss: 0.5645
Pretraining Epoch [60/500], Loss: 0.5508
Pretraining Epoch [70/500], Loss: 0.5408
Pretraining Epoch [80/500], Loss: 0.5347
Pretraining Epoch [90/500], Loss: 0.5286
Pretraining Epoch [100/500], Loss: 0.5284
Pretraining Epoch [110/500], Loss: 0.5266
Pretraining Epoch [120/500], Loss: 0.5250
Pretraining Epoch [130/500], Loss: 0.5298
Pretraining Epoch [140/500], Loss: 0.5330
Pretraining Epoch [150/500], Loss: 0.5311
Pretraining Epoch [160/500], Loss: 0.5305
Pretraining Epoch [170/500], Loss: 0.5276
Pretraining Epoch [180/500], Loss: 0.5283
Pretraining Epoch [190/500], Loss: 0.5277
Pretraining Epoch [200/500], Loss: 0.5266
Pretraining Epoch [210/500], Loss: 0.5237
Pretraining Epoch [220/500], Loss: 0.5253
Pretraining Epoch [230/500], Loss: 0.5257
Pre

In [48]:
celltypes = pd.read_csv("./datasets/Biase/subtype.ann",delimiter="\t",header=None)
celltypes.columns = ['cell','type']


In [49]:


ari = adjusted_rand_score(cluster_labels[:49],celltypes.values[:,1][1:])
nmi = normalized_mutual_info_score(cluster_labels[:49],celltypes.values[:,1][1:])

print(f"Adjusted Rand Index (ARI): {ari:.4f}")
print(f"Normalized Mutual Information (NMI): {nmi:.4f}")

print("Sklearn ARI and NMI:")
ari = adjusted_rand_score(cluster_labels_sklearn[:49],celltypes.values[:,1][1:])
nmi = normalized_mutual_info_score(cluster_labels_sklearn[:49],celltypes.values[:,1][1:])

print(f"Adjusted Rand Index (ARI): {ari:.4f}")
print(f"Normalized Mutual Information (NMI): {nmi:.4f}")




Adjusted Rand Index (ARI): 0.3012
Normalized Mutual Information (NMI): 0.4388
Sklearn ARI and NMI:
Adjusted Rand Index (ARI): 0.3746
Normalized Mutual Information (NMI): 0.6023


In [27]:
input1 = th.randn(100, 128)
input2 = th.randn(100, 128)
cos = nn.CosineSimilarity(dim=0, eps=1e-6)
output = cos(input1, input2)

In [30]:
output

tensor([ 6.0687e-02, -3.8181e-02,  1.0573e-01,  6.7259e-02,  1.7614e-02,
        -1.7498e-01, -5.8014e-02,  3.4063e-02,  1.7652e-02, -6.9140e-02,
        -2.9786e-02, -1.0883e-01, -8.3510e-02,  1.9375e-01, -3.4685e-02,
         7.4517e-02,  6.4099e-02,  7.9769e-02, -3.9772e-02, -9.3325e-02,
        -5.3324e-02,  1.0575e-01, -5.2770e-02, -2.6001e-03,  1.2925e-01,
         4.9770e-02,  1.0250e-02, -1.3765e-01, -2.3494e-01,  1.2739e-02,
        -1.1573e-01, -8.1230e-02, -1.0116e-01,  1.3597e-01, -1.1521e-01,
        -6.0887e-02, -1.6510e-02,  1.5036e-01,  1.1067e-02,  5.6749e-02,
        -1.4481e-01, -1.5417e-02, -2.6879e-01, -2.0842e-02,  7.9560e-02,
        -6.4373e-03,  1.1987e-01, -1.1367e-01, -6.6169e-02, -4.3092e-02,
         2.6214e-01,  4.9591e-02, -1.0859e-01, -2.7429e-02, -1.9571e-02,
         1.1765e-01,  5.8556e-02, -1.1562e-01,  4.0295e-02,  2.2053e-01,
         4.1145e-02, -1.3348e-02,  1.3305e-02, -3.6769e-03,  5.0300e-02,
        -5.6241e-02,  2.3380e-02,  4.1098e-02, -7.4

In [None]:

ari = adjusted_rand_score(cluster_labels.numpy())
nmi = normalized_mutual_info_score(cluster_labels.numpy)

print(f"Adjusted Rand Index (ARI): {ari:.4f}")
print(f"Normalized Mutual Information (NMI): {nmi:.4f}")


