In [None]:
import sys
sys.path.append("../")
import anndata
import pandas as pd
import scanpy as sc
import numpy as np
from scipy import stats
from scipy.sparse import issparse
import trvae
from trvae import pl
from trvae.models._trvae import CLDRCVAE

## loading and preparing data 

In [None]:
adata = sc.read("../data/haber_count.h5ad")
sc.pp.normalize_per_cell(adata)   
sc.pp.log1p(adata)                  
sc.pp.highly_variable_genes(adata, n_top_genes=2000)    
adata = adata[:, adata.var['highly_variable']]
n_conditions = adata.obs["condition"].unique().shape[0]

In [None]:
condition_key = "condition"
cell_type_key= "cell_label"

In [None]:
adata_train = adata[~((adata.obs["cell_label"] == "Tuft") & (adata.obs["condition"] == "Hpoly.Day10"))]

In [None]:
sc.pp.neighbors(adata_train)
sc.tl.umap(adata_train)

In [None]:
sc.pl.umap(adata_train, color=["condition", "cell_label"], wspace=.4)

## creating model object 

In [None]:
conditions = adata.obs[condition_key].unique().tolist()
conditions

In [None]:
cell_types = adata.obs[cell_type_key].unique().tolist()
cell_types

In [None]:
network = CLDRCVAE(
    gene_size=adata.shape[1],
    architecture=[256, 64],
    n_topic=50,
    gene_names=adata.var_names.tolist(),
    conditions=conditions,
    cell_types=cell_types,  
    cell_type_key=cell_type_key,
    model_path='./models/CLDRCVAE/haber/',
    dropout_rate=0.1,
    alpha=0.0001,
    beta=50,
    eta=100,
    contrastive_lambda=10.0,
    topk=5,
    loss_fn='sse', 
    output_activation='relu'
)

### Training CLDRCVAE

In [None]:
network.train(adata,
              condition_key,
              train_size=0.8,
              n_epochs=300,
              batch_size=512,
              early_stop_limit=50,
              lr_reducer=20,
              verbose=5,
              save=False
              )

## visualizing the latent space

In [None]:
latent_y = network.get_latent(
    adata=adata, 
    batch_key="condition",  
    return_z=True
)

adata_latent = sc.AnnData(latent_y)
adata_latent.obs["cell_label"] = adata.obs["cell_label"].tolist()
adata_latent.obs[condition_key] = adata.obs[condition_key].tolist()

sc.pp.neighbors(adata_latent)

sc.tl.umap(adata_latent)

sc.pl.umap(adata_latent, color=[condition_key, "cell_label"])

## Making prediction

In [None]:
def calc_R2(pred_adata, real_adata, n_trials=1000):
    r_values_mean = np.zeros((n_trials,))
    r_values_var = np.zeros((n_trials,))
    
    for i in range(n_trials):
        pred_idx = np.random.choice(range(pred_adata.shape[0]), int(0.9 * pred_adata.shape[0]), replace=False)
        real_idx = np.random.choice(range(real_adata.shape[0]), int(0.9 * real_adata.shape[0]), replace=False)
        
        if issparse(pred_adata.X):
            pred_adata.X = pred_adata.X.A
            real_adata.X = real_adata.X.A

        pred_mean = np.mean(pred_adata.X[pred_idx], axis=0)
        real_mean = np.mean(real_adata.X[real_idx], axis=0)
        pred_var = np.var(pred_adata.X[pred_idx], axis=0)
        real_var = np.var(real_adata.X[real_idx], axis=0)

        _, _, r_value_mean, _, _ = stats.linregress(pred_mean, real_mean)
        _, _, r_value_var, _, _ = stats.linregress(pred_var, real_var)

        r_values_mean[i] = r_value_mean ** 2
        r_values_var[i] = r_value_var ** 2

    return (
        r_values_mean.mean(), r_values_mean.std(),
        r_values_var.mean(), r_values_var.std()
    )

In [None]:
ground_truth = adata[((adata.obs["cell_label"] == "Tuft")
                      & (adata.obs["condition"].isin(["Hpoly.Day10", "Control"])))]

adata_source = adata[(adata.obs["cell_label"] == "Tuft") &
                     (adata.obs["condition"] == "Control")]

predicted_data = network.predict(
    adata=adata_source,
    condition_key="condition", 
    target_condition="Hpoly.Day10"
)

In [None]:
adata_pred = sc.AnnData(predicted_data_recon)
adata_pred.obs["condition"] = np.tile("predicted", len(adata_pred))
adata_pred.var_names = adata_source_recon.var_names.tolist()

all_adata = ground_truth_recon.concatenate(adata_pred)

In [None]:
sc.tl.pca(all_adata)
sc.pl.pca(all_adata, color=["condition"])

sc.pl.violin(all_adata, keys="Defa24", groupby="condition")

In [None]:
# 计算 R² 的均值和方差
r2_mean, r2_mean_std, r2_var, r2_var_std = calc_R2(adata_pred, ground_truth_recon)

# 打印结果
print(f"R² Mean: {r2_mean}, R² Mean Std: {r2_mean_std}")
print(f"R² Var: {r2_var}, R² Var Std: {r2_var_std}")

# Reg Mean & Reg Var plot for Tuft

In [None]:
adata = adata[adata.obs[condition_key].isin(['Control', 'Hpoly.Day10'])]
cell_type_adata = adata[adata.obs[cell_type_key] == "Tuft"]

sc.tl.rank_genes_groups(cell_type_adata, reference='Control', 
                       groupby=condition_key, groups=["Hpoly.Day10"],
                       key_added='up_reg_genes', n_genes=50, method='wilcoxon')

sc.tl.rank_genes_groups(cell_type_adata, reference="Hpoly.Day10",
                         groupby=condition_key, groups=['Control'],
                         key_added='down_reg_genes', n_genes=50, method='wilcoxon')

up_genes = cell_type_adata.uns['up_reg_genes']['names']['Hpoly.Day10']
down_genes = cell_type_adata.uns['down_reg_genes']['names']['Control']
top_genes = up_genes.tolist() + down_genes.tolist()

In [None]:
print(f"Top genes: {top_genes}")
print(f"Number of top genes: {len(top_genes)}")
print(all_adata.obs["condition"].unique())

In [None]:
trvae.pl.reg_mean_plot(all_adata,
                         top_100_genes=top_genes,
                         gene_list=top_genes[:5] + top_genes[50:55],
                         condition_key=condition_key,
                         axis_keys={'x': 'predicted', 'y': 'Hpoly.Day10'}, 
                         labels={'x': "", 'y': ""},
                         legend=False,
                         show=True,
                         x_coeff=1.0,
                         y_coeff=0.0)

In [None]:
trvae.pl.reg_var_plot(all_adata,
                         top_100_genes=top_genes,
                         gene_list=top_genes[:5] + top_genes[50:55],
                         condition_key=condition_key,
                         axis_keys={'x': 'predicted', 'y': 'Hpoly.Day10'}, 
                         labels={'x': "", 'y': ""},
                         legend=False,
                         show=True,
                         x_coeff=1.0,
                         y_coeff=0.1)