In [4]:
import numpy.ma as ma
import random
import numpy as np

from collections import namedtuple

import celescope.tools.emptydrop_cr.sgt as cr_sgt  # # modified sgt.py
import celescope.tools.emptydrop_cr.stats as cr_stats  # # modified stats.py
from celescope.tools.matrix import CountMatrix

# Set random seed
random.seed(0)
np.random.seed(0)

# Number of additional barcodes to consider after the initial cell calling
N_CANDIDATE_BARCODES = 20000

# Number of partitions (max number of barcodes to consider for ambient estimation)
N_PARTITIONS = 90000

# Drop this top fraction of the barcodes when estimating ambient.
MAX_OCCUPIED_PARTITIONS_FRAC = 0.5

# Minimum number of UMIS per barcode to consider after the initial cell calling
MIN_UMIS = 500

# Minimum ratio of UMIs to the median (initial cell call UMI) to consider after the initial cell calling
MIN_UMI_FRAC_OF_MEDIAN = 0.01

# Maximum adjusted p-value to call a barcode as non-ambient
MAX_ADJ_PVALUE = 0.01

In [41]:
def est_background_profile_sgt(matrix, use_bcs):
    """ Estimate a gene expression profile on a given subset of barcodes.
         Use Good-Turing to smooth the estimated profile.
    Args:
      matrix (scipy.sparse.csc_matrix): Sparse matrix of all counts
      use_bcs (np.array(int)): Indices of barcodes to use (col indices into matrix)
    Returns:
      profile (use_features, np.array(float)): Estimated probabilities of length use_features.
    """
    # Use features that are nonzero anywhere in the data
    use_feats = np.flatnonzero(np.asarray(matrix.sum(1)))

    # Estimate background profile
    bg_profile_p = estimate_profile_sgt(matrix, use_bcs, use_feats)

    return (use_feats, bg_profile_p)

def estimate_profile_sgt(matrix, barcode_indices, nz_feat):
    """ Estimate a gene expression profile by Simple Good Turing.
    Args:
      raw_mat (sparse matrix): Sparse matrix of all counts
      barcode_indices (np.array(int)): Barcode indices to use
      nz_feat (np.array(int)): Indices of features that are non-zero at least once
    Returns:
      profile (np.array(float)): Estimated probabilities of length len(nz_feat).
    """
    # Initial profile estimate
    prof_mat = matrix[:, barcode_indices]

    profile = np.ravel(prof_mat[nz_feat, :].sum(axis=1))
    zero_feat = np.flatnonzero(profile == 0)

    # Simple Good Turing estimate
    p_smoothed, p0 = cr_sgt.sgt_proportions(profile[np.flatnonzero(profile)])

    # Distribute p0 equally among the zero elements.
    p0_i = p0/len(zero_feat)

    profile_p = np.repeat(p0_i, len(nz_feat))
    profile_p[np.flatnonzero(profile)] = p_smoothed

    assert np.isclose(profile_p.sum(), 1.0)
    return profile_p

In [6]:
count_matrix = CountMatrix.from_matrix_dir(matrix_dir="/SGRNJ06/randd/USER/cjj/celedev/rna/20230703/5_U7/Mus_0614PZ_SGOT_1_5lib/05.count/Mus_0614PZ_SGOT_1_5lib_raw_feature_bc_matrix")
raw_mat=count_matrix.get_matrix()

2023-07-03 17:41:15,138 - celescope.tools.matrix.from_matrix_dir - INFO - start...
2023-07-03 17:41:22,607 - celescope.tools.matrix.from_matrix_dir - INFO - done. time used: 0:00:07.464178


In [7]:
umis_per_bc = np.squeeze(np.asarray(raw_mat.sum(axis=0)))

In [8]:
bc_order = np.argsort(umis_per_bc)

In [14]:
len(bc_order)

10530

In [9]:
nz_bcs = np.flatnonzero(umis_per_bc)

In [15]:
len(nz_bcs)

10530

In [25]:
N_PARTITIONS

90000

In [175]:
empty_bcs = bc_order[::-1][int(N_PARTITIONS*MAX_OCCUPIED_PARTITIONS_FRAC):N_PARTITIONS]
#empty_bcs = bc_order[::-1][int(18000*MAX_OCCUPIED_PARTITIONS_FRAC):18000]

In [176]:
len(empty_bcs)

0

In [177]:
use_bcs = np.intersect1d(empty_bcs, nz_bcs, assume_unique=True)

In [178]:
len(use_bcs)

0

In [179]:
    if len(use_bcs) > 0:
        try:
            # Get used "Gene" features (eval_features)
            # and the smoothed prob profile per "Gene" (ambient_profile_p)
            eval_features, ambient_profile_p = est_background_profile_sgt(raw_mat.tocsc(), use_bcs)
        except cr_sgt.SimpleGoodTuringError as e:
            print(str(e))
    else:
        eval_features = np.zeros(0, dtype=int)
        ambient_profile_p = np.zeros(0)

In [180]:
eval_features

array([], dtype=int64)

In [181]:
ambient_profile_p

array([], dtype=float64)

In [182]:
    gg_filtered_indices, gg_filtered_metrics, _msg = cr_stats.filter_cellular_barcodes_ordmag(
        umis_per_bc, recovered_cells=3000)

In [183]:
    print('Cell-called barcodes metrics:')
    print('\n'.join(list(map(lambda x: '{}: {}'.format(*x), list(gg_filtered_metrics.items())))))
    print('==============================')

    orig_cell_bc_set = set(gg_filtered_indices)
    orig_cells = np.flatnonzero(np.fromiter((bc in orig_cell_bc_set for bc in range(raw_mat.shape[1])), dtype=bool))

    # No good incoming cell calls
    if orig_cells.sum() == 0:
        print('Error: No original cells are selected!')

    # Look at non-cell barcodes above a minimum UMI count
    eval_bcs = np.ma.array(np.arange(raw_mat.shape[1]))
    eval_bcs[orig_cells] = ma.masked

    median_initial_umis = np.median(umis_per_bc[orig_cells])
    
    min_umi_frac_of_median=MIN_UMI_FRAC_OF_MEDIAN
    min_umis_nonambient=MIN_UMIS
    min_umis = int(max(min_umis_nonambient, round(np.ceil(median_initial_umis * min_umi_frac_of_median))))

    print('Median UMIs of initial cell calls: {}'.format(median_initial_umis))
    print('Min UMIs: {}'.format(min_umis))

    eval_bcs[umis_per_bc < min_umis] = ma.masked
    n_unmasked_bcs = len(eval_bcs) - eval_bcs.mask.sum()

    # Take the top N_CANDIDATE_BARCODES by UMI count, of barcodes that pass the above criteria
    # For evaluation of non-ambient bcs using background info estimated from SGT
    eval_bcs = np.argsort(ma.masked_array(umis_per_bc, mask=eval_bcs.mask))[:n_unmasked_bcs][-N_CANDIDATE_BARCODES:]

    if len(eval_bcs) == 0:
        print('Warning: no eval bcs are selected to evaluate non-empty bcs from SGT results!')
        print('Output bcs from 1st round cell calling ONLY.')
        # return orig_cells, gg_filtered_metrics, None
    else:
        assert not np.any(np.isin(eval_bcs, orig_cells))
        print('Number of candidate bcs: {}'.format(len(eval_bcs)))
        print('Range candidate bc umis: {}, {}'.format(umis_per_bc[eval_bcs].min(), umis_per_bc[eval_bcs].max()))

        eval_mat = raw_mat.tocsc()[eval_features, :][:, eval_bcs]

        if len(ambient_profile_p) == 0:
            obs_loglk = np.repeat(np.nan, len(eval_bcs))
            pvalues = np.repeat(1, len(eval_bcs))
            sim_loglk = np.repeat(np.nan, len(eval_bcs))
        else:
            # Compute observed log-likelihood of barcodes being generated from ambient RNA
            obs_loglk = cr_stats.eval_multinomial_loglikelihoods(eval_mat, ambient_profile_p)

Cell-called barcodes metrics:
filtered_bcs: 5562
filtered_bcs_var: 62928.3964
max_filtered_bcs: 18000.0
Median UMIs of initial cell calls: 1437.0
Min UMIs: 500
Number of candidate bcs: 1334
Range candidate bc umis: 500, 744


In [173]:
obs_loglk

array([nan, nan, nan, ..., nan, nan, nan])

In [None]:
eval_mat

In [129]:
gb_per_bc = float(eval_mat.shape[0] * eval_mat.dtype.itemsize) / (1024**3)

In [130]:
gb_per_bc

0.0

In [184]:
eval_bcs

array([  655,  5900,  4573, ...,  2671, 10368,  3999])

In [111]:
def eval_multinomial_loglikelihoods(matrix, profile_p, max_mem_gb=0.1):
    """Compute the multinomial log PMF for many barcodes
    Args:
      matrix (scipy.sparse.csc_matrix): Matrix of UMI counts (feature x barcode)
      profile_p (np.ndarray(float)): Multinomial probability vector
      max_mem_gb (float): Try to bound memory usage.
    Returns:
      log_likelihoods (np.ndarray(float)): Log-likelihood for each barcode
    """
    gb_per_bc = float(matrix.shape[0] * matrix.dtype.itemsize) / (1024**3)
    bcs_per_chunk = max(1, int(round(max_mem_gb/gb_per_bc)))
    num_bcs = matrix.shape[1]

    loglk = np.zeros(num_bcs)

    for chunk_start in range(0, num_bcs, bcs_per_chunk):
        chunk = slice(chunk_start, chunk_start+bcs_per_chunk)
        matrix_chunk = matrix[:, chunk].transpose().toarray()
        n = matrix_chunk.sum(1)
        loglk[chunk] = sp_stats.multinomial.logpmf(matrix_chunk, n, p=profile_p)
    return loglk

In [119]:
obs_loglk

array([-1581.09090462, -1567.27135016, -1509.93730578, ...,
       -1936.09485175, -2130.44639483, -2213.87755668])

In [150]:
np.zeros(1)

array([0.])