In [5]:
import dask.array as da
import numpy as np
import zarr
from pathlib import Path

In [6]:
data_loc =Path('/data')
dataset_name = 'HCR_783551_2025-10-08_13-00-00_processed_2025-10-22_17-55-54'
print(f'data location: {data_loc / dataset_name}')
print("---- fused channels ----")
channels = ['405', '488',  '561',  '638']
for ch in channels:
    zarr_path = data_loc / dataset_name / 'image_tile_fusing'/ 'fused'/ f'channel_{ch}.zarr' / '0'
    dask_array = da.from_zarr(str(zarr_path))
    print(f"channel {ch} shape ", dask_array.shape)
# zarr_path = data_loc / dataset_name / 'image_tile_fusing'/ 'fused'/ 'channel_405.zarr' / '0'
# dask_array = da.from_zarr(str(zarr_path))

print("---- segmentation mask ----")
mask_path = data_loc / dataset_name / 'image_tile_fusing'/ 'cell_body_segmentation'/ 'segmentation_mask.zarr' / '0'
mask_array = da.from_zarr(str(mask_path))
print("mask shape ", mask_array.shape)
# print("405 shape ", dask_array.shape)


data location: /data/HCR_783551_2025-10-08_13-00-00_processed_2025-10-22_17-55-54
---- fused channels ----
channel 405 shape  (1, 1, 965, 9252, 9241)
channel 488 shape  (1, 1, 965, 9252, 9241)
channel 561 shape  (1, 1, 965, 9252, 9241)
channel 638 shape  (1, 1, 965, 9252, 9241)
---- segmentation mask ----
mask shape  (1, 1, 965, 9252, 9240)


In [17]:
import json
from pathlib import Path
from typing import Iterable, Tuple, Union

try:
    import s3fs  # optional; only needed for s3:// sources
except ImportError:  # pragma: no cover
    s3fs = None  # type: ignore


def get_highest_resolution_spacing(
    zarr_root: Union[str, Path]
) -> Tuple[float, float, float]:
    """
    Return the finest-resolution (Z, Y, X) voxel spacing stored in an OME-Zarr.

    Parameters
    ----------
    zarr_root :
        Filesystem path (e.g. ``/data/sample.zarr``) or S3 URI (e.g. ``s3://bucket/sample.zarr``).

    Returns
    -------
    Tuple[float, float, float]
        Spacing in micrometers for the highest-resolution scale level.

    Raises
    ------
    FileNotFoundError
        If the `.zattrs` metadata file cannot be located.
    ValueError
        If the metadata is missing the multiscale definition or spatial scales.
    RuntimeError
        If S3 access is requested but the optional `s3fs` dependency is unavailable.
    """
    zarr_root = Path(zarr_root)
    zattrs_path = zarr_root / ".zattrs"

    if str(zattrs_path).startswith("s3://"):
        if s3fs is None:
            raise RuntimeError("s3fs is required to read s3:// paths")
        fs = s3fs.S3FileSystem()
        with fs.open(str(zattrs_path), "r") as handle:
            zattrs = json.load(handle)
    else:
        with zattrs_path.open("r") as handle:
            zattrs = json.load(handle)

    multiscales: Iterable[dict] = zattrs.get("multiscales", [])
    if not multiscales:
        raise ValueError("OME-Zarr metadata missing 'multiscales' block")

    datasets = multiscales[0].get("datasets", [])
    if not datasets:
        raise ValueError("OME-Zarr metadata missing 'datasets' entries")

    def spatial_scale(dataset: dict) -> Tuple[float, float, float]:
        for transform in dataset.get("coordinateTransformations", []):
            if transform.get("type") == "scale":
                scale = transform.get("scale", [])
                if len(scale) >= 3:
                    # Use the last three entries; OME order is usually (t, c, z, y, x)
                    return tuple(scale[-3:])
        raise ValueError("Dataset is missing a spatial 'scale' transform")

    # Highest resolution corresponds to the smallest spatial scale factors.
    finest = min(datasets, key=lambda ds: spatial_scale(ds))
    return spatial_scale(finest)

In [None]:
import numpy as np
import zarr
import dask.array as da
from dask.distributed import Client, LocalCluster, performance_report
from dask import delayed
import dask
from multiprocessing import Pool
import s3fs

import numcodecs
from numcodecs import Blosc
from dataclasses import dataclass



@dataclass
class OutputParameters:
    path: str
    chunksize: tuple[int, int, int, int, int]
    resolution_zyx: tuple[float, float, float]
    dtype: np.dtype = np.uint16
    dimension_separator: str = "."
    compressor: numcodecs.blosc.Blosc = None #Blosc(cname='zstd', clevel=1, shuffle=Blosc.SHUFFLE)

def initialize_output_volume(
    output_params: OutputParameters,
    output_volume_size: tuple[int, int, int],
) -> zarr.core.Array:
    """
    Self-documentation of output store initialization.

    Inputs
    ------
    output_params: OutputParameters application instance.
    output_volume_size: output of initalize_data_structures(...)

    Returns
    -------
    Zarr thread-safe datastore initialized on OutputParameters.
    """

    # Local execution
    out_group = zarr.open_group(output_params.path, mode="w")

    # Cloud execuion
    if output_params.path.startswith('s3'):
        s3 = s3fs.S3FileSystem(
            config_kwargs={
                'max_pool_connections': 50,
                's3': {
                'multipart_threshold': 64 * 1024 * 1024,  # 64 MB, avoid multipart upload for small chunks
                'max_concurrent_requests': 20  # Increased from 10 -> 20.
                },
                'retries': {
                'total_max_attempts': 100,
                'mode': 'adaptive',
                }
            }
        )
        store = s3fs.S3Map(root=output_params.path, s3=s3)
        out_group = zarr.open(store=store, mode='a')

    path = "0"
    chunksize = output_params.chunksize
    datatype = output_params.dtype
    dimension_separator = output_params.dimension_separator
    compressor = output_params.compressor
    output_volume = out_group.create_dataset(
        path,
        shape=(
            1,
            1,
            output_volume_size[0],
            output_volume_size[1],
            output_volume_size[2],
        ),
        chunks=chunksize,
        dtype=datatype,
        compressor=compressor,
        dimension_separator=dimension_separator,
        overwrite=True,
        fill_value=0,
    )

    return output_volume



def sum_chunk_data(chunk):
    """
    Sum the data within a given chunk.

    Parameters:
    - chunk: A multidimensional numpy array representing the chunk of data.

    Returns:
    - The sum of all elements within the chunk.
    """
    return chunk.sum().compute()

def process_chunk(params):
    z_idx, y_idx, x_idx, input_zarr, output_zarr, upscale_factors_zyx, current_chunk, total_chunks = params
    # print(f"Upscaling chunk {current_chunk}/{total_chunks}", flush=True)
    if len(input_zarr.shape) == 5: #tczyx
        chunk = input_zarr[0, 0, z_idx:z_idx+128, y_idx:y_idx+128, x_idx:x_idx+128]
    elif len(input_zarr.shape) == 3: #zyx
        chunk = input_zarr[z_idx:z_idx+128, y_idx:y_idx+128, x_idx:x_idx+128]
    else:
        print("len(input_zarr.shape) not compatible: ", len(input_zarr.shape), 'exiting')
        exit()
    
    # Upscale the chunk by duplicating each value to fill a 4x4x4 block
    upscaled_chunk = np.repeat(np.repeat(np.repeat(chunk, upscale_factors_zyx[0], axis=0), upscale_factors_zyx[1], axis=1), upscale_factors_zyx[2], axis=2)
    
    # Calculate the indices for placing the upscaled chunk in the output
    z_new, y_new, x_new = z_idx * upscale_factors_zyx[0], y_idx * upscale_factors_zyx[1], x_idx * upscale_factors_zyx[2]
    
    output_zarr[0, 0, z_new:z_new+upscaled_chunk.shape[0], y_new:y_new+upscaled_chunk.shape[1], x_new:x_new+upscaled_chunk.shape[2]] = upscaled_chunk.compute()


def upscale_zarr(input_path, output_params, upscale_factors_zyx=(1, 4, 4)):
    """
    Upscale a Zarr volume by a factor of 4 in the spatial dimensions (y, x)
    and save to a new Zarr file. Assumes input is in tczyx format with t and c = 1.

    Parameters:
    - input_path: Path to the input Zarr file.
    - output_path: Path to the output Zarr file.
    """
    # Open the input Zarr file
    input_zarr = da.from_zarr(input_path)
    t = 1
    c = 1
    if len(input_zarr.shape) == 5:
        _, _, z, y, x = input_zarr.shape
    else:
        z, y, x = input_zarr.shape


    # Calculate the shape of the upscaled volume
    new_shape = (t, c, z * upscale_factors_zyx[0], y * upscale_factors_zyx[1], x * upscale_factors_zyx[2])
    

    chunk_size = (1, 1, 128, 128, 128)

    print(f"Upscaling {input_path} from size {input_zarr.shape} by {upscale_factors_zyx} to new shape {new_shape} with {output_params.chunksize} chunk size and dtype: {output_params.dtype} ")

    client = Client(LocalCluster(n_workers=32, threads_per_worker=1, processes=True))

    output_zarr = initialize_output_volume(output_params, new_shape[-3:])

    # Calculate the total number of chunks to process
    total_chunks = (np.ceil(z / 128) * np.ceil(y / 128) * np.ceil(x / 128)).astype(int)
    current_chunk = 1

    # Process and upscale each chunk
    for z_idx in range(0, z, 128):
        for y_idx in range(0, y, 128):
            for x_idx in range(0, x, 128):
                # print(f"Upscaling chunk {current_chunk}/{total_chunks}")
                current_chunk += 1
                # Extract the current chunk
                if len(input_zarr.shape) == 5: #tczyx
                    chunk = input_zarr[0, 0, z_idx:z_idx+128, y_idx:y_idx+128, x_idx:x_idx+128]
                elif len(input_zarr.shape) == 3: #zyx
                    chunk = input_zarr[z_idx:z_idx+128, y_idx:y_idx+128, x_idx:x_idx+128]
                else:
                    print("len(input_zarr.shape) not compatible: ", len(input_zarr.shape), 'exitting')
                    exit()
                
                # Upscale the chunk by duplicating each value to fill a 4x4x4 block
                upscaled_chunk = np.repeat(np.repeat(np.repeat(chunk, upscale_factors_zyx[0], axis=0), upscale_factors_zyx[1], axis=1), upscale_factors_zyx[2], axis=2)

                #chunk_sum = sum_chunk_data(chunk)
                #upscaled_sum = sum_chunk_data(upscaled_chunk)
                #print("sums:", chunk_sum, upscaled_sum, 'ratio:', upscaled_sum/chunk_sum, "chunk_shape:", chunk.shape, "upscaled_chunk_shape", upscaled_chunk.shape)
                
                # Calculate the indices for placing the upscaled chunk in the output
                z_new, y_new, x_new = z_idx * upscale_factors_zyx[0], y_idx * upscale_factors_zyx[1], x_idx * upscale_factors_zyx[2]

                output_zarr[0, 0, z_new:z_new+upscaled_chunk.shape[0], y_new:y_new+upscaled_chunk.shape[1], x_new:x_new+upscaled_chunk.shape[2]] = upscaled_chunk.compute()

    print("Upscaling completed.")



In [3]:


def upscale_zarr_to_fixed_shape(input_path, output_params, upscale_factors_zyx=(1, 4, 4), fixed_shape= (None, None, None)):
    """
    Upscale a Zarr volume by a factor of 4 in the spatial dimensions (y, x)
    and save to a new Zarr file. Assumes input is in tczyx format with t and c = 1.

    Parameters:
    - input_path: Path to the input Zarr file.
    - output_path: Path to the output Zarr file.
    """
    # Open the input Zarr file
    input_zarr = da.from_zarr(input_path)
    t = 1
    c = 1
    if len(input_zarr.shape) == 5:
        _, _, z, y, x = input_zarr.shape
    else:
        z, y, x = input_zarr.shape


    # Calculate the shape of the upscaled volume
    new_shape = (t, c, z * upscale_factors_zyx[0], y * upscale_factors_zyx[1], x * upscale_factors_zyx[2])

    if fixed_shape[0] is not None:
        new_shape = (t, c, fixed_shape[0], fixed_shape[1], fixed_shape[2])

    chunk_size = (1, 1, 128, 128, 128)

    print(f"Upscaling {input_path} from size {input_zarr.shape} by {upscale_factors_zyx} to new shape {new_shape} with {output_params.chunksize} chunk size and dtype: {output_params.dtype} ")

    client = Client(LocalCluster(n_workers=32, threads_per_worker=1, processes=True))

    output_zarr = initialize_output_volume(output_params, new_shape[-3:])

    # Calculate the total number of chunks to process
    total_chunks = (np.ceil(z / 128) * np.ceil(y / 128) * np.ceil(x / 128)).astype(int)
    current_chunk = 1

    # Process and upscale each chunk
    for z_idx in range(0, z, 128):
        for y_idx in range(0, y, 128):
            for x_idx in range(0, x, 128):
                # print(f"Upscaling chunk {current_chunk}/{total_chunks}")
                current_chunk += 1
                # Extract the current chunk
                if len(input_zarr.shape) == 5: #tczyx
                    chunk = input_zarr[0, 0, z_idx:z_idx+128, y_idx:y_idx+128, x_idx:x_idx+128]
                elif len(input_zarr.shape) == 3: #zyx
                    chunk = input_zarr[z_idx:z_idx+128, y_idx:y_idx+128, x_idx:x_idx+128]
                else:
                    print("len(input_zarr.shape) not compatible: ", len(input_zarr.shape), 'exitting')
                    exit()
                
                # Upscale the chunk by duplicating each value to fill a 4x4x4 block
                upscaled_chunk = np.repeat(np.repeat(np.repeat(chunk, upscale_factors_zyx[0], axis=0), upscale_factors_zyx[1], axis=1), upscale_factors_zyx[2], axis=2)

                #chunk_sum = sum_chunk_data(chunk)
                #upscaled_sum = sum_chunk_data(upscaled_chunk)
                #print("sums:", chunk_sum, upscaled_sum, 'ratio:', upscaled_sum/chunk_sum, "chunk_shape:", chunk.shape, "upscaled_chunk_shape", upscaled_chunk.shape)
                
                # Calculate the indices for placing the upscaled chunk in the output
                z_new, y_new, x_new = z_idx * upscale_factors_zyx[0], y_idx * upscale_factors_zyx[1], x_idx * upscale_factors_zyx[2]

                output_zarr[0, 0, z_new:z_new+upscaled_chunk.shape[0], y_new:y_new+upscaled_chunk.shape[1], x_new:x_new+upscaled_chunk.shape[2]] = upscaled_chunk.compute()

    print("Upscaling completed.")


In [21]:
zarr_path = data_loc / dataset_name / 'image_tile_fusing'/ 'fused'/ f'channel_{ch}.zarr'
full_scale_voxel_spacing_zyx = get_highest_resolution_spacing(zarr_path)
print("full scale voxel spacing (zyx): ", full_scale_voxel_spacing_zyx)

full scale voxel spacing (zyx):  (1.0, 1.0, 1.0)


In [22]:
#let's try upscaling the segmentation mask to match the fused channel sizes 
full_res_shape = dask_array.shape  # shape of the fused channels

mask_path = data_loc / dataset_name / 'image_tile_fusing'/ 'cell_body_segmentation'/ 'segmentation_mask_orig_res.zarr' / '0'
mask_array = da.from_zarr(str(mask_path))
print("mask shape ", mask_array.shape)
fixed_shape = (full_res_shape[2], full_res_shape[3], full_res_shape[4])
print("fixed shape for upscaling: ", fixed_shape)


test_path = '/scratch/output_segmentation_mask_test.zarr'
compressor = Blosc(cname='zstd', clevel=1, shuffle=Blosc.SHUFFLE)
upload_params = OutputParameters(
    chunksize=(1,1,128,128,128),
    resolution_zyx=full_scale_voxel_spacing_zyx, 
    dtype=np.uint16, 
    path=test_path, 
    compressor=compressor)

upscale_zarr_to_fixed_shape(mask_path, upload_params, upscale_factors_zyx=(1, 4, 4), fixed_shape= fixed_shape)

mask shape  (1, 1, 965, 2313, 2310)
fixed shape for upscaling:  (965, 9252, 9241)
Upscaling /data/HCR_783551_2025-10-08_13-00-00_processed_2025-10-22_17-55-54/image_tile_fusing/cell_body_segmentation/segmentation_mask_orig_res.zarr/0 from size (1, 1, 965, 2313, 2310) by (1, 4, 4) to new shape (1, 1, 965, 9252, 9241) with (1, 1, 128, 128, 128) chunk size and dtype: <class 'numpy.uint16'> 




Upscaling completed.


In [25]:
upscaled_masks = da.from_zarr(Path(test_path).joinpath('0'))
print("upscaled mask shape: ", upscaled_masks.shape)

upscaled mask shape:  (1, 1, 965, 9252, 9241)
