# 0_subset_abc_atlas_to_thalamus

To subset the published ABC Atlas MERFISH dataset (3,739,961 cells) to cells that belong to the thalamus and/or the zona incerta (XXX cells), we used a combination of spatial and taxonomy-based labeling.

Functions for this process can be found in the `abc_load.py` & `abc_load_base.py` modules of the `thalamus-merfish-analysis` Python package.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.lines import Line2D

import seaborn as sns
import scipy.ndimage as ndi

%matplotlib inline

In [12]:
from thalamus_merfish_analysis import abc_load as abc
from thalamus_merfish_analysis import ccf_images as cimg

## Load cell metadata DataFrame

In [4]:
cells_df = abc.get_combined_metadata(drop_unused=False)
print(f'n_cells in the whole-brain ABC Atlas: {cells_df.shape[0]}')

## 1) Spatial filtering based on ARA parcellations & CCFv3 image volumes

As an initial spatial filtering step, we used a rasterized anatomical annotation 
image volume, which aligns a resampled (10x10x??um) CCFv3 coordinate space into 
the MERFISH data space and where each voxel is assigned a parcellation index that 
uniquely identifies its assigned 3D Allen Reference Atlas (ARA) anatomical parcellation.

All the detailed steps outlined below are wrapped in the `label_thalamus_spatial_subset()` 
function from the `abc_load` module. We confirm this by showing that the number
of cells remaining after this single function call (227,417) is the same as the
number of cells remaining after detail step 1f (also 227,417).

In [5]:
cells_df_th_subset = abc.label_thalamus_spatial_subset(cells_df,
                                                       distance_px=20,
                                                       cleanup_mask=True,
                                                       drop_end_sections=True,
                                                       filter_cells=True)

print(f'n_cells after label_thalamus_spatial_subset(): {cells_df_th_subset.shape[0]}')

### 1a) Load & view CCF parcellation/annotation image volumes

`imshow()` colors each pixel by its parcellation_index. 
You can see that each ARA parcellation has its own color (displayed in grayscale).

In [6]:
# loads the whole-brain CCF resampled_annotation image volumes
ccf_img = abc.get_ccf_labels_image(resampled=True)

# 1100x1100 pixels where each pixel is 10um^2; 76 sections
print(f'{ccf_img.shape=}')

# display an example section of the image volume that contains the thalamus 
zindex = 36 
plt.imshow(ccf_img[:,:,zindex].T, cmap='gray')  # transpose to plot in correct orientation
plt.title(f'example section: {zindex}')
xy_labels = 'resampled pixels (1px = 10um)'
plt.xlabel(xy_labels); plt.ylabel(xy_labels)
plt.show()

### 1b) Generate binary mask for TH+ZI parcellation regions

We generated a binary TH+ZI mask that included all voxels labelled with a thalalmus 
(‘TH’) or a zona incerta (‘ZI’) associated parcellation index.

In [7]:
# get all parcellation names that are in either TH or ZI
ccf_regions_to_select = ['TH', 'ZI'] # TH = thalamus ; ZI = zona incerta
ccf_regions = abc.get_ccf_names(top_nodes=ccf_regions_to_select, 
                                level='substructure')
print(f'{ccf_regions=}')

# convert parcellation names to the unique parcellation_index used in the image volume
ccf_index = abc.get_ccf_index(level='substructure')
reverse_lookup = pd.Series(ccf_index.index.values, index=ccf_index)
th_zi_index_values = reverse_lookup.loc[ccf_regions]
print(f'{th_zi_index_values=}')

In [8]:
# if you want to subset by a different ARA parcellation region, you can change 
# the `level` param {'division', 'structure', 'substructure'} to view alterative
# options to add to `ccf_regions_to_select` in the previous cell
all_division_names = abc.get_ccf_names(top_nodes=None, level='division')
print(all_division_names)

In [9]:
# generate a binary TH+ZI mask
th_zi_mask = np.isin(ccf_img, th_zi_index_values)

# subset the ccf image volume using the binary mask
th_zi_img = np.where(th_zi_mask, ccf_img, 0)


# display the subsetted image volume and TH+ZI binary mask
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
xy_labels = 'resampled pixels (1px=10um)'

ax1.imshow(ccf_img[:,:,zindex].T, cmap='gray')
ax1.set_title('CCFv3 parcellation substructures - whole-brain')
ax1.set_xlabel(xy_labels); ax1.set_ylabel(xy_labels)

ax2.imshow(th_zi_img[:,:,zindex].T, cmap='gray')
ax2.set_title('parcellation substructures - TH & ZI')
ax2.set_xlabel(xy_labels)

ax3.imshow(th_zi_mask[:,:,zindex].T, cmap='gray')
ax3.set_title('TH+ZI binary mask')
ax3.set_xlabel(xy_labels)

plt.show()

### 1c) Dilate TH+ZI binary mask by 20px (200um)
Potentially misaligned TH or ZI cells were captured by dilating the binary mask 
by 20 pixels (200um).

In [10]:
# 20px=200um
th_zi_mask_dilated = abc.sectionwise_dilation(th_zi_mask, distance_px=20)


# display the subsetted image volume and TH+ZI binary mask
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))
xy_labels = 'resampled pixels (1px=10um)'

ax1.imshow(ccf_img[:,:,zindex].T, cmap='gray')
ax1.set_title('CCFv3 parcellation substructures - whole-brain')
ax1.set_xlabel(xy_labels); ax1.set_ylabel(xy_labels)

ax2.imshow(th_zi_img[:,:,zindex].T, cmap='gray')
ax2.set_title('parcellation substructures - TH & ZI')
ax2.set_xlabel(xy_labels)

ax3.imshow(th_zi_mask[:,:,zindex].T, cmap='gray')
ax3.set_title('TH+ZI binary mask')
ax3.set_xlabel(xy_labels)

ax4.imshow(th_zi_mask_dilated[:,:,zindex].T, cmap='gray')
ax4.set_title('TH+ZI binary mask, dilated by 20px')
ax4.set_xlabel(xy_labels)

plt.show()

### 1d) Clean up the binary mask

In [14]:
th_zi_mask_dilated = cimg.cleanup_mask_regions(th_zi_mask_dilated, 
                                               area_ratio_thresh=0.1)

### 1e) Label & filter using spatial mask

This reduces the dataset to 233,340 cells from 16 coronal sections.

In [16]:
# label cells that fall within the TH+ZI mask
field_name='TH_ZI_dataset'
coords = ['x_reconstructed','y_reconstructed','z_reconstructed']
resolutions = np.array([10e-3, 10e-3, 200e-3])
cells_df[field_name] = th_zi_mask_dilated[cimg.image_index_from_coords(cells_df[coords], resolutions)]
# cells_df = abc_load_base._label_masked_cells(cells_df, th_zi_mask_dilated, coords, resolutions, field_name=field_name)

# filter out cells that do not fall within the TH+ZI mask
th_zi_cells_df = cells_df[cells_df[field_name]].copy().drop(columns=[field_name])

print(f'n_cells after spatial filtering: {th_zi_cells_df.shape[0]}')

### 1f) Drop the anterior-most section and posterior-most section

At this stage, we also filtered out all cells from the anterior-most section and 
the posterior-most section that contained TH and ZI cells due to poor alignment 
between the CCFv3 TH+ZI parcellation regions and cells mapping to thalamic cell 
types, as determined by visual inspection. This further reduced the dataset to 
227,955 cells and 14 coronal sections. 

In [17]:
# anterior-most section that is dropped: z_section=8.4
# posterior-most section that is dropped: z_section=4.8
th_zi_cells_df = th_zi_cells_df.query("5.0 <= z_section <= 8.2")

print(f'n_cells after end section drop: {th_zi_cells_df.shape[0]}')

## 2) Cell-type filtering based on mapped subclasses

### visualize which cells are captured by the TH mask

In [93]:
def plot_th_mask_with_cell_overlay(cell_df, th_mask, th_zi_sections, 
                                   column='class', trim_to_th=True, show_mask=True):

    # make a fixed colormap for displaying thalamus mask
    if show_mask:
        cmap_th_mask = colors.ListedColormap(['black', 'white'])
    else:
        cmap_th_mask = colors.ListedColormap(['white', 'white'])
    bounds=[0,1]
    norm = colors.BoundaryNorm(bounds, cmap_th_mask.N)

    # define legend elements for the column colors
    color_col = column+'_color'
    categories = cell_df[column].unique()
    cat_color_mapping = dict(zip(categories, cell_df[color_col].unique()))
    # sort the dict & the categories list by category
    cat_color_mapping = dict(sorted(cat_color_mapping.items()))
    categories = sorted(categories)
    legend_elements = [Line2D([0], [0], 
                              lw=0, marker='o', markersize=10,
                              markerfacecolor=cat_color_mapping[cat],
                              color=cat_color_mapping[cat], 
                              label=cat) 
                       for cat in cat_color_mapping]
    
    n_col = 2
    n_row = int(np.ceil(len(th_zi_sections) / n_col))
    fig, axes = plt.subplots(n_row, n_col, figsize=(n_col*6,n_row*3.75))
    axes = axes.ravel()

    

    for i, sec in enumerate(th_zi_sections):
        ax = axes[i]

        curr_sec_cell_df = cell_df[(cell_df['z_reconstructed']==sec)]

        ax.imshow(th_mask[:,:,int(np.rint(sec/0.2))].T, extent=[0, 11, 11, 0], zorder=0,
                  cmap=cmap_th_mask, norm=norm)
        sc = ax.scatter(curr_sec_cell_df['x_reconstructed'], curr_sec_cell_df['y_reconstructed'],
                        color=curr_sec_cell_df['class_color'],
                        s=0.5, marker='.', zorder=1)
        ax.set_title('z='+str(sec))

        ax.set_xlabel('x_reconstructed')
        ax.set_ylabel('y_reconstructed')
        if trim_to_th:
            x_min = 2; x_max = 9; y_min = 7.5; y_max = 3.5
            ax.set_xlim((x_min,x_max))
            ax.set_ylim((y_min,y_max))
        else:
            ax.set_xlim((0,11))
            ax.set_ylim((11,0))    
        ax.axis('off')
        ax.set_xticks([])
        ax.set_yticks([])

        if i==7:
            ax.legend(legend_elements, categories, title='class', 
                      loc='center left', bbox_to_anchor=(1.1,0.5))

In [45]:
# th_zi_sections = sorted(th_df['z_section'].unique())
th_zi_sections = sorted(cells_df['z_reconstructed'].unique())[19:35]
th_mask_dilated = abc.sectionwise_dilation(th_mask, 10)

In [46]:
nn_classes = ['30 Astro-Epen', '31 OPC-Oligo', '32 OEC', '33 Vascular', '34 Immune']
th_zi_classes = ['12 HY GABA', '13 CNU-HYa Glut','17 MH-LH Glut','18 TH Glut']

#### neuronal classes

In [47]:
neuronal_cells_df = cells_df[(~cells_df['class'].isin(nn_classes))
                             & (cells_df['z_reconstructed'].isin(th_zi_sections))
                            ]  #& (cells_df['thal_subclass'])
plot_th_mask_with_cell_overlay(neuronal_cells_df, th_mask_dilated, th_zi_sections, column='class')

#### thalamus classes only

In [20]:
neuronal_th_cells_df = cells_df[(~cells_df['class'].isin(nn_classes))
                                & (cells_df['z_reconstructed'].isin(th_zi_sections))
                                & (cells_df['class'].isin(th_zi_classes))
                               ]  
plot_th_mask_with_cell_overlay(neuronal_th_cells_df, th_mask_dilated, th_zi_sections, column='class')

#### thalamus subclasses only

In [21]:
neuronal_th_cells_df = cells_df[(~cells_df['class'].isin(nn_classes))
                                & (cells_df['z_reconstructed'].isin(th_zi_sections))
                                & (cells_df['thal_subclass'])
                               ]  
plot_th_mask_with_cell_overlay(neuronal_th_cells_df, th_mask_dilated, th_zi_sections, column='class')

#### all cells

In [42]:
th_sections_df = cells_df[(cells_df['z_reconstructed'].isin(th_zi_sections))
                               ]  
plot_th_mask_with_cell_overlay(th_sections_df, th_mask_dilated, th_zi_sections, column='class')

### view actual abc_load steps

#### CCF raster image volume

In [78]:
sec = 6.8
ind = int(np.rint(sec/0.2))
plt.imshow(th_mask[200:901,350:750:,ind].T, cmap='gray') #, extent=(0,11,0,11))
ax = plt.gca()
ax.axis('off')
ax.set_xticks([])
ax.set_yticks([])

#### subset & make mask from abc_load

In [84]:
dist_px = 20
cells_df_abc, mask_img = abc.label_thalamus_spatial_subset(cells_df,
                                                            flip_y=False,
                                                            distance_px=dist_px,
                                                            cleanup_mask=True,
                                                            drop_end_sections=True,
                                                            filter_cells=True,
                                                            realigned=False)

In [85]:
th_zi_sections_example = sorted(cells_df['z_reconstructed'].unique())[25:29]

#### dilated TH mask

In [89]:
th_start_ind = int(th_zi_sections_example[0]/0.2)
th_end_ind = int(th_zi_sections_example[-1]/0.2)+1
mask_img_subset = mask_img[:,:,th_start_ind:th_end_ind]


n_col = 2
n_row = int(np.ceil(mask_img_subset.shape[2] / n_col))
fig, axes = plt.subplots(n_row, n_col, figsize=(n_col*6,n_row*3.75))
axes = axes.ravel()

x_min = 2.5; x_max = 8.5; y_min = 7.5; y_max = 4

for z in np.arange(mask_img_subset.shape[2]):
    ax = axes[z]
    ax.imshow(np.squeeze(mask_img_subset[:,:,z]).T,cmap='gray', extent=(0,11,0,11))
    ax.set_xlim(2,9)
    ax.set_ylim(3.5,7.5)
    ax.axis('off')
    ax.set_xticks([])
    ax.set_yticks([])

#### all cells with mask

In [91]:
th_sections_df = cells_df[(cells_df['z_reconstructed'].isin(th_zi_sections_example))
                         ]  
plot_th_mask_with_cell_overlay(th_sections_df, mask_img, th_zi_sections_example, 
                               column='class', trim_to_th=True)

#### TH+ZI subset 

##### with mask

In [82]:
# th_mask_dilated = abc.sectionwise_dilation(th_mask, dist_px)

cells_df_th = cells_df_abc[(cells_df_abc['z_reconstructed'].isin(th_zi_sections))
                           ]

plot_th_mask_with_cell_overlay(cells_df_th, mask_img, th_zi_sections, column='class', trim_to_th=True)

##### no mask

In [94]:
plot_th_mask_with_cell_overlay(cells_df_th, mask_img, th_zi_sections_example, 
                               column='class', trim_to_th=True, show_mask=False)

#### TH+ZI neurons with mask

In [98]:
nn_classes = ['30 Astro-Epen', '31 OPC-Oligo', '32 OEC', '33 Vascular', '34 Immune']
# th_zi_classes = ['12 HY GABA', '13 CNU-HYa Glut','17 MH-LH Glut','18 TH Glut']
neuronal_cells_df = cells_df_th[(~cells_df_th['class'].isin(nn_classes))
                             & (cells_df_th['z_reconstructed'].isin(th_zi_sections_example))
                            ]

In [99]:
plot_th_mask_with_cell_overlay(neuronal_cells_df, mask_img, th_zi_sections_example, 
                               column='class', trim_to_th=True, show_mask=True)

In [100]:
plot_th_mask_with_cell_overlay(neuronal_cells_df, mask_img, th_zi_sections_example, 
                               column='class', trim_to_th=True, show_mask=False)

#### TH+ZI+MB neurons

In [103]:
neuronal_th_cells_df = abc.filter_by_class_thalamus(neuronal_cells_df, 
                                                    filter_nonneuronal=True,
                                                    filter_midbrain=False, 
                                                    filter_other_nonTH=True)

In [104]:
plot_th_mask_with_cell_overlay(neuronal_th_cells_df, mask_img, th_zi_sections_example, 
                               column='class', trim_to_th=True, show_mask=False)

### examine end sections with poor match

this is with a 300um dilation from above

In [22]:
z = 8.4
plt.imshow(th_mask_dilated[:,:, int(np.rint(z/0.2))].T, extent=[0, 11, 11, 0])
sns.scatterplot(data=cells_df.query(f"thal_subclass & z_reconstructed=={z}"), x=coords[0], y=coords[1])
plt.ylim(7, 5)
plt.xlim(4, 7)

In [23]:
z = 4.8
plt.imshow(th_mask_dilated[:,:, int(np.rint(z/0.2))].T, extent=[0, 11, 11, 0])
sns.scatterplot(data=cells_df.query(f"thal_subclass & z_reconstructed>{z-0.1} & z_reconstructed<{z+0.1}"), x=coords[0], y=coords[1], s=4)
plt.ylim(6, 4)
plt.xlim(2, 8)

In [24]:
z = 5.6
plt.imshow(th_mask_dilated[:,:, int(np.rint(z/0.2))].T, extent=[0, 11, 11, 0])
sns.scatterplot(data=cells_df.query(f"thal_subclass & z_reconstructed>{z-0.1} & z_reconstructed<{z+0.1}"), x=coords[0], y=coords[1], s=4)
# plt.ylim(7, 5)
# plt.xlim(4, 7)

# For reference, do not rerun

In [25]:
# to break a "run all"
assert False

### Using a true dilation by radius 

takes much longer, doesn't look much different.

In [None]:
data = dict()
for radius in [1,5,10,15,20,30]:
    mask = abc.sectionwise_dilation(th_mask, radius, true_radius=True)
    abc.label_masked_cells(cells_df, mask, coords, res, field_name=field_name)
    data[radius*0.01] = (cells_df.loc[cells_df['thal_subclass']]
        .groupby('z_reconstructed')[field_name].mean().loc[lambda x: x>0])

In [None]:
for i, xy in data.items():
    plt.plot(xy, label=i)

plt.ylabel('Proportion selected')
plt.xlabel('section z')
plt.legend(title='Dilation radius')

## full 10um ccf

In [32]:
# very slow to do operations with this large image volume

# ccf_img = abc.get_ccf_labels_image()
# # takes about 1 min
# th_mask = np.isin(ccf_img, th_zi_ind)

# Experiments?

## mask into polygon

In [None]:
# import skimage as ski

# z = th_zi_sections[7]
# mask_ind = int(np.rint(z/0.2))
# ex_th_mask = np.squeeze(th_mask_dilated[:,:,mask_ind]).T

# contours = ski.measure.find_contours(ex_th_mask)

# # Display the image and plot all contours found
# fig, ax = plt.subplots()#figsize=(20,30))
# ax.imshow(ex_th_mask, cmap='gray')#, extent=(0,11,0,11))

# for contour in contours:
#     ax.plot(contour[:, 1], contour[:, 0], linewidth=1)
    
#     poly_coords = ski.measure.approximate_polygon(contour, tolerance=1.0)
#     ax.plot(poly_coords[:, 1], poly_coords[:, 0], linewidth=1)
    
# ax.axis('image')
# # ax.set_xlim(250,900)
# # ax.set_ylim(750,375)
# # ax.set_xlim(580,620)
# # ax.set_ylim(450,375)
# # ax.set_xticks([])
# # ax.set_yticks([])
# plt.show()

# print(f'{len(contour)=}')
# print(f'{len(poly_coords)=}')

In [40]:
th_start_ind = int(th_zi_sections[0]/0.2)
th_end_ind = int(th_zi_sections[-1]/0.2)+1
th_mask_subset = th_mask_dilated[:,:,th_start_ind:th_end_ind]

th_mask_sobel_edge = np.zeros_like(th_mask_subset)

In [41]:
n_col = 2
n_row = int(np.ceil(th_mask_subset.shape[2] / n_col))
fig, axes = plt.subplots(n_row, n_col, figsize=(12,30))
axes = axes.ravel()

x_min = 2.5; x_max = 8.5; y_min = 7.5; y_max = 4

for z in np.arange(th_mask_subset.shape[2]):
    ax = axes[z]
    ax.imshow(np.squeeze(th_mask_subset[:,:,z]).T,cmap='gray', extent=(0,11,0,11))
    ax.set_xlim(2,9)
    ax.set_ylim(3.5,7.5)
    # plt.xlim(3.8,4.2)
    # plt.ylim(4.5,5)

In [None]:
n_col = 2
n_row = int(np.ceil(th_mask_subset.shape[2] / n_col))
fig, axes = plt.subplots(n_row, n_col, figsize=(12,30))
axes = axes.ravel()

x_min = 2.5; x_max = 8.5; y_min = 7.5; y_max = 4

for z in np.arange(th_mask_subset.shape[2]):
    sobel_edge = ski.filters.roberts(np.squeeze(th_mask_subset[:,:,z]))
    th_mask_sobel_edge[:,:,z] = sobel_edge

    ax = axes[z]
    ax.imshow(sobel_edge.T,cmap='gray', extent=(0,11,0,11))
    ax.set_xlim(2.5,8.5)
    ax.set_ylim(4,7)
    # plt.xlim(3.8,4.2)
    # plt.ylim(4.5,5)

## remove small mistake regions

In [34]:
dilated_th_mask = abc.sectionwise_dilation(th_mask, distance_px=20)

In [35]:
plt.imshow(dilated_th_mask[:,:,42].T,cmap='gray') #, extent=(0,11,0,11))

In [36]:
max_area_ratio=0.1
mask_img = dilated_th_mask[:,:,32]

labeled_mask, n_features = ndi.label(mask_img)

# calculate the area of the largest region
largest_region = np.argmax(ndi.sum(mask_img, labeled_mask, range(n_features + 1)))
largest_area = np.sum(labeled_mask==largest_region)

# filter out regions with area ratio smaller than the specified threshold
regions_to_keep = [label for label 
                   in range(1, n_features + 1) 
                   if np.sum(labeled_mask==label) / largest_area >= max_area_ratio
                  ]

# make a new mask with only the remaining objects
new_mask_img = np.isin(labeled_mask, regions_to_keep)

plt.imshow(new_mask_img.T,cmap='gray')

In [37]:
mask_img = dilated_th_mask

new_mask_img = np.zeros_like(mask_img)
for sec in range(mask_img.shape[2]):
    mask_2d = mask_img[:,:,sec]

    labeled_mask, n_regions = ndi.label(mask_2d)

    # calculate the area of the largest region
    largest_region = np.argmax(ndi.sum(mask_2d, labeled_mask, range(n_regions+1)))
    largest_area = np.sum(labeled_mask==largest_region)

    # filter out regions with area ratio smaller than the specified threshold
    regions_to_keep = [label for label 
                       in range(1, n_regions+1) 
                       if np.sum(labeled_mask==label) / largest_area >= max_area_ratio
                      ]
    # make a new mask with only the remaining objects
    new_mask_img[:,:,sec] = np.isin(labeled_mask, regions_to_keep)
    
plt.imshow(new_mask_img[:,:,32].T,cmap='gray')