In [None]:
########## Dependencies
##### define input directory
nd2_path = "/mnt/local/data2/Bootsma/2D_CTC/src/analysis/publication_code/test_data/"

##### define correction profile directory
FF_path = "/mnt/dho-nas06/zhaolab/long_term_storage/CTC/images/FF_profiles/"

##### Libraries
import os
import sys
import math
import tifffile
import numpy as np
import pandas as pd
import dask.array as da
from skimage.restoration import rolling_ball


sys.path.append('../src/') 
import SEE_TC as ctc

In [None]:
# Functions to extract a tile for a given region
tile_size = 32 # Define the size of the tiles
half_tile_size = tile_size // 2
def extract_tile(region):
        centroid = region.centroid
        x, y = int(centroid[1]), int(centroid[0])

        # Calculate start and end indices for slicing
        start_x = x - half_tile_size
        end_x = x + half_tile_size
        start_y = y - half_tile_size
        end_y = y + half_tile_size
        
        # Initialize a zero-padded tile
        tile = np.zeros((tile_size, tile_size, image_array.shape[2]))
        # tile_bin = np.zeros((tile_size, tile_size))

        # Calculate valid ranges within the image
        img_start_x = max(0, start_x)
        img_end_x = min(image_array.shape[1], end_x)
        img_start_y = max(0, start_y)
        img_end_y = min(image_array.shape[0], end_y)

        # Calculate valid ranges within the tile
        tile_start_x = max(0, -start_x)
        tile_end_x = tile_size - max(0, end_x - image_array.shape[1])
        tile_start_y = max(0, -start_y)
        tile_end_y = tile_size - max(0, end_y - image_array.shape[0])
        
        # Copy the valid region from the image to the tile
        tile[tile_start_y:tile_end_y, tile_start_x:tile_end_x, :] = image_array[img_start_y:img_end_y, img_start_x:img_end_x, :]
        return tile
def extract_tile_bin(region):
        centroid = region.centroid
        x, y = int(centroid[1]), int(centroid[0])

        # Calculate start and end indices for slicing
        start_x = x - half_tile_size
        end_x = x + half_tile_size
        start_y = y - half_tile_size
        end_y = y + half_tile_size
        
        # Initialize a zero-padded tile
        # tile = np.zeros((tile_size, tile_size, image_array.shape[2]))
        tile_bin = np.zeros((tile_size, tile_size))

        # Calculate valid ranges within the image
        img_start_x = max(0, start_x)
        img_end_x = min(image_array.shape[1], end_x)
        img_start_y = max(0, start_y)
        img_end_y = min(image_array.shape[0], end_y)

        # Calculate valid ranges within the tile
        tile_start_x = max(0, -start_x)
        tile_end_x = tile_size - max(0, end_x - image_array.shape[1])
        tile_start_y = max(0, -start_y)
        tile_end_y = tile_size - max(0, end_y - image_array.shape[0])
        
        # Copy the valid region from the image to the tile
        tile_bin[tile_start_y:tile_end_y, tile_start_x:tile_end_x] = binary_array[img_start_y:img_end_y, img_start_x:img_end_x]
        return tile_bin

def tile_with_known_tiles(small_array, large_shape):
    # Dimensions of the small array
    h, w = small_array.shape
    num_tiles_0 = math.ceil(large_shape[0]/small_array.shape[0])
    num_tiles_1 = math.ceil(large_shape[1]/small_array.shape[1])
    
    # Compute the stride based on the number of tiles and the shape of the large array
    stride_y = (large_shape[0] - h) // (num_tiles_0 - 1)
    stride_x = (large_shape[1] - w) // (num_tiles_1 - 1)
    
    # Compute overlap based on stride and small array dimensions
    overlap_y = h - stride_y
    overlap_x = w - stride_x
    
    # Create arrays to store the sum of values and the counts of contributions
    sum_array = np.zeros(large_shape)
    count_array = np.zeros(large_shape)
    
    # Loop over each position where we place a tile
    for i in range(num_tiles_0):
        for j in range(num_tiles_1):
            # Calculate the starting position of the tile
            start_y = i * stride_y
            start_x = j * stride_x
            end_y = start_y + h
            end_x = start_x + w
            
            # Place the tile in the sum_array and update the count_array
            sum_array[start_y:end_y, start_x:end_x] += small_array
            count_array[start_y:end_y, start_x:end_x] += 1

    # Avoid division by zero by using np.maximum
    avg_array = sum_array / np.maximum(count_array, 1)
    
    return avg_array

def apply_FF_correction_impute(image, binary, channels, FF_path = FF_path):

    print("FF correction using masks from: "+FF_path)
    print(image.shape, binary.shape, channels)
    channels = [s if s == 'BF' else s[:3] for s in channels]
    print(channels)
    image = np.expand_dims(image,0)
    dim_0 = image.shape[1] # define parameters for building mosaic to match size of input slide
    dim_1 = image.shape[2]
    large_shape = (dim_0, dim_1) 
    
    binary = np.expand_dims(binary,0)
    
    img_fg, img_bg = ctc.split_foreground_background(image,binary)
    image = image.squeeze()
    img_fg = img_fg.squeeze()
    img_bg = img_bg.squeeze()    
    img_pp = np.zeros_like(image)

    print(img_pp.shape, image.shape, img_fg.shape, img_bg.shape, binary.shape, channels)
    for c in range(len(channels)):
        print("c = "+str(c))
        print("image shape = ",image.shape)
        ##### apply FF correction mask
        channel_c = channels[c]
        print("Processing "+channel_c)
        if channel_c == "BF": # no processing on brightfield
            print("===========")
            print("BF")
            img_c = image[:,:,c] # capture raw channel       
            img_pp[:,:,c] = img_c
            print(img_pp.shape, image.shape)
            print("===========")
        else: # handle all other channels
######
            from pathlib import Path
            channel_c_prefix = [s[:1] for s in channel_c]   # safe for empty strings  
            channel_c_prefix = channel_c_prefix[0]
            
            bg_pattern = "FF_profile_BG_"+channel_c_prefix+"*"
            bg_path = list(Path(FF_path).glob(bg_pattern))
            bg_path = [str(p) for p in bg_path][0]

            fg_pattern = "FF_profile_FG_"+channel_c_prefix+"*"
            fg_path = list(Path(FF_path).glob(fg_pattern))
            fg_path = [str(p) for p in fg_path][0]  


            ###########
            if os.path.isfile(bg_path) == False | os.path.isfile(fg_path) == False:
                print("skipping channel "+channels[c]+", no input")
                print(bg_path)
                print(fg_path)
                print("###############\n")                     
                continue

            FF_BG_norm = tifffile.imread(bg_path)
            FF_BG_norm = tile_with_known_tiles(FF_BG_norm, large_shape)

            FF_FG_norm = tifffile.imread(fg_path)
            FF_FG_norm = tile_with_known_tiles(FF_FG_norm, large_shape)
            
            FF_FG_norm[FF_FG_norm == 0] = 1
            img_c = image[:,:,c]
            img_bg_c = img_bg[:,:,c]

            avg_BG_est = np.nanmean(img_bg_c, axis=(0,1))
            img_corrected = ((img_c-(avg_BG_est*FF_BG_norm))/FF_FG_norm)+avg_BG_est # apply formula from Kask et al.
            img_corrected[img_corrected>65535]=65535 # clipping values that exceed uint16 range

            # apply rolling ball                
            # Make a dask array with chunks
            img_dask = da.from_array(img_corrected, chunks=(5000, 5000))

            # Apply rolling ball with overlap to avoid introducing tiling artifacts
            bg = img_dask.map_overlap(
                rolling_ball,
                depth=21,           # Overlap 21 pixels on each side
                boundary='reflect',
                dtype=img_corrected.dtype,
                radius=21           # Pass radius argument to rolling_ball
            )
            img_corrected = img_corrected - bg
            #####
            img_pp[:,:,c] = img_corrected
    return(img_pp)

# correct image, extract features, save both
def correct_and_extract(file_i, target_scope = "NA", scope_specific = False, force_correction = True):    
        bin_path_i = file_i+".bin_open.tiff"
        img_path_i = file_i+".nd2"
        meta_path_i = file_i+".image_metadata.csv"
        meta_i = pd.read_csv(meta_path_i)
        serial_i = meta_i['serial_number'].unique()
        serial_i = str(serial_i[0])

        FF_path_i = FF_path+serial_i+"/"
        
        run = True
        if scope_specific == True:
            scope_check = target_scope in FF_path_i
            if scope_check:
                run = True
            else:
                print("Skipping: "+file_i)
                run = False
            
        if run == True:

            if os.path.isfile(img_path_i) == False:
                print("skipping sample ("+img_path_i+") no image")
                print("###############\n")     
            if os.path.isfile(bin_path_i) == False:
                print("skipping sample ("+bin_path_i+") no binary")
                print("###############\n")                

            if os.path.isfile(file_i+'.bias_corrected.tiff') == True:
                print("skipping sample ("+file_i+") already processed")
                print("###############\n")     

            else: # process if not already done, as checked above            
                img_i, c_names = ctc.extract_all_array_nd2(img_path_i)
                img_i = np.transpose(img_i, (1,2,0)) # channel last
                print("Input image shape: "+str(img_i.shape))
                bin_i = tifffile.imread(file_i+".bin_open.tiff")            
                print("Input binary shape: "+str(bin_i.shape))            

                ########### account for UNET padding
                d0_bin = bin_i.shape[0]
                d1_bin = bin_i.shape[1]
                d0_img = img_i.shape[0]
                d1_img = img_i.shape[1]
                if d0_img != d0_bin and d1_img != d1_bin:
                    print("CROPPING BIN TO FIT IMG")
                    bin_i = bin_i[0:d0_img,0:d1_img]

                # get indices of each channel so they are arranged properly in output, adaptive for missing channels                
                c3 = next((i for i, s in enumerate(c_names) if s.startswith("3")), -1)
                c4 = next((i for i, s in enumerate(c_names) if s.startswith("4")), -1)
                c5 = next((i for i, s in enumerate(c_names) if s.startswith("5")), -1)
                c6 = next((i for i, s in enumerate(c_names) if s.startswith("6")), -1)
                c7 = next((i for i, s in enumerate(c_names) if s.startswith("7")), -1)
                cB = next((i for i, s in enumerate(c_names) if s.startswith("B")), -1)
                
                c_idx = [c3,c4,c5,c6,c7,cB]
                c_idx = [x for x in c_idx if x != -1]
                c_names_out = [c_names[i] for i in c_idx]
                img_i = img_i[:,:,c_idx] # re-arrange channel order here

                n_c = len(c_idx)
                c_idx = np.append(c_idx,n_c)
                
                if force_correction == True:
                    print("\n=============applying FF correction...\n=============")
                    img_pp = apply_FF_correction_impute(img_i, bin_i, c_names_out, FF_path=FF_path_i)
                    
                    # export image
                    bin_i = np.expand_dims(bin_i,-1)
                    img_pp = np.concatenate((img_pp,bin_i), axis = -1)
                    img_pp = np.transpose(img_pp, (2,0,1)) # channel first
                    print("Writing FF Corrected Image")
                    c_names_out = np.append(c_names_out, ["UNET_open"])
                    tifffile.imwrite(file_i+".bias_corrected.tiff",
                        data=img_pp,
                        metadata={
                            'axes':'CYX',
                            'spacing':0.65,
                            'orientation':'topleft',
                            "Channel":{"Name":c_names_out.tolist()}},
                        ome=True)
                
                else: # extract features, don't correct, use corrected image to extract
                    print("FF correction exists, only extracting features")
                    binary = np.expand_dims(bin_i,0)
                    binary = np.transpose(binary,(1,2,0))
                    img_pp = tifffile.imread(file_i+".bias_corrected.tiff")
                    img_pp = np.transpose(img_pp,(1,2,0))
                    print(img_pp.shape)            
                
                return img_pp


In [None]:
##### define input samples
file_names = ctc.parse_nd2_paths(nd2_path, ".bin_open.tiff", recursive=True)
file_names = [s.replace(".bin_open.tiff", "") for s in file_names]
print(f"input: {len(file_names)} {file_names[0]}")

In [None]:
##### correct instrument bias
for file_i in file_names:
    correct_and_extract(file_i)