From 9c9a6b7ccb4e95be798555bf9689d56638501a49 Mon Sep 17 00:00:00 2001 From: Boyu Shen Date: Wed, 29 Oct 2025 00:50:21 -0400 Subject: [PATCH] feat: Add 2D data processing support with do_2d flag - Add do_2d parameter to DataConfig and InferenceDataConfig - Modify MonaiConnectomicsDataset to handle 2D dimensions when do_2d=True - Update model wrapper to squeeze/unsqueeze depth dimension for 2D models - Disable sliding window inference for 2D models, use direct inference - Make TTA flip axes dynamic based on data dimensions (2D vs 3D) - Switch evaluation metric from Jaccard to Adapted Rand Error for instance segmentation - Update monai2d_worm.yaml config to enable 2D processing This enables seamless 2D data processing while maintaining 3D compatibility. --- connectomics/config/hydra_config.py | 7 +++ connectomics/data/dataset/dataset_base.py | 6 ++- connectomics/lightning/lit_model.py | 57 +++++++++++++++-------- connectomics/models/arch/monai_models.py | 13 +++++- scripts/main.py | 1 + tutorials/monai2d_worm.yaml | 10 ++-- 6 files changed, 70 insertions(+), 24 deletions(-) diff --git a/connectomics/config/hydra_config.py b/connectomics/config/hydra_config.py index b3a5cbb5..0ca4785c 100644 --- a/connectomics/config/hydra_config.py +++ b/connectomics/config/hydra_config.py @@ -329,10 +329,14 @@ class DataConfig: - Multi-channel label transformations - Train/validation splitting options - Caching and performance optimization + - 2D data support with do_2d parameter """ # Dataset type dataset_type: Optional[str] = None # Type of dataset: None (volume), 'filename', 'tile', etc. + + # 2D data support + do_2d: bool = False # Enable 2D data processing (extract 2D slices from 3D volumes) # Paths - Volume-based datasets train_image: Optional[str] = None @@ -755,6 +759,9 @@ class InferenceDataConfig: default_factory=list ) # Axis permutation for test data (e.g., [2,1,0] for xyz->zyx) output_path: str = "results/" + + # 2D data support + do_2d: bool = False # Enable 2D data processing for inference @dataclass diff --git a/connectomics/data/dataset/dataset_base.py b/connectomics/data/dataset/dataset_base.py index 3ce46664..82166240 100644 --- a/connectomics/data/dataset/dataset_base.py +++ b/connectomics/data/dataset/dataset_base.py @@ -70,7 +70,11 @@ def __init__( super().__init__(data=data_dicts, transform=transforms) # Store connectomics-specific parameters - self.sample_size = ensure_tuple_rep(sample_size, 3) + # For 2D data, use 2D dimensions; otherwise use 3D + if do_2d: + self.sample_size = ensure_tuple_rep(sample_size, 2) + else: + self.sample_size = ensure_tuple_rep(sample_size, 3) self.mode = mode self.iter_num = iter_num self.valid_ratio = valid_ratio diff --git a/connectomics/lightning/lit_model.py b/connectomics/lightning/lit_model.py index cc632a55..a43abe56 100644 --- a/connectomics/lightning/lit_model.py +++ b/connectomics/lightning/lit_model.py @@ -156,6 +156,15 @@ def _setup_sliding_window_inferer(self): if not hasattr(self.cfg, 'inference'): return + # For 2D models with do_2d=True, disable sliding window inference + if getattr(self.cfg.data, 'do_2d', False): + warnings.warn( + "Sliding-window inference disabled for 2D models with do_2d=True. " + "Using direct inference instead.", + UserWarning, + ) + return + roi_size = self._resolve_inferer_roi_size() if roi_size is None: warnings.warn( @@ -200,12 +209,20 @@ def _resolve_inferer_roi_size(self) -> Optional[Tuple[int, ...]]: if hasattr(self.cfg, 'model') and hasattr(self.cfg.model, 'output_size'): output_size = getattr(self.cfg.model, 'output_size', None) if output_size: - return tuple(int(v) for v in output_size) + roi_size = tuple(int(v) for v in output_size) + # For 2D models with do_2d=True, convert to 3D ROI size + if getattr(self.cfg.data, 'do_2d', False) and len(roi_size) == 2: + roi_size = (1,) + roi_size # Add depth dimension + return roi_size if hasattr(self.cfg, 'data') and hasattr(self.cfg.data, 'patch_size'): patch_size = getattr(self.cfg.data, 'patch_size', None) if patch_size: - return tuple(int(v) for v in patch_size) + roi_size = tuple(int(v) for v in patch_size) + # For 2D models with do_2d=True, convert to 3D ROI size + if getattr(self.cfg.data, 'do_2d', False) and len(roi_size) == 2: + roi_size = (1,) + roi_size # Add depth dimension + return roi_size return None @@ -367,6 +384,10 @@ def _predict_with_tta(self, images: torch.Tensor, mask: Optional[torch.Tensor] = f"Expected shapes: (D, H, W), (B, D, H, W), or (B, C, D, H, W)" ) + # For 2D models with do_2d=True, squeeze the depth dimension if present + if getattr(self.cfg.data, 'do_2d', False) and images.size(2) == 1: # [B, C, 1, H, W] -> [B, C, H, W] + images = images.squeeze(2) + # Get TTA configuration (default to no augmentation if not configured) if hasattr(self.cfg, 'inference') and hasattr(self.cfg.inference, 'test_time_augmentation'): tta_flip_axes_config = getattr(self.cfg.inference.test_time_augmentation, 'flip_axes', None) @@ -385,23 +406,21 @@ def _predict_with_tta(self, images: torch.Tensor, mask: Optional[torch.Tensor] = ensemble_result = self._apply_tta_preprocessing(pred) else: if tta_flip_axes_config == 'all' or tta_flip_axes_config == []: - # "all" or []: All 8 flips (all combinations of Z, Y, X) - # IMPORTANT: MONAI Flip spatial_axis behavior for (B, C, D, H, W) tensors: - # spatial_axis=[0] flips C (channel) - WRONG for TTA! - # spatial_axis=[1] flips D (depth/Z) - CORRECT - # spatial_axis=[2] flips H (height/Y) - CORRECT - # spatial_axis=[3] flips W (width/X) - CORRECT - # Must use [1, 2, 3] for [D, H, W] flips, NOT [0, 1, 2]! - tta_flip_axes = [ - [], # No flip - [1], # Flip Z (depth) - [2], # Flip Y (height) - [3], # Flip X (width) - [1, 2], # Flip Z+Y - [1, 3], # Flip Z+X - [2, 3], # Flip Y+X - [1, 2, 3], # Flip Z+Y+X - ] + # "all" or []: All flips (all combinations of spatial axes) + # Determine spatial axes based on data dimensions + if images.dim() == 5: # 3D data: [B, C, D, H, W] + spatial_axes = [1, 2, 3] # [D, H, W] + elif images.dim() == 4: # 2D data: [B, C, H, W] + spatial_axes = [1, 2] # [H, W] + else: + raise ValueError(f"Unsupported data dimensions: {images.dim()}") + + # Generate all combinations of spatial axes + tta_flip_axes = [[]] # No flip baseline + for r in range(1, len(spatial_axes) + 1): + from itertools import combinations + for combo in combinations(spatial_axes, r): + tta_flip_axes.append(list(combo)) elif isinstance(tta_flip_axes_config, (list, tuple)): # Custom list: Add no-flip baseline + user-specified flips tta_flip_axes = [[]] + list(tta_flip_axes_config) diff --git a/connectomics/models/arch/monai_models.py b/connectomics/models/arch/monai_models.py index 18dd3b5f..e205b876 100644 --- a/connectomics/models/arch/monai_models.py +++ b/connectomics/models/arch/monai_models.py @@ -36,7 +36,18 @@ def __init__(self, model: nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through MONAI model.""" - return self.model(x) + # For 2D models, squeeze the depth dimension if present + if x.dim() == 5 and x.size(2) == 1: # [B, C, 1, H, W] -> [B, C, H, W] + x = x.squeeze(2) + + # Forward through model + output = self.model(x) + + # For 2D models, add back the depth dimension if needed for sliding window inference + if output.dim() == 4 and x.dim() == 5: # [B, C, H, W] -> [B, C, 1, H, W] + output = output.unsqueeze(2) + + return output def _check_monai_available(): diff --git a/scripts/main.py b/scripts/main.py index 28152279..152f4719 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -779,6 +779,7 @@ def setup(self, stage=None): cache_rate=cfg.data.cache_rate if use_cache else 0.0, iter_num=iter_num_for_dataset, sample_size=tuple(cfg.data.patch_size), + do_2d=cfg.data.do_2d, ) # Setup datasets based on mode if mode == "train": diff --git a/tutorials/monai2d_worm.yaml b/tutorials/monai2d_worm.yaml index f7150e4b..251788a3 100644 --- a/tutorials/monai2d_worm.yaml +++ b/tutorials/monai2d_worm.yaml @@ -58,6 +58,9 @@ model: # Data - Using automatic 80/20 train/val split (DeepEM-style) data: + # 2D data support + do_2d: true # Enable 2D data processing (extract 2D slices from 3D volumes) + # Volume configuration train_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTr/*.tif train_label: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/labelsTr/*.tif @@ -152,8 +155,9 @@ monitor: # Inference - MONAI SlidingWindowInferer inference: data: - test_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTr/Image96_00002_0000.tif - test_label: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/labelsTr/Image96_00002.tif + do_2d: true # Enable 2D data processing for inference + test_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTs/*.tif + test_label: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTs/*.tif test_resolution: [5, 5] output_path: outputs/monai2d_worm/results/ @@ -188,7 +192,7 @@ inference: # Evaluation evaluation: enabled: true # Use eval mode for BatchNorm - metrics: [jaccard] # Metrics to compute + metrics: [adapted_rand] # Metrics to compute (adapted_rand for instance segmentation) # NOTE: batch_size=1 for inference # During training: batch_size controls how many random patches to load