In [1]:
import os
########## Dependencies
##### define neural net weights
model_path = "/mnt/local/data2/Bootsma/2D_CTC/src/analysis/publication_code/weights/01_SEE_TC_classify_single_cells.hdf5"

##### define samples to process
img_dir = "/mnt/local/data2/Bootsma/2D_CTC/src/analysis/publication_code/test_data/"

##### define GPU indices to use
gpu_index = "0"  # Specify the GPUs you want to use
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

##### define suffix for input tiff
input_tiff_suffix = ".bias_corrected.tiff"

##### define suffix for output table 
# must match the suffix of your physical feature table
# results will be appended into said table
feature_suffix = ".region_features.tsv"

##### Libraries
import sys
import tifffile
import numpy as np
import pandas as pd
import tensorflow as tf
from scipy.ndimage import label
from skimage.measure import regionprops
from concurrent.futures import ThreadPoolExecutor
from tensorflow.keras.models import load_model

sys.path.append('../src/') 
import SEE_TC as ctc
##### load model
with tf.device(f'/GPU:{gpu_index}'):
    classifier = load_model(model_path)

In [None]:
# Define input files
file_names = ctc.parse_nd2_paths(img_dir, input_tiff_suffix, recursive=True)
sample_IDs = [s.replace(input_tiff_suffix, "") for s in file_names]
print(len(sample_IDs))
print(sample_IDs[0])


In [None]:
class_values = ['Cell', 'Doubleton', 'Cluster', 'Debris']
probsCS_names = ['probCS_cell', 'probCS_doubleton', 'probCS_cluster', 'probCS_debris']


for i in range(len(sample_IDs)):
        sample_ID_i = sample_IDs[i]

        print("Processing "+sample_ID_i)
        # load image
        img_i = tifffile.imread(sample_ID_i+input_tiff_suffix) # get image
        c_names = ctc.get_channel_names_tiff(sample_ID_i+input_tiff_suffix)

        # use channel name recognition to extract predictive channels
        c3 = [i for i, s in enumerate(c_names) if s.startswith('3')][0]
        c_UNET = [i for i, s in enumerate(c_names) if s.startswith('UNET')][0]
        c_to_use = [c_UNET,c3] # binary and Hoechst

        img_i = img_i.transpose(1,2,0) # channel last
        binary_i = img_i[:,:,c_UNET]
        ######### Handle 20x
        input_magnification = ctc.get_magnification(sample_ID_i+".nd2")
                
        if input_magnification != 10:
                scale_factor = input_magnification/10 # unet is trained at 10x so scale to fit that
                img_i = ctc.downscale_by_factor_of(img_i, scale_factor) # channel last
                binary_i = ctc.downscale_by_factor_of(binary_i, scale_factor) # channel last

        #########

        # Assuming image_array and binary_array are your inputs
        image_array = img_i
        binary_array = binary_i
        binary_array[binary_array>0]=255 # CNN was trained on 0-255 for binary... 

        labeled_array, n_obj = label(binary_array) # Label the binary regions
        regions = regionprops(labeled_array) # Extract region properties

        tile_size = 32 # Define the size of the tile
        half_tile_size = tile_size // 2

        # Function to extract a tile for a given region
        def extract_tile(region):
                centroid = region.centroid
                x, y = int(centroid[1]), int(centroid[0])

                # Calculate start and end indices for slicing
                start_x = x - half_tile_size
                end_x = x + half_tile_size
                start_y = y - half_tile_size
                end_y = y + half_tile_size
                
                # Initialize a zero-padded tile
                tile = np.zeros((tile_size, tile_size, image_array.shape[2]))
                # tile_bin = np.zeros((tile_size, tile_size))

                # Calculate valid ranges within the image
                img_start_x = max(0, start_x)
                img_end_x = min(image_array.shape[1], end_x)
                img_start_y = max(0, start_y)
                img_end_y = min(image_array.shape[0], end_y)

                # Calculate valid ranges within the tile
                tile_start_x = max(0, -start_x)
                tile_end_x = tile_size - max(0, end_x - image_array.shape[1])
                tile_start_y = max(0, -start_y)
                tile_end_y = tile_size - max(0, end_y - image_array.shape[0])
                
                # Copy the valid region from the image to the tile
                tile[tile_start_y:tile_end_y, tile_start_x:tile_end_x, :] = image_array[img_start_y:img_end_y, img_start_x:img_end_x, :]
                # tile_bin[tile_start_y:tile_end_y, tile_start_x:tile_end_x] = binary_array[img_start_y:img_end_y, img_start_x:img_end_x]
                # print(tile.shape, tile_bin.shape)
                return tile
        def extract_tile_bin(region):
                centroid = region.centroid
                x, y = int(centroid[1]), int(centroid[0])

                # Calculate start and end indices for slicing
                start_x = x - half_tile_size
                end_x = x + half_tile_size
                start_y = y - half_tile_size
                end_y = y + half_tile_size
                
                # Initialize a zero-padded tile
                # tile = np.zeros((tile_size, tile_size, image_array.shape[2]))
                tile_bin = np.zeros((tile_size, tile_size))

                # Calculate valid ranges within the image
                img_start_x = max(0, start_x)
                img_end_x = min(image_array.shape[1], end_x)
                img_start_y = max(0, start_y)
                img_end_y = min(image_array.shape[0], end_y)

                # Calculate valid ranges within the tile
                tile_start_x = max(0, -start_x)
                tile_end_x = tile_size - max(0, end_x - image_array.shape[1])
                tile_start_y = max(0, -start_y)
                tile_end_y = tile_size - max(0, end_y - image_array.shape[0])
                
                # Copy the valid region from the image to the tile
                # tile[tile_start_y:tile_end_y, tile_start_x:tile_end_x, :] = image_array[img_start_y:img_end_y, img_start_x:img_end_x, :]
                tile_bin[tile_start_y:tile_end_y, tile_start_x:tile_end_x] = binary_array[img_start_y:img_end_y, img_start_x:img_end_x]
                # print(tile.shape, tile_bin.shape)
                return tile_bin

        print("Normalizing stack for CNN...")
        # Parallel extraction of tiles
        with ThreadPoolExecutor() as executor:
                tiles = list(executor.map(extract_tile, regions))
        tiles = np.stack(tiles, axis = 0)
        with ThreadPoolExecutor() as executor:
                tiles_bins = list(executor.map(extract_tile_bin, regions))
        tiles_bins = np.stack(tiles_bins, axis = 0)

        # normalize at the group level, i.e., across tiles. Accounting for outliers (clipping above/below 1.5*IQR)
        tiles_norm = ctc.preprocess_autoEncoder_01(tiles)
        tiles_hoechst = tiles_norm[:,:,:,c3]
        
        tiles_bins[tiles_bins > 0] = 255
        tiles_bins = np.expand_dims(tiles_bins, axis=-1)  
        z_stack = np.concatenate((tiles_bins, np.expand_dims(tiles_hoechst,-1)), axis = 3) # merge all channels with binary

        print("Stack size:")
        print(z_stack.shape)

        input_dict = {} # initialize dictionary
        input_dict[sample_ID_i] = z_stack
        ############## APPLY CLEAN SEG ############
        input_dict = dict(list(input_dict.items()))        
        
        output_dict = {}
        for key, array in input_dict.items():
                
                input_for_model = array[:, :, :, :]  # shape (n, 32, 32, 2)
                
                # Pass through the classification model
                with tf.device('CPU'):
                        img_tensor = tf.convert_to_tensor(input_for_model)
                with tf.device(f'/GPU:{gpu_index}'):
                        class_probabilities = classifier.predict(img_tensor)
                # Determine the class with the highest probability
                class_indices = np.argmax(class_probabilities, axis=1)
                class_assignments = np.array(class_values)[class_indices]

                # Create tuples with the original array, assigned class, and class probabilities
                output_dict[key] = (array, class_assignments, class_probabilities)
        
        #########################
        print("Appending results to features...")
        cleanSeg_i = output_dict[sample_ID_i]
        classes_i = cleanSeg_i[1]
        indices_of_cells = [index for index, item in enumerate(classes_i) if item == "Cell"]
        print("Cells detected: "+str(len(indices_of_cells)))
        
        features_i = pd.read_csv(sample_ID_i+feature_suffix, sep='\t')
        features_i['class_cleanSeg'] = classes_i# append cleanSeg classes to the feature table
        
        probsCS = cleanSeg_i[2] # extract cleanSeg probabilities 
        probsCS_df = pd.DataFrame(probsCS, columns=probsCS_names)
        # append cleanSeg probabilities to the feature table, remove common columns if they exist ()
        common_columns = features_i.columns.intersection(probsCS_df.columns)
        features_i = features_i.drop(columns=common_columns)
        features_i = pd.concat([features_i, probsCS_df], axis=1)
        features_i.to_csv(sample_ID_i+feature_suffix, sep = "\t", index = False) # write output

        ###### export cell stack for review and class-specific parsing
        print("writing tiles...")
        tiles_out = np.transpose(tiles, (0,3,1,2))
        print(tiles_out.shape)
        tiles_out = tiles_out[indices_of_cells,:,:,:]
        ############ fill missing channels with 0's
        if len(c_names) < 7:
                print("APPENDING MISSING CHANNELS AS ZERO...")
                tiles_out = np.transpose(tiles_out, (0,2,3,1))
                print(tiles_out.shape)
                tiles_out = ctc.append_missing_channels(tiles_out, c_names) # append 0's after normalization so as to not introduce errors there. Need them present so we index predictive channel correctly (hard coded)
                c_names_write = ['300', '400', '500', '600', '700', 'BF', 'UNET_open']
                tiles_out = np.transpose(tiles_out, (0,3,1,2))
                print(tiles_out.shape)
        else:
                c_names_write = c_names.copy()
        ############
        tiles_out = tiles_out.astype(np.uint16)
        tifffile.imwrite(sample_ID_i+".cells.tiff",
                                data=tiles_out,
                                metadata={
                                        'axes':'ZCYX',
                                        'spacing':0.65,
                                        'orientation':'topleft',
                                        "Channel":{"Name":c_names_write}
                                        },
                        ome=True
                        )

        print("##### Done #####\n")