# Import libraries

In [None]:
%load_ext autoreload
%autoreload 2

import tqdm, sys, os, time, logging, warnings
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
warnings.simplefilter('ignore', category=NumbaDeprecationWarning)

import pandas as pd
import numpy as np
import scipy as sp
import scipy.sparse as sps

import scanpy as sc
import anndata as ad
import muon as mu

from sklearn.metrics import adjusted_rand_score as ari

## Load and preprocess individual datasets

### scRNA-seq A

In [None]:
ad_scrnaseq_A = mu.ad.read_loom("data/processed/scRNAseq_10x_v3_AIBS.loom", sparse=True)
ad_scrnaseq_A.obs.replace('nan', None, inplace=True)

ad_scrnaseq_A.obs.set_index("sample_name", inplace=True)
ad_scrnaseq_A.var.set_index("gene_name", inplace=True)

In [None]:
nonzero_gene_counts = pd.Series(np.array((ad_scrnaseq_A.X>0).sum(1)).flatten(),
                                index=ad_scrnaseq_A.obs.index)
ad_scrnaseq_A.obs['gene.counts'] = ad_scrnaseq_A.obs['gene.counts'].combine_first(nonzero_gene_counts)

ad_scrnaseq_A.obs['doublet.score'].fillna(0.0, inplace=True)

In [None]:
def filter_cell_A(s):
    if s['class_label'] == 'Low Quality':
        return False

    condition = True
    
    # Gene count
    if s['class_label'] == 'Non-Neuronal':
        condition = condition & (s['gene.counts'] >= 1000)
    else:
        condition = condition & (s['gene.counts'] >= 2000)

    # Doublet cells
    condition = condition & (s['doublet.score'] <= 0.3)

    return condition
    
ad_scrnaseq_A = ad_scrnaseq_A[ad_scrnaseq_A.obs.apply(filter_cell_A, axis=1), :]
ad_scrnaseq_A

### snRNA-seq B

In [None]:
ad_snrnaseq_B = mu.ad.read_loom("data/processed/snRNAseq_10x_v3_Broad.loom")
ad_snrnaseq_B.obs.replace('nan', None, inplace=True)

ad_snrnaseq_B.obs.set_index("sample_name", inplace=True)
ad_snrnaseq_B.var.set_index("gene_name", inplace=True)

In [None]:
ad_snrnaseq_B

## Filtering low QC genes & cells

In [None]:
nonzero_gene_counts = pd.Series(np.array((ad_snrnaseq_B.X>0).sum(1)).flatten(),
                                index=ad_snrnaseq_B.obs.index)
ad_snrnaseq_B.obs['gene.counts'] = ad_snrnaseq_B.obs['gene.counts'].combine_first(nonzero_gene_counts)

ad_snrnaseq_B.obs['Broad.QC.doublet'].fillna(0.0, inplace=True)

In [None]:
def filter_cell_B(s):
    if s['class_label'] == 'Low Quality':
        return False
    
    condition = True

    # Gene count
    if s['class_label'] == 'Non-Neuronal':
        condition = condition & (s['gene.counts'] >= 500)
    else:
        condition = condition & (s['gene.counts'] >= 1000)

    # Doublet cells
    condition = condition & (s['Broad.QC.doublet'] <= 0.3)

    return condition

ad_snrnaseq_B = ad_snrnaseq_B[ad_snrnaseq_B.obs.apply(filter_cell_B, axis=1), :]
ad_snrnaseq_B

In [None]:
assert (ad_snrnaseq_B.var.index == ad_scrnaseq_A.var.index).all()
ad_scrnaseq_A.var = ad_scrnaseq_A.var.reset_index().set_index('gene_id')
ad_snrnaseq_B.var = ad_scrnaseq_A.var

In [None]:
ad_scrnaseq_A.var.index.duplicated().any(), ad_snrnaseq_B.var.index.duplicated().any()

# Integrate cells from multiple dataset

In [None]:
mdata = mu.MuData({"scrna": ad_scrnaseq_A, 
                   "snrna": ad_snrnaseq_B}, 
                    axis=1)
mdata

In [None]:
mdata.X = sps.vstack([mdata['scrna'].X, mdata['snrna'].X])

In [None]:
mdata.update()
mdata

## Save

In [None]:
mdata['snrna'].obs['Comb.QC'].fillna("nan", inplace=True)

In [None]:
mdata.write_h5mu("data/processed/scRNAseq_snRNAseq_filteredQC.h5mu")