In [1]:
import anndata as ad
import matplotlib.pyplot as plt
import numpy as np

from thalamus_merfish_analysis import ccf_plots as cplots
from thalamus_merfish_analysis import ccf_images as cimg
from thalamus_merfish_analysis import abc_load as abc
from thalamus_merfish_analysis import distance_metrics as dm
get_ipython().run_line_magic('matplotlib', 'inline') 
cplots.CCF_REGIONS_DEFAULT = abc.get_thalamus_names()

In [2]:
# load NSF results
adata = ad.read_zarr("/root/capsule/data/nsf_2000_adata/nsf_2000_adata.zarr")
adata.X = np.log2(1 + adata.X.toarray()*1e3/np.sum(adata.X.toarray(), axis=1, keepdims=True))

In [3]:
# filter by thalamus coordinates
adata = abc.filter_by_thalamus_coords(adata)

# get the cell metadata as obs
obs = adata.obs.copy()

In [4]:
realigned = False
if realigned:
    ccf_label = 'parcellation_structure_realigned'
    coords = 'section'
else:
    ccf_label = 'parcellation_structure'
    coords = 'reconstructed'

In [5]:
x_col = 'x_'+coords
y_col = 'y_'+coords
section_col = 'z_'+coords

sections_all = sorted(obs[section_col].unique())

In [6]:
nsf_cols = [f"nsf{i}" for i in range(30)]
# normalized or scaling by totals?
# metrics changed: cosine
# metrics unchanged: braycurtis
# factors = obs[nsf_cols].values * obs['nsf_tot'].values[:,None]
factors = obs[nsf_cols].values.T
loadings = adata.var[nsf_cols].values.T

## distance metrics

In [7]:
# factor totals correlated with gene totals??
np.corrcoef(obs['nsf_tot'].values, adata.X.sum(axis=1).squeeze())

In [8]:
# clustering of cells to factors
factors.sum(axis=0)

In [9]:
factors.sum(axis=1).mean()

In [10]:
# factors as distributions over genes
loadings.sum(axis=1)

## PP distance clustermaps

In [11]:
from sklearn.metrics import pairwise_distances
# y = ssd.pdist(factors, metric='braycurtis')
# dist = ssd.squareform(y)

# link = sch.linkage(y, method='complete', optimal_ordering=True)
# dend = sch.dendrogram(link, no_plot=True)
# nsf_order = dend["leaves"]

dist = pairwise_distances(factors, 
                          metric='braycurtis')
nsf_order = dm.order_distances_by_clustering(dist)

fig = dm.plot_ordered_similarity_heatmap(dist, 
                                         nsf_order, 
                                         label='similarity (Bray-Curtis)')
# sns.heatmap(1-dist[np.ix_(nsf_order, nsf_order)], yticklabels=nsf_order, xticklabels=nsf_order,
#                 cbar_kws=dict(label="Dice coefficient (similarity)"), cmap='rocket_r', vmin=0, vmax=1)
# plt.savefig("/results/nsf_pattern_similarity.pdf", transparent=True)

In [12]:
# unchanged
# y = ssd.pdist(factors * obs['nsf_tot'].values[None,:], metric='braycurtis')

# a bit closer to cosine
dist_norm = pairwise_distances(factors / factors.sum(axis=1, keepdims=True), 
                               metric='braycurtis')

fig = dm.plot_ordered_similarity_heatmap(dist_norm, 
                                         nsf_order, 
                                         label='similarity (Bray-Curtis)')

In [13]:
dist_cosine = pairwise_distances(factors, 
                                 metric='cosine')

fig = dm.plot_ordered_similarity_heatmap(dist_cosine, 
                                         nsf_order, 
                                         label='similarity (cosine)')

### gene loadings distance

In [14]:
X = loadings / loadings.sum(axis=0, keepdims=True)
X = loadings
dist_genes = pairwise_distances(X, 
                            #    metric='cosine',
                               metric='braycurtis',
                               )

fig = dm.plot_ordered_similarity_heatmap(dist_genes, 
                                         nsf_order, 
                                         label='similarity (Bray-Curtis)')

In [15]:
order = dm.order_distances_by_clustering(dist_genes)
dm.plot_ordered_similarity_heatmap(dist_genes, 
                                   order, 
                                   label='similarity (Bray-Curtis)', 
                                   triangular=True, 
                                #    figsize=(8,7),
                                   )
plt.title("Gene pattern overlap")
plt.axis('image')

plt.xlabel("NSF pattern")
plt.savefig("/results/nsf_pattern_gene_similarity.pdf", transparent=True)

## CCF overlap

In [16]:
# th_names = abc.get_thalamus_names(level='structure')
# th_subregion_names = list(set(th_names).difference(['TH-unassigned']))
# ccf_label = 'parcellation_structure_realigned'
# obs[ccf_label] = obs[ccf_label].str.replace('-unassigned','')

# for x in th_subregion_names:
#     obs[x] = obs[ccf_label] == x
# th_subregions_found = list(set(th_subregion_names).intersection(obs[ccf_label].unique()))
# obs[th_subregion_names].mean(axis=0).loc[lambda x: x==0]

In [17]:
regions_subset = np.array([
    'AD',
    'AV',
    'LD',
    'LGd',
    # VP combine?
    'VPM',
    'VPL',
    'AM',
    'VPMpc',
    'MD',
    'LP',
    'PO',
    'IAD',
    'VAL',
    'VM',
    'RE',
    'CL',
    'PF',
    'CM',
    'PCN',
    'SPA',
    'IMD',
    'PVT',
    # additional
    'MH',
    'LH',
    'RT',
    # 'ZI'
])

In [18]:
regions_ccf_order = np.array([
    'VAL', 'VM', 'VPL', 'VPM', 'VPMpc', 'SPA', 'LGd', 'LP', 
    'PO', 'AV', 'AM', 'AD', 'IAD', 'LD', 'IMD', 'MD', 'PVT',
    'RE', 'CM', 'PCN', 'CL', 'PF', 'MH', 'LH', 'RT'])
regions_ot_clustering_order = np.array([
    'PF', 'VPMpc', 'RE', 'CM', 'IMD', 'MD', 'CL', 'PVT', 
    'SPA', 'AD', 'AV', 'AM', 'IAD', 'PCN', 'VAL', 
    'VM', 'PO', 'LP', 'LD', 'VPM', 'VPL', 'LGd', 'MH', 'LH', 'RT'])

In [19]:
regions_final = [
    "AD", "AV", "AM", "IAD", "LD", "VPM", "VPL", "LGd",
    "MD", "CL", "CM", "IMD", "PO", "LP", "VAL", "VM", 
    "RE", "PF", "VPMpc", "PCN", "SPA", "PVT", "MH", "LH", "RT", 
]

In [20]:
# order here is following thalamoseq fig 1
dist, y_names, x_names = dm.cluster_distances_from_labels(
    obs, y_col=ccf_label, x_col=nsf_cols, 
    y_names=regions_final, x_names=range(30))

y_order, x_order = dm.order_distances_x_to_y(dist, reorder_y=False)

dm.plot_ordered_similarity_heatmap(dist, 
                                   label="Dice coefficient",
                                   y_order=y_order, 
                                   x_order=x_order, 
                                   y_names=y_names, 
                                   x_names=x_names, 
                                #    figsize=(6,6),
                                   )
plt.axis('image')
plt.title("Spatial pattern overlap")
plt.xlabel("NSF pattern")
plt.savefig("/results/nsf_ccf_similarity_heatmap_ccf_order.pdf", transparent=True)

In [21]:
# order here is following thalamoseq fig 1
dist, y_names, x_names = dm.cluster_distances_from_labels(
    obs, y_col=ccf_label, x_col=nsf_cols, 
    y_names=regions_subset, x_names=range(30))

y_order, x_order = dm.order_distances_x_to_y(dist, reorder_y=True)

dm.plot_ordered_similarity_heatmap(dist, 
                       y_order=y_order, x_order=x_order, 
                       y_names=y_names, x_names=x_names)

plt.savefig("/results/nsf_ccf_similarity_heatmap_ccf_order.pdf", transparent=True)

In [22]:
dist, y_names, x_names = dm.cluster_distances_from_labels(
    obs, y_col=ccf_label, x_col=nsf_cols, 
    y_names=regions_ot_clustering_order, x_names=range(30))

y_order, x_order = dm.order_distances_x_to_y(dist, reorder_y=False)

dm.plot_ordered_similarity_heatmap(dist, 
                       y_order=y_order, x_order=x_order, 
                       y_names=y_names, x_names=x_names)
plt.savefig("/results/nsf_ccf_similarity_heatmap_ot_order.pdf", transparent=True)

In [23]:
dist, y_names, x_names = dm.cluster_distances_from_labels(
    obs, y_col=ccf_label, x_col=nsf_cols, 
    y_names=regions_ccf_order, x_names=range(30))

y_order, x_order = dm.order_distances_x_to_y(dist, reorder_y=False)

dm.plot_ordered_similarity_heatmap(
    dist, 
    y_order=y_order, 
    x_order=x_order, 
    y_names=y_names, 
    x_names=x_names
)

plt.savefig("/results/nsf_ccf_similarity_heatmap_atlas_order.pdf", transparent=True)

In [24]:

# jaccard
# d = dist_nsf_ccf/(2-dist_nsf_ccf)

### ordered by NSF pattern clustering

In [25]:
# fig, ax = plt.subplots(figsize=(7,4.5))

dist_nsf_ccf = dist.T[nsf_order,:]
subset = dist_nsf_ccf.min(axis=1) < 0.9

y_order, ccf_order = dm.order_distances_x_to_y(dist_nsf_ccf[subset], reorder_y=True)
# ccf_order = plot_distances_sorted(dist_nsf_ccf[subset], nsf_order[subset], regions_subset, reorder_y=True)
fig = dm.plot_ordered_similarity_heatmap(
            dist_nsf_ccf[subset], 
            y_order=y_order, 
            x_order=ccf_order, 
            y_names=nsf_order[subset], 
            x_names=regions_subset
            )
# plt.savefig("/results/nsf_ccf_similarity.pdf")

In [26]:
dist = dist_nsf_ccf[:,ccf_order].T
y_order, x_order = dm.order_distances_x_to_y(dist, reorder_y=False)

fig = dm.plot_ordered_similarity_heatmap(
            dist, 
            y_order=y_order, 
            x_order=x_order, 
            y_names=regions_subset[ccf_order], 
            x_names=nsf_order)

In [27]:
# TODO: try clustering ccf on gene-space similarity (or spatial?) to order instead (sns.clustermap?)

## spatial plotting over CCF section

In [28]:

ccf_images = abc.get_ccf_labels_image(resampled=True, realigned=realigned, subset_to_left_hemi=True)
# erase right hemisphere
# ccf_images[550:,:,:] = 0
ccf_boundaries = cimg.sectionwise_label_erosion(ccf_images, distance_px=1, fill_val=0, return_edges=True)

In [29]:
section = 7.2
top_patterns = obs.loc[obs[section_col]==section, nsf_cols].max(axis=0).loc[lambda x: x>0.25].sort_values(ascending=False)
top_patterns

In [30]:
patterns = top_patterns.index[:10].sort_values()
obs_plot = obs.copy()
# obs_plot[patterns] /= obs_plot[patterns].values.sum(axis=1, keepdims=True)
fig = cplots.plot_multichannel_overlay(
    obs_plot,
    patterns,
    section,
    section_col=section_col,
    normalize_by=None,
    ccf_images=ccf_images,
    boundary_img=ccf_boundaries,
)
# plt.savefig("/results/nsf_section_patterns_overlay.pdf", transparent=True)

In [31]:
kwargs = dict(
    section_col="z_section",
    x_col="x_" + coords,
    y_col="y_" + coords,
    point_size=0.5,
    # face_palette=None,
    # edge_color='grey',
    boundary_img=ccf_boundaries,
)
cols = top_patterns.index[:10]

fig, axes = plt.subplots(2, 5, figsize=(10, 5))
axes = axes.ravel()

for i, col in enumerate(cols):
    cplots.plot_section_overlay(
        obs,
        ccf_images=ccf_images,
        section=section,
        # sections=[section],
        point_hue=col,
        point_palette="Greens",
        custom_xy_lims=cplots.XY_LIMS_TH_LEFT_HEMI,
        # categorical=False,
        ax=axes[i],
        legend=None,
        **kwargs,
    )
# fig.savefig("/results/nsf_section_patterns_tiled.pdf", transparent=True)

## plotting individual patterns and genes

In [32]:

kwargs = dict(
    section_col='z_section',
    x_col = 'x_'+coords,
    y_col = 'y_'+coords,
    point_size=2, 
    # face_palette=None,
    # edge_color='grey',
    boundary_img=ccf_boundaries,
    custom_xy_lims=cplots.XY_LIMS_TH_LEFT_HEMI,
)

In [33]:
# MD
# n=27, 7, 23 (or 0 for center)
n=7
# VM etc
# little gene overlap
# n=20


### AV

In [34]:
# AV
for n in [22, 24]:
    fig = cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}', sections=[8.0],
                        point_palette='viridis', legend=None,
                        **kwargs)

In [35]:

diff_genes = adata.var["nsf22"] - adata.var["nsf24"]
(-diff_genes).sort_values(ascending=False).head()

In [36]:
cols = (-diff_genes).sort_values(ascending=False).head(2).index
for col in cols:
    adata.obs[col] = 2**adata[:,col].X.toarray().squeeze()
    fig = cplots.plot_ccf_overlay(adata.obs, 
                            ccf_images, 
                            categorical=False,
                            point_hue=col, 
                            sections=[8.0],
                            point_palette='Blues', 
                            legend=None,
                            **kwargs)

    fig[0].savefig(f"/results/nsf_genes_AV_{col}.pdf", transparent=True)

In [37]:

diff_genes = adata.var["nsf22"] - adata.var["nsf24"]
(diff_genes).sort_values(ascending=False).head()

In [38]:
cols = diff_genes.sort_values(ascending=False).head(2).index
for col in cols:
    adata.obs[col] = 2**adata[:,col].X.toarray().squeeze()
    fig = cplots.plot_ccf_overlay(adata.obs, 
                            ccf_images, 
                            categorical=False,
                            point_hue=col, 
                            sections=[8.0],
                            point_palette='Blues', 
                            legend=None,
                            **kwargs)
    fig[0].savefig(f"/results/nsf_genes_AV_{col}.pdf", transparent=True)

### MD

In [39]:
region = 'MD'
section_MD = 7.2

In [40]:
for n in [7, 27, 23]:
    figs = cplots.plot_ccf_overlay(obs, 
                            ccf_images, 
                            categorical=False,
                            point_hue=f'nsf{n}', 
                            sections=[section_MD],
                            point_palette='Greens', 
                            legend=None,
                            **kwargs)

In [41]:
n=23
figs = cplots.plot_ccf_overlay(obs, 
                            ccf_images, 
                            categorical=False,
                            point_hue=f'nsf{n}', 
                            sections=[section_MD],
                            point_palette='Greens', 
                            legend=None,
                            **kwargs)

figs[0].savefig(f"/results/nsf_pattern_nsf{n}_{region}.pdf", transparent=True)

In [42]:

diff_genes = adata.var["nsf23"] - adata.var["nsf27"] - adata.var["nsf7"]
diff_genes.sort_values(ascending=False).head()

In [43]:
cols = diff_genes.sort_values(ascending=False).head(3).index
for col in cols:
    adata.obs[col] = 2**adata[:,col].X.toarray().squeeze()
    figs = cplots.plot_ccf_overlay(
                            adata.obs, 
                            ccf_images, 
                            categorical=False,
                            point_hue=col, 
                            sections=[section_MD],
                            point_palette='Greens', 
                            legend=None,
                            **kwargs)
    figs[0].savefig(f"/results/nsf_genes_{region}_{col}.pdf", transparent=True)

## gene plots from loadings

Caution: Do not always look like factors!

In [44]:
loadings = adata.var[nsf_cols].to_numpy()
lin_weights = loadings @ np.linalg.pinv(loadings.T @ loadings)
# feat @ loadings.T = X, so feat ~= X @ lin_weights
adata.obsm['genes_on_loadings'] = adata.X @ loadings
adata.obsm['genes_linear_projection'] = adata.X @ lin_weights

In [45]:
N = loadings.shape[1]
# normalize by gene to compare patterns by gene?
# loadings_norm = loadings/loadings.sum(axis=1, keepdims=True)
loadings_norm = loadings
top_gene = np.zeros(N, dtype=int)
gene_prominence = np.zeros(N)
for n in range(N):
    pattern_prominence = loadings_norm[:,n] - np.delete(loadings_norm, n, axis=1).max(axis=1)
    n_gene = np.argmax(pattern_prominence)
    top_gene[n] = n_gene
    gene_prominence[n] = pattern_prominence[n_gene]
    obs[f"nsf{n}_1gene"] = 2**np.array(adata.X[:,n_gene])

In [46]:
np.nonzero(gene_prominence>0.1)

In [47]:
(top_gene == loadings.argmax(axis=0))[np.nonzero(gene_prominence>0.1)]

In [48]:

kwargs = dict(
    section_col='z_section',
    x_col = 'x_'+coords,
    y_col = 'y_'+coords,
    point_size=1, 
    face_palette=None,
    edge_color='grey',
    boundary_img=ccf_boundaries,
    custom_xy_lims=cplots.XY_LIMS_TH_LEFT_HEMI,
)

### AV

In [49]:
for n in[22, 24]:
    print(adata.var_names[top_gene[n]])
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}_1gene', sections=[8.0], 
                        point_palette='Blues', legend=None, 
                        **kwargs);

### MD

In [50]:
for n in[7, 27, 23]:
    print(adata.var_names[top_gene[n]])
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}_1gene', sections=[section], 
                        point_palette='Blues', legend=None, 
                        **kwargs);

In [52]:

for i, x in enumerate(nsf_cols):
    loadings = adata.var[x].to_numpy()[:,None]
    obs[x+"_allgenes"] = adata.X @ loadings
    order = loadings.argsort()
    loadings[order[:-3]] = 0
    obs[x+"_3genes"] = adata.X @ loadings
    loadings[order[:-2]] = 0
    obs[x+"_2genes"] = adata.X @ loadings
    obs[x+"_linear_genes"] = adata.X @ lin_weights[:,[i]]


In [53]:
for n in[7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}_linear_genes', sections=[section], 
                        point_palette='viridis', legend=None, 
                        **kwargs);

In [54]:
for n in[7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}_allgenes', sections=[section], 
                        point_palette='viridis', legend=None, 
                        **kwargs);

In [55]:
for n in[7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}_3genes', sections=[section], 
                        point_palette='viridis', legend=None, 
                        **kwargs);

In [56]:
for n in[7, 27, 23]:
    cplots.plot_ccf_overlay(obs, ccf_images, categorical=False,
                        point_hue=f'nsf{n}_2genes', sections=[section], 
                        point_palette='viridis', legend=None, 
                        **kwargs);