In [None]:
# Workflow: Concatenate BioIO time-series and save as lazy-loadable OME-Zarr

from bioio import BioImage
from bioio.writers import OmeZarrWriterV3
import xarray as xr
from typing import Union
from pathlib import Path


class TimeSeriesReader:
    def __init__(self, folder_path: Union[str, Path], time_chunks: int = 10):
        self.folder_path = Path(folder_path)
        self.time_chunks = time_chunks
        self.file_paths = sorted([str(p) for p in self.folder_path.glob("*") if p.is_file()])
        self.data_arrays = []
        self.metadata_list = []
        self.timestamps = []
        self.total_frames = 0
        self.combined = None
        self.combined_metadata = {}

    def load_data(self):
        for path in self.file_paths:
            img = BioImage(path)
            da_xr = img.xarray_data
            nT = da_xr.sizes.get("T", 1)
            self.total_frames += nT
            self.data_arrays.append(da_xr)
            self.metadata_list.append(img.metadata)

            timestamps = getattr(img.metadata, "timestamps", None)
            if timestamps is not None:
                self.timestamps.extend(timestamps)
            else:
                self.timestamps.extend([None] * nT)

    def concatenate_and_chunk(self):
        self.combined = xr.concat(self.data_arrays, dim="T")
        frames_per_chunk = max(1, self.total_frames // self.time_chunks)
        chunk_dict = {dim: 1 for dim in self.combined.dims if dim != "T"}
        chunk_dict["T"] = frames_per_chunk
        self.combined = self.combined.chunk(chunk_dict)
        self.combined = self.combined.assign_coords(T=("T", self.timestamps))

    def build_metadata(self):
        self.combined_metadata = {
            "source_files": self.file_paths,
            "concatenated_from": len(self.file_paths),
            "original_metadata": self.metadata_list,
        }

    def save_to_zarr(self, output_path: Union[str, Path]):
        shape = tuple(self.combined.sizes[dim] for dim in self.combined.dims)
        dtype = self.combined.dtype

        # Attempt to retrieve physical pixel sizes from the first image's metadata
        if self.metadata_list:
            first_meta = self.metadata_list[0]
            px_size_z = getattr(first_meta, "physical_size_z", 1.0)
            px_size_y = getattr(first_meta, "physical_size_y", 1.0)
            px_size_x = getattr(first_meta, "physical_size_x", 1.0)
        else:
            px_size_z = px_size_y = px_size_x = 1.0

        # Match the scale factors to the axes: T, C, Z, Y, X
        scale_factors = [1.0, 1.0, px_size_z, px_size_y, px_size_x]

        writer = OmeZarrWriterV3(
            store=str(output_path),
            shape=shape,
            dtype=dtype,
            scale_factors=scale_factors,
            axes_names=["T", "C", "Z", "Y", "X"],
        )
        writer.write_image(
            image=self.combined,
            dimension_order="".join(self.combined.dims),
            metadata=self.combined_metadata,
        )

    def run(self, output_path: str):
        self.load_data()
        self.concatenate_and_chunk()
        self.build_metadata()
        self.save_to_zarr(output_path)


# Example usage
if __name__ == "__main__":
    folder = "/mnt/Data1/Yovan/physio2025/ms2_pipeline_lite/test_data/MCP-mSG_His-RFP_r1-close(002)_embryo02"
    concatenator = TimeSeriesReader(folder, time_chunks=20)
    concatenator.run("/mnt/Data1/Yovan/physio2025/ms2_pipeline_lite/test_data/MCP-mSG_His-RFP_r1-close(002)_embryo02/collated_dataset.zarr")

    # Lazy load example
    # import xarray as xr
    # ds = xr.open_zarr("output_merged.zarr", consolidated=True)
    # data = ds["0"]  # the group name "0" is the default dataset
