Skip to content
Merged
2 changes: 2 additions & 0 deletions detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ python3 luna16_training.py \
-e ./config/environment_luna16_fold${i}.json \
-c ./config/config_train_luna16_16g.json
```
If you are tuning hyper-parameters, please also add `--verbose` flag.
Details about matched anchors during training will be printed out.

For each fold, 95% of the training data is used for training, while the rest 5% is used for validation and model selection.
The training and validation curves for 300 epochs of 10 folds are shown below. The upper row shows the training losses for box regression and classification. The bottom row shows the validation mAP and mAR for IoU ranging from 0.1 to 0.5.
Expand Down
6 changes: 3 additions & 3 deletions detection/generate_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
RandRotateBox90d,
RandZoomBoxd,
ConvertBoxModed,
StandardizeEmptyBoxd,
)


Expand Down Expand Up @@ -70,7 +71,6 @@ def generate_detection_train_transform(
Return:
training transform for detection
"""
amp = True
if amp:
compute_dtype = torch.float16
else:
Expand All @@ -82,6 +82,7 @@ def generate_detection_train_transform(
EnsureChannelFirstd(keys=[image_key]),
EnsureTyped(keys=[image_key, box_key], dtype=torch.float32),
EnsureTyped(keys=[label_key], dtype=torch.long),
StandardizeEmptyBoxd(box_keys=[box_key], box_ref_image_keys=image_key),
Orientationd(keys=[image_key], axcodes="RAS"),
intensity_transform,
EnsureTyped(keys=[image_key], dtype=torch.float16),
Expand Down Expand Up @@ -216,7 +217,6 @@ def generate_detection_val_transform(
Return:
validation transform for detection
"""
amp = True
if amp:
compute_dtype = torch.float16
else:
Expand All @@ -228,6 +228,7 @@ def generate_detection_val_transform(
EnsureChannelFirstd(keys=[image_key]),
EnsureTyped(keys=[image_key, box_key], dtype=torch.float32),
EnsureTyped(keys=[label_key], dtype=torch.long),
StandardizeEmptyBoxd(box_keys=[box_key], box_ref_image_keys=image_key),
Orientationd(keys=[image_key], axcodes="RAS"),
intensity_transform,
ConvertBoxToStandardModed(box_keys=[box_key], mode=gt_box_mode),
Expand Down Expand Up @@ -272,7 +273,6 @@ def generate_detection_inference_transform(
Return:
validation transform for detection
"""
amp = True
if amp:
compute_dtype = torch.float16
else:
Expand Down
9 changes: 8 additions & 1 deletion detection/luna16_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ def main():
default="./config/config_train.json",
help="config json file that stores hyper-parameters",
)
parser.add_argument(
"-v",
"--verbose",
default=False,
action="store_true",
help="whether to print verbose detail during training, recommand True when you are not sure about hyper-parameters",
)
args = parser.parse_args()

set_determinism(seed=0)
Expand Down Expand Up @@ -188,7 +195,7 @@ def main():
)

# 3) build detector
detector = RetinaNetDetector(network=net, anchor_generator=anchor_generator, debug=False).to(device)
detector = RetinaNetDetector(network=net, anchor_generator=anchor_generator, debug=args.verbose).to(device)

# set training components
detector.set_atss_matcher(num_candidates=4, center_in_gt=False)
Expand Down