# run dsb nomralization in moun for ADT denoise and normalization

In [None]:
import h5py
import scipy.sparse as scs
import pandas as pd
import anndata as ad
import scanpy as sc
import os
import glob
import muon as mu
import mudata as md
from mudata import MuData
import numpy as np
import seaborn as sns
from muon import prot as pt

In [None]:
data_path = '/home/jupyter/data/preRA_teaseq/EXP-00243/cache/'

In [None]:
# load the rna raw count data
raw_h5_path = '/home/jupyter/data/preRA_teaseq/EXP-00243/raw_counts/cache'
raw_files = glob.glob(raw_h5_path + '*/**/*.h5', recursive=True)
len(raw_files)

In [None]:
# helper functions
# define a function to read gene expression matrix
def read_mat(h5_con):
    mat = scs.csc_matrix(
        (h5_con['matrix']['data'][:],  # Count values
         h5_con['matrix']['indices'][:],  # Row indices
         h5_con['matrix']['indptr'][:]),  # Pointers for column positions
        shape=tuple(h5_con['matrix']['shape'][:])  # Matrix dimensions
    )
    return mat

# define a function to obeservation (i.e. metadata)


def read_obs(h5con):
    bc = h5con['matrix']['barcodes'][:]
    bc = [x.decode('UTF-8') for x in bc]

    # Initialized the DataFrame with cell barcodes
    obs_df = pd.DataFrame({'barcodes': bc})

    # Get the list of available metadata columns
    obs_columns = h5con['matrix']['observations'].keys()

    # For each column
    for col in obs_columns:
        # Read the values
        values = h5con['matrix']['observations'][col][:]
        # Check for byte storage
        if (isinstance(values[0], (bytes, bytearray))):
            # Decode byte strings
            values = [x.decode('UTF-8') for x in values]
        # Add column to the DataFrame
        obs_df[col] = values

    return obs_df
# define a function to construct anndata object from a h5 file


def read_h5_anndata(h5_file):
    h5_con = h5py.File(h5_file, mode='r')
    # extract the expression matrix
    mat = read_mat(h5_con)
    # extract gene names
    genes = h5_con['matrix']['features']['name'][:]
    genes = [x.decode('UTF-8') for x in genes]
    # extract metadata
    obs_df = read_obs(h5_con)
    # construct anndata
    adata = ad.AnnData(mat.T,
                       obs=obs_df)
    # make sure the gene names aligned
    adata.var_names = genes

    adata.var_names_make_unique()
    return adata

# helper functions for extract adt from H5 files

# define a function to read ADT expression ADT


def read_adt(h5_con):
    mat = scs.csc_matrix(
        (h5_con['prot']['data'][:],  # Count values
         h5_con['prot']['indices'][:],  # Row indices
         h5_con['prot']['indptr'][:]),  # Pointers for column positions
        shape=tuple(h5_con['prot']['shape'][:])  # ADT dimensions
    )
    return mat

# define a function to construct anndata object from a h5 file


def read_h5_adt_anndata(h5_file):
    h5_con = h5py.File(h5_file, mode='r')
    # extract the expression matrix
    mat = read_adt(h5_con)
    # extract gene names
    genes = h5_con['prot']['features']['id'][:]
    genes = [x.decode('UTF-8') for x in genes]
    # extract metadata
    obs_df = read_obs(h5_con)
    # construct anndata
    adata = ad.AnnData(mat.T,
                       obs=obs_df)
    # make sure the gene names aligned
    adata.var_names = genes

    adata.var_names_make_unique()
    return adata

# define a function to combine above and constracut a mudata file from the h5 files of the cite seq dataset


def read_cite_mudata(h5_file):
    rna = read_h5_anndata(h5_file)
    h5_adt = read_h5_adt_anndata(h5_file)
    mdata = MuData({"rna": rna, "ADT": h5_adt})
    return mdata

# define a function to run all h5 file and combine a mudata object


def read_cite_mudatas(h5_files):
    rna = [read_h5_anndata(filename) for filename in h5_files]
    join_rna = rna[0].concatenate(rna[1:])
    adt = [read_h5_adt_anndata(filename) for filename in h5_files]
    join_adt = adt[0].concatenate(adt[1:])
    mdata = MuData({"rna": join_rna, "ADT": join_adt})
    return mdata

In [None]:
# helper functions

# define a function to construct anndata object from a raw rna h5 file
def read_raw_rna(h5_file):
    h5_con = h5py.File(h5_file, mode='r')
    # extract the expression matrix
    mat = read_mat(h5_con)
    # extract gene names
    genes = h5_con['matrix']['features']['name'][:]
    genes = [x.decode('UTF-8') for x in genes]
    # extract metadata
    bc = h5_con['matrix']['barcodes'][:]
    bc = [x.decode('UTF-8') for x in bc]

    # Initialized the DataFrame with cell barcodes
    obs_df = pd.DataFrame({'barcodes': bc})

    # construct anndata
    adata = ad.AnnData(mat.T, obs=obs_df)
    # make sure the gene names aligned
    adata.var_names = genes
    adata.var_names_make_unique()
    # set barcodes as the index
    adata.obs['barcodes'] = adata.obs['barcodes'].str.split('-').str[0]
    adata.obs = adata.obs.set_index('barcodes')
    return adata


# define a function to construct anndata object from raw adt count table
def read_raw_adt(adt_file):
    # load the raw adt counts
    raw_adt_tb = pd.read_csv(adt_file, index_col='cell_barcode')
    raw_adt_tb.index.rename('barcodes', inplace=True)
    raw_adt_counts = raw_adt_tb.drop('total', axis=1)
    raw_adt = ad.AnnData(raw_adt_counts)
    return raw_adt

In [None]:
# raw_rna = read_raw_rna(raw_files[1])

## load prefiltered Raw data contains droplet

In [None]:
# load the adt raw count data
raw_adt_files = glob.glob(
    raw_h5_path + '*/**/*_Tag_Counts.csv', recursive=True)
len(raw_adt_files)

In [None]:
# match the raw rna and adt data file
rna_files_tb = pd.DataFrame({'rna_files': raw_files})
# match the raw rna and adt data file
adt_files_tb = pd.DataFrame({'adt_files': raw_adt_files})
# get the well id to match
rna_files_tb['well_id'] = rna_files_tb['rna_files'].str.extract(r'(c\dw\d)')
adt_files_tb['well_id'] = adt_files_tb['adt_files'].str.extract(r'(C\dW\d)')
adt_files_tb['well_id'] = adt_files_tb['well_id'].str.lower()
# generate the table to match data files
files_tb = rna_files_tb.merge(adt_files_tb, on='well_id')

In [None]:
files_tb

In [None]:
# define a function to construct anndata object from a h5 file
def read_raw_cite(file_table):
    rna = [read_raw_rna(filename) for filename in file_table['rna_files']]
    # join_rna = rna[0].concat(rna[1:])
    join_rna = ad.concat(rna, index_unique="_")
    adt = [read_raw_adt(filename) for filename in file_table['adt_files']]
    join_adt = ad.concat(adt, index_unique='_')
    # create mutatd
    raw_mudata = MuData({"rna": join_rna, "prot": join_adt})
    return raw_mudata

In [None]:
mdata_raw = read_raw_cite(files_tb.loc[0:1, :])

In [None]:
mdata_raw

In [None]:
mdata_raw['rna'].obs

In [None]:
mdata_raw['prot'].var_names

In [None]:
mdata_raw['rna'].obs.index.intersection(mdata_raw['prot'].obs.index)

In [None]:
# remove the cells from rna qc filtering
mu.pp.intersect_obs(mdata_raw)

In [None]:
mdata_raw.update()

In [None]:
mdata_raw

In [None]:
# make clean names for the adt names
mdata_raw['prot'].var.index =  mdata_raw['prot'].var.index.str.replace('-|\\.', '_', regex=True)

In [None]:
mdata_raw.update()

## load the actual data

In [None]:
# load the dataset
joint_mudata_fl = mu.read_h5mu(
    '/home/jupyter/data/preRA_teaseq/EXP-00243/totalVI/' +
    "PreRA_teaseq_qc_filtered_cells_3modality_rmBR2024.h5mu")

In [None]:
joint_mudata_fl

In [None]:
joint_mudata_fl['prot'].layers['counts'][1:50, 1:50].toarray()

In [None]:
joint_mudata_fl['prot'].var.index

In [None]:
mu.pl.histogram(joint_mudata_fl['prot'],
                joint_mudata_fl['prot'].var_names[5:10], bins=50)

In [None]:
# calcualte the distribution of the rna transcript in the raw data
mdata_raw['rna'].obs["log10umi"] = np.array(
    np.log10(mdata_raw['rna'].X.sum(axis=1) + 1)).reshape(-1)

In [None]:
# plot the distribution of the rna umi threshold in the raw data
mu.pl.histogram(mdata_raw['rna'], ['log10umi'], bins=50)

In [None]:
# zoom in certain region
mu.pl.histogram(mdata_raw['rna'][mdata_raw['rna'].obs.log10umi >= 1.5], ['log10umi'], bins=50)

In [None]:
# plot the tentative threshold
mu.pl.histogram(mdata_raw['rna'][(mdata_raw['rna'].obs.log10umi >= 1.5)&
                (mdata_raw['rna'].obs.log10umi <= 2.6)], ['log10umi'], bins=50)

In [None]:
joint_mudata_fl['rna'].layers['counts']

In [None]:
joint_mudata_fl['rna'].obs["log10umi"] = np.array(
    np.log10(joint_mudata_fl['rna'].X.sum(axis=1) + 1)).reshape(-1)

In [None]:
mu.pl.histogram(joint_mudata_fl['rna'], ['log10umi'], bins=50)

In [None]:
# identify the isotype in the
isotypes = mdata_raw['prot'].var_names[mdata_raw['prot'].var_names.str.contains(
    'iso')]
isotypes

In [None]:
isotypes = joint_mudata_fl['prot'].var_names[joint_mudata_fl['prot'].var_names.str.contains(
    'iso')]
isotypes

In [None]:
mdata_raw['prot'].var_names[~mdata_raw['prot'].var_names.isin(
    joint_mudata_fl['prot'].var_names)]

In [None]:
joint_mudata_fl['prot'].var_names[~joint_mudata_fl['prot'].var_names.isin(
    mdata_raw['prot'].var_names)]

In [None]:
mdata_raw['prot'].X

In [None]:
joint_mudata_fl['prot'].X == joint_mudata_fl['prot'].layers['counts']

In [None]:
mdata_raw

In [None]:
joint_mudata_fl['prot'].var_names

In [None]:
# run dsb normalization
pt.pp.dsb(joint_mudata_fl, mdata_raw,  empty_counts_range=(1.5, 2.6), 
          isotype_controls=isotypes, add_layer=True,
          random_state=1)

In [None]:
mu.pl.histogram(joint_mudata_fl['prot'],
                joint_mudata_fl['prot'].var_names[5:10], bins=50)

In [None]:
sc.pl.scatter(joint_mudata_fl['prot'], x="CD3", y="CD19", layers='counts')
sc.pl.scatter(joint_mudata_fl['prot'], x="CD3", y="CD19", layers='dsb')

In [None]:
joint_mudata_fl['prot'].layers['counts']

In [None]:
# move dsb data into X slot
joint_mudata_fl['prot'].X  = joint_mudata_fl['prot'].layers['dsb'].copy()

In [None]:
joint_mudata_fl.update()

# run analysis in dsb normalized data

In [None]:
# run analysis in 
prot = joint_mudata_fl.mod['prot']

In [None]:
# save a copy of the dsb normalized data into dsb layer
# ADT.layers['dsb'] = ADT.X

In [None]:
sc.tl.pca(prot)

In [None]:
sc.pl.pca_variance_ratio(prot, log=True)

In [None]:
sc.pl.pca(prot, color='CD3')

In [None]:
# run umap for protein
sc.pp.neighbors(prot, n_pcs=20)
sc.tl.umap(prot)

In [None]:
sc.pl.umap(prot, color=['CD3', 'CD19', 'HLA_DR',
           'CD56', 'CD14'], vmin='p1', vmax='p99')

In [None]:
joint_mudata_fl.update()

In [None]:
joint_mudata_fl

In [None]:
# save the dataset
joint_mudata_fl.write(
    '/home/jupyter/data/preRA_teaseq/EXP-00243/totalVI/' +
    "PreRA_teaseq_qc_filtered_cells_3modality_rmBR2024.h5mu")