# Figure 2 | Transcriptomic diversity across the thalamus.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import seaborn as sns

import sys
sys.path.append("/code/")
from thalamus_merfish_analysis import abc_load as abc
from thalamus_merfish_analysis import ccf_images as cimg
from thalamus_merfish_analysis import ccf_plots as cplots
from thalamus_merfish_analysis import ccf_erode as cerd
from thalamus_merfish_analysis import diversity_metrics as divmet
from thalamus_merfish_analysis import diversity_plots as dplot
import thalamus_merfish_analysis.distance_metrics as dm


from matplotlib import rcParams
rcParams['ps.fonttype'] = 42
rcParams['pdf.fonttype'] = 42
rcParams['font.size'] = 7

import matplotlib.pyplot as plt
get_ipython().run_line_magic('matplotlib', 'inline') 

## Load thalamus data

In [3]:
obs_wmb = abc.get_combined_metadata()
obs_th = abc.load_standard_thalamus(data_structure='obs')

In [4]:
# set spatial column names to `_reconstructed` coordinate space
coords = '_reconstructed'

x_col = 'x'+coords
y_col = 'y'+coords
z_col = 'z'+coords
section_col = 'brain_section_label'

In [5]:
# Load the CCF structures image
ccf_images = abc.get_ccf_labels_image(realigned=True)

# Merge substructures into structures (esp. AMd + AMv -> AM; LGd-co + LGd-ip + LGd-sh -> LG)
ccf_images = cerd.merge_substructures(ccf_images, ccf_level='structure')

# we only need to load the boundaries for sections that contain the thalamus
sections_all = sorted(obs_th[z_col].unique())
sections_int = np.rint(np.array(sections_all)/0.2).astype(int)

# precompute boundaries for thalamic sections
ccf_boundaries = cimg.sectionwise_label_erosion(ccf_images, distance_px=1, 
                                                fill_val=0, return_edges=True, 
                                                section_list=sections_int
                                                )

In [7]:
# We can load & use the published color palettes for the ABC Atlas taxonomy
abc_palettes = {level: abc.get_taxonomy_palette(level) for level in 
                ['neurotransmitter','class', 'subclass','supertype','cluster']}

# We have also provided a secondary color palette for the cluster level that 
# attempts to increase the color contrast between spatially negihboring clusters
abc_palettes['cluster'] = abc.get_thalamus_cluster_palette() 

ccf_palette = abc.get_ccf_palette('substructure')
# ccf_palette['ZI'] = ccf_palette.pop('ZI-unassigned') # need to fix a key error

In [8]:
# We'll use the eroded CCF structures to calculate similarity metrics
ccf_metrics_level = "structure"
obs_erode, ccf_label_eroded = cerd.label_cells_by_eroded_ccf(obs_th, ccf_images, ccf_level=ccf_metrics_level, distance_px=5) # default is erosion by 5px (50um)
ccf_label = ccf_label_eroded
# There is poor alignment in section 6.6 between PF celltypes and the PF CCF structure
# So, we'll set all cells in section 6.6 to 'unassigned' CCF structure
obs_erode.loc[lambda df: df['z_section']==6.6, ccf_label_eroded] = 'unassigned'

In [9]:
# if you reload thalamus_merfish_analysis.ccf_plots after this cell has been run, 
# cplots.CCF_REGIONS_DEFAULT will be reset to None & you'll need to re-run this cell
cplots.CCF_REGIONS_DEFAULT = abc.get_thalamus_names()

In [10]:
results_dir = '../../results'

## Fig. 2A-C | Similarity heatmaps: Thalamic nuclei vs cell types

In [11]:
regions_to_plot = [
    "AD", "AV", "AM", "IAD", "LD", "VPM", "VPL", "LGd",
    "MD", "CL", "CM", "IMD", "PO", "LP", "VAL", "VM", 
    "RE", "PF", "VPMpc", "PCN", "SPA", "PVT", "MH", "LH", "RT", 
]

In [12]:
from matplotlib.colors import LinearSegmentedColormap, BoundaryNorm

# Get the rocket_r colormap
rocket_r = sns.color_palette("rocket_r", as_cmap=True)

# Number of colors in the original colormap
n_colors_in_rocket = 256  # or use len(sns.color_palette("rocket_r"))

# Generate colors from the rocket_r colormap
colors = rocket_r(np.linspace(0, 1, n_colors_in_rocket))

# Add one white color at the beginning
white = np.array([[1, 1, 1, 1]])  # RGBA for white
new_colors = np.vstack([white, colors])  # Combine white and rocket_r colors

# Create a new colormap with the modified colors
custom_cmap = LinearSegmentedColormap.from_list("custom_rocket_r_with_white", new_colors)

# Create a discrete colormap with the modified colors
n_bins = 11
custom_cmap_discrete = LinearSegmentedColormap.from_list("custom_rocket_r_with_white", 
                                                         new_colors, 
                                                         N=n_bins)
# Create the norm to map values to the bins
norm = BoundaryNorm(np.concatenate([[0, 0.00000001], np.linspace(0.1, 1, n_bins-1)]), 
                    ncolors=n_bins)

# Example usage with imshow
data = [np.concatenate([np.array([0.0, 0.2]), np.random.random(10)])]
print(data)
plt.imshow(data, cmap=custom_cmap_discrete, norm=norm)
plt.colorbar()
plt.show()

In [13]:
np.concatenate([[0], np.linspace(0, 1, n_bins)])

In [14]:
np.random.random(10)

In [15]:

# Generate a purple-like palette with cubehelix
cmap_purple_256 = sns.cubehelix_palette(256, start=-0.2, rot=0.6, dark=0.1, light=1.0, as_cmap=True)
sns.palplot(sns.cubehelix_palette(256, start=-0.2, rot=0.6, dark=0.1, light=1.0))
plt.show()

cmap_purple_11 = sns.cubehelix_palette(11, start=-0.2, rot=0.6, dark=0.1, light=1.0, as_cmap=True)
sns.palplot(sns.cubehelix_palette(11, start=-0.2, rot=0.6, dark=0.1, light=1.0))
plt.show()

# Example usage with imshow
data = [np.concatenate([np.array([0.0, 0.2]), np.random.random(10)])]
print(data)
plt.imshow(data, cmap=cmap_purple_11) #, norm=norm)
plt.colorbar()
plt.show()


# Generate colors from the rocket_r colormap
colors = cmap_purple_256(np.linspace(0, 1, 256))

# Create a new colormap with the modified colors
custom_cmap = LinearSegmentedColormap.from_list("custom_purple_256", colors)

# Create a discrete colormap with the modified colors
n_bins = 16
custom_purple_discrete = LinearSegmentedColormap.from_list("custom_purple_256", 
                                                         colors, 
                                                         N=n_bins)
# Create the norm to map values to the bins
norm = BoundaryNorm(np.concatenate([np.linspace(0, 0.8, n_bins+1)]), 
                    ncolors=n_bins)

# Example usage with imshow
data = [np.concatenate([np.array([0.0, 0.05, 0.1]), np.random.random(10)])]
print(data)
plt.imshow(data, cmap=custom_purple_discrete, norm=norm)
plt.colorbar()
plt.show()

In [16]:
# Generate 256 equally spaced values from the colormap
cmap_values = cmap_purple_256(np.linspace(0, 1, 256))

# Convert RGB to lightness (grayscale)
# Perceptual weights for RGB to grayscale conversion
lightness = 0.2126 * cmap_values[:, 0] + 0.7152 * cmap_values[:, 1] + 0.0722 * cmap_values[:, 2]

# Plot the lightness values
# Plot each marker with its corresponding colormap color
for i in range(256):
    plt.plot(i, lightness[i], 'o', color=cmap_values[i], markersize=6)
plt.xlabel("Colormap Index")
plt.ylabel("Lightness (Perceptual)")
plt.title("Perceptual Lightness of cmap_purple_256")
plt.legend()
plt.show()


In [30]:
# color map options
# cmap = 'viridis'
# cmap = 'rocket_r'
# cmap = custom_cmap
# cmap = custom_cmap_discrete
cmap = cmap_purple_256

# range
vmin = 0
vmax = 0.8

### Fig. 2A | Similarity heatmap: Thalamic nuclei vs subclasses

In [31]:
# thalamic nuclei vs subclass similarity
taxonomy_level = 'subclass'
dist, y_names, x_names = dm.cluster_distances_from_labels(
                            obs_erode, y_col=ccf_label, x_col=taxonomy_level, 
                            y_names=sorted(regions_to_plot),
                            x_names=sorted(obs_erode[taxonomy_level].unique()))

y_order, x_order = dm.order_distances_x_to_y(dist, reorder_y=False) #True)

x_names_ids_only = [name[0:3] for name in x_names]

fig_hm_subclass = dm.plot_ordered_similarity_heatmap(
    dist, 
    y_order=y_order, x_order=x_order, 
    y_names=y_names, x_names=x_names_ids_only,
    cmap=cmap,
    # cmap=custom_purple_discrete, norm=norm,
    vmin=vmin, vmax=vmax,
)

fig_hm_subclass.set_size_inches(1.8,3)
fig_hm_subclass.gca().set_xlabel(taxonomy_level)
fig_hm_subclass.gca().set_ylabel('CCF Structure')

fig_hm_subclass.savefig(f'{results_dir}/fig2A_similarity_heatmap_nuclei_vs_{taxonomy_level}_alpha-reorderedY.pdf', 
                        transparent=True, bbox_inches='tight')
fig_hm_subclass.savefig(f'{results_dir}/fig2A_similarity_heatmap_nuclei_vs_{taxonomy_level}_alpha-reorderedY.png', 
                        dpi=600, transparent=True, bbox_inches='tight')


# Plot colored labels for the x-axis categories
fig_xlabel_colors = dm.plot_heatmap_xlabel_colors(x_names, 
                                                  x_order, 
                                                  abc_palettes[taxonomy_level])
fig_xlabel_colors.set_size_inches(1.8,3)
fig_xlabel_colors.savefig(f'{results_dir}/fig2A_similarity_heatmap_nuclei_vs_{taxonomy_level}_xaxis_colors.pdf', 
                        transparent=True, bbox_inches='tight')

### Fig. 2B | Similarity heatmap: Thalamic nuclei vs supertypes

In [19]:
# thalamic nuclei vs supertype similarity
taxonomy_level = 'supertype'
dist, y_names, x_names = dm.cluster_distances_from_labels(
                            obs_erode, y_col=ccf_label, x_col=taxonomy_level, 
                            y_names=sorted(regions_to_plot),
                            x_names=sorted(obs_erode[taxonomy_level].unique()))

y_order, x_order = dm.order_distances_x_to_y(dist, reorder_y=True)

x_names_ids_only = [name[0:4] for name in x_names]

fig_hm_supertype = dm.plot_ordered_similarity_heatmap(
    dist, 
    y_order=y_order, x_order=x_order, 
    y_names=y_names, x_names=x_names_ids_only,
    cmap=cmap,
    # cmap=custom_purple_discrete, norm=norm,
    vmin=vmin, vmax=vmax,
)

fig_hm_supertype.set_size_inches(5.8,3)
fig_hm_supertype.gca().set_xlabel(taxonomy_level)
fig_hm_supertype.gca().set_ylabel('CCF Structure')


fig_hm_supertype.savefig(f'{results_dir}/fig2B_similarity_heatmap_nuclei_vs_{taxonomy_level}_alpha-reorderedY.pdf', 
                        transparent=True, bbox_inches='tight')
fig_hm_supertype.savefig(f'{results_dir}/fig2B_similarity_heatmap_nuclei_vs_{taxonomy_level}_alpha-reorderedY.png', 
                        dpi=600, transparent=True, bbox_inches='tight')


# Plot colored labels for the x-axis categories
fig_xlabel_colors = dm.plot_heatmap_xlabel_colors(x_names, 
                                                  x_order, 
                                                  abc_palettes[taxonomy_level])
fig_xlabel_colors.set_size_inches(5.7,3)
fig_xlabel_colors
fig_xlabel_colors.savefig(f'{results_dir}/fig2B_similarity_heatmap_nuclei_vs_{taxonomy_level}_xaxis_colors.pdf', 
                          transparent=True, bbox_inches='tight')

### Fig. 2C | Similarity heatmap: Thalamic nuclei vs clusters

In [20]:
regions_to_plot_manual_clusters = [
    "AD", "AM", "AV", "CM", "IAD", "IMD", 
    "LD", "LGd", "LP", "PVT", "MD", "CL",
    "PCN", "VAL", "VM", "VPM", "VPMpc",
    "PO", "VPL", "RE", 
    "LH", "MH", "PF",  "RT", "SPA"
]

# thalamic nuclei vs cluster similarity
taxonomy_level = 'cluster'
dist, y_names, x_names = dm.cluster_distances_from_labels(
                            obs_erode, y_col=ccf_label, x_col=taxonomy_level, 
                            y_names=regions_to_plot_manual_clusters, #sorted(regions_to_plot),
                            x_names=sorted(obs_erode[taxonomy_level].unique()))

y_order, x_order = dm.order_distances_x_to_y(dist, reorder_y=False) #reorder_y=True)

x_names_ids_only = [name[0:4] for name in x_names]

fig_hm_cluster = dm.plot_ordered_similarity_heatmap(
    dist, 
    y_order=y_order, x_order=x_order, 
    y_names=y_names, x_names=x_names_ids_only,
    cmap=cmap,
    # cmap=custom_purple_discrete, norm=norm,
    vmin=vmin, vmax=vmax,
)

fig_hm_cluster.set_size_inches(8.5,3)
fig_hm_cluster.gca().set_xlabel(taxonomy_level)
fig_hm_cluster.gca().set_ylabel('CCF Structure')


fig_hm_cluster.savefig(f'{results_dir}/fig2C_similarity_heatmap_nuclei_vs_{taxonomy_level}_alpha-reorderedY.pdf', 
                        transparent=True, bbox_inches='tight')
fig_hm_cluster.savefig(f'{results_dir}/fig2C_similarity_heatmap_nuclei_vs_{taxonomy_level}_alpha-reorderedY.png', 
                        dpi=600, transparent=True, bbox_inches='tight')


# Plot colored labels for the x-axis categories
fig_xlabel_colors = dm.plot_heatmap_xlabel_colors(x_names, 
                                                  x_order, 
                                                  abc_palettes[taxonomy_level])
fig_xlabel_colors.set_size_inches(6.3,3)
fig_xlabel_colors.savefig(f'{results_dir}/fig2C_similarity_heatmap_nuclei_vs_{taxonomy_level}_xaxis_colors.pdf', 
                          transparent=True, bbox_inches='tight')

## Fig. 2D-E: Examples of cluster:nuclei correspondance

In [21]:
# set kwargs fpr annotated cluster plots
kwargs_cluster_annotations = dict(
    section_col=z_col,
    x_col=x_col,
    y_col=y_col,
    point_size=1.5,
    figsize=(4, 2),
    face_palette=None,
    edge_color='silver'
    )

### Fig. 2D | VAL & VM

In [22]:
# subset to just the left hemisphere to save space to display multiple sections
obs_th_left = obs_th.loc[obs_th['left_hemisphere']]

# Need to reload ccf images for just the left hemisphere
ccf_images_left = abc.get_ccf_labels_image(subset_to_left_hemi=True)
# Merge substructures into structures (esp. AMd + AMv -> AM; LGd-co + LGd-ip + LGd-sh -> LG)
ccf_images_left = cerd.merge_substructures(ccf_images_left, ccf_level='structure')

# we only need to load the boundaries for sections that contain the thalamus
sections_all = sorted(obs_th_left[z_col].unique())
sections_int = np.rint(np.array(sections_all)/0.2).astype(int)

# precompute boundaries for thalamic sections
ccf_boundaries_left = cimg.sectionwise_label_erosion(ccf_images_left, 
                                                     distance_px=1, 
                                                     fill_val=0, 
                                                     return_edges=True, 
                                                     section_list=sections_int
                                                     )

In [23]:
# get the cluster annotations for the nucleus of interest
nucleus = ['VAL', 'VM']
obs_annot = abc.get_obs_from_annotations(nucleus, obs_th_left, taxonomy_level='cluster')
# # subset to left hemisphere
# obs_annot = obs_annot.loc[obs_annot['left_hemisphere']]
# drop a couple that seem to be due to CCF misalignment
cat_to_drop = ['2686 TH Prkcd Grin2c Glut_13', '2684 TH Prkcd Grin2c Glut_13']
obs_annot = obs_annot.loc[~obs_annot['cluster'].isin(cat_to_drop)]


# Let's look at the cell types in the ATN in one sample section
# sections_all_VAL_VM = [7.8, 7.6, 7.2, 7.0, 6.8, 6.6]
# ant. to post.: C57BL6J-638850.42, C57BL6J-638850.40, C57BL6J-638850.39, C57BL6J-638850.38, C57BL6J-638850.37
sections_to_plot = [7.6, 7.2, 7.0, 6.8, 6.6] 
nuclei_highlight = nucleus
taxonomy_level = 'cluster'

figs_annot = cplots.plot_ccf_overlay(obs_annot,
                                     ccf_images_left,
                                     boundary_img=ccf_boundaries_left,
                                     bg_cells=obs_th_left, 
                                     ccf_highlight=nuclei_highlight,
                                     point_hue=taxonomy_level, 
                                     sections=sections_to_plot, # sections_all_VAL_VM
                                     point_palette=abc_palettes[taxonomy_level],
                                     legend='cells',
                                     **kwargs_cluster_annotations)

for i, sec in enumerate(sections_to_plot):
    sec_str = int(sec*10)
    figs_annot[i].savefig(f'{results_dir}/fig2D_cluster_annotations_z{sec_str}_VAL_VM.pdf',
                        transparent=True, bbox_inches='tight')
    figs_annot[i].savefig(f'{results_dir}/fig2D_cluster_annotations_z{sec_str}_VAL_VM.png',
                        transparent=True, bbox_inches='tight', dpi=1200)

### Fig. 2E | Anterior thalamic nuclei (ATN: AD, AV, AM)

In [24]:
# get the cluster annotations for the nucleus of interest
nucleus = ['AD', 'AM', 'AV']
obs_annot = abc.get_obs_from_annotations(nucleus, obs_th_left, taxonomy_level='cluster')

# Let's look at the cell types in the ATN in one sample section
# anterior to posterior: C57BL6J-638850.44, C57BL6J-638850.43
sections_to_plot = [8.0, 7.8]
nuclei_highlight = ['AD', 'AMd', 'AMv', 'AV', 'AV']
taxonomy_level = 'cluster'

plt.rcParams.update({'font.size': 7})
figs_annot = cplots.plot_ccf_overlay(obs_annot, 
                                    ccf_images_left,
                                    boundary_img=ccf_boundaries_left,
                                    bg_cells=obs_th_left, 
                                    ccf_highlight=nuclei_highlight,
                                    point_hue=taxonomy_level, 
                                    sections=sections_to_plot,
                                    point_palette=abc_palettes[taxonomy_level],
                                    legend='cells',
                                    **kwargs_cluster_annotations)

for i, sec in enumerate(sections_to_plot):
    sec_str = int(sec*10)
    figs_annot[i].savefig(f'{results_dir}/fig2E_cluster_annotations_z{sec_str}_ATN_left.pdf',
                        transparent=True, bbox_inches='tight')
    figs_annot[i].savefig(f'{results_dir}/fig2E_cluster_annotations_z{sec_str}_ATN_left.png',
                        transparent=True, bbox_inches='tight', dpi=1200)

## Fig. 2F-G | Cluster diversity metrics

In [25]:
# calculate per-nucleus diversity metrics
th_ccf_metrics = divmet.calculate_diversity_metrics(obs_erode, ccf_label=ccf_label)

# set color map
cmap = sns.color_palette("mako_r", as_cmap=True)

### Fig. 2F | Per nucleus cluster diversity
Aggregates per nucleus, across all sections

In [26]:
# plot the cell type counts in the CCF structures
sections_to_plot = cplots.TH_EXAMPLE_Z_SECTIONS
plt.rcParams.update({'font.size': 7})
figs_clust_count_norm2cells = cplots.plot_metrics_ccf(ccf_images_left, 
                                    th_ccf_metrics['count_norm2cells_cluster'], 
                                    sections_to_plot,
                                    ccf_level=ccf_metrics_level,
                                    vmin=0, 
                                    vmax=0.15, 
                                    cmap=cmap,
                                    cb_label='cluster count / # cells'
                                    )

for i, fig in enumerate(figs_clust_count_norm2cells):
    fig.set_size_inches(3.5, 2)
    fig.savefig(f'{results_dir}/fig2F_ccf_cluster_count_norm2cells_sec{int(sections_to_plot[i]*10)}.png',
                transparent=True, bbox_inches='tight', dpi=300)
    
fig.savefig(f'{results_dir}/fig2F_ccf_cluster_count_norm2cells_colorbar.pdf',
            transparent=True, bbox_inches='tight')

### Fig 2G | Local cluster diversity 
local = 15 nearest neighbors in same section

In [27]:
# calculate local diversity index using Inverse Simpson's Index (ISI) as metric
local_isi_df = divmet.calculate_local_diversity_metric(obs_erode, 
                                                       divmet.inverse_simpsons_index, 
                                                       metric_name='isi', 
                                                       n_neighbors=15)

In [28]:
metric_name = 'local_isi_cluster'

# subset to left hemisphere for plotting
obs_erode_left = obs_erode.loc[obs_erode['left_hemisphere']]

for section in cplots.TH_EXAMPLE_Z_SECTIONS:
    fig = dplot.plot_local_metric_ccf_section(obs_erode_left, local_isi_df, ccf_images_left,
                                              section, metric_name, cmap=cmap)

    fig.set_size_inches(3.5, 2)
    fig.savefig(f'{results_dir}/fig2G_ccf_local_isi_cluster_sec{int(section*10)}.png',
                transparent=True, bbox_inches='tight', dpi=300)
    fig.savefig(f'{results_dir}/fig2G_ccf_local_isi_cluster_colorbar.pdf',
                transparent=True, bbox_inches='tight')