# Bioimage Quality Control (QC) Notebook

This notebook provides a set of simple quality control analyses for microscopy images, allowing users to flag issues such as saturation, crosstalk or inappropriate bit depth.

**QC checks implemented:**
- Histogram oddities
- Background flatness
- Bit depth assessment
- Dynamic range calculation
- Saturation percentage

⚠️ **WORK IN PROGRESS** The code is based on a prototype script. 

In [38]:
import bioio_ome_tiff
import numpy as np
from bioio import BioImage
from scipy.optimize import curve_fit
from skimage.transform import resize

from regression_model import *
from regression_model import CrossTalkRegressionModel

## Detect Odd Histogram Distribution
Flags images with strange histogram distributions (e.g., many zero bins within the main intensity range).

In [39]:
def detect_odd_histogram_distribution(image, bins=256, percentile_threshold=99.99):
    # Calculate the histogram of the image
    hist, bin_edges = np.histogram(image, bins=bins, range=(np.min(image), np.max(image)))

    # Calculate the cumulative histogram to find the percentile threshold bin
    cumulative_hist = np.cumsum(hist)
    total_pixels = cumulative_hist[-1]

    # Find the bin that corresponds to the 95th percentile
    threshold_index = np.searchsorted(cumulative_hist, percentile_threshold / 100 * total_pixels)

    # Find the indices of the first non-zero bin
    non_zero_bins = np.where(hist > 0)[0]
    if len(non_zero_bins) == 0:
        # If no non-zero bins are found return zero for all metrics
        return 0, 0

    first_non_zero_bin = non_zero_bins[0]

    # Count zero bins between the first non-zero bin and the threshold bin
    zero_bins = np.sum(hist[first_non_zero_bin:threshold_index] == 0)

    # Calculate the ratio of zero bins to the total number of bins in this range
    total_bins_in_range = threshold_index - first_non_zero_bin
    zero_bin_ratio = zero_bins / total_bins_in_range if total_bins_in_range > 0 else 0

    return zero_bins, zero_bin_ratio

## Estimate Background Flatness
Fits a flat plane to each slice of a 3D image and measures deviation for non-uniformity.

In [40]:
def flat_plane(coords, p0, p1, p2):
    # Unpack the coordinates
    x, y = coords
    # Linear plane model (1st order polynomial)
    return p0 * x + p1 * y + p2


def estimate_background_flat_plane_deviation(image_3d):
    depth, height, width = image_3d.shape

    # Initialize array to store background estimates for each slice
    background_3d = np.zeros_like(image_3d, dtype=np.float64)

    deviations = []

    for z in range(depth):
        print(f'Checking slice {z} of {depth}')
        # Process each 2D slice independently
        image = image_3d[z, :, :]

        # Generate grid of coordinates
        y = np.arange(height)
        x = np.arange(width)
        xx, yy = np.meshgrid(x, y)

        # Flatten arrays
        x_flat = xx.ravel()
        y_flat = yy.ravel()
        image_flat = image.ravel()

        # Fit a flat plane to the current slice
        p_initial = np.zeros(3)
        params, _ = curve_fit(flat_plane, (x_flat, y_flat), image_flat, p0=p_initial)

        # Calculate fitted background for the current slice
        background_slice = flat_plane((xx, yy), *params).reshape(image.shape)

        # Store the fitted background slice in the 3D array
        background_3d[z, :, :] = background_slice

        # Calculate deviation from the flat plane
        deviation = np.abs(image - background_slice)
        deviations.append(np.std(deviation))

    # The non-uniformity metric could be an average or max deviation across slices
    non_uniformity = np.mean(deviations)  # or max(deviations)

    return background_3d, non_uniformity

## Bit Depth Checker
Warns if the image bit depth is unusually low.

In [41]:
def check_bit_depth(image):
    # Determine the bit depth of the image
    bit_depth = image.dtype.itemsize * 8  # itemsize gives the number of bytes, so multiply by 8 to get bits

    if bit_depth <= 8:
        print(f"Warning: The image has a low bit depth of {bit_depth}-bits, which may limit image quality.")
    else:
        print(f"The image has a bit depth of {bit_depth}-bits, which is adequate for most purposes.")

    return bit_depth

## Dynamic Range Calculation

In [42]:
def calculate_dynamic_range(image):
    min_intensity = np.min(image)
    max_intensity = np.max(image)

    # Determine the maximum possible range based on the image's data type
    dtype_max = np.iinfo(image.dtype).max

    # Normalized dynamic range
    dynamic_range = (max_intensity - min_intensity) / dtype_max
    return dynamic_range

## Saturation Percentage
Calculates percentage of pixels that are fully saturated (min or max value for the data type).

In [43]:
def calculate_saturation_percentage(image):
    # Determine the minimum and maximum possible values based on the image's data type 
    dtype_min = np.iinfo(image.dtype).min
    dtype_max = np.iinfo(image.dtype).max

    # Count the number of saturated pixels
    saturated_pixels = np.sum((image == dtype_min) | (image == dtype_max))

    # Calculate the total number of pixels
    total_pixels = image.size

    # Calculate the percentage of saturated pixels
    saturation_percentage = (saturated_pixels / total_pixels) * 100

    return saturation_percentage

## Estimate Cross-talk

In [44]:
def normalize_image(img):
    img_min, img_max = img.min(), img.max()
    if img_max > img_min:  # Avoid division by zero
        return (img - img_min) / (img_max - img_min)
    else:
        return img


def crosstalk_test_transforms_fn(mixed_np, source_np):
    mixed_np = normalize_image(resize(mixed_np.astype(np.float32), (256, 256)))
    source_np = normalize_image(resize(source_np.astype(np.float32), (256, 256)))
    mixed_tensor = torch.from_numpy(mixed_np).unsqueeze(0)
    source_tensor = torch.from_numpy(source_np).unsqueeze(0)

    return torch.cat([mixed_tensor, source_tensor], dim=0).unsqueeze(0)


def estimate_crosstalk(channel1, channel2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    crosstalk_model = CrossTalkRegressionModel(initial_filters=128, num_conv_blocks=6)
    crosstalk_model.load_state_dict(
        torch.load('./crosstalk_model/crosstalk_regression_model_trained_2025-08-21_01-02-53_256_0.0005.pth',
                   map_location=device))
    crosstalk_model.eval()
    crosstalk_model.to(device)
    # Wrap the prediction in the no_grad context
    with torch.no_grad():
        # 1. Convert the NumPy array to a PyTorch tensor
        input_tensor = crosstalk_test_transforms_fn(channel1, channel2)

        # 2. Move the tensor to the correct device
        input_tensor = input_tensor.to(device)

        # 3. Make a single prediction
        output = crosstalk_model(input_tensor)

        # 4. Get the predicted labels and convert to NumPy
        # For a classification model, you might use argmax to get the class index
        #predicted_label = torch.argmax(output, dim=1).item()

        # If your model's output is not a class, you can just
        # move it to the CPU and convert to a NumPy array.
        return output[0].cpu().numpy()[0]

## Example Usage

This block demonstrates how to use the QC functions on an example image. Replace the image path with your own image as needed.

In [48]:
img = BioImage('./inputs/Experiment-09-test.ome.tiff', reader=bioio_ome_tiff.Reader)
print(img.dims)
print(img.data.shape)

for c in range(img.dims.C):
    channel = img.get_image_data('CZYX', C=c)
    channel = channel[0, :, :]
    print(f'\n--- Channel {c} ---')
    check_bit_depth(channel)
    dr = calculate_dynamic_range(channel)
    print(f'Dynamic range of Channel {c} is {dr}')
    saturation_percentage = calculate_saturation_percentage(channel)
    print(f'Relative saturation of Channel {c} is {saturation_percentage}%')
    # background_3d, non_uniformity = estimate_background_flat_plane_deviation(channel)
    # print(f"Non-uniformity (Flat Plane Deviation) for Channel {c} is {non_uniformity}")
    zero_bins, zero_bin_ratio = detect_odd_histogram_distribution(channel)
    print(f"Number of zero bins: {zero_bins}")
    print(f"Ratio of zero bins: {zero_bin_ratio:.4f}")
    for s in range(img.dims.C):
        if c == s:
            continue
        crosstalk = estimate_crosstalk(channel[0], img.get_image_data('CZYX', C=s)[0, 0])
        print(f'An estimated {crosstalk * 100:.1f}% of channel {c} is crosstalk from channel {s}')

<Dimensions [T: 1, C: 2, Z: 1, Y: 2208, X: 2752]>
(1, 2, 1, 2208, 2752)

--- Channel 0 ---
The image has a bit depth of 16-bits, which is adequate for most purposes.
Dynamic range of Channel 0 is 0.2477607385366598
Relative saturation of Channel 0 is 0.0%
Number of zero bins: 0
Ratio of zero bins: 0.0000
Using device: cpu
An estimated 14.5% of channel 0 is crosstalk from channel 1

--- Channel 1 ---
The image has a bit depth of 16-bits, which is adequate for most purposes.
Dynamic range of Channel 1 is 0.3207904173342489
Relative saturation of Channel 1 is 0.0%
Number of zero bins: 0
Ratio of zero bins: 0.0000
Using device: cpu
An estimated 42.8% of channel 1 is crosstalk from channel 0
