# Subset cistopic object and create metacells

In [None]:
import os
import sys
import pickle
import logging as log
import warnings
from pathlib import Path

import numpy as np
import scipy as sp
from scipy.sparse import csr_matrix
import pandas as pd
from pandas.api.types import is_numeric_dtype
import scanpy as sc
from pycisTopic.cistopic_class import *

import seaborn as sns
from IPython.display import display, Markdown

from utils import load_cistopic_obj, save_cistopic_obj

## Params

In [None]:
filter_peaks_cell_frac = 0.02
work_dir = Path("/path/to/work_dir")

In [None]:
display(Markdown(f"""
**parameters:**
- **filter for peaks present in fraction of cells:** *{filter_peaks_cell_frac}*
- **working directory for output files:** *{work_dir}*
"""))

In [None]:
log.info("get paths")

atac_path = work_dir / "scATAC"
rna_path = work_dir / "scRNA"

## 1) Load

In [None]:
log.info("load cistopic object")

cistopic_obj = load_cistopic_obj(atac_path / "cistopic_obj.pkl")

In [None]:
log.info("load metacell annotation")

with open(rna_path / "metacells.pkl", "rb") as f:
    metacells = pickle.load(f)

In [None]:
log.info("load aggregated rna annotation")

rna_obs_annot = pd.read_csv(rna_path / "anndata_metacells_obs.csv", index_col=0)

## 2) Subset cells

In [None]:
log.info("subset to cells included in metacells")

cn = set(cistopic_obj.cell_names)
ct_mask = [k for k in metacells.keys() if k in cn]
cistopic_obj.subset(cells = ct_mask)

cell_df = cistopic_obj.cell_data

In [None]:
log.info("add metacell info to cistopic object annotation")

cell_df['metacell'] = cell_df.index.map(metacells)

log.info(f"cistopic cell IDs: {cistopic_obj.cell_names[:5]}")
log.info(f"metacells cell IDs: {list(metacells)[:5]}")

In [None]:
log.info(cistopic_obj)

## 3) Filter peaks

In [15]:
bin_mat = cistopic_obj.binary_matrix
sums = bin_mat.sum(axis=1)
peak_freq = np.array(sums).flatten() / bin_mat.shape[1]

In [None]:
sns.histplot(peak_freq, binwidth=0.005)

In [17]:
if filter_peaks_cell_frac:
    subs_reg = np.array(cistopic_obj.region_names)[peak_freq > filter_peaks_cell_frac]
    region_data = cistopic_obj.region_data.loc[subs_reg,:]
    cistopic_obj.subset(regions=subs_reg.tolist(), copy=False)
    cistopic_obj.region_data = region_data  # NOTE: cistopic_obj.subset not subsetting cistopic_obj.region_data correctly, fix manually

In [None]:
log.info(cistopic_obj)

## 4) Aggregate into metacells

In [None]:
n_cells = len(set(metacells.keys()))
n_metacells = len(set(metacells.values()))

if n_cells == n_metacells:
    log.info("no metacells found, skipping aggregation step...")
else:
    log.info(f"found {n_metacells} metacells for {n_cells} cells")
    log.info("aggregating metacells in cistopic object object...")
    
    ######################
    #  aggregate counts  #
    ######################

    log.info("prepare metacell aggregation...")

    groups = np.array(cell_df['metacell'].tolist())

    cut_mat = np.zeros((cistopic_obj.fragment_matrix.shape[1], np.unique(groups).size))
    for i, g in enumerate(np.unique(groups)):
        cut_mat[:,i] = (groups==g)

    log.info("aggregate")

    agg_mat = cistopic_obj.fragment_matrix @ cut_mat

    log.info(f"...aggregated fragment counts into shape {agg_mat.shape}")

    cistopic_obj.fragment_matrix = csr_matrix(agg_mat)
    cistopic_obj.binary_matrix = csr_matrix((agg_mat>0).astype(int))


    ##############################
    #  aggregate cell meta-data  #
    ##############################

    log.info("...aggregate metadata")

    def agg_func(x):
        try:
            if is_numeric_dtype(x):
                return lambda y: y.mean()
            else:
                return lambda y: y.value_counts().index[0]
        except:
            return lambda y: np.nan

    idx_name = cell_df.index.name or 'index'
    cell_df = cell_df.reset_index(names=idx_name)
    cell_df = cell_df.groupby("metacell", as_index=False).agg({k: agg_func(v) for k, v in cell_df.items()}).set_index(idx_name)

    
    ##############################
    #      apply subsetting      #
    ##############################

    log.info("select cells from RNA anndata object")

    metacells_subs = set(rna_obs_annot["metacell"].unique().tolist())
    sub_mask = [x in metacells_subs for x in cell_df["metacell"].tolist()]

    cistopic_obj.fragment_matrix = cistopic_obj.fragment_matrix[:,sub_mask]
    cistopic_obj.binary_matrix = cistopic_obj.binary_matrix[:,sub_mask]
    cell_df = cell_df.loc[sub_mask,:]
    

    log.info("match RNA anndata object barcodes by metacell")

    rna_obs_metacell_dict = rna_obs_annot.rename_axis(index="barcodes").reset_index().set_index("metacell")["barcodes"].to_dict()
    cell_df.index = cell_df["metacell"].map(rna_obs_metacell_dict)

    
    ##############################
    #       store in object      #
    ##############################

    cistopic_obj.cell_data = cell_df
    cistopic_obj.cell_names = cell_df.index.tolist()

In [None]:
log.info(f"fragment matrix shape: {cistopic_obj.fragment_matrix.shape}")
log.info(f"binary matrix shape: {cistopic_obj.binary_matrix.shape}")
log.info(f"cell metadata shape: {cistopic_obj.cell_data.shape}")
log.info(f"cell names length: {len(cistopic_obj.cell_names)}")

In [None]:
log.info(cistopic_obj)
log.info(cistopic_obj.cell_data)

## 5) Save

In [None]:
log.info("save aggregated cistopic object...")

save_cistopic_obj(cistopic_obj, atac_path / "cistopic_obj_filt.pkl")

In [None]:
log.info("all done.")