In [None]:
## notebook to cluster region specific files using marker genes and scVIVA  ### 

In [None]:
import os
from pathlib import Path
import ast

import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import scvi

from spida.P.setup_adata import multi_round_clustering, _calc_embeddings

import matplotlib.pyplot as plt
import seaborn as sns
from spida.pl import plot_categorical, plot_continuous
plt.rcParams['axes.facecolor'] = 'white'

from datetime import datetime 
current_datetime = datetime.now().strftime("%Y-%m-%d_%H:%M")

In [None]:
#parameters
EXPERIMENT = "PU" 
prefix = "BICAN_BG"
suffix = "proseg_fv38_filt"
output_dir = "/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation"
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/annotations"
model_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/models"

In [None]:
adata_path = Path(f"{output_dir}/{prefix}_{EXPERIMENT}/{EXPERIMENT}.h5ad")
scviva_adata_path = Path(f"{output_dir}/{prefix}_{EXPERIMENT}/{EXPERIMENT}_scVIVA.h5ad")

In [None]:
EXP_TO_REF = {
    "PU": "PU",
    "CAH": "CA",
    "CAB": "CA",
    "CAT": "CA",
    "GP": "GP",
    "STH": "STH",
    "NAC": "NAC",
    "MGM": "MGM",
}
REF_EXP = EXP_TO_REF[EXPERIMENT]

In [None]:
ref_adata_path = Path(f"/home/x-aklein2/projects/aklein/BICAN/data/reference/AIT/AIT_{REF_EXP}_filtered.h5ad")
deg_path = Path(f"/home/x-aklein2/projects/aklein/BICAN/data/reference/DEGs/{REF_EXP}/summary_subclass.csv")
degs = pd.read_csv(deg_path)
image_path = Path(image_path) / EXPERIMENT
model_path = Path(model_path) / EXPERIMENT / "clustering.pt"
image_path.mkdir(parents=True, exist_ok=True)
model_path.parent.mkdir(parents=True, exist_ok=True)

In [None]:
#papermill_description=Reading AnnData
adata = ad.read_h5ad(adata_path)
adata

## Marker Genes

In [None]:
# From SCVI SCANVI tutorial 
def get_score(normalized_adata, gene_set):
    """Returns the score per cell given a dictionary of + and - genes.

    Parameters
    ----------
    normalized_adata
      anndata dataset that has been log normalized and scaled to mean 0, std 1
    gene_set
      a dictionary with two keys: 'positive' and 'negative'
      each key should contain a list of genes
      for each gene in gene_set['positive'], its expression will be added to the score
      for each gene in gene_set['negative'], its expression will be subtracted from its score

    Returns
    -------
    array of length of n_cells containing the score per cell
    """
    score = np.zeros(normalized_adata.n_obs)
    for gene in gene_set["positive"]:
        expression = np.array(normalized_adata[:, gene].X.toarray())
        score += expression.flatten()
    for gene in gene_set["negative"]:
        expression = np.array(normalized_adata[:, gene].X.toarray())
        score -= expression.flatten()
    return score


def get_cell_mask(normalized_adata, gene_set):
    """Get cell mask.

    Calculates the score per cell for a list of genes, then returns a mask for
    the cells with the highest 50 scores.

    Parameters
    ----------
    normalized_adata
      anndata dataset that has been log normalized and scaled to mean 0, std 1
    gene_set
      a dictionary with two keys: 'positive' and 'negative'
      each key should contain a list of genes
      for each gene in gene_set['positive'], its expression will be added to the score
      for each gene in gene_set['negative'], its expression will be subtracted from its score

    Returns
    -------
    Mask for the cells with the top 50 scores over the entire dataset
    """
    score = get_score(normalized_adata, gene_set)
    cell_idx = score.argsort()[-20:]
    mask = np.zeros(normalized_adata.n_obs)
    mask[cell_idx] = 1
    return mask.astype(bool)

In [None]:
markers = {}
for c, row in degs.iterrows(): 
    # print(c, row)
    _ct = row['cell_type']
    positive_markers = ast.literal_eval(row['top_upregulated'])
    negative_markers = ast.literal_eval(row['top_downregulated'])
    print(_ct, positive_markers, negative_markers)
    markers[_ct] = {"positive" : positive_markers, "negative": negative_markers}

In [None]:
adata.X = adata.layers['volume_norm'].copy()
sc.pp.scale(adata, max_value=10)

In [None]:
#papermill_description=Calculating seed labels with marker genes
masks = {}
for cell_type, gene_list in markers.items():
    mask = get_cell_mask(adata, gene_list)
    masks[cell_type] = mask
    adata.obs[cell_type + "_mask"] = mask
adata.layers['scaled'] = adata.X.copy()
adata.X = adata.layers['volume_norm'].copy()

In [None]:
ncols = 5
nrows = int(np.ceil(len(markers) / ncols))
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(25, 5 * nrows))
for i, _cell_type in enumerate(markers.keys()):
    plot_categorical(adata, cluster_col="brain_region", 
                     coord_base="X_base_round1_umap", 
                     ax=ax[i // ncols, i % ncols], 
                     max_points=5000, show=False
                    )
    adata_sub = adata[adata.obs[adata.obs[_cell_type + "_mask"]].index].copy()
    adata_sub.uns[_cell_type + "_mask_colors"] = ["red", "red"]
    plot_categorical(adata_sub, s=10, cluster_col=_cell_type + "_mask",
                     coord_base="X_base_round1_umap",
                     ax=ax[i // ncols, i % ncols], show=False,
                     scatter_kws={"alpha":1, "edgecolor":"k", "linewidth":0.5}
                    )
for j in range(i + 1, nrows * ncols):
    fig.delaxes(ax[j // ncols, j % ncols])

plt.savefig(image_path / f"{current_datetime}_00_base_umap_subclass_masks.png", dpi=300, bbox_inches='tight')
plt.close()

In [None]:
seed_labels = np.array(adata.shape[0] * ["Unknown"])
for _cell_type, mask in masks.items():
    seed_labels[mask] = _cell_type
adata.obs["seed_labels"] = seed_labels

In [None]:
scvi.model.SCVI.setup_anndata(adata, batch_key="dataset_id", labels_key="seed_labels", layer="counts")
scvi_model = scvi.model.SCVI(adata, n_latent=50, n_layers=4)

In [None]:
%%time
#papermill_description=Training SCVI model 

if model_path.with_suffix(".scvi.pt").exists():
    scvi_model.load(model_path.with_suffix(".scvi.pt"), adata=adata)
else:
    scvi_model.train(100)
    scvi_model.save(model_path.with_suffix(".scvi.pt"), overwrite=True)

In [None]:
%%time
#papermill_description=Training SCANVI model 

if model_path.with_suffix(".scanvi.pt").exists():
    scanvi_model = scvi.model.SCANVI.load(model_path.with_suffix(".scanvi.pt"), adata=adata)
    # scanvi_model.load(model_path.with_suffix(".scanvi.pt"), adata=adata)
else:
    scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model, "Unknown")
    scanvi_model.train(25)
    scanvi_model.save(model_path.with_suffix(".scanvi.pt"), overwrite=True)

In [None]:
SCANVI_LATENT_KEY = "X_scANVI"
SCANVI_PREDICTIONS_KEY = "C_scANVI"

adata.obsm[SCANVI_LATENT_KEY] = scanvi_model.get_latent_representation(adata)
adata.obs[SCANVI_PREDICTIONS_KEY] = scanvi_model.predict(adata)

In [None]:
#papermill_description=Calculating SCANVI embeddings
_calc_embeddings(
    adata,
    use_rep="X_scANVI",
    knn=30,
    min_dist=0.25,
    leiden_res=1,
    key_added="scANVI_",
)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
axes = axes.flatten()
plot_categorical(adata, cluster_col="seed_labels", coord_base="scANVI_umap", show=False, coding=True, text_anno=True, ax=axes[0])
plot_categorical(adata, cluster_col=SCANVI_PREDICTIONS_KEY, coord_base="scANVI_umap", show=False, coding=True, text_anno=True, ax=axes[1])
plot_categorical(adata, cluster_col="donor", coord_base="scANVI_umap", show=False, coding=True, text_anno=False, ax=axes[2])
plot_categorical(adata, cluster_col="replicate", coord_base="scANVI_umap", show=False, coding=True, text_anno=False, ax=axes[3])

plt.savefig(image_path / f"{current_datetime}_00_SCANVI_umap.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()


In [None]:
adata.write_h5ad(scviva_adata_path)

## scVIVA

In [None]:
adata = ad.read_h5ad(scviva_adata_path)
adata

In [None]:
SCANVI_LATENT_KEY = "X_scANVI"
SCANVI_PREDICTIONS_KEY = "C_scANVI"
setup_kwargs = {
    "sample_key" : "dataset_id",
    "labels_key" : SCANVI_PREDICTIONS_KEY,
    "cell_coordinates_key" : "spatial",
    "expression_embedding_key" : SCANVI_LATENT_KEY,
}

In [None]:
scvi.external.SCVIVA.preprocessing_anndata(
    adata, 
    k_nn=20,
    **setup_kwargs
)

In [None]:
scvi.external.SCVIVA.setup_anndata(
    adata, 
    layer='counts', 
    batch_key="dataset_id",
    **setup_kwargs
)

In [None]:
scviva_model = scvi.external.SCVIVA(adata, n_layers=2, n_latent=20)
scviva_model

In [None]:
#papermill_description=Training scVIVA model 
scviva_model.train(
    max_epochs=600, 
    early_stopping=True,
    check_val_every_n_epoch=1,
    batch_size=512, 
    plan_kwargs = {
        "lr": 5e-4
    }
)

In [None]:
scviva_model.history["elbo_validation"].plot()
scviva_model.history["niche_compo_validation"].plot()
scviva_model.history["niche_reconst_validation"].plot()
scviva_model.history["kl_local_validation"].plot()
scviva_model.history["reconstruction_loss_validation"].plot()

In [None]:
adata.obsm["X_scVIVA"] = scviva_model.get_latent_representation()

In [None]:
#papermill_description=Calculating scVIVA embeddings
multi_round_clustering(
    adata,
    layer=None,
    use_rep="X_scVIVA",
    key_added="scVIVA_",
    num_rounds=2,
    leiden_res=[0.75, 0.5],
    min_dist=0.25,
    knn=50,
    min_group_size=50,
    run_harmony=True, 
    batch_key=["replicate", "donor"],
    harmony_nclust=20,
    max_iter_harmony=20,
)
# multi_round_clustering(adata, use_rep="X_scVIVA", key_added="scVIVA_", num_rounds=2, leiden_res=[0.75, 0.5], min_dist=0.25, knn=50, min_group_size=50)
# _calc_embeddings(adata, use_rep="X_scVIVA", knn=30, min_dist=0.25, leiden_res=1, key_added="scVIVA_")

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
axes = axes.flatten()

plot_categorical(adata, cluster_col="scVIVA_round1_leiden", coord_base="scVIVA_round1_umap", show=False, coding=True, text_anno=False, ax=axes[0])
plot_categorical(adata, cluster_col=SCANVI_PREDICTIONS_KEY, coord_base="scVIVA_round1_umap", show=False, coding=True, text_anno=False, ax=axes[1])
plot_categorical(adata, cluster_col="donor", coord_base="scVIVA_round1_umap", show=False, coding=True, text_anno=False, ax=axes[2])
plot_categorical(adata, cluster_col="replicate", coord_base="scVIVA_round1_umap", show=False, coding=True, text_anno=False, ax=axes[3])

plt.savefig(image_path / f"{current_datetime}_00_scVIVA_umap.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()


In [None]:
fig, axes = plt.subplots(1,2, figsize=(8,4))

plot_categorical(adata, cluster_col="scVIVA_round1_leiden", coord_base="scVIVA_round1_umap", show=False, ax=axes[0])
plot_categorical(adata, cluster_col="scVIVA_round2_leiden", coord_base="scVIVA_round1_umap", show=False, ax=axes[1])

plt.savefig(image_path / f"{current_datetime}_00_scVIVA_umap_multiround.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
adata.write_h5ad(scviva_adata_path)

In [None]:
### Generated Expression?
adata.layers["scVIVA_normalized"] = scviva_model.get_normalized_expression()

In [None]:
#papermill_description=Calculating scVIVA embeddings
multi_round_clustering(
    adata,
    layer="scVIVA_normalized",
    use_rep=None,
    key_added="scVIVA_expr_",
    num_rounds=2,
    leiden_res=[0.75, 0.5],
    min_dist=0.25,
    knn=50,
    min_group_size=50,
    run_harmony=True, 
    batch_key=["replicate", "donor"],
    harmony_nclust=20,
    max_iter_harmony=20,
)
# multi_round_clustering(adata, use_rep="X_scVIVA", key_added="scVIVA_", num_rounds=2, leiden_res=[0.75, 0.5], min_dist=0.25, knn=50, min_group_size=50)
# _calc_embeddings(adata, use_rep="X_scVIVA", knn=30, min_dist=0.25, leiden_res=1, key_added="scVIVA_")

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
axes = axes.flatten()

plot_categorical(adata, cluster_col="scVIVA_expr_round1_leiden", coord_base="scVIVA_expr_round1_umap", show=False, coding=True, text_anno=False, ax=axes[0])
plot_categorical(adata, cluster_col=SCANVI_PREDICTIONS_KEY, coord_base="scVIVA_expr_round1_umap", show=False, coding=True, text_anno=False, ax=axes[1])
plot_categorical(adata, cluster_col="donor", coord_base="scVIVA_expr_round1_umap", show=False, coding=True, text_anno=False, ax=axes[2])
plot_categorical(adata, cluster_col="replicate", coord_base="scVIVA_expr_round1_umap", show=False, coding=True, text_anno=False, ax=axes[3])

plt.savefig(image_path / f"{current_datetime}_00_scVIVA_expr_umap.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()


In [None]:
fig, axes = plt.subplots(1,2, figsize=(8,4))

plot_categorical(adata, cluster_col="scVIVA_expr_round1_leiden", coord_base="scVIVA_expr_round1_umap", show=False, ax=axes[0])
plot_categorical(adata, cluster_col="scVIVA_expr_round2_leiden", coord_base="scVIVA_expr_round1_umap", show=False, ax=axes[1])

plt.savefig(image_path / f"{current_datetime}_00_scVIVA_expr_umap_multiround.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
adata.write_h5ad(scviva_adata_path)