Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions REFACTORING_PLAN.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,18 @@ connectomics/lightning/
└── lit_data.py # LightningDataModule (684 lines, existing)
```

**Migration Steps:**
1. Create new module files
2. Move functionality in logical chunks
3. Update imports in `lit_model.py`
4. Add integration tests for each module
5. Update documentation

**Success Criteria:**
- [ ] Each file < 500 lines
- [ ] Clear separation of concerns
- [ ] All existing tests pass
- [ ] Documentation updated
**Note:** Multi-task learning was integrated into `deep_supervision.py` (not a separate module) since the logic is tightly coupled with deep supervision.

**Completed Actions:**
Expand Down Expand Up @@ -408,7 +420,6 @@ class DataConfig:
- [x] `deep_supervision_clamp_min: float` (default: -20.0)
- [x] `deep_supervision_clamp_max: float` (default: 20.0)
- [x] Validation logic with warning for insufficient weights
- [x] Backward compatible (defaults match old behavior)
- [ ] Other hardcoded values (target interpolation, rejection sampling) - Future work

**Status:** ✅ Phase 2.3 (Deep Supervision) completed. Users can now customize deep supervision weights and clamping ranges via config.
Expand Down Expand Up @@ -500,14 +511,12 @@ def build_test_transforms(cfg: Config, keys: list[str] = None) -> Compose:
- File size reduced from 791 to 727 lines (-64 lines, ~8% reduction)
- Eliminated ~80% code duplication
- Single source of truth for shared transform logic
- Backward compatible (same public API)

**Action Items:**
- [x] Extract shared logic into `_build_eval_transforms_impl()`
- [x] Identify val/test-specific differences (4 key differences)
- [x] Create mode-specific branching with clear comments
- [x] Keep wrapper functions for API compatibility
- [x] Backward compatible (public API unchanged)

**Status:** ✅ Phase 2.5 complete. Code duplication eliminated while preserving all functionality.

Expand Down Expand Up @@ -996,10 +1005,8 @@ See Priority 1.3 above for full details.

### Mitigation Strategies
1. **Comprehensive testing** before and after each change
2. **Feature flags** for backward compatibility
3. **Deprecation warnings** before removal
4. **Rollback plan** for each phase
5. **User communication** via release notes
2. **Rollback plan** for each phase
3. **User communication** via release notes

---

Expand Down
6 changes: 0 additions & 6 deletions connectomics/config/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,6 @@ class DataConfig:
default_factory=list
) # Axis permutation for training data (e.g., [2,1,0] for xyz->zyx)
val_transpose: List[int] = field(default_factory=list) # Axis permutation for validation data
test_transpose: List[int] = field(
default_factory=list
) # Axis permutation for test data (deprecated, use inference.data.test_transpose)

# Dataset statistics (for auto-planning)
target_spacing: Optional[List[float]] = None # Target voxel spacing [z, y, x] in mm
Expand Down Expand Up @@ -868,9 +865,6 @@ class TestTimeAugmentationConfig:
flip_axes: Any = (
None # TTA flip strategy: "all" (8 flips), null (no aug), or list like [[0], [1], [2]]
)
act: Optional[str] = (
None # Single activation for all channels: 'softmax', 'sigmoid', 'tanh', None (deprecated, use channel_activations)
)
channel_activations: Optional[List[Any]] = (
None # Per-channel activations: [[start_ch, end_ch, 'activation'], ...] e.g., [[0, 2, 'softmax'], [2, 3, 'sigmoid'], [3, 4, 'tanh']]
)
Expand Down
6 changes: 3 additions & 3 deletions connectomics/config/hydra_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def resolve_data_paths(cfg: Config) -> Config:
Supported paths:
- Training: cfg.data.train_path + cfg.data.train_image/train_label/train_mask
- Validation: cfg.data.val_path + cfg.data.val_image/val_label/val_mask
- Testing (legacy): cfg.data.test_path + cfg.data.test_image/test_label/test_mask
- Inference (primary): cfg.inference.data.test_path + cfg.inference.data.test_image/test_label/test_mask
- Testing: cfg.data.test_path + cfg.data.test_image/test_label/test_mask
- Inference: cfg.inference.data.test_path + cfg.inference.data.test_image/test_label/test_mask

Args:
cfg: Config object to resolve paths for
Expand Down Expand Up @@ -316,7 +316,7 @@ def _combine_path(base_path: str, file_path: Optional[Union[str, List[str]]]) ->
cfg.data.val_mask = _combine_path(cfg.data.val_path, cfg.data.val_mask)
cfg.data.val_json = _combine_path(cfg.data.val_path, cfg.data.val_json)

# Resolve test paths (legacy support for cfg.data.test_path)
# Resolve test paths
if cfg.data.test_path:
cfg.data.test_image = _combine_path(cfg.data.test_path, cfg.data.test_image)
cfg.data.test_label = _combine_path(cfg.data.test_path, cfg.data.test_label)
Expand Down
17 changes: 6 additions & 11 deletions connectomics/data/augment/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def build_train_transforms(
# Load images first (unless using pre-cached dataset)
if not skip_loading:
# Use appropriate loader based on dataset type
dataset_type = getattr(cfg.data, "dataset_type", "volume") # Default to volume for backward compatibility
dataset_type = getattr(cfg.data, "dataset_type", "volume")

if dataset_type == "filename":
# For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged
Expand All @@ -94,12 +94,9 @@ def build_train_transforms(
transforms.append(ApplyVolumetricSplitd(keys=keys))

# Apply resize if configured (before cropping)
# Check data_transform first (new), then fall back to image_transform.resize (legacy)
resize_size = None
if hasattr(cfg.data, "data_transform") and hasattr(cfg.data.data_transform, "resize") and cfg.data.data_transform.resize is not None:
resize_size = cfg.data.data_transform.resize
elif hasattr(cfg.data.image_transform, "resize") and cfg.data.image_transform.resize is not None:
resize_size = cfg.data.image_transform.resize

if resize_size:
# Use bilinear for images, nearest for labels/masks
Expand Down Expand Up @@ -247,7 +244,7 @@ def _build_eval_transforms_impl(
transforms = []

# Load images first - use appropriate loader based on dataset type
dataset_type = getattr(cfg.data, "dataset_type", "volume") # Default to volume for backward compatibility
dataset_type = getattr(cfg.data, "dataset_type", "volume")

if dataset_type == "filename":
# For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged
Expand All @@ -260,17 +257,15 @@ def _build_eval_transforms_impl(
if mode == "val":
transpose_axes = cfg.data.val_transpose if cfg.data.val_transpose else []
else: # mode == "test"
# Check both data.test_transpose and inference.data.test_transpose
# Use inference.data.test_transpose
transpose_axes = []
if cfg.data.test_transpose:
transpose_axes = cfg.data.test_transpose
if (
hasattr(cfg, "inference")
and hasattr(cfg.inference, "data")
and hasattr(cfg.inference.data, "test_transpose")
and cfg.inference.data.test_transpose
):
transpose_axes = cfg.inference.data.test_transpose # inference takes precedence
transpose_axes = cfg.inference.data.test_transpose

transforms.append(
LoadVolumed(keys=keys, transpose_axes=transpose_axes if transpose_axes else None)
Expand Down Expand Up @@ -455,8 +450,8 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str], do_2d: bo
List of MONAI transforms
"""
transforms = []
# Get preset mode (default to "some" for backward compatibility)

# Get preset mode
preset = getattr(aug_cfg, "preset", "some")

# Helper function to check if augmentation should be applied
Expand Down
34 changes: 2 additions & 32 deletions connectomics/lightning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,25 +428,7 @@ def create_callbacks(cfg) -> list:
callbacks.append(vis_callback)

# Model checkpoint callback
# Support both new unified config (training.checkpoint_*) and old separate config (checkpoint.*)
if hasattr(cfg, 'checkpoint') and cfg.checkpoint is not None:
# Old config style (backward compatibility)
monitor = getattr(cfg.checkpoint, 'monitor', 'val/loss')
default_filename = f'epoch={{epoch:03d}}-{monitor}={{{monitor}:.4f}}'
filename = getattr(cfg.checkpoint, 'filename', default_filename)

checkpoint_callback = ModelCheckpoint(
monitor=monitor,
mode=getattr(cfg.checkpoint, 'mode', 'min'),
save_top_k=getattr(cfg.checkpoint, 'save_top_k', 3),
save_last=getattr(cfg.checkpoint, 'save_last', True),
dirpath=getattr(cfg.checkpoint, 'dirpath', 'checkpoints'),
filename=filename,
verbose=True
)
callbacks.append(checkpoint_callback)
elif hasattr(cfg, 'monitor') and hasattr(cfg.monitor, 'checkpoint'):
# New unified config style (monitor.checkpoint.*)
if hasattr(cfg, 'monitor') and hasattr(cfg.monitor, 'checkpoint'):
monitor = getattr(cfg.monitor.checkpoint, 'monitor', 'val/loss')
filename = getattr(cfg.monitor.checkpoint, 'filename', None)
if filename is None:
Expand All @@ -465,19 +447,7 @@ def create_callbacks(cfg) -> list:
callbacks.append(checkpoint_callback)

# Early stopping callback
# Support both new unified config (training.early_stopping_*) and old separate config (early_stopping.*)
if hasattr(cfg, 'early_stopping') and cfg.early_stopping is not None and cfg.early_stopping.enabled:
# Old config style (backward compatibility)
early_stop_callback = EarlyStopping(
monitor=getattr(cfg.early_stopping, 'monitor', 'val/loss'),
patience=getattr(cfg.early_stopping, 'patience', 10),
mode=getattr(cfg.early_stopping, 'mode', 'min'),
min_delta=getattr(cfg.early_stopping, 'min_delta', 0.0),
verbose=True
)
callbacks.append(early_stop_callback)
elif hasattr(cfg, 'monitor') and hasattr(cfg.monitor, 'early_stopping') and getattr(cfg.monitor.early_stopping, 'enabled', False):
# New unified config style (monitor.early_stopping.*)
if hasattr(cfg, 'monitor') and hasattr(cfg.monitor, 'early_stopping') and getattr(cfg.monitor.early_stopping, 'enabled', False):
early_stop_callback = EarlyStopping(
monitor=getattr(cfg.monitor.early_stopping, 'monitor', 'val/loss'),
patience=getattr(cfg.monitor.early_stopping, 'patience', 10),
Expand Down
Loading
Loading