In [None]:
import os 
import numpy as np
import pandas as pd
import anndata
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import matplotlib
matplotlib.rc('ytick', labelsize=14)
matplotlib.rc('xtick', labelsize=14)
sc.set_figure_params(dpi_save=300)

import sys
sys.path.append("../")
from scipy import stats
from scipy.sparse import issparse
import trvae
from trvae import pl
from trvae.models._trvae import CLDRCVAE 


In [None]:
path_to_save = "../results/Figures/Figure 5/"
os.makedirs(path_to_save, exist_ok=True)
sc.settings.figdir = path_to_save

In [None]:
train = sc.read("../data/train_species.h5ad")
train

In [None]:
train.obs.groupby(['condition', 'species']).size()

In [None]:
condition_key = "condition"
cell_type_key = "species"
stim_key = "LPS6"
ctrl_key = "unst"

## creating model object

In [None]:
adata = train[~((train.obs[cell_type_key] == "rat") & (train.obs[condition_key] == stim_key))]

In [None]:
conditions = adata.obs[condition_key].unique().tolist()
conditions

cell_types = adata.obs[cell_type_key].unique().tolist()
cell_types

In [None]:
network = CLDRCVAE(
    gene_size=adata.shape[1],
    architecture=[256, 64],
    n_topic=50,
    gene_names=adata.var_names.tolist(),
    conditions=conditions,
    cell_types=cell_types, 
    cell_type_key=cell_type_key,
    model_path='./models/CLDRCVAE/species/',
    dropout_rate=0.3,
    alpha=0.0001,
    beta=100,
    eta=100,
    contrastive_lambda=10.0,
    second_contrastive_lambda=10.0,
    topk=5,
    loss_fn='sse',
    output_activation='relu'
)

In [None]:
network.train(adata,
              condition_key,
              train_size=0.8,
              n_epochs=50,
              batch_size=512,
              early_stop_limit=50,
              lr_reducer=20,
              verbose=5,
              save=False
              )

## Making prediction
Evaluate the similarity between the predicted data generated by the model and the real data, and quantify the proximity of the generated data to the real data through correlation.

In [None]:
rabbit = train[train.obs["species"] == "rabbit"]
rabbit_cd = rabbit[rabbit.obs["condition"] == "unst"]
rabbit_stim = rabbit[rabbit.obs["condition"] == "LPS6"]
pig = train[train.obs["species"] == "pig"]
pig_cd = pig[pig.obs["condition"] == "unst"]
pig_stim = pig[pig.obs["condition"] == "LPS6"]
mouse = train[train.obs["species"] == "mouse"]
mouse_cd = mouse[mouse.obs["condition"] == "unst"]
mouse_stim = mouse[mouse.obs["condition"] == "LPS6"]
rat = train[train.obs["species"] == "rat"]
rat_cd = rat[rat.obs["condition"] == "unst"]
rat_stim = rat[rat.obs["condition"] == "LPS6"]

In [None]:
ground_truth = train[(train.obs[cell_type_key] == "rat")]

adata_source = train[(train.obs[cell_type_key] == "rat") & (train.obs[condition_key] == ctrl_key)]

predicted_data = network.predict(
    adata = adata_source, 
    condition_key=condition_key, 
    target_condition=stim_key
)

In [None]:
adata_pred = sc.AnnData(predicted_data)
adata_pred.obs[condition_key] = np.tile("predicted", len(adata_pred))
adata_pred.var_names = adata_source.var_names.tolist()

all_adata = ground_truth.concatenate(adata_pred)

## Mean correlation plot

In [None]:
sc.tl.rank_genes_groups(rat, groupby="condition", method="wilcoxon", n_genes=100)
gene_list = rat.uns['rank_genes_groups']['names']["LPS6"].tolist() 
gene_list[:5]

In [None]:
sc.tl.rank_genes_groups(rat, groupby="condition", method="wilcoxon", n_genes=10)
top_10_rat_gene_list = rat.uns['rank_genes_groups']['names']["LPS6"].tolist()
sc.tl.rank_genes_groups(rabbit, groupby="condition", method="wilcoxon", n_genes=10)
top_10_rabbit_gene_list = rabbit.uns['rank_genes_groups']['names']["LPS6"].tolist()
sc.tl.rank_genes_groups(mouse, groupby="condition", method="wilcoxon", n_genes=10)
top_10_mouse_gene_list = mouse.uns['rank_genes_groups']['names']["LPS6"].tolist()
sc.tl.rank_genes_groups(pig, groupby="condition", method="wilcoxon", n_genes=10)
top_10_pig_gene_list = pig.uns['rank_genes_groups']['names']["LPS6"].tolist()

In [None]:
print(sorted(top_10_rat_gene_list))
print(sorted(top_10_mouse_gene_list))
print(sorted(top_10_rabbit_gene_list))
print(sorted(top_10_pig_gene_list))

In [None]:
gene_list_4 = []   
gene_list_3 = []
gene_list_2 = []
gene_list_1 = []

all_genes = set(top_10_mouse_gene_list + top_10_pig_gene_list + top_10_rabbit_gene_list + top_10_rat_gene_list)

for gene in all_genes:
    counter = 0
    if top_10_mouse_gene_list.__contains__(gene):
        counter += 1
    if top_10_rat_gene_list.__contains__(gene):
        counter += 1
    if top_10_rabbit_gene_list.__contains__(gene):
        counter += 1
    if top_10_pig_gene_list.__contains__(gene):
        counter += 1
    

    if counter == 1:
        gene_list_1.append(gene)
    elif counter == 2:
        gene_list_2.append(gene)
    elif counter == 3:
        gene_list_3.append(gene)
    else:
        gene_list_4.append(gene)
    
print(gene_list_4)
print(gene_list_3)
print(gene_list_2)
print(gene_list_1)
gene_list_dot = gene_list_4 + gene_list_3 + gene_list_2 + gene_list_1

In [None]:
gene_list_dot

In [None]:
conditions = {"pred_stim": "predicted", "real_stim": "LPS6"}
matplotlib.rc('ytick', labelsize=16)   
matplotlib.rc('xtick', labelsize=16)    
trvae.pl.reg_mean_plot(all_adata, 
                             condition_key="condition",
                             axis_keys={"x": conditions["pred_stim"], "y": conditions["real_stim"]},
                             gene_list=gene_list[:5],
                             top_100_genes=gene_list,
                             path_to_save=os.path.join(path_to_save, f"Fig5b_study_reg_mean_all_genes.pdf"),
                             legend=False,
                             title="",
                             fontsize=15,
                             labels={"x":"predicted LPS", "y":"stimulation by LPS"},
                             show=True, 
                             x_coeff=0.40,
                             y_coeff=0.85,
                             range=[0, 7, 1])

In [None]:
conditions = {"pred_stim": "predicted", "real_stim": "LPS6"}
matplotlib.rc('ytick', labelsize=16)  
matplotlib.rc('xtick', labelsize=16)    
trvae.pl.reg_var_plot(all_adata,
                             condition_key="condition",
                             axis_keys={"x": conditions["pred_stim"], "y": conditions["real_stim"]},
                             gene_list=gene_list[:5],
                             top_100_genes=gene_list,
                             path_to_save=os.path.join(path_to_save, f"Fig5b_study_reg_mean_all_genes.pdf"),
                             legend=False,
                             title="",
                             fontsize=15,
                             labels={"x":"predicted LPS", "y":"stimulation by LPS"},
                             show=True, 
                             x_coeff=0.40,
                             y_coeff=0.85,
                             range=[0, 7, 1])

In [None]:
all_adata.obs['condition'].value_counts()

In [None]:
all_stim = sc.AnnData(np.concatenate([rat_stim.X.A, adata_pred.X ,rat_cd.X.A, rabbit_cd.X.A,
                                      rabbit_stim.X.A, pig_cd.X.A, pig_stim.X.A, mouse_cd.X.A,
                                      mouse_stim.X.A]))
all_stim.var_names = train.var_names
all_stim.obs["condition"] = ["rat_LPS6"] * len(rat_stim) + ["rat_LPS6_pred"] * len(adata_pred) + ["rat_ctrl"]*len(rat_cd) + ["rabbit_ctrl"]*len(rabbit_cd)  + ["rabbit_LPS6"] *len(rabbit_stim)\
+ ["pig_ctrl"]*len(pig_cd) + ["pig_LPS6"]*len(pig_stim) + ["mouse_ctrl"]*len(mouse_cd)\
+ ["mouse_LPS6"]*len(mouse_stim)

In [None]:
sc.set_figure_params(fontsize=14)
sc.pl.dotplot(all_stim, groupby="condition", var_names=gene_list_dot,
              save=f"_cross_species.pdf", use_raw=False)
os.rename(src=os.path.join(path_to_save, "dotplot_cross_species.pdf"), 
          dst=os.path.join(path_to_save, "Fig5c_dotplot_cross_species.pdf"))

In [None]:
all_stim.obs['condition'].value_counts()