# Creating a custom `ImageStack`

You might want to write a custom `ImageStack` class if you have data stored in a format
that is chunked or capable of sub-file access, i.e. you want to be able to extract 
patches during the training loop without having to load all the data into RAM at once. 
The image stack has to follow the python `Protocol` defined in [patch_extractor/image_stack/image_stack_protocol.py](patch_extractor/image_stack/image_stack_protocol.py).

To use a custom `ImageStack` with the `CAREamicsDataset` we will also have to write an
image stack loader function, that has a protocol defined in [src/careamics/dataset_ng/patch_extractor/image_stack_loader.py](patch_extractor/image_stack_loader.py). It is a callable with the function signature:

```python
# example signature
def custom_image_stack_loader(
    source: Any, axes: str, *args: Any, **kwargs: Any
) -> Sequence[ImageStack]: ...
```

In this demo, we will create a custom image stack and image stack loader for data saved
in a hdf5 file.

In [None]:
from collections.abc import Sequence
from pathlib import Path
from typing import Union

import h5py
import matplotlib.pyplot as plt
import numpy as np
import tifffile
from careamics_portfolio import PortfolioManager
from numpy.typing import DTypeLike, NDArray

from careamics.config import create_care_configuration
from careamics.dataset_ng.dataset import Mode
from careamics.dataset_ng.factory import create_dataset

## Downloading and re-saving data

We will resave some data as HDF5 for the purpose of this demo.

First we download some data that is available using `careamics_portfolio`.

In [None]:
# instantiate data portfolio manager and download the data
data_path = Path("./data")

portfolio = PortfolioManager()
download = portfolio.denoising.CARE_U2OS.download(data_path)

In [None]:
root_path = data_path / "denoising-CARE_U2OS.unzip" / "data" / "U2OS"
train_path = root_path / "train" / "low"
target_path = root_path / "train" / "GT"
test_path = root_path / "test" / "low"
test_target_path = root_path / "test" / "GT"

In [None]:
# checking the train input and target files we have
print(list(train_path.glob("*.tif")))
print(list(target_path.glob("*.tif")))

### Save as HDF5

We will save all the images in a HDF5 file, the input images under a "train" path and 
target images under a "target" path, and all the images will have their original file 
name.

In [None]:
hdf5_file_path = data_path / "CARE_U2OS-train.h5"

if not hdf5_file_path.is_file():
    with h5py.File(name=hdf5_file_path, mode="w") as file:
        train_group = file.create_group("train_input")
        target_group = file.create_group("train_target")
        test_group = file.create_group("test_input")
        test_target_group = file.create_group("test_target")
        for path in train_path.glob("*.tif"):
            image = tifffile.imread(path)
            train_group.create_dataset(name=path.stem, data=image)
        for path in target_path.glob("*.tif"):
            image = tifffile.imread(path)
            target_group.create_dataset(name=path.stem, data=image)
        for path in test_path.glob("*.tif"):
            image = tifffile.imread(path)
            test_group.create_dataset(name=path.stem, data=image)
        for path in test_target_path.glob("*.tif"):
            image = tifffile.imread(path)
            test_target_group.create_dataset(name=path.stem, data=image)

# Defining the image stack

An ImageStack must have the attributes: `data_shape`, `data_dtype` and `source` and the
method `extract_patch`.

The `data_shape` attribute should be shape the data would have once reshaped to match the axes 
`SC(Z)YX`.

The `data_dtype` attribute is the data type of the underlying array.

The `source` attribute should have the type `Path`, it will be returned alongside the patches by the
`CAREamicsDataset` and can be used as a way to identify where the data came from. In the
future it may be used as a way to automatically save predictions to disk.

The `extract_patch` method needs to return a patch for a given `sample_index`, `coords` 
and `patch_size` that has the axes `SC(Z)YX`. So, for our HDF5 case the patches need to 
be reshaped when the `extract_patch_method` is called.

In [None]:
from careamics.dataset.dataset_utils import reshape_array
from careamics.dataset_ng.patch_extractor.image_stack.zarr_image_stack import (
    _reshaped_array_shape,
)


class HDF5ImageStack:

    def __init__(self, image_data: h5py.Dataset, axes: str):
        self._image_data = image_data
        self._original_axes = axes
        self._original_data_shape = image_data.shape
        self.data_shape = _reshaped_array_shape(
            self._original_axes, self._image_data.shape
        )

    @property
    def data_dtype(self) -> DTypeLike:
        return self._image_data.dtype

    @property
    def source(self) -> Path:
        return Path(self._image_data.file.filename + str(self._image_data.name))

    # this method is almost an exact copy of the ZarrImageStack.extract patch
    def extract_patch(
        self, sample_idx: int, coords: Sequence[int], patch_size: Sequence[int]
    ) -> NDArray:
        # original axes assumed to be any subset of STCZYX (containing YX), in any order
        # arguments must be transformed to index data in original axes order
        # to do this: loop through original axes and append correct index/slice
        #   for each case: STCZYX
        #   Note: if any axis is not present in original_axes it is skipped.

        # guard for no S and T in original axes
        if ("S" not in self._original_axes) and ("T" not in self._original_axes):
            if sample_idx not in [0, -1]:
                raise IndexError(
                    f"Sample index {sample_idx} out of bounds for S axes with size "
                    f"{self.data_shape[0]}"
                )

        patch_slice: list[Union[int, slice]] = []
        for d in self._original_axes:
            if d == "S":
                patch_slice.append(self._get_S_index(sample_idx))
            elif d == "T":
                patch_slice.append(self._get_T_index(sample_idx))
            elif d == "C":
                patch_slice.append(slice(None, None))
            elif d == "Z":
                patch_slice.append(slice(coords[0], coords[0] + patch_size[0]))
            elif d == "Y":
                y_idx = 0 if "Z" not in self._original_axes else 1
                patch_slice.append(
                    slice(coords[y_idx], coords[y_idx] + patch_size[y_idx])
                )
            elif d == "X":
                x_idx = 1 if "Z" not in self._original_axes else 2
                patch_slice.append(
                    slice(coords[x_idx], coords[x_idx] + patch_size[x_idx])
                )
            else:
                raise ValueError(f"Unrecognised axis '{d}', axes should be in STCZYX.")

        patch = self._image_data[tuple(patch_slice)]
        patch_axes = self._original_axes.replace("S", "").replace("T", "")
        return reshape_array(patch, patch_axes)[0]  # remove first sample dim

    def _get_T_index(self, sample_idx: int) -> int:
        """Get T index given `sample_idx`."""
        if "T" not in self._original_axes:
            raise ValueError("No 'T' axis specified in original data axes.")
        axis_idx = self._original_axes.index("T")
        dim = self._original_data_shape[axis_idx]

        # new S' = S*T
        # T_idx = S_idx' // T_size
        # S_idx = S_idx' % T_size
        # - floor divide finds the row
        # - modulus finds how far along the row i.e. the column
        return sample_idx % dim

    def _get_S_index(self, sample_idx: int) -> int:
        """Get S index given `sample_idx`."""
        if "S" not in self._original_axes:
            raise ValueError("No 'S' axis specified in original data axes.")
        if "T" in self._original_axes:
            T_axis_idx = self._original_axes.index("T")
            T_dim = self._original_data_shape[T_axis_idx]

            # new S' = S*T
            # T_idx = S_idx' // T_size
            # S_idx = S_idx' % T_size
            # - floor divide finds the row
            # - modulus finds how far along the row i.e. the column
            return sample_idx // T_dim
        else:
            return sample_idx

### Now define the image loader

The loader needs to have the first two arguments be `source` and `axes`, then any 
additional kwargs are allowed. However, note that the additional kwargs have to be 
shared by both the input and the target when the dataset is initialized.


In [None]:
# A image stack loader
# both the input and target image stacks must be contained within the same HDF5 file
def hdf5_image_stack_loader(
    source: Sequence[str], axes: str, file: h5py.File
) -> Sequence[HDF5ImageStack]:
    image_stacks: list[HDF5ImageStack] = []
    for data_path in source:
        if data_path not in file:
            raise KeyError(f"Data does not exist at path '{data_path}'")
        image_data = file[data_path]
        if not isinstance(image_data, h5py.Dataset):
            raise TypeError(f"HDF5 node at path '{data_path}' is not a Dataset.")
        image_stacks.append(HDF5ImageStack(image_data, axes=axes))
    return image_stacks

In [None]:
# This is an alternative hdf5 image stack loader
# The input and target files can be contained in separate hdf5 files
# An HDF5Source typed dict has to be defined
# this is to allow both the file and the data paths to be combined in a single argument

from typing import TypedDict


class HDF5Source(TypedDict):
    file: h5py.File
    data_path: str


def hdf5_image_stack_loader_alt(
    source: Sequence[HDF5Source], axes: str
) -> Sequence[HDF5ImageStack]:
    image_stacks: list[HDF5ImageStack] = []
    for image_stack_source in source:
        data_path = image_stack_source["data_path"]
        file = image_stack_source["file"]
        if data_path not in file:
            raise KeyError(f"Data does not exist at path '{data_path}'")
        image_data = file[data_path]
        if not isinstance(image_data, h5py.Dataset):
            raise TypeError(f"HDF5 node at path '{data_path}' is not a Dataset.")
        image_stacks.append(HDF5ImageStack(image_data, axes=axes))
    return image_stacks

## Now we test it

### create a configuration for the data

In [None]:
train_files = sorted(train_path.glob("*.tif"))
train_target_files = sorted(target_path.glob("*.tif"))

config = create_care_configuration(
    experiment_name="care_U20S",
    data_type="custom",
    axes="YX",
    patch_size=[128, 128],
    batch_size=32,
    num_epochs=50,
)

In [None]:
hdf5_file = h5py.File(hdf5_file_path, mode="r")

inputs = sorted([f"train_input/{key}" for key in hdf5_file["train_input"].keys()])
targets = sorted([f"train_target/{key}" for key in hdf5_file["train_target"].keys()])
test_inputs = sorted([f"test_input/{key}" for key in hdf5_file["test_input"].keys()])
test_targets = sorted([f"test_target/{key}" for key in hdf5_file["test_target"].keys()])

dataset = create_dataset(
    config=config.data_config,
    mode=Mode.TRAINING,
    inputs=inputs,
    targets=targets,
    in_memory=False,
    image_stack_loader=hdf5_image_stack_loader,
    image_stack_loader_kwargs={"file": hdf5_file},
)

### Index the dataset and display the result

In [None]:
fig, axes = plt.subplots(1, 2)
train_input, target = dataset[0]
axes[0].imshow(train_input.data[0])
axes[0].set_title("Input")
axes[1].imshow(target.data[0])
axes[1].set_title("Target")

In [None]:
# input and target are ImageRegionData objects
train_input, target

### Test the alternative image stack loader

In [None]:
hdf5_file = h5py.File(hdf5_file_path, mode="r")

data_keys = sorted(hdf5_file["train_input"].keys())

# for the alternative image stack loader we have to construct a list of dicts
# because we defined the source type to be a HDF5Source typed dict
inputs: list[HDF5Source] = [
    {"data_path": f"train_input/{key}", "file": hdf5_file} for key in data_keys
]
targets: list[HDF5Source] = [
    {"data_path": f"train_target/{key}", "file": hdf5_file} for key in data_keys
]

dataset = create_dataset(
    config=config.data_config,
    mode=Mode.TRAINING,
    inputs=inputs,
    targets=targets,
    in_memory=False,
    image_stack_loader=hdf5_image_stack_loader_alt,
    # now we don't have any additional kwargs
)

In [None]:
# display the first item
# note this will be a different patch because of the random patching
fig, axes = plt.subplots(1, 2)
train_input, target = dataset[0]
axes[0].imshow(train_input.data[0])
axes[0].set_title("Input")
axes[1].imshow(target.data[0])
axes[1].set_title("Target")

### Now let's run N2V training pipeline and see how it performs

#### Creating the lightning data module for training

In [None]:
from careamics.config.inference_model import InferenceConfig
from careamics.lightning.dataset_ng.data_module import CareamicsDataModule

hdf5_file = h5py.File(hdf5_file_path, mode="r")

train_data_keys = sorted(hdf5_file["train_input"].keys())

inputs: list[HDF5Source] = [
    {"data_path": f"train_input/{key}", "file": hdf5_file} for key in train_data_keys
]
targets: list[HDF5Source] = [
    {"data_path": f"train_target/{key}", "file": hdf5_file} for key in train_data_keys
]

test_data_keys = sorted(hdf5_file["test_input"].keys())
test_inputs: list[HDF5Source] = [
    {"data_path": f"test_input/{key}", "file": hdf5_file} for key in test_data_keys
]
test_targets: list[HDF5Source] = [
    {"data_path": f"test_target/{key}", "file": hdf5_file} for key in test_data_keys
]
config = create_care_configuration(
    experiment_name="care_U20S",
    data_type="custom",
    axes="YX",
    patch_size=[128, 128],
    batch_size=32,
    num_epochs=50,
)
train_data_module = CareamicsDataModule(
    data_config=config.data_config,
    train_data=inputs,
    train_data_target=targets,
    val_data=inputs,
    val_data_target=targets,
    image_stack_loader=hdf5_image_stack_loader_alt,
)

#### Creating the model and the trainer

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from careamics.lightning.dataset_ng.lightning_modules import CAREModule

root = Path("care_stack_loader")

# TODO: replace with N2V!!!
model = CAREModule(config.algorithm_config)

callbacks = [
    ModelCheckpoint(
        dirpath=root / "checkpoints",
        filename="care_baseline",
        save_last=True,
        monitor="val_loss",
        mode="min",
    )
]

trainer = Trainer(max_epochs=50, default_root_dir=root, callbacks=callbacks)

#### Training the model

In [None]:
trainer.fit(model, datamodule=train_data_module)

#### Creating the inference data module

In [None]:

inference_config = InferenceConfig(
    model_config=config,
    data_type="custom",
    tile_size=(128, 128),
    tile_overlap=(32, 32),
    axes="YX",
    batch_size=1,
    image_means=train_data_module.train_dataset.input_stats.means,
    image_stds=train_data_module.train_dataset.input_stats.stds,
)

inf_data_module = CareamicsDataModule(
    data_config=inference_config,
    pred_data=test_inputs,
    image_stack_loader=hdf5_image_stack_loader_alt,
)

#### Running the prediction on the test set

In [None]:
from careamics.dataset_ng.legacy_interoperability import imageregions_to_tileinfos
from careamics.prediction_utils import convert_outputs

predictions = trainer.predict(model, datamodule=inf_data_module)
tile_infos = imageregions_to_tileinfos(predictions)
prediction = convert_outputs(tile_infos, tiled=True)

#### Displaying the predictions

In [None]:
from careamics.utils.metrics import psnr, scale_invariant_psnr

# Show two images
noises = [tifffile.imread(f) for f in sorted(test_path.glob("*.tif"))]
gts = [tifffile.imread(f) for f in sorted(test_target_path.glob("*.tif"))]

fig, ax = plt.subplots(3, 3, figsize=(7, 7))
fig.tight_layout()

for i in range(3):
    pred_image = prediction[i].squeeze()
    psnr_noisy = psnr(
        gts[i],
        noises[i],
        data_range=gts[i].max() - gts[i].min(),
    )
    psnr_result = psnr(
        gts[i],
        pred_image,
        data_range=gts[i].max() - gts[i].min(),
    )

    scale_invariant_psnr_result = scale_invariant_psnr(gts[i], pred_image)

    ax[i, 0].imshow(noises[i], cmap="gray")
    ax[i, 0].title.set_text(f"Noisy\nPSNR: {psnr_noisy:.2f}")

    ax[i, 1].imshow(pred_image, cmap="gray")
    ax[i, 1].title.set_text(
        f"Prediction\nPSNR: {psnr_result:.2f}\n"
        f"Scale invariant PSNR: {scale_invariant_psnr_result:.2f}"
    )

    ax[i, 2].imshow(gts[i], cmap="gray")
    ax[i, 2].title.set_text("Ground-truth")

#### Calculating the metrics on the test set

In [None]:
psnrs = np.zeros((len(prediction), 1))
scale_invariant_psnrs = np.zeros((len(prediction), 1))

for i, (pred, gt) in enumerate(zip(prediction, gts)):
    psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())
    scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())

print(f"PSNR: {psnrs.mean():.2f} +/- {psnrs.std():.2f}")
print(
    f"Scale invariant PSNR: "
    f"{scale_invariant_psnrs.mean():.2f} +/- {scale_invariant_psnrs.std():.2f}"
)