In [None]:
import sys
sys.path.insert(0,'/home/cane/Documents/yoseflab/can/resolVI')
from scvi.external import RESOLVI

In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import scvi
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch
from pyro.infer import Importance, EmpiricalMarginal, Trace_ELBO, SVI
from pyro.infer.autoguide import AutoDiagonalNormal

In [None]:
scvi.settings.seed = 0
sc.set_figure_params(dpi=100, dpi_save=300, format='png', frameon=False, vector_friendly=True, fontsize=14, color_map='viridis', figsize=None)
sc.settings.figdir = 'figure4_new/'

In [None]:
sys.path.append('..')
import _utils

# Cleanup data and initialize supervised dataset

In [None]:
resolvae = RESOLVI.load('/external_data/other/resolvi_final_other_files/mouse_colitis/original_q25_semisupervised_tier3_d0-d21/resolvae')

In [None]:
resolvae.adata

In [None]:
adata = resolvae.adata
adata.obsm["X_resolVI"] = resolvae.get_latent_representation()
sc.pp.neighbors(adata, use_rep='X_resolVI', n_neighbors=15, method='rapids')
sc.tl.umap(adata, min_dist=0.3, init_pos='spectral')

In [None]:
adata.obs['total_counts'] = adata.layers['raw_counts'].sum(1)

In [None]:
adata.obs['timepoint'] = [i.split('_')[1] for i in adata.obs['Mouse_ID']]

In [None]:
sc.pl.umap(adata, color=["Tier1", "Tier3", "timepoint"], save='unsupervised_d0_d21.pdf')

In [None]:
adata.obsm['X_tsne'] = adata.obs[['Tier1_umap_x', 'Tier1_umap_y']].values
sc.pl.tsne(adata, color=["Tier1", "Tier3", "timepoint"], save='original_d0_d21.pdf')

In [None]:
samples_corr = resolvae.sample_posterior_predictive(
    model=resolvae.module.model_corrected,
    return_sites=['obs'],
    num_samples=5, return_samples=False, batch_size=1000, batch_steps=10, summary_fun={"q50": np.median})
samples_corr = pd.DataFrame(samples_corr).T

In [None]:
adata.layers['generated_expression'] = samples_corr.loc['post_sample_q50', 'obs']

In [None]:
samples = resolvae.sample_posterior_predictive(
    model=resolvae.module.model_residuals,
    return_sites=[
        'mixture_proportions', 'per_gene_background', 
        'diffusion_mixture_proportion', 'per_neighbor_diffusion', 'px_r_inv'
        ],
    num_samples=5, return_samples=False, batch_size=1000, batch_steps=10)
samples = pd.DataFrame(samples).T

In [None]:
adata.obs['true_proportion'] = samples.loc['post_sample_means', 'mixture_proportions'][:, 0]
adata.obs['diffusion_proportion'] = samples.loc['post_sample_means', 'mixture_proportions'][:, 1]
adata.obs['background_proportion'] = samples.loc['post_sample_means', 'mixture_proportions'][:, 2]
adata.varm['background'] = pd.DataFrame(samples.loc['post_sample_means', 'per_gene_background'][0, ...].squeeze().T, index=adata.var_names)
adata.var['px_r'] = 1/(1e-6 + samples.loc['post_sample_means', 'px_r_inv'][0, :])

_ = plt.hist(adata.obs['background_proportion'], bins=30, range=(0,1))
_ = plt.hist(adata.obs['true_proportion'], bins=30, range=(0,1))
_ = plt.hist(adata.obs['diffusion_proportion'], bins=30, range=(0,1))
plt.legend(['Background_Proportion', 'True_proportions', 'Diffusion_Proportion'])
plt.savefig(f'figure4_new/histogram_proportions.pdf')
plt.show()

In [None]:
adata.obs['true_counts'] = adata.obs['true_proportion'] * adata.obs['total_counts']

In [None]:
sc.pl.umap(adata[adata.obs['true_counts']<15], color=['true_counts', 'total_counts', 'background_proportion', 'Tier3'], vmax='p95')
sc.pl.umap(adata[adata.obs['true_counts']>15], color=['true_counts', 'total_counts', 'background_proportion', 'Tier3'], vmax='p95')

In [None]:
sc.tl.rank_genes_groups(adata, groupby='Tier1')
sc.pl.rank_genes_groups_dotplot(adata, standard_scale='var')
sc.pl.rank_genes_groups_dotplot(adata[adata.obs['true_counts']>15], standard_scale='var', save='_marker_genes_hq_coarse.pdf')
sc.pl.rank_genes_groups_dotplot(adata[adata.obs['true_counts']<15], standard_scale='var', save='_marker_genes_lq_coarse.pdf')

In [None]:
adata.obs['fov_batch'] = adata.obs['Slice_ID'].astype(str) + '__' + adata.obs['FOV'].astype(str)

In [None]:
adata.obs['low_quality'] = 0
adata.obs.loc[np.logical_and(adata.obs['background_proportion']>0.5, adata.obs['true_counts']<20), 'low_quality'] = 1

In [None]:
b = adata.obs.groupby('Slice_ID')['low_quality'].sum()
b = pd.DataFrame(b)
b.rename_axis(columns={'low_quality': 'counts'})
b['mean'] = adata.obs.groupby('Slice_ID')['low_quality'].mean()
list(b.sort_values(by='mean', ascending=False).head(10).index)

In [None]:
adata = sc.read_h5ad('mouse_colitis/cleaned_up_labels_colitis_d0-d21.h5ad')

In [None]:
slice = '082421_D21_m2_1_slice_1'
sc.pl.spatial(adata[(adata.obs['Slice_ID']==slice) & (adata.obs['low_quality']==1)], spot_size=20, color='Tier1', groups=('EntericNervous', 'Smooth Muscle Cells', 'Epithelial'),
             save='low_quality_spatial_cells.pdf')
sc.pl.spatial(adata[(adata.obs['Slice_ID']==slice) & (adata.obs['low_quality']==0)], spot_size=10, color='Tier1', groups=('EntericNervous', 'Smooth Muscle Cells', 'Epithelial'),
             save='high_quality_spatial_cells.pdf')

In [None]:
for slice in ['100221_D9_m5_2_slice_1', '062921_D9_m5_2_slice_3', '082421_D21_m2_1_slice_1']:
    print(slice)
    sc.pl.spatial(adata[(adata.obs['Slice_ID']==slice) & (adata.obs['low_quality']==1)], spot_size=20, color='Tier1', groups=('EntericNervous', 'Smooth Muscle Cells', 'Epithelial'))
    sc.pl.spatial(adata[(adata.obs['Slice_ID']==slice) & (adata.obs['low_quality']==0)], spot_size=10, color='Tier1', groups=('EntericNervous', 'Smooth Muscle Cells', 'Epithelial'))

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

df_fov = adata.obs.groupby('fov_batch')['low_quality'].agg(['sum', 'mean'])

# Extract unique timepoints; ensure it results in a single value per group
df_fov['timepoint'] = adata.obs.groupby('fov_batch')['timepoint'].agg(lambda x: x.unique()[0])

# Plotting
# Prepare colors - map each unique timepoint to a color
unique_timepoints = df_fov['timepoint'].unique()
color_map = {tp: plt.cm.tab10(i) for i, tp in enumerate(unique_timepoints)}
colors = df_fov['timepoint'].map(color_map)

# Create scatter plot
plt.scatter(df_fov['sum'], df_fov['mean'], s=3, c=colors, vmax=10)  # vmax used here if needed for normalization

# Adding a colorbar with timepoint labels requires more complex handling,
# because plt.scatter does not support automatic colorbar for categorical data.
# We will create a custom legend:
handles = [plt.Line2D([0], [0], marker='o', color=color_map[tp], linestyle='', label=tp) for tp in unique_timepoints]
plt.legend(handles=handles, title='Timepoint', loc='center left', bbox_to_anchor=(1, 0.5))

plt.xlabel('Sum of Low Quality')
plt.ylabel('Mean of Low Quality')
plt.title('Scatter Plot of Low Quality by FOV Batch')

plt.show()


In [None]:
sub = adata[adata.obs['low_quality']==0].copy()

In [None]:
sc.pp.neighbors(sub, use_rep='X_resolVI', n_neighbors=100, method='rapids')

In [None]:
from tqdm import tqdm

In [None]:
cell_types = sub.obs['Tier1'].cat.codes.values
most_common_type = []
connectivities = sub.obsp['connectivities']
weighted_ratios = []

for i in tqdm(range(connectivities.shape[0])):
    neighbors = connectivities[i].nonzero()[1]
    neighbor_types = cell_types[neighbors]
    neighbor_weights = connectivities[i, neighbors].toarray().flatten()
    unique_types, counts = np.unique(neighbor_types, return_counts=True)
    weighted_counts = np.zeros(len(sub.obs['Tier1'].cat.categories))

    for j, type_index in enumerate(unique_types):
        weighted_counts[type_index] = neighbor_weights[neighbor_types == type_index].sum()

    most_common = weighted_counts.argmax()
    most_common_type.append(most_common)
    total_weight = neighbor_weights.sum()
    if total_weight > 0:
        weighted_ratios.append(weighted_counts[most_common] / total_weight)
    else:
        weighted_ratios.append(0)

sub.obs['smoothed_tier1'] = sub.obs['Tier1'].cat.categories[most_common_type]
sub.obs['uncertainty_tier1'] = 1 - np.array(weighted_ratios)

In [None]:
sc.pl.umap(sub, color=['smoothed_tier1', 'uncertainty_tier1'], vmin=0.4)

In [None]:
cell_types = sub.obs['Tier3'].cat.codes.values
most_common_type = []
connectivities = sub.obsp['connectivities']
weighted_ratios = []

for i in tqdm(range(connectivities.shape[0])):
    neighbors = connectivities[i].nonzero()[1]
    neighbor_types = cell_types[neighbors]
    neighbor_weights = connectivities[i, neighbors].toarray().flatten()
    unique_types, counts = np.unique(neighbor_types, return_counts=True)
    weighted_counts = np.zeros(len(sub.obs['Tier3'].cat.categories))

    for j, type_index in enumerate(unique_types):
        weighted_counts[type_index] = neighbor_weights[neighbor_types == type_index].sum()

    most_common = weighted_counts.argmax()
    most_common_type.append(most_common)
    total_weight = neighbor_weights.sum()
    if total_weight > 0:
        weighted_ratios.append(weighted_counts[most_common] / total_weight)
    else:
        weighted_ratios.append(0)

sub.obs['smoothed_tier3'] = sub.obs['Tier3'].cat.categories[most_common_type]
sub.obs['uncertainty_tier3'] = 1 - np.array(weighted_ratios)

In [None]:
sc.pl.umap(sub, color=['smoothed_tier3', 'uncertainty_tier3'])

In [None]:
fine_coarse_dictionary = sub.obs[['Tier3', 'Tier1']].drop_duplicates().set_index('Tier3').to_dict()['Tier1']

In [None]:
sc.tl.rank_genes_groups(sub, groupby='Tier1')
sc.pl.rank_genes_groups_dotplot(sub, n_genes=3)

In [None]:
adata.obs['redo_celltyping'] = 'unassigned'
adata.obs.loc[sub[(sub.obs['uncertainty_tier1']<0.05) & (sub.obs['uncertainty_tier3']<0.4)].obs_names, 'redo_celltyping'] = sub.obs.loc[
    (sub.obs['uncertainty_tier1']<0.05) & (sub.obs['uncertainty_tier3']<0.4), 'smoothed_tier3']
adata.obs['redo_celltyping_coarse'] = 'unassigned'
adata.obs.loc[sub[(sub.obs['uncertainty_tier1']<0.05)].obs_names, 'redo_celltyping_coarse'] = sub.obs.loc[
    sub.obs['uncertainty_tier1']<0.05, 'smoothed_tier1']

In [None]:
unique_cell_types = list(sub.obs['Tier3'].unique())
for i in range(0, len(unique_cell_types), 3):
    groups = unique_cell_types[i:i+3]
    print(groups)
    sc.pl.umap(adata, color=['redo_celltyping', 'Tier3'], groups=groups, size=20)

In [None]:
adata.obs['undone_celltyping_coarse'] = adata.obs['Tier1'].astype(str)
adata.obs.loc[sub[sub.obs['uncertainty_tier1']<0.1].obs_names, 'undone_celltyping_coarse'] = 'assigned'

In [None]:
sc.pl.umap(adata, color=['redo_celltyping', 'redo_celltyping_coarse'], ncols=1, save='redid_celltyping_unsupervised.pdf')
sc.pl.umap(adata, color=['redo_celltyping', 'redo_celltyping_coarse'], ncols=1, groups=['unassigned'])

In [None]:
adata.write_h5ad('mouse_colitis/cleaned_up_labels_colitis_d0-d21.h5ad')

# Transfer learning

In [None]:
from scipy.sparse import csr_matrix

In [None]:
query_data = sc.read_h5ad('/external_data/other/resolvi_final_other_files/ibd_moffitt/adata_day35.h5ad')

In [None]:
query_data.layers['raw_counts'] = csr_matrix(pd.read_csv('/external_data/other/resolvi_final_other_files/ibd_moffitt/X_raw_day35.csv', header=None).values)

In [None]:
query_data.obs['timepoint'] = 'D35'

In [None]:
query_data.obs_names = 'query_' + query_data.obs_names

In [None]:
query_data.obsm['spatial'] = query_data.obs[['x', 'y']].values

In [None]:
resolvae = RESOLVI.load('mouse_colitis/original_q25_semisupervised_redo3/resolvae')

In [None]:
resolvae.prepare_query_anndata(query_data, resolvae)

In [None]:
retrain = False
if retrain:
    resolvae_query = resolvae.load_query_data(query_data, reference_model=resolvae)
    resolvae_query.train(max_epochs=100, weight_decay=0., batch_size=1024, n_epochs_kl_warmup=0)
    resolvae_query.save('mouse_colitis/original_q25_semisupervised_redo3/resolvae_query', save_anndata=True, overwrite=True)
else:
    resolvae_query = RESOLVI.load('mouse_colitis/original_q25_semisupervised_redo3/resolvae_query')

In [None]:
plt.plot(resolvae_query.history['elbo_train'])
plt.show()

In [None]:
adata = sc.read('mouse_colitis/cleaned_up_labels_colitis_d0-d21.h5ad')

In [None]:
all_data = adata.concatenate(query_data, join='outer', batch_key='source', batch_categories=['reference', 'query'])

In [None]:
all_data.obs['cluster'] = all_data.obs['cluster'].fillna("unassigned")

In [None]:
all_data = all_data[all_data.obs['cluster']!='Mast cell']

In [None]:
all_data.obs_names = [i.split('-')[0] for i in all_data.obs_names]

In [None]:
resolvae_both = resolvae_query.load_query_data(all_data, reference_model=resolvae_query)

In [None]:
all_data.obsm['celltype_predicted'] = resolvae_both.predict(soft=True, num_samples=5, batch_size=3000)
all_data.obs['predicted_celltype'] = all_data.obsm['celltype_predicted'].idxmax(axis=1)
all_data.obs['predicted_celltype_prob'] = all_data.obsm['celltype_predicted'].max(axis=1)

In [None]:
fine_coarse_dictionary = adata.obs[['Tier3', 'Tier1']].drop_duplicates().set_index('Tier3').to_dict()['Tier1']
all_data.obs['predicted_celltype_coarse'] = all_data.obs['predicted_celltype'].map(fine_coarse_dictionary)

In [None]:
all_data.obsm["X_resolVI"] = resolvae_both.get_latent_representation()

In [None]:
sc.pp.neighbors(all_data, use_rep='X_resolVI', method='rapids')
sc.tl.umap(all_data, min_dist=0.3)

In [None]:
all_data = sc.read_h5ad('/external_data/other/resolvi_final_other_files/figure4/query_trained_data_d0_35.h5ad')

In [None]:
all_data.obs.loc[all_data.obs['redo_celltyping'] == 'unassigned', 'redo_celltyping'] = None

In [None]:
sc.pl.umap(all_data[all_data.obs['timepoint'].isin(['D0', 'D3', 'D21', 'D9'])],
           color=["predicted_celltype", "predicted_celltype_prob", "redo_celltyping"], ncols=1, save='semisupervised_celltypes_predicted.pdf')

In [None]:
fibroblast = sub[(sub.obs['Tier1']=='Fibroblast') & (sub.obs['predicted_celltype_coarse']=='Fibroblast')].copy()
sc.tl.rank_genes_groups(fibroblast, groupby='Tier3')
sc.pl.rank_genes_groups_dotplot(fibroblast, standard_scale='var', dendrogram=True, n_genes=5)
sc.pl.rank_genes_groups_dotplot(fibroblast[fibroblast.obs['redo_celltyping']!='unassigned'], standard_scale='var', dendrogram=True, n_genes=5, save='marker_genes_seed_fine.pdf')
sc.pl.rank_genes_groups_dotplot(fibroblast[fibroblast.obs['redo_celltyping']=='unassigned'], standard_scale='var', dendrogram=True, n_genes=5, save='marker_genes_other_fine.pdf')

In [None]:
redo_data = fibroblast[fibroblast.obs['predicted_celltype']!=fibroblast.obs['Tier3']].copy()
sc.pl.rank_genes_groups_dotplot(redo_data, standard_scale='var', dendrogram=True, n_genes=5, save='marker_genes_disagreement_original.pdf')
redo_data.obs['Tier3'] = redo_data.obs['predicted_celltype']
sc.pl.rank_genes_groups_dotplot(redo_data, standard_scale='var', dendrogram=True, n_genes=5, save='marker_genes_disagreement_resolvi.pdf')

In [None]:
all_data.write_h5ad('/external_data/other/resolvi_final_other_files/figure4/query_trained_data_d0_35.h5ad')

In [None]:
all_data = sc.read_h5ad('/external_data/other/resolvi_final_other_files/figure4/query_trained_data_d0_35.h5ad')
resolvae_both = RESOLVI.load_query_data(all_data, reference_model='mouse_colitis/original_q25_semisupervised_redo3/resolvae_query')

In [None]:
samples_corr = resolvae_both.sample_posterior_predictive(
    model=resolvae_both.module.model_corrected,
    return_sites=['px_rate', 'obs'],
    summary_fun={"post_sample_means": np.mean, "post_sample_q50": np.median},
    num_samples=10, return_samples=False, batch_size=2000, batch_steps=10)
samples_corr = pd.DataFrame(samples_corr).T

samples = resolvae_both.sample_posterior_predictive(
    model=resolvae_both.module.model_residuals,
    return_sites=[
        'mixture_proportions', 'mean_poisson', 'per_gene_background', 
        'diffusion_mixture_proportion', 'per_neighbor_diffusion', 'px_r_inv'
        ],
    summary_fun={"post_sample_means": np.mean, "post_sample_q50": np.median},
    num_samples=10, return_samples=False, batch_size=2000, batch_steps=10)
samples = pd.DataFrame(samples).T

In [None]:
from scipy.sparse import csr_matrix

In [None]:
all_data.layers['generated_expression'] = csr_matrix(samples_corr.loc['post_sample_post_sample_q50', 'obs'])
all_data.layers['imputed_expression'] = csr_matrix(samples_corr.loc['post_sample_post_sample_q50', 'px_rate'])

In [None]:
all_data.obs['true_proportions'] = samples.loc['post_sample_post_sample_q50', 'mixture_proportions'].T[0, :]
all_data.obs['diffusion_proportions'] = samples.loc['post_sample_post_sample_q50', 'mixture_proportions'].T[1, :]
all_data.obs['background_proportions'] = samples.loc['post_sample_post_sample_q50', 'mixture_proportions'].T[2, :]

In [None]:
all_data.obs['total_counts'] = np.array(all_data.layers['raw_counts'].sum(1))

In [None]:
all_data.obs['true_counts'] = all_data.obs['true_proportions'] * all_data.obs['total_counts']

In [None]:
sc.pl.umap(all_data[all_data.obs['predicted_celltype_prob']>0.6], color=["Tier3", "predicted_celltype", "predicted_celltype_prob", "true_counts"])
sc.pl.umap(all_data[all_data.obs['predicted_celltype_prob']<0.6], color=["Tier3", "predicted_celltype", "predicted_celltype_prob", "true_counts"])

In [None]:
all_data.write_h5ad('/external_data/other/resolvi_final_other_files/figure4/query_trained_data_d0_35.h5ad')

# Downstream

In [None]:
all_data = sc.read_h5ad('/external_data/other/resolvi_final_other_files/figure4/query_trained_data_d0_35.h5ad')

In [None]:
all_data.X = all_data.layers['generated_expression'].copy()
sc.pp.normalize_total(all_data)
sc.pp.log1p(all_data)

In [None]:
sub = all_data[all_data.obs['true_counts']>20].copy()

In [None]:
sub.layers['counts'] = sub.layers['raw_counts'].copy()
sc.pp.normalize_total(sub, layer='counts')
sc.pp.log1p(sub, layer='counts')

In [None]:
resolvae_both = RESOLVI.load_query_data(sub, reference_model='mouse_colitis/original_q25_semisupervised_redo3/resolvae_query')

In [None]:
sub.obs['timepoint'] = [i.split('_')[1] for i in sub.obs['Mouse_ID']]

In [None]:
for i in ['D0', 'D35', 'D3', 'D9', 'D21']:
    sc.pl.umap(sub[
        (sub.obs['timepoint']==i) & (sub.obs['predicted_celltype_prob']>0.4)
    ], color=['predicted_celltype_prob', 'Tier3', 'predicted_celltype'], vmax='p99', ncols=3, size=2, title=i)
    

In [None]:
fibroblast = sub[(sub.obs['Tier1']=='Fibroblast') & (sub.obs['predicted_celltype_coarse']=='Fibroblast')].copy()

In [None]:
sc.tl.rank_genes_groups(fibroblast, groupby='predicted_celltype')
sc.pl.rank_genes_groups_dotplot(fibroblast, standard_scale='var', dendrogram=True, n_genes=5)
sc.pl.rank_genes_groups_dotplot(fibroblast, standard_scale='var', dendrogram=True, n_genes=5, layer='generated_expression')

In [None]:
sc.tl.rank_genes_groups(fibroblast, groupby='timepoint', layer='generated_expression')
sc.pl.rank_genes_groups_dotplot(fibroblast, standard_scale='var', dendrogram=True, n_genes=10, layer='generated_expression')
sc.pl.rank_genes_groups_dotplot(fibroblast, standard_scale='var', dendrogram=True, n_genes=10)

In [None]:
epithelial = sub[(sub.obs['Tier1']=='Epithelial') & (sub.obs['predicted_celltype_coarse']=='Epithelial')].copy()

In [None]:
sc.tl.rank_genes_groups(epithelial, groupby='predicted_celltype')
sc.pl.rank_genes_groups_dotplot(epithelial, standard_scale='var', dendrogram=True, n_genes=5)
sc.pl.rank_genes_groups_dotplot(epithelial, standard_scale='var', dendrogram=True, n_genes=5, layer='generated_expression')

In [None]:
for slice in epithelial.obs['Slice_ID'].unique():
    sc.pl.spatial(epithelial[epithelial.obs['Slice_ID']==slice], spot_size=15, color=['predicted_celltype', 'Tier3'], title=slice, ncols=2)

In [None]:
sc.pl.spatial(fibroblast[fibroblast.obs['Slice_ID']=='062921_D0_m3a_1_slice_2'], spot_size=15, color=['predicted_celltype', 'Tier3'], title='062921_D0_m3a_1_slice_2')
sc.pl.spatial(fibroblast[fibroblast.obs['Slice_ID']=='072523_D35_m6_1_slice_3'], spot_size=15, color=['predicted_celltype', 'Tier3'], title='072523_D35_m6_1_slice_3')

In [None]:
sc.pl.spatial(epithelial[epithelial.obs['Slice_ID']=='062921_D0_m3a_1_slice_2'], spot_size=7, color=['predicted_celltype', 'Tier3'], title='062921_D0_m3a_1_slice_2')
sc.pl.spatial(epithelial[epithelial.obs['Slice_ID']=='072523_D35_m6_1_slice_3'], spot_size=7, color=['predicted_celltype', 'Tier3'], title='072523_D35_m6_1_slice_3')

In [None]:
sc.pl.spatial(sub[sub.obs['Slice_ID']=='062921_D0_m3a_1_slice_2'], spot_size=10, color=['Il22ra1', 'Rorc', 'Il10rb', 'Edn1'], title='062921_D0_m3a_1_slice_2', vmax=[3, 4, 3, 2])
sc.pl.spatial(sub[sub.obs['Slice_ID']=='072523_D35_m6_1_slice_3'], spot_size=10, color=['Il22ra1', 'Rorc', 'Il10rb', 'Edn1'], title='072523_D35_m6_1_slice_3', vmax=[3, 4, 3, 2])

In [None]:
sc.pl.spatial(sub[sub.obs['Slice_ID']=='062921_D0_m3a_1_slice_2'], spot_size=7, layer='generated_expression',
              color=['Il18', 'Tlr1', 'Il18', 'Acvr1b', 'L1cam'], title='062921_D0_m3a_1_slice_2', vmax=[3, 2, 3, 4, 2])
sc.pl.spatial(sub[sub.obs['Slice_ID']=='072523_D35_m6_1_slice_3'], spot_size=7, layer='generated_expression',
              color=['Il18', 'Tlr1', 'Il18', 'Acvr1b', 'L1cam'], title='072523_D35_m6_1_slice_3', vmax=[3, 2, 3, 4, 2])

In [None]:
sub.write_h5ad('/external_data/other/resolvi_final_other_files/figure4/query_trained_data_d0_35_filtered.h5ad')

# Plots

In [None]:
sub = sc.read_h5ad('/external_data/other/resolvi_final_other_files/figure4/query_trained_data_d0_35_filtered.h5ad')

In [None]:
sc.pl.umap(sub, color='timepoint', save='timepoint_d0-35.pdf')

In [None]:
resolvae_both = RESOLVI.load_query_data(sub, reference_model='mouse_colitis/original_q25_semisupervised_redo3/resolvae_query')

In [None]:
import squidpy as sq

In [None]:
sq.gr.spatial_neighbors(sub, coord_type="generic", library_key='Slice_ID', n_neighs=30, delaunay=False, transform=None, set_diag=True)

In [None]:
cell_types_prediction = sub.obsm['celltype_predicted'].values
most_common_type = []
connectivities = sub.obsp['spatial_connectivities']
weighted_ratios = np.zeros(cell_types_prediction.shape)

for i in tqdm(range(connectivities.shape[0])):
    neighbors = connectivities[i].nonzero()[1]
    neighbor_types = cell_types_prediction[neighbors, :]
    weighted_ratios[i, :] = np.sum(neighbor_types, axis=0)

In [None]:
sub.obsm['celltypes_neighborhood'] = pd.DataFrame(weighted_ratios, columns=sub.obsm['celltype_predicted'].columns, index=sub.obs_names)

In [None]:
cell_types_prediction = sub.obsm['celltype_predicted'].values
most_common_type = []
connectivities = sub.obsp['spatial_connectivities']
weighted_ratios = np.zeros(cell_types_prediction.shape)

for i in tqdm(range(connectivities.shape[0])):
    neighbors = connectivities[i].nonzero()[1]
    neighbor_types = cell_types_prediction[neighbors, :]
    neighbor_weights = connectivities[i, neighbors].toarray().flatten()
    weighted_counts = neighbor_types @ neighbor_weights

    total_weight = neighbor_weights.sum()
    weighted_ratios[i, :] = weighted_counts / total_weight

In [None]:
sub.obsm['celltypes_neighborhood_weighted'] = pd.DataFrame(weighted_ratios, columns=sub.obsm['celltype_predicted'].columns, index=sub.obs_names)

In [None]:
sc.pp.neighbors(sub, use_rep='celltypes_neighborhood', n_neighbors=50, method='rapids')
sub.obsm['X_tsne'] = sc.tl.umap(sub, min_dist=0.5, init_pos='spectral', copy=True).obsm['X_umap']
sc.pl.tsne(sub, color=['Tier1', 'Tier3'])

In [None]:
sc.tl.louvain(sub, resolution=0.8, key_added='leiden_spatial', flavor='rapids')
sc.pl.tsne(sub, color=['leiden_spatial'])

In [None]:
palette = {
    'FOL': '#7a4900',
    'FOL1': '#1ce6ff',
    'FOL2': '#ff34ff',
    'FOL3': '#ff4a46',
    'LUM': '#008941',
    'ME': '#997d87',
    'ME1': '#a30059',
    'ME2': '#ffdbe5',
    'ME3': '#ffff00',
    'ME4': '#0000a6',
    'MES1': '#00c2a0',
    'MES2': '#b79762',
    'MES3': '#004d43',
    'MU': '#8fb0ff',
    'MU1': '#006fa6',
    'MU2': '#5a0007',
    'MU3': '#d16100',
    'MU4': '#6a3a4c',
    'MU5': '#63ffac',
    'MU6': '#b903aa',
    'MU7': '#3b5dff',
    'MU8': '#4a3b53',
    'MU9': '#ff2f80',
    'MU10': '#61615a',
    'MU11': '#ba0900',
    'SM': '#6b7900',
    'SM1': '#1b4400',
    'SM2': '#ffaa92',
    'SM3': '#ff90c9',
    'other': '#4fc601',
    'others': '#d16100'
}

In [None]:
sub.obs['Leiden_neigh_re'] = sub.obs['Leiden_neigh']
sub.obs.loc[sub.obs['timepoint']=='D35', 'Leiden_neigh_re'] = None
sub.obs.loc[sub.obs['Leiden_neigh_re']=='others', 'Leiden_neigh_re'] = None

In [None]:
sc.pl.tsne(sub, color=['timepoint', 'predicted_celltype_coarse', 'leiden_neigh_resolvi'], save='niches_embedding.pdf')

In [None]:
sc.pl.tsne(sub[sub.obs['predicted_celltype_coarse']=='Fibroblast'], color=['timepoint', 'predicted_celltype', 'leiden_neigh_resolvi'], save='niches_embedding_fibroblast.pdf')

In [None]:
sc.pl.tsne(sub[sub.obs['predicted_celltype']=='Fibro 2'], color=['timepoint', 'predicted_celltype', 'leiden_neigh_resolvi'], save='niches_embedding_fibro2.pdf')

In [None]:
cell_types = sub.obs['Leiden_neigh_re'].cat.codes.values
most_common_type = []
connectivities = sub.obsp['connectivities']
weighted_ratios = []

for i in tqdm(range(connectivities.shape[0])):
    neighbors = connectivities[i].nonzero()[1]
    neighbor_types = cell_types[neighbors]
    neighbor_weights = connectivities[i, neighbors].toarray().flatten()
    unique_types, counts = np.unique(neighbor_types, return_counts=True)
    weighted_counts = np.zeros(len(sub.obs['Leiden_neigh_re'].cat.categories))

    for j, type_index in enumerate(unique_types):
        weighted_counts[type_index] = neighbor_weights[neighbor_types == type_index].sum()

    most_common = weighted_counts[:-1].argmax()
    total_weight = neighbor_weights.sum()
    if weighted_counts[:-1].sum() > 0:
        most_common_type.append(most_common)
        weighted_ratios.append(weighted_counts[most_common] / total_weight)
    else:
        most_common_type.append(-1)
        weighted_ratios.append(0)

In [None]:
sub.obs['leiden_neigh_resolvi'] = sub.obs['Leiden_neigh_re'].cat.categories[most_common_type]
sub.obs['uncertainty_leiden_neigh'] = 1 - np.array(weighted_ratios)

In [None]:
sc.pl.tsne(sub, color=['uncertainty_leiden_neigh', 'Leiden_neigh', 'leiden_neigh_resolvi'], palette=palette, ncols=1)

In [None]:
sc.pl.tsne(sub, color=['Leiden_neigh_re'], legend_loc='on data', legend_fontsize='small')

In [None]:
sc.pl.tsne(sub[sub.obs['timepoint']=='D35'], color=['Leiden_neigh_re', 'leiden_neigh_resolvi', 'leiden_spatial', 'uncertainty_leiden_neigh'], legend_loc='on data', legend_fontsize='small', title='D35')

In [None]:
for i in ['D35', 'D0', 'D3', 'D9', 'D21']:
    sc.pl.tsne(sub[sub.obs['timepoint']==i], color=['leiden_neigh_resolvi', 'leiden_spatial'], legend_loc='on data', legend_fontsize='small', title=i)

In [None]:
for i in ['D35', 'D0', 'D3', 'D9', 'D21']:
    sc.pl.tsne(sub[sub.obs['timepoint']==i], color='predicted_celltype', groups=['Fibro 2', 'Stem cells', 'Fibro 6', 'Fibro 4', 'Fibro 7', 'Colonocytes', 'Goblet 1', 'Goblet 2', 'TA'])

In [None]:
sub.write_h5ad(f'figure4_new/processed_adata_all_final_niche_final.h5ad')

In [None]:
sub = sc.read_h5ad('figure4_new/processed_adata_all_final_niche_final.h5ad')

In [None]:
resolvae_both = RESOLVI.load_query_data(sub, reference_model='mouse_colitis/original_q25_semisupervised_redo3/resolvae_query')

In [None]:
tier2_3_ordered = ['Stem cells',
       'TA',
       'Colonocytes',
       'Goblet 1',
       'Goblet 2',
       'EEC',
       'M cells',
       'Epithelial (Clu+)',
       'IAE 1',
       'IAE 2',
       'IAE 3',
       'Repair associated  (Arg1+)',
       'Fibro 6',
       'Fibro 2',
       'Fibro 13',
       'Fibro 7',
       'Fibro 4',
       'Fibro 1',
       'Fibro 15',
       'Fibro 5',
       'Fibro 12',
       'FRC',
       'Pericyte 1',
       'Pericyte 2',
       'IAF 1',
       'IAF 2',
       'IAF 3',
       'IAF 4',
       'IAF 5',
       
       'Monocyte', 'Macrophage (Itgax+)','Macrophage (Mrc1+)', 'Macrophage (Lyve1+)','Macrophage (Cxcl10+)',
       'cDC1', 'DC (Fscn1+)','DC (Ccl22+)','Neutrophil 1','Neutrophil 2', 'Mast cell',
       'T (Cd4+ Ccr7+)','Treg','T (Mki67+)','Tfh','T (Cd8+)','NK','ILC2','ILC3-LTi-like',
       'B cell 1','B cell 2','B cell (Aicda+)','Plasma cell',
        
       'SMC 1','SMC 2', 'SMC Mesentery',
       'IASMC 1', 'IASMC 2', 'IASMC 3', 
       
       'Lymphatic EC',
       'Lymphatic EC (Ccl21a+)', 'Arterial EC', 'Venous EC',
       'Capillarial EC', 
       
       'Glia (Apod+)', 'Glia (Gfap+)', 'Glia (Gfra3+)',
       'Glia (Mpz+)', 'Neuron (Chat+)', 'Neuron (Nos1+)', 
       'ICC 1','ICC 2', 'Adipose']

In [None]:
columns_probabilities = []
for ct in tier2_3_ordered:
    try:
        values = sub.obsm['celltype_predicted'][ct]
        values[values<1e-3] = 0
        sub.obs[f'prediction_probability_{ct}'] = values
        columns_probabilities.append(f'prediction_probability_{ct}')
    except:
        print(ct)

In [None]:
list_neighborhoods = ['MU1', 'MU2', 'MU3', 'MU4', 'MU5', 'MU6', 'MU7', 'MU8', 'MU9', 'MU10', 'MU11','LUM',
       'SM1', 'SM2', 'SM3', 'ME1', 'ME2', 'ME3', 'ME4',
       'FOL1', 'FOL2', 'FOL3','MES1', 'MES2', 'MES3'
]

In [None]:
sc.pl.dotplot(sub[(~sub.obs['Leiden_neigh'].isin(['ME', 'FOL', 'MU', 'SM', 'other', 'others'])) & (sub.obs['leiden_neigh_resolvi'].astype(str)!=sub.obs['Leiden_neigh'].astype(str))],
              var_names=columns_probabilities, groupby='Leiden_neigh', standard_scale='var', dendrogram=False,
             categories_order=list_neighborhoods, smallest_dot=6., save='original_niches.pdf')
sc.pl.dotplot(sub[(~sub.obs['Leiden_neigh'].isin(['ME', 'FOL', 'MU', 'SM', 'other', 'others'])) & (sub.obs['leiden_neigh_resolvi'].astype(str)!=sub.obs['Leiden_neigh'].astype(str))],
              var_names=columns_probabilities, groupby='leiden_neigh_resolvi', standard_scale='var', dendrogram=False,
             categories_order=list_neighborhoods, smallest_dot=6., save='resolvi_niches.pdf')

In [None]:
sc.pl.dotplot(sub[sub.obs['leiden_neigh_resolvi']!='others'], var_names=columns_probabilities, groupby='leiden_neigh_resolvi', standard_scale='var', dendrogram=False,
             categories_order=list_neighborhoods, smallest_dot=6., save='resolvi_niches_all.pdf')

In [None]:
sc.pl.dotplot(sub[~sub.obs['Leiden_neigh'].isin(['ME', 'FOL', 'MU', 'SM', 'other', 'others'])], var_names=columns_probabilities, groupby='Leiden_neigh', standard_scale='var', dendrogram=False,
             categories_order=list_neighborhoods, smallest_dot=6., save='original_niches_all.pdf')

In [None]:
da = resolvae_both.differential_niche_abundance(
    groupby='timepoint', group1='D35', group2='D0', subset_idx=np.where(sub.obs['predicted_celltype']=='Fibro 2')[0], neighbor_key='index_neighbor', test_mode='three',
    delta=0.05, pseudocounts=3e-2)
da.head(1)

In [None]:
da.head(20)

In [None]:
import decoupler as dc

In [None]:
dc.plot_volcano_df(
    da,
    x='lfc_mean',
    y='proba_not_de',
    sign_thr=0.5,
    lFCs_thr=0.5,
    top=30,
    figsize=(10, 10),
    save='figure4_new/fibro2_da.pdf'
)
plt.show()

In [None]:
sub.obs['sub_fibro'] = [i if j=='Fibroblast' else None for i,j in zip(sub.obs['predicted_celltype'], sub.obs['predicted_celltype_coarse'])]

In [None]:
with plt.rc_context({"figure.figsize": (15, 15), "figure.dpi": (300)}):
    sc.pl.spatial(
        sub[sub.obs['Slice_ID']=='082421_D0_m6_1_slice_1'], spot_size=15, layer='generated_expression', color=['sub_fibro'], title='082421_D0_m6_1_slice_1', ncols=1, save='distribution_fibro2_d0.pdf'
    )

In [None]:
with plt.rc_context({"figure.figsize": (15, 15), "figure.dpi": (300)}):
    sc.pl.spatial(
        sub[sub.obs['Slice_ID']=='072523_D35_m6_1_slice_3'], spot_size=15, layer='generated_expression', color=['sub_fibro'], title='072523_D35_m6_1_slice_3', ncols=1, save='distribution_fibro2_d35.pdf'
    )

In [None]:
de_result_importance = resolvae_both.differential_expression(
    adata=sub[(sub.obs['predicted_celltype_coarse']=='Fibroblast') & (sub.obs['timepoint']=='D0')], groupby='leiden_neigh_resolvi', group1='MU3', group2='MU1', weights='importance',
    pseudocounts=1e-2, delta=0.05, filter_outlier_cells=True, mode='change', test_mode='three'#, batch_correction=True, batchid1=batch_index, batchid2=batch_index,
)
de_result_importance.head(50)

In [None]:
dc.plot_volcano_df(
    de_result_importance,
    x='lfc_mean',
    y='proba_not_de',
    sign_thr=0.3,
    lFCs_thr=0.4,
    top=30,
    figsize=(10, 10),
    save='figure4_new/fibroblast_resolvi_de.pdf'
)
plt.show()

In [None]:
de_result_uniform = resolvae_both.differential_expression(
    adata=sub[sub.obs['predicted_celltype']=='Colonocytes'], groupby='timepoint', group1='D0', group2='D35', weights='importance',
    pseudocounts=1e-2, delta=0.05, filter_outlier_cells=True, mode='change', test_mode='three', #batch_correction=True, batchid1=batch_index, batchid2=batch_index,
)
de_result_uniform.head(30)

In [None]:
dc.plot_volcano_df(
    de_result_uniform,
    x='lfc_mean',
    y='proba_not_de',
    sign_thr=0.3,
    lFCs_thr=0.4,
    top=60,
    figsize=(10, 10),
    save='figure4_new/colonocyte_resolvi_de.pdf'
)
plt.show()

In [None]:
sub.obs['leiden_spatial_fibroblast'] = [i if j=='Fibroblast' else None for i, j in zip(sub.obs['leiden_spatial'], sub.obs['Tier1'])]
sub.obs['leiden_neigh_resolvi_fibroblast'] = [i if j=='Fibroblast' else None for i, j in zip(sub.obs['leiden_neigh_resolvi'], sub.obs['Tier1'])]
sub.obs['Tier3_fibroblast'] = [i if j=='Fibroblast' else None for i, j in zip(sub.obs['Tier3'], sub.obs['Tier1'])]

In [None]:
for slice in sub.obs['Slice_ID'].unique():
    if ('D0' in slice) or ('D35' in slice):
        sc.pl.spatial(
            sub[(sub.obs['Slice_ID']==slice)],
            spot_size=7, color=['leiden_neigh_resolvi_fibroblast', 'leiden_spatial_fibroblast', 'Tier3_fibroblast'], title=slice) 

In [None]:
sub.obs['neighborhood_colonocyte'] = sub.obsm['celltypes_neighborhood']['Colonocytes'] / sub.obsm['celltypes_neighborhood'].sum(1)
sub.obs['neighborhood_stem'] = sub.obsm['celltypes_neighborhood']['Stem cells'] / sub.obsm['celltypes_neighborhood'].sum(1)
sub.obs['neighborhood_ta'] = sub.obsm['celltypes_neighborhood']['TA'] / sub.obsm['celltypes_neighborhood'].sum(1)

In [None]:
sc.pl.tsne(sub, color=['leiden_spatial', 'neighborhood_colonocyte', 'neighborhood_stem', 'neighborhood_ta'], cmap='Reds', vmax=0.7)

In [None]:
with plt.rc_context({"figure.figsize": (15, 15), "figure.dpi": (300)}):
    sc.pl.spatial(
        sub[sub.obs['Slice_ID']=='082421_D0_m6_1_slice_1'], spot_size=8, layer='generated_expression',
        color=['neighborhood_colonocyte', 'neighborhood_stem', 'neighborhood_ta'], title='082421_D0_m6_1_slice_1', ncols=3, save='zonation_epithelial_d0.pdf'
    )

In [None]:
with plt.rc_context({"figure.figsize": (15, 15), "figure.dpi": (300)}):
    sc.pl.spatial(
        sub[sub.obs['Slice_ID']=='072523_D35_m6_1_slice_3'], spot_size=8, layer='generated_expression',
        color=['neighborhood_colonocyte', 'neighborhood_stem', 'neighborhood_ta'], ncols=3, save='zonation_epithelial_d0.pdf'
    )

In [None]:
sub.obs['neighborhood_stem_colonocyte'] = sub.obs['neighborhood_colonocyte'] - sub.obs['neighborhood_stem']
with plt.rc_context({"figure.figsize": (15, 15), "figure.dpi": (300)}):
    sc.pl.spatial(
        sub[sub.obs['Slice_ID']=='082421_D0_m6_1_slice_1'], spot_size=8, layer='generated_expression', cmap='bwr', vmax=1, vmin=-1,
        color=['neighborhood_stem_colonocyte'], ncols=3, save='gradient_epithelial_d0.pdf'
    )

In [None]:
with plt.rc_context({"figure.figsize": (15, 15), "figure.dpi": (300)}):
    sc.pl.spatial(
        sub[sub.obs['Slice_ID']=='072523_D35_m6_1_slice_3'], spot_size=8, layer='generated_expression', cmap='bwr', vmax=1, vmin=-1,
        color=['neighborhood_stem_colonocyte'], ncols=3, save='gradient_epithelial_d35.pdf'
    )

In [None]:
sub.obs['neighborhood_stem_colonocyte_threshold'] = (sub.obs['neighborhood_stem_colonocyte'] > 0.5).astype(str)
sc.pl.spatial(
    sub[sub.obs['Slice_ID']=='072523_D35_m6_1_slice_3'], spot_size=8, layer='generated_expression', cmap='bwr', vmax=1, vmin=-1,
    color=['neighborhood_stem_colonocyte_threshold'], ncols=3,
)
sc.pl.spatial(
    sub[sub.obs['Slice_ID']=='082421_D0_m6_1_slice_1'], spot_size=8, layer='generated_expression', cmap='bwr', vmax=1, vmin=-1,
    color=['neighborhood_stem_colonocyte_threshold'], ncols=3,
)

In [None]:
de_result_importance = resolvae_both.differential_expression(
    adata=sub[(sub.obs['timepoint']=='D35') & (sub.obs['predicted_celltype']=='Colonocytes')], groupby='neighborhood_stem_colonocyte_threshold', group1='True', group2='False', weights='importance',
    pseudocounts=1e-2, delta=0.05, filter_outlier_cells=True, mode='change', test_mode='three', #batch_correction=True, batchid1=batch_index, batchid2=batch_index,
)
de_result_importance.head(50)

In [None]:
dc.plot_volcano_df(
    de_result_importance,
    x='lfc_mean',
    y='proba_not_de',
    sign_thr=0.5,
    lFCs_thr=0.1,
    top=20,
    figsize=(10, 10),
    save='figure4_new/colonocyte_niches_resolvi_de.pdf'
)
plt.show()

# Benchmark

In [None]:
from harmony import harmonize
sc.tl.pca(sub)
sub.obsm['X_pca_harmony'] = harmonize(sub.obsm['X_pca'], sub.obs, batch_key = 'Slice_ID')

In [None]:
resolvae = RESOLVI.load_query_data(sub, 'mouse_colitis/original_q25_semisupervised_tier3_d0-d21/resolvae')
sub.obsm['X_resolvi_unsupervised'] = resolvae.get_latent_representation()

In [None]:
sub = sc.read_h5ad(f'figure4_new/processed_adata_all_final_niche_final.h5ad')

In [None]:
sub_d0 = sub[sub.obs['timepoint']=='D0']

In [None]:
sc.pl.umap(sub_d0, color='predicted_celltype_coarse', save='celltypes_d0.pdf')

In [None]:
sub_d0 = sub_d0[sub_d0.obs['low_quality'] == 0]

In [None]:
from scib_metrics.benchmark._core import BatchCorrection

batch_correction = BatchCorrection(
    silhouette_batch=True,
    ilisi_knn=True,
    kbet_per_label=True,
    graph_connectivity=True,
    pcr_comparison=False,
)

from contextlib import contextmanager
from scib_metrics.benchmark import Benchmarker

In [None]:
bm = Benchmarker(
    sub_d0,
    batch_key="Slice_ID",
    label_key="Tier3",
    batch_correction_metrics=batch_correction,
    embedding_obsm_keys=['X_pca', 'X_pca_harmony', 'X_resolVI', 'X_resolvi_unsupervised'],
    pre_integrated_embedding_obsm_key='X_pca',
    n_jobs=12,
)
bm.benchmark()

In [None]:
plt.rcParams['font.weight'] = 'normal'
plt.rcParams['axes.titleweight'] = 'normal'
plt.rcParams['axes.labelweight'] = 'normal'

bm.plot_results_table(min_max_scale=False, save_dir=f'figure4_new/')