In [None]:
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
adata = sc.read('/mnt/sata1/Analysis_Alex/timecourse_replicates/analysis/cleaned/full_xenium_replicates_and_reference_no_peyers.h5ad')

In [None]:
batches = {'day6': ['day6_SI', 'day6_SI_r2'], 'day8':['day8_SI_Ctrl', 'day8_SI_r2'], 'day30': ['day30_SI', 'day30_SI_r2'], 'day90': ['day90_SI', 'day90_SI_r2']}

In [None]:

inverted_batches = {}
for key, value in batches.items():
    for item in value:
        inverted_batches[item] = key

In [None]:
adata.obs['timepoint'] = [inverted_batches.get(i) for i in adata.obs['batch']]
adata = adata[~adata.obs['timepoint'].isna()]

In [None]:
def transformation(x, a=0.1, b=0.1, c=0.5, d=2.5, f=4, w=1):
    x = np.array(x)
    return a * np.exp(b * ((x - w))) - c * np.exp(-d * (x - w)) + f
def filter_adata_expressed_in_n_cells(adata, percent=0.01):
    bin_Layer = adata.X > 0
    gene_expressed_in_percent_cells = np.mean(bin_Layer, axis=0)
    keep = gene_expressed_in_percent_cells > percent
    adata = adata[:,keep]
    return(adata)


In [None]:
test_adatas = []
unique_batches = np.unique(adata.obs['batch'])

for timepoint in unique_batches:
    print(timepoint)
    ad_ = adata[adata.obs['batch'] == timepoint]
    sub_adata = filter_adata_expressed_in_n_cells(ad_)
    sub_adata = sub_adata.copy()

    sc.pp.normalize_total(sub_adata, target_sum=1e4)
    sc.pp.log1p(sub_adata)

    sub_adata.obs["epithelial_distance_transformed"] = transformation(
        sub_adata.obs["epithelial_distance_clipped"]
    )
    plt.scatter(sub_adata.obs['epithelial_distance_transformed'], sub_adata.obs['crypt_villi_axis'], c=sub_adata.X[:, sub_adata.var.index == 'Epcam'], s=2, vmax=40)
    plt.show()
    cutoff = input('Enter cutoff for LP cells: ')
    sub_adata.obs['condition'] = ['IEL' if i < float(cutoff) else 'LP' for i in sub_adata.obs['epithelial_distance_transformed']]

    test_adatas.append(sub_adata.obs)



In [None]:
# from tqdm.notebook import tqdm

# conditions = []
# batches = []
# mean_cxcl9_IEL = []
# mean_cxcl9_LP = []
# mean_cxcl10_IEL = []
# mean_cxcl10_LP = []

# for j in tqdm(range(len(unique_batches))):
#     ad_ = adata[adata.obs['batch'] == unique_batches[j]]
#     sub_adata = filter_adata_expressed_in_n_cells(ad_)
#     sub_adata = sub_adata.copy()

#     sc.pp.normalize_total(sub_adata, target_sum=1e4)
#     sc.pp.log1p(sub_adata)

#     sub_adata.obs = test_adatas[j]

#     current_df = sub_adata.obs.copy()

#     groupby = current_df.groupby('condition')
#     for group in groupby:
#         conditions.append(group[0])
#         batches.append(unique_batches[j])
#         ids_iel = current_df['condition'] == 'IEL'
#         ids_lp = current_df['condition'] == 'LP'
#         mean_cxcl9_IEL.append(sub_adata.X[ids_iel, sub_adata.var.index == 'Cxcl9'])
#         mean_cxcl9_LP.append(sub_adata.X[ids_lp, sub_adata.var.index == 'Cxcl9'])
#         mean_cxcl10_IEL.append(sub_adata.X[ids_iel, sub_adata.var.index == 'Cxcl10'])
#         mean_cxcl10_LP.append(sub_adata.X[ids_lp, sub_adata.var.index == 'Cxcl10'])

In [None]:
adata.obs = adata.obs.merge(pd.concat(test_adatas, axis=0)['condition'], how='left', left_index=True, right_index=True)

In [None]:
adata.obs['timepoint_region'] = adata.obs['timepoint'].astype(str) + '_' + adata.obs['condition'].astype(str)

In [None]:
sc.set_figure_params(figsize=(20, 20), dpi=400)
fig = sc.pl.dotplot(adata, var_names=['Cxcl9', 'Cxcl10'], groupby='timepoint_region', use_raw=False, log=True, return_fig=True, show=False)
fig.savefig('figures/panel_f_dynamics.pdf')
plt.show()