In [1]:
import torch
import json
from generate_transforms import (
    generate_detection_train_transform,
    generate_detection_train_transform_v1,
    generate_detection_train_transform_v2,
)
from monai.data import Dataset, load_decathlon_datalist
from monai.transforms import ScaleIntensityRanged
from monai.utils import set_determinism

set_determinism(seed=24)

environment_file = "/workspace/Code/tutorials/detection/config/environment_luna16_fold0.json"
config_file = "/workspace/Code/tutorials/detection/config/config_train_luna16_16g.json"
env_dict = json.load(open(environment_file, "r"))
config_dict = json.load(open(config_file, "r"))

args = {}
for k, v in env_dict.items():
    args[k] = v
for k, v in config_dict.items():
    args[k] = v

intensity_transform = ScaleIntensityRanged(
    keys=["image"],
    a_min=-1024,
    a_max=300.0,
    b_min=0.0,
    b_max=1.0,
    clip=True,
)

train_transforms_old = generate_detection_train_transform(
    "image",
    "box",
    "label",
    args["gt_box_mode"],
    intensity_transform,
    args["patch_size"],
    args["batch_size"],
    affine_lps_to_ras=True,
    amp=True,
)

train_transforms_new_v1 = generate_detection_train_transform_v1(
    "image",
    "box",
    "label",
    "point",
    args["gt_box_mode"],
    intensity_transform,
    args["patch_size"],
    args["batch_size"],
    affine_lps_to_ras=True,
    amp=True,
)

train_transforms_new_v2 = generate_detection_train_transform_v2(
    "image",
    "box",
    "label",
    "point",
    args["gt_box_mode"],
    intensity_transform,
    args["patch_size"],
    args["batch_size"],
    affine_lps_to_ras=True,
    amp=True,
)

train_data = load_decathlon_datalist(
    args["data_list_file_path"],
    is_segmentation=True,
    data_list_key="training",
    base_dir=args["data_base_dir"],
)
train_ds_old = Dataset(
    data=train_data,
    transform=train_transforms_old,
)
train_ds_new_v1 = Dataset(
    data=train_data,
    transform=train_transforms_new_v1,
)

train_ds_new_v2 = Dataset(
    data=train_data,
    transform=train_transforms_new_v2,
)

  from torch.distributed.optim import ZeroRedundancyOptimizer
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import numpy as np

np.set_printoptions(precision=3, suppress=True)
for j, i in enumerate(train_ds_old):
    print(f"old version box {j}: ")
    print(i[0]["box"])
    break

for j, i in enumerate(train_ds_new_v1):
    print(f"new version v1 box {j}: ")
    print(i[0]["box"])
    break

for j, i in enumerate(train_ds_new_v2):
    print(f"new version v2 box {j}: ")
    print(i[0]["box"])
    break

old version box 0: 
metatensor([[117.,   0.,  63., 122.,   1.,  65.],
        [130.,  27.,  42., 144.,  41.,  51.]], dtype=torch.float16)
new version v1 box 0: 
metatensor([[117.,   0.,  63., 122.,   1.,  65.],
        [130.,  27.,  42., 144.,  41.,  51.]])
new version v2 box 0: 
metatensor([[112.3040,   0.0000,  57.1339, 129.3148,   2.9282,  70.1371],
        [128.4112,  24.6803,  40.4263, 146.4434,  42.5999,  54.2102]],
       dtype=torch.float64)
