# Predicting individual cell responses to pharmacologic compounds with Shennong framework
### Visualization the latent space of the dataset and the influence term score for each cell

In [1]:
import warnings
warnings.simplefilter(action='ignore')

In [2]:
import scanpy as sc
import torch
import scarches as sca
import pandas as pd
import numpy as np
import gdown

 captum (see https://github.com/pytorch/captum).
INFO:lightning_fabric.utilities.seed:Global seed set to 0


In [3]:
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sb
import os
from scarches.plotting.terms_scores import plot_abs_bfs_key

In [4]:
sc.set_figure_params(frameon=False)
sc.set_figure_params(dpi=300)
sc.set_figure_params(figsize=(5, 5))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)
plt.rcParams["axes.grid"] = False
matplotlib.rc('ytick', labelsize=14)
matplotlib.rc('xtick', labelsize=14)

In [5]:
sc.settings.verbosity = 3

## Data info and model dir

In [6]:
adata_train = 'train_gmt.h5ad'
adata_train_umap = 'train_gmt_umap.h5ad'
adata_train_umap_metadata = 'train_gmt_umap_metadata.csv'

adata_query = 'query.h5ad'
adata_query_output = 'query_gmt.h5ad'
adata_query_metadata = 'query_gmt_umap_metadata.csv'

adata_train_query = 'train_query_gmt.h5ad'
adata_train_query_metadata = 'train_query_gmt_umap_metadata.csv'

latents_sub_adata_dir = 'latents_sub.h5ad'

In [7]:
model_train_dir = 'train_normal'
model_query_dir = 'query_normal'

In [8]:
output_dir = 'output/'
output_gene_dir = 'output/term_gene/'

if (os.path.exists(output_dir) != True):
    os.mkdir(output_dir)
if (os.path.exists(output_gene_dir) != True):
    os.mkdir(output_gene_dir)

In [9]:
MEAN=False

## expiMap load model

In [10]:
query = sc.read(adata_query_output)

In [11]:
q_intr_cvae = sca.models.EXPIMAP.load(model_query_dir, query)

AnnData object with n_obs × n_vars = 42517 × 8500
    obs: 'Cell', 'nCount_RNA', 'nFeature_RNA', 'reads', 'depth', 'percent.mt', 'Sample_Name', 'Patient', 'Cancer_Type', 'Tissue_Source', 'Cluster', 'Celltype', 'Celllineage', 'Annotation', 'sample_cluster', 'sample_celltype', 'sample_lineage', 'Tissue', 're_clusters', 're_annotation', 're_cluster_annotation', 're_clusters_raw', 'Malignangt', 're_cluster_merge', 're_cluster_merge2', 'batch'
    uns: 'terms'
    layers: 'counts'

INITIALIZING NEW NETWORK..............
Encoder Architecture:
	Input Layer in, out and cond: 8500 512 3
	Hidden Layer 1 in/out: 512 512
	Hidden Layer 2 in/out: 512 512
	Hidden Layer 3 in/out: 512 512
	Hidden Layer 4 in/out: 512 512
	Mean/Var Layer in/out: 512 17276
Decoder Architecture:
	Masked linear layer in, ext_m, ext, cond, out:  17276 0 0 3 8500
	with hard mask.
Last Decoder layer: softmax


### load data

In [14]:
adata = sc.read_h5ad(adata_train)

In [None]:
query_pbmc = sc.AnnData.concatenate(adata, query, batch_key='batch_join', uns_merge='same')

In [None]:
query_pbmc.obsm['X_cvae'] = q_intr_cvae.get_latent(query_pbmc.X, 
                                                   query_pbmc.obs['Tissue_Source'], 
                                                   mean=MEAN, 
                                                   only_active=True)

In [None]:
q_intr_cvae.latent_directions(adata = query_pbmc)

In [None]:
query_pbmc.obsm['X_cvae_direction'] = query_pbmc.obsm['X_cvae'] * query_pbmc.uns['directions']

In [None]:
query_pbmc

## Analysis of the extension nodes for reference + query dataset

## cacluate latents in singel cell

In [22]:
latents = (q_intr_cvae.get_latent(query_pbmc.X, query_pbmc.obs['Tissue_Source'], mean=MEAN) * query_pbmc.uns['directions'])

In [23]:
latents.shape

(388646, 17276)

### significant differents terms latents

In [1]:
terms = query_pbmc.uns['terms']
terms = list(terms)
select_terms = pd.read_csv((output_dir + 'train_query_top_term.csv'), index_col = 0) 
idx = [terms.index(term) for term in select_terms]

In [26]:
latents_sub = latents[:, idx]

In [30]:
latents_sub_adata = sc.AnnData(latents_sub, columns = select_terms)
latents_sub_adata.write(latents_sub_adata_dir, compression = True)