In [None]:
# Load packages
import warnings
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import squidpy as sq
from loguru import logger
import seaborn as sns

import scvi
from scvi.external.stereoscope import RNAStereoscope, SpatialStereoscope
from scvi.model import CondSCVI, DestVI
import torch
from scvi.external import Tangram
from tqdm import tqdm
from scipy.sparse import csr_matrix


# Read ST and scRNA data

In [None]:
sc_adata = sc.read_h5ad(r"C:\Users\rafaelo\OneDrive - NTNU\Documents\Projects\STNav\data\processed\PipelineRun_2024_06_03-11_53_39_AM\scRNA\Files\raw_adata.h5ad")
# let us filter some genes
sc_adata.var_names_make_unique()
sc.pp.filter_genes(sc_adata, min_counts=10)

sc_adata.layers["counts"] = sc_adata.X.copy()

sc.pp.highly_variable_genes(
    sc_adata, n_top_genes=10000, subset=True, layer="counts", flavor="seurat_v3"
)

sc.pp.normalize_total(sc_adata, target_sum=10e4)
sc.pp.log1p(sc_adata)
sc_adata.raw = sc_adata

In [None]:
st_adata = sc.read_h5ad(r"C:\Users\rafaelo\OneDrive - NTNU\Documents\Projects\STNav\data\processed\PipelineRun_2024_06_03-11_53_39_AM\ST\Files\raw_adata.h5ad")
st_adata.var_names_make_unique()


# sc.pp.filter_genes(st_adata, min_counts=15)
# sc.pp.filter_cells(st_adata, min_genes=5)
st_adata.layers["counts"] = st_adata.X.copy()

sc.pp.normalize_total(st_adata, target_sum=10e4)
sc.pp.log1p(st_adata)
st_adata.raw = st_adata

In [None]:
intersect = np.intersect1d(sc_adata.var_names, st_adata.var_names)
st_adata.var_names_make_unique()
sc_adata.var_names_make_unique()

st_adata = st_adata[:, intersect].copy()
sc_adata = sc_adata[:, intersect].copy()
G = len(intersect)

In [None]:
st_adata.layers["counts"] = st_adata.layers["counts"].astype('float64')

In [None]:
CondSCVI.setup_anndata(sc_adata, layer="counts", labels_key="ann_level_3_transferred_label")
sc_model = CondSCVI(sc_adata, weight_obs=False)
sc_model.view_anndata_setup()
sc_model.train()
sc_model.history["elbo_train"].iloc[5:].plot()
plt.show()


In [None]:
# sq.gr.spatial_neighbors(st_adata)
# st_adata.obsp["spatial_connectivities"] = csr_matrix(st_adata.obsp["spatial_connectivities"])

# st_adata.X = st_adata.obsp["spatial_connectivities"].dot(st_adata.X)
# st_adata.layers["counts"] = st_adata.X

In [None]:
st_adata.layers["counts"][150,10:150].toarray()

In [None]:
st_adata.layers["counts"] = st_adata.layers["counts"].astype('int32')

In [None]:
DestVI.setup_anndata(st_adata, layer="counts")
st_model = DestVI.from_rna_model(st_adata, sc_model)
st_model.view_anndata_setup()

In [None]:
st_model.train(max_epochs=2500)


In [None]:
st_adata.obsm["proportions"] = st_model.get_proportions()


In [None]:
st_adata.obsm["proportions"].head(5)


In [None]:
ct_list = ["B cells", "CD8 T cells", "Monocytes"]
for ct in ct_list:
    data = st_adata.obsm["proportions"][ct].values
    st_adata.obs[ct] = np.clip(data, 0, np.quantile(data, 0.99))

In [None]:
sc.pl.embedding(st_adata, basis="spatial", color=ct_list, cmap="Reds", s=80)
