Add spatial shape constraint docs and test for SwinUNETR (#6771)#8817
Add spatial shape constraint docs and test for SwinUNETR (#6771)#8817Cado87 wants to merge 2 commits intoProject-MONAI:devfrom
Conversation
…AI#6771) Signed-off-by: Adrian Caderno <adriancaderno@gmail.com>
📝 WalkthroughWalkthroughUpdated Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/networks/nets/test_swin_unetr.py (1)
93-103: Consider adding one non-defaultpatch_sizecase.This verifies the default (
patch_size=2) path well. Add one invalid-shape assertion forpatch_size=3to directly cover the generalizedpatch_size ** 5contract.Proposed test extension
def test_invalid_input_shape(self): # spatial dims not divisible by patch_size**5 (default patch_size=2, so must be divisible by 32) net = SwinUNETR(in_channels=1, out_channels=2, feature_size=24, spatial_dims=3) with self.assertRaises(ValueError): net(torch.randn(1, 1, 33, 64, 64)) # 33 is not divisible by 32 net_2d = SwinUNETR(in_channels=1, out_channels=2, feature_size=24, spatial_dims=2) with self.assertRaises(ValueError): net_2d(torch.randn(1, 1, 48, 33)) # 33 is not divisible by 32 + + net_2d_patch3 = SwinUNETR(in_channels=1, out_channels=2, feature_size=24, spatial_dims=2, patch_size=3) + with self.assertRaises(ValueError): + net_2d_patch3(torch.randn(1, 1, 33, 33)) # 33 is not divisible by 3**5 (243)As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/networks/nets/test_swin_unetr.py` around lines 93 - 103, Add a new invalid-shape assertion in the same test (test_invalid_input_shape) that instantiates SwinUNETR with a non-default patch_size (e.g., patch_size=3) and calls the model with an input tensor whose spatial dimensions are not divisible by that patch_size (e.g., use torch.randn(1, 1, 10, 48, 48) for spatial_dims=3) wrapped in self.assertRaises(ValueError); this will exercise the SwinUNETR constructor/check logic (SwinUNETR) for non-default patch sizes and ensure the generalized patch_size validation path is covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/networks/nets/test_swin_unetr.py`:
- Around line 93-103: Add a new invalid-shape assertion in the same test
(test_invalid_input_shape) that instantiates SwinUNETR with a non-default
patch_size (e.g., patch_size=3) and calls the model with an input tensor whose
spatial dimensions are not divisible by that patch_size (e.g., use
torch.randn(1, 1, 10, 48, 48) for spatial_dims=3) wrapped in
self.assertRaises(ValueError); this will exercise the SwinUNETR
constructor/check logic (SwinUNETR) for non-default patch sizes and ensure the
generalized patch_size validation path is covered.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8e650413-a978-4d3a-999c-17bfa4eb13df
📒 Files selected for processing (2)
monai/networks/nets/swin_unetr.pytests/networks/nets/test_swin_unetr.py
Fixes #6771 .
Description
Adds documentation of spatial shape constraints for
SwinUNETR. Each inputspatial dimension must be divisible by
patch_size ** 5(32 by default withpatch_size=2). The runtime validation logic already existed in_check_input_size()but was undocumented. This PR adds a
Spatial Shape Constraintssection to theclass docstring, updates the
patch_sizearg description in__init__, and addsa test to verify that
forward()raisesValueErrorfor invalid spatial shapes.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.