In [1]:
from aind_large_scale_prediction.generator.utils import (
    concatenate_lazy_data, recover_global_position, unpad_global_coords)
from aind_large_scale_prediction.io import ImageReaderFactory
from aind_large_scale_prediction.generator.dataset import create_data_loader
import numpy as np
import logging
from aind_brain_segmentation.model.network import Neuratt
import multiprocessing

prediction_chunksize = (112, 112, 112)
target_size_mb = 2048
n_workers = 0
super_chunksize = None
scale = 3
image_path = 's3://aind-open-data/SmartSPIM_761339_2025-01-10_21-19-15_stitched_2025-01-12_05-39-33/image_tile_fusing/OMEZarr/Ex_488_Em_525.zarr'
checkpoint_path = "/data/smartspim_brain_seg_models/whole_brain_seg/whole_brain_seg/cfelpja3/checkpoints/best_model.ckpt"

device = None

pin_memory = True
if device is not None:
    pin_memory = False
    multiprocessing.set_start_method("spawn", force=True)
    
axis_pad = 8
overlap_prediction_chunksize = (axis_pad, axis_pad, axis_pad)

lazy_data = (
    ImageReaderFactory()
    .create(data_path=str(image_path), parse_path=False, multiscale=scale)
    .as_dask_array()
)

logger = logging.Logger(name="log")

print("Loaded lazy data: ", lazy_data)
batch_size = 1
dtype = np.float32
zarr_data_loader, zarr_dataset = create_data_loader(
    lazy_data=lazy_data,
    target_size_mb=target_size_mb,
    prediction_chunksize=prediction_chunksize,
    overlap_prediction_chunksize=overlap_prediction_chunksize,
    n_workers=n_workers,
    batch_size=batch_size,
    dtype=dtype,  # Allowed data type to process with pytorch cuda
    super_chunksize=super_chunksize,
    lazy_callback_fn=None,  # partial_lazy_deskewing,
    logger=logger,
    device=device,
    pin_memory=pin_memory,
    override_suggested_cpus=False,
    drop_last=True,
    locked_array=False,
)

# Creating model
segmentation_model = Neuratt()

if checkpoint_path:
    print(f"Loading path from {checkpoint_path}")
    segmentation_model = Neuratt.load_from_checkpoint(checkpoint_path)

total_batches = sum(zarr_dataset.internal_slice_sum) / batch_size
print("Total batches: ", total_batches)

  from .autonotebook import tqdm as notebook_tqdm


Loaded lazy data:  dask.array<from-zarr, shape=(1, 1, 464, 1296, 944), dtype=uint16, chunksize=(1, 1, 128, 128, 128), chunktype=numpy.ndarray>
Estimating super chunksize. Provided super chunksize: None - Target MB: 2048
Estimated chunksize to fit in memory 2048 MiB: (560, 1008, 1008)
Adding overlap area to super chunk size: (560, 1008, 1008) - (576, 1024, 1024)
Loading path from /data/smartspim_brain_seg_models/whole_brain_seg/whole_brain_seg/cfelpja3/checkpoints/best_model.ckpt
Total batches:  650.0


In [2]:
lazy_data

Unnamed: 0,Array,Chunk
Bytes,1.06 GiB,4.00 MiB
Shape,"(1, 1, 464, 1296, 944)","(1, 1, 128, 128, 128)"
Dask graph,352 chunks in 2 graph layers,352 chunks in 2 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 1.06 GiB 4.00 MiB Shape (1, 1, 464, 1296, 944) (1, 1, 128, 128, 128) Dask graph 352 chunks in 2 graph layers Data type uint16 numpy.ndarray",1  1  944  1296  464,

Unnamed: 0,Array,Chunk
Bytes,1.06 GiB,4.00 MiB
Shape,"(1, 1, 464, 1296, 944)","(1, 1, 128, 128, 128)"
Dask graph,352 chunks in 2 graph layers,352 chunks in 2 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray


In [11]:
# import zarr

# output_seg_path = "/results/intermediate_seg.zarr"

# output_intermediate_seg = zarr.open(
#     output_seg_path,
#     "w",
#     shape=(
#         1,
#         1,
#     )
#     + zarr_dataset.lazy_data.shape[-3:],
#     chunks=(
#         1,
#         1,
#     )
#     + (128, 128, 128),
#     dtype=np.float32,
# )
# shape = zarr_dataset.lazy_data.shape[-3:]

In [3]:
segmentation_model.eval()

Neuratt(
  (encoder_path): EncoderPath(
    (conv_1): ConvolutionalBlock(
      (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv_2): ConvolutionalBlock(
      (0): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (drop_3): Dropout(p=0.2, inplace=False)
    (conv_3): ConvolutionalBlock(
      (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=same)
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (drop_4): Dropout(p=0.2, inplace=False)
    (conv_4): ConvolutionalBlock(
      (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), paddi

In [6]:
import torch
cuda_device = torch.device(0)

In [7]:
for i, sample in enumerate(zarr_data_loader):
    block = sample.batch_tensor[None, ...].to(cuda_device)
    pred_mask, prob_mask = segmentation_model.predict(
        batch=block,
        threshold=0.5,
    )
    
    print(
        f"Batch [{i} {sample.batch_tensor.shape} - Pinned?: {sample.batch_tensor.is_pinned()} - dtype: {sample.batch_tensor.dtype} - device: {sample.batch_tensor.device} - Pred mask -> {pred_mask.shape}"
    )

Batch [0 torch.Size([1, 120, 120, 120]) - Pinned?: True - dtype: torch.float32 - device: cpu - Pred mask -> torch.Size([1, 1, 120, 120, 120])
Batch [1 torch.Size([1, 120, 120, 128]) - Pinned?: True - dtype: torch.float32 - device: cpu - Pred mask -> torch.Size([1, 1, 120, 120, 128])
Batch [2 torch.Size([1, 120, 120, 128]) - Pinned?: True - dtype: torch.float32 - device: cpu - Pred mask -> torch.Size([1, 1, 120, 120, 128])
Batch [3 torch.Size([1, 120, 120, 128]) - Pinned?: True - dtype: torch.float32 - device: cpu - Pred mask -> torch.Size([1, 1, 120, 120, 128])
Batch [4 torch.Size([1, 120, 120, 128]) - Pinned?: True - dtype: torch.float32 - device: cpu - Pred mask -> torch.Size([1, 1, 120, 120, 128])
Batch [5 torch.Size([1, 120, 120, 128]) - Pinned?: True - dtype: torch.float32 - device: cpu - Pred mask -> torch.Size([1, 1, 120, 120, 128])
Batch [6 torch.Size([1, 120, 120, 128]) - Pinned?: True - dtype: torch.float32 - device: cpu - Pred mask -> torch.Size([1, 1, 120, 120, 128])
Batch 