# subset anndata and create metacells

In [None]:
import sys
import os
import pickle
import logging as log
from pathlib import Path

import numpy as np
import scipy as sp
import pandas as pd
from pandas.api.types import is_numeric_dtype
import scanpy as sc
import loompy as lp

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown

from scanpy_helpers import *

## parameters

In [None]:
# input
anndata_infile = "/path/to/rna.h5ad"

# params
cell_types = []  # list of cell types to subset to
annot_column = "cell_type"  # cell type annotation column
additional_filters = ""  # filter string using pd.query() syntax on adata.obs
gene_names_col = ""  # adata.var column with gene symbols if not in index
target_ncells = 0  # if >0: choose metacell size to obtain that many metacells
pseudobulk_size = 10  # group that many cells in one metacell (maximum)
metacells = "cluster"  # method for creatinbg metacells
metacell_batch = "cell_type"  # adata.obs column to limit metacell groups
subsample = 500  # subsample to that many metacells max per `subsample_by`
subsample_by = "cell_type"  # adata.obs column to subsample by
work_dir = Path("/path/to/work_dir")

n_threads = 8

In [None]:
# Set maximum number of jobs for Scanpy.
sc.settings.njobs = n_threads

In [None]:
rna_path = work_dir / "scRNA"

## 1) Load anndata

In [None]:
ad = sc.read_h5ad(anndata_infile)

In [None]:
is_raw = True

if ad.X.shape[1] < 15000:
    print("genes seem to be filtered, restoring ad.raw")
    ad = ad.raw.to_adata()
elif ad.X.min() < 0:
    print("counts seem to be scaled, restoring ad.raw")
    ad = ad.raw.to_adata()
elif ad.X.max() < 100:
    print("counts seem to be log-norm, restoring ad.raw")
    ad = ad.raw.to_adata()
    
if ad.X.shape[1] > 15000 and ad.X.min() >= 0:
    if ad.X.max() < 100:
        print("full log-normalised counts found")
        is_raw = False
    else:
        print("full raw counts found")
else:
    raise ValueError("no full unscaled counts found in anndata")
    
ad.raw = ad

In [None]:
if annot_column:
    ad.obs["celltype"] = ad.obs[annot_column].astype(str)

#### UMAP

In [None]:
try:
    with plt.rc_context({"figure.figsize": (15,15), "figure.dpi": 300, "figure.frameon": False}):
        sc.pl.umap(
            ad,
            color=annot_column,
            alpha=0.7,
            size=50,
            add_outline=True,
            outline_width = (0.25, 2.5),
            legend_fontoutline=3,
            legend_loc="on data",
            title = ""
        )
except:
    pass

## 2) Subset cells

In [None]:
if cell_types:
    ad_sub = ad[ad.obs[annot_column].isin(cell_types)]
else:
    ad_sub = ad

In [None]:
if additional_filters:
    ad_sub = ad_sub[ad_sub.obs_names.isin(
        ad_sub.obs.query(additional_filters).index
    )]

In [None]:
if gene_names_col:
    ad_sub.var["ID"] = ad_sub.var_names.tolist()
    ad_sub.var["gene_symbol"] = ad_sub.var[gene_names_col].tolist()
    ad_sub.var = ad_sub.var.set_index("gene_symbol")
    ad_sub.var_names_make_unique()
else:
    ad_sub.var["gene_symbol"] = ad_sub.var_names.tolist()
    ad_sub.var = ad_sub.var.set_index("gene_symbol")

## 3) Create metacells

In [None]:
n_cells = ad_sub.X.shape[0]

if target_ncells and not pseudobulk_size:
    pseudobulk_size = round(n_cells / target_ncells)
    if pseudobulk_size < 2:
        metacells = ""
        pseudobulk_size = 0 

In [None]:
pseudobulk_size

### compute connectivities

In [None]:
ad_sub.raw = ad_sub

if is_raw:
    sc.pp.normalize_total(ad_sub)
    sc.pp.log1p(ad_sub)
else:
    # fix if missing
    if "log1p" not in ad_sub.uns:
        ad_sub.uns["log1p"] = {}
    if "base" not in ad_sub.uns["log1p"]:
        # assume natural log
        ad_sub.uns["log1p"]["base"] = None

sc.pp.highly_variable_genes(ad_sub)

sc.pp.scale(ad_sub)
sc.tl.pca(ad_sub)

sc.pp.neighbors(ad_sub)
n_mat = ad_sub.obsp["connectivities"]

ad_sub.X = ad_sub.raw.X

### process

In [None]:
if metacells == "cluster":
    summarise_by = "sum" if is_raw else "mean"
    min_cells = round(pseudobulk_size / 2)
    
    log.info(f"create metacells{' by '+metacell_batch if metacell_batch else ''} using {summarise_by} with size {pseudobulk_size} and min cells: {min_cells}")
    
    if metacell_batch:
        agg_res = get_metacells_by_group(ad_sub, metacell_batch, max_group_size = pseudobulk_size, min_cells=min_cells, summarise = summarise_by)
    else:
        agg_res = get_metacells(ad_sub, max_group_size = pseudobulk_size, min_cells=min_cells, summarise = summarise_by)
    
    ad_sub = agg_res.adata
    metacell_bc = agg_res.obs_orig["metacell"].to_dict()
else:
    metacell_bc = {bc:i for i,bc in enumerate(ad_sub.obs_names)}

In [None]:
ad_sub

## subsample

In [None]:
if subsample:
    if subsample_by:
        # subsample by column 'subsample_by'
        def sample_rows(x, n=1000):
            if n > 1:
                # by total number
                return x.sample(n=int(n)) if x.shape[0]>n else x
            else:
                # by fraction
                return x.sample(frac=float(n)) if x.shape[0]>1 else x
        
        log.info("cell counts before subsampling")
        log.info(ad_sub.obs.value_counts(subsample_by))
        
        select_cells = ad_sub.obs.groupby(
            subsample_by
        ).apply(
            sample_rows, 
            n = subsample
        ).index.get_level_values(1).tolist()
        
        ad_sub = ad_sub[ad_sub.obs_names.isin(select_cells)]
        
        log.info("cell counts after subsampling")
        log.info(ad_sub.obs.value_counts(subsample_by))
    else:
        # subsample across all cells
        if subsample > 1:
            # by total number
            ad_sub = sc.pp.subsample(ad_sub, n_obs=min(int(subsample), ad_sub.shape[0]), copy=True)
        else:
            # by fraction
            ad_sub = sc.pp.subsample(ad_sub, fraction=float(subsample), copy=True)

### filter genes

In [None]:
sc.pp.filter_genes(ad_sub, min_counts=50)

### leiden clustering

In [None]:
ad_sub.raw = ad_sub.copy()

In [None]:
if not "leiden" in ad_sub.obs:
    if is_raw:
        sc.pp.normalize_total(ad_sub)
        sc.pp.log1p(ad_sub)
    sc.pp.scale(ad_sub)
    sc.tl.pca(ad_sub)
    sc.pp.neighbors(ad_sub)
    sc.tl.leiden(ad_sub)
    ad_sub.X = ad_sub.raw.X

if not annot_column:
    ad.obs["celltype"] = ad.obs["leiden"].astype(str)

## 4) Save anndata and barcodes

In [None]:
ad_sub.X.min(), ad_sub.X.max(), ad_sub.raw.X.min(), ad_sub.raw.X.max()

In [None]:
rna_path.mkdir(parents = True, exist_ok = True)

with open(rna_path / "metacells.pkl", "wb") as f:
    pickle.dump(metacell_bc, f)

In [None]:
ad_sub.write(rna_path / "anndata_metacells.h5ad")

In [None]:
ad_sub.obs.to_csv(rna_path / "anndata_metacells_obs.csv")