In [1]:
# %load_ext autoreload
# %autoreload 2

In [2]:
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from thalamus_merfish_analysis import ccf_plots as cplots
from thalamus_merfish_analysis import ccf_images as cimg
from thalamus_merfish_analysis import abc_load as abc
get_ipython().run_line_magic('matplotlib', 'inline') 

In [3]:
adata = ad.read_zarr("/root/capsule/data/nsf_2000_adata1/nsf_2000_adata.zarr")
adata.X = np.log2(1 + adata.X*1e6/np.sum(adata.X.toarray(), axis=1, keepdims=True))

In [4]:
obs = adata.obs

section_col = 'z_section'
sections_all = sorted(obs[section_col].unique())
# pick 3 example sections
sections_3 = [6.4, 7.2, 8.0]

In [5]:
realigned = False
if realigned:
    ccf_label = 'parcellation_structure_realigned'
    coords = 'section'
else:
    ccf_label = 'parcellation_structure'
    coords = 'reconstructed'

## distance metrics

In [6]:
import sklearn.metrics
import scipy.spatial.distance as ssd
import scipy.cluster.hierarchy as sch

In [7]:
# factor totals correlated with gene totals??
np.corrcoef(obs['nsf_tot'].values, adata.X.sum(axis=1).squeeze())

In [8]:
nsf_cols = [f"nsf{i}" for i in range(30)]
# normalized or scaling by totals?
# metrics changed: cosine
# metrics unchanged: braycurtis
# factors = obs[nsf_cols].values * obs['nsf_tot'].values[:,None]
factors = obs[nsf_cols].values.T
loadings = adata.var[nsf_cols].values.T

In [9]:
# clustering of cells to factors
factors.sum(axis=0)

In [10]:
factors.sum(axis=1).mean()

In [11]:
# factors as distributions over genes
loadings.sum(axis=1)

## plotting

In [12]:

ccf_polygons = abc.get_ccf_labels_image(resampled=True, realigned=realigned)
# erase right hemisphere
ccf_polygons[550:,:,:] = 0
ccf_boundaries = cimg.sectionwise_label_erosion(ccf_polygons, distance_px=1, fill_val=0, return_edges=True)

In [13]:
cplots.plot_ccf_section_raster

In [14]:
section = 7.2
top_patterns = obs.loc[obs[section_col]==section, nsf_cols].max(axis=0).loc[lambda x: x>0.25].sort_values(ascending=False)
top_patterns

In [15]:
from importlib import reload
reload(cplots)

In [16]:

def plot_patterns_overlay(obs, patterns, section_z, colors=cplots.glasbey):
    df = obs.loc[obs[section_col]==section_z]
    fig, ax = plt.subplots(figsize=(8,5))
    for i, col in enumerate(patterns):
        plt.scatter([], [], c=colors[i], label=patterns[i])
        sns.scatterplot(df, 
                        c=(np.array(cplots.to_rgba(colors[i]))[None,:] * 
                        (np.array([1,1,1,0]) + 
                            (np.array([0,0,0,1])*
                            #  (df[col]/df[col].max())
                            df[col]
                            .values[:,None]))), 
                        linewidth=0,
                        ax=ax, 
                        x = 'x_'+coords,
                        y = 'y_'+coords,
                        s=3, 
                        legend=False)
    ax.legend()
    cplots.plot_ccf_section_raster(ccf_polygons, section_z, 
                                   palette='dark_outline', boundary_img=ccf_boundaries, 
                            legend=False, ax=ax)
    cplots.format_image_axes(axes=False)
    # plt.gca().invert_yaxis()


In [17]:

plot_patterns_overlay(obs, top_patterns.index[:10], section)
plt.savefig("/results/nsf_section_patterns_overlay.pdf", transparent=True)

In [18]:

kwargs = dict(
    section_col='z_section',
    x_col = 'x_'+coords,
    y_col = 'y_'+coords,
    s=0.1, 
    shape_palette='dark_outline',
    boundary_img=ccf_boundaries
)

In [19]:
cols = top_patterns.index[:10]

fig, axes = plt.subplots(2, 5, figsize=(10,5))
axes = axes.ravel()

for i, col in enumerate(cols):
    cplots.plot_expression_ccf_section(obs, col, ccf_polygons, section, ax=axes[i],
                                       colorbar=False, set_lims=[2.8, 5.8, 7, 4], **kwargs)
    # sns.scatterplot(df, hue=col, palette='Blues', ax=axes[i], **kwargs)
    # axes[i].invert_yaxis()
fig.savefig("/results/nsf_section_patterns_tiled.pdf", transparent=True)

In [20]:
# 13, 
patterns = [f"nsf{i}" for i in np.array([1,2,3,4,18,21,24,])-1]

## PP distance clustermaps

In [21]:
y = ssd.pdist(factors, metric='braycurtis')
dist = ssd.squareform(y)

link = sch.linkage(y, method='complete', optimal_ordering=True)
dend = sch.dendrogram(link, no_plot=True)
nsf_order = dend["leaves"]

sns.heatmap(1-dist[np.ix_(nsf_order, nsf_order)], yticklabels=nsf_order, xticklabels=nsf_order,
                cbar_kws=dict(label="Dice coefficient (similarity)"), cmap='rocket_r')
plt.savefig("/results/nsf_pattern_similarity.pdf", transparent=True)

In [22]:
def plot_ordered_distances(dist, xorder, yorder=None):
    if yorder is None: yorder = xorder
    sns.heatmap(dist[np.ix_(yorder, xorder)], yticklabels=yorder, xticklabels=xorder)

In [23]:
# unchanged
# y = ssd.pdist(factors * obs['nsf_tot'].values[None,:], metric='braycurtis')
# a bit closer to cosine
y = ssd.pdist(factors / factors.sum(axis=1, keepdims=True), metric='braycurtis')
dist = ssd.squareform(y)

sns.heatmap(dist[np.ix_(nsf_order, nsf_order)], yticklabels=nsf_order, xticklabels=nsf_order)

In [24]:
y = ssd.pdist(factors, metric='cosine')
dist = ssd.squareform(y)

sns.heatmap(dist[np.ix_(nsf_order, nsf_order)], yticklabels=nsf_order, xticklabels=nsf_order)

### gene loadings distance

In [25]:
y = ssd.pdist(loadings, metric='cosine')
dist = ssd.squareform(y)

sns.heatmap(dist[np.ix_(nsf_order, nsf_order)], yticklabels=nsf_order, xticklabels=nsf_order)

In [26]:
link = sch.linkage(y, method='single', optimal_ordering=True)
dend = sch.dendrogram(link, no_plot=True)
order = dend["leaves"]

sns.heatmap(dist[np.ix_(order, order)], yticklabels=order, xticklabels=order)

## CCF overlap

In [27]:
th_names = abc.get_thalamus_names(level='structure')
th_subregion_names = list(set(th_names).difference(['TH-unassigned']))

In [28]:
ccf_label = 'parcellation_structure_realigned'
obs[ccf_label] = obs[ccf_label].str.replace('-unassigned','')

for x in th_subregion_names:
    obs[x] = obs[ccf_label] == x

In [29]:
th_subregions_found = list(set(th_subregion_names).intersection(obs[ccf_label].unique()))
obs[th_subregion_names].mean(axis=0).loc[lambda x: x==0]

In [30]:
regions_subset = np.array([
    'AD',
    'AV',
    'LD',
    'LGd',
    # VP combine?
    'VPM',
    'VPL',
    'AM',
    'VPMpc',
    'MD',
    'LP',
    'PO',
    'IAD',
    'VAL',
    'VM',
    'RE',
    'CL',
    'PF',
    'CM',
    'PCN',
    'SPA',
    'IMD',
    'PVT',
    # additional
    'MH',
    'LH',
    'RT',
    # 'ZI'
])

In [31]:
# TODO: lookup sort from cross-species alignment
regions = obs[regions_subset].values.T.astype(float)
# ssd.squareform(ssd.pdist())
dist_nsf_ccf = sklearn.metrics.pairwise_distances(factors[nsf_order], regions, metric='braycurtis')
# jaccard
# d = dist_nsf_ccf/(2-dist_nsf_ccf)

In [32]:

def plot_distance_matrix_sorted_columns(D, y_labels, x_labels, reorder_y=False):
    argmin = D.argmin(axis=0)
    min_dist = D.min(axis=0)
    x_order = np.argsort(argmin + min_dist)
    # y_show = np.unique(argmin) if show_matched_y_only else slice(None)
    
    
    order = list(range(D.shape[0]))
    if reorder_y:
    # reorder rows without a match 
        y_old = [i for i in range(D.shape[0]) if i not in np.unique(argmin) 
                 or min_dist[list(argmin).index(i)]>0.5]
        y_new = D[:, D[y_old, :].argmin(axis=1)].argmin(axis=0)
        for i in np.argsort(y_new)[::-1]:
            inew = y_new[i] + 1
            order = order[:inew] + [-1] + order[inew:]
            order.remove(y_old[i])
            order[order.index(-1)] = y_old[i]

    sns.heatmap(1-D[order,:][:,x_order], yticklabels=np.array(y_labels)[order], xticklabels=np.array(x_labels)[x_order], 
                cbar_kws=dict(label="Dice coefficient (similarity)"), cmap='rocket_r')
    plt.ylabel("NSF patterns")
    return x_order


In [33]:
fig, ax = plt.subplots(figsize=(7,4.5))

nsf_order = np.array(nsf_order)
subset = dist_nsf_ccf.min(axis=1) < 0.9
ccf_order = plot_distance_matrix_sorted_columns(dist_nsf_ccf[subset], nsf_order[subset], regions_subset, reorder_y=True)

plt.savefig("/results/nsf_ccf_similarity.pdf")

In [34]:

fig, ax = plt.subplots(figsize=(7,6))
x_order = plot_distance_matrix_sorted_columns(dist_nsf_ccf[:,ccf_order].T, regions_subset[ccf_order], nsf_order)

In [35]:
# regions /= regions.sum(axis=0)

# ssd.squareform(ssd.pdist())
dist = sklearn.metrics.pairwise_distances(factors[nsf_order], regions, metric='cosine')

# min_dist = dist_nsf_ccf.min(axis=0)
# ccf_sort = np.argsort(dist_nsf_ccf.argmin(axis=0) + min_dist)

fig, ax = plt.subplots(figsize=(10,6))
sns.heatmap(dist[:,ccf_order], yticklabels=nsf_order, xticklabels=np.array(th_subregions_found)[ccf_order], ax=ax)

In [36]:
# TODO: try clustering ccf on gene-space similarity (or spatial?) to order instead (sns.clustermap?)

## plotting individual patterns and genes

In [37]:

kwargs = dict(
    section_col='z_section',
    x_col = 'x_'+coords,
    y_col = 'y_'+coords,
    s=1, 
    shape_palette='dark_outline',
    boundary_img=ccf_boundaries
)

In [38]:
# MD
# n=27, 7, 3 (or 0 for center)
n=7
# VM etc
# little gene overlap
# n=20


### AV

In [39]:
# AV
for n in [22, 24]:
    cplots.plot_ccf_overlay(obs, ccf_polygons, categorical=False,
                        point_hue=f'nsf{n}', sections=[8.0],
                        point_palette='viridis', legend=None,
                        **kwargs);

In [40]:

fig, ax = plt.subplots(figsize=(5,5))
cplots.plot_expression_ccf_section(adata.obs, "nsf24", ccf_polygons, 8.0, #ax=axes[i],
                                    colorbar=False, set_lims=[2.8, 5.8, 7, 4],
                                    ax=ax, **kwargs)
fig.savefig(f"/results/nsf_pattern_av.pdf", transparent=True)

In [41]:

diff_genes = adata.var["nsf22"] - adata.var["nsf24"]
(-diff_genes).sort_values(ascending=False).head()

In [42]:
cols = (-diff_genes).sort_values(ascending=False).head(2).index
for i, col in enumerate(cols):
    fig, ax = plt.subplots(figsize=(5,5))
    adata.obs[col] = 2**adata[:,col].X.toarray().squeeze()
    cplots.plot_expression_ccf_section(adata.obs, col, ccf_polygons, 8.0, #ax=axes[i],
                                       colorbar=True, set_lims=[2.8, 5.8, 7, 4], label="CPM",
                                       ax=ax, **kwargs)
    fig.savefig(f"/results/nsf_genes_AV_{col}.pdf", transparent=True)

In [43]:

diff_genes = adata.var["nsf22"] - adata.var["nsf24"]
(diff_genes).sort_values(ascending=False).head()

In [44]:
cols = diff_genes.sort_values(ascending=False).head(2).index
for i, col in enumerate(cols):
    fig, ax = plt.subplots(figsize=(5,5))
    adata.obs[col] = 2**adata[:,col].X.toarray().squeeze()
    cplots.plot_expression_ccf_section(adata.obs, col, ccf_polygons, 8.0, #ax=axes[i],
                                       colorbar=True, set_lims=[2.8, 5.8, 7, 4], label="CPM",
                                       ax=ax, **kwargs)
    fig.savefig(f"/results/nsf_genes_AV_{col}.pdf", transparent=True)

### MD

In [45]:
for n in [7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_polygons, categorical=False,
                        point_hue=f'nsf{n}', sections=[section],
                        point_palette='viridis', legend=None,
                        **kwargs);

In [46]:

fig, ax = plt.subplots(figsize=(5,5))
cplots.plot_expression_ccf_section(adata.obs, "nsf23", ccf_polygons, section, #ax=axes[i],
                                    colorbar=False, set_lims=[2.8, 5.8, 7, 4],
                                    ax=ax, **kwargs)
fig.savefig(f"/results/nsf_pattern_md.pdf", transparent=True)

In [47]:

diff_genes = adata.var["nsf23"] - adata.var["nsf27"] - adata.var["nsf7"]
diff_genes.sort_values(ascending=False).head()

In [48]:
reload(cplots)

In [63]:
cols = diff_genes.sort_values(ascending=False).head(3).index
for i, col in enumerate(cols):
    fig, ax = plt.subplots(figsize=(5,5))
    adata.obs[col] = 2**adata[:,col].X.toarray().squeeze()
    cplots.plot_expression_ccf_section(adata.obs, col, ccf_polygons, section, #ax=axes[i],
                                       colorbar=True, set_lims=[2.8, 5.8, 7, 4], label="CPM",
                                       ax=ax, **kwargs)
    # fig.savefig(f"/results/nsf_genes_{col}.pdf", transparent=True)

## gene plots from loadings

Caution: Do not always look like factors!

In [50]:
loadings = adata.var[nsf_cols].to_numpy()
lin_weights = loadings @ np.linalg.pinv(loadings.T @ loadings)
# feat @ loadings.T = X, so feat ~= X @ lin_weights
adata.obsm['genes_on_loadings'] = adata.X @ loadings
adata.obsm['genes_linear_projection'] = adata.X @ lin_weights

In [51]:
N = loadings.shape[1]
# normalize by gene to compare patterns by gene?
# loadings_norm = loadings/loadings.sum(axis=1, keepdims=True)
loadings_norm = loadings
top_gene = np.zeros(N, dtype=int)
gene_prominence = np.zeros(N)
for n in range(N):
    pattern_prominence = loadings_norm[:,n] - np.delete(loadings_norm, n, axis=1).max(axis=1)
    n_gene = np.argmax(pattern_prominence)
    top_gene[n] = n_gene
    gene_prominence[n] = pattern_prominence[n_gene]
    obs[f"nsf{n}_1gene"] = 2**np.array(adata.X[:,n_gene])

In [52]:
np.nonzero(gene_prominence>0.1)

In [53]:
(top_gene == loadings.argmax(axis=0))[np.nonzero(gene_prominence>0.1)]

In [54]:

kwargs = dict(
    section_col='z_section',
    x_col = 'x_'+coords,
    y_col = 'y_'+coords,
    s=1, 
    shape_palette='dark_outline',
    boundary_img=ccf_boundaries
)

### AV

In [55]:
for n in[22, 24]:
    print(adata.var_names[top_gene[n]])
    cplots.plot_ccf_overlay(obs, ccf_polygons, categorical=False,
                        point_hue=f'nsf{n}_1gene', sections=[8.0], 
                        point_palette='Blues', legend=None, 
                        **kwargs);

### MD

In [56]:
for n in[7, 27, 23]:
    print(adata.var_names[top_gene[n]])
    cplots.plot_ccf_overlay(obs, ccf_polygons, categorical=False,
                        point_hue=f'nsf{n}_1gene', sections=[section], 
                        point_palette='Blues', legend=None, 
                        **kwargs);

In [57]:
# can't plot with obsm currently

# cplots.plot_ccf_overlay(obs, ccf_polygons, categorical=False,
#                         point_hue=adata.obsm['genes_on_loadings'][:,0], 
#                         sections=sections_3,
#                         shape_palette='dark_outline', 
#                         point_palette='viridis', legend=None, 
#                         section_col=section_col, 
#                         x_col=x_coord_col, y_col=y_coord_col, s=3,
#                         boundary_img=ccf_boundaries);

In [58]:

for i, x in enumerate(nsf_cols):
    l = adata.var[x].to_numpy()[:,None]
    obs[x+"_allgenes"] = adata.X @ l
    order = l.argsort()
    l[order[:-3]] = 0
    obs[x+"_3genes"] = adata.X @ l
    l[order[:-2]] = 0
    obs[x+"_2genes"] = adata.X @ l
    obs[x+"_linear_genes"] = adata.X @ lin_weights[:,[i]]


In [59]:
for n in[7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_polygons, categorical=False,
                        point_hue=f'nsf{n}_linear_genes', sections=[section], 
                        point_palette='viridis', legend=None, 
                        **kwargs);

In [60]:
for n in[7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_polygons, categorical=False,
                        point_hue=f'nsf{n}_allgenes', sections=[section], 
                        point_palette='viridis', legend=None, 
                        **kwargs);

In [61]:
for n in[7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_polygons, categorical=False,
                        point_hue=f'nsf{n}_3genes', sections=[section], 
                        point_palette='viridis', legend=None, 
                        **kwargs);

In [62]:
for n in[7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_polygons, categorical=False,
                        point_hue=f'nsf{n}_2genes', sections=[section], 
                        point_palette='viridis', legend=None, 
                        **kwargs);