In [None]:
import sys
import scanpy as sc
import anndata
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import os
import cell2location
import scvi
from matplotlib import rcParams
rcParams['pdf.fonttype'] = 42 # enables correct plotting of text
import seaborn as sns

In [None]:
results_folder = './results/mouseDSS/'
ref_run_name = f'{results_folder}/reference_signatures'
run_name = f'{results_folder}/cell2location_map'

In [None]:
adata_ref = sc.read_h5ad("mouseDSSstromal.h5ad")

In [None]:
adata_vis = sc.read_visium("./rawfiles/V19S23-097_B1/")
adata_vis.var_names_make_unique()
adata_vis.var["mt"] = adata_vis.var_names.str.startswith("mt-")
sc.pp.calculate_qc_metrics(adata_vis, qc_vars=["mt"], inplace=True)

sc.pp.filter_cells(adata_vis, min_counts=4000)
print(f"#cells after min filter: {adata_vis.n_obs}")
sc.pp.filter_cells(adata_vis, max_counts=32000)
print(f"#cells after max filter: {adata_vis.n_obs}")
adata_vis = adata_vis[adata_vis.obs["pct_counts_mt"] < 20]
print(f"#cells after MT filter: {adata_vis.n_obs}")
sc.pp.filter_genes(adata_vis, min_cells=10)

# lncRNA and mitochondrial protein coding genes are deleted as original paper 
genestokeep = pd.read_csv('genestokeep.csv')
genestokeep = list(genestokeep["Genes"])
genestokeep = np.in1d(adata_vis.var_names.values.astype(str), genestokeep)
adata_vis = adata_vis[:,genestokeep]
adata_vis.obs['sample'] = list(adata_vis.uns['spatial'].keys())[0]
del genestokeep

In [None]:
scvi.data.setup_anndata(adata=adata_ref,
                        # 10X reaction / sample / batch
                        batch_key='donor',
                        # cell type, covariate used for constructing signatures
                        labels_key='Cluster'
                        # multiplicative technical effects (platform, 3' vs 5', donor effect)
                       )
scvi.data.view_anndata_setup(adata_ref)

In [None]:
from cell2location.models import RegressionModel
mod = RegressionModel(adata_ref)
mod.train(max_epochs=250, batch_size=2500, train_size=1, lr=0.001, use_gpu=True)
mod.plot_history(20)

In [None]:
adata_ref = mod.export_posterior(
    adata_ref, sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'use_gpu': True}
)
mod.save(f"{ref_run_name}", overwrite=True)
adata_file = f"{ref_run_name}/sc.h5ad"
adata_ref.write(adata_file)
mod = cell2location.models.RegressionModel.load(f"{ref_run_name}", adata_ref)
adata_file = f"{ref_run_name}/sc.h5ad"
adata_ref = sc.read_h5ad(adata_file)

In [None]:
if 'means_per_cluster_mu_fg' in adata_ref.varm.keys():
    inf_aver = adata_ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
else:
    inf_aver = adata_ref.var[[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
inf_aver.columns = adata_ref.uns['mod']['factor_names']
inf_aver.iloc[0:5, 0:5]

In [None]:
intersect = np.intersect1d(adata_vis.var_names, inf_aver.index)
adata_vis = adata_vis[:, intersect].copy()
inf_aver = inf_aver.loc[intersect, :].copy()
scvi.data.setup_anndata(adata=adata_vis, batch_key="sample")
scvi.data.view_anndata_setup(adata_vis)

In [None]:
mod = cell2location.models.Cell2location(
    adata_vis, cell_state_df=inf_aver,
    detection_alpha=200
)

mod.train(max_epochs=30000,
          batch_size=None,
          train_size=1,
          use_gpu=True)
mod.plot_history(1000)
plt.legend(labels=['full data training']);

In [None]:
adata_vis = mod.export_posterior(
    adata_vis, sample_kwargs={'num_samples': 1000, 'batch_size': mod.adata.n_obs, 'use_gpu': True}
)

# Save model
mod.save(f"{run_name}", overwrite=True)
adata_file = f"{run_name}/sp.h5ad"
adata_vis.write(adata_file)
adata_file

In [None]:
mod = cell2location.models.Cell2location.load(f"{run_name}", adata_vis)
adata_file = f"{run_name}/sp.h5ad"
adata_vis = sc.read_h5ad(adata_file)

In [None]:
adata_vis.obs[adata_vis.uns['mod']['factor_names']] = adata_vis.obsm['q05_cell_abundance_w_sf']

In [None]:
mod.samples = adata_vis.uns['mod']

In [None]:
expected_dict = mod.module.model.compute_expected_per_cell_type(
    mod.samples["post_sample_q05"], mod.adata
)
for i, n in enumerate(mod.factor_names_):
    adata_vis.layers[n] = expected_dict['mu'][i]

In [None]:
adata_vis.write('DSSmousewithcelltypredictions.h5ad')