In [1]:
import anndata as ad
import pandas as pd
import scanpy as sc
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors
import seaborn as sns

In [2]:
# Load MERFISH data
merData = ad.read_h5ad("/data/merfish_609882_AIT17.1_annotated_TH_ZI_only_2023-02-16_00-00-00/atlas_brain_609882_AIT17_1_annotated_TH_ZI_only.h5ad")

# Subset to neuronal
merData = merData[merData.obs["division_id_label"].isin(["3 PAL-sAMY-TH-HY-MB-HB neuronal"])]

In [3]:
merData.obs['Level1_id_label'].value_counts(normalize=False).loc[lambda x: x>100]

In [4]:
# Load sequencing data
seqData = ad.read_h5ad("/data/rnaseq_AIT17.2_2022-12-15_12-00-00/rnaseq/processed.U19_TH-EPI.postQC.h5ad", backed=False)

data = seqData[seqData.obs['Level1_id_label']=='8_TH']

In [5]:
data

In [7]:
data.obs['nt_type_label'].value_counts()

In [9]:
supertypes = data.obs['supertype_id_label'].value_counts(normalize=False).loc[lambda x: x>0]
supertypes.sort_index()

In [10]:
sorted(data.obs['Level2_id_label'].unique())

In [11]:
expand = ['65_AD Serpinb7 Glut', '66_AV Col27a1 Glut', '69_TH Prkcd Grin2c Glut']
data.obs['grouped_cell_type'] = data.obs['Level2_id_label'].astype(str)
data.obs.loc[lambda df: df['grouped_cell_type'].isin(expand), 'grouped_cell_type'] = data.obs.loc[lambda df: df['grouped_cell_type'].isin(expand), 'supertype_id_label']
data.obs['grouped_cell_type'].value_counts()

In [12]:
len(data.obs['grouped_cell_type'].unique())**2

In [13]:
data.obs['method'].value_counts()

In [23]:
import dprime
import scipy.spatial.distance as distance
from diskcache import Cache

cache = Cache("/scratch/cache")

@cache.memoize()
def tx_dprime(cluster_label, features=None, type_list=None, n_folds=5, r=3, zero_inflated=True, n_subsample=1000, **kwargs ):
    global data
    adata = data[data.obs.groupby(cluster_label).sample(n_subsample, replace=True).index]
    df = adata.obs
    type_labels = df[cluster_label]
    data = adata.X if features is None else adata[:, features].X
    if type_list is None:
        type_list = type_labels.unique()
    # d-prime calculation
    if zero_inflated:
        dprime_results = dprime.zinb_dprime(
            data, type_list, type_labels, n_folds=n_folds,
            r=r, **kwargs)
    else:
        dprime_results = dprime.negative_binomial_dprime(
            data, type_list, type_labels, n_folds=n_folds,
            r=r, **kwargs)

    dprime_mat = distance.squareform(
        [np.abs(dprime_results[k]["dprime"]) for k in dprime_results])
    output_dprime_df = pd.DataFrame(dprime_mat, index=type_list, columns=type_list)
#     return output_dprime_df, dprime_results
    return output_dprime_df


In [24]:
result = tx_dprime(cluster_label='grouped_cell_type', zero_inflated=False, r=1)

In [45]:
result.to_csv("../resources/grouped_supertype_dprime.csv")

## load and plot

In [2]:
result = pd.read_csv("../resources/grouped_supertype_dprime.csv", index_col=0)

In [13]:
import re
labels = dprime.columns
labels = [' '.join(re.split(' |_', x)[1:]) for x in labels]
labels

In [14]:
dprime = result.copy()

dprime.index = labels
dprime.columns = labels
order = ['AD Serpinb7 Glut',
 'AV Col27a1 Glut',
 'TH Prkcd Grin2c Glut 9',
 'TH Prkcd Grin2c Glut 13',
 'TH Prkcd Grin2c Glut 10',
 'TH Prkcd Grin2c Glut 4',
 'TH Prkcd Grin2c Glut 6',
 'TH Prkcd Grin2c Glut 7',
 'TH Prkcd Grin2c Glut 11',
 'TH Prkcd Grin2c Glut 12',
 'TH Prkcd Grin2c Glut 8',
 'TH Prkcd Grin2c Glut 14',
 'TH Prkcd Grin2c Glut 1',
 'TH Prkcd Grin2c Glut 2',
 'TH Prkcd Grin2c Glut 3',
 'TH Prkcd Grin2c Glut 5',
 'PVT-PT Ntrk1 Glut',
 'CM-IAD-CL-PCN Glut',
 'RE-Xi Nox4 Glut',
 'MG-POL-SGN Glut',
 'PF Fzd5 Glut',
 'MH Tac2 Glut',
 'LH Pou4f1 Sox1 Glut',
        ]
dprime = dprime.loc[order, order]
order[order.index('TH Prkcd Grin2c Glut 9')] += ' (AM)'
order[order.index('TH Prkcd Grin2c Glut 10')] += ' (MD)'
order[order.index('TH Prkcd Grin2c Glut 13')] += ' (VM/VAL)'
dprime.index = order
dprime.columns = order

In [15]:
plt.figure(figsize=(7,7))
sns.heatmap(dprime, cmap='viridis_r', vmin=0, vmax=4, cbar=True, cbar_kws=dict(label="distinctness d'"))
plt.axis('image')

In [19]:
from scipy.cluster import hierarchy
from scipy.spatial.distance import squareform

X = squareform(dprime.values)
Z = hierarchy.linkage(X, method='average')
order = hierarchy.leaves_list(hierarchy.optimal_leaf_ordering(Z, X))

In [20]:
plt.figure(figsize=(7,7))
sns.heatmap(dprime.iloc[order, order], cmap='viridis_r', vmin=0, vmax=4, cbar=True, cbar_kws=dict(label="distinctness d'"))
plt.axis('image')

In [31]:
bold = ['AD Serpinb7 Glut',
 'AV Col27a1 Glut',
 'TH Prkcd Grin2c Glut 9 (AM)',
 'TH Prkcd Grin2c Glut 13 (VM/VAL)',
 'TH Prkcd Grin2c Glut 10 (MD)',
        ]

In [34]:
Z = hierarchy.linkage(X, method='single')
order = hierarchy.leaves_list(hierarchy.optimal_leaf_ordering(Z, X))

plt.figure(figsize=(7,6))
sns.heatmap(dprime.iloc[order, order], cmap='viridis_r', vmin=0, vmax=4, cbar=True, cbar_kws=dict(label="distinctness d'"))
plt.axis('image')
plt.xticks

ax = plt.gca()
from matplotlib.patches import Rectangle
args = dict(linewidth=1.5, edgecolor='red', facecolor='none')
k = dprime.shape[0]
boxes = [
    ax.add_patch(Rectangle((0,4), k, 1, **args)),
    ax.add_patch(Rectangle((0,10), k, 1, **args)),
    ax.add_patch(Rectangle((0,18), k, 3, **args)),
]
labels = dprime.index[order]
labels = [r"$\mathbf{" + x.replace(' ', '\ ') + "}$" if x in bold else x for x in labels]
ax.set_yticklabels(labels)
ax.tick_params(labelbottom=False) 
plt.show()

In [51]:
types = ["277 TH Prkcd Grin2c Glut_3",
         "280 TH Prkcd Grin2c Glut_6"
        ]

In [30]:
facs_dprime = tx_dprime(seqData, cluster_label='supertype_id_label', type_list=types, zero_inflated=False, r=1)

In [31]:
facs_dprime

In [54]:
# higher r / lower dispersion
tx_dprime(seqData, cluster_label='supertype_id_label', type_list=types, zero_inflated=False, r=10)

In [32]:
tx_dprime(seqData, cluster_label='supertype_id_label', features=seqData.var_names[:10000], type_list=types, zero_inflated=False, r=1)

In [41]:
sc.pp.highly_variable_genes(seqData, flavor='seurat_v3', n_top_genes=10000)

In [42]:
tx_dprime(seqData, cluster_label='supertype_id_label', features=seqData.var.query('highly_variable').index, type_list=types, zero_inflated=False, r=1)

In [43]:
hvg = sc.experimental.pp.highly_variable_genes(seqData, n_top_genes=10000, inplace=False)

In [45]:
tx_dprime(seqData, cluster_label='supertype_id_label', features=hvg.query('highly_variable').index, type_list=types, zero_inflated=False, r=1)