Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions connectomics/config/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,14 @@ class DataConfig:
- Multi-channel label transformations
- Train/validation splitting options
- Caching and performance optimization
- 2D data support with do_2d parameter
"""

# Dataset type
dataset_type: Optional[str] = None # Type of dataset: None (volume), 'filename', 'tile', etc.

# 2D data support
do_2d: bool = False # Enable 2D data processing (extract 2D slices from 3D volumes)

# Paths - Volume-based datasets
train_image: Optional[str] = None
Expand Down Expand Up @@ -755,6 +759,9 @@ class InferenceDataConfig:
default_factory=list
) # Axis permutation for test data (e.g., [2,1,0] for xyz->zyx)
output_path: str = "results/"

# 2D data support
do_2d: bool = False # Enable 2D data processing for inference


@dataclass
Expand Down
6 changes: 5 additions & 1 deletion connectomics/data/dataset/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def __init__(
super().__init__(data=data_dicts, transform=transforms)

# Store connectomics-specific parameters
self.sample_size = ensure_tuple_rep(sample_size, 3)
# For 2D data, use 2D dimensions; otherwise use 3D
if do_2d:
self.sample_size = ensure_tuple_rep(sample_size, 2)
else:
self.sample_size = ensure_tuple_rep(sample_size, 3)
self.mode = mode
self.iter_num = iter_num
self.valid_ratio = valid_ratio
Expand Down
57 changes: 38 additions & 19 deletions connectomics/lightning/lit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ def _setup_sliding_window_inferer(self):
if not hasattr(self.cfg, 'inference'):
return

# For 2D models with do_2d=True, disable sliding window inference
if getattr(self.cfg.data, 'do_2d', False):
warnings.warn(
"Sliding-window inference disabled for 2D models with do_2d=True. "
"Using direct inference instead.",
UserWarning,
)
return

roi_size = self._resolve_inferer_roi_size()
if roi_size is None:
warnings.warn(
Expand Down Expand Up @@ -200,12 +209,20 @@ def _resolve_inferer_roi_size(self) -> Optional[Tuple[int, ...]]:
if hasattr(self.cfg, 'model') and hasattr(self.cfg.model, 'output_size'):
output_size = getattr(self.cfg.model, 'output_size', None)
if output_size:
return tuple(int(v) for v in output_size)
roi_size = tuple(int(v) for v in output_size)
# For 2D models with do_2d=True, convert to 3D ROI size
if getattr(self.cfg.data, 'do_2d', False) and len(roi_size) == 2:
roi_size = (1,) + roi_size # Add depth dimension
return roi_size

if hasattr(self.cfg, 'data') and hasattr(self.cfg.data, 'patch_size'):
patch_size = getattr(self.cfg.data, 'patch_size', None)
if patch_size:
return tuple(int(v) for v in patch_size)
roi_size = tuple(int(v) for v in patch_size)
# For 2D models with do_2d=True, convert to 3D ROI size
if getattr(self.cfg.data, 'do_2d', False) and len(roi_size) == 2:
roi_size = (1,) + roi_size # Add depth dimension
return roi_size

return None

Expand Down Expand Up @@ -367,6 +384,10 @@ def _predict_with_tta(self, images: torch.Tensor, mask: Optional[torch.Tensor] =
f"Expected shapes: (D, H, W), (B, D, H, W), or (B, C, D, H, W)"
)

# For 2D models with do_2d=True, squeeze the depth dimension if present
if getattr(self.cfg.data, 'do_2d', False) and images.size(2) == 1: # [B, C, 1, H, W] -> [B, C, H, W]
images = images.squeeze(2)

# Get TTA configuration (default to no augmentation if not configured)
if hasattr(self.cfg, 'inference') and hasattr(self.cfg.inference, 'test_time_augmentation'):
tta_flip_axes_config = getattr(self.cfg.inference.test_time_augmentation, 'flip_axes', None)
Expand All @@ -385,23 +406,21 @@ def _predict_with_tta(self, images: torch.Tensor, mask: Optional[torch.Tensor] =
ensemble_result = self._apply_tta_preprocessing(pred)
else:
if tta_flip_axes_config == 'all' or tta_flip_axes_config == []:
# "all" or []: All 8 flips (all combinations of Z, Y, X)
# IMPORTANT: MONAI Flip spatial_axis behavior for (B, C, D, H, W) tensors:
# spatial_axis=[0] flips C (channel) - WRONG for TTA!
# spatial_axis=[1] flips D (depth/Z) - CORRECT
# spatial_axis=[2] flips H (height/Y) - CORRECT
# spatial_axis=[3] flips W (width/X) - CORRECT
# Must use [1, 2, 3] for [D, H, W] flips, NOT [0, 1, 2]!
tta_flip_axes = [
[], # No flip
[1], # Flip Z (depth)
[2], # Flip Y (height)
[3], # Flip X (width)
[1, 2], # Flip Z+Y
[1, 3], # Flip Z+X
[2, 3], # Flip Y+X
[1, 2, 3], # Flip Z+Y+X
]
# "all" or []: All flips (all combinations of spatial axes)
# Determine spatial axes based on data dimensions
if images.dim() == 5: # 3D data: [B, C, D, H, W]
spatial_axes = [1, 2, 3] # [D, H, W]
elif images.dim() == 4: # 2D data: [B, C, H, W]
spatial_axes = [1, 2] # [H, W]
else:
raise ValueError(f"Unsupported data dimensions: {images.dim()}")

# Generate all combinations of spatial axes
tta_flip_axes = [[]] # No flip baseline
for r in range(1, len(spatial_axes) + 1):
from itertools import combinations
for combo in combinations(spatial_axes, r):
tta_flip_axes.append(list(combo))
elif isinstance(tta_flip_axes_config, (list, tuple)):
# Custom list: Add no-flip baseline + user-specified flips
tta_flip_axes = [[]] + list(tta_flip_axes_config)
Expand Down
13 changes: 12 additions & 1 deletion connectomics/models/arch/monai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,18 @@ def __init__(self, model: nn.Module):

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through MONAI model."""
return self.model(x)
# For 2D models, squeeze the depth dimension if present
if x.dim() == 5 and x.size(2) == 1: # [B, C, 1, H, W] -> [B, C, H, W]
x = x.squeeze(2)

# Forward through model
output = self.model(x)

# For 2D models, add back the depth dimension if needed for sliding window inference
if output.dim() == 4 and x.dim() == 5: # [B, C, H, W] -> [B, C, 1, H, W]
output = output.unsqueeze(2)

return output


def _check_monai_available():
Expand Down
1 change: 1 addition & 0 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,7 @@ def setup(self, stage=None):
cache_rate=cfg.data.cache_rate if use_cache else 0.0,
iter_num=iter_num_for_dataset,
sample_size=tuple(cfg.data.patch_size),
do_2d=cfg.data.do_2d,
)
# Setup datasets based on mode
if mode == "train":
Expand Down
10 changes: 7 additions & 3 deletions tutorials/monai2d_worm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ model:

# Data - Using automatic 80/20 train/val split (DeepEM-style)
data:
# 2D data support
do_2d: true # Enable 2D data processing (extract 2D slices from 3D volumes)

# Volume configuration
train_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTr/*.tif
train_label: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/labelsTr/*.tif
Expand Down Expand Up @@ -152,8 +155,9 @@ monitor:
# Inference - MONAI SlidingWindowInferer
inference:
data:
test_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTr/Image96_00002_0000.tif
test_label: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/labelsTr/Image96_00002.tif
do_2d: true # Enable 2D data processing for inference
test_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTs/*.tif
test_label: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTs/*.tif
test_resolution: [5, 5]
output_path: outputs/monai2d_worm/results/

Expand Down Expand Up @@ -188,7 +192,7 @@ inference:
# Evaluation
evaluation:
enabled: true # Use eval mode for BatchNorm
metrics: [jaccard] # Metrics to compute
metrics: [adapted_rand] # Metrics to compute (adapted_rand for instance segmentation)

# NOTE: batch_size=1 for inference
# During training: batch_size controls how many random patches to load
Expand Down
Loading