Region Cell Counts

In [1]:
#Imports
import os
import numpy as np
import imageio
import matplotlib.pyplot as plt
from tqdm import tqdm

from skimage import io, transform
from skimage.util import img_as_float32

import concurrent.futures

# Module imports
import sys
sys.path.append('E://Documents/Professional/Jupyter notebooks/Projects/Iordonova_lab/')
# sys.path.append('brain_segmentations')

from brain_segmentations.config import *
from brain_segmentations.preprocessing.file_io import *
from brain_segmentations.registration.registration import *
from brain_segmentations.segmentation.segmentation import *



experiment_folder defined in config. Be careful in the future! Iordanova_06082022_SOC-R9-F_NeuN-cFOS/


In [2]:
# Constants and variables

region_id = 406

In [3]:
# Paths
# label_path = 'M://Brain_Registration/brainreg_napari_output/full_brain_dowsampled_tiff_chris/'
label_path = 'M://Brain_Registration/brainreg_napari_output/may10_20ds_fullz_preds/'

atlas_identifier = 'registered_atlas_original_orientation'
atlas_filename = label_path + atlas_identifier + '.tiff'

fullres_folder = 'Z://Collaboration_data/Iordonova_lab/Iordanova_06082022_SOC-R9-F_NeuN-cFOS/561nm_NeuN/'
fos_folder = 'Z://Collaboration_data/Iordonova_lab/Iordanova_06082022_SOC-R9-F_NeuN-cFOS/647nm_cFOS/'
ds_folder = 'M://Brain_Registration/downsampled_20/neun/'

# Get a list of all tif files in the folder 
identifiers = [f[:-4] for f in os.listdir(fullres_folder) if f.endswith('.tif')]

# slice_identifier = '356850_415210_029200'
slice_identifier = '356850_415210_044800' # Eventually will on a loop over all slices in the folder.

In [4]:
# Load the reference atlas in image coordinates
labels = io.imread(atlas_filename).astype(np.int32) 

# Reorder the labels to be more intuitive, where the z-axis is the 3rd dimension
labels = np.moveaxis(labels, 0, -1)

lab_xdim, lab_ydim, lab_zdim = labels.shape
labels.shape


(738, 507, 284)

In [5]:
# #Load an original-resolution tif image.

# tif_filename = fullres_folder + slice_identifier + '.tif'


# '''NOTE:
# Eventuyally we'll replace all fs_img with neun_img
# '''

# # Load the fullsized tif
# fs_img = io.imread(tif_filename)
# plt.imshow(fs_img)

# fs_zdim = len(identifiers)
# fs_zdim

In [6]:
# For a given brain region in the atlas, find the associated images and get the masked images.

# # Get the indices of the label volume that contain this region.
# region_indices = np.where(labels == region_id)
# print(np.shape(region_indices))

# import numpy as np

# Assuming `labels` is your 3D numpy array (volume)

# def get_slices_containing_region(volume, region_id):
#     # Identify where in the volume the region_id is found
#     indices = np.where(volume == region_id)

#     # indices is a tuple of 3 1D arrays (for the 3 dimensions of the volume)
#     # The third element of the tuple gives the indices in the third dimension (slices)
#     slice_indices = indices[2]

#     # Get unique slice indices, as there may be multiple voxels with region_id in a single slice
#     unique_slices = np.unique(slice_indices)

#     return unique_slices

# def test_slices_contiguity(slices):
#     # Calculate the differences between adjacent elements
#     differences = np.diff(slices)
    
#     # Check if all differences are 1 (which indicates contiguity)
#     is_contiguous = np.all(differences == 1)
    
#     return is_contiguous

# def plot_slices(labels, slices, region_id):
#     fig, axs = plt.subplots(1, 5, figsize=(15, 3))

#     # Define the slice indices ensuring we don't exceed the volume boundaries
#     slice_indices = [
#         max(slices[0] - 1, 0),
#         slices[0]+2,
#         slices[len(slices) // 2],
#         slices[-1]-2,
#         min(slices[-1] + 1, labels.shape[2] - 1)
#     ]

#     for ax, slice_index in zip(axs, slice_indices):
#         # Plot the grayscale volume slice
#         ax.imshow(labels[:, :, slice_index], cmap='gray')
        
#         # Overlay the selected region in red
#         overlay = np.where(labels[:, :, slice_index] == region_id, 1, np.nan)
#         ax.imshow(overlay, cmap='Reds', alpha=1, vmin=0, vmax=1)
        
#         ax.set_title(f'Slice {slice_index}')

#     plt.tight_layout()
#     plt.show()


In [7]:

# Test it
# for region_id in range(400,410):#np.unique(labels):  # I'm only using the first 5 unique labels for brevity
#     print(f"Region {region_id}:")
#     atlas_slices_with_region = get_slices_containing_region(labels, region_id)
#     plot_slices(labels, atlas_slices_with_region, region_id)




In [8]:
'''In Registration.py'''
# def map_label_to_img(first_slice, last_slice, identifier_list, label_volume):
#     '''Take a range of slices (first_slice to last_slice inclusive) in the label volume,
#     and map these to the corresponding image identifiers'''

#     # Get the total number of slices in the image set and label volume
#     img_zdim = len(identifier_list)
#     lab_zdim = label_volume.shape[2]

#     # Interpolate to find the corresponding image indices
#     first_img_ind = int(first_slice * img_zdim / lab_zdim)
#     last_img_ind = int((last_slice+1) * img_zdim / lab_zdim)  # +1 to make the range inclusive

#     # Err on the side of including more images by expanding the range
#     first_img_ind = max(0, first_img_ind - 1)
#     last_img_ind = min(img_zdim - 1, last_img_ind + 1)

#     # Return the corresponding image identifiers
#     return identifier_list[first_img_ind : last_img_ind + 1]  # +1 to make the range inclusive


In [9]:

region_id=400

atlas_slices_with_region = get_slices_containing_region(labels, region_id)
print(f"Slices containing region {region_id}: {atlas_slices_with_region}")

# For this region, get the corresponding slices in the full-sized images

# Use the first and last slices of the labels volume containing the region
first_slice = atlas_slices_with_region[0]
last_slice = atlas_slices_with_region[-1]

# Map this to the corresponding slices in the full-sized images with map_img_to_label()

region_img_identifiers = map_label_to_img(first_slice, last_slice, identifiers, labels)
print(f"Image identifiers containing region {region_id}: {region_img_identifiers}")




Slices containing region 400: [ 92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107 108 109
 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
 182 183 184 185 186 187 188 189 190 191]
Image identifiers containing region 400: ['356850_415210_035800', '356850_415210_035840', '356850_415210_035880', '356850_415210_035920', '356850_415210_035960', '356850_415210_036000', '356850_415210_036040', '356850_415210_036080', '356850_415210_036120', '356850_415210_036160', '356850_415210_036200', '356850_415210_036240', '356850_415210_036280', '356850_415210_036320', '356850_415210_036360', '356850_415210_036400', '356850_415210_036440', '356850_415210_036480', '356850_415210_036520', '356850_415210_036560', '356850_415210_036600', '356850_415210_03

In [10]:
# import re

'''
Added to Registration
'''

# def parse_itk_snap_label_file(file_path):
#     labels = {}
#     with open(file_path, 'r') as file:
#         for line in file:
#             # Skip lines that start with '#' or are empty
#             if line.startswith("#") or line.strip() == "":
#                 continue

#             # Split the line into fields
#             fields = line.split()

#             # Extract the index, RGB values, and label
#             index = int(fields[0])
#             rgb = (int(fields[1]), int(fields[2]), int(fields[3]))
#             # Extract the label using a regular expression
#             label = re.search(r'"(.+)"', line).group(1)

#             # Store the extracted information in a dictionary
#             labels[index] = {'rgb': rgb, 'label': label}


#     return labels


In [11]:
'''
TO UPDATE And RETURN TO REGISTRATION.py
'''


'''
Have to reorganize this pipeline, so that the scaling of the mask happens just once per image
'''
'''
Added to registration.py
'''
# def apply_region_mask(img, slice_identifier, identifiers, labels, region_id, plot=False): 
    
#     '''Apply a mask to an image to only keep the voxels corresponding to a given region'''
    
#     '''  Note: This functions seems a little overloaded, but I need to be able to pass
#       the image as it undergoes upstream image processing steps. I also need the slice_identifier and list of identifiers
#       to map the image to the label volume. Finally, I need the labels and region_id to get the mask.
#     '''

#     # Get the index of this image with respect to
#     (img_ind, lab_ind), this_atlas_slice = map_img_to_label(slice_identifier, identifiers, labels)

#     # Get the mask this region
#     mask = get_mask_from_label(region_id, labels)
    
#     # Get the corresponding slice of the mask for this image
#     mask_slice = mask[:,:,lab_ind]

#     # Use it for the mask we just created with the full-sized image
#     masked_img = apply_mask_to_img(img, mask_slice)

#     return masked_img


In [12]:
# LABEL_PATH = 'Z://Open_data_sets/EBrains/WHS_SD_rat_atlas_v4_pack/WHS_SD_rat_atlas_v4.label'#"your_label_file.txt"

label_data = parse_itk_snap_label_file(LABEL_PATH)

keys_list = list(label_data.keys())


# Run this part to prove to yourself that the labels are correct
# for region_id in keys_list:#range(3):
    
#     region_label = label_data[region_id]['label'] # Prints the label for index 1
#     # print(region_id, region_label)
#     print(f"Region {region_id}: {region_label}")
#     atlas_slices_with_region = get_slices_containing_region(labels, region_id)
#     plot_slices(labels, atlas_slices_with_region, region_id)



In [30]:
# Now run the same operations on the GPU with clesperanto
import pyclesperanto_prototype as cle

plot=False
subregion_masking = False
keys_list = list(label_data.keys()) # These are the region_ids in tha atlas

cell_counts = [] # storing tuples of (region_id, cell_count)


for slice_id in tqdm(identifiers[1151:-1]):

    # Use the example slice_id from midway through the stack
    # slice_id = identifiers[1000]

    # Load this image
    full_img = io.imread(fullres_folder + slice_id + '.tif') # < 2s

    # Once per image, get the slice of the labels volume that corresponds to this image
    (img_ind, lab_ind), this_atlas_slice = map_img_to_label(slice_id, identifiers, labels)

    # Scale the label slice to match the image dimensions 
    scaled_label_slice = transform.resize(this_atlas_slice, full_img.shape, order=0, preserve_range=True)

    # Apply a top hat filter to the image to remove background
    tophat_img = cle.top_hat_box(full_img, radius_x=20, radius_y=20)
    thresh_img = cle.threshold_otsu(tophat_img)
    lab_img = cle.voronoi_otsu_labeling(thresh_img)

    cell_counts.append((slice_id, len(np.unique(lab_img))))

    if(plot):
        # Compare 3 plots side by side (Expensive operation)
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))
        ax[0].imshow(full_img, cmap='gray', vmin=0, vmax=2000)
        ax[0].set_title('Full resolution')
        ax[1].imshow(scaled_label_slice)#, cmap='gray', vmin=0, vmax=0.05)
        ax[1].set_title('Atlas')
        ax[2].imshow(lab_img, cmap='gray', vmin=0, vmax=1)
        ax[2].set_title('Labelled')
        plt.show()

    if(subregion_masking):

        # for region_id in keys_list:#tqdm(keys_list):#range(3):
        for region_id in tqdm(list(np.unique(this_atlas_slice))):

            region_label = label_data[region_id]['label'] # Prints the label for index 1

            # Get the mask for this region (it is already scaled up to the full image size), verify..
            region_mask = get_mask_from_label(region_id, scaled_label_slice)

            # Apply the mask to the image
            masked_img = cle.mask(lab_img, region_mask)#, masked_img)
            cell_counts.append((slice_id, region_id,region_label,len(np.unique(masked_img))))

            if(plot):
                # Compare 3 plots side by side (Expensive operation)
                plt.clf()
                fig, ax = plt.subplots(1, 3, figsize=(15, 5))
                ax[0].imshow(full_img, cmap='gray', vmin=0, vmax=2000)
                ax[0].set_title('Full resolution')
                ax[1].imshow(region_mask)#, cmap='gray', vmin=0, vmax=0.05)
                ax[1].set_title(region_label)
                ax[2].imshow(masked_img, cmap='gray', vmin=0, vmax=1)
                ax[2].set_title('Labelled & Masked')
                plt.show()

        # print(region_label, 'cells: ', len(np.unique(masked_img)))


            # # Get the index of this image with respect to
            # (img_ind, lab_ind), this_atlas_slice = map_img_to_label(slice_identifier, identifiers, labels)

            # # Get the mask this region
            # mask = get_mask_from_label(region_id, labels)
            
            # # Get the corresponding slice of the mask for this image
            # mask_slice = mask[:,:,lab_ind]

            # Use it for the mask we just created with the full-sized image
            # masked_img = apply_mask_to_img(lab_img, mask_slice)


            # # print(f"Region {region_id}: {region_label}")

            # region_mask = apply_region_mask(lab_img, slice_id, identifiers, labels, region_id)


            # # If region mask is binary
            # if len(np.unique(region_mask)) == 2:

                # cell_counts.append(np.nan)


                # seg_img, pos = label_image(region_mask) #~5s
                # cell_counts.append((slice_id, region_id, pos.shape[0]))
                

100%|██████████| 1614/1614 [5:01:38<00:00, 11.21s/it]  


In [27]:
print(len(cell_counts))
counts_arr = np.asarray(cell_counts)

cell_counts_copy = cell_counts.copy()


#Show summary statistics of counts_arr
print(counts_arr)

40808
[['356850_415210_000000' '0' 'Clear Label' '567468']
 ['356850_415210_000000' '406' 'Secondary motor area' '22']
 ['356850_415210_000000' '427' 'Retrosplenial dysgranular area' '1581']
 ...
 ['356850_415210_046040' '152' 'Secondary auditory area, dorsal part'
  '132']
 ['356850_415210_046040' '153' 'Secondary auditory area, ventral part'
  '39']
 ['356850_415210_046040' '180' 'lateral olfactory tract' '9']]


Summary so far:

By sending the computations to the GPU, we've taken a ~20min per image operation down to ~40s per image. 

Also, it seems as though most of the most useful image processing functions are included in cle. Results more promsing than before. However, the counts and segmentations/ labels clearly aren't yet reflecting well the actua cells, so we'll need to fine-tune each of the steps in the Napari plugin. 
(installed in napari-env but not yet tested. )

Currently the expensive operations are still the scaling up of the label to the actual image dimensions, but this happens only once per image, so not too bad. 

Since the very different intensity between regions seems to be throwing off the counts and not solved by the current version of the tophat filter, it is worth trying to perform these steps on each masked region of each image.. of course with the risk of overfitting being present. 

Next steps will be to try and find ideal parameters for each of the image processing steps in the cle-napari plugin, and to translate them into a Python script. -  Reminder there is a helper for this. 


In [None]:



assert 1==2

In [None]:
# Get an example identifier halfway through the list
# slice_id = identifiers[len(identifiers) // 2]

keys_list = list(label_data.keys())

cell_counts = [] # storing tuples of (region_id, cell_count)

for slice_id in tqdm(identifiers):

    # Load this image
    full_img = io.imread(fullres_folder + slice_id + '.tif') # < 2s

    # Full image segmentation
    tophat_img = top_hat_transform(full_img) #~40-s
    thresh_img = threshold_image(tophat_img) #<2s

# This part only took 2s so may nto be as problemeatic as I thought.

# # Get the index of this image with respect to
# (img_ind, lab_ind), this_atlas_slice = map_img_to_label(slice_identifier, identifiers, labels)

# # Scale up the entire volume for this slice. 
# scaled_label = scale_mask(labels[:,:,lab_ind], thresh_img) 
# assert scaled_label.shape == full_img.shape

# # '''The code below has to be in the loop, so that the mask is applied for each region
# # HOWEVER, the scaling should happen just once per image, so it should be outside the loop
# # '''

    for region_id in tqdm(keys_list):#range(3):
    
        region_label = label_data[region_id]['label'] # Prints the label for index 1
        # print(f"Region {region_id}: {region_label}")

        region_mask = apply_region_mask(thresh_img, slice_id, identifiers, labels, region_id)
    
        # If region mask is binary
        if len(np.unique(region_mask)) == 2:

            seg_img, pos = label_image(region_mask) #~5s

            cell_counts.append((slice_id, region_id, pos.shape[0]))
            
            # print(pos.shape)

            # plt.clf()
            # plt.imshow(seg_img, cmap='gray', vmin=0, vmax=1)
            # plt.title(f"Segmentation: {region_label}")
            # plt.show()

        # else:
        #     print('Skipping region_id: ', region_id)

100%|██████████| 223/223 [18:57<00:00,  5.10s/it]
100%|██████████| 223/223 [18:46<00:00,  5.05s/it]s/it]
100%|██████████| 223/223 [18:55<00:00,  5.09s/it]s/it]
100%|██████████| 223/223 [17:38<00:00,  4.75s/it]s/it]
100%|██████████| 223/223 [17:37<00:00,  4.74s/it]03s/it]
100%|██████████| 223/223 [17:45<00:00,  4.78s/it]53s/it]
100%|██████████| 223/223 [17:30<00:00,  4.71s/it]32s/it]
100%|██████████| 223/223 [17:39<00:00,  4.75s/it]85s/it]
100%|██████████| 223/223 [17:35<00:00,  4.73s/it]27s/it]
100%|██████████| 223/223 [18:12<00:00,  4.90s/it]00s/it]
100%|██████████| 223/223 [18:12<00:00,  4.90s/it].64s/it]
100%|██████████| 223/223 [18:10<00:00,  4.89s/it].93s/it]
100%|██████████| 223/223 [18:12<00:00,  4.90s/it].90s/it]
100%|██████████| 223/223 [18:12<00:00,  4.90s/it].16s/it]
100%|██████████| 223/223 [18:12<00:00,  4.90s/it].55s/it]
100%|██████████| 223/223 [18:10<00:00,  4.89s/it].30s/it]
100%|██████████| 223/223 [18:14<00:00,  4.91s/it].98s/it]
100%|██████████| 223/223 [18:12<00:00

KeyboardInterrupt: 

In [None]:
assert 1==2
for slice in tqdm(region_img_identifiers):

   print(slice)
   masked_fs_img, masked_fos_img, this_atlas_slice = process_slice(slice, plot=False)
   
   # Apply image filtering, transfrom and segmentation operations to masked image
   top_hat_transformed = top_hat_transform(masked_fs_img)

   seg_img, pos = segment_image(top_hat_transformed)
   compare_segmentation(masked_fs_img, seg_img, pos)



AssertionError: 

In [None]:
# # Define the number of workers (threads) you want to use
# num_workers = 10#os.cpu_count()
# print('Number of workers: {}'.format(num_workers))

# # Create a ThreadPoolExecutor and run the process_slice function concurrently
# with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
#     results = list(tqdm(executor.map(process_slice, identifiers), total=len(identifiers)))


In [None]:
# Compare the timing per iteration above with instead just analyzing each image on a slice-by-slice basis
# NOTE: No application of atlas in this test. 
for slice in tqdm(identifiers):

#    print(slice)
#    masked_fs_img, masked_fos_img, this_atlas_slice = process_slice(slice, plot=False)

    full_img = load_single_image(slice, folder='neun')

    # Apply image filtering, transfrom and segmentation operations to masked image
    top_hat_transformed = top_hat_transform(full_img)

    seg_img, pos = segment_image(top_hat_transformed)
    compare_segmentation(masked_fs_img, seg_img, pos)


  0%|          | 0/2766 [05:30<?, ?it/s]


NameError: name 'masked_fs_img' is not defined

In [None]:
# Then, compare this method with the gridded approach

for slice in identifiers:#tqdm(identifiers):


    full_img = load_single_image(slice, folder='neun')

    crop_list = crop_image(full_img, crop_size=(1000,1000), mode='grid')

    # print(type(crop_list))
    for this_crop in tqdm(crop_list):
        # Apply image filtering, transfrom and segmentation operations to masked image
        top_hat_transformed = top_hat_transform(this_crop)

        seg_img, pos = segment_image(top_hat_transformed)
        # compare_segmentation(this_crop, seg_img, pos)

100%|██████████| 150/150 [09:11<00:00,  3.68s/it]
100%|██████████| 150/150 [09:21<00:00,  3.74s/it]
100%|██████████| 150/150 [09:18<00:00,  3.73s/it]
100%|██████████| 150/150 [09:19<00:00,  3.73s/it]
100%|██████████| 150/150 [09:20<00:00,  3.74s/it]
100%|██████████| 150/150 [09:21<00:00,  3.74s/it]
100%|██████████| 150/150 [09:12<00:00,  3.69s/it]
100%|██████████| 150/150 [09:17<00:00,  3.72s/it]
100%|██████████| 150/150 [09:13<00:00,  3.69s/it]
100%|██████████| 150/150 [09:10<00:00,  3.67s/it]
100%|██████████| 150/150 [09:10<00:00,  3.67s/it]
100%|██████████| 150/150 [09:10<00:00,  3.67s/it]
100%|██████████| 150/150 [09:11<00:00,  3.67s/it]
100%|██████████| 150/150 [09:09<00:00,  3.66s/it]
100%|██████████| 150/150 [09:07<00:00,  3.65s/it]
100%|██████████| 150/150 [09:16<00:00,  3.71s/it]
100%|██████████| 150/150 [09:11<00:00,  3.68s/it]
100%|██████████| 150/150 [08:56<00:00,  3.58s/it]
100%|██████████| 150/150 [08:59<00:00,  3.59s/it]
100%|██████████| 150/150 [08:58<00:00,  3.59s/it]


KeyboardInterrupt: 