Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion connectomics/data/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,31 @@ def read_volume(
if image_suffix in ["h5", "hdf5"]:
data = read_hdf5(filename, dataset)
elif "tif" in image_suffix:
data = imageio.volread(filename).squeeze()
# Check if filename contains glob patterns
if "*" in filename or "?" in filename:
# Expand glob pattern to get matching files
file_list = sorted(glob.glob(filename))
if len(file_list) == 0:
raise FileNotFoundError(f"No TIFF files found matching pattern: {filename}")

# Read each file and stack along depth dimension
volumes = []
for filepath in file_list:
vol = imageio.volread(filepath).squeeze()
# imageio.volread can return multi-page TIFF as (D, H, W) or single page as (H, W)
# Ensure all volumes have at least 3D (D, H, W)
if vol.ndim == 2:
vol = vol[np.newaxis, ...] # Add depth dimension: (H, W) -> (1, H, W)
# vol.ndim == 3 means (D, H, W), which is what we want
volumes.append(vol)

# Stack all volumes along depth dimension
# Each volume is (D_i, H, W), result will be (sum(D_i), H, W)
data = np.concatenate(volumes, axis=0) # Stack along depth (first dimension)
else:
# Single file or multi-page TIFF
data = imageio.volread(filename).squeeze()

if data.ndim == 4:
# Convert (D, C, H, W) to (C, D, H, W) order
data = data.transpose(1, 0, 2, 3)
Expand Down
192 changes: 160 additions & 32 deletions connectomics/lightning/lit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,38 +521,66 @@ def _apply_postprocessing(self, data: np.ndarray) -> np.ndarray:
from connectomics.decoding.postprocess import apply_binary_postprocessing

# Process each sample in batch
batch_size = data.shape[0] if data.ndim >= 4 else 1

# Handle different input shapes
if data.ndim == 2: # (H, W) -> (1, 1, H, W)
data = data[np.newaxis, np.newaxis, ...]
elif data.ndim == 3: # (D, H, W) or (C, H, W) -> assume (D, H, W) and add batch dim
data = data[np.newaxis, ...] # (1, D, H, W)
# Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data
print(f" DEBUG: _apply_postprocessing - input data shape: {data.shape}, ndim: {data.ndim}")
if data.ndim == 4:
# 2D data: (B, C, H, W)
batch_size = data.shape[0]
print(f" DEBUG: _apply_postprocessing - detected 2D data, batch_size: {batch_size}")
elif data.ndim == 5:
# 3D data: (B, C, D, H, W)
batch_size = data.shape[0]
print(f" DEBUG: _apply_postprocessing - detected 3D data, batch_size: {batch_size}")
elif data.ndim == 3:
# Single 3D volume: (C, D, H, W) or (D, H, W) - add batch dimension
batch_size = 1
data = data[np.newaxis, ...] # (1, C, D, H, W) or (1, D, H, W)
print(f" DEBUG: _apply_postprocessing - single 3D sample, added batch dimension")
elif data.ndim == 2:
# Single 2D image: (H, W) - add batch and channel dimensions
batch_size = 1
data = data[np.newaxis, np.newaxis, ...] # (1, 1, H, W)
print(f" DEBUG: _apply_postprocessing - single 2D sample, added batch and channel dimensions")
else:
batch_size = 1

# Ensure we have at least 4D: (B, ...) where ... can be (D, H, W) or (C, D, H, W)
# Ensure we have at least 4D: (B, ...) where ... can be (C, H, W) for 2D or (C, D, H, W) for 3D
results = []
for batch_idx in range(batch_size):
sample = data[batch_idx] # (C, D, H, W) or (D, H, W)

# Extract foreground probability (handle both 3D and 4D)
if sample.ndim == 4: # (C, D, H, W)
foreground_prob = sample[0] # Use first channel
else: # (D, H, W) - already single channel
sample = data[batch_idx] # (C, H, W) for 2D or (C, D, H, W) for 3D
print(f" DEBUG: _apply_postprocessing - processing batch_idx {batch_idx}, sample shape: {sample.shape}")

# Extract foreground probability (always use first channel if channel dimension exists)
if sample.ndim == 4: # (C, D, H, W) - 3D with channel
foreground_prob = sample[0] # Use first channel -> (D, H, W)
elif sample.ndim == 3:
# Could be (C, H, W) for 2D or (D, H, W) for 3D without channel
# If first dim is small (<=4), assume it's channel (2D), otherwise depth (3D)
if sample.shape[0] <= 4:
foreground_prob = sample[0] # (C, H, W) -> use first channel -> (H, W)
else:
foreground_prob = sample # (D, H, W) - already single channel
elif sample.ndim == 2: # (H, W) - 2D single channel
foreground_prob = sample
else:
foreground_prob = sample

# Apply binary postprocessing
processed = apply_binary_postprocessing(foreground_prob, binary_config)

# Expand dims to maintain shape consistency
if sample.ndim == 4:
# Expand dims to maintain shape consistency with original sample structure
if sample.ndim == 4: # (C, D, H, W) -> processed is (D, H, W)
processed = processed[np.newaxis, ...] # (1, D, H, W)
else:
processed = processed # Keep (D, H, W)
elif sample.ndim == 3 and sample.shape[0] <= 4: # (C, H, W) -> processed is (H, W)
processed = processed[np.newaxis, ...] # (1, H, W)
# else: processed is already correct shape (D, H, W) or (H, W)

results.append(processed)

# Stack results back into batch
print(f" DEBUG: _apply_postprocessing - stacking {len(results)} results, shapes: {[r.shape for r in results]}")
data = np.stack(results, axis=0)
print(f" DEBUG: _apply_postprocessing - after stacking, data shape: {data.shape}")

# Step 2: Apply scaling if configured (support both new and legacy names)
intensity_scale = getattr(postprocessing, 'intensity_scale', None)
Expand Down Expand Up @@ -651,13 +679,29 @@ def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray:
}

# Process each sample in batch
batch_size = data.shape[0] if data.ndim == 5 else 1
# Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data
print(f" DEBUG: _apply_decode_mode - input data shape: {data.shape}, ndim: {data.ndim}")
if data.ndim == 4:
data = data[np.newaxis, ...] # Add batch dimension
# 2D data: (B, C, H, W)
batch_size = data.shape[0]
print(f" DEBUG: _apply_decode_mode - detected 2D data, batch_size: {batch_size}")
elif data.ndim == 5:
# 3D data: (B, C, D, H, W)
batch_size = data.shape[0]
print(f" DEBUG: _apply_decode_mode - detected 3D data, batch_size: {batch_size}")
else:
# Single sample: add batch dimension
batch_size = 1
print(f" DEBUG: _apply_decode_mode - single sample, adding batch dimension")
if data.ndim == 3:
data = data[np.newaxis, ...] # (C, H, W) -> (1, C, H, W)
elif data.ndim == 2:
data = data[np.newaxis, np.newaxis, ...] # (H, W) -> (1, 1, H, W)

results = []
for batch_idx in range(batch_size):
sample = data[batch_idx] # (C, D, H, W)
sample = data[batch_idx] # (C, H, W) for 2D or (C, D, H, W) for 3D
print(f" DEBUG: _apply_decode_mode - processing batch_idx {batch_idx}, sample shape: {sample.shape}")

# Apply each decode mode sequentially
for decode_cfg in decode_modes:
Expand Down Expand Up @@ -718,8 +762,10 @@ def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray:
results.append(sample)

# Stack results back into batch
decoded = np.stack(results, axis=0) if len(results) > 1 else results[0]

# Always preserve batch dimension, even for batch_size=1
print(f" DEBUG: _apply_decode_mode - stacking {len(results)} results, shapes: {[r.shape for r in results]}")
decoded = np.stack(results, axis=0)
print(f" DEBUG: _apply_decode_mode - final decoded shape: {decoded.shape}")
return decoded

def _resolve_output_filenames(self, batch: Dict[str, Any]) -> List[str]:
Expand All @@ -742,26 +788,59 @@ def _resolve_output_filenames(self, batch: Dict[str, Any]) -> List[str]:

meta = batch.get('image_meta_dict')
filenames: List[Optional[str]] = []

print(f" DEBUG: _resolve_output_filenames - meta type: {type(meta)}, batch_size: {batch_size}")

if isinstance(meta, dict):
# Handle different metadata structures
if isinstance(meta, list):
# Multiple metadata dicts (one per sample in batch)
print(f" DEBUG: _resolve_output_filenames - meta is list with {len(meta)} items")
for idx, meta_item in enumerate(meta):
if isinstance(meta_item, dict):
filename = meta_item.get('filename_or_obj')
if filename is not None:
filenames.append(filename)
else:
print(f" DEBUG: _resolve_output_filenames - meta_item[{idx}] has no filename_or_obj")
else:
print(f" DEBUG: _resolve_output_filenames - meta_item[{idx}] is not a dict: {type(meta_item)}")
# Update batch_size from metadata if we have a list
batch_size = max(batch_size, len(filenames))
print(f" DEBUG: _resolve_output_filenames - extracted {len(filenames)} filenames from list")
elif isinstance(meta, dict):
# Single metadata dict
print(f" DEBUG: _resolve_output_filenames - meta is dict")
meta_filenames = meta.get('filename_or_obj')
if isinstance(meta_filenames, (list, tuple)):
filenames = list(meta_filenames)
filenames = [f for f in meta_filenames if f is not None]
elif meta_filenames is not None:
filenames = [meta_filenames]
elif isinstance(meta, list):
for meta_item in meta:
if isinstance(meta_item, dict):
filenames.append(meta_item.get('filename_or_obj'))
# Update batch_size from metadata if we have a list
batch_size = max(batch_size, len(filenames))
# Update batch_size from metadata
if len(filenames) > 0:
batch_size = max(batch_size, len(filenames))
print(f" DEBUG: _resolve_output_filenames - extracted {len(filenames)} filenames from dict")
else:
# Handle case where meta might be None or other types
# This can happen if metadata wasn't preserved through transforms
# We'll use fallback filenames based on batch_size
print(f" DEBUG: _resolve_output_filenames - meta is None or unexpected type: {type(meta)}")
pass

resolved_names: List[str] = []
for idx in range(batch_size):
if idx < len(filenames) and filenames[idx]:
resolved_names.append(Path(str(filenames[idx])).stem)
else:
# Generate fallback filename - this shouldn't happen if metadata is preserved correctly
resolved_names.append(f"volume_{self.global_step}_{idx}")

print(f" DEBUG: _resolve_output_filenames - returning {len(resolved_names)} resolved names: {resolved_names[:3]}...")

# Always return exactly batch_size filenames
if len(resolved_names) < batch_size:
print(f" WARNING: _resolve_output_filenames - Only {len(resolved_names)} filenames but batch_size is {batch_size}, padding with fallback names")
while len(resolved_names) < batch_size:
resolved_names.append(f"volume_{self.global_step}_{len(resolved_names)}")

return resolved_names

Expand Down Expand Up @@ -799,8 +878,42 @@ def _write_outputs(
if hasattr(self.cfg.inference, 'postprocessing'):
output_transpose = getattr(self.cfg.inference.postprocessing, 'output_transpose', [])

# Determine actual batch size from predictions
# Handle both batched (B, ...) and unbatched (...) predictions
print(f" DEBUG: _write_outputs - predictions shape: {predictions.shape}, ndim: {predictions.ndim}, filenames count: {len(filenames)}")

if predictions.ndim >= 4:
# Has batch dimension: (B, C, D, H, W) or (B, C, H, W) or (B, D, H, W)
actual_batch_size = predictions.shape[0]
elif predictions.ndim == 3:
# Could be batched 2D data (B, H, W) or single 3D volume (D, H, W)
# Check if first dimension matches number of filenames -> it's batched 2D data
if len(filenames) > 0 and predictions.shape[0] == len(filenames):
# Batched 2D data: (B, H, W) where B matches number of filenames
actual_batch_size = predictions.shape[0]
print(f" DEBUG: _write_outputs - detected batched 2D data (B, H, W) with batch_size={actual_batch_size}")
else:
# Single 3D volume: (D, H, W) - treat as batch_size=1
actual_batch_size = 1
predictions = predictions[np.newaxis, ...] # Add batch dimension
print(f" DEBUG: _write_outputs - detected single 3D volume, added batch dimension")
elif predictions.ndim == 2:
# Single 2D image: (H, W) - treat as batch_size=1
actual_batch_size = 1
predictions = predictions[np.newaxis, ...] # Add batch dimension
else:
# Unexpected shape, default to batch_size=1
actual_batch_size = 1
if predictions.ndim < 2:
predictions = predictions[np.newaxis, ...] # Add batch dimension

# Ensure we don't exceed the actual batch size
batch_size = min(actual_batch_size, len(filenames))
print(f" DEBUG: _write_outputs - actual_batch_size: {actual_batch_size}, batch_size: {batch_size}, will save {batch_size} predictions")

# Save predictions
for idx, name in enumerate(filenames):
for idx in range(batch_size):
name = filenames[idx]
prediction = predictions[idx]

# Apply output transpose if specified
Expand Down Expand Up @@ -1303,16 +1416,28 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP
labels = batch.get('label')
mask = batch.get('mask') # Get test mask if available

# Get batch size from images
actual_batch_size = images.shape[0]
print(f" DEBUG: test_step - images shape: {images.shape}, batch_size: {actual_batch_size}")

# Always use TTA (handles no-transform case) + sliding window
# TTA preprocessing (activation, masking) is applied regardless of flip augmentation
# Note: TTA always returns a simple tensor, not a dict (deep supervision not supported in test mode)
predictions = self._predict_with_tta(images, mask=mask)

# Convert predictions to numpy for saving/decoding
predictions_np = predictions.detach().cpu().float().numpy()
print(f" DEBUG: test_step - predictions_np shape: {predictions_np.shape}")

# Resolve filenames once for all saving operations
filenames = self._resolve_output_filenames(batch)
print(f" DEBUG: test_step - filenames count: {len(filenames)}, filenames: {filenames[:5]}...")

# Ensure filenames list matches actual batch size
# If we don't have enough filenames, generate default ones
while len(filenames) < actual_batch_size:
filenames.append(f"volume_{self.global_step}_{len(filenames)}")
print(f" DEBUG: test_step - after padding, filenames count: {len(filenames)}")

# Check if we should save intermediate predictions (before decoding)
save_intermediate = False
Expand All @@ -1324,10 +1449,13 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTP
self._write_outputs(predictions_np, filenames, suffix="tta_prediction")

# Apply decode mode (instance segmentation decoding)
print(f" DEBUG: test_step - before decode, predictions_np shape: {predictions_np.shape}")
decoded_predictions = self._apply_decode_mode(predictions_np)
print(f" DEBUG: test_step - after decode, decoded_predictions shape: {decoded_predictions.shape}")

# Apply postprocessing (scaling and dtype conversion) if configured
postprocessed_predictions = self._apply_postprocessing(decoded_predictions)
print(f" DEBUG: test_step - after postprocess, postprocessed_predictions shape: {postprocessed_predictions.shape}")

# Save final decoded and postprocessed predictions
self._write_outputs(postprocessed_predictions, filenames, suffix="prediction")
Expand Down
Loading
Loading