In [None]:
import cospar as cs
import pandas as pd
import scipy
import numpy as np
import os
import scanpy as sc
from cospar.plotting import _utils as pl_util

%reload_ext autoreload
%autoreload 2

import pandas as pd
import torch
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
cs.settings.set_figure_params(format="png", figsize=[4, 3.5], dpi=100, fontsize=12, pointsize=3)

### CoSpar Basics

In [None]:
adata = sc.read_h5ad('./datasets/Weinreb/adata_used.h5ad')
adata

In [None]:
adata.obs['Clone_ID'] = [name[6:] for name in adata.obs.clones.values]

In [None]:
adata.obs

In [None]:
RNA_count_matrix = scipy.sparse.coo_matrix(np.expm1(adata.raw.X))

In [None]:
adata_orig = cs.pp.initialize_adata_object(
    X_state=RNA_count_matrix,
    gene_names=adata.raw.var_names,
    cell_names=adata.obs.index,
    time_info=adata.obs.Time_point.values,
    state_info=adata.obs.label_man.values,
    X_emb=adata.obsm['X_umap'],
    X_pca=adata.obsm['X_pca'],
    data_des="cospar",
)

In [None]:
adata_orig

In [None]:
# cs.hf.update_time_ordering(adata_orig, updated_ordering=["3", "10", "17"])

In [None]:
cs.pp.get_highly_variable_genes(adata_orig)
cs.pp.remove_cell_cycle_correlated_genes(adata_orig, corr_threshold=0.5, confirm_change=True)
cs.pp.get_X_clone(adata_orig, adata.obs.index.values, adata.obs.Clone_ID.values)

In [None]:
cs.pl.clones_on_manifold(
    adata_orig,
    selected_clone_list=[1],
    color_list=["black", "red", "blue"],
    clone_markersize=8,
)

In [None]:
selected_fates = ["Ery", "Meg", "Eos", "Mast", "DC", "Mono", "Neu", "pDC", "Ly", 'Baso']

In [None]:
cs.pl.barcode_heatmap(
    adata_orig,
    selected_fates=selected_fates,
    color_bar=True,
    log_transform=False,
    fig_height=4
)

In [None]:
# cs.pl.barcode_heatmap(
#     adata_orig,
#     selected_fates=selected_fates,
#     color_bar=True,
#     log_transform=False,
#     binarize=True,
#     fig_height=4
# )

In [None]:
cs.tl.fate_coupling(adata_orig, selected_fates=selected_fates, source="X_clone") 
cs.pl.fate_coupling(adata_orig, source="X_clone")

Strong coupling implies the existence of bi-potent or multi-potent cell states at the time of barcoding. You can visualize the fate hierarchy by a simple neighbor-joining method

In [None]:
cs.tl.fate_hierarchy(adata_orig, selected_fates=selected_fates, source="X_clone")
cs.pl.fate_hierarchy(adata_orig, source="X_clone")

In [None]:
sc.pl.umap(adata, color="label_man", groups='Meg')

In [None]:
df_results = []

for fate in selected_fates:
    cs.tl.clonal_fate_bias(adata_orig, selected_fate=fate, alternative="two-sided")
    cs.pl.clonal_fate_bias(adata_orig)
    df_results.append(adata_orig.uns["clonal_fate_bias"])

In [None]:
for idx, fate in enumerate(selected_fates):
    df_results[idx].to_csv(f'./data/fate_bias_weinreb_cospar_{fate}.csv')

### Transition Map Inference

In [None]:
adata_all_latert = adata_orig.copy()
adata_all_latert = cs.tmap.infer_Tmap_from_multitime_clones(
    adata_all_latert,
    clonal_time_points=['2.0', '4.0', '6.0'],
    later_time_point='6.0',
    smooth_array=[20, 15, 10, 5],
    max_iter_N=10,
)

In [None]:
adata_all_latert

In [None]:
adata_clone_latert = adata_orig.copy()
adata_clone_latert = cs.tmap.infer_Tmap_from_clonal_info_alone(
    adata_clone_latert,
    method='naive',
    clonal_time_points=['2.0', '4.0', '6.0'],
    later_time_point=None
)

In [None]:
adata_clone_latert

In [None]:
def post_process_adata(adata):
    parent_map, node_mapping = {}, {}

    for key, value in adata_all_latert.uns['fate_hierarchy_X_clone']['parent_map'].items():
        parent_map[str(key)] = value

    for key, value in adata_all_latert.uns['fate_hierarchy_X_clone']['node_mapping'].items():
        node_mapping[str(key)] = value
    
    adata.uns['fate_hierarchy_X_clone']['parent_map'] = parent_map
    adata.uns['fate_hierarchy_X_clone']['node_mapping'] = node_mapping

    return adata

In [None]:
# adata_all_latert = post_process_adata(adata_all_latert)
# adata_clone_latert = post_process_adata(adata_clone_latert)

# cs.hf.save_map(adata_all_latert)
# cs.hf.save_map(adata_clone_latert)

### Others

In [None]:
coarse_X_clone, mega_cluster_list = cs.tl.coarse_grain_clone_over_cell_clusters(
    adata_orig,
    selected_times=['2.0', '4.0', '6.0'],
    selected_fates=selected_fates
)

In [None]:
coarse_X_clone

In [None]:
coarse_X_clone.shape

In [None]:
coarse_X_clone_new = pl_util.custom_hierachical_ordering(np.arange(coarse_X_clone.shape[0]), coarse_X_clone)
coarse_X_clone_new

In [None]:
coarse_X_clone_new.shape

### Transition Map Visualization

In [None]:
cs.tl.fate_map(
    adata_all_latert,
    selected_fates=["Neu", "Mono"],
    source="transition_map",
    map_backward=True,
)

cs.pl.fate_map(
    adata_all_latert,
    selected_fates=["Neu"],
    source="transition_map",
    plot_target_state=True,
    show_histogram=False,
)

In [None]:
cs.tl.fate_potency(
    adata_all_latert,
    source="transition_map",
    map_backward=True,
    method="norm-sum",
    fate_count=True,
)

cs.pl.fate_potency(adata_all_latert, source="transition_map")

In [None]:
cs.tl.fate_bias(
    adata_all_latert,
    selected_fates=["Neu", "Mono"],
    source="transition_map",
    pseudo_count=0,
)

cs.pl.fate_bias(
    adata_all_latert,
    selected_fates=["Neu", "Mono"],
    source="transition_map",
    plot_target_state=False,
)

In [None]:
cs.pl.embedding(adata_all_latert, color=["fate_map_transition_map_Ery"], save='ery.svg')

In [None]:
cs.tl.progenitor(
    adata_all_latert,
    selected_fates=["Neu", "Mono"],
    source="transition_map",
    map_backward=True,
    bias_threshold_A=0.5,
    bias_threshold_B=0.5,
    sum_fate_prob_thresh=0.2,
    avoid_target_states=True,
)

cs.pl.embedding(adata_all_latert, color=['progenitor_transition_map_Neu', 'progenitor_transition_map_Mono'])

In [None]:
import numpy as np

cell_group_A = np.array(adata_all_latert.obs["diff_trajectory_transition_map_Neu"])
cell_group_B = np.array(adata_all_latert.obs["diff_trajectory_transition_map_Mono"])

dge_gene_A, dge_gene_B = cs.tl.differential_genes(
    adata_all_latert, cell_group_A=cell_group_A, cell_group_B=cell_group_B, FDR_cutoff=0.05
)

In [None]:
dge_gene_A

In [None]:
selected_genes = dge_gene_A["gene"][:2]
cs.pl.gene_expression_on_manifold(
    adata_all_latert, selected_genes=selected_genes, color_bar=True, savefig=False
)

In [None]:
gene_list = list(dge_gene_A["gene"][:20]) + list(dge_gene_B["gene"][:20]) 

selected_fates = [
    "Neu",
    "Mono",
    ["Baso", "Eos"],
    ["Mast", "Ery", "Meg"],
]
renames = ["Neu", "Mon", "Baso-Eos", "Mast-Ery-Meg"]

gene_expression_matrix = cs.pl.gene_expression_heatmap(
    adata_all_latert,
    selected_genes=gene_list,
    selected_fates=selected_fates,
    rename_fates=renames,
    fig_width=12,
)

In [None]:
gene_name_list = ["Gata1", "Mpo", "Elane", "S100a8"]
selected_fate = "Neu"

cs.pl.gene_expression_dynamics(
    adata_all_latert,
    selected_fate,
    gene_name_list,
    traj_threshold=0.2,
    invert_PseudoTime=False,
    compute_new=True,
    gene_exp_percentile=99,
    n_neighbors=15,
    plot_raw_data=False,
)

### Benchmark with Gillespie

#### Init Gillespie Results

In [None]:
import os 
os.chdir('/ssd/users/mingzegao/clonaltrans/clonaltrans/')

path = '../trails/checkpoints/WeinrebDynamicRates/0301_112802/model_last.pt'
model = torch.load(path, map_location='cpu')
device = torch.device('cpu')

import os 
os.chdir('/ssd/users/mingzegao/clonaltrans/')

In [None]:
anno = pd.read_csv(os.path.join(
    model.config['data_loader']['args']['data_dir'], 
    model.config['data_loader']['args']['annots']
))
cluster_names = anno['populations'].values[:model.N.shape[2]]

In [None]:
gillespie_dir='./trails/checkpoints/WeinrebDynamicGillespie/0301_211649/models/'

#### Performance Comparison

In [None]:
from clonaltrans.pl import with_cospar

In [None]:
adata_all_latert.obs['meta_clones'] = adata.obs['meta_clones'].values
adata_all_latert.obs['Clone_ID'] = adata.obs['Clone_ID'].values

cs.pl.embedding(adata_all_latert, color=["meta_clones", "fate_map_transition_map_Mono"])

In [None]:
adata.obs[adata.obs['meta_clones'] == '0']

In [None]:
selected_fates = ["Ery", "Meg", "Eos", "Mast", "DC", "Mono", "Neu", "pDC", "Ly", 'Baso']

In [None]:
from clonaltrans.pl import get_fate_prob
aggre = get_fate_prob(model, cluster_names, gillespie_dir)

In [None]:
aggre['clone_0']

In [None]:
from clonaltrans.pl import with_cospar_all
with_cospar_all(
    adata_all_latert, 
    adata, 
    model, 
    cluster_names,
    gillespie_dir,
    selected_fates,
    save='weinrebcosparfate'
)

In [None]:
from clonaltrans.pl import with_cospar_all
with_cospar_all(
    adata_all_latert, 
    adata, 
    model, 
    cluster_names,
    gillespie_dir,
    selected_fates,
    show_fate=False,
    save='weinrebcosparprog'
)

In [None]:
with_cospar(
    adata_all_latert, 
    adata, 
    'prog_1', 
    'Mono', 
    model, 
    cluster_names,
    gillespie_dir
)

In [None]:
with_cospar(
    adata_all_latert, 
    adata, 
    'prog_2', 
    'Mono', 
    model, 
    cluster_names,
    gillespie_dir
)

In [None]:
adata_all_latert.obs['meta_clones'] = adata.obs['meta_clones'].values
adata_all_latert.obs['Clone_ID'] = adata.obs['Clone_ID'].values

df = adata_all_latert.obs[adata_all_latert.obs[f'fate_map_transition_map_Mono'] >= 0]
df = df[df['state_info'] == 'prog_2']

cospar_bias = pd.DataFrame(
        df.groupby('meta_clones')[f'fate_map_transition_map_Mono'].mean()
    )
cospar_bias

In [None]:
with_cospar(
    adata_all_latert, 
    adata, 
    'prog_3', 
    'Mono', 
    model, 
    cluster_names,
    gillespie_dir
)

In [None]:
with_cospar(
    adata_all_latert, 
    adata, 
    'prog_2', 
    'Meg', 
    model, 
    cluster_names,
    gillespie_dir
)

In [None]:
with_cospar(
    adata_all_latert, 
    adata, 
    'prog_Baso_Meg_Ery_Mast', 
    'Ery', 
    model, 
    cluster_names,
    gillespie_dir,
    save='eryexample'
)

In [None]:
with_cospar(
    adata_all_latert, 
    adata, 
    'prog_2', 
    'Baso', 
    model, 
    cluster_names,
    gillespie_dir
)

In [None]:
with_cospar(
    adata_all_latert, 
    adata, 
    'prog_Baso_Eos', 
    'Baso', 
    model, 
    cluster_names,
    gillespie_dir
)

In [None]:
with_cospar(
    adata_all_latert, 
    adata, 
    'prog_Ly_pDC', 
    'Ly', 
    model, 
    cluster_names,
    gillespie_dir
)

In [None]:
with_cospar(
    adata_all_latert, 
    adata, 
    'prog_4', 
    'Neu', 
    model, 
    cluster_names,
    gillespie_dir
)