In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from thalamus_merfish_analysis import ccf_plots as cplot
from thalamus_merfish_analysis import ccf_images as cimg
from thalamus_merfish_analysis import abc_load as abc

get_ipython().run_line_magic('matplotlib', 'inline') 

## Load brain3 data

In [2]:
obs = abc.get_combined_metadata(realigned=True)
obs = abc.label_thalamus_spatial_subset(obs, realigned=True, filter_cells=True)

ccf_label = 'parcellation_substructure_realigned'
nn_classes = [
    "31 OPC-Oligo",
    "30 Astro-Epen",
    "33 Vascular",
    "34 Immune",
]
# subset just the neurons
obs_neurons = obs[~obs['class'].isin(nn_classes)]

In [5]:
th_names = abc.get_thalamus_substructure_names()
th_subregion_names = list(set(th_names).difference(['TH-unassigned']))


In [81]:
# define celltype lists based on strict spatial subset
# obs_th_neurons = obs_neurons[obs_neurons[ccf_label].isin(th_names)]
obs_th_neurons = obs_neurons[obs_neurons['parcellation_substructure_realigned'].isin(th_names) |
                             obs_neurons['parcellation_substructure'].isin(th_names)]
th_celltypes = dict()
th_celltypes['subclass'] = obs_th_neurons['subclass'].value_counts().loc[lambda x: x>100].index
print(f"{len(th_celltypes['subclass'])=}")

th_celltypes['supertype'] = obs_th_neurons['supertype'].value_counts().loc[lambda x: x>20].index
print(f"{len(th_celltypes['supertype'])=}")

th_celltypes['cluster'] = obs_th_neurons['cluster'].value_counts().loc[lambda x: x>10].index
print(f"{len(th_celltypes['cluster'])=}")

## matching metrics

In [44]:
from sklearn.metrics import f1_score, precision_score, recall_score

def get_nucleus_celltype_match(obs, ccf_label, ccf_name, celltype_label, celltype_name):
    nucleus = obs[ccf_label] == ccf_name
    celltype = obs[celltype_label] == celltype_name
    record = {
        'nucleus_precision': precision_score(nucleus, celltype),
        'nucleus_coverage': recall_score(nucleus, celltype),
        'nucleus_f1': f1_score(nucleus, celltype),
    }
    return record
    
def get_nucleus_celltype_metrics(obs, ccf_label, celltype_label, 
                                 ccf_list=None, celltype_list=None):
    if celltype_list is None:
        celltype_list = obs[celltype_label].unique()
    # else: obs = obs[obs[celltype_label].isin(celltype_list)]
    if ccf_list is None:
        ccf_list = obs[ccf_label].unique()
    # else: obs = obs[obs[ccf_label].isin(ccf_list)]
    # could subset like this, but need to track the negatives...
    
    records = []
    for celltype_name in celltype_list:
        celltype = obs[celltype_label] == celltype_name
        for ccf_name in ccf_list:
            nucleus = obs[ccf_label] == ccf_name
            tp = (nucleus & celltype).sum()
            fp = (~nucleus & celltype).sum()
            fn = (nucleus & ~celltype).sum()
            recall = tp/(tp+fn)
            precision = tp/(tp+fp)
            jaccard = tp/(tp+fp+fn)
            f1 = 2*recall*precision/(recall+precision)
            if precision>0.5 or recall>0.5 or f1>0.4:
                record = {
                    'nucleus': ccf_name,
                    'celltype': celltype_name,
                    'nucleus_precision': precision,
                    'nucleus_recall': recall,
                    'nucleus_f1': f1,
                    'jaccard':jaccard
                }
                records.append(record)
    return pd.DataFrame.from_records(records)

In [67]:
subclass_df = get_nucleus_celltype_metrics(obs_neurons, ccf_label, 'subclass', 
                                           ccf_list=th_subregion_names,
                                           celltype_list=th_celltypes['subclass'])

In [68]:
subclass_set = subclass_df.query('nucleus_precision>0.6 & nucleus_recall>0.15')
subclass_set.sort_values('nucleus_f1', ascending=False)

In [61]:
subclass_region_names = set(subclass_set['nucleus']).difference(['LGd-sh'])
nonsubclass_region_names = set(th_subregion_names).difference(subclass_region_names)

In [69]:
supertype_df = get_nucleus_celltype_metrics(obs_neurons, ccf_label, 'supertype', 
                                           ccf_list=nonsubclass_region_names,
                                           celltype_list=th_celltypes['supertype'])

In [70]:
supertype_set = supertype_df.query('nucleus_precision>0.6 & nucleus_recall>0.15')
supertype_set.sort_values('nucleus_f1', ascending=False)

In [71]:
cluster_df = get_nucleus_celltype_metrics(obs_neurons, ccf_label, 'cluster', 
                                           ccf_list=nonsubclass_region_names,
                                           celltype_list=th_celltypes['cluster'])

In [72]:
cluster_set = cluster_df.query('nucleus_precision>0.6 & nucleus_recall>0.1')
cluster_set.sort_values('nucleus_f1', ascending=False)

## original alignment

In [82]:
ccf_label = 'parcellation_substructure'

In [83]:
subclass_df = get_nucleus_celltype_metrics(obs_neurons, ccf_label, 'subclass', 
                                           ccf_list=th_subregion_names,
                                           celltype_list=th_celltypes['subclass'])

In [84]:
subclass_set = subclass_df.query('nucleus_precision>0.6 & nucleus_recall>0.15')
subclass_set.sort_values('nucleus_f1', ascending=False)

In [85]:
subclass_region_names = set(subclass_set['nucleus']).difference(['LGd-sh'])
nonsubclass_region_names = set(th_subregion_names).difference(subclass_region_names)

In [86]:
supertype_df = get_nucleus_celltype_metrics(obs_neurons, ccf_label, 'supertype', 
                                           ccf_list=nonsubclass_region_names,
                                           celltype_list=th_celltypes['supertype'])

In [87]:
supertype_set = supertype_df.query('nucleus_precision>0.6 & nucleus_recall>0.15')
supertype_set.sort_values('nucleus_f1', ascending=False)

In [88]:
cluster_df = get_nucleus_celltype_metrics(obs_neurons, ccf_label, 'cluster', 
                                           ccf_list=nonsubclass_region_names,
                                           celltype_list=th_celltypes['cluster'])

In [89]:
cluster_set = cluster_df.query('nucleus_precision>0.6 & nucleus_recall>0.1')
cluster_set.sort_values('nucleus_f1', ascending=False)