In [None]:
# Import dependencies
import os
import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import scvelo as scv

from moscot.problems.time import TemporalProblem

import cellrank as cr
import scanpy as sc
from cellrank.kernels import RealTimeKernel

sc.settings.set_figure_params(frameon=False, dpi=100)
cr.settings.verbosity = 2

import matplotlib.pyplot as plt

# Initialize random seed
import random
random.seed(111)

import warnings
warnings.simplefilter("ignore", category=UserWarning)

# Print date and time:
import datetime
e = datetime.datetime.now()
print ("Current date and time = %s" % e)

# set a working directory
#wdir = "/ceph/project/tendonhca/akurjan/analysis/"
wdir = "/mnt/da8aa2c4-0136-465b-87a2-d12a59afec55/akurjan/analysis/notebooks/developmental/"
os.chdir( wdir )

# folder structures
INPUT_FOLDERNAME = "scVI/results/"
RESULTS_FOLDERNAME = "MoscotCR/results/"
FIGURES_FOLDERNAME = "MoscotCR/figures/"

if not os.path.exists(RESULTS_FOLDERNAME):
    os.makedirs(RESULTS_FOLDERNAME)
if not os.path.exists(FIGURES_FOLDERNAME):
    os.makedirs(FIGURES_FOLDERNAME)

# Set folder for saving figures into
sc.settings.figdir = FIGURES_FOLDERNAME
scv.settings.figdir = FIGURES_FOLDERNAME   

def savesvg(fname: str, fig, folder: str=FIGURES_FOLDERNAME) -> None:
    """
    Save figure as vector-based SVG image format.
    """
    fig.tight_layout()
    fig.savefig(os.path.join(folder, fname), format='svg')

# Set other settings
sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_versions()
scv.set_figure_params('scvelo')

In [None]:
adata = sc.read_h5ad(os.path.join(INPUT_FOLDERNAME, 'dev_scgen_harmony_latent.h5ad'))
adata.var_names = adata.var_names.astype('str')
adata.var_names_make_unique()
adata

In [None]:
print(adata.X[0:5,0:5])

In [None]:
# remove non-connected nodes
keep = ['ABI3BP GAS2 Fibroblasts',
        'COL3A1 PI16 Fibroblasts',
        'FGF14 THBS4 Fibroblasts', 
        'COL6A6 FNDC1 Fibroblasts',
        'NEGR1 SCN7A Fibroblasts',
        'vasEndothelial Cells',
        'Smooth Myocytes',
        'SCX FGF14 THBS4 FSTL5 Progenitors',
        'RUNX2 THBS2 COL11A1 Progenitors',
        'COL6A6 FSTL1 DCLK1 Progenitors',
        'MKX TNMD ABI3BP GAS2 Progenitors',
        'SOX5 CREB5 Chondrocyte Progenitors',
        'lymEndothelial Cells',
        'Nervous System Cells',
        #'Immune Cells',
        'Chondrocytes',
        'Skeletal Myocytes',
        'Satellite Cells',
        'MSC Precursors',
        'Embryonic Chondrocytes'
        ]

adata = adata[adata.obs['C_scANVI'].isin(keep)].copy()
adata

In [None]:
adata

In [None]:
sc.pp.filter_genes(adata, min_counts=20, inplace=True)
sc.pp.filter_cells(adata, min_genes=200)

# Moscot

In [None]:
sc.pl.embedding(adata, basis='draw_graph_fa', color='C_scANVI')

In [None]:
tp = TemporalProblem(adata)

In [None]:
tp = tp.score_genes_for_marginals(
    gene_set_proliferation="human", gene_set_apoptosis="human"
)

In [None]:
sc.pl.embedding(adata, basis="draw_graph_fa", color=["proliferation", "apoptosis"],
                vmax="p99", vcenter=0, s=10, 
                save='proliferation_apoptosis_scores.svg'
               )

In [None]:
adata.obs["agefloat"] = adata.obs["ageint"].astype(float).astype("category")
adata.obs["agefloat"]

In [None]:
adata.obs.age

In [None]:
print(adata.X[0:10,0:10])

In [None]:
print(adata.raw.X[0:10,0:10])

In [None]:
original_palette = ['#1f77b4', '#ff7f0e', '#279e68', '#d62728', '#aa40fc', '#8c564b',
       '#e377c2', '#b5bd61', '#17becf', '#aec7e8', '#ffbb78', '#98df8a',
       '#ff9896', '#c5b0d5', '#c49c94', '#f7b6d2', '#dbdb8d', '#9edae5',
       '#ad494a', '#8c6d31']

In [None]:
tp = tp.prepare(time_key="agefloat", joint_attr='corrected_latent')

In [None]:
tp = tp.solve(epsilon=1e-3, tau_a=0.99, tau_b=0.999, scale_cost="mean",
              batch_size=1200, device='gpu')

In [None]:
tp.save(os.path.join(RESULTS_FOLDERNAME, 'moscottp_correctedlatent'), overwrite=True)

In [None]:
adata

In [None]:
#tp = tp.load(os.path.join(RESULTS_FOLDERNAME, 'moscottp_correctedlatent'))
#tp

In [None]:
import moscot.plotting as mtp

In [None]:
tp.sankey(
    source=6.5,
    target=20,
    source_groups="C_scANVI",
    target_groups="C_scANVI",
    threshold=0.2,
)

In [None]:
mtp.sankey(tp, dpi=100, figsize=(40, 8), save=os.path.join(FIGURES_FOLDERNAME,'sankey_corrlatent.svg'))

In [None]:
adata.obs["posterior_growth_rates"] = tp.posterior_growth_rates
sc.pl.embedding(adata, basis="draw_graph_fa", color=["posterior_growth_rates"],
                vmax="p99", cmap='viridis', s=10, save='posterior_growth.svg'
               )

## Checking individual cell costs (the higher, the more likely not to have an ancestor or descendant):

Not too bad, no need to remove any cells and rerun the analysis

In [None]:
adata.obs["cell_costs_source"] = tp.cell_costs_source
adata.obs["cell_costs_target"] = tp.cell_costs_target

In [None]:
sc.pl.embedding(
    adata, basis="draw_graph_fa", color=["cell_costs_source", "cell_costs_target"], cmap='viridis',
    s=20, save='source_target.svg'
)

## Visualising likely ancestor cells:

In [None]:
keys = [6.5, 7.2, 8.4, 9.0, 9.3, 12.0, 17.0, 20.0]

for i in range(len(keys) - 1): 
    t1 = keys[i]
    t2 = keys[i + 1]
    dict_key = f'{t1}_{t2}'
    
    ct_desc = tp.cell_transition(t1, t2, "C_scANVI", "C_scANVI", forward=False, key_added=f"transitions_{t1}_{t2}")
    mtp.cell_transition(tp,fontsize=8,figsize=(5, 5),return_fig=True,key=f"transitions_{t1}_{t2}",
                       save=f'{dict_key}_ancestors.svg'
                       )

## Visualising likely descendant cells:

In [None]:
for i in range(len(keys) - 1): 
    t1 = keys[i]
    t2 = keys[i + 1]
    dict_key = f'{t1}_{t2}'
    
    ct_desc = tp.cell_transition(t1, t2, "C_scANVI", "C_scANVI", forward=True, key_added=f"transitions_{t1}_{t2}")
    mtp.cell_transition(tp,fontsize=8,figsize=(5, 5),return_fig=True,key=f"transitions_{t1}_{t2}",
                       save=f'{dict_key}_descendants.svg'
                       )

In [None]:
#adata.obs.drop(columns=['cell_costs_source', 
#                        'cell_costs_target'], inplace=True)

In [None]:
#adata.write(os.path.join(RESULTS_FOLDERNAME, 'moscot_cellrank.h5ad'))

# CellRank

In [None]:
tmk = RealTimeKernel.from_moscot(tp)

Higher conn_weight: Increases the relative importance of connectivities when calculating transitions within the same block (developmental stage, in your context). This is useful when intra-stage transitions are biologically justified or when stage-specific dynamics are more isolated.

Lower conn_weight: Reduces the emphasis on intra-stage connectivities, potentially allowing for greater inter-stage transition probabilities. This could be more appropriate when stages are part of a continuum of development where transitions between stages are as important as transitions within stages.

In [None]:
tmk.compute_transition_matrix(
    self_transitions="all",
    threshold="auto_local",
    conn_weight=0.2
    #conn_kwargs={"n_neighbors": 80, "metric": "correlation"} - tested but did not use in the end
)

In [None]:
tmk.plot_random_walks(
    max_iter=1000,
    start_ixs={"agefloat": 6.5},
    basis="draw_graph_fa",
    seed=0,
    dpi=150,
    size=30,
)

In [None]:
ax = tmk.plot_single_flow(
    cluster_key="C_scANVI",
    time_key="agefloat",
    cluster="MSC Precursors",
    min_flow=0.20,
    xticks_step_size=1,
    show=False)
_ = ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
ax.legend(loc='best') 
fig = ax.get_figure()
fig.set_size_inches(10, 5)
fig.savefig('msc_prec_single_flow.svg', dpi=300)  

In [None]:
ax = tmk.plot_single_flow(
    cluster_key="C_scANVI",
    time_key="agefloat",
    cluster='SCX FGF14 THBS4 FSTL5 Progenitors',
    min_flow=0.37,
    xticks_step_size=1,
    show=False
)
_ = ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
ax.legend(loc='best') 
fig = ax.get_figure()
fig.set_size_inches(10, 5)
fig.savefig('scx_prog_single_flow.svg', dpi=300)  

In [None]:
ax = tmk.plot_single_flow(
    cluster_key="C_scANVI",
    time_key="agefloat",
    cluster='MKX TNMD ABI3BP GAS2 Progenitors',
    min_flow=0.25,
    xticks_step_size=1,
    show=False,
)
_ = ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
ax.legend(loc='best') 
fig = ax.get_figure()
fig.set_size_inches(10, 5)
fig.savefig('mkx_prog_single_flow.svg', dpi=300)  

In [None]:
ax = tmk.plot_single_flow(
    cluster_key="C_scANVI",
    time_key="agefloat",
    cluster='COL6A6 FSTL1 DCLK1 Progenitors',
    min_flow=0.2,
    xticks_step_size=1,
    show=False,
)
_ = ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
ax.legend(loc='right') 
fig = ax.get_figure()
fig.set_size_inches(10, 5)
fig.savefig('col6_prog_single_flow.svg', dpi=300)  

In [None]:
ax = tmk.plot_single_flow(
    cluster_key="C_scANVI",
    time_key="agefloat",
    cluster='RUNX2 THBS2 COL11A1 Progenitors',
    min_flow=0.25,
    xticks_step_size=1,
    show=False,
)
_ = ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
ax.legend(loc='best') 
fig = ax.get_figure()
fig.set_size_inches(10, 5)
fig.savefig('runx2_prog_single_flow.svg', dpi=300)  

In [None]:
ax = tmk.plot_single_flow(
    cluster_key="C_scANVI",
    time_key="agefloat",
    cluster='SOX5 CREB5 Chondrocyte Progenitors',
    min_flow=0.2,
    xticks_step_size=1,
    show=False,
)
_ = ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
ax.legend(loc='right') 
fig = ax.get_figure()
fig.set_size_inches(10, 5)
fig.savefig('sox5_prog_single_flow.svg', dpi=300)  

In [None]:
ax = tmk.plot_single_flow(
    cluster_key="C_scANVI",
    time_key="agefloat",
    cluster='Embryonic Chondrocytes',
    min_flow=0.4,
    xticks_step_size=1,
    show=False,
)
_ = ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
ax.legend(loc='right') 
fig = ax.get_figure()
fig.set_size_inches(10, 5)
fig.savefig('chondro_prog_single_flow.svg', dpi=300)  

## Identifying Probable Terminal and Initial States

In [None]:
g2 = cr.estimators.GPCCA(tmk)
g2.compute_schur(which='LM')
g2.plot_spectrum(real_only=True)

In [None]:
g2.compute_macrostates(n_states=11, cluster_key="C_scANVI")
g2.plot_macrostates(which="all", legend_loc="right", s=100, basis='draw_graph_fa',
                    save='devfibros_11macrostates.svg'
                   )

In [None]:
g2.plot_macrostate_composition(key="age", figsize=(7, 4))

In [None]:
g2.plot_coarse_T(annotate=True)

In [None]:
g2.plot_macrostates(which="all", legend_loc="right", basis='draw_graph_fa',
                    discrete=False, same_plot=False, ncols=3, figsize=(9,9),
                    save='devfibros_macrostates_separate2.svg')

In [None]:
#g2.predict_terminal_states(method='top_n', n_states=9)
#g2.plot_macrostates(which="terminal", legend_loc="right", s=100, basis='draw_graph_fa',
#                    figsize=(5,5), save='terminal2.svg')

In [None]:
terminal_states = [
    'FGF14 THBS4 Fibroblasts', 
    'COL3A1 PI16 Fibroblasts_2',
    'Smooth Myocytes', 
    'vasEndothelial Cells',
    'Nervous System Cells', 
    'lymEndothelial Cells',
    'Skeletal Myocytes',
    'ABI3BP GAS2 Fibroblasts', 
    'Chondrocytes'
]

g2.set_terminal_states(states=terminal_states)
g2.plot_macrostates(which="terminal", legend_loc="right", s=100, basis='draw_graph_fa',
                    save='manual_terminal2.svg')

In [None]:
adata.obs['term_states_fwd'].cat.categories

In [None]:
g2.plot_macrostates(which="terminal", legend_loc="right", basis='draw_graph_fa',
                    discrete=False, same_plot=False, ncols=3, figsize=(9,9), cmap='Greens',
                   save='devfibros_terminal_separate2.svg')

In [None]:
selected_cells = adata[adata.obs['C_scANVI'] == 'MSC Precursors']
adata.obs['selected'] = 0.2
ind = selected_cells.obs_names
for i in ind:
    adata.obs.loc[i, 'selected'] = 1
sc.pl.embedding(adata, basis='draw_graph_fa', 
                color='selected', 
                cmap='Greys', vmin=0, s=50)

In [None]:
adata.obs['selected'] = 0.2
sel = selected_cells[selected_cells.obs['age'] == '6.5w'].obs_names
sel2 = selected_cells[selected_cells.obs['age'] == '7.2w'].obs_names
selected = list(sel.values) + list(sel2.values)

for i in selected:
    adata.obs.loc[i, 'selected'] = 1
sc.pl.embedding(adata, basis='draw_graph_fa', 
                color='selected', 
                cmap='Greys', vmin=0, s=50)

In [None]:
#initial_dict = {
#    'MSC Precursors': selected_cells[selected_cells.obs['age'] == '6.5w'].obs_names,
#    'SOX9 SCX Progenitors': selected_cells[selected_cells.obs['C_scANVI_original'] == 'SOX9 SCX Progenitors'].obs_names,
#    'NTNG1 COL6A6 Progenitors': selected_cells[selected_cells.obs['C_scANVI_original'] == 'NTNG1 COL6A6 Progenitors'].obs_names,
#}

g2.set_initial_states({'MSC Precursors': selected})

In [None]:
g2.plot_macrostates(which="initial", legend_loc="right", s=100, figsize=(5,5),
                    basis='draw_graph_fa', save='devfibros_initial_fa.svg'
                   )

In [None]:
g2.compute_fate_probabilities()
g2.plot_fate_probabilities(same_plot=False, basis='draw_graph_fa',
                           ncols=3, figsize=(9,9), cmap='Purples', 
                           save='fateprobs_fa.svg'
                          )

In [None]:
cr.pl.circular_projection(adata, keys=["age", "C_scANVI"], legend_loc="right", s=10,
                         wspace=0.5, figsize=(30,15))

In [None]:
cr.pl.circular_projection(adata, keys=["ageint", "C_scANVI"], legend_loc="right", s=10,
                          wspace=0.5, figsize=(30,15), vmin=6.5, vmax=20,
                          save='circular.svg'
                         )

In [None]:
lin_drivers2 = g2.compute_lineage_drivers(use_raw=False)
lin_drivers2.to_csv(os.path.join(RESULTS_FOLDERNAME, 'devfibros_exptime_lineagedrivers.csv'))

In [None]:
for i in list(g2.terminal_states.cat.categories):
    print(i)
    g2.plot_lineage_drivers(i, n_genes=8, basis='draw_graph_fa', perc=[2,98], use_raw=False,
                   cmap='Blues', figsize=(15,10), save=f'dev_exptime_lindrivers{i}.svg'
                   )
    plt.show()

In [None]:
g2.macrostates.cat.categories

In [None]:
adata.obs['C_scANVI'].cat.categories

In [None]:
sc.pp.neighbors(adata,n_neighbors=20,use_rep="X_diff")
sc.tl.paga(adata,"C_scANVI")

In [None]:
cr.pl.aggregate_fate_probabilities(
    adata,
    mode="bar",
    lineages=g2.terminal_states.cat.categories,
    cluster_key="C_scANVI",
    clusters=adata.obs['C_scANVI'].cat.categories,
    ncols=6,
    figsize=(20, 20),
    save='barfateprobs.svg'
)

In [None]:
cr.pl.aggregate_fate_probabilities(
    adata,
    mode="paga_pie",
    cluster_key="C_scANVI",
    backward=False,
    basis="draw_graph_fa",
    legend_kwargs={"loc": "top right out"},
    legend_loc="top left out",
    node_size_scale=4,
    edge_width_scale=2,
    max_edge_width=3,
    figsize=(9,9),
    threshold=0.2,
    title="PAGA",
    save='alllineages_diffPAGA_fa02.svg',
)

In [None]:
cr.pl.aggregate_fate_probabilities(
    adata,
    mode="heatmap",
    lineages=g2.terminal_states.cat.categories,
    cluster_key="C_scANVI",
    clusters=adata.obs['C_scANVI'].cat.categories,
    ncols=6, cmap='Reds',
    figsize=(15, 6),
    save='heatmap_fateprobs.svg'
)

In [None]:
cr.pl.aggregate_fate_probabilities(
    adata,
    mode="clustermap",
    lineages=g2.terminal_states.cat.categories,
    cluster_key="C_scANVI",
    clusters=adata.obs['C_scANVI'].cat.categories,
    ncols=6, cmap='Reds',
    figsize=(15, 9),
    save='clustermap_fateprobs.svg'
)

In [None]:
for i in g2.terminal_states.cat.categories:
    cr.pl.aggregate_fate_probabilities(
        adata,
        mode="violin",
        lineages=[i],
        cluster_key="C_scANVI",
        clusters=adata.obs['C_scANVI'].cat.categories,
        save=f'{i}_violinfateprobs.svg'
    )

In [None]:
for i in g2.terminal_states.cat.categories:
    cr.pl.aggregate_fate_probabilities(
        adata,
        mode="paga",
        lineages=[i],
        cluster_key="C_scANVI",
        legend_loc=None,
        backward=False,
        basis="draw_graph_fa",
        node_size_scale=3,
        edge_width_scale=2,
        max_edge_width=3,
        figsize=(9,7),
        cmap='cividis', vmin=0,
        threshold=0.2,
        title="PAGA",
        save=f'{i}_diffPAGA_fa02.svg',
    )