In [9]:
from aind_large_scale_prediction.generator.utils import (
    concatenate_lazy_data, recover_global_position, unpad_global_coords)
from aind_large_scale_prediction.io import ImageReaderFactory
from aind_large_scale_prediction.generator.dataset import create_data_loader
import numpy as np
import logging
from aind_brain_segmentation.model.network import Neuratt
import multiprocessing

prediction_chunksize = (112, 112, 112)
target_size_mb = 2048
n_workers = 0
super_chunksize = None
scale = 3
image_path = 's3://aind-open-data/SmartSPIM_761339_2025-01-10_21-19-15_stitched_2025-01-12_05-39-33/image_tile_fusing/OMEZarr/Ex_488_Em_525.zarr'
checkpoint_path = "/data/smartspim_brain_seg_models/whole_brain_seg/whole_brain_seg/cfelpja3/checkpoints/best_model.ckpt"

device = None

pin_memory = True
if device is not None:
    pin_memory = False
    multiprocessing.set_start_method("spawn", force=True)
    
axis_pad = 8
overlap_prediction_chunksize = (axis_pad, axis_pad, axis_pad)

lazy_data = (
    ImageReaderFactory()
    .create(data_path=str(image_path), parse_path=False, multiscale=scale)
    .as_dask_array()
)

logger = logging.Logger(name="log")

print("Loaded lazy data: ", lazy_data)
batch_size = 1
dtype = np.float32
zarr_data_loader, zarr_dataset = create_data_loader(
    lazy_data=lazy_data,
    target_size_mb=target_size_mb,
    prediction_chunksize=prediction_chunksize,
    overlap_prediction_chunksize=overlap_prediction_chunksize,
    n_workers=n_workers,
    batch_size=batch_size,
    dtype=dtype,  # Allowed data type to process with pytorch cuda
    super_chunksize=super_chunksize,
    lazy_callback_fn=None,  # partial_lazy_deskewing,
    logger=logger,
    device=device,
    pin_memory=pin_memory,
    override_suggested_cpus=False,
    drop_last=True,
    locked_array=False,
)

# Creating model
segmentation_model = Neuratt()

if checkpoint_path:
    print(f"Loading path from {checkpoint_path}")
    segmentation_model = Neuratt.load_from_checkpoint(checkpoint_path)

total_batches = sum(zarr_dataset.internal_slice_sum) / batch_size
print("Total batches: ", total_batches)

Loaded lazy data:  dask.array<from-zarr, shape=(1, 1, 464, 1296, 944), dtype=uint16, chunksize=(1, 1, 128, 128, 128), chunktype=numpy.ndarray>
Estimating super chunksize. Provided super chunksize: None - Target MB: 2048
Estimated chunksize to fit in memory 2048 MiB: (560, 1008, 1008)
Adding overlap area to super chunk size: (560, 1008, 1008) - (576, 1024, 1024)
Loading path from /data/smartspim_brain_seg_models/whole_brain_seg/whole_brain_seg/cfelpja3/checkpoints/best_model.ckpt
Total batches:  650.0


In [10]:
lazy_data

Unnamed: 0,Array,Chunk
Bytes,1.06 GiB,4.00 MiB
Shape,"(1, 1, 464, 1296, 944)","(1, 1, 128, 128, 128)"
Dask graph,352 chunks in 2 graph layers,352 chunks in 2 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 1.06 GiB 4.00 MiB Shape (1, 1, 464, 1296, 944) (1, 1, 128, 128, 128) Dask graph 352 chunks in 2 graph layers Data type uint16 numpy.ndarray",1  1  944  1296  464,

Unnamed: 0,Array,Chunk
Bytes,1.06 GiB,4.00 MiB
Shape,"(1, 1, 464, 1296, 944)","(1, 1, 128, 128, 128)"
Dask graph,352 chunks in 2 graph layers,352 chunks in 2 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray


In [11]:
# chunk_indices = zarr_dataset.lazy_data.shape[-3:] // np.array(prediction_chunksize)
# output_shape = tuple(
#     n * c - (n - 1) * o
#     for n, c, o in zip(chunk_indices, prediction_chunksize, overlap_prediction_chunksize)
# )

# chunk_indices_grid = np.array(
#     np.meshgrid(
#         np.arange(chunk_indices[0]),
#         np.arange(chunk_indices[1]),
#         np.arange(chunk_indices[2]),
#         indexing='ij'
#     )
# ).reshape(3, -1).T

# print(chunk_indices, output_shape, chunk_indices_grid)

In [12]:
# global_start = np.array([ix, iy, iz]) * (np.array(chunk_size) - np.array(overlap_size))
# global_end = global_start + valid_end - valid_start

In [13]:
import zarr

output_seg_path = "/results/intermediate_seg.zarr"

output_intermediate_seg = zarr.open(
    output_seg_path,
    "w",
    shape=(
        1,
        1,
    )
    + zarr_dataset.lazy_data.shape[-3:],
    chunks=(
        1,
        1,
    )
    + (128, 128, 128),
    dtype=np.float32,
)
shape = zarr_dataset.lazy_data.shape[-3:]
print(shape)

(512, 1408, 1024)


In [14]:
segmentation_model.eval()

Neuratt(
  (encoder_path): EncoderPath(
    (conv_1): ConvolutionalBlock(
      (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv_2): ConvolutionalBlock(
      (0): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (drop_3): Dropout(p=0.2, inplace=False)
    (conv_3): ConvolutionalBlock(
      (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (drop_4): Dropout(p=0.2, inplace=False)
    (conv_4): ConvolutionalBlock(
      (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), paddi

In [15]:
import torch
cuda_device = torch.device(0)

In [16]:
for i, sample in enumerate(zarr_data_loader):
    block = sample.batch_tensor[None, ...].to(cuda_device)
    pred_mask, prob_mask = segmentation_model.predict(
        batch=block,
        threshold=0.5,
    )

    (
        global_coord_pos,
        global_coord_positions_start,
        global_coord_positions_end,
    ) = recover_global_position(
        super_chunk_slice=sample.batch_super_chunk[0],
        internal_slices=sample.batch_internal_slice,
    )

    unpadded_global_slice, unpadded_local_slice = unpad_global_coords(
        global_coord_pos=global_coord_pos[-3:],
        block_shape=block.shape[-3:],
        overlap_prediction_chunksize=overlap_prediction_chunksize[-3:],
        dataset_shape=zarr_dataset.lazy_data.shape[-3:],  # zarr_dataset.lazy_data.shape,
    )
    pred_mask = torch.squeeze(pred_mask)
    pred_mask = pred_mask.detach().cpu()

    # print(unpadded_global_slice, unpadded_local_slice, output_intermediate_seg[unpadded_global_slice].shape, pred_mask[unpadded_local_slice].shape)
    # s
    # print(output_intermediate_seg[unpadded_global_slice].shape, unpadded_global_slice)
    unpadded_global_slice = (slice(0, 1), slice(0, 1), ) + unpadded_global_slice
    output_intermediate_seg[unpadded_global_slice] = pred_mask[unpadded_local_slice][None, None, ...]
     
    print(
        f"Tensor shape: {sample.batch_tensor.shape} - Pred mask -> {pred_mask.shape} - unpadded_global_slice: {unpadded_global_slice}"
    )

Tensor shape: torch.Size([1, 120, 120, 120]) - Pred mask -> torch.Size([120, 120, 120]) - unpadded_global_slice: (slice(0, 1, None), slice(0, 1, None), slice(0, 112, None), slice(0, 112, None), slice(0, 112, None))
Tensor shape: torch.Size([1, 120, 120, 128]) - Pred mask -> torch.Size([120, 120, 128]) - unpadded_global_slice: (slice(0, 1, None), slice(0, 1, None), slice(0, 112, None), slice(0, 112, None), slice(112, 224, None))
Tensor shape: torch.Size([1, 120, 120, 128]) - Pred mask -> torch.Size([120, 120, 128]) - unpadded_global_slice: (slice(0, 1, None), slice(0, 1, None), slice(0, 112, None), slice(0, 112, None), slice(224, 336, None))
Tensor shape: torch.Size([1, 120, 120, 128]) - Pred mask -> torch.Size([120, 120, 128]) - unpadded_global_slice: (slice(0, 1, None), slice(0, 1, None), slice(0, 112, None), slice(0, 112, None), slice(336, 448, None))
Tensor shape: torch.Size([1, 120, 120, 128]) - Pred mask -> torch.Size([120, 120, 128]) - unpadded_global_slice: (slice(0, 1, None), s

In [28]:
# for i, sample in enumerate(zarr_data_loader):
#     block = sample.batch_tensor[None, ...].to(cuda_device)
#     pred_mask, prob_mask = segmentation_model.predict(
#         batch=block,
#         threshold=0.5,
#     )

#     (
#         global_coord_pos,
#         global_coord_positions_start,
#         global_coord_positions_end,
#     ) = recover_global_position(
#         super_chunk_slice=sample.batch_super_chunk[0],
#         internal_slices=sample.batch_internal_slice,
#     )

#     global_coord_positions_start = np.array(global_coord_positions_start[0])
#     global_coord_positions_end = np.array(global_coord_positions_end[0])

#     # start_condition = np.where(
#     #     (global_coord_positions_start > 0) &
#     #     (global_coord_positions_start != shape)
#     # )[0]

#     # global_coord_positions_start[start_condition] += (axis_pad * (2**len(start_condition)))

#     # end_condition = np.where(
#     #     (global_coord_positions_end > 0) &
#     #     (global_coord_positions_end != shape)
#     # )[0]

#     # global_coord_positions_end[start_condition] += (axis_pad * (2**len(start_condition)))

#     start_condition = np.where(
#         (global_coord_positions_start == 0)
#     )[0]

#     if len(start_condition) == len(shape):
#         new_global_stop = global_coord_positions_end

#     else:
        
        

#     new_global_coords = []
#     for i in range(len(global_coord_positions_start)):
#         new_global_coords.append(
#             slice(
#                 global_coord_positions_start[i],
#                 global_coord_positions_end[i],
#             )
#         )

#     new_global_coords = (slice(0, 1), slice(0, 1), ) + tuple(new_global_coords)

#     # output_intermediate_seg[new_global_coords] = pred_mask
    
#     # print(global_coord_positions_start, global_coord_positions_end, global_coord_pos, new_global_coords, shape)

#     # Pinned?: {sample.batch_tensor.is_pinned()} - dtype: {sample.batch_tensor.dtype}
    
#     print(
#         f"Tensor shape: {sample.batch_tensor.shape} - Pred mask -> {pred_mask.shape} Global coords: {global_coord_pos} - new global {new_global_coords}"
#     )

Tensor shape: torch.Size([1, 120, 120, 120]) - Pred mask -> torch.Size([1, 1, 120, 120, 120]) Global coords: (slice(0, 120, None), slice(0, 120, None), slice(0, 120, None)) - new global (slice(0, 1, None), slice(0, 1, None), slice(0, 120, None), slice(0, 120, None), slice(0, 120, None))
Tensor shape: torch.Size([1, 120, 120, 128]) - Pred mask -> torch.Size([1, 1, 120, 120, 128]) Global coords: (slice(0, 120, None), slice(0, 120, None), slice(104, 232, None)) - new global (slice(0, 1, None), slice(0, 1, None), slice(0, 120, None), slice(0, 120, None), slice(120, 248, None))
Tensor shape: torch.Size([1, 120, 120, 128]) - Pred mask -> torch.Size([1, 1, 120, 120, 128]) Global coords: (slice(0, 120, None), slice(0, 120, None), slice(216, 344, None)) - new global (slice(0, 1, None), slice(0, 1, None), slice(0, 120, None), slice(0, 120, None), slice(232, 360, None))
Tensor shape: torch.Size([1, 120, 120, 128]) - Pred mask -> torch.Size([1, 1, 120, 120, 128]) Global coords: (slice(0, 120, None

In [20]:
position = np.array([120, 120, 120])
shape = np.array([512, 1408, 1024])

condition = np.where( (position > 0) & (position != shape) )[0]
position[condition]

array([120, 120])