In [None]:
#clustering cells from GAE concatenated embeddings and comparison to scRNA clusters

In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader
from sklearn.cluster import KMeans
import pandas as pd
import numpy as np
import wandb
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
import scanpy as sc
from torch.utils.tensorboard import SummaryWriter


Code without tensorboard

In [107]:
class GAEModel(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GAEModel, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv2 = GCNConv(2 * out_channels, out_channels)
        self.decoder = torch.nn.Bilinear(out_channels, out_channels, 1)

    def encode(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)

    def decode(self, z, edge_index):
        src, dst = z[edge_index[0]], z[edge_index[1]]
        return torch.sigmoid(self.decoder(src, dst))

    def forward(self, x, edge_index):
        z = self.encode(x, edge_index)
        out = self.decode(z, edge_index)
        return out
    
    def get_node_embeddings(self, x, edge_index):
        return self.encode(x, edge_index)

def load_data(node_feature_path, edge_incidence_path):
    # Load gene expression data for each cell
    node_features = pd.read_csv(node_feature_path, index_col=0).head(1000)
    # Load edge incidence matrix (common to all cells)
    edge_incidence = pd.read_csv(edge_incidence_path, index_col=0).T

    # Create a node index map based on the gene order in node_features
    node_index_map = {gene: idx for idx, gene in enumerate(node_features.columns)}

    # Create edge_index from the edge incidence matrix
    edge_index = []
    for col in edge_incidence.columns:
        if col in node_index_map:
            target_idx = node_index_map[col]
            source_genes = edge_incidence.index[edge_incidence[col] == 1].tolist()
            for source in source_genes:
                if source in node_index_map:
                    source_idx = node_index_map[source]
                    edge_index.append([source_idx, target_idx])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    print('edge_index', edge_index)
    print('edge_index shape', edge_index.shape)

    # Create a list of Data objects for each cell
    data_list = []
    for _, row in node_features.iterrows():
        x = torch.tensor(row.values, dtype=torch.float).unsqueeze(1)
        data = Data(x=x, edge_index=edge_index)
        data_list.append(data)
    return data_list

def train(model, dataset, num_epochs=6):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.BCEWithLogitsLoss()
    
    train_size = int(0.8 * len(dataset))
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(num_epochs):
        model.train()
        for data in train_loader:
            optimizer.zero_grad()
            edge_predictions = model(data.x, data.edge_index).squeeze()
            loss = criterion(edge_predictions, edge_predictions.detach())
            loss.backward()
            optimizer.step()
            wandb.log({"loss": loss.item()})
        
        evaluate(model, criterion, val_loader)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}')
    return model

def evaluate(model, criterion, val_loader):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            edge_predictions = model(data.x, data.edge_index).squeeze()
            loss = criterion(edge_predictions, edge_predictions.detach())
            val_loss += loss.item()
    val_loss /= len(val_loader)
    print(f'Validation Loss: {val_loss:.4f}')
    wandb.log({"val_loss": val_loss})

def extract_cell_embeddings(model, dataset):
    model.eval()
    cell_embeddings = []
    with torch.no_grad():
        for data in dataset:
            z = model.get_node_embeddings(data.x, data.edge_index)
            cell_embedding = z.flatten().cpu().numpy()
            cell_embeddings.append(cell_embedding)
    return np.array(cell_embeddings)


if __name__ == "__main__":
    wandb.init(project="sc_GAE")
    
    # Load gene expression and interaction data
    data_list = load_data('/Users/work/Desktop/expression_matrix_yitao.csv', '/Users/work/Desktop/suberites_presence_absence_yitao.csv')

    model = GAEModel(in_channels=1, out_channels=8)
    trained_model = train(model, data_list, num_epochs=100)
    
    cell_embeddings = extract_cell_embeddings(trained_model, data_list)
 

0,1
loss,▁█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,0.69315
val_loss,0.69315


edge_index tensor([[  901,  1482,  1632,  ...,  7939,  8086,  8258],
        [    0,     0,     0,  ..., 13133, 13133, 13133]])
edge_index shape torch.Size([2, 82068])
Validation Loss: 0.6932
Epoch 1/100, Loss: 0.6932
Validation Loss: 0.6932
Epoch 2/100, Loss: 0.6931
Validation Loss: 0.6932
Epoch 3/100, Loss: 0.6932
Validation Loss: 0.6932
Epoch 4/100, Loss: 0.6931
Validation Loss: 0.6931
Epoch 5/100, Loss: 0.6932
Validation Loss: 0.6931
Epoch 6/100, Loss: 0.6931
Validation Loss: 0.6931
Epoch 7/100, Loss: 0.6931
Validation Loss: 0.6931
Epoch 8/100, Loss: 0.6931
Validation Loss: 0.6931
Epoch 9/100, Loss: 0.6931
Validation Loss: 0.6931
Epoch 10/100, Loss: 0.6931
Validation Loss: 0.6931
Epoch 11/100, Loss: 0.6931
Validation Loss: 0.6931
Epoch 12/100, Loss: 0.6931
Validation Loss: 0.6931
Epoch 13/100, Loss: 0.6931
Validation Loss: 0.6931
Epoch 14/100, Loss: 0.6931
Validation Loss: 0.6931
Epoch 15/100, Loss: 0.6931
Validation Loss: 0.6931
Epoch 16/100, Loss: 0.6931
Validation Loss: 0.6931
E

In [4]:
data_list = load_data('/Users/work/Desktop/expression_matrix_yitao.csv', '/Users/work/Desktop/suberites_presence_absence_yitao.csv')


edge_index tensor([[  901,  1482,  1632,  ...,  7939,  8086,  8258],
        [    0,     0,     0,  ..., 13133, 13133, 13133]])
edge_index shape torch.Size([2, 82068])


In [98]:
pd.read_csv("/Users/work/Desktop/expression_matrix_yitao.csv")

Unnamed: 0.1,Unnamed: 0,SUB2.g1,SUB2.g2,SUB2.g3,SUB2.g4,SUB2.g5,SUB2.g6,SUB2.g7,SUB2.g8,SUB2.g9,...,SUB2.g13135,SUB2.g13136,SUB2.g13137,SUB2.g13138,SUB2.g13139,SUB2.g13140,SUB2.g13141,SUB2.g13142,SUB2.g13143,SUB2.g13144
0,AAACCCAAGGACAGCT-1,0.009687,-1.032545,-1.566363,1.018223,-0.933981,-0.410398,-0.349553,-0.444472,-0.132447,...,0.0,0.0,0.0,-0.293056,0.0,0.0,0.0,0.0,-0.365098,-0.652811
1,AAACCCAAGGGTTTCT-1,1.598354,-0.300789,-0.202161,-0.002592,0.181964,0.269252,-0.393041,0.037609,0.115123,...,0.0,0.0,0.0,-0.192047,0.0,0.0,0.0,0.0,-0.296943,-0.677305
2,AAACCCACAAATGGTA-1,-0.777102,-0.792809,-1.281675,1.262175,-0.185155,-0.726397,-0.455605,-0.671795,-0.448440,...,0.0,0.0,0.0,-0.315716,0.0,0.0,0.0,0.0,-0.205751,0.327006
3,AAACCCATCGAGAAGC-1,0.696619,-0.306533,0.700373,-1.587061,-0.054666,0.263420,-0.105024,0.250046,0.661372,...,0.0,0.0,0.0,-0.300775,0.0,0.0,0.0,0.0,-0.310817,-0.878208
4,AAACCCATCTTGTTAC-1,-1.319795,-1.191836,-1.658408,1.626106,-0.142163,-1.176196,-0.391015,-0.537126,-0.254368,...,0.0,0.0,0.0,-0.301799,0.0,0.0,0.0,0.0,-0.303617,-0.420889
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2666,TTTGGTTGTGATGAAT-1,1.373431,2.000000,1.041506,-0.439186,0.454166,0.649058,2.000000,2.000000,0.291616,...,0.0,0.0,0.0,-0.320812,0.0,0.0,0.0,0.0,-0.169915,0.400232
2667,TTTGGTTGTTATAGCC-1,1.079402,0.559011,-0.162774,-0.475533,-0.359524,0.943268,-0.280327,-0.424414,-0.203679,...,0.0,0.0,0.0,0.842400,0.0,0.0,0.0,0.0,-0.301879,0.520434
2668,TTTGTTGAGAGAGTGA-1,-1.486299,-1.097147,-1.712721,1.434417,0.358839,-1.098839,-0.345341,-0.442162,-0.121341,...,0.0,0.0,0.0,-0.292260,0.0,0.0,0.0,0.0,-0.370699,-0.793447
2669,TTTGTTGAGAGGTCGT-1,1.075499,0.579794,1.183538,-0.750108,0.465700,1.618033,-0.163200,-0.444858,-0.234243,...,0.0,0.0,0.0,-0.237166,0.0,0.0,0.0,0.0,-0.313765,0.076373


In [20]:
from sklearn.cluster import SpectralClustering


def cluster_cells(cell_embeddings, n_clusters=2):
    spectral = SpectralClustering(n_clusters=n_clusters, affinity='nearest_neighbors', random_state=0)
    return spectral.fit_predict(cell_embeddings)  # Returns cluster labels for each cell


# def cluster_cells(cell_embeddings, n_clusters=2):
#     kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(cell_embeddings)
#     return kmeans.labels_  # Returns cluster labels for each cell

cluster_labels = cluster_cells(cell_embeddings, n_clusters=15)

   
#    Save clustering results
pd.DataFrame(cluster_labels, columns=['Cluster']).to_csv("cell_clusters.csv", index=False)
#cluster_labels = pd.read_csv("/Users/work/Library/Mobile Documents/com~apple~CloudDocs/Desktop/ADesktop/Studium/PhD/DataMining/TF_Gene_Interaction_Prediction/src/cell_clusters.csv")["Cluster"]
import scanpy as sc
adata = sc.read_h5ad("/Users/work/Desktop/subdom_processed.h5ad")
adata.obs_names = np.arange(0, len(adata)).astype(str)
labels_adata = adata.obs['clusters']

from sklearn.metrics import adjusted_rand_score

# Assuming `labels_ml` contains the cluster labels from your ML pipeline
# and that `labels_ml` and `labels_adata` are aligned and of the same length.

# Calculate Adjusted Rand Index
ari = adjusted_rand_score(labels_adata, cluster_labels)

# Print the ARI result
print(f"Adjusted Rand Index (ARI) spectral: {ari}")




# Comparison to un-embedded data #

In [7]:
adata = sc.read_h5ad("/Users/work/Desktop/subdom_processed.h5ad")

In [11]:
data_matrix =adata.raw.X.toarray()

In [8]:
data_matrix =adata.X


In [12]:
k = 15  # Specify the number of clusters
kmeans = KMeans(n_clusters=k, random_state=42)
adata.obs['kmeans_clusters'] = kmeans.fit_predict(data_matrix)
kmeans_pre_embedding = adata.obs['kmeans_clusters']

kmeans = KMeans(n_clusters=k, random_state=42, init='k-means++', n_init=10)
spherical_pre_embedding = kmeans.fit_predict(data_matrix)


In [13]:
from sklearn.metrics import adjusted_rand_score

# Assuming `labels_ml` contains the cluster labels from your ML pipeline
# and that `labels_ml` and `labels_adata` are aligned and of the same length.

# Calculate Adjusted Rand Index
ari = adjusted_rand_score(labels_adata, kmeans_pre_embedding)
ari2 = adjusted_rand_score(labels_adata, spherical_pre_embedding)


# Print the ARI result
print(f"Adjusted Rand Index (ARI): {ari,ari2}")


Adjusted Rand Index (ARI): (0.5623482165003406, 0.6224136833939046)


# time line analysis #

Average gene expression over clusters (and later metacells) to construct timeline interactions of TFs

In [9]:
adata = sc.read_h5ad("/Users/work/Desktop/subdom_processed.h5ad")
adata.obs_names = np.arange(0, len(adata)).astype(str)
labels_adata = adata.obs['clusters']

In [10]:
adata.var

Unnamed: 0,gene_ids,feature_types,genome,n_cells_by_counts,mean_counts,pct_dropout_by_counts,total_counts,highly_variable,means,dispersions,dispersions_norm,mean,std,ct_gene_corr,ct_correlates
SUB2.g1,SUB2.g1,Gene Expression,isoseq_reference,219,0.089854,91.800824,240.0,False,4.980643e-02,-4.542987,-0.922765,7.669924e-12,0.022066,0.668407,False
SUB2.g2,SUB2.g2,Gene Expression,isoseq_reference,364,0.163984,86.372145,438.0,True,8.974101e-02,-2.857633,0.520586,-2.021063e-11,0.062651,0.507098,False
SUB2.g3,SUB2.g3,Gene Expression,isoseq_reference,592,0.269188,77.836016,719.0,False,1.482865e-01,-3.145663,0.273915,-1.628408e-11,0.071940,0.673111,False
SUB2.g4,SUB2.g4,Gene Expression,isoseq_reference,258,0.107450,90.340696,287.0,False,6.187559e-02,-4.124707,-0.564547,-7.620040e-12,0.029855,-0.288837,False
SUB2.g5,SUB2.g5,Gene Expression,isoseq_reference,107,0.043804,95.994010,117.0,False,2.365992e-02,-4.478562,-0.867591,-1.315720e-12,0.015843,0.379239,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SUB2.g13140,SUB2.g13140,Gene Expression,isoseq_reference,0,0.000000,100.000000,0.0,False,1.000000e-12,,0.000000,0.000000e+00,1.000000,,False
SUB2.g13141,SUB2.g13141,Gene Expression,isoseq_reference,0,0.000000,100.000000,0.0,False,1.000000e-12,,0.000000,0.000000e+00,1.000000,,False
SUB2.g13142,SUB2.g13142,Gene Expression,isoseq_reference,0,0.000000,100.000000,0.0,False,1.000000e-12,,0.000000,0.000000e+00,1.000000,,False
SUB2.g13143,SUB2.g13143,Gene Expression,isoseq_reference,2,0.000749,99.925122,2.0,False,5.678186e-04,-4.993979,-1.308999,1.392674e-13,0.001943,0.188750,False


In [11]:
import pandas as pd
import numpy as np

# Assuming `adata` is your AnnData object and cluster labels are in `adata.obs['clusters']`
clusters = adata.obs['clusters']

# Convert the expression matrix to a DataFrame for easy manipulation
expression_df = pd.DataFrame(
    adata.X.toarray() if hasattr(adata.X, "toarray") else adata.X, 
    index=adata.obs_names, 
    columns=adata.var_names
)

# Add cluster information to the expression DataFrame
expression_df['cluster'] = clusters.values

# Group by cluster and calculate the mean for each gene
cluster_avg_expression = expression_df.groupby('cluster').mean()

# Drop the cluster column after grouping
cluster_avg_expression = cluster_avg_expression.drop(columns='cluster', errors='ignore')


  cluster_avg_expression = expression_df.groupby('cluster').mean()


In [30]:
#cluster_avg_expression.to_csv("cluster_avg_expression.csv")
cluster_avg_expression = pd.read_csv("/Users/work/Library/Mobile Documents/com~apple~CloudDocs/Desktop/ADesktop/Studium/PhD/DataMining/TF_Gene_Interaction_Prediction/src/cluster_avg_expression.csv")



two cluster approach

In [76]:
import pandas as pd

# Load your data (replace with actual file paths)
gene_expression = cluster_avg_expression.T
adj_matrix = pd.read_csv("/Users/work/Desktop/suberites_presence_absence_yitao.csv",index_col=0)
# Select only the relevant cell types (0 and 7)
expression_filtered = gene_expression[[0, 7]]
# Convert adjacency matrix to a long format (edges list)
edges = adj_matrix.stack().reset_index()
edges.columns = ["outgoing_gene", "incoming_gene", "interaction"]
edges = edges[edges["interaction"] == 1]  # Keep only existing interactions

# Merge to get expression values
edges = edges.merge(expression_filtered[0], left_on="outgoing_gene", right_index=True)
edges = edges.merge(expression_filtered[7], left_on="incoming_gene", right_index=True)

# Rename columns for clarity
edges = edges.rename(columns={0: "outgoing_gene_expr", 7: "incoming_gene_expr"})

del edges["interaction"]
# Save the final table
edges.to_csv("filtered_gene_interactions.csv", index=False)

print(edges.head())


   outgoing_gene incoming_gene  outgoing_gene_expr  incoming_gene_expr
3        SUB2.g1     SUB2.g902            0.466621            0.031868
11       SUB2.g1    SUB2.g1483            0.466621           -0.521559
13       SUB2.g1    SUB2.g1633            0.466621           -0.435638
14       SUB2.g1    SUB2.g1666            0.466621           -0.281713
16       SUB2.g1    SUB2.g1976            0.466621            0.290553


In [79]:
edges.columns

Index(['outgoing_gene', 'incoming_gene', 'outgoing_gene_expr',
       'incoming_gene_expr'],
      dtype='object')

## conditional probability ##

In [None]:
# take all cells of both of the clusters, train MLP on these, cluster the conditional probabilities;
# P(cell5 | gene 1)


#first take all the cells for the separate clusters, only using the genes from cluster 0 that appear in the adjacency matrix
#the caculate the conditional probailites and train the MLP

In [105]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define an MLP for conditional probability estimation
class ConditionalMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ConditionalMLP, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Softmax(dim=-1)  # Outputs conditional probabilities
        )

    def forward(self, x):
        return self.network(x)

# Example usage
model = ConditionalMLP(input_dim=2, hidden_dim=32, output_dim=3)  # 3 output clusters
x = torch.randn(10, 2)  # 10 samples with 2 features each
probs = model(x)
print(probs)  # Outputs probability distribution over 3 clusters


tensor([[0.3707, 0.3358, 0.2935],
        [0.3051, 0.3007, 0.3942],
        [0.2773, 0.2813, 0.4414],
        [0.4571, 0.3270, 0.2159],
        [0.3077, 0.3373, 0.3550],
        [0.3444, 0.3422, 0.3134],
        [0.3614, 0.3331, 0.3055],
        [0.3348, 0.3431, 0.3222],
        [0.2486, 0.2454, 0.5060],
        [0.3400, 0.3414, 0.3186]], grad_fn=<SoftmaxBackward0>)
