In [None]:
import sys
import scanpy as sc 
import random
import glob
import os
import scipy as sp
import csv

import pandas as pd
import numpy as np
import seaborn as sns 
import matplotlib.pyplot as plt

from signaturescoring import score_signature
from signaturescoring.utils.utils import get_least_variable_genes_per_bin_v1

from scanpy.preprocessing._utils import _get_mean_var

sys.path.append('../..')
from data.load_data import load_datasets, load_dgex_genes_for_mal_cells

sc.settings.verbosity = 2

In [None]:
adata = load_datasets('luad')
if 'log1p' in adata.uns_keys():
    adata.uns['log1p']['base'] = None
else:
    adata.uns['log1p'] = {'base': None}

In [None]:
def get_bins_info(adata, nbins=25):
    X = adata.X
    df = pd.DataFrame()
    # compute mean and variance
    df['mean'], df['var'] = _get_mean_var(X)
    df = df.set_index(adata.var_names)
    df = df.sort_values(by='mean')
    
    gene_means = df['mean']
    
    ranked_gene_means = gene_means.rank(method="min")
    gene_bins = pd.cut(ranked_gene_means, 25, labels=False)
    
    bin_info =  pd.concat([
        gene_bins.value_counts().sort_index(),
        round(gene_bins.value_counts().sort_index() / len(gene_means) * 100, 2)
    ],axis=1)
    bin_info.columns = ['nr_genes', 'percent_tot_genes']
    
    
    return bin_info, gene_means, gene_bins

In [None]:
bin_info, gene_means, gene_bins = get_bins_info(adata)
bin_info

In [None]:
least_variable_genes_per_bin = get_least_variable_genes_per_bin_v1(adata, gene_bins, 100, method='seurat')

In [None]:
least_variable_genes_per_bin

In [None]:
gene_bins_diff = gene_bins.diff(1)

In [None]:
bin_idx = np.where(gene_bins_diff == 1)[0]

In [None]:
bin_idx = np.append(bin_idx, len(gene_bins))
bin_idx

In [None]:
gene_means = gene_means.sort_values()

In [None]:
for i,(key, val) in enumerate(least_variable_genes_per_bin.items()):
    if i==0:
        curr_idx = 0
    else:
        curr_idx = bin_idx[i-1]
    next_idex = bin_idx[i]
    plt.figure(figsize=(10,5))
    tmp = gene_means.iloc[curr_idx:next_idex]
    tmp.plot()
    for y in val:
        plt.axvline(tmp.index.get_loc(y), c='r', ls=':')
    plt.title(f'Expression bin {key}.')