In [None]:
from collections.abc import Sequence

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray

from careamics.dataset_ng.patching_strategies import (
    PatchingStrategy,
    RandomPatchingStrategy,
    StratifiedPatchingStrategy,
)

# Demoing the Stratified Patching Strategy

In [None]:
def demo_selected_patches(
    patching_strategy: PatchingStrategy,
    data_shapes: Sequence[Sequence[int]],
    epochs: int,
) -> Sequence[NDArray[np.int_]]:
    """Create a map where all the patches have been selected from.

    Every time a patch is selected that area is incremented by 1.
    """
    tracking_arrays = [np.zeros(shape, dtype=int) for shape in data_shapes]
    for _ in range(epochs):
        for index in range(patching_strategy.n_patches):
            patch_spec = patching_strategy.get_patch_spec(index)
            data_idx = patch_spec["data_idx"]
            sample_idx = patch_spec["sample_idx"]
            coord = patch_spec["coords"]
            patch_size = patch_spec["patch_size"]

            patch_slice = [
                slice(c, c + ps) for c, ps in zip(coord, patch_size, strict=True)
            ]
            tracking_arrays[data_idx][sample_idx, ..., *patch_slice] += 1
    return tracking_arrays

In [None]:
seed = 42

## Comparing the Stratified Patching Strategy to the Random Strategy

In [None]:
data_shapes = [(1, 1, 512, 620)]
patch_size = (64, 64)

stratified_patching = StratifiedPatchingStrategy(data_shapes, patch_size, seed=42)
random_patching = RandomPatchingStrategy(data_shapes, patch_size, seed=42)

epochs = 1
stratified_selected = demo_selected_patches(stratified_patching, data_shapes, epochs)
random_selected = demo_selected_patches(random_patching, data_shapes, epochs)

fig, axes = plt.subplots(1, 2, figsize=(12, 6))
fig.suptitle(f"Epochs: {epochs}")
axes[0].imshow(stratified_selected[0][0, 0])
axes[0].set_title("Stratified Patching")
axes[1].imshow(random_selected[0][0, 0])
axes[1].set_title("Random Patching")
fig.tight_layout()

In [None]:
epochs = 200
stratified_selected = demo_selected_patches(stratified_patching, data_shapes, epochs)
random_selected = demo_selected_patches(random_patching, data_shapes, epochs)

fig, axes = plt.subplots(1, 2, figsize=(12, 6))
fig.suptitle(f"Epochs: {epochs}")
axes[0].imshow(stratified_selected[0][0, 0])
axes[0].set_title("Stratified Patching")
axes[1].imshow(random_selected[0][0, 0])
axes[1].set_title("Random Patching")
fig.tight_layout()

In [None]:
mean = np.mean(random_selected[0]/epochs)
std = np.std(random_selected[0]/epochs)
print("--- Random Strategy ---")
print("Expected value that a pixel is selected in an epoch")
print(f"Mean: {mean:.3f}, StdDev: {std:.3f}")
print("\n")

mean = np.mean(stratified_selected[0]/epochs)
std = np.std(stratified_selected[0]/epochs)
print("--- Stratified Strategy ---")
print("Expected value that a pixel is selected in an epoch")
print(f"Mean: {mean:.3f}, StdDev: {std:.3f}")

In [None]:
centre_slice = [slice(ps, -ps) for ps in patch_size]

mean = np.mean(random_selected[0][..., *centre_slice]/epochs)
std = np.std(random_selected[0][..., *centre_slice]/epochs)
print("--- Random Strategy ---")
print("Expected value that a central pixel is selected in an epoch")
print(f"Mean: {mean:.3f}, StdDev: {std:.3f}")
print("\n")

mean = np.mean(stratified_selected[0][..., *centre_slice]/epochs)
std = np.std(stratified_selected[0][..., *centre_slice]/epochs)
print("--- Stratified Strategy ---")
print("Expected value that a central pixel is selected in an epoch")
print(f"Mean: {mean:.3f}, StdDev: {std:.3f}")

## Demo patch exclusion

Excluded patches have to lie on the grid which has a grid point on (0, 0) and has a 
spacing equal to the chosen patch size

In [None]:
# chose patches to exclude and make mask

exclude_patches = [(3, 2), (5, 6), (4, 6), (2, 8)]
exlc_mask = np.zeros(data_shapes[0], dtype=bool)
for grid_coord in exclude_patches:
    patch_slice = [
        slice(c * ps, (c + 1) * ps)
        for c, ps in zip(grid_coord, patch_size, strict=True)
    ]
    exlc_mask[..., *patch_slice] = True
plt.imshow(exlc_mask[0, 0])
plt.title("Excluded patches map")

In [None]:
# exclude patches

stratified_patching.exclude_patches(
    data_idx=0, sample_idx=0, grid_coords=exclude_patches
)

In [None]:
# plot results

stratified_1 = demo_selected_patches(stratified_patching, data_shapes, epochs=1)
stratified_200 = demo_selected_patches(stratified_patching, data_shapes, epochs=200)

fig, axes = plt.subplots(1, 2, figsize=(12, 6))
fig.suptitle("Stratified Patching")
axes[0].imshow(stratified_1[0][0, 0])
axes[0].set_title("Epochs: 1")
axes[1].imshow(stratified_200[0][0, 0])
axes[1].set_title("Epochs: 200")
fig.tight_layout()

In [None]:
mean = np.mean(stratified_200[0][~exlc_mask]/200)
std = np.std(stratified_200[0][~exlc_mask]/200)
print("--- Stratified Strategy ---")
print("Expected value that an included pixel is selected in an epoch")
print(f"Mean: {mean:.3f}, StdDev: {std:.3f}")