# Purpose
- To re-create mfishtools in python (along with Hannah's code)
- Validate with inhibitory gene panel selection in R: no subsampling
- Compare with 240926_for_python_validation.Rmd

In [55]:
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys
from importlib import reload

from mfishtoolspy import mfishtools_archived as mft_r


In [56]:
# options for cluster grouping
gene_panel_selection_ops = {
    'panel_name': 'inhibitory',
    'full_panel_size': 28,
    'starting_genes': ["Gad2","Slc17a7","Pvalb","Sst","Vip","Cck","Tac1","Npy","Crh","Necab1","Ptprt","Kirrel3","Penk","Hpse","Calb2","Chodl"],
    'layer_1234_filter': True,
    'use_supertypes': False,
    'blend_supertypes': False,
    'remove_bad_genes': False,
    'other_as_subclass': True,
}

# gene_panel_selection_ops = {
#     'panel_name': 'pan_neuronal',
#     'full_panel_size': 30,
#     'starting_genes': ["Gad2","Slc17a7","Tac2","Tac1","Calb1","Npy","Cck","Vip","Crh","Calb2","Penk","Oprm1","Pvalb","Ptprt","Kirrel3","Sst","Ndnf","Nos1","Baz1a","Sncg","Mybpc1","Lamp5","Hpse","Etv1","Rorb","Agmat","Chat","Adamts2"],
#     'layer_1234_filter': True,
#     'use_supertypes': False,
#     'blend_supertypes': True,
#     'remove_bad_genes': True,
#     'other_as_subclass': True,
# }

In [59]:
# paths to the data
data_folder = Path('//allen/programs/mindscope/workgroups/omfish/hannahs/mfish_project/gene_panels/L23_inh_panel/Mm_VISp_14236_20180912')
output_folder = Path(r'\\allen\programs\mindscope\workgroups\learning\jinho\gene_panel_selection\mfishtoolspy\results'.replace('\\', '/'))

In [61]:
# read annotation
annotation = pd.read_feather(data_folder / 'anno.feather')

# read data (tasic 2018 v1)
data = pd.read_feather(data_folder / 'data_t.feather')
# Takes about 1 minute to run

data.set_index('gene', inplace=True, drop=True)
annotation.set_index('sample_id', inplace=True, drop=True)
assert np.all(annotation.index.values == data.columns.values)

In [62]:
# read supertype information
# Though unnecessary, because we are not using supertypes
supertype_folder = Path(r'\\allen\programs\mindscope\workgroups\learning\jinho\gene_panel_selection\mfishtoolspy\data'.replace('\\', '/'))
supertype_fn = supertype_folder / 'tasic2018_supertypes_manual_v2.xlsx'
sheet_name = 'all_supertypes_v2'
supertype = pd.read_excel(supertype_fn, sheet_name=sheet_name)
supertype.rename(columns={'Cell Type': 'cell_type', 'Supertype': 'supertype'}, inplace=True)
supertype.cell_type = supertype.cell_type.str.replace('\xa0', ' ')
supertype.supertype = supertype.supertype.str.replace('\xa0', ' ')
assert np.all([ct in annotation['cluster_label'].values for ct in supertype.cell_type.values])
supertype.set_index('cell_type', inplace=True, drop=True)

annotation['supertype_label'] = annotation.cluster_label.map(supertype.supertype)

  annotation['supertype_label'] = annotation.cluster_label.map(supertype.supertype)


In [63]:
# Preprocessing specific to match with Hannah's code
keep_class = ['GABAergic']
gabaergic_layer_threshold = 0.15
L6_layer_threshold = 0.75
L1234_labels = ['L1', 'L1-L2/3', 'L1-L4', 'L2/3', 'L2/3-L4', 'L4']
L6_labels = ['L5-L6', 'L6', 'L6b']

keep_types = []
if gene_panel_selection_ops['layer_1234_filter']:
    if 'Glutamatergic' in keep_class:
        L234_exc_subclasses = ['L2/3 IT','L4']
        L5_exc_subclasses = ['L5 IT','L5 PT','NP']
        L234_exc_types = annotation[annotation['subclass_label'].isin(L234_exc_subclasses)].cluster_label.unique()
        L5_exc_types = annotation[annotation['subclass_label'].isin(L5_exc_subclasses)].cluster_label.unique()
        keep_types.extend(L234_exc_types)
    if 'GABAergic' in keep_class:
        layer_df = annotation.query('class_label=="GABAergic"')[['layer_label', 'cluster_label']].copy()
        layer_table = layer_df.groupby(['layer_label', 'cluster_label']).size().unstack(fill_value=0)
        prop_table = layer_table.div(layer_table.sum(axis=0), axis=1)
        L1234_prop_sum = prop_table.loc[L1234_labels].sum(axis=0)
        L1234_inh_types = set(L1234_prop_sum[L1234_prop_sum >= gabaergic_layer_threshold].index.values)
        not_L1234_inh_types = set(layer_df.cluster_label).difference(L1234_inh_types)
        L6_prop_sum = prop_table.loc[L6_labels].sum(axis=0)
        L6_inh_types = set(L6_prop_sum[L6_prop_sum >= L6_layer_threshold].index.values)
        L5_inh_types = not_L1234_inh_types.difference(L6_inh_types)
        keep_types.extend(L1234_inh_types)
    
    # Check these codes later
    if gene_panel_selection_ops['other_as_subclass']:
        L5_inh_cluster_labels = []
        for cluster_label in L5_inh_types:
            temp_subclass = cluster_label.split(' ')[0]
            indices = annotation.query('cluster_label==@cluster_label').index
            annotation.loc[indices, 'cluster_label'] = f'L5 Inh {temp_subclass}' # need to change this code later. Don't reuse the same column name!
            L5_inh_cluster_labels.append(f'L5 Inh {temp_subclass}')
        L5_inh_cluster_labels = np.unique(L5_inh_cluster_labels)
        keep_types.extend(L5_inh_cluster_labels)
    else:
        keep_types.extend(L5_inh_types)
    
    # Check these codes later
    if gene_panel_selection_ops['use_supertypes']:
        keep_clusts = annotation.query('cluster_label in @keep_types').supertype_label.unique()
        L5_inh_types = annotation.query('cluster_label in @L5_inh_types').L5_inh_types.unique()
        annotation['cluster_label_original'] = annotation['cluster_label']
        annotation['cluster_label'] = annotation['supertype_label']
        annotation.query('cluster_label_original in @L5_inh_cluster_labels')['cluster_label'] = \
            annotation.query('cluster_label_original in @L5_inh_cluster_labels')['cluster_label_original']
    else:
        keep_clusts = annotation.query('cluster_label in @keep_types').cluster_label.unique()
    


In [64]:
# remove starting genes that are not in the data
st_in_data = [st not in data.index.values for st in gene_panel_selection_ops['starting_genes']]
if np.any(st_in_data):
    st_not_in_data = [st for i, st in enumerate(gene_panel_selection_ops['starting_genes']) if st_in_data[i]]
    print(f'{st_not_in_data} are not in the data')
    gene_panel_selection_ops['starting_genes'] = [st for st in gene_panel_selection_ops['starting_genes'] if st not in st_not_in_data]

In [65]:
# Convert cpm data to log2
data_log2 = np.log2(data + 1)
# takes about 9 s to run

In [66]:
# calculate proportions and medians per cluster
cluster_names = annotation.cluster_label.unique()
expr_thresh = 1
# make data_log2 to have another level of columns with matching cluster names per cell ID
data_log2_cluster = data_log2.copy().T
assert np.all(data_log2.columns == annotation.index.values)
# groupby cluster and calculate median and proportion
data_log2_cluster['cluster_label'] = annotation['cluster_label']
median_per_cluster = data_log2_cluster.groupby('cluster_label').median().T
prop_expr = data_log2_cluster.groupby('cluster_label').apply(lambda x: (x > expr_thresh).mean(axis=0)).T
assert np.all(prop_expr.index.values == median_per_cluster.index.values)
assert np.all(prop_expr.index.values == data_log2.index.values)



In [67]:
# Compare with R results
prop_expr_r_fn = output_folder / 'prop_expr_tmp_R.csv'
median_per_cluster_r_fn = output_folder / 'median_expr_tmp_R.csv'
prop_expr_r = pd.read_csv(prop_expr_r_fn, index_col=0)
median_per_cluster_r = pd.read_csv(median_per_cluster_r_fn, index_col=0)

In [68]:
np.max(np.abs(prop_expr - prop_expr_r))

5.551115123125783e-16

In [70]:
# order median_per_cluster_r by the column name order of median_per_cluster
median_per_cluster_r = median_per_cluster_r[median_per_cluster.columns]
median_per_cluster_r.head()

Unnamed: 0,Astro Aqp4,CR Lhx5,Endo Ctla2a,Endo Cytl1,L2/3 IT VISp Adamts2,L2/3 IT VISp Agmat,L2/3 IT VISp Rrad,L4 IT VISp Rspo1,L5 IT VISp Batf3,L5 IT VISp Col27a1,...,Vip Crispld2 Kcne4,Vip Igfbp4 Mab21l1,Vip Igfbp6 Car10,Vip Igfbp6 Pltp,Vip Lect1 Oxtr,Vip Lmo1 Myl1,Vip Ptprt Pkp2,Vip Pygm C1ql1,Vip Rspo1 Itga4,Vip Rspo4 Rxfp1 Chat
0610005C13Rik,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0610006L08Rik,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0610007P14Rik,0.0,6.247789,0.0,0.0,6.914645,6.657416,6.141694,6.418927,6.173059,6.312999,...,5.870013,6.374446,6.279457,6.55238,6.466378,6.474331,6.894218,6.418571,6.642808,6.875275
0610009B22Rik,0.814706,0.0,0.0,0.0,5.769752,6.076098,4.468977,5.78896,5.990343,6.339431,...,5.918368,6.000401,5.269012,5.617239,6.258903,6.428014,6.458077,5.990274,6.16506,6.249834
0610009E02Rik,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [71]:
np.max(np.abs(median_per_cluster.values - median_per_cluster_r.values))

4.973799150320701e-14

## median and propexpr correct
- Now go into the filter panel function

In [83]:
run_genes, keep_genes  = mft_r.filter_panel_genes_archived(
    2**median_per_cluster - 1, 
    prop_expr=prop_expr,
    on_clusters=list(keep_clusts),
    off_clusters=list(annotation.query('class_label=="Non-Neuronal"').cluster_label.unique()),
    starting_genes=gene_panel_selection_ops['starting_genes'],
    num_binary_genes=300,
    min_on=10,
    max_on=300,
    max_off=10,
    min_length=1400,
    max_fraction_on_clusters=0.5,
    exclude_families=["LOC","Fam","RIK","RPS","RPL","\\-","Gm","Rnf","BC0"])

1263 total genes pass constraints prior to binary score calculation.


In [80]:
run_genes_r_fn = output_folder / 'run_genes_tmp_inh_R.csv'
run_genes_r = pd.read_csv(run_genes_r_fn, index_col=0)
run_genes_r = run_genes_r.index.values

In [81]:
len(run_genes_r)

313

In [84]:
len(run_genes)

313

In [85]:
set(run_genes_r) - set(run_genes)

set()

# Up until gene filtering is working correctly


## Testing keep_genes, on_clusters, and beta_Scores

In [86]:
genes = median_per_cluster.index
keep_genes_df = pd.DataFrame({'keep':keep_genes}, index=genes)
# keep_genes_df.to_csv(output_folder / 'keep_genes_python.csv')

In [87]:
keep_genes_df

Unnamed: 0_level_0,keep
gene,Unnamed: 1_level_1
0610005C13Rik,False
0610006L08Rik,False
0610007P14Rik,False
0610009B22Rik,False
0610009E02Rik,False
...,...
n-R5s142,False
n-R5s143,False
n-R5s144,False
n-R5s146,False


In [88]:
keep_genes_r = pd.read_csv(output_folder / 'keepGenes_tmp_inh_R.csv')
keep_genes_r.rename(columns={'Unnamed: 0':'gene', 'x': 'keep_r'}, inplace=True)
keep_genes_r.set_index('gene', inplace=True, drop=True)

In [89]:
keep_genes_r

Unnamed: 0_level_0,keep_r
gene,Unnamed: 1_level_1
0610005C13Rik,False
0610006L08Rik,False
0610007P14Rik,False
0610009B22Rik,False
0610009E02Rik,False
...,...
n-R5s142,False
n-R5s143,False
n-R5s144,False
n-R5s146,False


In [90]:
np.all(keep_genes_r.index == keep_genes_df.index)

True

In [91]:
keep_genes_merged = keep_genes_r.merge(keep_genes_df, left_index=True, right_index=True, how='inner')

In [92]:
keep_genes_merged

Unnamed: 0_level_0,keep_r,keep
gene,Unnamed: 1_level_1,Unnamed: 2_level_1
0610005C13Rik,False,False
0610006L08Rik,False,False
0610007P14Rik,False,False
0610009B22Rik,False,False
0610009E02Rik,False,False
...,...,...
n-R5s142,False,False
n-R5s143,False,False
n-R5s144,False,False
n-R5s146,False,False


In [93]:
np.where(keep_genes_merged.keep.values != keep_genes_merged.keep_r.values)

(array([], dtype=int64),)

In [94]:
np.equal(keep_genes_df.values, keep_genes_r.values).all()

True

## Up until keep_genes are the same.
- So the problem is in beta score calculation

In [97]:
from mfishtoolspy import filtering
on_clusters=list(keep_clusts)
top_beta_scores = filtering.get_beta_score(prop_expr.loc[keep_genes, on_clusters], True)
top_beta = filtering.get_beta_score(prop_expr.loc[keep_genes, on_clusters], False)

In [99]:
top_beta_scores_r = pd.read_csv(output_folder / 'topBeta_scores_tmp_inh_R.csv', index_col=0)
top_beta_r = pd.read_csv(output_folder / 'topBeta_tmp_inh_R.csv', index_col=0)

In [101]:
top_beta_r

Unnamed: 0,x
AA414768,619
AF529169,287
AK129341,571
AU022252,690
AU041133,1148
...,...
Zkscan14,926
Znhit6,1043
Zpbp,941
Zscan18,1210


In [109]:
np.all(top_beta_r.x.values == top_beta)

True

In [102]:
np.max(np.abs(top_beta_scores_r.x.values - top_beta_scores))

1.2501111257279263e-12

# Testing gene panel building now

In [110]:
keep_sampled_cells = annotation.query('cluster_label in @keep_clusts').index.values


In [113]:
reload(mft_r)
built_panel = mft_r.build_mapping_based_marker_panel_archived(
                map_data=data_log2.loc[run_genes, keep_sampled_cells].copy(),
                median_data=median_per_cluster.loc[run_genes, keep_clusts].copy(),
                cluster_call=annotation.loc[keep_sampled_cells, 'cluster_label'].copy(),
                panel_size=gene_panel_selection_ops['full_panel_size'],
                current_panel=gene_panel_selection_ops['starting_genes'].copy(),
                num_subsample=None,
                panel_min=3, # should not affect
                # optimize='correlation_distance',
                optimize='fraction_correct',

)
# Takes 7 minutes and 17 seconds to run (pan-inhibitory)
# Takes 6 minutes and 46 seconds to run after indexing cluster_label map
# Takes 7 minutes and 39 seconds after removing cluster_distance re-calculation (weird...)
# Takes 7 minutes and 17 seconds after flattening cluster_distance dataframe. (revert back for readability)

# Leave the cluster_distance as dataframe, as it is easier to read

Added Nkain3 with 0.662, now matching [16].
Added Oprm1 with 0.685, now matching [17].
Added Htr1f with 0.701, now matching [18].
Added Il1rapl2 with 0.713, now matching [19].
Added Tpbg with 0.723, now matching [20].
Added Ngf with 0.731, now matching [21].
Added Fibcd1 with 0.739, now matching [22].
Added Cpne4 with 0.745, now matching [23].
Added Sfrp2 with 0.750, now matching [24].
Added Grm8 with 0.755, now matching [25].
Added Pfkfb3 with 0.761, now matching [26].
Added Tox with 0.764, now matching [27].


In [116]:
built_panel_r = pd.read_csv(output_folder / 'gene_panel_selection_results_nosubsample_fraction_correct_R.csv', index_col=0)

In [119]:
np.all(built_panel_r.index.values == built_panel)

True

In [121]:
built_panel_r.iloc[16:].Accuracy

Gene
Nkain3      66.167825
Oprm1       68.466061
Htr1f       70.069482
Il1rapl2    71.298771
Tpbg        72.314270
Ngf         73.115981
Fibcd1      73.864244
Cpne4       74.469980
Sfrp2       74.986638
Grm8        75.485480
Pfkfb3      76.091217
Tox         76.411901
Name: Accuracy, dtype: float64

# All works fine, validated with both fraction_correct and correlation_distance