In [None]:
from collections.abc import Sequence
from pathlib import Path

import numpy as np
import zarr
from numpy.typing import NDArray
from zarr.storage import FSStore

In [None]:
from careamics.dataset_ng.patch_extractor import PatchExtractor
from careamics.dataset_ng.patch_extractor.image_stack import ZarrImageStack

In [None]:
def create_zarr_array(file_path: Path, data_path: str, data: NDArray):
    store = FSStore(url=file_path.resolve())
    # create array
    array = zarr.create(
        store=store,
        shape=data.shape,
        chunks=data.shape,  # only 1 chunk
        dtype=np.uint16,
        path=data_path,
    )
    # write data
    array[...] = data
    store.close()

def create_zarr_group(file_path: Path, group_path: str):
    store = FSStore(url=file_path.resolve())
    zarr.open_group(store, path=group_path)
    store.close()

# def create_zarr(
#     file_path: Path, data_paths: Sequence[str], data: Sequence[NDArray]
# ):
#     store = FSStore(url=file_path.resolve())
#     _ = zarr.open_group(store)
#     store.close()
#     for data_path, array in zip(data_paths, data):
#         for parent_path in Path(data_path).parents[-2::-1]:
#             create_zarr_group(file_path=file_path, group_path=str(parent_path))
#         create_zarr_array(file_path=file_path, data_path=data_path, data=array)
def create_zarr(
    file_path: Path, data_paths: Sequence[str], data: Sequence[NDArray]
):
    for data_path, array in zip(data_paths, data):
        create_zarr_array(file_path=file_path, data_path=data_path, data=array)

In [None]:
dir_path = Path("/home/melisande.croft/Documents/Data")
file_name = "test_ngff_image.zarr"
file_path = dir_path / file_name

data_paths = [
    "image_1",
    "group_1/image_1.1",
    "group_1/image_1.2",
]
data_shapes = [
    (1, 3, 64, 64),
    (1, 3, 32, 48),
    (1, 3, 32, 32)
]
data = [
    np.random.randint(1, 255, size=shape, dtype=np.uint8) for shape in data_shapes
]
create_zarr(file_path, data_paths, data)

In [None]:
store = FSStore(url=file_path.resolve(), mode="r")

In [None]:
list(store.keys())

In [None]:
zarr.open_group(store)["group_1"]

In [None]:
def custom_image_stack_loader(store: FSStore, data_paths: Sequence[str], axes="str"):
    image_stacks = [
        ZarrImageStack(store=store, data_path=data_path, axes=axes)
        for data_path in data_paths
    ]
    return image_stacks

In [None]:
image_stacks = custom_image_stack_loader(
    store=store, data_paths=data_paths, axes="SCYX"
)

In [None]:
patch_extractor = PatchExtractor(image_stacks)

In [None]:
import matplotlib.pyplot as plt

In [None]:
patch = patch_extractor.extract_patch(0, 0, (8, 16), (16, 16))
plt.imshow(np.moveaxis(patch, 0, -1))