In [None]:
import sys
sys.path.insert(0, "../")

In [None]:
import infercnvpy
import pandas as pd
import pathlib as pl
import numpy as np
import scanpy as sc

In [None]:
import simul.base.utils as utils
import simul.simulate.run as run
import simul.cnv.profiles as cnvprofile
import simul.cnv.sampling as cnvsample
import simul.patients.dataset as patut
from simul.base.config import SimCellConfig
from simul.simulate.utils import save_dataset

# Create dataset

In [None]:
# first set the random seed
rs = 1

In [None]:
anchors = ["CD4","CD8A","CD7"]

In [None]:
adatavar = pd.read_csv("/path/to/data/with/chromosome/positions",index_col=0)

In [None]:
genome = cnvprofile.Genome(genes_df=adatavar, chromosome_column="chromosome", start_column="start")

In [None]:
vartosave = adatavar[["chromosome","start","end"]]

In [None]:
CNVGenerator = cnvprofile.CNVPerBatchGenerator(
        genome=genome,
        anchors = anchors,
        chromosomes_gain=["chr1","chr2","chr4","chr5","chr10","chr11","chr14","chr19","chr20"],
        #chromosomes_gain=["chr1","chr4","chr10","chr14","chr20"],
        chromosomes_loss=["chr3","chr6","chr8","chr13","chr16","chr18",],
        #chromosomes_loss=["chr3","chr8","chr16","chr18",],
        dropout=0.8, # normal
        #dropout=0.5, # high CNV
        dropout_child=0.8,
        p_anchor = 0.2,
        min_region_length=200,
        max_region_length=300,
        seed = rs,
)

In [None]:
dataset = patut.Dataset(
    n_batches=20,
    n_programs=3,
    CNVGenerator=CNVGenerator,
    seed=rs,
    n_subclones_min=1,
    n_subclones_max=3,
    n_malignant_max=600,
    n_malignant_min=300,
    n_healthy_max=500,
    n_healthy_min=300,
    subclone_alpha=5,
    
)

In [None]:
utils.plot_subclone_profile(dataset=dataset,filename="figures/heatmap_highcnv_subclones.png")

In [None]:
for pat in dataset.patients:
    print(pat.batch, pat.n_total_cells(), pat.subclone_proportions)

In [None]:
anchor_alphas = utils.generate_anchor_alphas(anchors=anchors, alpha_add=10, start_alpha=[5,10,10])

In [None]:
### LEAVE AS IS
MIN_PROGRAMS = 2
DROPOUT = 0

In [None]:
distribution = cnvsample.generate_probabilities(
    anchors_to_alphas=anchor_alphas,
    batches=dataset.batches,
    min_programs=MIN_PROGRAMS,
    prob_dropout=DROPOUT,
    program_names=dataset.programs,
    seed=rs,
)

In [None]:
distribution._conditional_probability

In [None]:
celltypes = ["Macro","Plasma","program1","program2","program3"]

In [None]:
full_obs = run.simulate_full_obs(dataset=dataset, prob_dist=distribution, p_drop = [0.3, 0.2])

In [None]:
n_cells = np.sum([full_obs[pat].shape[0] for pat in full_obs])

config = SimCellConfig(random_seed=rs, 
                       n_genes=adatavar.shape[0], 
                       batch_effect=True,
                        n_cells=n_cells, group_names=celltypes, 
                       batch_names=list(full_obs.keys()),
                       libsize_scale=0.25,
                       libsize_loc=10,
                       p_de_list=np.array([0.2,0.2,0.1,0.1,0.1]), 
                      p_down_list=np.array([0.5,0.5,0.5,0.5,0.5]), 
                      de_location_list=np.array([0.4,0.4,0.25,0.25,0.25]),
                      de_scale_list=np.array([0.5,0.5,0.1,0.1,0.1]), 
                       pb_de_list=0.1, 
                       bde_location_list=0.05, 
                       bde_scale_list=0.1, shared_cnv=False)

rng = config.create_rng()

In [None]:
counts, de_facs_group, de_facs_be, gain_expr_full, loss_expr_full = run.simulate_dataset(config=config, 
                                                         rng=rng, 
                                                         full_obs=full_obs, 
                                                         dataset=dataset)

In [None]:
adatas = run.counts_to_adata(counts_pp=counts, 
                         observations=full_obs, var=vartosave)

In [None]:
ds_name = "morecells_2"
savedir = pl.Path("/path/to/save/dir")

save_dataset(adatas=adatas, 
             ds_name=ds_name, 
             savedir=savedir, 
             de_group=pd.DataFrame(de_facs_group),
             de_batch=pd.DataFrame(de_facs_be),
             gain_expr_full=gain_expr_full,
             loss_expr_full=loss_expr_full,
             config=config)



# Analyze simulated data

In [None]:
adatas = []
for f in (savedir / ds_name).iterdir():
    if "patient" in f.stem:
        print(f.stem)
        adatas.append(sc.read_h5ad(f))

In [None]:
simadata = adatas[0].concatenate(*adatas[1:])

In [None]:
sc.pp.calculate_qc_metrics(simadata, percent_top=None, log1p=True, inplace=True)

sc.pp.normalize_total(simadata, target_sum=10000)
sc.pp.log1p(simadata)

In [None]:
maladata = simadata[simadata.obs.malignant_key=="malignant"].copy()
sc.pp.highly_variable_genes(maladata, n_top_genes=2000)
simadata = simadata[:,maladata.var.highly_variable].copy()

In [None]:
sc.pp.neighbors(simadata)
sc.tl.umap(simadata)

In [None]:
simadata.obs["pat_subclone"] = simadata.obs.subclone.astype(str) + "_" + simadata.obs.batch.astype(str)
simadata.obs["pat_subclone"] = simadata.obs["pat_subclone"].apply(lambda x: "NA" if "NA" in x else x)

In [None]:
sc.pl.umap(simadata,color=["batch","pat_subclone","malignant_key",
 "program","log1p_total_counts",] ,ncols=2, wspace=0.25,
                 save=f"{ds_name}.png")

In [None]:
sc.pp.normalize_total(adatas[3],target_sum=10000)
sc.pp.log1p(adatas[3])

In [None]:
infercnvpy.tl.infercnv(adatas[3],reference_key="program",reference_cat=["Macro","Plasma"],window_size=200)

In [None]:
infercnvpy.pl.chromosome_heatmap(adatas[3],groupby="subclone")

In [None]:
utils.plot_cnv_heatmap(dataset=dataset, patient="patient11", var=adatavar)