In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from thalamus_merfish_analysis.distance_metrics import *

In [3]:
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from thalamus_merfish_analysis import ccf_plots as cplots
from thalamus_merfish_analysis import ccf_images as cimg
from thalamus_merfish_analysis import abc_load as abc
get_ipython().run_line_magic('matplotlib', 'inline') 

In [None]:
adata = ad.read_zarr("/root/capsule/data/nsf_2000_adata1/nsf_2000_adata.zarr")
adata.X = np.log2(1 + adata.X*1e6/np.sum(adata.X.toarray(), axis=1, keepdims=True))

In [5]:
obs = adata.obs

section_col = 'z_section'
sections_all = sorted(obs[section_col].unique())
# pick 3 example sections
sections_3 = [6.4, 7.2, 8.0]

In [6]:
realigned = False
if realigned:
    ccf_label = 'parcellation_structure_realigned'
    coords = 'section'
else:
    ccf_label = 'parcellation_structure'
    coords = 'reconstructed'

## distance metrics

In [7]:
import sklearn.metrics
import scipy.spatial.distance as ssd
import scipy.cluster.hierarchy as sch

In [8]:
# factor totals correlated with gene totals??
np.corrcoef(obs['nsf_tot'].values, adata.X.sum(axis=1).squeeze())

In [9]:
nsf_cols = [f"nsf{i}" for i in range(30)]
# normalized or scaling by totals?
# metrics changed: cosine
# metrics unchanged: braycurtis
# factors = obs[nsf_cols].values * obs['nsf_tot'].values[:,None]
factors = obs[nsf_cols].values.T
loadings = adata.var[nsf_cols].values.T

In [10]:
# clustering of cells to factors
factors.sum(axis=0)

In [11]:
factors.sum(axis=1).mean()

In [12]:
# factors as distributions over genes
loadings.sum(axis=1)

## PP distance clustermaps

In [13]:
# y = ssd.pdist(factors, metric='braycurtis')
# dist = ssd.squareform(y)

# link = sch.linkage(y, method='complete', optimal_ordering=True)
# dend = sch.dendrogram(link, no_plot=True)
# nsf_order = dend["leaves"]

dist = pairwise_distances(factors, metric='braycurtis')
nsf_order = order_distances_by_clustering(dist)

plot_ordered_similarity_heatmap(dist, nsf_order)
# sns.heatmap(1-dist[np.ix_(nsf_order, nsf_order)], yticklabels=nsf_order, xticklabels=nsf_order,
#                 cbar_kws=dict(label="Dice coefficient (similarity)"), cmap='rocket_r', vmin=0, vmax=1)
# plt.savefig("/results/nsf_pattern_similarity.pdf", transparent=True)

In [14]:
# unchanged
# y = ssd.pdist(factors * obs['nsf_tot'].values[None,:], metric='braycurtis')

# a bit closer to cosine
dist_norm = pairwise_distances(factors / factors.sum(axis=1, keepdims=True), 
                               metric='braycurtis')

plot_ordered_similarity_heatmap(dist_norm, nsf_order)

In [15]:
dist_cosine = pairwise_distances(factors, 
                               metric='cosine')

plot_ordered_similarity_heatmap(dist_cosine, nsf_order)

### gene loadings distance

In [16]:
plt.rcParams.update({'font.size': 14})

In [19]:
X = loadings / loadings.sum(axis=0, keepdims=True)
X = loadings
dist_genes = pairwise_distances(X, 
                            #    metric='cosine',
                               metric='braycurtis',
                               )

plot_ordered_similarity_heatmap(dist_genes, nsf_order)

In [20]:
order = order_distances_by_clustering(dist_genes)
plt.figure(figsize=(8,7))
plot_ordered_similarity_heatmap(dist_genes, order, label="Dice coefficient", triangular=True)
plt.title("Gene pattern overlap")
plt.axis('image')

plt.xlabel("NSF pattern")
plt.savefig("/results/nsf_pattern_gene_similarity.pdf", transparent=True)

## CCF overlap

In [None]:
# th_names = abc.get_thalamus_names(level='structure')
# th_subregion_names = list(set(th_names).difference(['TH-unassigned']))
# ccf_label = 'parcellation_structure_realigned'
# obs[ccf_label] = obs[ccf_label].str.replace('-unassigned','')

# for x in th_subregion_names:
#     obs[x] = obs[ccf_label] == x
# th_subregions_found = list(set(th_subregion_names).intersection(obs[ccf_label].unique()))
# obs[th_subregion_names].mean(axis=0).loc[lambda x: x==0]

In [None]:
regions_subset = np.array([
    'AD',
    'AV',
    'LD',
    'LGd',
    # VP combine?
    'VPM',
    'VPL',
    'AM',
    'VPMpc',
    'MD',
    'LP',
    'PO',
    'IAD',
    'VAL',
    'VM',
    'RE',
    'CL',
    'PF',
    'CM',
    'PCN',
    'SPA',
    'IMD',
    'PVT',
    # additional
    'MH',
    'LH',
    'RT',
    # 'ZI'
])

In [None]:
regions_ccf_order = np.array([
    'VAL', 'VM', 'VPL', 'VPM', 'VPMpc', 'SPA', 'LGd', 'LP', 
    'PO', 'AV', 'AM', 'AD', 'IAD', 'LD', 'IMD', 'MD', 'PVT',
    'RE', 'CM', 'PCN', 'CL', 'PF', 'MH', 'LH', 'RT'])
regions_ot_clustering_order = np.array([
    'PF', 'VPMpc', 'RE', 'CM', 'IMD', 'MD', 'CL', 'PVT', 
    'SPA', 'AD', 'AV', 'AM', 'IAD', 'PCN', 'VAL', 
    'VM', 'PO', 'LP', 'LD', 'VPM', 'VPL', 'LGd', 'MH', 'LH', 'RT'])

In [41]:
regions_final = [
    "AD", "AV", "AM", "IAD", "LD", "VPM", "VPL", "LGd",
    "MD", "CL", "CM", "IMD", "PO", "LP", "VAL", "VM", 
    "RE", "PF", "VPMpc", "PCN", "SPA", "PVT", "MH", "LH", "RT", 
]

In [43]:
# order here is following thalamoseq fig 1
dist, y_names, x_names = cluster_distances_from_labels(
    obs, y_col=ccf_label, x_col=nsf_cols, 
    y_names=regions_final, x_names=range(30))

y_order, x_order = order_distances_x_to_y(dist, reorder_y=False)

plt.figure(figsize=(6,6))
plot_ordered_similarity_heatmap(dist, label="Dice coefficient",
                       y_order=y_order, x_order=x_order, 
                       y_names=y_names, x_names=x_names,)
plt.axis('image')
plt.title("Spatial pattern overlap")
plt.xlabel("NSF pattern")
plt.savefig("/results/nsf_ccf_similarity_heatmap_ccf_order.pdf", transparent=True)

In [68]:
# order here is following thalamoseq fig 1
dist, y_names, x_names = cluster_distances_from_labels(
    obs, y_col=ccf_label, x_col=nsf_cols, 
    y_names=regions_subset, x_names=range(30))

y_order, x_order = order_distances_x_to_y(dist, reorder_y=True)

plot_ordered_similarity_heatmap(dist, 
                       y_order=y_order, x_order=x_order, 
                       y_names=y_names, x_names=x_names)

plt.savefig("/results/nsf_ccf_similarity_heatmap_ccf_order.pdf", transparent=True)

In [69]:
dist, y_names, x_names = cluster_distances_from_labels(
    obs, y_col=ccf_label, x_col=nsf_cols, 
    y_names=regions_ot_clustering_order, x_names=range(30))

y_order, x_order = order_distances_x_to_y(dist, reorder_y=False)

plot_ordered_similarity_heatmap(dist, 
                       y_order=y_order, x_order=x_order, 
                       y_names=y_names, x_names=x_names)
plt.savefig("/results/nsf_ccf_similarity_heatmap_ot_order.pdf", transparent=True)

In [70]:
dist, y_names, x_names = cluster_distances_from_labels(
    obs, y_col=ccf_label, x_col=nsf_cols, 
    y_names=regions_ccf_order, x_names=range(30))

y_order, x_order = order_distances_x_to_y(dist, reorder_y=False)

plot_ordered_similarity_heatmap(dist, 
                       y_order=y_order, x_order=x_order, 
                       y_names=y_names, x_names=x_names)
plt.savefig("/results/nsf_ccf_similarity_heatmap_atlas_order.pdf", transparent=True)

In [71]:

# jaccard
# d = dist_nsf_ccf/(2-dist_nsf_ccf)

### ordered by NSF pattern clustering

In [72]:
fig, ax = plt.subplots(figsize=(7,4.5))

dist_nsf_ccf = dist.T[nsf_order,:]
subset = dist_nsf_ccf.min(axis=1) < 0.9

y_order, ccf_order = order_distances_x_to_y(dist_nsf_ccf[subset], reorder_y=True)
# ccf_order = plot_distances_sorted(dist_nsf_ccf[subset], nsf_order[subset], regions_subset, reorder_y=True)
plot_ordered_similarity_heatmap(dist_nsf_ccf[subset], 
                       y_order=y_order, x_order=ccf_order, 
                       y_names=nsf_order[subset], x_names=regions_subset)
# plt.savefig("/results/nsf_ccf_similarity.pdf")

In [73]:
dist = dist_nsf_ccf[:,ccf_order].T
y_order, x_order = order_distances_x_to_y(dist, reorder_y=False)

plot_ordered_similarity_heatmap(dist, 
                       y_order=y_order, x_order=x_order, 
                       y_names=regions_subset[ccf_order], x_names=nsf_order)

In [36]:
# TODO: try clustering ccf on gene-space similarity (or spatial?) to order instead (sns.clustermap?)

## spatial plotting over CCF section

In [15]:

ccf_images = abc.get_ccf_labels_image(resampled=True, realigned=realigned, subset_to_left_hemi=True)
# erase right hemisphere
# ccf_images[550:,:,:] = 0
ccf_boundaries = cimg.sectionwise_label_erosion(ccf_images, distance_px=1, fill_val=0, return_edges=True)

In [16]:
section = 7.2
top_patterns = obs.loc[obs[section_col]==section, nsf_cols].max(axis=0).loc[lambda x: x>0.25].sort_values(ascending=False)
top_patterns

In [17]:

def plot_patterns_overlay(obs, patterns, section_z, colors=cplots.glasbey):
    df = obs.loc[obs[section_col]==section_z]
    fig, ax = plt.subplots(figsize=(8,5))
    for i, col in enumerate(patterns):
        plt.scatter([], [], c=colors[i], label=patterns[i])
        sns.scatterplot(df, 
                        c=(np.array(cplots.to_rgba(colors[i]))[None,:] * 
                        (np.array([1,1,1,0]) + 
                            (np.array([0,0,0,1])*
                            #  (df[col]/df[col].max())
                            df[col]
                            .values[:,None]))), 
                        linewidth=0,
                        ax=ax, 
                        x = 'x_'+coords,
                        y = 'y_'+coords,
                        s=3, 
                        legend=False)
    ax.legend()
    cplots.plot_ccf_section(ccf_images, section_z, 
                            palette='dark_outline', boundary_img=ccf_boundaries, 
                            legend=False, ax=ax)
    cplots.format_image_axes()
    # plt.gca().invert_yaxis()


In [18]:

plot_patterns_overlay(obs, top_patterns.index[:10], section)
plt.savefig("/results/nsf_section_patterns_overlay.pdf", transparent=True)

In [30]:

kwargs = dict(
    section_col='z_section',
    x_col = 'x_'+coords,
    y_col = 'y_'+coords,
    s=0.05, 
    # face_palette=None,
    # edge_color='grey',
    boundary_img=ccf_boundaries
)

In [31]:
cols = top_patterns.index[:10]

fig, axes = plt.subplots(2, 5, figsize=(10,5))
axes = axes.ravel()

for i, col in enumerate(cols):
    cplots.plot_expression_ccf(obs, col, ccf_images, [section]
                                       cmap='Greens',
                                       colorbar=False, custom_xy_lims=[2.8, 5.8, 7, 4], **kwargs)
    # sns.scatterplot(df, hue=col, palette='Blues', ax=axes[i], **kwargs)
    # axes[i].invert_yaxis()
fig.savefig("/results/nsf_section_patterns_tiled.pdf", transparent=True)

## plotting individual patterns and genes

In [25]:

kwargs = dict(
    section_col='z_section',
    x_col = 'x_'+coords,
    y_col = 'y_'+coords,
    s=1, 
    # face_palette=None,
    # edge_color='grey',
    boundary_img=ccf_boundaries
)

In [38]:
# MD
# n=27, 7, 23 (or 0 for center)
n=7
# VM etc
# little gene overlap
# n=20


### AV

In [39]:
# AV
for n in [22, 24]:
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}', sections=[8.0],
                        point_palette='viridis', legend=None,
                        **kwargs);

In [26]:

fig, ax = plt.subplots(figsize=(5,5))
cplots.plot_expression_ccf_section(adata.obs, "nsf24", ccf_images, 8.0, #ax=axes[i],
                                    colorbar=False, custom_xy_lims=[2.8, 5.8, 7, 4],
                                    cmap="Greens",
                                    ax=ax, **kwargs)
fig.savefig(f"/results/nsf_pattern_av.pdf", transparent=True)

In [41]:

diff_genes = adata.var["nsf22"] - adata.var["nsf24"]
(-diff_genes).sort_values(ascending=False).head()

In [42]:
cols = (-diff_genes).sort_values(ascending=False).head(2).index
for i, col in enumerate(cols):
    fig, ax = plt.subplots(figsize=(5,5))
    adata.obs[col] = 2**adata[:,col].X.toarray().squeeze()
    cplots.plot_expression_ccf_section(adata.obs, col, ccf_images, 8.0, #ax=axes[i],
                                       colorbar=True, custom_xy_lims=[2.8, 5.8, 7, 4], label="CPM",
                                       ax=ax, **kwargs)
    fig.savefig(f"/results/nsf_genes_AV_{col}.pdf", transparent=True)

In [43]:

diff_genes = adata.var["nsf22"] - adata.var["nsf24"]
(diff_genes).sort_values(ascending=False).head()

In [44]:
cols = diff_genes.sort_values(ascending=False).head(2).index
for i, col in enumerate(cols):
    fig, ax = plt.subplots(figsize=(5,5))
    adata.obs[col] = 2**adata[:,col].X.toarray().squeeze()
    cplots.plot_expression_ccf_section(adata.obs, col, ccf_images, 8.0, #ax=axes[i],
                                       colorbar=True, custom_xy_lims=[2.8, 5.8, 7, 4], label="CPM",
                                       ax=ax, **kwargs)
    fig.savefig(f"/results/nsf_genes_AV_{col}.pdf", transparent=True)

### MD

In [45]:
for n in [7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}', sections=[section],
                        point_palette='viridis', legend=None,
                        **kwargs);

In [27]:

fig, ax = plt.subplots(figsize=(5,5))
cplots.plot_expression_ccf(adata.obs, "nsf23", ccf_images, [section], #ax=axes[i],
                                    colorbar=False, custom_xy_lims=[2.8, 5.8, 7, 4],
                                    cmap="Greens",
                                    ax=ax, **kwargs)
fig.savefig(f"/results/nsf_pattern_md.pdf", transparent=True)

In [47]:

diff_genes = adata.var["nsf23"] - adata.var["nsf27"] - adata.var["nsf7"]
diff_genes.sort_values(ascending=False).head()

In [48]:
reload(cplots)

In [63]:
cols = diff_genes.sort_values(ascending=False).head(3).index
for i, col in enumerate(cols):
    fig, ax = plt.subplots(figsize=(5,5))
    adata.obs[col] = 2**adata[:,col].X.toarray().squeeze()
    cplots.plot_expression_ccf(adata.obs, col, ccf_images, [section], #ax=axes[i],
                                       colorbar=True, custom_xy_lims=[2.8, 5.8, 7, 4], label="CPM",
                                       ax=ax, **kwargs)
    # fig.savefig(f"/results/nsf_genes_{col}.pdf", transparent=True)

## gene plots from loadings

Caution: Do not always look like factors!

In [50]:
loadings = adata.var[nsf_cols].to_numpy()
lin_weights = loadings @ np.linalg.pinv(loadings.T @ loadings)
# feat @ loadings.T = X, so feat ~= X @ lin_weights
adata.obsm['genes_on_loadings'] = adata.X @ loadings
adata.obsm['genes_linear_projection'] = adata.X @ lin_weights

In [51]:
N = loadings.shape[1]
# normalize by gene to compare patterns by gene?
# loadings_norm = loadings/loadings.sum(axis=1, keepdims=True)
loadings_norm = loadings
top_gene = np.zeros(N, dtype=int)
gene_prominence = np.zeros(N)
for n in range(N):
    pattern_prominence = loadings_norm[:,n] - np.delete(loadings_norm, n, axis=1).max(axis=1)
    n_gene = np.argmax(pattern_prominence)
    top_gene[n] = n_gene
    gene_prominence[n] = pattern_prominence[n_gene]
    obs[f"nsf{n}_1gene"] = 2**np.array(adata.X[:,n_gene])

In [52]:
np.nonzero(gene_prominence>0.1)

In [53]:
(top_gene == loadings.argmax(axis=0))[np.nonzero(gene_prominence>0.1)]

In [54]:

kwargs = dict(
    section_col='z_section',
    x_col = 'x_'+coords,
    y_col = 'y_'+coords,
    s=1, 
    face_palette=None,
    edge_color='grey',
    boundary_img=ccf_boundaries
)

### AV

In [55]:
for n in[22, 24]:
    print(adata.var_names[top_gene[n]])
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}_1gene', sections=[8.0], 
                        point_palette='Blues', legend=None, 
                        **kwargs);

### MD

In [56]:
for n in[7, 27, 23]:
    print(adata.var_names[top_gene[n]])
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}_1gene', sections=[section], 
                        point_palette='Blues', legend=None, 
                        **kwargs);

In [57]:
# can't plot with obsm currently

# cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
#                         point_hue=adata.obsm['genes_on_loadings'][:,0], 
#                         sections=sections_3,
#                         face_palette=None,
#                         edge_color='grey', 
#                         point_palette='viridis', legend=None, 
#                         section_col=section_col, 
#                         x_col=x_coord_col, y_col=y_coord_col, s=3,
#                         boundary_img=ccf_boundaries);

In [58]:

for i, x in enumerate(nsf_cols):
    l = adata.var[x].to_numpy()[:,None]
    obs[x+"_allgenes"] = adata.X @ l
    order = l.argsort()
    l[order[:-3]] = 0
    obs[x+"_3genes"] = adata.X @ l
    l[order[:-2]] = 0
    obs[x+"_2genes"] = adata.X @ l
    obs[x+"_linear_genes"] = adata.X @ lin_weights[:,[i]]


In [59]:
for n in[7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}_linear_genes', sections=[section], 
                        point_palette='viridis', legend=None, 
                        **kwargs);

In [60]:
for n in[7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}_allgenes', sections=[section], 
                        point_palette='viridis', legend=None, 
                        **kwargs);

In [61]:
for n in[7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}_3genes', sections=[section], 
                        point_palette='viridis', legend=None, 
                        **kwargs);

In [62]:
for n in[7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}_2genes', sections=[section], 
                        point_palette='viridis', legend=None, 
                        **kwargs);