In [2]:
import matplotlib as mpl
mpl.use('Agg')
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['text.usetex'] = False
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['figure.dpi'] = 300

from mpl_toolkits.mplot3d import Axes3D

import time, os, sys, math
from types import SimpleNamespace
from pprint import pprint
from datetime import datetime
from itertools import compress

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import colors

from cellpose import utils, models
import skimage.io
from skimage import data, img_as_int, img_as_uint, img_as_float, filters, feature, morphology, color, exposure, measure
import cv2
from scipy import ndimage as nd

import pickle


%matplotlib inline

In [3]:
## INPUT PARAMS

params = SimpleNamespace()

# This dataset channel order is:
# 0 - C1 - SRSF2 Green
# 1 - C2 - DRAQ5 Red

# Make adjustments here:
params.parent_path = '/lab/solexa_young/young_ata4/CONDENSATE_DISEASE/imaging/data/210522-SALL1-TEST/daxx_hct_tif/'
params.nuclei_path = '/lab/solexa_young/young_ata4/CONDENSATE_DISEASE/imaging/210726_cancer_nuclear_output/'
params.output_path = '/lab/solexa_young/young_ata4/CONDENSATE_DISEASE/imaging/210805_daxx_hct_output/'

params.channel = 'Airyscan'
params.channel_idx = 0  # scaffold
params.client_idx = 0  # client

params.log_tm = 3
params.erode_selem = np.ones(shape=(21,21))

if not os.path.isdir(params.output_path):
    os.mkdir(params.output_path)



In [4]:
from cellpose import plot, io

def parse_input_path(params):

    # parent_path: This function will search this entire directory (including subdirectories) to find TIFs that contain the channel name
    # channel: This is a string corresponding to a unique identifier for this channel (e.g. "488 GFP", "561 CY3", "642 CY5")
    
    # output is a sorted list of file paths specifying the images.
    parent_path = params.parent_path
    channel = params.channel
    
    file_list = []
    sample_names = []
    for root, dirs, files in os.walk(parent_path):
        for file in files:
            if ".tif" in file and channel in file: #try with czi
                file_list.append(os.path.join(root, file))                
    
    file_list.sort()
    sample_names = [os.path.basename(f)[:os.path.basename(f).find("_Airyscan Processing.tif")] for f in file_list]
   
    return file_list, sample_names


def find_spot_volume(params, channel_idx=None, client_idx=None, seg_type = 'nuc', erode_nuc_mask_flag=False, save=False):
    
    # seg_type is the type of segmentation mask type that you want to run:
    # 'nuc' = inside nucleus
    # 'cyt' = outside nucleus
    # 'all' = entire image
    
    # erode_nuc_mask_flag determines whether the nuclear outlines will be eroded a bit to avoid calling signal near the nuclear membrane
    
    for sample in params.sample_names:
        spot_data = pd.DataFrame(columns=['spot_id', 'nuc_id', 'voxels',
                                          'mean_intensity', 'c_out',
                                          'partition_ratio',
                                          'centroid_r', 'centroid_c', 
                                          # 'centroid_z'
                                         ])
        
        print()
        print(f'Running {sample} ...')
        
        img_path = [p for p in params.file_list if all([sample in p, params.channel in p])]
        
        if len(img_path) == 1:
            img0 = skimage.io.imread(img_path[0])
            print(f'Image has shape {img0.shape}')
                    
            if client_idx is not None:
                client0 = img0[:, :, client_idx].copy()
    
            if channel_idx is not None:
                if len(img0.shape) == 3:
                     img0 = img0[:, :, :]
                else:
                    img0 = img0[:, channel_idx, :, :]
                img0 = max_project(img0)
            
            img_integer = img0
            img0 = img_as_float(img0)
            client0 = img_as_float(client0)
            
            print(f'Channel shape is {img0.shape}')
            
            # Multiple options for preprocessing ("double", "subtraction")
            filtered_img = preprocess_img(img0, kind="subtraction", med_filter_size=(30,30), sigma=2)
            
            #Segmentation type
            if seg_type == 'nuc' or seg_type == 'cyt':
                nuclear_mask = load_nuclear_segmentation(params, sample)
                if erode_nuc_mask_flag:
                    nuclear_mask = erode_nuc_mask(nuclear_mask, selem=params.erode_selem)  # Specific to this dataset, as there is strong INSR signal at the nuclear periphery
            elif seg_type == 'all':
                nuclear_mask = np.ones(shape=img0.shape, dtype=bool)
            else:
                print(f'ERROR: Could not recognize {seg_type} as a segmentation type')
                sys.exit(0)
            
            if seg_type == 'all':
                nuclear_mask_for_display = nuclear_mask.astype(int)
            else:
                nuclear_mask_for_display = nuclear_mask.copy()
            
            if seg_type == 'cyt':
                # For an inverted nuclear mask
                nuclear_binary = np.ones(shape=nuclear_mask.shape, dtype=bool)
                nuclear_binary[nuclear_mask > 0] = False
                nuclear_mask = nuclear_binary

            # print nuclear intensities
            get_nuclear_intensity(img_integer,nuclear_mask)
            
                        # For when the channel is the scaffold and the client channel of interest
            spots, spots_raw, labeled_img = find_spots_LoG(filtered_img, filtered_intensity_img=filtered_img, raw_intensity_img=img0, tm=params.log_tm, plot_flag=False)
              
            inverted_total_spot_mask = np.ones(shape=img0.shape, dtype=bool)
            spots_to_keep = np.ones(len(spots), dtype=bool)
            for idx, spot in enumerate(spots):
                r = int(round(spot.centroid[0]))
                c = int(round(spot.centroid[1]))
                if nuclear_mask[r, c] < 1:
                    spots_to_keep[idx] = False
                else:
                    # z_coords = [z[0] for z in spot.coords]
                    r_coords = [r[0] for r in spot.coords]
                    c_coords = [c[1] for c in spot.coords]
#                     inverted_total_spot_mask[z_coords, r_coords, c_coords] = False  # 3D
                    inverted_total_spot_mask[r_coords, c_coords] = False  # 2D
                    
            
            #spots = list(compress(spots, spots_to_keep))
            spots = list(compress(spots_raw, spots_to_keep))
            
            # Get average intensity of each nuclei excluding the spots
            n_nuclei = np.unique(nuclear_mask.astype(int))  # if mask is binary
            # projected_nuc_mask = np.repeat(nuclear_mask[np.newaxis, :, :], img0.shape[0], axis=0)  # make 3D nuclear mask by extending it throughout z-range
            projected_nuc_mask = nuclear_mask
            
            c_out_raw = [-1] 
            c_out = [-1]
            for n in n_nuclei:
                if n > 0:  # because 0 is background
                    single_projected_nuc_mask = np.zeros(shape=projected_nuc_mask.shape, dtype=bool)
                    single_projected_nuc_mask[projected_nuc_mask == n] = True
                    combined_mask = single_projected_nuc_mask & inverted_total_spot_mask
                    
                    # For client analysis
#                     c_out.append(np.mean(client0[combined_mask]))
                    
                    # When scaffold == client
                    c_out_raw.append(np.mean(img0[combined_mask]))
                    c_out.append(np.mean(filtered_img[combined_mask]))
            
            filtered_labeled_img = np.zeros(shape=img0.shape)
            
            for idx, spot in enumerate(spots):
                # z = int(round(spot.centroid[0]))
                r = int(round(spot.centroid[0]))
                c = int(round(spot.centroid[1]))
                
                # z_coords = [z[0] for z in spot.coords]
                r_coords = [r[0] for r in spot.coords]
                c_coords = [c[1] for c in spot.coords]
                
#               filtered_labeled_img[z_coords, r_coords, c_coords] = idx + 1  # 3D
                filtered_labeled_img[r_coords, c_coords] = idx + 1  # 2D
                
                mean_intensity = spot.mean_intensity
                #nuc_c_out = c_out[nuclear_mask[r,c]]
                nuc_c_out = c_out_raw[nuclear_mask[r,c]]
                partition_ratio = mean_intensity/nuc_c_out   
                
                # get nuclear intensity for corresponding nuclei
                nuclear_mask_for_spot = np.where(nuclear_mask != nuclear_mask[r,c], nuclear_mask*0, nuclear_mask)
                labeled_img = measure.label(nuclear_mask_for_spot)
                nuclei_raw = measure.regionprops(labeled_img, intensity_image=img_integer)
                for nucleus in nuclei_raw:
                    nuc_mean_intensity = nucleus.mean_intensity


                
                spot_data = spot_data.append({'spot_id': idx + 1,
                                              'nuc_id': nuclear_mask[r, c],
                                                   'voxels': spot.area,
                                                   'nuc_mean_intesity': nuc_mean_intensity,
                                                   'mean_intensity': mean_intensity,
                                                   'c_out': nuc_c_out,
                                                   'partition_ratio': partition_ratio,
                                                   'centroid_r': r,
                                                   'centroid_c': c,
                                                   # 'centroid_z': z,
                                             },
                                                    ignore_index=True)
            if save:
                spot_data.to_excel(os.path.join(params.output_path, f'{sample}_{seg_type}_spot_data.xlsx'), header=True, index=False)
            

            # Combined overlay with spots
            fig, ax = plt.subplots(1,3, figsize=(15, 5))
#             under_img = exposure.equalize_adapthist(max_project(img0))  # 3D
            under_img = exposure.equalize_adapthist(img0)  # 2D
#             overlay = color.label2rgb(max_project(filtered_labeled_img), image=under_img, bg_label=0, bg_color=[0,0,0], alpha=0.5)  # 3D
            overlay = color.label2rgb(filtered_labeled_img, image=under_img, bg_label=0, bg_color=[0,0,0], alpha=0.5)  # 2D
            
            overlay_nuc_mask = color.label2rgb(nuclear_mask_for_display, colors=mpl.cm.get_cmap('tab20b').colors, bg_label=0, bg_color=[0, 0, 0])
            ax[0].imshow(overlay_nuc_mask)
            text_offset = 15
            regions = measure.regionprops(nuclear_mask_for_display)
            for region in regions:
                ax[0].text(region.centroid[1] - text_offset, region.centroid[0] + text_offset, str(region.label), fontsize=15, color='w')
            
            ax[1].imshow(under_img, cmap='gray')
            ax[2].imshow(overlay)
            for a in ax:
                a.set_axis_off()
            plt.tight_layout()
            
            if save:
                plt.savefig(os.path.join(params.output_path, f'{sample}_{seg_type}_spots.png'), dpi=300)
                plt.close()
            else:
                plt.show()
            print('Completed')
    
    
def preprocess_img(img0, kind="subtraction", max_filter_size=10, med_filter_size=(10,10), sigma=1):
    if kind == "double":
        output_img = nd.filters.maximum_filter(img0, size=max_filter_size) # this may throw errors if the z-dimension is smaller than 10
        output_img = nd.filters.median_filter(output_img, size=med_filter_size)
        success = True
    elif kind == "subtraction":
        img = img0.copy()
        med_img = nd.filters.median_filter(img, size=med_filter_size)
        
        output_img = img - med_img
        output_img[output_img < 0] = 0
        output_img = filters.gaussian(output_img, sigma=sigma)
        success = True

    return output_img


def find_spots_LoG(img, filtered_intensity_img, raw_intensity_img, tm=4, selem=np.ones(shape=(3,3)), plot_flag=True):
    
    '''The radius of each blob is approximately :math:`\sqrt{2}\sigma` for
    a 2-D image and :math:`\sqrt{3}\sigma` for a 3-D image.'''
    
    
    gl_img = -nd.gaussian_laplace(img, sigma=10.0)
    threshold = np.mean(gl_img) + tm*np.std(gl_img)

    binary_img = np.zeros(shape=gl_img.shape, dtype=bool)
    binary_img[gl_img > threshold] = True
    
#     binary_img = morphology.opening(binary_img, selem=np.ones(shape=(1, 3,3)))  # 3D
    binary_img = morphology.binary_opening(binary_img, selem)  # 2D
    labeled_img = measure.label(binary_img)
    spots = measure.regionprops(labeled_img, intensity_image=filtered_intensity_img)
    spots_raw = measure.regionprops(labeled_img, intensity_image=raw_intensity_img)
    
    if plot_flag:
        plt.close()
        fig, ax = plt.subplots(1,3,figsize=(20,60))
        ax = ax.flatten()
        ax[0].imshow(exposure.equalize_adapthist(max_project(img)), cmap='gray')
        ax[1].imshow(max_project(gl_img), cmap='gray')
        ax[2].imshow(max_project(binary_img), cmap='gray')
        for a in ax:
            a.set_axis_off()
        plt.tight_layout()
        plt.show()
        
    return spots, spots_raw, labeled_img


def find_spots_threshold(img, intensity_img, nuclear_mask, tm=4, selem=np.ones(shape=(3,3)), plot_flag=True):
    
    #threshold of just nuclei
    nuc_intensities=img*img_as_float(nuclear_mask)
    print(nuc_intensities[nuc_intensities > 0 ])
    threshold = np.mean(nuc_intensities[nuc_intensities > 0 ]) + tm*np.std(img*nuclear_mask)
    print(threshold)
    
    binary_img = np.zeros(shape=img.shape, dtype=bool)
    binary_img[img >= threshold] = True
    
#     binary_img = morphology.opening(binary_img, selem=np.ones(shape=(1, 3,3)))  # 3D
    binary_img = morphology.binary_opening(binary_img, selem)  # 2D
    labeled_img = measure.label(binary_img)
    spots = measure.regionprops(labeled_img, intensity_image=intensity_img)
    
    if plot_flag:
        plt.close()
        fig, ax = plt.subplots(1,3,figsize=(20,60))
        ax = ax.flatten()
        ax[0].imshow(exposure.equalize_adapthist(max_project(img)), cmap='gray')
        ax[1].imshow(max_project(gl_img), cmap='gray')
        ax[2].imshow(max_project(binary_img), cmap='gray')
        for a in ax:
            a.set_axis_off()
        plt.tight_layout()
        plt.show()
        
    return spots, labeled_img

def get_nuclear_intensity(intensity_img, nuclear_mask):
    labeled_img = measure.label(nuclear_mask)
    nuclei_raw = measure.regionprops(labeled_img, intensity_image=intensity_img)
    
    for nucleus in nuclei_raw:
        print(nucleus.mean_intensity)
    
    
def fix_mask_labels(mask):
    # This function will take a label mask and renumber the labels so that they are contiguous
    
    unique_labels = np.unique(mask)
    x = len(unique_labels) - 1  # The number of labels should equal the max label # - 1 because 0 is included

    while np.max(mask) > x:
        for i in np.arange(0, x+1):  # The +1 is to include the last label
            if i > 0:
                label_test = np.any(mask == i)
                if not label_test:
                    mask[np.nonzero(mask > i)] = mask[np.nonzero(mask > i)] - 1

    return mask


def load_nuclear_segmentation(params, sample_name):
    parent_path = params.nuclei_path
    
    print(sample_name)
    output = []
    for root, dirs, files in os.walk(parent_path):
        for file in files:
            if "_seg.npy" in file and sample_name in file:
                output.append(os.path.join(root, file))
    success = False
    data = None
    if len(output) > 1:
        print(f'Warning: Found multiple nuclei files for {sample_name}')
    elif len(output) == 0:
        print(f'Warning: Could not find nuclei file for {sample_name}')
        sys.exit(0)
    elif len(output) == 1:
        data = np.load(output[0], allow_pickle=True)
        
        data = np.load(output[0], allow_pickle=True).item()
        data = data['masks']  
        success = True
    else:
        print(f'Warning: Something went wrong with finding the nuclei file for {sample_name}')
        
    return data


def erode_nuc_mask(mask, selem=np.ones(shape=(3,3))):
    nuc_labels = np.unique(mask)
    
    output_labeled_mask = np.zeros(shape=mask.shape, dtype=int)
    
    for n in nuc_labels:
        if n > 0:
            temp_mask = np.zeros(shape=mask.shape, dtype=bool)
            temp_mask[mask == n] = True
            

            
            temp_mask = morphology.binary_erosion(temp_mask, selem=selem)
            temp_mask = nd.binary_fill_holes(temp_mask)

            
            output_labeled_mask[temp_mask] = int(n)
            
    return output_labeled_mask
    


def max_project(img):
    projection = np.max(img, axis=0)
    
    return projection


def find_img_channel_name(file_name):
    str_idx = file_name.find('Conf ')  # this is specific to our microscopes file name format
    channel_name = file_name[str_idx + 5 : str_idx + 8]
    channel_name = 'ch' + channel_name

    return channel_name


def get_file_extension(file_path):
    file_ext = os.path.splitext(file_path)
    
    return file_ext[1]  # because splitext returns a tuple and the extension is the second element


def get_sample_name(file_path, extension='.nd'):
    basename = os.path.basename(file_path)
    idx = basename.find("_w")

    sample_name = basename[:idx]
    
    if "Airyscan" in sample_name:
        sample_name = sample_name.replace("_Airyscan Processing", "")
        
    if "." in sample_name:
        sample_name = sample_name.replace(".", "dot")
        
    return sample_name


def axis_off(ax):
    ax.set_axis_off()



In [5]:
### MAIN ###

# parse directory
file_list, sample_names = parse_input_path(params)

params.sample_names = sample_names
params.file_list = file_list

find_spot_volume(params, channel_idx=params.channel_idx, client_idx=params.client_idx, seg_type = 'nuc', erode_nuc_mask_flag=False, save=True)
        
print(f'Completed at {datetime.now()}')




Running 210716 HCT116 DAXX-M_2021_07_19__17_50_19 ...
Image has shape (9, 2, 1560, 1560)
Channel shape is (1560, 1560)
210716 HCT116 DAXX-M_2021_07_19__17_50_19
0.06999626726390444
0.05889423076923077
0.7216476398382421
0.37036948406241027
64.99302140537156
32.62294703642448
0.06277268678573046
0.06611570247933884
0.04952503919579452
0.03762185511645555
Completed

Running 210716 HCT116 DAXX-M_2021_07_19__17_53_23 ...
Image has shape (10, 2, 1560, 1560)
Channel shape is (1560, 1560)
210716 HCT116 DAXX-M_2021_07_19__17_53_23
0.3171916754393068
0.3953139241863707
0.7670097753609115
0.35560269419174245
0.509967181232527
1.0087679592281513
13.035638238260674
72.79212598425197
180.04348231491267
0.3557521572536681
8.495742879904402
54.434630738522955
114.98776727595171
2.031355830909598
0.6367571884984026
0.6605118059753439
0.3596070760699009
183.09405792348167
5.696126603279944
1.0093001261034047
0.3521508544490277
0.5001380834023751
0.26210319892143646
0.4592725764071359
0.998860868882080