# get metacell h5ad

In [None]:
import gc
import logging as log
from pathlib import Path

import numpy as np
import scipy as sc
import pandas as pd
import scanpy as sc

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown

from scanpy_helpers import *

In [None]:
logger = log.getLogger()

In [None]:
# Set maximum number of jobs for Scanpy.
sc.settings.njobs = 4

## Params

input

In [None]:
anndata_raw = "/path/to/rna_raw.h5ad"
anndata_annot = "/path/to/rna.h5ad"

output

In [None]:
anndata_merge = "/path/to/rna_merged.h5ad"
metacell_rna_h5ad = "/path/to/rna_metacells.h5ad"

params

In [None]:
annotation_col = "cell_type_obs_column"
additional_filter = ""  # filter string in pd.query() format
use_cell_types = []  # list of cell types to subset to
metacell_batch = "batch"  # retrict metacells by adata.obs column

## 1) Load RNA anndata

In [None]:
log.info("load rna anndata")

In [None]:
ad = sc.read_h5ad(anndata_annot)

### subset

In [None]:
if use_cell_types:
    log.info(f"subset rna anndata to: {use_cell_types}")
    ad = ad[ad.obs.query("`cell type` in @use_cell_types").index]

In [None]:
if additional_filter:
    ad = ad[ad.obs.query(additional_filter).index]

In [None]:
ad

### get raw counts

In [None]:
log.info("load raw counts")

In [None]:
ad_raw = sc.read_h5ad(anndata_raw)

In [None]:
if ad_raw.raw and ad_raw.X.max() < 100:
    log.warning("found anndata.raw attribute, getting raw counts from there")
    ad_raw = ad_raw.raw.to_adata()

In [None]:
ad_raw = ad_raw[ad.obs_names, :]

In [None]:
ad_raw.obs = ad.obs
ad_raw.obsm = ad.obsm
ad_raw.obsp = ad.obsp
ad_raw.uns = ad.uns

In [None]:
ad = ad_raw

In [None]:
gc.collect()

In [None]:
ad.write(anndata_merge)

In [None]:
ad

### remove batches with <10 cells

In [None]:
if metacell_batch:
    batch_vc = ad.obs[metacell_batch].value_counts()
    log.warning(f"removing {batch_vc[batch_vc <= 10].size} batches with <10 cells")
    batch_sel = batch_vc[batch_vc > 10].index.tolist()

    ad = ad[ad.obs[metacell_batch].isin(batch_sel)]

## 2) Get metacells

In [None]:
log.info("filter genes")

In [None]:
sc.pp.filter_genes(ad, min_cells=100)

In [None]:
log.info("compute metacells")

In [None]:
if metacell_batch:
    ad_meta = get_metacells_by_group(ad, metacell_batch, max_group_size = 15, min_cells=5, summarise = "sum")
else:
    ad_meta = get_metacells(ad, max_group_size = 15, min_cells=5, summarise = "sum")

In [None]:
ad_red = ad[ad_meta.obs_names, :]

ad_meta.obsm = ad_red.obsm
ad_meta.obsp = ad_red.obsp
ad_meta.uns = ad_red.uns

try:
    sc.pl.umap(ad_meta, color="annot")
except:
    pass

In [None]:
ad_meta

In [None]:
ad_meta.write(metacell_rna_h5ad)

In [None]:
ad_meta.raw = ad_meta

## 3) Check embeddings

In [None]:
log.info("plot embeddings")

In [None]:
sc.pp.normalize_total(ad_meta)
sc.pp.log1p(ad_meta)

In [None]:
ad_meta.raw = ad_meta

In [None]:
ad_meta.X = ad_meta.raw.X

try:
    sc.pp.highly_variable_genes(ad_meta, batch_key="batch")
except BaseException:
    log.warning("could not use 'batch' for HVG calculation")
    sc.pp.highly_variable_genes(ad_meta)
    
sc.pp.scale(ad_meta)
sc.pp.pca(ad_meta)
sc.pp.neighbors(ad_meta, n_neighbors=15)
# sc.external.pp.bbknn(ad_meta, batch_key='batch')

In [None]:
sc.tl.leiden(ad_meta, resolution=2)
sc.tl.umap(ad_meta)
sc.tl.paga(ad_meta, groups = annotation_col)
fig = sc.pl.umap(ad_meta, color = "leiden", show=False)

fig = sc.pl.paga(ad_meta, threshold=0.3, show=False)
sc.tl.umap(ad_meta, init_pos="paga")
sc.tl.draw_graph(ad_meta, init_pos="paga")

In [None]:
plot_cols = [col for col in [annotation_col, "batch", "leiden", "pcw"] if col in ad_meta.obs.columns.tolist()]
log.info(f"plotting data on embeddings: {plot_cols}")

In [None]:
try:
    with plt.rc_context({"figure.figsize": (5,5), "figure.facecolor": "white"}):
        sc.pl.umap(ad_meta, color=plot_cols, palette=sc.plotting.palettes.default_20)
except BaseException:
    log.exception("could not plot UMAP")

In [None]:
try:
    with plt.rc_context({"figure.figsize": (5,5), "figure.facecolor": "white"}):
        sc.pl.draw_graph(ad_meta, color=plot_cols, ncols=1, palette=sc.plotting.palettes.default_20, frameon=True)
except BaseException:
    log.exception("could not plot FA embedding")