In [1]:
import SimpleITK as sitk
import numpy as np

# Read original DICOM or original NRRD (LPS)
img = sitk.ReadImage(r"J:\startup\skull_ct.nrrd")

# Get original metadata
size = img.GetSize()
spacing = img.GetSpacing()
origin = np.array(img.GetOrigin())
direction = np.array(img.GetDirection()).reshape(3, 3)

# LPS -> RAS conversion matrix
lps_to_ras = np.diag([-1, -1, 1])

# Apply direction fix
new_direction = lps_to_ras @ direction
img.SetDirection(new_direction.flatten())

# Apply origin fix
new_origin = lps_to_ras @ origin
img.SetOrigin(tuple(new_origin))

# DO NOT flip the image data again
sitk.WriteImage(img, r"J:\startup\skull_ct_RAS00.nrrd")


In [2]:
print(img.GetDirection())
print(img.GetOrigin())


(-0.9999999999999999, 0.0, 0.0, 0.0, -0.9999999999999999, 0.0, 0.0, 0.0, 1.0)
(119.8255078125, 249.2535078125, -823.18)


In [1]:
import os
import time
import numpy as np
import torch
import nrrd
from collections import OrderedDict

# MONAI imports
from monai.bundle import ConfigParser
from monai.data import decollate_batch, list_data_collate
from monai.utils import convert_to_dst_type, MetaKeys
from torch.cuda.amp import autocast
from monai.inferers import SlidingWindowInfererAdapt
from monai.transforms import (
    Compose, CropForegroundd, EnsureTyped, Invertd, KeepLargestConnectedComponentd,
    Lambdad, LoadImaged, NormalizeIntensityd, Resized, ScaleIntensityRanged,
    Spacingd, Orientationd, ConcatItemsd,
)

# --- System Check ---
print("--- System Configuration ---")
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"✅ GPU Detected: {torch.cuda.get_device_name(0)}")
    print(f"   CUDA Version: {torch.version.cuda}")
    print(f"   Device Count: {torch.cuda.device_count()}")
else:
    device = torch.device("cpu")
    print("⚠️  No GPU detected. Running on CPU (Inference will be slower).")
print("----------------------------")

--- System Configuration ---
✅ GPU Detected: NVIDIA GeForce RTX 2050
   CUDA Version: 12.1
   Device Count: 1
----------------------------


In [3]:
# --- Configuration ---
# Using raw strings (r"...") to handle Windows backslashes correctly
MODEL_FILE = r"J:\startup\whole-head-05mm-v1.0.1\model.pt"
IMAGE_FILE = r"J:\startup\skull_ct_RAS00.nrrd"
RESULT_FILE = r"J:\startup\output_seg02.nrrd"

# Optional additional images (set to None if not used)
IMAGE_FILE_2 = None
IMAGE_FILE_3 = None
IMAGE_FILE_4 = None


# --- Helper Functions ---

def logits2pred(logits, sigmoid=False, dim=1):
    if isinstance(logits, (list, tuple)):
        logits = logits[0]

    if sigmoid:
        pred = torch.sigmoid(logits)
        pred = (pred >= 0.5)
    else:
        pred = torch.softmax(logits, dim=dim)
        pred = torch.argmax(pred, dim=dim, keepdim=True).to(dtype=torch.uint8)

    return pred

def _add_normalization_transforms(ts, key, normalize_mode, intensity_bounds):
    if normalize_mode == "none":
        pass
    elif normalize_mode in ["range", "ct"]:
        ts.append(ScaleIntensityRanged(keys=key, a_min=intensity_bounds[0], a_max=intensity_bounds[1],
                                     b_min=-1, b_max=1, clip=False))
        ts.append(Lambdad(keys=key, func=lambda x: torch.sigmoid(x)))
    elif normalize_mode in ["meanstd", "mri"]:
        ts.append(NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True))
    elif normalize_mode in ["meanstdtanh"]:
        ts.append(NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True))
        ts.append(Lambdad(keys=key, func=lambda x: 3 * torch.tanh(x / 3)))
    elif normalize_mode in ["pet"]:
        ts.append(Lambdad(keys=key, func=lambda x: torch.sigmoid((x - x.min()) / x.std())))
    else:
        raise ValueError("Unsupported normalize_mode" + str(normalize_mode))

In [None]:
run_inference_verbose(
    model_file=r"J:\startup\whole-head-05mm-v1.0.1\model.pt",
    image_file=r"J:\startup\skull_ct_RAS00.nrrd",
    result_file=r"J:\startup\output_seg01.nrrd",
)


In [4]:
@torch.no_grad()
def run_inference(model_file, image_file, result_file, save_mode=None, 
                 image_file_2=None, image_file_3=None, image_file_4=None):
    
    start_time = time.time()
    timing_checkpoints = []  # list of (operation, time) tuples

    # --- Checking for model file ---
    if not os.path.exists(model_file):
        raise ValueError('Cannot find model file:' + str(model_file))

    print(f"Loading model from: {model_file}")
    checkpoint = torch.load(model_file, map_location="cpu")

    if 'config' not in checkpoint:
        raise ValueError('Config not found in checkpoint (not a auto3dseg/segresnet model):' + str(model_file))

    config = checkpoint["config"]
    state_dict = checkpoint["state_dict"]

    epoch = checkpoint.get("epoch", 0)
    best_metric = checkpoint.get("best_metric", 0)
    sigmoid = config.get("sigmoid", False)

    model = ConfigParser(config["network"]).get_parsed_content()
    model.load_state_dict(state_dict, strict=True)

    print(f'Model loaded: Epoch {epoch}, Best Metric {best_metric}')

    # Select device based on previous check
    device = torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device(0)
    print(f"Moving model to device: {device}")
    
    model = model.to(device=device, memory_format=torch.channels_last_3d)
    model.eval()

    # --- BRATS Mode Check ---
    if save_mode == 'brats' or 'brats' in model_file:
        print("Mode: BRATS")
        image_files = []
        for index, img in enumerate([image_file, image_file_2, image_file_3, image_file_4]):
            if img is not None:
                image_files.append(img)

        for img in image_files:
            if img is None or not os.path.exists(img):
                raise ValueError(f'Incorrect image filename for {img}: "{img}"')

        ts = [
            LoadImaged(keys="image", ensure_channel_first=True, dtype=None, allow_missing_keys=True, image_only=False),
            EnsureTyped(keys="image", data_type="tensor", dtype=torch.float, allow_missing_keys=True)
        ]

        if config.get("orientation_ras", False):
            print('Using orientation_ras')
            ts.append(Orientationd(keys="image", axcodes="RAS"))
        
        if config.get("crop_foreground", True):
            print('Using crop_foreground')
            ts.append(CropForegroundd(keys="image", source_key="image", margin=10, allow_smaller=True))

        if config.get("resample_resolution", None) is not None:
            pixdim = list(config["resample_resolution"])
            print(f'Using resample with resample_resolution {pixdim}')
            ts.append(
                Spacingd(
                    keys=["image"], pixdim=list(pixdim), mode=["bilinear"], dtype=torch.float,
                    min_pixdim=np.array(pixdim) * 0.75, max_pixdim=np.array(pixdim) * 1.25, allow_missing_keys=True,
                )
            )

        main_normalize_mode = config["normalize_mode"]
        intensity_bounds = config["intensity_bounds"]
        _add_normalization_transforms(ts, 'image', main_normalize_mode, intensity_bounds)

        inf_transform = Compose(ts)
        roi_size = config["roi_size"]
        sliding_inferrer = SlidingWindowInfererAdapt(roi_size=roi_size, sw_batch_size=1, overlap=0.625, mode="gaussian",
                                                     cache_roi_weight_map=False, progress=True)

        batch_data = inf_transform([{"image": image_files}])
        original_affine = batch_data[0]['image'].meta[MetaKeys.ORIGINAL_AFFINE]
        batch_data = list_data_collate([batch_data])
        data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=torch.channels_last_3d, device=device)
        timing_checkpoints.append(("Preprocessing", time.time()))

        print('Running Inference ...')
        with autocast(enabled=True):
            logits = sliding_inferrer(inputs=data, network=model)
        timing_checkpoints.append(("Inference", time.time()))

        print(f"Logits shape: {logits.shape}")
        
        try:
            pred = logits2pred(logits, sigmoid=sigmoid)
        except RuntimeError as e:
            if not logits.is_cuda:
                raise e
            print(f"logits2pred failed on GPU, retrying on CPU. Shape: {logits.shape}")
            logits = logits.cpu()
            pred = logits2pred(logits, sigmoid=sigmoid)
        
        print(f"Preds shape: {pred.shape}")
        timing_checkpoints.append(("Logits", time.time()))
        logits = None

        post_transforms = Compose([Invertd(keys="pred", orig_keys="image", transform=inf_transform, nearest_interp=True)])

        batch_data["pred"] = convert_to_dst_type(pred, batch_data["image"], dtype=pred.dtype, device=pred.device)[0]
        pred = [post_transforms(x)["pred"] for x in decollate_batch(batch_data)]
        seg = pred[0]
        print(f"Preds inverted shape: {seg.shape}")
        timing_checkpoints.append(("Preds", time.time()))

        # Merge BRATS channels
        p2 = 2 * seg.any(0).to(dtype=torch.uint8)
        p2[seg[1:].any(0)] = 1
        p2[seg[2:].any(0)] = 3
        seg = p2
        print(f"Updated seg for BRATS: {seg.shape}")

    # --- Standard Mode (Non-BRATS) ---
    else:
        print("Mode: Standard")
        image_files = {}
        for index, img in enumerate([image_file, image_file_2, image_file_3, image_file_4]):
            if img is not None:
                image_files[f"image{index + 1}"] = img

        keys = list(image_files.keys())
        for img in image_files.keys():
            if image_files[img] is None or not os.path.exists(image_files[img]):
                raise ValueError(f'Incorrect image filename for {img}: "{image_files[img]}"')

        loader = LoadImaged(keys=keys, ensure_channel_first=True, dtype=None, allow_missing_keys=True, image_only=False)
        images_loaded = loader(image_files)
        timing_checkpoints.append(("Loading volumes", time.time()))

        if len(keys) > 1:
            image1_shape = images_loaded[keys[0]].shape[1:]
            for idx, img in enumerate(keys[1:]):
                temp_shape = images_loaded[img].shape[-len(image1_shape):]
                if np.any(np.not_equal(image1_shape, temp_shape)):
                    print(f'Volumes do not have the same size - Resizing volume {img}')
                    resizer = Resized(keys=img, spatial_size=image1_shape, mode='bilinear')
                    images_loaded = resizer(images_loaded)
                    timing_checkpoints.append((f"Resizing volume {img}", time.time()))

        main_normalize_mode = config["normalize_mode"]
        intensity_bounds = config["intensity_bounds"]
        
        if len(keys) == 1:
            ts = [
                ConcatItemsd(keys=keys, name="image", dim=0),
                EnsureTyped(keys="image", data_type="tensor", dtype=torch.float, allow_missing_keys=True)
            ]
            _add_normalization_transforms(ts, "image", main_normalize_mode, intensity_bounds)
        else:
            ts = []
            extra_modalities = OrderedDict(config['extra_modalities'])
            normalize_modes = [main_normalize_mode] + list(extra_modalities.values())
            for key, normalize_mode in zip(keys, normalize_modes):
                _add_normalization_transforms(ts, key, normalize_mode, intensity_bounds)
            ts.extend([
                ConcatItemsd(keys=keys, name="image", dim=0),
                EnsureTyped(keys="image", data_type="tensor", dtype=torch.float, allow_missing_keys=True)
            ])

        if config.get("orientation_ras", False):
            print('Using orientation_ras')
            ts.append(Orientationd(keys="image", axcodes="RAS"))
        
        if config.get("crop_foreground", True):
            print('Using crop_foreground')
            ts.append(CropForegroundd(keys="image", source_key="image1", margin=10, allow_smaller=True))

        if config.get("resample_resolution", None) is not None:
            pixdim = list(config["resample_resolution"])
            print(f'Using resample with resample_resolution {pixdim}')
            ts.append(
                Spacingd(
                    keys=["image"], pixdim=list(pixdim), mode=["bilinear"], dtype=torch.float,
                    min_pixdim=np.array(pixdim) * 0.75, max_pixdim=np.array(pixdim) * 1.25, allow_missing_keys=True,
                )
            )

        inf_transform = Compose(ts)
        roi_size = config["roi_size"]
        sliding_inferrer = SlidingWindowInfererAdapt(roi_size=roi_size, sw_batch_size=1, overlap=0.625, mode="gaussian",
                                                     cache_roi_weight_map=False, progress=True)

        batch_data = inf_transform([images_loaded])
        original_affine = batch_data[0]['image'].meta[MetaKeys.ORIGINAL_AFFINE]
        batch_data = list_data_collate([batch_data])
        data = batch_data["image"].as_subclass(torch.Tensor).to(memory_format=torch.channels_last_3d, device=device)
        timing_checkpoints.append(("Preprocessing", time.time()))

        print('Running Inference ...')
        with autocast(enabled=True):
            logits = sliding_inferrer(inputs=data, network=model)
        timing_checkpoints.append(("Inference", time.time()))

        print(f"Logits shape: {logits.shape}")
        try:
            pred = logits2pred(logits, sigmoid=sigmoid)
        except RuntimeError as e:
            if not logits.is_cuda:
                raise e
            print(f"logits2pred failed on GPU, retrying on CPU. Shape: {logits.shape}")
            logits = logits.cpu()
            pred = logits2pred(logits, sigmoid=sigmoid)
        
        print(f"Preds shape: {pred.shape}")
        timing_checkpoints.append(("Logits", time.time()))
        logits = None

        post_transforms_list = [Invertd(keys="pred", orig_keys="image", transform=inf_transform, nearest_interp=True)]
        if 'whole-head' in model_file:
            print("Applying KeepLargestConnectedComponentd (whole-head model detected)")
            post_transforms_list.append(KeepLargestConnectedComponentd(keys="pred", num_components=2)) 
        
        post_transforms = Compose(post_transforms_list)

        batch_data["pred"] = convert_to_dst_type(pred, batch_data["image"], dtype=pred.dtype, device=pred.device)[0]
        pred = [post_transforms(x)["pred"] for x in decollate_batch(batch_data)]
        seg = pred[0][0]

    print(f"Preds inverted shape: {seg.shape}")
    timing_checkpoints.append(("Preds", time.time()))

    seg = seg.cpu().numpy().astype(np.uint8)
    timing_checkpoints.append(("Convert to array", time.time()))

    # Save result
    nrrd_header = nrrd.read_header(image_file)
    nrrd.write(result_file, seg, nrrd_header)
    timing_checkpoints.append(("Save", time.time()))

    print("\n--- Computation Time Log ---")
    previous_start_time = start_time
    for timing_checkpoint in timing_checkpoints:
        print(f"  {timing_checkpoint[0]:<20}: {timing_checkpoint[1] - previous_start_time:.2f} seconds")
        previous_start_time = timing_checkpoint[1]

    print(f'\nALL DONE. Result saved in: {result_file}')


# Execute
run_inference(
    model_file=MODEL_FILE,
    image_file=IMAGE_FILE,
    result_file=RESULT_FILE,
    image_file_2=IMAGE_FILE_2,
    image_file_3=IMAGE_FILE_3,
    image_file_4=IMAGE_FILE_4
)

Loading model from: J:\startup\whole-head-05mm-v1.0.1\model.pt


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
`apex.normalization.InstanceNorm3dNVFuser` is not installed properly, use nn.InstanceNorm3d instead.


Model loaded: Epoch 450, Best Metric 0.9060699939727783
Moving model to device: cuda:0
Mode: Standard
Using crop_foreground
Using resample with resample_resolution [0.45703124999999994, 0.459, 0.458015625]
Running Inference ...


`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
100%|██████████| 125/125 [13:22<00:00,  6.42s/it]


Logits shape: torch.Size([1, 10, 512, 496, 338])
Preds shape: torch.Size([1, 1, 512, 496, 338])
Applying KeepLargestConnectedComponentd (whole-head model detected)
Preds inverted shape: torch.Size([512, 512, 258])

--- Computation Time Log ---
  Loading volumes     : 5.63 seconds
  Preprocessing       : 7.21 seconds
  Inference           : 802.57 seconds
  Logits              : 26.98 seconds
  Preds               : 31.02 seconds
  Convert to array    : 0.91 seconds
  Save                : 2.60 seconds

ALL DONE. Result saved in: J:\startup\output_seg02.nrrd


chatgpt code below


In [5]:
import os
import time
import numpy as np
import torch
import nrrd

from collections import OrderedDict
from torch.cuda.amp import autocast

from monai.bundle import ConfigParser
from monai.data import decollate_batch, list_data_collate
from monai.inferers import SlidingWindowInfererAdapt
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureTyped,
    ConcatItemsd,
    Orientationd,
    CropForegroundd,
    Spacingd,
    Invertd,
    KeepLargestConnectedComponentd,
)
from monai.utils import MetaKeys, convert_to_dst_type

# ------------------------------------------------------------
# Helper: normalization (same logic used in Auto3DSeg)
# ------------------------------------------------------------
def _add_normalization_transforms(ts, key, normalize_mode, intensity_bounds):
    if normalize_mode == "range":
        from monai.transforms import ScaleIntensityRanged
        ts.append(
            ScaleIntensityRanged(
                keys=key,
                a_min=intensity_bounds[0],
                a_max=intensity_bounds[1],
                b_min=0.0,
                b_max=1.0,
                clip=True,
            )
        )
    elif normalize_mode == "meanstd":
        from monai.transforms import NormalizeIntensityd
        ts.append(NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True))


# ------------------------------------------------------------
# Main inference
# ------------------------------------------------------------
@torch.no_grad()
def run_inference_verbose(
    model_file,
    image_file,
    result_file,
    save_mode=None,
    image_file_2=None,
    image_file_3=None,
    image_file_4=None,
):
    t0 = time.time()

    # ------------------ DEVICE ------------------
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    print(f"[DEVICE] {device}")

    # ------------------ LOAD MODEL ------------------
    if not os.path.exists(model_file):
        raise FileNotFoundError(model_file)

    checkpoint = torch.load(model_file, map_location="cpu")
    config = checkpoint["config"]
    state_dict = checkpoint["state_dict"]

    model = ConfigParser(config["network"]).get_parsed_content()
    model.load_state_dict(state_dict, strict=True)
    model.to(device=device, memory_format=torch.channels_last_3d)
    model.eval()

    print(f"[MODEL] Loaded (epoch={checkpoint.get('epoch', '?')}, "
          f"best_metric={checkpoint.get('best_metric', '?')})")

    sigmoid = config.get("sigmoid", False)

    # ------------------ MODE ------------------
    is_brats = save_mode == "brats" or (
        save_mode is None and "brats" in model_file.lower()
    )
    print(f"[MODE] {'BRATS' if is_brats else 'STANDARD'}")

    # ------------------ INPUT FILES ------------------
    input_files = [image_file, image_file_2, image_file_3, image_file_4]
    input_files = [f for f in input_files if f is not None]

    for f in input_files:
        if not os.path.exists(f):
            raise FileNotFoundError(f)

    # ------------------ LOAD ------------------
    if is_brats:
        data_dict = {"image": input_files}
        keys = ["image"]
    else:
        data_dict = {f"image{i+1}": f for i, f in enumerate(input_files)}
        keys = list(data_dict.keys())

    loader = LoadImaged(keys=keys, ensure_channel_first=True, image_only=False)
    loaded = loader(data_dict)

    print("[DATA] Loaded volumes:")
    for k in keys:
        print(f"  {k}: {loaded[k].shape}")

    # ------------------ TRANSFORMS ------------------
    ts = []

    if not is_brats:
        ts.append(ConcatItemsd(keys=keys, name="image", dim=0))

    ts.append(EnsureTyped(keys="image", dtype=torch.float))

    _add_normalization_transforms(
        ts,
        "image",
        config["normalize_mode"],
        config["intensity_bounds"],
    )

    if config.get("orientation_ras", False):
        ts.append(Orientationd(keys="image", axcodes="RAS"))

    if config.get("crop_foreground", True):
        ts.append(CropForegroundd(keys="image", source_key="image"))

    if config.get("resample_resolution") is not None:
        pixdim = list(config["resample_resolution"])
        ts.append(
            Spacingd(
                keys="image",
                pixdim=pixdim,
                mode="bilinear",
            )
        )

    inf_transform = Compose(ts)

    batch = inf_transform([loaded])
    meta = batch[0]["image"].meta
    batch = list_data_collate(batch)

    data = batch["image"].to(device, memory_format=torch.channels_last_3d)

    print(f"[DATA] Final tensor: {tuple(data.shape)}")

    # ------------------ INFERENCE ------------------
    roi_size = config["roi_size"]
    inferer = SlidingWindowInfererAdapt(
        roi_size=roi_size,
        sw_batch_size=1,
        overlap=0.625,
        mode="gaussian",
        progress=True,
    )

    print("[RUN] Inference started")
    with autocast(enabled=use_cuda):
        logits = inferer(data, model)

    print(f"[RUN] Logits shape: {tuple(logits.shape)}")

    # ------------------ PRED ------------------
    from monai.networks.utils import one_hot
    if sigmoid:
        pred = (logits.sigmoid() > 0.5).float()
    else:
        pred = torch.argmax(logits, dim=1, keepdim=True)

    logits = None

    # ------------------ INVERT ------------------
    post = [Invertd(keys="pred", orig_keys="image", transform=inf_transform)]
    if "whole-head" in model_file.lower():
        post.append(KeepLargestConnectedComponentd(keys="pred"))

    post = Compose(post)

    batch["pred"] = convert_to_dst_type(
        pred, batch["image"], device=pred.device
    )[0]

    seg = decollate_batch(batch)[0]["pred"]

    print(f"[POST] Segmentation shape: {tuple(seg.shape)}")

    # ------------------ SAVE NRRD (CORRECT METADATA) ------------------
    seg_np = seg.cpu().numpy().astype(np.uint8)

    header = {
        "space": meta.get("space"),
        "space directions": meta.get("space_directions"),
        "space origin": meta.get("space_origin"),
    }

    nrrd.write(result_file, seg_np, header)
    print(f"[SAVE] {result_file}")

    print(f"[DONE] Total time: {time.time() - t0:.2f}s")


In [6]:
run_inference_verbose(
    model_file=r"J:\startup\whole-head-05mm-v1.0.1\model.pt",
    image_file=r"J:\startup\skull_ct_RAS00.nrrd",
    result_file=r"J:\startup\output_seg03.nrrd",
    
)


[DEVICE] cuda:0


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


[MODEL] Loaded (epoch=450, best_metric=0.9060699939727783)
[MODE] STANDARD
[DATA] Loaded volumes:
  image1: torch.Size([1, 512, 512, 258])


`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.


[DATA] Final tensor: (1, 1, 525, 498, 338)
[RUN] Inference started


100%|██████████| 125/125 [38:44<00:00, 18.60s/it]


[RUN] Logits shape: (1, 10, 525, 498, 338)
[POST] Segmentation shape: (1, 525, 498, 338)


TypeError: iteration over a 0-d array