# Preprocessing Notebook

It appears that performing a high pass gaussian filter, followed by Richardson-Lucy deconvolution, and low pass gaussian blur best resolves dots. Additionally, it appears to not increase the intensity of "fake dots" and really separate nearby dots. This method was adapted from the following reference.

1.Moffitt, J. R. et al. High-throughput single-cell gene-expression profiling with multiplexed error-robust fluorescence in situ hybridization. PNAS 113, 11046–11051 (2016).


In [1]:
debug = True

Debug True automatically gets a test tif

In [2]:
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2


#Switch Directories to import functions
#---------------------------------------------
old_cwd = os.getcwd()
os.chdir('/home/nrezaee/data-pipeline/')
#---------------------------------------------


#Importing Pipeline Functions
#---------------------------------------------
from datapipeline_tools.dot_detection.preprocessing.preprocess import blur_back_subtract_3d
from datapipeline_tools.dot_detection.preprocessing.preprocess import blur_3d
from datapipeline_tools.dot_detection.preprocessing.preprocess import tophat_3d

from datapipeline_tools_dev.dot_detection.helpers.background_subtraction import get_shifted_background

import load_tiff
#---------------------------------------------

#Switch Back
#---------------------------------------------
os.chdir(old_cwd)
#---------------------------------------------

In [3]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np

def plot_2d_image(img_2d, zmax):
    
    #For Plotting 2d image
    #-------------------------------------------
    fig = px.imshow(
        img_2d,
        width=700,
        height=700,
        binary_string=True,
        binary_compression_level=4,
        binary_backend='pil',
        zmax = zmax
    )
    
    fig.show()

## Get Tiff file Path

In [4]:

# if debug == True:
    
#     #Get inputs about tiff src
#     #---------------------------------------------
#     personal = 'Lex'
#     exp_name = '/groups/CaiLab/personal/Lex/raw/20k_dash_063021_3t3/'
#     hyb = '0'
#     position = '0'
#     num_channels = '5'
#     #---------------------------------------------

#     #Create tiff src
#     #---------------------------------------------
#     tiff_src = os.path.join('/groups/CaiLab/personal', personal, 'raw', exp_name, 'HybCycle_' + hyb, 'MMStack_Pos' + str(position) + '.ome.tif')
#     print(f'{tiff_src=}')
#     #---------------------------------------------
    
    
# else:
#     #Get inputs about tiff src
#     #---------------------------------------------
#     personal = input('Personal [nrezaee] : ')
#     exp_name = input('Experiment Name [2020-08-08-takei]: ');
#     hyb = input('HybCycle [2] : ');
#     position = input('Position [0]: ');
#     num_channels = input('Number of Channels in tiff [0]: ');
#     #---------------------------------------------

#     #Create tiff src
#     #---------------------------------------------
#     tiff_src = os.path.join('/groups/CaiLab/personal', personal, 'raw', exp_name, 'HybCycle_' + hyb, 'MMStack_Pos' + str(position) + '.ome.tif')
#     print(f'{tiff_src=}')
#     #---------------------------------------------

### Load Tiff

In [None]:
import sys

#Import load tiff function from pipeline
#---------------------------------------------
from datapipeline_tools.load_tiff import tiffy
#---------------------------------------------

#Load tiff
#---------------------------------------------
tiff = tiffy.load("/groups/CaiLab/personal/Lex/raw/20k_dash_063021_3t3/notebook_pyfiles/aberr_corrected", "3")
#---------------------------------------------

print(f'{tiff.shape=}')

In [None]:
tiff = np.swapaxes(tiff,0,1)

### Get Specific Channel

In [None]:
if debug:
    ch = 0
else:
    ch = int(input('Channel to Visualize (Index starting at Zero): '))

tiff_3d = tiff[:, ch]

### Function for plotting tiff 3d

In [None]:
def plot_tiff_3d(tiff_ch, vmax=10, figsize=(10,10), log = True):
    """
    Inputs:
        tiff_ch : a 3d tiff channel
        vmax : increase to decrease brightness, decrease to increase brightness
        figsize : Figure Size
        log : Takes log of image for visualization
    Outputs:
        Plots of 3d image
    """
    
    #Make Log of tiff for visualization
    #----------------------------------------
    if log == True:
        tiff_ch = np.log(tiff_ch)
    #----------------------------------------
    
    #Loop through each z 
    #----------------------------------------
    for i in range(len(tiff_ch)):
        plt.figure(figsize=figsize)
        plt.imshow(tiff_ch[i], cmap='gray', vmax=vmax)
    #----------------------------------------
    
#Plot Raw image
#----------------------------------------
plot_tiff_3d(tiff[:, 0], figsize= (10,10), vmax=8)
#----------------------------------------

## Background Subtract the Image

In [None]:
def yn_choice(message, default='y'):
    """
    Makes message and input with a binary choice
    """
    choices = 'Y/n' if default.lower() in ('y', 'yes') else 'y/N'
    choice = input("%s (%s) " % (message, choices))
    values = ('y', 'yes', '') if choices == 'Y/n' else ('y', 'yes')
    return choice.strip().lower() in values

if debug:
    bool_back_sub = True
else:
    bool_back_sub = yn_choice('Do you want background subtraction?')
    print(f'{bool_back_sub=}')
    
if bool_back_sub == False:
    back_subtracted_3d = tiff_3d

### Get the Shifted Background

In [None]:
if bool_back_sub == True:
    
    if debug:
        analysis_name = 'lex_align'
    else:
        analysis_name = input('Analysis Name of Experiment (in order to get offsets): ')

    #Get Background
    #----------------------------------------
    final_background_src = os.path.join('/groups/CaiLab/personal', personal, 'raw', exp_name, 'final_background', 'MMStack_Pos' + str(position) + '.ome.tif')
    back_tiff = tiffy.load(final_background_src)
    back_ch = back_tiff[:, ch]
    #----------------------------------------

    #Shift the Background
    #----------------------------------------
    shifted_back_ch = get_shifted_background(back_ch, tiff_src, analysis_name)
    #----------------------------------------
    
    #Plot Raw image
    #----------------------------------------
    plot_tiff_3d(back_ch, figsize= (10,10), vmax=7)
    #----------------------------------------


### Subtract the Image and plot Raw Image, Background, and Subtracted Image

In [None]:
if bool_back_sub == True:
    
    #Blur Background Beforehand
    #----------------------------------------
    background_blur_kernel_size = 1
    back_ch_blurred = blur_3d(back_ch, background_blur_kernel_size).astype(np.uint16)
    #----------------------------------------
    
    
    #Subtracted the image
    #----------------------------------------
    back_subtracted_3d = cv2.subtract(tiff_3d, back_ch_blurred)
    #----------------------------------------
    
    #plot_tiff_3d(back_subtracted_3d, log = False)

    back_check_dir_dst = 'Background_Subtraction_Checks'
    os.makedirs(back_check_dir_dst, exist_ok = True)
    for i in range(len(tiff_3d)):

        fig, axs = plt.subplots(1, 3, figsize = (20,20))
        axs[0].imshow(np.log(tiff_3d[i]), cmap='gray', vmax= 8)
        axs[0].title.set_text('Raw image Z ' + str(i)) 

        axs[1].imshow(np.log(back_ch_blurred[i]), cmap='gray', vmax = 8)
        axs[1].title.set_text('Background image Z ' + str(i)) 

        axs[2].imshow(back_subtracted_3d[i], cmap='gray', vmax=1200)
        axs[2].title.set_text('Back Subtracted image image Z ' + str(i)) 

        fig_dst = os.path.join(back_check_dir_dst, 'z_' + str(i) + '.png')
        fig.savefig(fig_dst)

In [None]:
#Plot Background image
#----------------------------------------
plot_tiff_3d(back_subtracted_3d, figsize= (10,10), vmax=3000, log=False)
#----------------------------------------

# Skimage Rolling ball

In [None]:
# from skimage import color, data, restoration, util
# image = []
# for i in range(2):
#     background = restoration.rolling_ball(back_subtracted_3d[i], radius=5, num_threads=8)
#     rb_processed = back_subtracted_3d[i]-background
#     image.append(rb_processed)
# rb_processed = np.array(image)

In [None]:
# plot_2d_image(rb_processed[0], zmax=2000)

In [None]:
# plot_2d_image(back_subtracted_3d[0], zmax=2000)

# High pass mean filter

In [None]:
import cv2
from skimage import util
import warnings

warnings.filterwarnings("ignore")

def high_pass_mean(img, kern=25):
    """A high pass mean filter
    Parameters
    ----------
    img = z,c,x,y
    kern = int
    """
    #generate kernel
    kernel = np.ones((kern,kern),np.float32)/kern**2
    #blur the image and subtract
    z_slice = []
    for z in range(img.shape[0]):
        channel_slice = []
        for c in range(img.shape[1]):
            #mean filter
            blur = cv2.filter2D(img[z][c],-1,kernel)
            #subtract
            filtered = util.img_as_int(img[z][c])-util.img_as_int(blur)
            #set negative values to zero
            filtered[filtered<0]=0
            channel_slice.append(filtered)
        z_slice.append(channel_slice)
    return np.array(z_slice)

In [None]:
hpmf_image = high_pass_mean(tiff, kern=25)

In [None]:
plot_2d_image(hpmf_image[1][0], zmax=1000)

A high pass mean filter matches the output of Nick's rolling ball. It should be noted that the definition for rolling ball between ImageJ and skimage is different. ImageJ does a mean blur with a circular selection element. Then it subtracts the blurred image from the real image. Skimage calculates the local minima for a given area using a circular selection element. Then, it subtracts that value from the real image.

# High pass gaussian filter

In [None]:
import cv2
from skimage import util
import warnings

warnings.filterwarnings("ignore")

def high_pass_gaussian(img, kern=9):
    """A high pass gaussian filter
    Parameters
    ----------
    img = z,c,x,y
    kern = int
    """
    #generate kernel
    kernel = np.ones((kern,kern),np.float32)/kern**2
    #blur the image and subtract
    z_slice = []
    for z in range(img.shape[0]):
        channel_slice = []
        for c in range(img.shape[1]):
            #gaussian filter
            blur = cv2.GaussianBlur(img[z][c],(kern,kern),cv2.BORDER_DEFAULT)
            #subtract
            filtered = util.img_as_int(img[z][c])-util.img_as_int(blur)
            #set negative values to zero
            filtered[filtered<0]=0
            channel_slice.append(filtered)
        z_slice.append(channel_slice)
    return np.array(z_slice)

In [None]:
hpgb_image = high_pass_gaussian(tiff, kern=9)

In [None]:
plot_2d_image(hpgb_image[1][0], zmax=700)

# Deconvolution using Richardson-Lucy algorithm: Use after some form of high pass filter to prevent fake dots form becoming real dots

In [None]:
from skimage import color, data, restoration, util
from tqdm import tqdm

def gaussian_kernel(size, sigma):
    """Generates a gaussian kernel where the sum of all values equals 1.
    Parameters
    ----------
    size = int
    sigma = int
    normalize = bool"""
    half = size//2
    if size % 2 == 0:
        center = half - 0.5
    else:
        center = half
    X, Y = np.mgrid[0:size, 0:size]
    exp = (X - center)**2 / (2 * sigma * sigma) + (Y - center)**2 / (2 * sigma * sigma)
    kern = (1/(2*np.pi*(sigma**2)))*np.exp(-exp)
    
    return kern

def RL_deconvolution(image, kernel=1, sigma=(1.8,1.6,1.5,1.3), microscope = "boc"):
    """Assuming a gaussian psf, images are deconvoluted using the richardson-lucy algorithm
    Parameters
    ----------
    image = multi or single array of images
    kernel = int
    sigma = define sigma at each channel (750nm,647nm,555nm,488nm)
    microscope = use preset sigmas for defined scope ("boc" and "lb")
    """
    # defined sigma from testing
    sigma_dict = {"boc":[1.8,1.6,1.5,1.3],"lb":[2.0,1.7,1.3,1.2]}
    
    #check to see if it is one z
    if len(image) == 1:
        channel_slice = []
        #perform deconvolution on each channel
        if microscope == None:
            for c in range(image.shape[1]):
                psf = gaussian_kernel(kernel, sigma[c])
                adj_img = util.img_as_float(image[c]) + 1E-4
                deconvolved_RL = restoration.richardson_lucy(adj_img, psf, 20)
                channel_slice.append(deconvolved_RL)
            return psf, util.img_as_uint(np.array(channel_slice))
        else:
            ch_sigma = sigma_dict[microscope]
            for c in range(image.shape[1]):
                psf = gaussian_kernel(kernel, ch_sigma[c])
                adj_img = util.img_as_float(image[c]) + 1E-4
                deconvolved_RL = restoration.richardson_lucy(adj_img, psf, 20)
                channel_slice.append(deconvolved_RL)
            return psf, util.img_as_uint(np.array(channel_slice))
    
    else:
        if microscope == None:
            z_slice=[] 
            #go across z's and channels
            for z in tqdm(range(image.shape[0])):
                channel_slice=[]
                #deconvolution
                for c in range(image.shape[1]):
                    psf = gaussian_kernel(kernel, sigma[c])
                    adj_img = util.img_as_float(image[z][c]) + 1E-4
                    deconvolved_RL = restoration.richardson_lucy(adj_img, psf, 20)
                    channel_slice.append(deconvolved_RL)
                z_slice.append(channel_slice)
            img_arr = np.array(z_slice)
            img_arr = util.img_as_uint(img_arr)
            return psf, img_arr
        else:
            ch_sigma = sigma_dict[microscope]
            z_slice=[] 
            #go across z's and channels
            for z in tqdm(range(image.shape[0])):
                channel_slice=[]
                #deconvolution
                for c in range(image.shape[1]):
                    psf = gaussian_kernel(kernel, ch_sigma[c])
                    adj_img = util.img_as_float(image[z][c]) + 1E-4
                    deconvolved_RL = restoration.richardson_lucy(adj_img, psf, 20)
                    channel_slice.append(deconvolved_RL)
                z_slice.append(channel_slice)
            img_arr = np.array(z_slice)
            img_arr = util.img_as_uint(img_arr)
            return psf, img_arr

In [None]:
#good old Lucy
psf, rl_img_hpmf = RL_deconvolution(hpmf_image[:,:4,:,:], kernel = 5,
                                    sigma=(1.8,1.6,1.5,1.3), microscope = "lb")

In [None]:
#good old Lucy
psf, rl_img_hpgb = RL_deconvolution(hpgb_image[:,:4,:,:], kernel = 5,
                                    sigma=(1.8,1.6,1.5,1.3), microscope = "lb")

In [None]:
plt.imshow(psf)

In [None]:
# #good old Lucy
# psf, rl_img_nofilt = RL_deconvolution(back_subtracted_3d, kernel = 5,sigma = 2)

In [None]:
# fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

# # Plot the surface.
# X = np.arange(0, 10, 1)
# Y = np.arange(0,10, 1)
# X, Y = np.meshgrid(X, Y)
# ax.plot_surface(X, Y, psf, cmap="coolwarm", linewidth=0, antialiased=False)

# # rotate the axes and update
# for angle in range(40, 360):
#     ax.view_init(40, angle)
#     plt.draw()
#     plt.pause(.01)

# HPMF after deconvolution 

In [None]:
# import cv2
# image = []
# kernel = np.ones((25,25),np.float32)/625
# #blur the image with 25 pixel kernel
# for i in range(2):
#     image.append(cv2.filter2D(rl_img_nofilt[i],-1,kernel))
# blur_hpmf_aft_rl = np.array(image)

In [None]:
# from skimage import util
# #subtract blurred background from real image
# hpmf_aft_rl_image = []
# for i in range(len(blur_hpmf_aft_rl)):
#     new = util.img_as_int(rl_img_nofilt[i])-util.img_as_int(blur_hpmf_aft_rl[i])
#     #make negative values 0
#     new[new<0]=0
#     hpmf_aft_rl_image.append(new)
# hpmf_aft_rl_image = np.array(hpmf_aft_rl_image)

In [None]:
# plot_2d_image(hpmf_aft_rl_image[0], zmax=1000)

# Compare all before low pass filter

In [None]:
plot_2d_image(rl_img_hpmf[0][1], zmax=3000)

In [None]:
plot_2d_image(rl_img_hpgb[0][1], zmax=1000)

In [None]:
plot_2d_image(tiff[0][1], zmax=4000)

In [None]:
#plot_2d_image(rl_img_nofilt[0], zmax=4000)

In [None]:
#plot_2d_image(hpmf_aft_rl_image[0], zmax=1000)

# Low pass filter

In [None]:
def low_pass_gaussian(image, kern=3):
    """A low pass gaussian blur
    Parameters
    ----------
    image = single or list of arrays
    kern = int
    """
    z_slice = []
    for z in range(image.shape[0]):
        channel_slice = []
        for c in range(image.shape[1]):
            channel_slice.append(cv2.GaussianBlur(image[z][c],(kern,kern),cv2.BORDER_DEFAULT))
        z_slice.append(channel_slice)
    return np.array(z_slice)

In [None]:
#lpgf
rl_img_lpmf = low_pass_gaussian(rl_img_hpmf, kern = 3)
rl_img_lpgb = low_pass_gaussian(rl_img_hpgb, kern = 3)

In [None]:
plot_2d_image(rl_img_lpmf[0][1], zmax=3000)

In [None]:
plot_2d_image(rl_img_lpgb[0][2], zmax=400)

In [None]:
plot_2d_image(tiff[0][2], zmax=1000)

## Blur the Image

**cv2.blur** is used on each z slice.

Supplemental Information:
https://www.geeksforgeeks.org/python-opencv-cv2-blur-method/

In [None]:
# blur_kernel_size = 1

# #Blur image 3d 
# #----------------------------------------
# blurred_3d = blur_3d(back_subtracted_3d, blur_kernel_size)
# #----------------------------------------

# #Plot Blurred image
# #----------------------------------------
# plot_tiff_3d(blurred_3d, figsize= (10,10), vmax=1500, log=False)
# #----------------------------------------

# Tophat

**cv2.tophat** is used on each z slice.

Supplemental information:

https://www.geeksforgeeks.org/top-hat-and-black-hat-transform-using-python-opencv/
https://theailearner.com/tag/top-hat-transform-opencv/

In [None]:
# tophat_kernel_size = 6


# #Tophat 3d
# #----------------------------------------
# tophatted_3d = tophat_3d(rolling_ball_3d, tophat_kernel_size)
# #----------------------------------------

# #Plot tiff 3d
# #----------------------------------------
# plot_tiff_3d(tophatted_3d, figsize= (10,10), vmax=1000, log= False)
# #----------------------------------------