In [None]:
import sys
sys.path.insert(0, "../")

import simul.cna.api as cna
import simul.patients.api as patut
import simul.run.api as run
import simul.run.utils as plotut

import pathlib
import scvi
import scipy
import infercnvpy as cnv
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad

In this notebook, we will generate a dataset from scratch and take a brief look at the generated data. Prerequisites are that 

(a) an original adata for simulation be selected 
(b) an scVI model be pretrained on the original adata object 

# Simulate

We download the data as it was when pretrained. 

The dataset is a subset from Pelka et al., https://doi.org/10.1016/j.cell.2021.08.003
We use macrophages and plasma cells to simulate healthy cells, and T cells in 3 different flavors (CD4+, CD8+ and gamma delta, PLZF+ T cells ie TZBTB16) as the basis for our malignant cells

In [None]:
DATAPATH = pathlib.Path("/cluster/work/boeva/scRNAdata/preprocessed/crc/2022-06-14_14-54-09/data")
adata = sc.read_h5ad(DATAPATH / "non_malignant.h5ad")

In [None]:
adata = adata[adata.obs.celltype.isin(
    ["TCD4", "TCD8", "Tgd", "TZBTB16", "Macro", "Plasma"])].copy()

First we define the subset of patients in the original data that will be used for the generative process 
These have at least 50 cells for each program and at least 50 healthy cells

In [None]:
SELECTED_PATIENTS = []
for sample in adata.obs.sample_id.unique():
    patadata = adata[adata.obs.sample_id == sample].copy()
    valcounts = patadata.obs.celltype.value_counts()
    if valcounts.shape[0] == 6 and (valcounts["TCD4"] > 30) and (
            valcounts["Macro"] > 30) and (valcounts["Plasma"] > 30) and (
            valcounts["TCD8"] + valcounts["Tgd"] > 30) and (valcounts["TZBTB16"] > 30):
        SELECTED_PATIENTS.append(sample)

*Note: we remove C162 because the patient was sampled twice in different chemistries and C171 drove the latent space because of the amount of cells*

In [None]:
SELECTED_PATIENTS = np.setdiff1d(SELECTED_PATIENTS,["C162","C171"])

In [None]:
SELECTED_PATIENTS

We load the pretrained scVI model

In [None]:
scvi.model.SCVI.setup_anndata(adata, batch_key="sample_id")
model = scvi.model.SCVI.load("../scvi-model/", adata)

If you want to visualize the latent space of the trained model, you can use the following code

In [None]:
#lr = model.get_latent_representation()

#adata.obsm["X_scvi"] = lr

#sc.pp.neighbors(adata, use_rep="X_scvi")
#sc.tl.umap(adata)

#sc.pl.umap(adata, color=["celltype","sample_id"]")

We define the anchors associated with each program. In this case we took anchors that were linked to the original programs (ie CD4 of T CD4+, CD8A for T CD8+/gamma delta and CD7 for PLZF+ T ie TZBTB16)

In [None]:
anchors = ["CD4","CD8A","CD7"]

We create a genome object and a CNVGenerator object.

WARNING: beware not to put in the chromosomes where one can create a potential loss the chromosomes of the anchor, to avoid a weird situation of a gain happening in the middle of a loss.

In [None]:
genome = cna.Genome(genes_df=adata.var, chromosome_column="chromosome", start_column="start")

The different hyperparameters chosen here:
- chromosomes_gain: which chromosomes a potential gain can occur on
- chromosomes_loss: which chromosomes a potential loss can occur on
- dropout: the probability of a dropping a chromosome for gain/loss in the ancestral subclone
- dropout_child: the probability of dropping a chromosome for gain/loss in the children subclones
- p_anchor: the probability of gaining an anchor at each call 
- min_region_length: the minimal length of the CNV region (not considering the cut linked to the end of the chromosome)
- max_region_length: the maximal length of the CNV region (not considering the cut linked to the end of the chromosome)
- seed: the random seed associated with the CNV generation

In [None]:
CNVGenerator = cna.CNVPerBatchGenerator(
        genome=genome,
        anchors = anchors,
        chromosomes_gain=["chr1","chr2","chr4","chr5","chr10","chr11","chr14","chr19","chr20"],
        chromosomes_loss=["chr3","chr6","chr8","chr13","chr16","chr18",],
        dropout=0.5,
        dropout_child=0.7,
        p_anchor = 0.2,
        min_region_length=500,
        max_region_length=700,
        seed = np.random.randint(100),
)

Now we generate a dataset. We use the Dataset class provided in this framework that will automatically instantiate patients, subclones etc. For more info see the `patients` subpackage.

The hyperparameters chosen here are:
- n_batches: the number of patients to generate. BEWARE if using sampling without replacement you cannot generate more patients than you have selected patients (in our case 5)
- n_programs: number of programs to generate (here fixed at 3)
- CNVGenerator: the CNVGenerator instantiated above
- seed: random seed
- n_subclones_min: min number of subclones per patient
- n_subclones_max: max number of subclones per patient
- n_malignant_max: max number of malignant cells
- n_malignant_min: min number of malignant cells
- n_healthy_max: max number of healthy cells
- n_healthy_min: min number of healthy cells
- subclone_alpha: alpha for the dirichlet distribution for subclone proportion sampling ($\alpha_i =$subclone_alpha $\forall i$)

*Note: when generating subclones, we start from an ancestral clone and create a branching structure, see simul.cna subpackage for more details. If the child generated is the same as the ancestral clone, we call the child generation again, which will be indicated by a print message. If this goes on for too long, consider changing the probability p_child in the CNVGenerator, it might be too low*

In [None]:
dataset = patut.Dataset(
    n_batches=len(SELECTED_PATIENTS),
    n_programs=3,
    CNVGenerator=CNVGenerator,
    seed=np.random.randint(100),
    n_subclones_min=1,
    n_subclones_max=3,
    n_malignant_max=400,
    n_malignant_min=200,
    n_healthy_max=250,
    n_healthy_min=100,
    subclone_alpha=5,
    
)

We look at what subclone proportions will exist in the data per patient

In [None]:
for pat in dataset.patients:
    print(pat.batch, pat.n_total_cells(), pat.subclone_proportions)

We plot and save the heatmap showing the generator CNV profiles for the entire dataset

In [None]:
run.plot_subclone_profile(dataset=dataset,filename="figures/heatmap_highcnv_subclones.png")

We also provide utilities to visualize the subclone profiles per patient which chromosomal annotation.

In [None]:
plotut.plot_cnv_heatmap(dataset=dataset, patient="patient3", adata=adata)

Now we generate the alphas for the dirichlet distributions associated with the anchors

In [None]:
anchor_alphas = run.generate_anchor_alphas(anchors=anchors, alpha_add=10, start_alpha=[10,10,5])

In [None]:
anchor_alphas

We create a probability distribution P(Program|anchors,batch) using the aforementioned alphas and sampling from the resulting dirichlet distributions 

*Note: here we set dropout at 0 because we will remove certain programs after the generation procedure rather than before to avoid the possibility of dropping out the most probable clone.*

In [None]:
MIN_PROGRAMS = 2
DROPOUT = 0

In [None]:
distribution = cna.generate_probabilities(
    anchors_to_alphas=anchor_alphas,
    batches=dataset.batches,
    min_programs=MIN_PROGRAMS,
    prob_dropout=DROPOUT,
    program_names=dataset.programs,
    seed=np.random.randint(100),
)

In [None]:
distribution._conditional_probability

We now start the simulation procedure. First we create the composition of the malignant compartment for all patients. We pick for each cell which program and subclone it belongs to based on the generated probability distribution.

In [None]:
all_malignant_obs = run.simulate_malignant_comp_batches(dataset=dataset, prob_dist=distribution)

We then drop programs for patients. We drop the rarest program with probability `p_1`. Iff the rarest program is dropped, we drop the second rarest program with probability `p_2`. 

In [None]:
all_malignant_obs, dataset = run.drop_rarest_program(all_malignant_obs,dataset,p_1=0.3, p_2=0.2)

We now pick the healthy compartment. Cells are either macrophages or plasma cells in our case.

In [None]:
all_healthy_obs = run.simulate_healthy_comp_batches(dataset=dataset)

We sample which original patients the simulated patients will come from. One can either chose to proceed without replacement so that no patients share common  priors or with replacement to be able to generate more patients in the dataset.

In [None]:
sample_patients = run.sample_patient_original(dataset=dataset, selected_patients=SELECTED_PATIENTS)
#sample_patients = run.sample_patient_original_replacement(dataset=dataset, selected_patients=SELECTED_PATIENTS)

We add here the chemistry with which the original patient was sampled. Indeed, patients were sampled with V2 and V3 chemistry in the Pelka et al dataset, which can lead to strong batch effects in addition to patient-specific batch effects. We want to be able to quantify how well these are later learned.

In [None]:
all_malignant_obs = run.add_chemistry_obs(all_malignant_obs=all_malignant_obs, sample_patients=sample_patients, adata=adata)

all_healthy_obs = run.add_chemistry_obs(all_malignant_obs=all_healthy_obs, sample_patients=sample_patients, adata=adata)

Finally we sample from a ZINB the malignant and healthy compartments of all the patients.

In [None]:
all_malignant_gex = run.simulate_gex_malignant(adata=adata, model=model, dataset=dataset, 
                                               all_malignant_obs=all_malignant_obs, 
                                               sample_patients=sample_patients)

In [None]:
all_healthy_gex = run.simulate_gex_healthy(adata=adata, model=model, 
                                               all_healthy_obs=all_healthy_obs, 
                                               sample_patients=sample_patients,)

We use this function to save a patient as a `.h5ad` object

In [None]:
def save_batch_gex(batch_name,batch_gex,gene_names,df_obs,savedir):
    
    batch_gex = scipy.sparse.csr_matrix(batch_gex)
    batch_gex = pd.DataFrame.sparse.from_spmatrix(batch_gex, index=df_obs.index, columns=gene_names)
    adata = ad.AnnData(batch_gex, obs=df_obs)
    adata.write(savedir / f"{batch_name}.h5ad")

For every patient, we concatenate the malignant and healthy components and save them as a .h5ad object

In [None]:
import os
savedir = pathlib.Path("/cluster/work/boeva/scRNAdata/cna_simulation/raw_datasets/highcnv_subclones/")
os.makedirs(savedir, exist_ok=True)

In [None]:
gene_names = list(adata.var_names)

for pat in all_malignant_gex:
    
    df_mal = all_malignant_obs[pat]
    df_h = all_healthy_obs[pat]
    df_obs = pd.concat([df_mal,df_h])
    
    gex_mal = all_malignant_gex[pat]
    gex_h = all_healthy_gex[pat]
    batch_gex = np.concatenate([gex_mal,gex_h])
    
    print(f"Saving {pat}")
    save_batch_gex(batch_name=pat,
                   batch_gex=batch_gex,
                   gene_names=gene_names,
                   df_obs=df_obs,
                   savedir=savedir)

We also save the cnv profiles for reference

In [None]:
resdir = savedir / "cnvprofiles"
os.makedirs(resdir, exist_ok=True)

for pat in dataset.patients:
    profiles = []
    for subc in pat.subclones:
        profiles.append(subc.profile)
    pd.concat(profiles,axis=1).to_csv(resdir / f"{pat.batch}_cnv.csv")
    

# Analyse (briefly)

We load the simulated data and create one adata object for the whole dataset

In [None]:
adatas = [sc.read_h5ad(savedir / f"patient{i}.h5ad") for i in range(1,len(SELECTED_PATIENTS))]

simadata = adatas[0].concatenate(*adatas[1:])

We compute standard quality control metrics, compute the cell cycle score, and then we compute the UMAP representation of the dataset.

In [None]:
cc_genes = pd.read_csv("../data/cc_genes_2.csv")

s_genes = cc_genes["G1/S"].str.strip().dropna().ravel()
g2m_genes = cc_genes["G2/M"].str.strip().dropna().ravel()

simadata.var['mt'] = simadata.var_names.str.startswith('MT-')  # annotate the group of mitochondrial genes as 'mt'
sc.pp.calculate_qc_metrics(simadata, qc_vars=['mt'], percent_top=None, log1p=True, inplace=True)

sc.pp.normalize_total(simadata, target_sum=10000)
sc.pp.log1p(simadata)
sc.tl.score_genes_cell_cycle(simadata, s_genes=s_genes, g2m_genes=g2m_genes)

We can select only the highly variable genes in the malignant cells, as this is a common preprocessing step for batch integration algorithms

In [None]:
maladata = simadata[simadata.obs.malignant_key=="malignant"].copy()
sc.pp.highly_variable_genes(maladata, n_top_genes=2000)
simadata = simadata[:,maladata.var.highly_variable].copy()

In [None]:
sc.pp.neighbors(simadata)
sc.tl.umap(simadata)

In [None]:
simadata.obs["pat_subclone"] = simadata.obs.subclone.astype(str) + "_" + simadata.obs.batch.astype(str)
simadata.obs["pat_subclone"] = simadata.obs["pat_subclone"].apply(lambda x: "NA" if "NA" in x else x)

In [None]:
sc.pl.umap(simadata,color=["malignant_key","chemistry","batch","pat_subclone",
 "program","log1p_total_counts", "pct_counts_mt","phase"] ,ncols=2, wspace=0.25,
                 save="highcnv_subclones.png")

We want to briefly see how well inferCNV manages to pick up the true CNVs we simulated. We thus apply inferCNV to a patient in the set and visualize the results vs the true CNVs.

In [None]:
adatas[3].var = adata.var
sc.pp.normalize_total(adatas[3],target_sum=10000)
sc.pp.log1p(adatas[3])

In [None]:
cnv.tl.infercnv(adatas[3],reference_key="program",reference_cat=["Macro","Plasma"],window_size=200)

In [None]:
cnv.pl.chromosome_heatmap(adatas[3],groupby="subclone", save="infercnv_hard_small_new_pat4.png")

In [None]:
plotut.plot_cnv_heatmap(dataset=dataset, patient="patient4", adata=adata, filename="figures/hard_small_new_pat4.png")

# Comparing simulated and original

We want to see how the simulated data looks in comparison with the original data it was simulated from.

In [None]:
simadata.obs["sample_id"]="patient" + (simadata.obs["batch"].astype(int)+1).astype(str)

In [None]:
adata = adata[adata.obs.sample_id.isin(SELECTED_PATIENTS)].copy()

In [None]:
joint = adata.concatenate(simadata)

In [None]:
# might be needed to avoid running out of memory!
del simadata
del adata

In [None]:
# we put all the information in the same columns, as the naming conventions were different
joint.obs.chemistry = joint.obs.chemistry.fillna(joint.obs.SINGLECELL_TYPE) 

joint.obs.program = joint.obs.program.astype(str).replace({"nan": np.nan}).fillna(joint.obs.celltype.astype(str))

joint.obs["cell_origin"] = joint.obs.batch.replace({'0': "Original", '1': "Simulated"})

joint.obs.total_counts = joint.obs.total_counts.fillna(joint.obs.n_counts)

In [None]:
sc.pp.normalize_total(joint, target_sum=10000)
sc.pp.log1p(joint)

sc.pp.neighbors(joint)
sc.tl.umap(joint)

In [None]:
sc.pl.umap(joint, color=["chemistry","program","cell_origin",
"total_counts","sample_id",],ncols=2,wspace=0.25, save="juxt_original_simulated_new.png")