In [1]:
import scanpy as sc
import gc

In [2]:
import pandas as pd
import csv
import networkx as nx

In [3]:
import torch.nn as nn
import torch 

# Model Definition for scEGA


In [None]:
class graphattention_layer(nn.Module):
    def __init__(self,input_size,numnodes,output_size,adjM):
        #WkH_{i-1} is out dimension : CurrentNodeShape x N
        self.inpshape = input_size
        self.opshape = output_size
        self.A = adjM
        super.__init__(graphattention_layer)
        self.vkt = nn.Linear(in_features=numnodes,output_size = 1)
        self.vkr = nn.Linear(in_features=numnodes,output_size = 1)
        self.W =  nn.Linear(in_features=input_size,out_features=output_size) 
    def forward(self, H_k):
        #H_k represents the previous layer's graph representation
        M_s = self.A * self.vkt(nn.ReLU(self.W(H_k))).T
        M_r = (self.A * self.vkr(nn.ReLU(self.W(H_k))).T).T
        Attention = nn.Softmax(nn.Sigmoid(M_s+M_r))
        H_new = nn.ReLU(self.W(H_k))@Attention
        return H_new

In [None]:
class encoder(nn.Module):
    def __init__(self,hidden_size,adjM):
        super.__init__(encoder)
        self.layer1 = graphattention_layer(input_size=adjM.shape[0]
                                           ,numnodes=adjM.shape(0),
                                           output_size=(adjM.shape(0))/8,adjM=adjM)
        self.layer2 = graphattention_layer(1,adjM)
        self.layer3 = graphattention_layer(1,adjM)
    def forward(self, H):
        H1 = self.layer1(H)
        H2 = self.layer2(H1)
        H3 = self.layer3(H2)
        return H3

In [None]:
class decoder(nn.Module):
    def __init__(self,hidden_size,adjM,GeneGraph):
        super.__init__(encoder)
        self.gGraph = GeneGraph
        self.layer1 = graphattention_layer(input_size=adjM.shape[0]
                                           ,numnodes=adjM.shape(0),
                                           output_size=(adjM.shape(0))/8,adjM=adjM)
        self.layer2 = graphattention_layer(1,adjM)
        self.layer3 = graphattention_layer(1,adjM)
    def forward(self, H):
        decodedPass = torch.stack(self.gGraph,H)
        H1 = self.layer1(decodedPass)
        H2 = self.layer2(H1)
        H3 = self.layer3(H2)
        return H3

In [None]:
class scdEGA(nn.Module):
    def __init__(self,hidden_size,adjM,GeneGraph):
        super.__init__(scdEGA)
        self.encoder = encoder(hidden_size,adjM)
        self.decoder = decoder(hidden_size,adjM,GeneGraph)
    def forward(self, H):
        H1 = self.encoder(H)
        H2 = self.decoder(H1)
        return H2

In [9]:
check = nn.Linear(in_features=5,out_features=1)
mat = torch.rand(size=(10,5)) # 5 nodes , each of a feature size of 10
print(check(mat))

tensor([[-0.1999],
        [-0.3973],
        [-0.8067],
        [-0.7447],
        [-0.5682],
        [-0.4899],
        [-0.5840],
        [-0.6114],
        [-0.6487],
        [-0.4664]], grad_fn=<AddmmBackward0>)


## Part 1:Data preprocessing

This will involve the following steps:
1. Obtaining a knn cell-cell graph
2. Obtaining a gene-gene graph based on protein-protein interactomes from https://string-db.org/cgi/network?taskId=bSmGYnEnGRS8&sessionId=bG2dm5uLdSGY

### Step 1 : Obtaining Cell matrix

In [4]:
adata = sc.read_text("./GSE57249_fpkm.txt")
adata = adata.T
sc.pp.pca(adata)

In [14]:
adata

AnnData object with n_obs × n_vars = 56 × 25737
    obs: 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb'
    var: 'mt', 'ribo', 'hb', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'pca'
    obsm: 'X_pca'
    varm: 'PCs'

In [6]:
# mitochondrial genes
adata.var["mt"] = adata.var_names.str.startswith("MT-")
# ribosomal genes
adata.var["ribo"] = adata.var_names.str.startswith(("RPS","RPL"))
# hemoglobin genes.
adata.var["hb"] = adata.var_names.str.contains(("^HB[^(P)]"))
sc.pp.calculate_qc_metrics(
    adata, qc_vars=["mt", "ribo", "hb"], inplace=True, percent_top=[20], log1p=True
)
adata

AnnData object with n_obs × n_vars = 56 × 25737
    obs: 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb'
    var: 'mt', 'ribo', 'hb', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'pca'
    obsm: 'X_pca'
    varm: 'PCs'

In [None]:
import seaborn as sns
p1 = sns.displot(adata.obs["total_counts"], bins=100, kde=False)
# sc.pl.violin(adata, 'total_counts')
p2 = sc.pl.violin(adata, "pct_counts_mt")
p3 = sc.pl.scatter(adata, "total_counts", "n_genes_by_counts", color="pct_counts_mt")

Theres like no mtrna lol.

In [None]:
p1 = sns.displot(adata.obs["total_counts"], bins=100, kde=False)
# sc.pl.violin(adata, 'total_counts')
p2 = sc.pl.violin(adata, "pct_counts_mt")
p3 = sc.pl.scatter(adata, "total_counts", "n_genes_by_counts", color="pct_counts_mt")

In [None]:

# Remove cells with a minimum number of expressed genes
sc.pp.filter_cells(adata, min_genes=3)

# Print the updated shape of the AnnData object
print("Updated shape:", adata.shape)

In [60]:
sc.pp.pca(adata)
sc.pp.neighbors(adata)
adj_list = adata.obsp["distances"].indices.reshape(adata.shape[0],-1)

In [61]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=2000)


  dispersion = np.log(dispersion)


In [65]:
sc.neighbors.neighbors(adata.T,5)

         Falling back to preprocessing with `sc.pp.pca` and default params.


Our adjacency matrix in this case is simply our knn matrix, where we store the indices("is:cells"), whose expressions are closest to the cell of that row.

In [66]:
A = torch.zeros(size=(adata.T.shape[1],adata.T.shape[1]))
for i in range(adj_list.shape[0]):
    for j in range(5):
        A[i][adj_list[i][j]] = 1

Now to obtain the gene-gene graph



In [74]:
with open("highly_variable_genes.txt", "w") as f:    
     f.write(','.join(adata.var_names[adata.var["highly_variable"]]))

In [77]:
edge_list = pd.DataFrame(csv.reader(open("string_interactions.tsv"),delimiter="\t"))
edge_list.columns = edge_list.iloc[0]
edge_list = edge_list.drop(0)

In [83]:
edge_list

Unnamed: 0,#node1,node2,node1_string_id,node2_string_id,neighborhood_on_chromosome,gene_fusion,phylogenetic_cooccurrence,homology,coexpression,experimentally_determined_interaction,database_annotated,automated_textmining,combined_score
1,0610010F05Rik,Ahsa2,10090.ENSMUSP00000044265,10090.ENSMUSP00000020529,0,0,0,0,0.091,0,0,0.392,0.423
2,1300017J02Rik,Trf,10090.ENSMUSP00000035163,10090.ENSMUSP00000035158,0,0,0.061,0.959,0.104,0,0.900,0,0.908
3,1300017J02Rik,Ica1,10090.ENSMUSP00000035163,10090.ENSMUSP00000040062,0,0,0,0,0,0,0,0.473,0.473
4,1600014C10Rik,Coasy,10090.ENSMUSP00000130271,10090.ENSMUSP00000102929,0,0,0,0,0.050,0,0,0.538,0.542
5,1600014C10Rik,BC048679,10090.ENSMUSP00000130271,10090.ENSMUSP00000120616,0,0,0,0,0,0,0,0.604,0.604
...,...,...,...,...,...,...,...,...,...,...,...,...,...
3458,Zfp933,Rpp25,10090.ENSMUSP00000101343,10090.ENSMUSP00000079358,0,0,0,0,0,0.073,0,0.388,0.408
3459,Zglp1,Figla,10090.ENSMUSP00000111157,10090.ENSMUSP00000032070,0,0,0,0,0.089,0,0,0.402,0.432
3460,Znrd1as,Rpp21,10090.ENSMUSP00000048695,10090.ENSMUSP00000025319,0,0,0,0,0.079,0,0,0.538,0.556
3461,Zswim2,Cct6b,10090.ENSMUSP00000044913,10090.ENSMUSP00000021040,0,0,0,0,0.052,0,0,0.511,0.517


In [91]:
gene_graph = nx.Graph()
gene_graph.add_edges_from(edge_list[["node1_string_id","node2_string_id"]].values)

In [None]:
from node2vec import Node2Vec
node2vec = Node2Vec(gene_graph, dimensions=64, walk_length=30, num_walks=200, workers=4)

In [None]:

for i in range(edge_list.shape[0]):
    gene_graph.add_edge(edge_list., edge_list.iloc[i,1])