# import

In [1]:
import warnings, os
import h5py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch, math, csv
from torchvision import transforms
from PIL import Image
import tifffile
from tqdm import tqdm
warnings.filterwarnings("ignore")
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

# Define file paths and settings

In [2]:
all_tif_image_stack_paths = [
    './row_data/On_time_varied_data/01_25LP_532_100ms_1nMRO_1d8nt_1nMP3_trolox.ome.tif',
    './row_data/On_time_varied_data/01_25LP_532_100ms_400pMRO_1d10nt_1nMP3_trolox.tif',
    './row_data/On_time_varied_data/Another_dataset/01_13LP_532_100ms_1nMRO_1d8nt_1nMP3_trolox-002.tif',
    './row_data/On_time_varied_data/Another_dataset/02_13LP_532_100ms_200pMRO_1d10nt_1nMP3_trolox-001.tif'
]
all_drift_corrected_paths = [
    './row_data/On_time_varied_data/01_25LP_532_100ms_1nMRO_1d8nt_1nMP3_trolox_locs_render_picked.hdf5',
    './row_data/On_time_varied_data/01_25LP_532_100ms_400pMRO_1d10nt_1nMP3_trolox_locs_render_picked.hdf5',
    './row_data/On_time_varied_data/Another_dataset/01_13LP_532_100ms_1nMRO_1d8nt_1nMP3_trolox_locs_render1000_picked.hdf5',
    './row_data/On_time_varied_data/Another_dataset/02_13LP_532_100ms_200pMRO_1d10nt_1nMP3_trolox_locs_render1000_picked.hdf5'
]
all_drift_uncorrected_paths = [
    './row_data/On_time_varied_data/Not_drift_corrected/01_25LP_532_100ms_1nMRO_1d8nt_1nMP3_trolox_locs.hdf5',
    './row_data/On_time_varied_data/Not_drift_corrected/01_25LP_532_100ms_400pMRO_1d10nt_1nMP3_trolox_locs.hdf5',
    './row_data/On_time_varied_data/Another_dataset/Not_drift_corrected/01_13LP_532_100ms_1nMRO_1d8nt_1nMP3_trolox_locs.hdf5',
    './row_data/On_time_varied_data/Another_dataset/Not_drift_corrected/02_13LP_532_100ms_200pMRO_1d10nt_1nMP3_trolox_locs.hdf5'
    ]
all_drift_trajectory_paths = [
    './row_data/On_time_varied_data/01_25LP_532_100ms_1nMRO_1d8nt_1nMP3_trolox_locs_240229_094939_drift.txt',
    './row_data/On_time_varied_data/01_25LP_532_100ms_400pMRO_1d10nt_1nMP3_trolox_locs_240229_094701_drift.txt',
    './row_data/On_time_varied_data/Another_dataset/01_13LP_532_100ms_1nMRO_1d8nt_1nMP3_trolox_locs_240303_101726_drift.txt',
    './row_data/On_time_varied_data/Another_dataset/02_13LP_532_100ms_200pMRO_1d10nt_1nMP3_trolox_locs_240303_102117_drift.txt'
    ]
data_dir = './data'
box_width = 10
max_drift_distance = box_width/5 # Unit: Pixel, assume that drift speed is less than 130 um/100ms
# Our setting: pixel height/width: 6.5 um, frame rate = 10 Hz


# Load and save molecular localization and drift data to .csv tables

In [None]:
for domain_idx, hdf5_path in enumerate(all_drift_corrected_paths):
    save_path = f'{os.path.dirname(hdf5_path)}/domain_{domain_idx + 1}.csv'
    print(save_path)
    if not os.path.exists(save_path) or reprocess:
        f = h5py.File(hdf5_path)
        df = pd.DataFrame(np.array(f['locs']))
        df.to_csv(save_path, index=False)
        print('Domain:', domain_idx + 1, '\n', df)
        
for domain_idx, hdf5_path in enumerate(all_drift_uncorrected_paths):
    save_path = f'{os.path.dirname(hdf5_path)}/domain_{domain_idx + 1}.csv'
    print(save_path)
    if not os.path.exists(save_path) or reprocess:
        f = h5py.File(hdf5_path)
        df = pd.DataFrame(np.array(f['locs']))
        df.to_csv(save_path, index=False)
        print('Domain:', domain_idx + 1, '\n', df)

for domain_idx, text_path in enumerate(all_drift_trajectory_paths):
    if len(text_path) > 0:
        save_path = f'{os.path.dirname(text_path)}/drift_{domain_idx + 1}.csv'
        if not os.path.exists(save_path) or reprocess:
            # Read the text file and skip blank lines
            with open(text_path, 'r') as file:
                lines = [line.strip() for line in file if line.strip() and not line.startswith('#')]

            # Create a list to store data
            data = []

            # Iterate through lines and extract values
            for i, line in enumerate(lines):
                values = line.split()
                dx = float(values[0])
                dy = float(values[1])
                data.append([i, dx, dy])

            # Create a DataFrame
            df = pd.DataFrame(data, columns=['frame', 'dx', 'dy'])
            df.to_csv(save_path, index=False)
            print('Domain:', domain_idx + 1, '\n', df)

# Substract the mean intensity for each frame, mask and save original fluorescent spots in tif format (use uncorrected df for a complete collection of binding events and blinking correction)

In [None]:
# Function to calculate Euclidean distance
def calculate_distance(x1, y1, x2, y2):
    return math.sqrt((x2 - x1)**2 + (y2 - y1)**2)

def save_image_patches(original_frame, image, x, y, box_width, patch_save_path, subtracted_patch_save_path, mean_background_intensity_df):
    # Calculate box boundaries
    box_left = max(0, round(x - box_width / 2))
    box_right = min(image.shape[1], round(x + box_width / 2))
    box_top = max(0, round(y - box_width / 2))
    box_bottom = min(image.shape[0], round(y + box_width / 2))

    # Extract image patch
    image_patch = image[box_top:box_bottom, box_left:box_right]
    
    # Save the image patch
    os.makedirs(os.path.dirname(patch_save_path), exist_ok=True)
    image_patch_uint16 = image_patch.astype(np.uint16)
    # print(f"image_patch_uint16: {image_patch_uint16}")
    image_patch_pil = Image.fromarray(image_patch_uint16)
    image_patch_pil.save(patch_save_path)

    # subtract the mean background intensity from the image patch
    image_patch_mean = np.mean(image_patch_uint16)
    background_mean_intensity = mean_background_intensity_df[mean_background_intensity_df['frame']==original_frame+1]['mean_background_intensity'].values
    masked_image_patch_uint16 = image_patch_uint16
    # print(original_frame, background_mean_intensity)
    background_mean_intensity = background_mean_intensity.item()
    mask_threshold = (image_patch_mean + background_mean_intensity) / 2
    background_offset = 100
    masked_image_patch_uint16[masked_image_patch_uint16 < mask_threshold] = background_offset + background_mean_intensity
    os.makedirs(os.path.dirname(subtracted_patch_save_path), exist_ok=True)
    subtracted_image_patch_uint16 = masked_image_patch_uint16 - background_mean_intensity
    # print(f"subtracted_image_patch_uint16: {subtracted_image_patch_uint16}")                    
    subtracted_image_patch_pil = Image.fromarray(subtracted_image_patch_uint16)
    subtracted_image_patch_pil.save(subtracted_patch_save_path)

    # For debug
    # image = Image.open(subtracted_patch_save_path)
    # # Convert the image to a numpy array
    # image_array = np.array(image)
    # print(f"image_array: {image_array}") # the same as image_patch/subtracted_image_patch_uint16
    # print(stop)

process_domain = 1
blinking_correction_th = 10
read_exist, reprocess = False, True
mean_background_intensity_df = pd.DataFrame()
# Loop through each TIFF file path
for domain_idx, tif_path in enumerate(all_tif_image_stack_paths):
    if domain_idx in [process_domain - 1]: # need to minus 1
        mean_background_intensity_file = f"{data_dir}/spots/domain_{domain_idx + 1}/mean_background_intensity.csv"
        # Read TIFF image stack
        image_stack = tifffile.imread(tif_path)

        # Read original localizations without drift corrections
        locs_path = f'{os.path.dirname(tif_path)}/Not_drift_corrected/domain_{domain_idx + 1}.csv'
        locs_df = pd.read_csv(locs_path)

        # Read drift trajectories
        drift_path = f'{os.path.dirname(tif_path)}/drift_{domain_idx + 1}.csv'
        if not os.path.exists(drift_path):
            continue
        drift_df = pd.read_csv(drift_path)

        # Print information about the read data
        print(f"TIFF Path: {tif_path}")
        print(f"Shape of Image Stack: {image_stack.shape}")
        print(f"Data Type of Image Stack: {image_stack.dtype}")

        # Convert image stack to a supported data type (e.g., float32)
        image_stack = image_stack.astype(np.float32)

        corrected_save_path = f'{os.path.dirname(tif_path)}/corrected_{domain_idx + 1}.csv'
        if not os.path.exists(corrected_save_path) or reprocess:
            # Merge locs_df and drift_df on "frame"
            merged_df = pd.merge(locs_df, drift_df, on='frame', suffixes=('', '_drift'))

            # Subtract "dx" from "x" and "dy" from "y"
            merged_df['x_corrected'] = merged_df['x'] - merged_df['dx']
            merged_df['y_corrected'] = merged_df['y'] - merged_df['dy']

            # Reset the index within each group
            merged_df['spot'] = merged_df.groupby('frame').cumcount()

            # Drop unnecessary columns
            merged_df.to_csv(corrected_save_path, index=False)

        linked_save_path = f"{data_dir}/spots/domain_{domain_idx + 1}_indexes_linked.csv"
        if os.path.exists(linked_save_path) and read_exist:
            indexes_df = pd.read_csv(linked_save_path)
        else:
            corrected_df = pd.read_csv(corrected_save_path) # [x_corrected, y_corrected] can replace groups as binding site labels
            indexes_df = corrected_df
            # Initialize
            indexes_df['group'] = np.nan
            indexes_df['x_group'] = np.nan
            indexes_df['y_group'] = np.nan
            indexes_df['missed_from_frame'] = np.nan
            indexes_df['missed_from_spot'] = np.nan
        indexes_df_grouped_by_frame = indexes_df.groupby('frame')
        
        # Read original localizations with drift corrections
        grouped_path = f'{os.path.dirname(tif_path)}/domain_{domain_idx + 1}.csv'
        grouped_df = pd.read_csv(grouped_path)

        # Extract light spots from frames
        total_steps = len(grouped_df)
        grouped_df = grouped_df.sort_values(by='frame')
        for grouped_idx, row in tqdm(grouped_df.iterrows(), total=total_steps):
            # if pd.isna(row['spot']):
            frame, group, group_x, group_y = int(row['frame']), row['group'], row['x'], row['y']
            image = image_stack[frame]
            
            image_save_path = f"{data_dir}/spots/domain_{domain_idx + 1}/images/frame_{frame}.png"
            background_image_save_path = f"{data_dir}/spots/domain_{domain_idx + 1}/images_background/frame_{frame}.tif"
            if not os.path.exists(image_save_path) or reprocess:
                fig1, ax1 = plt.subplots()
                ax1.imshow(image)
                
            # Save background tiff images
            os.makedirs(os.path.dirname(background_image_save_path), exist_ok=True)
            if reprocess or not os.path.exists(mean_background_intensity_file):
                # Initialize mean_background_intensity_df with the necessary columns
                mean_background_intensity_df = pd.DataFrame(columns=['domain', 'frame', 'mean_background_intensity'])
                mean_background_intensity_df.to_csv(mean_background_intensity_file, index=False)
            else:
                mean_background_intensity_df = pd.read_csv(mean_background_intensity_file)
            for frame_idx in np.arange(frame-blinking_correction_th, frame+1):
                if frame_idx >= 0 and (frame_idx+1 not in mean_background_intensity_df['frame'].values):
                    one_frame_locs_df = indexes_df_grouped_by_frame.get_group(frame_idx)
                    # print(f'Processing background for frame {frame_idx}')
                    # Calculate the background mean intensity
                    background_mask = np.ones_like(image)
                    for _, row in one_frame_locs_df.iterrows():
                        spot, x, y, x_corrected, y_corrected = int(row['spot']), row['x'], row['y'], row['x_corrected'], row['y_corrected']
                        # Calculate box boundaries
                        box_left = max(0, round(x - box_width / 2))
                        box_right = min(image.shape[1], round(x + box_width / 2))
                        box_top = max(0, round(y - box_width / 2))
                        box_bottom = min(image.shape[0], round(y + box_width / 2))
        
                        # Exclude area with light spots
                        background_mask[box_top:box_bottom, box_left:box_right] = 0
        
                    # Calculate the average background intensity of the image_patch
                    # mean_background_intensity = image * background_mask / np.sum(background_mask)
                    
                    # Set values smaller than the mean intensity to 0
                    # print(background_image_save_path)
                    background_image = image * background_mask
                    # print(f"background_image: {background_image}")
                    background_image_uint16 = background_image.astype(np.uint16)
                    # print(f"background_image_uint16: {background_image_uint16}")
                    # fig, ax = plt.subplots()
                    background_image_pil = Image.fromarray(background_image_uint16)
                    # ax.axis('off')
                    # ax.imshow(background_image) # , interpolation='none'
                    # fig.savefig(background_image_save_path, bbox_inches='tight', pad_inches=0)
                    background_image_pil.save(background_image_save_path)
                    background_mean_intensity = np.sum(background_image_uint16) / np.sum(background_mask)
                    # print(f"background_mean_intensity: {background_mean_intensity}")
                    current_df_len = len(mean_background_intensity_df)
                    mean_background_intensity_df.loc[current_df_len, 'domain'] = domain_idx + 1
                    mean_background_intensity_df.loc[current_df_len, 'frame'] = frame_idx + 1
                    mean_background_intensity_df.loc[current_df_len, 'mean_background_intensity'] = background_mean_intensity
                    mean_background_intensity_df.to_csv(mean_background_intensity_file, index=False)
                    # print(frame_idx+1, mean_background_intensity_df[mean_background_intensity_df['frame']==frame_idx+1]['frame'].values.item())
                    
                    # For debug
                    # image_read = Image.open(background_image_save_path)
                    # image_read_array = np.array(image_read)  # Convert back to original scale
                    # print(f"image_read_array: {image_read_array}") # the same as background_image
                    # print(stop)

            existing_groups_in_linked_frames = {}
            for frame_idx in np.arange(frame, frame-1-blinking_correction_th, -1):
                if frame_idx < 0:
                    continue
                one_frame_locs_df = indexes_df_grouped_by_frame.get_group(frame_idx)
                existing_groups_in_linked_frames[f'frame_{frame_idx}'] = one_frame_locs_df['group'].values.tolist()
                if group not in existing_groups_in_linked_frames[f'frame_{frame_idx}']:
                    min_distance = max_drift_distance
                    for corrected_idx, row in one_frame_locs_df.iterrows():
                        original_frame, spot, x, y, x_corrected, y_corrected = int(row['frame']), int(row['spot']), row['x'], row['y'], row['x_corrected'], row['y_corrected']
                        patch_save_path = f"{data_dir}/spots/domain_{domain_idx + 1}/patches/frame_{frame + 1}/spot_{spot + 1}.tif"
                        subtracted_patch_save_path = f"{data_dir}/spots/domain_{domain_idx + 1}/patches_subtracted/frame_{frame + 1}/spot_{spot + 1}.tif"
                        
                        if not os.path.exists(subtracted_patch_save_path) or reprocess:
                            save_image_patches(original_frame, image, x, y, box_width, patch_save_path, subtracted_patch_save_path, mean_background_intensity_df)
                            
                        if not os.path.exists(image_save_path) or reprocess:
                            # save the image with localizations using matplotlib
                            box = [y - box_width/2, x - box_width/2, y + box_width/2, x + box_width/2]
                            rect = plt.Rectangle((box[1], box[0]), box[3] - box[1], box[2] - box[0],
                                                 linewidth=1, edgecolor='g', facecolor='none')
                            ax1.add_patch(rect)
                            ax1.text((box[1] + box[3]) // 2, box[0], str(spot), color='r', fontsize=8)
                            ax1.plot(x, y, 'ro', markersize=1)
        
                        distance = calculate_distance(x_corrected, y_corrected, group_x, group_y)
                        group_value = indexes_df.loc[corrected_idx, 'group']
                        # print(distance < min_distance, np.isnan(group_value), (group not in existing_groups_in_linked_frames[f'frame_{frame_idx}']))
                        # if distance < min_distance:
                        #     print(f'distance: {distance}, min_distance: {min_distance}, group_value: {group_value}')
                        if distance < min_distance and np.isnan(group_value) and (group not in existing_groups_in_linked_frames[f'frame_{frame_idx}']):
                            min_distance = distance
                            indexes_df.loc[corrected_idx, 'group'] = group
                            indexes_df.loc[corrected_idx, 'x_group'] = x_corrected
                            indexes_df.loc[corrected_idx, 'y_group'] = y_corrected
                            existing_groups_in_linked_frames[f'frame_{frame_idx}'].append(group)
                            if original_frame < frame - 1:
                                for missed_frame_idx in np.arange(original_frame + 1, frame):
                                    if group not in existing_groups_in_linked_frames[f'frame_{missed_frame_idx}']:
                                        missed_patch_save_path = f"{data_dir}/spots/domain_{domain_idx + 1}/patches/frame_{missed_frame_idx + 1}/missed_frame_{frame + 1}_spot_{spot + 1}.tif"
                                        missed_subtracted_patch_save_path = f"{data_dir}/spots/domain_{domain_idx + 1}/patches_subtracted/frame_{missed_frame_idx + 1}/missed_frame_{frame + 1}_spot_{spot + 1}.tif"
                                        missed_frame_drift_row = drift_df[drift_df['frame']==missed_frame_idx]
                                        dx, dy = missed_frame_drift_row.dx.item(), missed_frame_drift_row.dy.item()
                                        x_missed, y_missed = x_corrected - dx, y_corrected - dy
                                        missed_image = image_stack[missed_frame_idx]
                                        # Add a new empty row 
                                        current_len = len(indexes_df)
                                        indexes_df.loc[current_len] = [np.nan] * len(indexes_df.columns)
                                        indexes_df.loc[current_len, 'frame'] = missed_frame_idx
                                        indexes_df.loc[current_len, 'group'] = group
                                        indexes_df.loc[current_len, 'spot'] = -1 # undetected spots
                                        indexes_df.loc[current_len, 'x'] = x_missed
                                        indexes_df.loc[current_len, 'y'] = y_missed
                                        indexes_df.loc[current_len, 'dx'] = dx
                                        indexes_df.loc[current_len, 'dy'] = dy
                                        indexes_df.loc[current_len, 'x_group'] = x_corrected
                                        indexes_df.loc[current_len, 'y_group'] = y_corrected
                                        indexes_df.loc[current_len, 'missed_from_frame'] = frame
                                        indexes_df.loc[current_len, 'missed_from_spot'] = spot
                                        
                                        if not os.path.exists(missed_patch_save_path) or not os.path.exists(missed_subtracted_patch_save_path):
                                            save_image_patches(original_frame, missed_image, x_missed, y_missed, box_width, missed_patch_save_path, missed_subtracted_patch_save_path, mean_background_intensity_df)
                                            print(f'Saving missing patches in frame {missed_frame_idx + 1} for frame {frame + 1} group {group} spot {spot + 1}')

            if not os.path.exists(image_save_path) or reprocess:
                plt.title(f'Domain {domain_idx + 1} Frame {frame + 1} with bounding boxes and spot indexes')
                os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
                plt.savefig(image_save_path)
                # Clear the contents of ax
                ax1.cla()
                # Close the entire figure
                plt.close(fig1)

            if grouped_idx % 10000 == 0:
                indexes_df.to_csv(linked_save_path, index=False)

        indexes_df.to_csv(linked_save_path, index=False)  
        print("="*50)
