# Particle statistics calculations

The git repository for this notebook contains a includes a standard development environment that downloads the necessary dataset and install all required packages. If using VS Code you can use the _Dev containers: Reopen in container_ command to run this notebook locally within a tested environment.

In [None]:
import gzip
import pickle
from pathlib import Path

import matplotlib
import numpy as np
import pandas as pd
import PIL
import seaborn as sns
from ipywidgets import Button, interact, interactive
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from transformers import pipeline

DATA_HOME = Path("./data") / "LZ-05"

sns.set(
    context="notebook",
    style="ticks",
    font="Arial",
    font_scale=1.1,
    rc={"svg.fonttype": "none", "lines.linewidth": 1.6, "figure.autolayout": True},
)

In [None]:
excel_file = pd.read_excel(DATA_HOME / "Experiments.xlsx").query('Active == "Yes"')
excel_file

In [None]:
sample_image_files = {
    filename: DATA_HOME / filename for filename in excel_file["Filename"] if (DATA_HOME / filename).exists()
}

In [None]:
# read image files
imgs_RGB = {id: PIL.Image.open(path)
            for id, path in tqdm(sample_image_files.items())}

# shrink by 1/4
imgs_small_RGB = {
    id: img.resize((img.width // 4, img.height // 4), resample=PIL.Image.BILINEAR)
    for id, img in tqdm(imgs_RGB.items())
}

Show the first few images

In [None]:
def show_image(id):
    data = imgs_small_RGB[id]
    plt.imshow(data)
    plt.title(id)


interact(show_image, id=sample_image_files.keys())

Generate or load masks for all images

In [None]:
mask_generator = pipeline(task="mask-generation", model="facebook/sam-vit-base", device='cuda', points_per_crop=64, pred_iou_thresh=0.2, stability_score_thresh=0.2, crops_nms_thresh=0.1, points_per_batch=128)

In [None]:
mask_file = "out/masks.pkl.gz"

if mask_file.exists():
    with gzip.open(mask_file, "rb") as f:
        masks = {
            id: mask for id, mask in pickle.load(f).items() if id in imgs_small_RGB
        }
else:
    masks = {}

new_masks = {
    id: np.array(mask_generator(img)['masks'])
    for id, img in tqdm(imgs_small_RGB.items())
    if id not in masks
}
# Save masks to pickle file
if new_masks:
    masks.update(new_masks)
    if mask_file.exists():
        # make a backup of current mask_file with a timestamp
        backup_file = mask_file.with_name(
            f'{mask_file.stem}_{pd.Timestamp.now().strftime("%Y%m%d%H%M%S")}{mask_file.suffix}'
        )
        mask_file.rename(backup_file)
    with gzip.open(mask_file, "wb") as f:
        pickle.dump(masks, f)
{id: len(img_masks) for id, img_masks in masks.items()}

In [None]:
def show_hist(id):
    areas = np.log10([sample_mask.sum() for sample_mask in masks[id]])
    plt.hist(areas, bins=100)


interact(show_hist, id=sample_image_files.keys())

In [None]:
# Remove masks that are too big
MAX_MASK_AREA = 1500
MIN_MASK_AREA = 50
masks_filtered = {
    id: np.array([mask for mask in img_masks if MIN_MASK_AREA < mask.sum() < MAX_MASK_AREA])
    for id, img_masks in masks.items()
}
{id: len(img_masks) for id, img_masks in masks_filtered.items()}

## Visualising the location of masks for each image

In [None]:
mask_tensors = {
    img_id: masks_filtered[img_id].astype(np.uint8)
    for img_id in tqdm(masks_filtered)
}

In [None]:
interact(lambda sample_name: plt.imshow(mask_tensors[sample_name].sum(axis=0)) and plt.title(sample_name), sample_name=mask_tensors.keys())

In [None]:
all_included = {
    img_id: np.sum(
        mask_tensors[img_id].astype(int)
        * np.random.randint(1, 16, size=(mask_tensors[img_id].shape[0], 1, 1)),
        axis=0,
    )
    for img_id in tqdm(mask_tensors)
}

In [None]:
# Color map where 0 is black and 1-16 are different colors
cmap = matplotlib.colormaps["tab20"]
# set 0 to black
cmap.colors = ((0, 0, 0, 1),) + cmap.colors[1:]

def show_fn(img_id):
    filename = img_id.replace(".jpg", ".svg")
    img = imgs_small_RGB[img_id]
    mask = all_included[img_id]
    f, (img_ax, mask_ax) = plt.subplots(1, 2, figsize=(12, 6))
    img_ax.imshow(img)
    mask_ax.imshow(mask, cmap=cmap, interpolation="none")
    mask_ax.set_title(f"{len(masks_filtered[img_id])} masks")

    b = Button(description=f"Save to {filename}")
    b.on_click(lambda x: f.savefig(filename, format="svg", transparent=True))
    display(b)


interactive(show_fn, img_id=all_included.keys())

In [None]:
# Add mask count and area to the dataframe
mask_areas = [
    pd.DataFrame(
        {
            "mask_area": [mask.sum() for mask in masks],
            "Filename": img_id,
        },
    )
    for img_id, masks in masks_filtered.items()
]
count_df = excel_file.merge(
    pd.DataFrame(
        {"# particles": [len(masks) for masks in masks_filtered.values()]},
        index=masks_filtered.keys(),
    ),
    left_on="Filename",
    right_index=True,
)
mask_df = excel_file.merge(pd.concat(mask_areas), on="Filename")
mask_df = mask_df.assign(
    **{"Diameter (µm)": np.sqrt(mask_df["mask_area"] / np.pi) * 2 / 0.906}
)

In [None]:
mask_df

In [None]:
count_df

In [None]:
sns.lineplot(
    data=count_df.query("`Number of pulses` < 2400").rename(
        columns={"Number of pulses": "# pulses"}
    ),
    x="# pulses",
    y="# particles",
    hue="Pulse duration",
    # native_scale=True,
    # width=0.4,
    # log_scale=True
)
plt.savefig(
    "out/num_particles vs num_pulses.svg",
    transparent=True,
    bbox_inches="tight",
    pad_inches=0.1,
)

In [None]:
sns.boxplot(
    data=count_df.query("`Number of pulses` < 2400").rename(
        columns={"Number of pulses": "# pulses"}
    ),
    x="# pulses",
    y="# particles",
    hue="Pulse duration",
    native_scale=True,
    fliersize=1,
    showfliers=False,
    # width=0.4,
    # log_scale=True
)

In [None]:
sns.boxplot(
    data=mask_df.query("`Number of pulses` < 2400").rename(
        columns={"Number of pulses": "# pulses"}
    ),
    x="# pulses",
    y="Diameter (µm)",
    hue="Pulse duration",
    native_scale=True,
    fliersize=1,
    showfliers=False,
    # width=0.4,
    # log_scale=True
)
plt.savefig(
    "out/diameter vs num_pulses.svg", transparent=True, bbox_inches="tight", pad_inches=0.1
)

In [None]:
sns.lineplot(
    data=mask_df.query("`Number of pulses` < 2400").rename(
        columns={"Number of pulses": "# pulses"}
    ),
    x="# pulses",
    y="Diameter (µm)",
    hue="Pulse duration",
    # width=0.4,
    # log_scale=True
)

In [None]:
count_df.to_excel("out/mask_counts.xlsx", index=False)
mask_df.to_excel("out/mask_areas.xlsx", index=False)