diff --git a/modules/dynunet_pipeline/transforms.py b/modules/dynunet_pipeline/transforms.py index 9ef5ce07c7..52b95afcb3 100644 --- a/modules/dynunet_pipeline/transforms.py +++ b/modules/dynunet_pipeline/transforms.py @@ -67,11 +67,22 @@ def get_task_transforms(mode, task_id, pos_sample_num, neg_sample_num, num_sampl RandFlipd(["image", "label"], spatial_axis=[0], prob=0.5), RandFlipd(["image", "label"], spatial_axis=[1], prob=0.5), RandFlipd(["image", "label"], spatial_axis=[2], prob=0.5), + CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), + EnsureTyped(keys=["image", "label"]), + ] + elif mode == "validation": + other_transforms = [ + CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)), + EnsureTyped(keys=["image", "label"]), ] - - return Compose(load_transforms + sample_transforms + other_transforms) else: - return Compose(load_transforms + sample_transforms) + other_transforms = [ + CastToTyped(keys=["image"], dtype=(np.float32)), + EnsureTyped(keys=["image"]), + ] + + all_transforms = load_transforms + sample_transforms + other_transforms + return Compose(all_transforms) def resample_image(image, shape, anisotrophy_flag): resized_channels = []