In [None]:
import aind_cloud_fusion.blend as blend
import aind_cloud_fusion.fusion as fusion
import aind_cloud_fusion.geometry as geometry
import aind_cloud_fusion.io as io

In [None]:
# Need a custom dataset. 
class ExaspimTestDataset(io.Dataset):
    @property
    def tile_volumes_zyx(self) -> dict[int, io.LazyArray]:
        # Assuming Dataset 'exaSPIM_659146' is mounted. 
        tile_1_path = '/root/capsule/data/exaSPIM_659146_2023-11-10_14-02-06/SPIM.ome.zarr/tile_x_0002_y_0001_z_0000_ch_488.zarr/5'
        tile_2_path = '/root/capsule/data/exaSPIM_659146_2023-11-10_14-02-06/SPIM.ome.zarr/tile_x_0002_y_0002_z_0000_ch_488.zarr/5'

        tile_arrays: dict[int, io.LazyArray] = {}
        for tile_id, t_path in enumerate([tile_1_path, tile_2_path]):
            tile_zarr = da.from_zarr(t_path)
            tile_zarr_zyx = tile_zarr[0, 0, :, :, :]
            tile_arrays[tile_id] = ZarrArray(tile_zarr_zyx)

        return tile_arrays

    @property
    def tile_transforms_zyx(self) -> dict[int, list[geometry.Transform]]:
        tile_transforms: dict[int, list[geometry.Transform]] = {}        
        tile_transforms[0] = geometry.Affine(np.array([[1., 0., 0., 0.], 
                                                       [0., 1., 0., 0.],
                                                       [0., 0., 1., 0.]]))
        tile_transforms[1] = geometry.Affine(np.array([[1., 0., 0., 0.], 
                                                       [0., 1., 0., 281.25], 
                                                       [0., 0., 1., 0.]]))
        return tile_transforms

    @property
    def tile_resolution_zyx(self) -> tuple[float, float, float]:
        return (1.0, 0.748, 0.748)


In [None]:
# Application Object: DATASET
DATASET = ExaspimTestDataset()

# Application Object: OUTPUT_PARAMS
OUTPUT_PARAMS = io.OutputParameters(
        path='fused_exaspim.zarr',
        chunksize=(1, 1, 100, 100, 100),
        resolution_zyx=(1.0, 0.748, 0.748),
)

# Application Object: RUNTIME PARAMS
# (Fill worker cells later)
RUNTIME_PARAMS = io.RuntimeParameters(
        use_gpus=False,
        devices=[torch.device("cpu")],
        pool_size=1
)

# Application Parameter: CELL_SIZE
CELL_SIZE = [100, 100, 100]

# Application Object: BLENDING_MODULE
BLENDING_MODULE = blend.MaskedBlending()

# Run fusion
fusion.run_fusion(
        DATASET,
        OUTPUT_PARAMS,
        RUNTIME_PARAMS,
        CELL_SIZE,
        POST_REG_TFMS,
        BLENDING_MODULE,
)

In [None]:
import matplotlib.pyplot as plt
import zarr

output_path = "tmp/fused_in_x.zarr/0"
arr = zarr.open(output_path, mode="r")
fused_data = arr[0, 0, :, :, :]

fused_data.shape

plt.imshow(fused_data[:, 0, :])
plt.show()