In [None]:
from collections.abc import Sequence

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray

from careamics.dataset_ng.patching_strategies import (
    PatchingStrategy,
    StratifiedPatchingStrategy,
)
from careamics.dataset_ng.val_split import create_val_split

In [None]:
def demo_selected_patches(
    patching_strategy: PatchingStrategy,
    data_shapes: Sequence[Sequence[int]],
    epochs: int,
) -> Sequence[NDArray[np.int_]]:
    """Create a map where all the patches have been selected from.

    Every time a patch is selected that area is incremented by 1.
    """
    tracking_arrays = [np.zeros(shape, dtype=int) for shape in data_shapes]
    for _ in range(epochs):
        for index in range(patching_strategy.n_patches):
            patch_spec = patching_strategy.get_patch_spec(index)
            data_idx = patch_spec["data_idx"]
            sample_idx = patch_spec["sample_idx"]
            coord = patch_spec["coords"]
            patch_size = patch_spec["patch_size"]

            patch_slice = [
                slice(c, c + ps) for c, ps in zip(coord, patch_size, strict=True)
            ]
            tracking_arrays[data_idx][sample_idx, ..., *patch_slice] += 1
    return tracking_arrays

In [None]:
rng = np.random.default_rng(42)

In [None]:
data_shapes = [(1, 1, 512, 620), (1, 1, 300, 335), (1, 1, 512, 512)]
patch_size = (64, 64)

stratified_patching = StratifiedPatchingStrategy(data_shapes, patch_size, seed=42)
n_val_patches = int(np.ceil(stratified_patching.n_patches * 0.1))  # 10% of patches
print(
    f"Selecting {n_val_patches} validation patches from "
    f"{stratified_patching.n_patches} total patches."
)
train_patching, val_patching = create_val_split(stratified_patching, n_val_patches, rng)

train_1 = demo_selected_patches(train_patching, data_shapes, epochs=1)
train_200 = demo_selected_patches(train_patching, data_shapes, epochs=200)
val = demo_selected_patches(val_patching, data_shapes, epochs=1)

In [None]:
fig, axes = plt.subplots(3, len(data_shapes), figsize=(12, 12), constrained_layout=True)
for i in range(len(data_shapes)):
    axes[0, i].set_title(f"Image {i}")
    axes[0, i].imshow(train_1[i][0, 0])
    axes[1, i].imshow(train_200[i][0, 0])
    axes[2, i].imshow(val[i][0, 0])
axes[0, 0].set_ylabel("Train epochs 1")
axes[1, 0].set_ylabel("Train epochs 200")
axes[2, 0].set_ylabel("Validation")