In [1]:
import numpy as np
import xml.etree.ElementTree as ET
from skimage.io import imread, imsave
import tifffile
import re
import matplotlib.pyplot as plt
import os
import tkinter as tk
from tkinter import filedialog
import re

In [2]:
def parse_xml_v2(xml_string):
    channel_data = {}
    entry_regex = re.compile(r'<Entry ChannelID="(.*?)">(.*?)</Entry>', re.DOTALL)
    flatfield_profile_regex = re.compile(r'<FlatfieldProfile>(.*?)</FlatfieldProfile>', re.DOTALL)

    for entry_match in entry_regex.finditer(xml_string):
        channel_id = int(entry_match.group(1))
        flatfield_profile = entry_match.group(2)

        coeffs_text = re.search(r'Coefficients: \[\[(.*?)\]\]', flatfield_profile).group(1)
        coeffs = np.array([list(map(float, row.split(','))) for row in coeffs_text.split('], [')], dtype=object)

        origin = tuple(map(float, re.search(r'Origin: \[(.*?)\]', flatfield_profile).group(1).split(', ')))
        scale = tuple(map(float, re.search(r'Scale: \[(.*?)\]', flatfield_profile).group(1).split(', ')))
        background_mean = float(re.search(r'Mean: (.*?),', flatfield_profile).group(1))
        
        dims = tuple(map(int, re.search(r'Dims: \[(.*?)\]', flatfield_profile).group(1).split(', ')))  # Add this line

        channel_data[channel_id] = {
            'coefficients': coeffs,
            'origin': origin,
            'scale': scale,
            'background_mean': background_mean,
            'dims': dims  # Add this line
        }

    return channel_data


In [3]:
def reconstruct_flatfield_image(channel_data):
    coeffs, origin,  = channel_data['coefficients'], channel_data['origin']
    scale, img_shape = channel_data['scale'], channel_data['dims']
    yv, xv = np.meshgrid(np.arange(img_shape[0]), np.arange(img_shape[1]), indexing='ij')
    xv = (xv - origin[0]) * scale[0]
    yv = (yv - origin[1]) * scale[1]

    flatfield_image = np.zeros(img_shape)
    for row in coeffs:
        for j, coeff in enumerate(row):
            # Coefficients in the X order, eg for 3rd degree polynomial:
            # x^3, y·x^2, x·y^2, y^3
            flatfield_image += coeff * (xv ** (len(row)-1-j)) * (yv ** j)
    return flatfield_image

In [4]:
def apply_ffc(image, channel_data):
    corrected_image = np.zeros_like(image, dtype=np.float32)
    for channel in range(image.shape[0]):
        channel_info = channel_data[channel + 1]
        flatfield = channel_info["flatfield"]
        dark_field = channel_info['background_mean']
        m = np.mean(flatfield - dark_field)
        gain = m / (flatfield - dark_field)
        corrected_image[channel] = (image[channel] - dark_field) * gain
    return np.clip(corrected_image, 0, 2**16 - 1).astype(np.uint16)

def apply_ffc_chann(image, channel_data, channel_index):
    corrected_image = np.zeros_like(image, dtype=np.float32)
    channel_info = channel_data[channel_index]
    flatfield = channel_info["flatfield"]
    dark_field = channel_info['background_mean']
    m = np.mean(flatfield - dark_field)
    gain = m / (flatfield - dark_field)
    corrected_image = (image - dark_field) * gain
    return np.clip(corrected_image, 0, 2**16 - 1).astype(np.uint16)

In [5]:
def read_xml_data(file_path):
    with open(file_path, 'r') as file:
        return file.read()

In [6]:
def validate_image_shape(image_shape, num_channels, channel_dimensions):
    if len(image_shape) != 3:
        return False

    if image_shape[0] != num_channels:
        return False

    if image_shape[1:] != channel_dimensions:
        return False

    return True

In [7]:
# Prompt the user to select the XML file
root = tk.Tk()
root.withdraw()
xml_file_path = filedialog.askopenfilename(title="Select the XML file")

# Read the XML data
xml_data = read_xml_data(xml_file_path)

# Parse the XML data
channel_data = parse_xml_v2(xml_data)
# print(channel_data)  # Add this line to print the channel_data

# Extract the number of channels and their dimensions from the parsed XML data
num_channels = len(channel_data)
channel_dimensions = channel_data[1]['dims']

# Prompt the user to select the input and output folders
input_folder_path = filedialog.askdirectory(title="Select the input folder")
output_folder_path = filedialog.askdirectory(title="Select the output folder")

# Generate the flatfield for every channel once
for channel_idx, channel_d in channel_data.items():
    flatfield = reconstruct_flatfield_image(channel_d)
    channel_data[channel_idx]["flatfield"] = flatfield
    
# Regex to capture channel index if image has only two dimensions
ch_idx_re = re.compile(r"ch(\d+)")

# Iterate through all files in the input folder
for filename in os.listdir(input_folder_path):
    file_path = os.path.join(input_folder_path, filename)

    # Check if the file is an image file (assuming TIFF format)
    if file_path.lower().endswith('.tiff') or file_path.lower().endswith('.tif'):
        # Read the original image
        image = tifffile.imread(file_path)

        # Validate the image shape
        if len(image.shape) == 2: # We capture the channel index from the name
            re_res = ch_idx_re.search(filename)
            if len(re_res.groups()) != 1:
                print(f"Image '{filename}' channel index detection failed. Skipping...")
                continue
            channel_index = int(re_res.groups()[0])
            corrected_image = apply_ffc_chann(image, channel_data, channel_index)
            
        else: # Multichannel images are processed at once
            if not validate_image_shape(image.shape, num_channels, channel_dimensions):
                print(f"Image '{filename}' has an incorrect shape. Skipping...")
                continue

            # Apply flat-field correction
            corrected_image = apply_ffc(image, channel_data)

        # Save the corrected image
        output_filename = f"corrected_{filename}"
        output_file_path = os.path.join(output_folder_path, output_filename)
        tifffile.imwrite(output_file_path, corrected_image)
