In [1]:
import numpy as np
import pandas as pd
import networkx as nx
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
import igviz as ig
from node2vec import Node2Vec
from gensim.models import KeyedVectors
import seaborn as sns

%load_ext autoreload
%autoreload 2


# import src.preprocess as pre
# import src.visualize as vis
# pio.renderers.default = "png"

from src import models, training
from torch_geometric.nn import GAE
import torch
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
preprocess_output_path = "./out/preprocessing_output/"
training_output_path = "./out/training_output/"
celllevel_data, genelevel_data = training.create_pyg_data(preprocess_output_path)
data = (celllevel_data, genelevel_data)

data

(Data(x=[1597, 125], edge_index=[2, 7985], y=[1597, 1597]),
 Data(x=[71865, 64], edge_index=[2, 1824841]))

In [3]:
def build_clarifyGAE_pytorch(data, hyperparams = None):
    num_cells, num_cellfeatures = data[0].x.shape[0], data[0].x.shape[1]
    num_genes, num_genefeatures = data[1].x.shape[0], data[1].x.shape[1]
    hidden_dim = hyperparams["concat_hidden_dim"] // 2
    num_genespercell = hyperparams["num_genespercell"]

    cellEncoder = models.GraphEncoder(num_cellfeatures, hidden_dim)
    geneEncoder = models.SubgraphEncoder(num_features=num_genefeatures, hidden_dim=hidden_dim, num_vertices = num_cells, num_subvertices = num_genespercell)
    
    multiviewEncoder = models.MultiviewEncoder(SubgraphEncoder = geneEncoder, GraphEncoder = cellEncoder)
    gae = GAE(multiviewEncoder)

    return gae


hyperparameters = {
    "num_genespercell": 45,
    "concat_hidden_dim": 64,
    "optimizer" : "adam",
    "criterion" : torch.nn.BCELoss(),
    "num_epochs": 400
}

trained_gae = build_clarifyGAE_pytorch(data, hyperparameters)
trained_gae.load_state_dict(torch.load(os.path.join(training_output_path,f'trained_gae_model.pth')))
trained_gae.eval()

GAE(
  (encoder): MultiviewEncoder(
    (encoder_g): SubgraphEncoder(
      (conv1): GCNConv(64, 32)
      (conv2): GCNConv(32, 32)
      (linear): Linear(in_features=1440, out_features=32, bias=True)
    )
    (encoder_c): GraphEncoder(
      (conv1): GCNConv(125, 32)
      (conv2): GCNConv(32, 32)
    )
  )
  (decoder): InnerProductDecoder()
)

In [4]:
cell_level_encoder = trained_gae.encoder.encoder_c
gene_level_encoder = trained_gae.encoder.encoder_g
z, z_c, z_g = trained_gae.encode(data[0].x,data[1].x, data[0].edge_index, data[1].edge_index)

In [5]:
z,z.shape

(tensor([[-0.1112,  0.1229, -0.2332,  ...,  0.0674,  0.0020, -0.0342],
         [-0.1296,  0.1299, -0.2572,  ...,  0.0346,  0.0223, -0.0537],
         [-0.0744,  0.1145, -0.1754,  ...,  0.0124,  0.0714, -0.0825],
         ...,
         [ 0.0970,  0.0127,  0.1533,  ...,  0.0845, -0.0419,  0.0275],
         [-0.0340, -0.0606, -0.3503,  ..., -0.0693,  0.0326, -0.0410],
         [ 0.0226, -0.0130, -0.4287,  ..., -0.1413,  0.1442, -0.1397]],
        grad_fn=<CatBackward0>),
 torch.Size([1597, 64]))

In [6]:
z_c,z_c.shape

(tensor([[-0.1112,  0.1229, -0.2332,  ...,  0.0888, -0.2396, -0.3305],
         [-0.1296,  0.1299, -0.2572,  ...,  0.0953, -0.2569, -0.3565],
         [-0.0744,  0.1145, -0.1754,  ...,  0.1117, -0.1832, -0.2226],
         ...,
         [ 0.0970,  0.0127,  0.1533,  ..., -0.1178, -0.5628,  0.2156],
         [-0.0340, -0.0606, -0.3503,  ...,  0.1191,  0.2159,  0.6880],
         [ 0.0226, -0.0130, -0.4287,  ...,  0.1142,  0.1929,  0.5404]],
        grad_fn=<AddBackward0>),
 torch.Size([1597, 32]))

In [7]:
z_g, z_g.shape

(tensor([[-0.0267, -0.0257,  0.0115,  ...,  0.0674,  0.0020, -0.0342],
         [ 0.0297, -0.0378, -0.0108,  ...,  0.0346,  0.0223, -0.0537],
         [ 0.0101, -0.0543, -0.0365,  ...,  0.0124,  0.0714, -0.0825],
         ...,
         [-0.0466,  0.0260,  0.0454,  ...,  0.0845, -0.0419,  0.0275],
         [ 0.0660, -0.0342, -0.0256,  ..., -0.0693,  0.0326, -0.0410],
         [ 0.1892, -0.1347, -0.1397,  ..., -0.1413,  0.1442, -0.1397]],
        grad_fn=<AddmmBackward0>),
 torch.Size([1597, 32]))

## Evaluate CCI Inference

## Evaluate GRN Refinement

In [19]:
import umap

numsubgraphs = 20
genespercell = 45

cellnumbers = []

gae_gene_embeddings = z_g[:genespercell*numsubgraphs, :].detach().numpy()


for i in range(numsubgraphs):
    val = f'Cell{i}'
    for j in range(genespercell):
        cellnumbers.append(val)

umap_manifold = umap.UMAP(n_neighbors=20, random_state=42).fit(gae_gene_embeddings)

df = pd.DataFrame({"UMAP1":umap_manifold.embedding_[:, 0], "UMAP2":umap_manifold.embedding_[:, 1], "Cell #":cellnumbers})
df


Unnamed: 0,UMAP1,UMAP2,Cell #
0,5.678921,9.090379,Cell0
1,2.931720,8.074990,Cell0
2,1.346900,6.735238,Cell0
3,7.117861,10.639055,Cell0
4,3.343302,10.291466,Cell0
...,...,...,...
895,7.183425,11.222568,Cell19
896,-0.579905,5.257681,Cell19
897,-0.820770,4.420348,Cell19
898,-1.018366,4.039659,Cell19


In [20]:
px.scatter(df,x="UMAP1",y="UMAP2", color="Cell #")


In [8]:
genelevel_edges = np.load(os.path.join(preprocess_output_path, "genelevel_edgelist.npy")).T
genelevel_graph = nx.from_edgelist(genelevel_edges)
genelevel_graph

<networkx.classes.graph.Graph at 0x7fc042e74760>

In [13]:
genelevel_adjmatrix = nx.to_numpy_array(genelevel_graph)
genelevel_adjmatrix.shape

(71865, 71865)

In [10]:
from scipy.linalg import block_diag

def create_intracellular_gene_mask(num_cells, num_genespercell):
  one_block = np.ones(shape=(num_genespercell,num_genespercell))
  block_list = [one_block for _ in range(num_cells)]
  return block_diag(*block_list)

intracellular_gene_mask = create_intracellular_gene_mask(1597, 45)
intracellular_gene_mask.shape

(71865, 71865)

In [12]:
genelevel_adjmatrix_masked = genelevel_adjmatrix[intracellular_gene_mask.astype(bool)]
np.save()

(3233925,)

In [14]:
genelevel_adjmatrix[intracellular_gene_mask.astype(bool)]

array([1., 1., 1., ..., 0., 0., 1.])

In [24]:
recon_grns = []
for i in range(z_g.shape[0]):
    z_gi = z_g[i].unsqueeze(1)
    recon_grns.append(z_gi @ z_gi.T)
    # 
torch.stack(recon_grns,dim=0).shape

torch.Size([1597, 32, 32])

In [15]:
np.load(os.path.join(preprocess_output_path, "initial_grns.npy")).flatten().shape

(3233925,)