# Computational imaging of aerosol reactivity outcomes

Run the following cell if you haven't yet downloaded the microscopy data, e.g. if running in Colab.

In [None]:
!wget -O data.tar.gz 'https://zenodo.org/records/15632556/files/data.tar.gz?download=1'
!tar -xzf data.tar.gz
!touch .env

In [None]:
# Apply environment variables from .env file
from os import environ
from pathlib import Path

for line in open('.env', 'r'):
    line = line.strip()
    if not line or line.startswith('#') or '=' not in line:
        continue
    var, val = line.split('=', 1)
    environ[var] = val

In [None]:
import gzip
import math
import pickle
import re

import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import seaborn as sns
from ipywidgets import Button, Layout, interact
from scipy.stats import multivariate_normal, norm
from skimage.measure import perimeter
from sklearn.mixture import GaussianMixture
from tqdm.auto import tqdm
from transformers import pipeline

N_CLUSTERS = 3
DATA_HOME = Path(environ['DATA_HOME'])

sns.set_theme(style="ticks", context="notebook", font_scale=1.2, font="Arial")
matplotlib.rcParams["svg.fonttype"] = "none"

In [None]:
sample_image_files = sorted(DATA_HOME.glob('**/*.jpg'))
sample_image_files = {p.stem: p for p in sample_image_files}
sample_image_files

In [None]:
imgs_RGB = {id: PIL.Image.open(path) for id, path in tqdm(sample_image_files.items())}
# scaled down to 1/4 size
imgs_small_RGB = {id: img.resize((img.width // 4, img.height // 4)) for id, img in tqdm(imgs_RGB.items())}
imgs_small_HSV = {id: img.convert('HSV') for id, img in tqdm(imgs_small_RGB.items())}
imgs_small_LAB = {id: img.convert('LAB') for id, img in tqdm(imgs_small_RGB.items())}

In [None]:
def save_button(fn, *args, **kwargs):
    def save(fn):
        for sample_id in tqdm(imgs_small_RGB):
            fn(sample_id, *args, **kwargs)
    button = Button(description='Save all')
    button.on_click(lambda _: save(fn))
    display(button)

## Image visualisation in different colour spaces

In [None]:
@interact(id=list(imgs_small_RGB.keys()))
def show_all_colourspaces(id):
    fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(15, 12))
    spaces = [
        ('RGB', imgs_small_RGB, ['R', 'G', 'B']),
        ('HSV', imgs_small_HSV, ['H', 'S', 'V']),
        ('LAB', imgs_small_LAB, ['L', 'A', 'B']),
    ]
    for row, (space, imgs, channel_names) in enumerate(spaces):
        img = np.array(imgs[id])
        for col in range(3):
            ax = axs[row, col]
            ax.imshow(img[..., col], cmap='gray')
            ax.set_title(f'{space} - {channel_names[col]}')
            ax.figure.colorbar(ax.images[0], ax=ax, orientation='horizontal')
    fig.suptitle(id, fontsize=20)
    fig.savefig(f'out/colourspaces_{id}.svg', bbox_inches='tight', transparent=True)
    fig.savefig(f'out/colourspaces_{id}.png', bbox_inches='tight', transparent=True, dpi=300)

save_button(show_all_colourspaces)

## Image segmentation

In [None]:
mask_generator = pipeline(
    task="mask-generation",
    model="facebook/sam-vit-base",
    device="cuda",
    points_per_crop=128,
    pred_iou_thresh=0.2,
    stability_score_thresh=0.2,
    crops_nms_thresh=0.1,
    points_per_batch=128,
)

In [None]:
mask_file = Path("out/masks.pkl.gz")
mask_file.parent.mkdir(parents=True, exist_ok=True)

if mask_file.exists():
    with gzip.open(mask_file, "rb") as f:
        masks = {
            id: mask for id, mask in pickle.load(f).items() if id in imgs_small_RGB
        }
else:
    masks = {}

new_masks = {
    id: np.array(mask_generator(img)["masks"])
    for id, img in tqdm(imgs_small_RGB.items())
    if id not in masks
}

# Save masks to pickle file
if new_masks:
    masks.update(new_masks)
    if mask_file.exists():
        # make a backup of current mask_file with a timestamp
        backup_file = mask_file.with_name(
            f'{mask_file.stem}_{pd.Timestamp.now().strftime("%Y%m%d%H%M%S")}{mask_file.suffix}'
        )
        mask_file.rename(backup_file)
    with gzip.open(mask_file, "wb") as f:
        pickle.dump(masks, f)
{id: len(img_masks) for id, img_masks in masks.items()}

### Convert images from PIL object to numpy arrays

In [None]:
for collection in [imgs_RGB, imgs_small_RGB, imgs_small_HSV, imgs_small_LAB]:
    for id, img in collection.items():
        collection[id] = np.array(img)

## Filter masks by size and circularity
This step will filter out part of the background, dust particles, etc. that may have been selected by the segmentation model.

In [None]:
MAX_MASK_AREA = 1500
MIN_MASK_AREA = 36

masks_filtered = {
    id: np.array(
        [mask for mask in img_masks if MIN_MASK_AREA < mask.sum() < MAX_MASK_AREA]
    )
    for id, img_masks in tqdm(masks.items())
}
{id: len(img_masks) for id, img_masks in masks_filtered.items()}

Knowing that the microdroplets are circular, any regions significantly deviating from circularity can also be eliminated, regardless of size.

In [None]:
MIN_CIRCULARITY = 0.8

def circularity(mask):
    """
    Calculate the circularity of a binary mask.
    Circularity of perfect circle: 1.0
    """
    area = mask.sum()
    p = perimeter(mask)
    if p == 0:
        return 0
    return 4 * np.pi * area / (p * p)

masks_filtered = {
    id: np.array([
        mask for mask in img_masks 
        if (MIN_MASK_AREA < mask.sum() < MAX_MASK_AREA) and 
           circularity(mask) > MIN_CIRCULARITY
    ])
    for id, img_masks in tqdm(masks_filtered.items())
}
{id: len(img_masks) for id, img_masks in masks_filtered.items()}

## Visualising the detected region masks before and after filtering

In [None]:
@interact(id=list(masks_filtered.keys()))
def show_masks(id):
    fig, (img_ax, unfiltered_ax, filtered_ax) = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
    img_ax.imshow(imgs_small_RGB[id])
    img_ax.set_title(id)

    unfiltered_mask = np.sum(masks[id], axis=0)
    m = unfiltered_ax.imshow(
        unfiltered_mask, cmap='gray', vmin=0, vmax=np.max(unfiltered_mask)
    )
    c = fig.colorbar(m, ax=unfiltered_ax, orientation='horizontal')
    c.set_label('# regions at point')
    c.set_ticks(np.arange(0, np.max(unfiltered_mask) + 1, 1))
    unfiltered_ax.set_title(f'{len(masks[id])} unfiltered masks')
    
    filtered_mask = np.sum(masks_filtered[id], axis=0)
    m = filtered_ax.imshow(
        filtered_mask, cmap='gray', vmin=0, vmax=np.max(filtered_mask)
    )
    c = fig.colorbar(m, ax=filtered_ax, orientation='horizontal')
    c.set_label('# regions at point')
    c.set_ticks(np.arange(0, np.max(filtered_mask) + 1, 1))
    filtered_ax.set_title(f'{len(masks_filtered[id])} filtered masks')
    fig.tight_layout()

    fig.savefig(f'out/masks_{id}.png', bbox_inches='tight', transparent=True)
    fig.savefig(f'out/masks_{id}.svg', bbox_inches='tight', transparent=True)

save_button(show_masks)

# Choosing representative colours for each droplet

In [None]:
def avg_region_color(img, masks):
    """
    Return the average color of the pixels in each mask.
    """
    mask_pixels = [img[m] for m in masks]
    region_pixels = np.array([m.mean(axis=0) for m in mask_pixels])
    return region_pixels / 255

In [None]:
def max_region_color(img, masks):
    """
    Return the color of the pixel with the largest norm in each mask.
    Meant for use with RGB images only.
    """
    mask_pixels = [img[m] for m in masks]
    representative_pixel = [np.linalg.norm(pixels, axis=-1).argmax() for pixels in mask_pixels]
    region_pixels = np.array([mask_pixels[i][representative_pixel[i]] for i in range(len(mask_pixels))])
    return region_pixels / 255

In [None]:
def max_value_region_color(img, masks):
    """
    Return the color of the pixel with the largest V (value) in each mask.
    Meant for use with HSV images only.
    """
    mask_pixels = [img[m] for m in masks]
    # find the pixel with the largest value in the third (V) channel
    representative_pixel = [pixels[..., 2].argmax() for pixels in mask_pixels]
    region_pixels = np.array([mask_pixels[i][representative_pixel[i]] for i in range(len(mask_pixels))])
    return region_pixels / 255

In [None]:
def max_sat_region_color(img, masks):
    """
    Return the color of the pixel with the largest S (saturation) in each mask.
    Meant for use with HSV images only.
    """
    mask_pixels = [img[m] for m in masks]
    representative_pixel = [pixels[..., 1].argmax() for pixels in mask_pixels]
    region_pixels = np.array([mask_pixels[i][representative_pixel[i]] for i in range(len(mask_pixels))])
    return region_pixels / 255

In [None]:
def max_SV_region_color(img, masks):
    """
    Return the color of the pixel with the largest product of S (saturation) and V (value) in each mask.
    Meant for use with HSV images only.
    """
    mask_pixels = [img[m] for m in masks]
    # indices of the middle 25% V pixels
    representative_pixels = [np.argsort(pixels[..., 2])[-(pixels.shape[0] // 4):] for pixels in mask_pixels]
    # index of max S among the top 50% V pixels
    representative_pixel = [pixels[:, 1][representative_pixels[i]].argmax() for i, pixels in enumerate(mask_pixels)]
    region_pixels = np.array([mask_pixels[i][representative_pixels[i][representative_pixel[i]]] for i in range(len(mask_pixels))])
    return region_pixels / 255

In [None]:
def max_SV_AB_color(img_hsv, img_lab, masks):
    """
    Return the color of the pixel with the largest product of S (saturation) and V (value) in each mask.
    Meant for use with HSV images only.
    """
    mask_pixels_hsv = [img_hsv[m] for m in masks]
    mask_pixels_lab = [img_lab[m] for m in masks]
    # indices of the middle 25% V pixels
    representative_pixels = [np.argsort(pixels[..., 2])[-(pixels.shape[0] // 4):] for pixels in mask_pixels_hsv]
    # index of max S among the top 50% V pixels
    representative_pixel = [pixels[representative_pixels[i]][:, 1].argmax() for i, pixels in enumerate(mask_pixels_hsv)]
    region_pixels = np.array([mask_pixels_lab[i][representative_pixels[i][representative_pixel[i]]] for i in range(len(mask_pixels_hsv))])
    return region_pixels / 255

In [None]:
def max_SV_RGB_color(img_hsv, img_rgb, masks):
    """
    Return the color of the pixel with the largest product of S (saturation) and V (value) in each mask.
    Meant for use with HSV images only.
    """
    mask_pixels_hsv = [img_hsv[m] for m in masks]
    mask_pixels_rgb = [img_rgb[m] for m in masks]
    # indices of the middle 25% V pixels
    representative_pixels = [np.argsort(pixels[..., 2])[-(pixels.shape[0] // 4):] for pixels in mask_pixels_hsv]
    # index of max S among the top 25% V pixels
    representative_pixel = [pixels[representative_pixels[i]][:, 1].argmax() for i, pixels in enumerate(mask_pixels_hsv)]
    region_pixels = np.array([mask_pixels_rgb[i][representative_pixels[i][representative_pixel[i]]] for i in range(len(mask_pixels_hsv))])
    return region_pixels / 255

In [None]:
def max_V_RGB_color(img_hsv, img_rgb, masks):
    """
    Return the RGB color of the pixel with the largest V (value) in each mask.
    """
    mask_pixels_hsv = [img_hsv[m] for m in masks]
    mask_pixels_rgb = [img_rgb[m] for m in masks]
    representative_pixel = [pixels[..., 2].argmax() for pixels in mask_pixels_hsv]
    region_pixels = np.array([mask_pixels_rgb[i][pixel] for i, pixel in enumerate(representative_pixel)])
    return region_pixels / 255

In [None]:
def max_SV_prod_AB_color(img_hsv, img_lab, masks):
    """
    Return the color of the pixel with the largest product of S (saturation) and V (value) in each mask.
    Meant for use with HSV images only.
    """
    mask_pixels_hsv = [img_hsv[m] for m in masks]
    mask_pixels_lab = [img_lab[m] for m in masks]

    representative_pixels = [np.prod(pixels[..., 1:], axis=-1).argmax() for pixels in mask_pixels_hsv]

    region_pixels = np.array([mask_pixels_lab[i][representative_pixels[i]] for i in range(len(mask_pixels_hsv))])
    return region_pixels / 255

In [None]:
def centre_region_color(img, masks):
    """
    Return the color of the pixel at the centre of each mask.
    """
    result = []
    for mask in masks:
        h, v = np.nonzero(mask)
        h, v = int(h.mean()), int(v.mean())
        result.append(img[h, v])
    return np.array(result) / 255

In [None]:
def add_scalebar(ax, length, text, color='black', box=False):
    scalebar_length = length  # plot coordinate units
    scalebar_height = 10
    
    ylim = ax.get_ylim()
    xlim = ax.get_xlim()
    scalebar_y = ylim[0] + (ylim[1] - ylim[0])*0.1
    scalebar_x = xlim[1] - (xlim[1] - xlim[0])*0.25
    
    rect = patches.Rectangle((scalebar_x, scalebar_y), 
                            scalebar_length, scalebar_height,
                           facecolor=color,
                           edgecolor='none',
                           transform=ax.transData)
    if box:
        bg = patches.Rectangle(
            (scalebar_x - scalebar_length/10, scalebar_y - scalebar_height*7),
            scalebar_length*1.2, scalebar_height*10,
            facecolor='white', edgecolor='none', alpha=0.5,
            transform=ax.transData
        )
        ax.add_patch(bg)
    ax.add_patch(rect)
    ax.text(scalebar_x + scalebar_length/2,
            scalebar_y - scalebar_height*1.5,
            text, color=color, fontweight='bold',
            horizontalalignment='center',
            transform=ax.transData)

In [None]:
def hsv_to_rgb(hsv):
    hsv_uint8 = (hsv * 255).astype(np.uint8)[:, None, :]
    hsv_img = PIL.Image.fromarray(hsv_uint8, mode='HSV')
    rgb_img = hsv_img.convert('RGB')
    return np.array(rgb_img).astype(float)[:, 0, :] / 255


def lab_to_rgb(lab):
    lab_uint8 = (lab * 255).astype(np.uint8)[:, None, :]
    lab_img = PIL.Image.fromarray(lab_uint8, mode='LAB')
    rgb_img = lab_img.convert('RGB')
    return np.array(rgb_img).astype(float)[:, 0, :] / 255


def make_visualisation(colors, masks):
    result = np.zeros((masks[0].shape[0], masks[0].shape[1], 3))
    for mask, color in zip(masks, colors):
        result[mask] = color
    return result

methods = {
    'avg': lambda id: (avg_region_color(imgs_small_RGB[id], masks_filtered[id]), lambda colors: make_visualisation(colors, masks_filtered[id])),
    'max': lambda id: (max_region_color(imgs_small_RGB[id], masks_filtered[id]), lambda colors: make_visualisation(colors, masks_filtered[id])),
    'max_V': lambda id: (max_value_region_color(imgs_small_HSV[id], masks_filtered[id]), lambda colors: make_visualisation(hsv_to_rgb(colors), masks_filtered[id])),
    'max_S': lambda id: (max_sat_region_color(imgs_small_HSV[id], masks_filtered[id]), lambda colors: make_visualisation(hsv_to_rgb(colors), masks_filtered[id])),
    'max_SV': lambda id: (max_SV_region_color(imgs_small_HSV[id], masks_filtered[id]), lambda colors: make_visualisation(hsv_to_rgb(colors), masks_filtered[id])),
    'max_SV_AB': lambda id: (max_SV_AB_color(imgs_small_HSV[id], imgs_small_LAB[id], masks_filtered[id]), lambda colors: make_visualisation(lab_to_rgb(colors), masks_filtered[id])),
    'max_SV_RGB': lambda id: (max_SV_RGB_color(imgs_small_HSV[id], imgs_small_RGB[id], masks_filtered[id]), lambda colors: make_visualisation(colors, masks_filtered[id])),
    'max_V_RGB': lambda id: (max_V_RGB_color(imgs_small_HSV[id], imgs_small_RGB[id], masks_filtered[id]), lambda colors: make_visualisation(colors, masks_filtered[id])),
    'max_SV_prod_AB': lambda id: (max_SV_prod_AB_color(imgs_small_HSV[id], imgs_small_LAB[id], masks_filtered[id]), lambda colors: make_visualisation(lab_to_rgb(colors), masks_filtered[id])),
    'centre': lambda id: (centre_region_color(imgs_small_RGB[id], masks_filtered[id]), lambda colors: make_visualisation(colors, masks_filtered[id])),
}

In [None]:
@interact(id=list(imgs_small_RGB.keys()), ticks=True, save=False, **{**{m: False for m in methods}, 'centre': True})
def show_image(id, ticks, save, **kwargs):
    colors = {}
    chosen_methods = [m for m in methods if m in kwargs and kwargs[m]]
    nplots = len(chosen_methods) + 1
    ncols = math.ceil(math.sqrt(nplots))
    nrows = math.ceil(nplots / ncols)
    fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(5 * ncols, 5 * nrows), sharex=True, sharey=True, squeeze=False)
    axs = axs.flatten()
    axs[0].imshow(imgs_small_RGB[id])
    for i, ax in enumerate(axs[1:]):
        if not i < len(chosen_methods):
            # for cases where we don't have a perfect grid
            ax.axis('off')
            continue
        chosen_method = chosen_methods[i]
        c, visualisation = methods[chosen_method](id)
        ax.imshow(np.minimum(visualisation(c), 1.0))
        ax.set_title(chosen_method)
        add_scalebar(ax, 200, '200 µm', color='white')
        colors[chosen_method] = c
    
    if not ticks:
        for ax in axs:
            ax.axis('off')
    
    fig.tight_layout()
    add_scalebar(axs[0], 200, '200 µm', color='black')
    if save:
        fig.savefig(f'out/segment_colours_{id}.svg', bbox_inches='tight', transparent=True)
        fig.savefig(f'out/segment_colours_{id}.png', bbox_inches='tight', transparent=True, dpi=300)

    return colors

save_button(show_image, False, save=True, avg=True, max_V=True, max_S=True, max_SV=True, centre=True)

In [None]:
@interact(id=list(imgs_small_RGB.keys()))
def size_histogram(id):
    fig, (ax_img, ax_hist) = plt.subplots(ncols=2, figsize=(8,3.5))
    ax_img.imshow(imgs_small_RGB[id])
    ax_img.axis('off')
    add_scalebar(ax_img, 200, '200 µm', color='black', box=True)
    area_pixels = masks_filtered[id].sum(axis=(1,2))
    # 1000 px x 1000 px = 1000,000 px^2 = 1 mm^2 = 1000,000 um^2
    area_um2 = area_pixels / 1.0
    diameter_um = 2 * np.sqrt(area_um2 / np.pi)
    sns.histplot(diameter_um, ax=ax_hist)
    ax_hist.set_xlabel('Diameter (um)')
    ax_hist.set_ylabel('Count')
    ax_hist.set_title(id)
    fig.tight_layout()
    save_button = Button(description='Save')
    fig.savefig(f'out/diameter_distribution_{id}.svg', bbox_inches='tight', transparent=True)
    fig.savefig(f'out/diameter_distribution_{id}.png', bbox_inches='tight', transparent=True, dpi=300)

save_button(size_histogram)

### Visualizing the fitted Gaussian distributions

In [None]:
@interact(id=list(imgs_small_RGB.keys()), chosen_method=list(methods.keys()))
def visualize_distributions(id, save=False, chosen_method='centre'):
    colors = show_image(id, ticks=False, save=False, **{chosen_method: True})[chosen_method]
    data = pd.DataFrame(colors, columns=['R', 'G', 'B'])

    gm = GaussianMixture(n_components=N_CLUSTERS, covariance_type="full")
    fit = gm.fit(colors)

    mus, covs, weights = fit.means_, fit.covariances_, fit.weights_

    axs = sns.pairplot(data, markers=".", plot_kws={"alpha": 0.5})
    axes = axs.axes

    for i in range(N_CLUSTERS):
        cluster_color = mus[i]
        for j in range(3):
            # For diagonal plots, show marginal distribution
            marginal_var = covs[i, j, j]
            x = np.linspace(0.0, 1.0, 100)
            y = norm.pdf(x, mus[i, j], np.sqrt(marginal_var))
            axes[j, j].plot(
                x, y / y.max() * weights[i] / weights.max(), color=cluster_color, lw=2
            )

            for k in range(j+1, 3):
                # For off-diagonal plots, show PDF contours using full covariance
                cov_2d = covs[i][[j,k]][:,[j,k]]
                mu_2d = mus[i][[j,k]]
                
                x, y = np.mgrid[mu_2d[0]-3*np.sqrt(cov_2d[0,0]):mu_2d[0]+3*np.sqrt(cov_2d[0,0]):100j,
                            mu_2d[1]-3*np.sqrt(cov_2d[1,1]):mu_2d[1]+3*np.sqrt(cov_2d[1,1]):100j]
                pos = np.dstack((x, y))
                
                rv = multivariate_normal(mu_2d, cov_2d)
                z = rv.pdf(pos)
                z /= z.max()
                
                z *= weights[i]
                
                axes[k, j].contour(x, y, z, levels=[min(weights) / 1.5], colors=[cluster_color], linewidths=2)

    if save:
        plt.savefig(f'out/pairplot_{id}_{chosen_method}_{N_CLUSTERS}clusters.svg', bbox_inches='tight', transparent=True)
        plt.savefig(f'out/pairplot_{id}_{chosen_method}_{N_CLUSTERS}clusters.png', bbox_inches='tight', transparent=True, dpi=300)
    
    return colors, gm

save_button(visualize_distributions, save=True)

### Visualization of droplet identity on a 2D simplex

Individual images show each droplet in the image within which it has the highest probability of belonging.

In [None]:
def simplex_coords(probs):
    """
    Convert three probabilities into 2D coordinates on a simplex.
    """
    x = 0.5 * (2 * probs[:,1] + probs[:,2])
    y = (np.sqrt(3)/2) * probs[:,2]
    return x, y

@interact(id=list(imgs_small_RGB.keys()), chosen_method=list(methods.keys()))
def visualize_points(id, save=False, chosen_method='centre'):
    colors, gm = visualize_distributions(id, save=False, chosen_method=chosen_method)
    # Get probabilities for each point
    probs = gm.predict_proba(colors)
    assignment = probs.argmax(axis=-1)

    # Add jitter to prevent overlapping
    x_coords, y_coords = simplex_coords(probs) + np.random.normal(0, 0.05, size=(2, len(colors)))

    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10,10))

    axes[0,0].scatter(x_coords, y_coords, color=colors, alpha=1.0, s=5)

    axes[0,0].plot([0,1], [0,0], 'k-', alpha=0.5)
    axes[0,0].plot([0,0.5], [0,np.sqrt(3)/2], 'k-', alpha=0.5)
    axes[0,0].plot([1,0.5], [0,np.sqrt(3)/2], 'k-', alpha=0.5)

    for i, axs in enumerate(axes.flatten()[1:]):
        sum_mask = make_visualisation(colors[assignment == i], masks_filtered[id][assignment == i])
        axs.imshow(np.minimum(sum_mask, 1.0))
        add_scalebar(axs, 200, '200 µm', color='white')
        axs.axis('off')

    axes[0,0].axis('off')
    fig.tight_layout(pad=0.2)

    if save:
        fig.savefig(f'out/simplex_{id}_{chosen_method}_{N_CLUSTERS}clusters.svg', bbox_inches='tight', transparent=True)
        fig.savefig(f'out/simplex_{id}_{chosen_method}_{N_CLUSTERS}clusters.png', bbox_inches='tight', transparent=True, dpi=300)

save_button(visualize_points, save=True)

In [None]:
# Gather droplet diameters for all images
rows = []
for image_id, masks in masks_filtered.items():
    area_pixels = masks.sum(axis=(1,2))
    diameter_um = 2 * np.sqrt(area_pixels / np.pi)
    # Extract series (prefix before timestamp)
    m = re.match(r"(.+?) \d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}", image_id)
    series = m.group(1) if m else image_id
    for d in diameter_um:
        rows.append({"image_id": image_id, "series": series, "diameter": d})

sizes_df = pd.DataFrame(rows)

# Sort image_id by series and timestamp for logical order
sizes_df["timestamp"] = sizes_df["image_id"].str.extract(r"(\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2})")
sizes_df["image_id"] = pd.Categorical(sizes_df["image_id"],
    categories=sizes_df.sort_values(["series", "timestamp"])["image_id"].unique(),
    ordered=True)

series_to_sample = {
    'ZCY-010 together': 'Acid/base 1',
    'ZCY-009-1': 'Acid/base 2',
    'ZCY-009-2': 'Acid/base 3',
    'ZL-48a': 'Acid/base 4',
    'ZL-48b': 'Acid/base 5',
    'ZCY-014c': 'Redox',
}
sizes_df['experiment'] = sizes_df['series'].map(series_to_sample)

droplet_counts = sizes_df.groupby("image_id").size().reindex(sizes_df["image_id"].cat.categories)

fig = plt.figure(figsize=(0.5 * sizes_df["image_id"].nunique(), 10))
gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.05)

# Droplet size distribution (top)
ax0 = fig.add_subplot(gs[0])
sns.boxenplot(data=sizes_df, x="timestamp", y="diameter", hue="experiment", ax=ax0, dodge=False, flier_kws={"alpha": 0.5, "s": 1})
ax0.set_title("Droplet size distribution per image")
ax0.set_xlabel("")
ax0.set_ylabel("Diameter (µm)")
ax0.tick_params(axis='x', rotation=90, labelbottom=False)
ax0.grid(True, axis='y', linestyle='--', alpha=0.5)
ax0.legend(title="Sample", loc='upper right', bbox_to_anchor=(1, 1))

# Droplet count per image (bottom)
ax1 = fig.add_subplot(gs[1], sharex=ax0)
ax1.bar(sizes_df["timestamp"].cat.categories if hasattr(sizes_df["timestamp"], 'cat') else sizes_df["timestamp"].unique(), droplet_counts, color='gray', alpha=0.7)
ax1.set_ylabel("Droplet count")
ax1.set_xlabel("Sample")
ax1.tick_params(axis='x', rotation=90)
ax1.grid(True, axis='y', linestyle='--', alpha=0.5)

plt.setp(ax0.get_xticklabels(), visible=False)
fig.tight_layout(rect=[0, 0, 1, 1])
fig.savefig('out/droplet_size_and_count_distribution.svg', bbox_inches='tight', transparent=True)
fig.savefig('out/droplet_size_and_count_distribution.png', bbox_inches='tight', transparent=True, dpi=300)

### Summary statistics table

In [None]:
def summary_stats(df, groupby):
    return df.groupby(groupby)["diameter"].agg(['count', 'mean', 'std', 'min', 'max']).round(2)

from IPython.display import display, Markdown

display(Markdown("### Summary statistics per image"))
display(summary_stats(sizes_df, "image_id"))
display(Markdown("### Summary statistics per series"))
display(summary_stats(sizes_df, "series"))
display(Markdown("### Overall summary statistics"))
display(sizes_df["diameter"].agg(['count', 'mean', 'std', 'min', 'max']).round(2).to_frame().T)