In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['R_HOME'] = '/home/cane/miniconda3/envs/hub_paper/lib/R'
import milopy

In [None]:
from torch import distributions

In [None]:
import sys
sys.path.insert(0,'/home/cane/Documents/yoseflab/can/resolVI')
from scvi.external import RESOLVI

In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import scvi
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch
from pyro.infer import Importance, EmpiricalMarginal, Trace_ELBO, SVI
from pyro.infer.autoguide import AutoDiagonalNormal

In [None]:
sns.reset_orig()
sc.settings._vector_friendly = True
sc.settings.n_jobs = -1
# p9.theme_set(p9.theme_classic)
plt.rcParams["svg.fonttype"] = "none"
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["savefig.transparent"] = True
plt.rcParams["figure.figsize"] = (4, 4)

plt.rcParams["axes.titlesize"] = 15
plt.rcParams["axes.titleweight"] = 500
plt.rcParams["axes.titlepad"] = 8.0
plt.rcParams["axes.labelsize"] = 14
plt.rcParams["axes.labelweight"] = 500
plt.rcParams["axes.linewidth"] = 1.2
plt.rcParams["axes.labelpad"] = 6.0
plt.rcParams["axes.spines.top"] = False
plt.rcParams["axes.spines.right"] = False

plt.rcParams["font.size"] = 11
# plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Helvetica', "Computer Modern Sans Serif", "DejaVU Sans"]
plt.rcParams['font.weight'] = 500

plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['xtick.minor.size'] = 1.375
plt.rcParams['xtick.major.size'] = 2.75
plt.rcParams['xtick.major.pad'] = 2
plt.rcParams['xtick.minor.pad'] = 2

plt.rcParams['ytick.labelsize'] = 12
plt.rcParams['ytick.minor.size'] = 1.375
plt.rcParams['ytick.major.size'] = 2.75
plt.rcParams['ytick.major.pad'] = 2
plt.rcParams['ytick.minor.pad'] = 2

plt.rcParams["legend.fontsize"] = 12
plt.rcParams['legend.handlelength'] = 1.4
plt.rcParams['legend.numpoints'] = 1
plt.rcParams['legend.scatterpoints'] = 3

plt.rcParams['lines.linewidth'] = 1.7
DPI = 300

In [None]:
sc.set_figure_params(dpi=100, dpi_save=300, format='png', frameon=False, vector_friendly=True, fontsize=14, color_map='viridis', figsize=None)
sc.settings.figdir = 'figure4_new/'

In [None]:
sub = sc.read(f'figure4_new/processed_adata_all_final_niche_final.h5ad')

In [None]:
tmp = sub.obsm['celltype_predicted'].groupby(sub.obs['Slice_ID']).mean()
scaled = tmp.div(tmp.sum(axis=1), axis=0)

In [None]:
import anndata as ad

In [None]:
ad_scaled = ad.AnnData(scaled)
meta_info = sub.obs[[
    'Slice_ID', 'timepoint']].drop_duplicates().set_index('Slice_ID')
ad_scaled.obs['timepoint'] = meta_info['timepoint']

In [None]:
fine_coarse_dictionary = sub.obs[['Tier3', 'Tier1']].drop_duplicates().set_index('Tier3').to_dict()['Tier1']

In [None]:
epithelial = [i for i, j in fine_coarse_dictionary.items() if j=='Epithelial']
fibroblast = [i for i, j in fine_coarse_dictionary.items() if j=='Fibroblast']

In [None]:
sc.pl.dotplot(ad_scaled, groupby='timepoint', var_names=fibroblast, standard_scale='var', expression_cutoff=0.01, swap_axes=True, smallest_dot=30, save='fibroblast_cell_timepoint.pdf')

In [None]:
epithelial = ['Stem cells', 'Colonocytes', 'TA', 'Goblet 2', 'M cells', 'EEC', 'Goblet 1', 'IAE 1', 'IAE 2', 'IAE 3', 'Repair associated  (Arg1+)', 'Epithelial (Clu+)']

In [None]:
sc.pl.dotplot(ad_scaled, groupby='timepoint', var_names=epithelial, standard_scale='var', expression_cutoff=0.01, swap_axes=True, smallest_dot=30, save='epithelial_cell_timepoint.pdf')

In [None]:
scaled = scaled.merge(meta_info, how='inner', left_index=True, right_index=True)

In [None]:
scaled.index = scaled.index.astype(str)

In [None]:
proportional_analysis = sc.AnnData(scaled[scaled.columns[scaled.dtypes == 'float32']])
proportional_analysis.obs = scaled[scaled.columns[scaled.dtypes != 'float32']]

In [None]:
sc.tl.pca(proportional_analysis)
sc.pl.pca(proportional_analysis, color='timepoint', size=100, components=['1, 2', '3, 4', '1, 3', '2, 3'], save='ratio_pca_plot.pdf')

In [None]:
palette = sns.color_palette("deep", len(proportional_analysis.obs['timepoint'].unique()))
timepoint_palette = {timepoint: palette[i] for i, timepoint in enumerate(proportional_analysis.obs['timepoint'].cat.categories)}
fig, ax = plt.subplots(figsize=(6, 8))

pca_coords = proportional_analysis.obsm['X_pca'][:, :2]
timepoints = proportional_analysis.obs['timepoint']
means = pd.DataFrame(pca_coords, columns=['PC1', 'PC2'], index=timepoints.index).groupby(timepoints).mean()

for timepoint, mean_coords in means.iterrows():
    ax.scatter(mean_coords['PC1'], mean_coords['PC2'], color=timepoint_palette[timepoint], marker='x', s=300, linewidths=3, label=f"Mean {timepoint}")
    
sc.pl.pca(proportional_analysis, color='timepoint', size=200, components=['1, 2'],
          palette=timepoint_palette,
          save=False, show=False, ax=ax)

plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
fig.savefig('figure4_new/proportion_plot_timepoints.svg')
plt.show()

In [None]:
import milopy
import milopy.core as milo

In [None]:
sub_sub = sub[sub.obs['timepoint'].isin(['D0', 'D35'])]

In [None]:
sc.pp.neighbors(sub_sub, n_neighbors=100, use_rep='X_resolVI', method='rapids')
milo.make_nhoods(sub_sub, prop=0.1)
milo.count_nhoods(sub_sub, sample_col="Slice_ID")
milo.DA_nhoods(sub_sub, design="~ timepoint")
milo_results = sub_sub.uns["nhood_adata"].obs
milopy.utils.build_nhood_graph(sub_sub)

In [None]:
milopy.plot.plot_nhood_graph(sub_sub, alpha=0.05, min_size=0.1, min_logFC=0.5, save='ct_ct_embedding_d0_d35.pdf')

In [None]:
milopy.utils.annotate_nhoods(sub_sub, anno_col='predicted_celltype')

In [None]:
plt.hist(sub_sub.uns['nhood_adata'].obs["nhood_annotation_frac"]);
plt.xlabel("celltype fraction")

In [None]:
sub_sub.uns['nhood_adata'].obs.loc[sub_sub.uns['nhood_adata'].obs["nhood_annotation_frac"] < 0.4, "nhood_annotation"] = "Mixed"

In [None]:
sc.pl.violin(sub_sub.uns['nhood_adata'], "logFC", groupby="nhood_annotation", rotation=90, show=False,
             order=['Stem cells', 'TA', 'Colonocytes', 'Fibro 2', 'Fibro 6', 'Fibro 4', 'Fibro 7', 'IAF 2', 'Plasma cell'])
plt.axhline(y=0, color='black', linestyle='--')
plt.tight_layout()
plt.savefig('figure4_new/milo_da_d0_d35.pdf', bbox_inches='tight')
plt.show()

In [None]:
fibroblast = [
 'Fibro 12',
 'Fibro 5',
 'Fibro 15',
 'Fibro 7',
 'Fibro 2',
 'Fibro 13',
 'Fibro 6',
 'Fibro 4',
 'Fibro 1',
 'IAF 2',
 'IAF 3',]

In [None]:
sc.pl.violin(sub_sub.uns['nhood_adata'], "logFC", groupby="nhood_annotation", rotation=90, show=False,
             order=fibroblast)
plt.axhline(y=0, color='black', linestyle='--')
plt.tight_layout()
plt.savefig('figure4_new/milo_da_d0_d35.pdf', bbox_inches='tight')
plt.show()

In [None]:
sub_sub.obs['milo_lfc'] = sub_sub.uns['nhood_adata'].obs['logFC']
sub_sub.obs['FDR'] = sub_sub.uns['nhood_adata'].obs['FDR']

In [None]:
from sklearn.impute import KNNImputer

In [None]:
latent = pd.DataFrame(sub_sub.obsm['X_resolVI'])
latent['milo_lfc'] = (sub_sub.obs['milo_lfc'] * (sub_sub.obs['FDR']<0.05))

In [None]:
pd.options.display.max_rows=500

In [None]:
imputer = KNNImputer(n_neighbors=20)
b = imputer.fit_transform(latent)
sub_sub.obs['milo_lfc_imputed'] = b[:, -1]

In [None]:
sub_sub.obs['milo_lfc_thresholded'] = [i if abs(i)>1.5 else None for i in sub_sub.obs['milo_lfc_imputed']] 

In [None]:
sub_sub.obs['highlight_celltype'] = [i if i in ['Stem cells', 'TA', 'Colonocytes', 'Fibro 2', 'Fibro 6', 'Fibro 4', 'Fibro 7', 'IAF 2', 'Plasma cell'] else None for i in sub_sub.obs['Tier3']]

In [None]:
sc.pl.spatial(
    sub_sub[sub_sub.obs['Slice_ID']=='082421_D0_m6_1_slice_1'], spot_size=8, layer='generated_expression', color=['milo_lfc_thresholded'], title='082421_D0_m6_1_slice_1', ncols=1,
    save='d0_milo_spatial_plot.pdf', cmap='seismic', vmax=5, vmin=-5
)

In [None]:
sc.pl.spatial(
    sub_sub[sub_sub.obs['Slice_ID']=='072523_D35_m6_1_slice_3'], spot_size=8, layer='generated_expression', color=['milo_lfc_thresholded'], title='072523_D35_m6_1_slice_3', ncols=1,
    save='d35_milo_spatial_plot.pdf', cmap='seismic', vmax=5, vmin=-5
)

In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.spatial(
        sub[
            np.logical_and(sub.obs['Slice_ID']=='082421_D0_m6_1_slice_1', sub.obs['predicted_celltype_coarse']=='Epithelial')],
        spot_size=15, layer='generated_expression',
        color=['Cldn23', 'Bmp2', 'Il22ra1', 'Cldn4', 'Oasl1', 'Edn1', 'Timp3', 'Bmp3', 'Dusp1', 'Itgb6', 'Tnfaip3', 'Itgav', 'Nf2', 'Yap1', 'Ltbr'],
        ncols=5, cmap='Reds', save='epithelial_gexp_d0.pdf', vmax=[8, 18, 14, 14, 14, 16, 8, 4, 5, 4, 5, 4, 6, 5, 4]
    )

In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.spatial(
        sub[
            np.logical_and(sub.obs['Slice_ID']=='072523_D35_m6_1_slice_3', sub.obs['predicted_celltype_coarse']=='Epithelial')],
        spot_size=15, layer='generated_expression',
        color=['Cldn23', 'Bmp2', 'Il22ra1', 'Cldn4', 'Oasl1', 'Edn1', 'Timp3', 'Bmp3', 'Dusp1', 'Itgb6', 'Tnfaip3', 'Itgav', 'Nf2', 'Yap1', 'Ltbr'],
        ncols=5, cmap='Reds', save='epithelial_gexp_d35.pdf', vmax=[8, 18, 14, 14, 14, 16, 8, 4, 5, 4, 5, 4, 6, 5, 4]
    )

In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.spatial(
        sub_sub[
            np.logical_and(sub_sub.obs['Slice_ID']=='082421_D0_m6_1_slice_1', sub_sub.obs['predicted_celltype_coarse']=='Fibroblast')],
        spot_size=15, layer='generated_expression', color=['Vegfa', 'Tnc', 'F3','Adamdec1', 'Col27a1', 'Igfbp5', 'Tgfbr3', 'Bmp2', 'Bmp5', 'Bmp7'], ncols=5,
        vmax=[4, 2, 2, 8, 2, 8, 4, 1, 3, 1.5], cmap='Reds', save='fibroblast_gexp_d0.pdf'
    )

In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.spatial(
        sub_sub[
            np.logical_and(sub_sub.obs['Slice_ID']=='072523_D35_m6_1_slice_3', sub_sub.obs['predicted_celltype_coarse']=='Fibroblast')],
        spot_size=15, layer='generated_expression', color=['Vegfa', 'Tnc', 'F3','Adamdec1', 'Col27a1', 'Igfbp5', 'Tgfbr3', 'Bmp2', 'Bmp5', 'Bmp7'], ncols=5,
        vmax=[4, 2, 2, 8, 2, 8, 4, 1, 3, 1.5], cmap='Reds', save='fibroblast_gexp_d35.pdf'
    )

In [None]:
sub_sub2 = sub_sub.copy()
sub_sub2.obsm['X_umap'] = sub_sub2.obsm['X_tsne']

In [None]:
sub_sub2.obsm['celltypes_neighborhood_'] = sub_sub2.obsm['celltypes_neighborhood'].values
sc.pp.neighbors(sub_sub2, n_neighbors=100, use_rep='celltypes_neighborhood_', method='rapids')
milo.make_nhoods(sub_sub2, prop=0.1)
milo.count_nhoods(sub_sub2, sample_col="Slice_ID")
milo.DA_nhoods(sub_sub2, design="~ timepoint")
milo_results = sub_sub2.uns["nhood_adata"].obs
milopy.utils.build_nhood_graph(sub_sub2)
milopy.plot.plot_nhood_graph(sub_sub2, alpha=0.05, min_size=0.1, min_logFC=0.5, save='niche_d0_d35')