Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

In [1]:
import torch
import numpy as np
from monai.transforms import (
    Compose,
    DeleteItemsd,
    EnsureChannelFirstd,
    EnsureTyped,
    LoadImaged,
    Orientationd,
    RandAdjustContrastd,
    RandGaussianNoised,
    RandGaussianSmoothd,
    RandRotated,
    RandScaleIntensityd,
    RandShiftIntensityd,
)
from monai.apps.detection.transforms.dictionary import (
    AffineBoxToImageCoordinated,
    BoxToMaskd,
    ClipBoxToImaged,
    ConvertBoxToStandardModed,
    MaskToBoxd,
    RandCropBoxByPosNegLabeld,
    RandFlipBoxd,
    RandRotateBox90d,
    RandZoomBoxd,
    StandardizeEmptyBoxd,
)

from generate_transforms import generate_detection_train_transform as generate_detection_train_transform_v1

# old version
def generate_detection_train_transform(
    image_key,
    box_key,
    label_key,
    gt_box_mode,
    intensity_transform,
    patch_size,
    batch_size,
    affine_lps_to_ras=False,
    amp=True,
):
    """
    Generate training transform for detection.

    Args:
        image_key: the key to represent images in the input json files
        box_key: the key to represent boxes in the input json files
        label_key: the key to represent box labels in the input json files
        gt_box_mode: ground truth box mode in the input json files
        intensity_transform: transform to scale image intensities,
            usually ScaleIntensityRanged for CT images, and NormalizeIntensityd for MR images.
        patch_size: cropped patch size for training
        batch_size: number of cropped patches from each image
        affine_lps_to_ras: Usually False.
            Set True only when the original images were read by itkreader with affine_lps_to_ras=True
        amp: whether to use half precision

    Return:
        training transform for detection
    """
    if amp:
        compute_dtype = torch.float16
    else:
        compute_dtype = torch.float32

    train_transforms = Compose(
        [
            LoadImaged(keys=[image_key], image_only=False, meta_key_postfix="meta_dict"),
            EnsureChannelFirstd(keys=[image_key]),
            EnsureTyped(keys=[image_key, box_key], dtype=torch.float32),
            EnsureTyped(keys=[label_key], dtype=torch.long),
            StandardizeEmptyBoxd(box_keys=[box_key], box_ref_image_keys=image_key),
            Orientationd(keys=[image_key], axcodes="RAS"),
            intensity_transform,
            EnsureTyped(keys=[image_key], dtype=torch.float16),
            ConvertBoxToStandardModed(box_keys=[box_key], mode=gt_box_mode),
            AffineBoxToImageCoordinated(
                box_keys=[box_key],
                box_ref_image_keys=image_key,
                image_meta_key_postfix="meta_dict",
                affine_lps_to_ras=affine_lps_to_ras,
            ),
            RandCropBoxByPosNegLabeld(
                image_keys=[image_key],
                box_keys=box_key,
                label_keys=label_key,
                spatial_size=patch_size,
                whole_box=True,
                num_samples=batch_size,
                pos=1,
                neg=1,
            ),
            RandZoomBoxd(
                image_keys=[image_key],
                box_keys=[box_key],
                box_ref_image_keys=[image_key],
                prob=0.2,
                min_zoom=0.7,
                max_zoom=1.4,
                padding_mode="constant",
                keep_size=True,
            ),
            ClipBoxToImaged(
                box_keys=box_key,
                label_keys=[label_key],
                box_ref_image_keys=image_key,
                remove_empty=True,
            ),
            RandFlipBoxd(
                image_keys=[image_key],
                box_keys=[box_key],
                box_ref_image_keys=[image_key],
                prob=0.5,
                spatial_axis=0,
            ),
            RandFlipBoxd(
                image_keys=[image_key],
                box_keys=[box_key],
                box_ref_image_keys=[image_key],
                prob=0.5,
                spatial_axis=1,
            ),
            RandFlipBoxd(
                image_keys=[image_key],
                box_keys=[box_key],
                box_ref_image_keys=[image_key],
                prob=0.5,
                spatial_axis=2,
            ),
            RandRotateBox90d(
                image_keys=[image_key],
                box_keys=[box_key],
                box_ref_image_keys=[image_key],
                prob=0.75,
                max_k=3,
                spatial_axes=(0, 1),
            ),
            BoxToMaskd(
                box_keys=[box_key],
                label_keys=[label_key],
                box_mask_keys=["box_mask"],
                box_ref_image_keys=image_key,
                min_fg_label=0,
                ellipse_mask=True,
            ),
            RandRotated(
                keys=[image_key, "box_mask"],
                mode=["nearest", "nearest"],
                prob=0.2,
                range_x=np.pi / 6,
                range_y=np.pi / 6,
                range_z=np.pi / 6,
                keep_size=True,
                padding_mode="zeros",
            ),
            MaskToBoxd(
                box_keys=[box_key],
                label_keys=[label_key],
                box_mask_keys=["box_mask"],
                min_fg_label=0,
            ),
            DeleteItemsd(keys=["box_mask"]),
            RandGaussianNoised(keys=[image_key], prob=0.1, mean=0, std=0.1),
            RandGaussianSmoothd(
                keys=[image_key],
                prob=0.1,
                sigma_x=(0.5, 1.0),
                sigma_y=(0.5, 1.0),
                sigma_z=(0.5, 1.0),
            ),
            RandScaleIntensityd(keys=[image_key], prob=0.15, factors=0.25),
            RandShiftIntensityd(keys=[image_key], prob=0.15, offsets=0.1),
            RandAdjustContrastd(keys=[image_key], prob=0.3, gamma=(0.7, 1.5)),
            EnsureTyped(keys=[image_key, box_key], dtype=compute_dtype),
            EnsureTyped(keys=[label_key], dtype=torch.long),
        ]
    )
    return train_transforms


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import json
from monai.data import Dataset, load_decathlon_datalist
from monai.transforms import ScaleIntensityRanged
from monai.utils import set_determinism

set_determinism(seed=24)

environment_file = "/workspace/Code/tutorials/detection/config/environment_luna16_fold0.json"
config_file = "/workspace/Code/tutorials/detection/config/config_train_luna16_16g.json"
env_dict = json.load(open(environment_file, "r"))
config_dict = json.load(open(config_file, "r"))

args = {}
for k, v in env_dict.items():
    args[k] = v
for k, v in config_dict.items():
    args[k] = v

intensity_transform = ScaleIntensityRanged(
    keys=["image"],
    a_min=-1024,
    a_max=300.0,
    b_min=0.0,
    b_max=1.0,
    clip=True,
)

train_transforms_old = generate_detection_train_transform(
    "image",
    "box",
    "label",
    args["gt_box_mode"],
    intensity_transform,
    args["patch_size"],
    args["batch_size"],
    affine_lps_to_ras=True,
    amp=True,
)

train_transforms_new_v1 = generate_detection_train_transform_v1(
    "image",
    "box",
    "label",
    "points",
    args["gt_box_mode"],
    intensity_transform,
    args["patch_size"],
    args["batch_size"],
    affine_lps_to_ras=True,
    amp=True,
)


train_data = load_decathlon_datalist(
    args["data_list_file_path"],
    is_segmentation=True,
    data_list_key="training",
    base_dir=args["data_base_dir"],
)
train_ds_old = Dataset(
    data=train_data,
    transform=train_transforms_old,
)
train_ds_new_v1 = Dataset(
    data=train_data,
    transform=train_transforms_new_v1,
)

In [3]:
import numpy as np

np.set_printoptions(precision=3, suppress=True)
for j, i in enumerate(train_ds_old):
    print(f"old version box {j}: ")
    print(i[0]["box"])
    break

for j, i in enumerate(train_ds_new_v1):
    print(f"new version v1 box {j}: ")
    print(i[0]["box"])
    print(type(i[0]["box"]))
    break

old version box 0: 
metatensor([[117.,   0.,  63., 122.,   1.,  65.],
        [130.,  27.,  42., 144.,  41.,  51.]], dtype=torch.float16)
new version v1 box 0: 
metatensor([[117.,   0.,  63., 122.,   1.,  65.],
        [130.,  27.,  42., 144.,  41.,  51.]])
<class 'monai.data.meta_tensor.MetaTensor'>
