From 57f8dbf2eb9655a7ff58826190a8a987a82ad3fb Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 13 Aug 2020 21:17:55 +0800 Subject: [PATCH] 895 Update temp directory function in all tests and examples (#899) * [DLMED] update tempdir * [DLMED] update tempdir in examples * [DLMED] fix typo * [MONAI] python code formatting Co-authored-by: monai-bot --- .../segmentation_3d/unet_evaluation_array.py | 92 +++--- .../segmentation_3d/unet_evaluation_dict.py | 130 ++++---- .../segmentation_3d/unet_training_array.py | 255 +++++++-------- .../segmentation_3d/unet_training_dict.py | 277 ++++++++-------- .../unet_evaluation_array.py | 148 +++++---- .../unet_evaluation_dict.py | 160 +++++----- .../unet_training_array.py | 232 +++++++------- .../unet_training_dict.py | 296 +++++++++--------- examples/workflows/unet_evaluation_dict.py | 128 ++++---- examples/workflows/unet_training_dict.py | 237 +++++++------- tests/test_arraydataset.py | 175 +++++------ tests/test_cachedataset.py | 47 ++- tests/test_cachedataset_parallel.py | 27 +- tests/test_check_md5.py | 13 +- tests/test_csv_saver.py | 36 +-- tests/test_data_stats.py | 38 ++- tests/test_data_statsd.py | 40 ++- tests/test_dataset.py | 86 +++-- tests/test_handler_checkpoint_loader.py | 48 ++- tests/test_handler_classification_saver.py | 52 ++- tests/test_handler_segmentation_saver.py | 80 +++-- tests/test_handler_stats.py | 44 ++- tests/test_handler_tb_image.py | 27 +- tests/test_handler_tb_stats.py | 86 +++-- tests/test_img2tensorboard.py | 52 ++- tests/test_integration_sliding_window.py | 16 +- tests/test_load_decathalon_datalist.py | 144 +++++---- tests/test_load_nifti.py | 12 +- tests/test_load_niftid.py | 13 +- tests/test_load_numpy.py | 51 ++- tests/test_load_numpyd.py | 51 ++- tests/test_load_png.py | 12 +- tests/test_load_pngd.py | 14 +- tests/test_nifti_dataset.py | 162 +++++----- tests/test_nifti_rw.py | 123 ++++---- tests/test_nifti_saver.py | 68 ++-- tests/test_persistentdataset.py | 52 ++- tests/test_plot_2d_or_3d_image.py | 11 +- tests/test_png_rw.py | 89 +++--- tests/test_png_saver.py | 62 ++-- 40 files changed, 1782 insertions(+), 1904 deletions(-) diff --git a/examples/segmentation_3d/unet_evaluation_array.py b/examples/segmentation_3d/unet_evaluation_array.py index 50bf7f0a0a..332667a8ea 100644 --- a/examples/segmentation_3d/unet_evaluation_array.py +++ b/examples/segmentation_3d/unet_evaluation_array.py @@ -11,7 +11,6 @@ import logging import os -import shutil import sys import tempfile from glob import glob @@ -33,58 +32,57 @@ def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) - tempdir = tempfile.mkdtemp() - print(f"generating synthetic data to {tempdir} (this may take a while)") - for i in range(5): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) + with tempfile.TemporaryDirectory() as tempdir: + print(f"generating synthetic data to {tempdir} (this may take a while)") + for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) - images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) - segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) + images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) + segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) - # define transforms for image and segmentation - imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) - segtrans = Compose([AddChannel(), ToTensor()]) - val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) - # sliding window inference for one image at every iteration - val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) - dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") + # define transforms for image and segmentation + imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) + segtrans = Compose([AddChannel(), ToTensor()]) + val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) + # sliding window inference for one image at every iteration + val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) + dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") - device = torch.device("cuda:0") - model = UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, - ).to(device) + device = torch.device("cuda:0") + model = UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ).to(device) - model.load_state_dict(torch.load("best_metric_model.pth")) - model.eval() - with torch.no_grad(): - metric_sum = 0.0 - metric_count = 0 - saver = NiftiSaver(output_dir="./output") - for val_data in val_loader: - val_images, val_labels = val_data[0].to(device), val_data[1].to(device) - # define sliding window size and batch size for windows inference - roi_size = (96, 96, 96) - sw_batch_size = 4 - val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = dice_metric(y_pred=val_outputs, y=val_labels) - metric_count += len(value) - metric_sum += value.item() * len(value) - val_outputs = (val_outputs.sigmoid() >= 0.5).float() - saver.save_batch(val_outputs, val_data[2]) - metric = metric_sum / metric_count - print("evaluation metric:", metric) - shutil.rmtree(tempdir) + model.load_state_dict(torch.load("best_metric_model.pth")) + model.eval() + with torch.no_grad(): + metric_sum = 0.0 + metric_count = 0 + saver = NiftiSaver(output_dir="./output") + for val_data in val_loader: + val_images, val_labels = val_data[0].to(device), val_data[1].to(device) + # define sliding window size and batch size for windows inference + roi_size = (96, 96, 96) + sw_batch_size = 4 + val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) + value = dice_metric(y_pred=val_outputs, y=val_labels) + metric_count += len(value) + metric_sum += value.item() * len(value) + val_outputs = (val_outputs.sigmoid() >= 0.5).float() + saver.save_batch(val_outputs, val_data[2]) + metric = metric_sum / metric_count + print("evaluation metric:", metric) if __name__ == "__main__": diff --git a/examples/segmentation_3d/unet_evaluation_dict.py b/examples/segmentation_3d/unet_evaluation_dict.py index 1ec736a033..fa86760a30 100644 --- a/examples/segmentation_3d/unet_evaluation_dict.py +++ b/examples/segmentation_3d/unet_evaluation_dict.py @@ -11,7 +11,6 @@ import logging import os -import shutil import sys import tempfile from glob import glob @@ -34,71 +33,70 @@ def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) - tempdir = tempfile.mkdtemp() - print(f"generating synthetic data to {tempdir} (this may take a while)") - for i in range(5): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) - - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) - - images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) - segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) - val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] - - # define transforms for image and segmentation - val_transforms = Compose( - [ - LoadNiftid(keys=["img", "seg"]), - AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), - ScaleIntensityd(keys="img"), - ToTensord(keys=["img", "seg"]), - ] - ) - val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) - # sliding window inference need to input 1 image in every iteration - val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) - dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") - - # try to use all the available GPUs - devices = get_devices_spec(None) - model = UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, - ).to(devices[0]) - - model.load_state_dict(torch.load("best_metric_model.pth")) - - # if we have multiple GPUs, set data parallel to execute sliding window inference - if len(devices) > 1: - model = torch.nn.DataParallel(model, device_ids=devices) - - model.eval() - with torch.no_grad(): - metric_sum = 0.0 - metric_count = 0 - saver = NiftiSaver(output_dir="./output") - for val_data in val_loader: - val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0]) - # define sliding window size and batch size for windows inference - roi_size = (96, 96, 96) - sw_batch_size = 4 - val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = dice_metric(y_pred=val_outputs, y=val_labels) - metric_count += len(value) - metric_sum += value.item() * len(value) - val_outputs = (val_outputs.sigmoid() >= 0.5).float() - saver.save_batch(val_outputs, val_data["img_meta_dict"]) - metric = metric_sum / metric_count - print("evaluation metric:", metric) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + print(f"generating synthetic data to {tempdir} (this may take a while)") + for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + + images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) + segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) + val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] + + # define transforms for image and segmentation + val_transforms = Compose( + [ + LoadNiftid(keys=["img", "seg"]), + AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), + ScaleIntensityd(keys="img"), + ToTensord(keys=["img", "seg"]), + ] + ) + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + # sliding window inference need to input 1 image in every iteration + val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) + dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") + + # try to use all the available GPUs + devices = get_devices_spec(None) + model = UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ).to(devices[0]) + + model.load_state_dict(torch.load("best_metric_model.pth")) + + # if we have multiple GPUs, set data parallel to execute sliding window inference + if len(devices) > 1: + model = torch.nn.DataParallel(model, device_ids=devices) + + model.eval() + with torch.no_grad(): + metric_sum = 0.0 + metric_count = 0 + saver = NiftiSaver(output_dir="./output") + for val_data in val_loader: + val_images, val_labels = val_data["img"].to(devices[0]), val_data["seg"].to(devices[0]) + # define sliding window size and batch size for windows inference + roi_size = (96, 96, 96) + sw_batch_size = 4 + val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) + value = dice_metric(y_pred=val_outputs, y=val_labels) + metric_count += len(value) + metric_sum += value.item() * len(value) + val_outputs = (val_outputs.sigmoid() >= 0.5).float() + saver.save_batch(val_outputs, val_data["img_meta_dict"]) + metric = metric_sum / metric_count + print("evaluation metric:", metric) if __name__ == "__main__": diff --git a/examples/segmentation_3d/unet_training_array.py b/examples/segmentation_3d/unet_training_array.py index 3aac22bff7..0b7e1bfd55 100644 --- a/examples/segmentation_3d/unet_training_array.py +++ b/examples/segmentation_3d/unet_training_array.py @@ -11,7 +11,6 @@ import logging import os -import shutil import sys import tempfile from glob import glob @@ -35,133 +34,135 @@ def main(): logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris - tempdir = tempfile.mkdtemp() - print(f"generating synthetic data to {tempdir} (this may take a while)") - for i in range(40): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) - - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) - - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) - - images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) - segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) - - # define transforms for image and segmentation - train_imtrans = Compose( - [ - ScaleIntensity(), - AddChannel(), - RandSpatialCrop((96, 96, 96), random_size=False), - RandRotate90(prob=0.5, spatial_axes=(0, 2)), - ToTensor(), - ] - ) - train_segtrans = Compose( - [ - AddChannel(), - RandSpatialCrop((96, 96, 96), random_size=False), - RandRotate90(prob=0.5, spatial_axes=(0, 2)), - ToTensor(), - ] - ) - val_imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) - val_segtrans = Compose([AddChannel(), ToTensor()]) - - # define nifti dataset, data loader - check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) - check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) - im, seg = monai.utils.misc.first(check_loader) - print(im.shape, seg.shape) - - # create a training data loader - train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) - train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) - # create a validation data loader - val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) - val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) - dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") - - # create UNet, DiceLoss and Adam optimizer - device = torch.device("cuda:0") - model = monai.networks.nets.UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, - ).to(device) - loss_function = monai.losses.DiceLoss(sigmoid=True) - optimizer = torch.optim.Adam(model.parameters(), 1e-3) - - # start a typical PyTorch training - val_interval = 2 - best_metric = -1 - best_metric_epoch = -1 - epoch_loss_values = list() - metric_values = list() - writer = SummaryWriter() - for epoch in range(5): - print("-" * 10) - print(f"epoch {epoch + 1}/{5}") - model.train() - epoch_loss = 0 - step = 0 - for batch_data in train_loader: - step += 1 - inputs, labels = batch_data[0].to(device), batch_data[1].to(device) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_function(outputs, labels) - loss.backward() - optimizer.step() - epoch_loss += loss.item() - epoch_len = len(train_ds) // train_loader.batch_size - print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") - writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) - epoch_loss /= step - epoch_loss_values.append(epoch_loss) - print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") - - if (epoch + 1) % val_interval == 0: - model.eval() - with torch.no_grad(): - metric_sum = 0.0 - metric_count = 0 - val_images = None - val_labels = None - val_outputs = None - for val_data in val_loader: - val_images, val_labels = val_data[0].to(device), val_data[1].to(device) - roi_size = (96, 96, 96) - sw_batch_size = 4 - val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = dice_metric(y_pred=val_outputs, y=val_labels) - metric_count += len(value) - metric_sum += value.item() * len(value) - metric = metric_sum / metric_count - metric_values.append(metric) - if metric > best_metric: - best_metric = metric - best_metric_epoch = epoch + 1 - torch.save(model.state_dict(), "best_metric_model.pth") - print("saved new best metric model") - print( - "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( - epoch + 1, metric, best_metric, best_metric_epoch + with tempfile.TemporaryDirectory() as tempdir: + print(f"generating synthetic data to {tempdir} (this may take a while)") + for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + + images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) + segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) + + # define transforms for image and segmentation + train_imtrans = Compose( + [ + ScaleIntensity(), + AddChannel(), + RandSpatialCrop((96, 96, 96), random_size=False), + RandRotate90(prob=0.5, spatial_axes=(0, 2)), + ToTensor(), + ] + ) + train_segtrans = Compose( + [ + AddChannel(), + RandSpatialCrop((96, 96, 96), random_size=False), + RandRotate90(prob=0.5, spatial_axes=(0, 2)), + ToTensor(), + ] + ) + val_imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) + val_segtrans = Compose([AddChannel(), ToTensor()]) + + # define nifti dataset, data loader + check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) + check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) + im, seg = monai.utils.misc.first(check_loader) + print(im.shape, seg.shape) + + # create a training data loader + train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) + train_loader = DataLoader( + train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available() + ) + # create a validation data loader + val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) + val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available()) + dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") + + # create UNet, DiceLoss and Adam optimizer + device = torch.device("cuda:0") + model = monai.networks.nets.UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ).to(device) + loss_function = monai.losses.DiceLoss(sigmoid=True) + optimizer = torch.optim.Adam(model.parameters(), 1e-3) + + # start a typical PyTorch training + val_interval = 2 + best_metric = -1 + best_metric_epoch = -1 + epoch_loss_values = list() + metric_values = list() + writer = SummaryWriter() + for epoch in range(5): + print("-" * 10) + print(f"epoch {epoch + 1}/{5}") + model.train() + epoch_loss = 0 + step = 0 + for batch_data in train_loader: + step += 1 + inputs, labels = batch_data[0].to(device), batch_data[1].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + epoch_len = len(train_ds) // train_loader.batch_size + print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") + writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) + epoch_loss /= step + epoch_loss_values.append(epoch_loss) + print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") + + if (epoch + 1) % val_interval == 0: + model.eval() + with torch.no_grad(): + metric_sum = 0.0 + metric_count = 0 + val_images = None + val_labels = None + val_outputs = None + for val_data in val_loader: + val_images, val_labels = val_data[0].to(device), val_data[1].to(device) + roi_size = (96, 96, 96) + sw_batch_size = 4 + val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) + value = dice_metric(y_pred=val_outputs, y=val_labels) + metric_count += len(value) + metric_sum += value.item() * len(value) + metric = metric_sum / metric_count + metric_values.append(metric) + if metric > best_metric: + best_metric = metric + best_metric_epoch = epoch + 1 + torch.save(model.state_dict(), "best_metric_model.pth") + print("saved new best metric model") + print( + "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( + epoch + 1, metric, best_metric, best_metric_epoch + ) ) - ) - writer.add_scalar("val_mean_dice", metric, epoch + 1) - # plot the last model output as GIF image in TensorBoard with the corresponding image and label - plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") - plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") - plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") - shutil.rmtree(tempdir) - print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") - writer.close() + writer.add_scalar("val_mean_dice", metric, epoch + 1) + # plot the last model output as GIF image in TensorBoard with the corresponding image and label + plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") + plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") + plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") + + print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") + writer.close() if __name__ == "__main__": diff --git a/examples/segmentation_3d/unet_training_dict.py b/examples/segmentation_3d/unet_training_dict.py index 42188356d0..42737fb44c 100644 --- a/examples/segmentation_3d/unet_training_dict.py +++ b/examples/segmentation_3d/unet_training_dict.py @@ -11,7 +11,6 @@ import logging import os -import shutil import sys import tempfile from glob import glob @@ -43,145 +42,145 @@ def main(): logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris - tempdir = tempfile.mkdtemp() - print(f"generating synthetic data to {tempdir} (this may take a while)") - for i in range(40): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) - - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) - - images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) - segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) - train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])] - val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])] - - # define transforms for image and segmentation - train_transforms = Compose( - [ - LoadNiftid(keys=["img", "seg"]), - AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), - ScaleIntensityd(keys="img"), - RandCropByPosNegLabeld( - keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 - ), - RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]), - ToTensord(keys=["img", "seg"]), - ] - ) - val_transforms = Compose( - [ - LoadNiftid(keys=["img", "seg"]), - AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), - ScaleIntensityd(keys="img"), - ToTensord(keys=["img", "seg"]), - ] - ) - - # define dataset, data loader - check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) - # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training - check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate) - check_data = monai.utils.misc.first(check_loader) - print(check_data["img"].shape, check_data["seg"].shape) - - # create a training data loader - train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) - # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training - train_loader = DataLoader( - train_ds, - batch_size=2, - shuffle=True, - num_workers=4, - collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available(), - ) - # create a validation data loader - val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) - val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) - dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") - - # create UNet, DiceLoss and Adam optimizer - device = torch.device("cuda:0") - model = monai.networks.nets.UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, - ).to(device) - loss_function = monai.losses.DiceLoss(sigmoid=True) - optimizer = torch.optim.Adam(model.parameters(), 1e-3) - - # start a typical PyTorch training - val_interval = 2 - best_metric = -1 - best_metric_epoch = -1 - epoch_loss_values = list() - metric_values = list() - writer = SummaryWriter() - for epoch in range(5): - print("-" * 10) - print(f"epoch {epoch + 1}/{5}") - model.train() - epoch_loss = 0 - step = 0 - for batch_data in train_loader: - step += 1 - inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_function(outputs, labels) - loss.backward() - optimizer.step() - epoch_loss += loss.item() - epoch_len = len(train_ds) // train_loader.batch_size - print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") - writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) - epoch_loss /= step - epoch_loss_values.append(epoch_loss) - print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") - - if (epoch + 1) % val_interval == 0: - model.eval() - with torch.no_grad(): - metric_sum = 0.0 - metric_count = 0 - val_images = None - val_labels = None - val_outputs = None - for val_data in val_loader: - val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) - roi_size = (96, 96, 96) - sw_batch_size = 4 - val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - value = dice_metric(y_pred=val_outputs, y=val_labels) - metric_count += len(value) - metric_sum += value.item() * len(value) - metric = metric_sum / metric_count - metric_values.append(metric) - if metric > best_metric: - best_metric = metric - best_metric_epoch = epoch + 1 - torch.save(model.state_dict(), "best_metric_model.pth") - print("saved new best metric model") - print( - "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( - epoch + 1, metric, best_metric, best_metric_epoch + with tempfile.TemporaryDirectory() as tempdir: + print(f"generating synthetic data to {tempdir} (this may take a while)") + for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + + images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) + segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) + train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])] + val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])] + + # define transforms for image and segmentation + train_transforms = Compose( + [ + LoadNiftid(keys=["img", "seg"]), + AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), + ScaleIntensityd(keys="img"), + RandCropByPosNegLabeld( + keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 + ), + RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]), + ToTensord(keys=["img", "seg"]), + ] + ) + val_transforms = Compose( + [ + LoadNiftid(keys=["img", "seg"]), + AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), + ScaleIntensityd(keys="img"), + ToTensord(keys=["img", "seg"]), + ] + ) + + # define dataset, data loader + check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training + check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate) + check_data = monai.utils.misc.first(check_loader) + print(check_data["img"].shape, check_data["seg"].shape) + + # create a training data loader + train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training + train_loader = DataLoader( + train_ds, + batch_size=2, + shuffle=True, + num_workers=4, + collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available(), + ) + # create a validation data loader + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) + dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean") + + # create UNet, DiceLoss and Adam optimizer + device = torch.device("cuda:0") + model = monai.networks.nets.UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ).to(device) + loss_function = monai.losses.DiceLoss(sigmoid=True) + optimizer = torch.optim.Adam(model.parameters(), 1e-3) + + # start a typical PyTorch training + val_interval = 2 + best_metric = -1 + best_metric_epoch = -1 + epoch_loss_values = list() + metric_values = list() + writer = SummaryWriter() + for epoch in range(5): + print("-" * 10) + print(f"epoch {epoch + 1}/{5}") + model.train() + epoch_loss = 0 + step = 0 + for batch_data in train_loader: + step += 1 + inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + epoch_len = len(train_ds) // train_loader.batch_size + print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") + writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) + epoch_loss /= step + epoch_loss_values.append(epoch_loss) + print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") + + if (epoch + 1) % val_interval == 0: + model.eval() + with torch.no_grad(): + metric_sum = 0.0 + metric_count = 0 + val_images = None + val_labels = None + val_outputs = None + for val_data in val_loader: + val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) + roi_size = (96, 96, 96) + sw_batch_size = 4 + val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) + value = dice_metric(y_pred=val_outputs, y=val_labels) + metric_count += len(value) + metric_sum += value.item() * len(value) + metric = metric_sum / metric_count + metric_values.append(metric) + if metric > best_metric: + best_metric = metric + best_metric_epoch = epoch + 1 + torch.save(model.state_dict(), "best_metric_model.pth") + print("saved new best metric model") + print( + "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( + epoch + 1, metric, best_metric, best_metric_epoch + ) ) - ) - writer.add_scalar("val_mean_dice", metric, epoch + 1) - # plot the last model output as GIF image in TensorBoard with the corresponding image and label - plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") - plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") - plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") - shutil.rmtree(tempdir) - print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") - writer.close() + writer.add_scalar("val_mean_dice", metric, epoch + 1) + # plot the last model output as GIF image in TensorBoard with the corresponding image and label + plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image") + plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") + plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") + + print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") + writer.close() if __name__ == "__main__": diff --git a/examples/segmentation_3d_ignite/unet_evaluation_array.py b/examples/segmentation_3d_ignite/unet_evaluation_array.py index 343980594e..3a498f0ba7 100644 --- a/examples/segmentation_3d_ignite/unet_evaluation_array.py +++ b/examples/segmentation_3d_ignite/unet_evaluation_array.py @@ -11,7 +11,6 @@ import logging import os -import shutil import sys import tempfile from glob import glob @@ -35,80 +34,79 @@ def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) - tempdir = tempfile.mkdtemp() - print(f"generating synthetic data to {tempdir} (this may take a while)") - for i in range(5): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) - - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) - - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) - - images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) - segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) - - # define transforms for image and segmentation - imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) - segtrans = Compose([AddChannel(), ToTensor()]) - ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) - - device = torch.device("cuda:0") - net = UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, - ) - net.to(device) - - # define sliding window size and batch size for windows inference - roi_size = (96, 96, 96) - sw_batch_size = 4 - - def _sliding_window_processor(engine, batch): - net.eval() - with torch.no_grad(): - val_images, val_labels = batch[0].to(device), batch[1].to(device) - seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) - return seg_probs, val_labels - - evaluator = Engine(_sliding_window_processor) - - # add evaluation metric to the evaluator engine - MeanDice(sigmoid=True, to_onehot_y=False).attach(evaluator, "Mean_Dice") - - # StatsHandler prints loss at every iteration and print metrics at every epoch, - # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions - val_stats_handler = StatsHandler( - name="evaluator", - output_transform=lambda x: None, # no need to print loss value, so disable per iteration output - ) - val_stats_handler.attach(evaluator) - - # for the array data format, assume the 3rd item of batch data is the meta_data - file_saver = SegmentationSaver( - output_dir="tempdir", - output_ext=".nii.gz", - output_postfix="seg", - name="evaluator", - batch_transform=lambda x: x[2], - output_transform=lambda output: predict_segmentation(output[0]), - ) - file_saver.attach(evaluator) - - # the model was trained by "unet_training_array" example - ckpt_saver = CheckpointLoader(load_path="./runs/net_checkpoint_100.pth", load_dict={"net": net}) - ckpt_saver.attach(evaluator) - - # sliding window inference for one image at every iteration - loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) - state = evaluator.run(loader) - print(state) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + print(f"generating synthetic data to {tempdir} (this may take a while)") + for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + + images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) + segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) + + # define transforms for image and segmentation + imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) + segtrans = Compose([AddChannel(), ToTensor()]) + ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) + + device = torch.device("cuda:0") + net = UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ) + net.to(device) + + # define sliding window size and batch size for windows inference + roi_size = (96, 96, 96) + sw_batch_size = 4 + + def _sliding_window_processor(engine, batch): + net.eval() + with torch.no_grad(): + val_images, val_labels = batch[0].to(device), batch[1].to(device) + seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) + return seg_probs, val_labels + + evaluator = Engine(_sliding_window_processor) + + # add evaluation metric to the evaluator engine + MeanDice(sigmoid=True, to_onehot_y=False).attach(evaluator, "Mean_Dice") + + # StatsHandler prints loss at every iteration and print metrics at every epoch, + # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions + val_stats_handler = StatsHandler( + name="evaluator", + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + ) + val_stats_handler.attach(evaluator) + + # for the array data format, assume the 3rd item of batch data is the meta_data + file_saver = SegmentationSaver( + output_dir="tempdir", + output_ext=".nii.gz", + output_postfix="seg", + name="evaluator", + batch_transform=lambda x: x[2], + output_transform=lambda output: predict_segmentation(output[0]), + ) + file_saver.attach(evaluator) + + # the model was trained by "unet_training_array" example + ckpt_saver = CheckpointLoader(load_path="./runs/net_checkpoint_100.pth", load_dict={"net": net}) + ckpt_saver.attach(evaluator) + + # sliding window inference for one image at every iteration + loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) + state = evaluator.run(loader) + print(state) if __name__ == "__main__": diff --git a/examples/segmentation_3d_ignite/unet_evaluation_dict.py b/examples/segmentation_3d_ignite/unet_evaluation_dict.py index 2ea3a4ba3a..05f49aae4a 100644 --- a/examples/segmentation_3d_ignite/unet_evaluation_dict.py +++ b/examples/segmentation_3d_ignite/unet_evaluation_dict.py @@ -11,7 +11,6 @@ import logging import os -import shutil import sys import tempfile from glob import glob @@ -35,86 +34,85 @@ def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) - tempdir = tempfile.mkdtemp() - print(f"generating synthetic data to {tempdir} (this may take a while)") - for i in range(5): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) - - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) - - images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) - segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) - val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] - - # define transforms for image and segmentation - val_transforms = Compose( - [ - LoadNiftid(keys=["img", "seg"]), - AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), - ScaleIntensityd(keys="img"), - ToTensord(keys=["img", "seg"]), - ] - ) - val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) - - device = torch.device("cuda:0") - net = UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, - ) - net.to(device) - - # define sliding window size and batch size for windows inference - roi_size = (96, 96, 96) - sw_batch_size = 4 - - def _sliding_window_processor(engine, batch): - net.eval() - with torch.no_grad(): - val_images, val_labels = batch["img"].to(device), batch["seg"].to(device) - seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) - return seg_probs, val_labels - - evaluator = Engine(_sliding_window_processor) - - # add evaluation metric to the evaluator engine - MeanDice(sigmoid=True, to_onehot_y=False).attach(evaluator, "Mean_Dice") - - # StatsHandler prints loss at every iteration and print metrics at every epoch, - # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions - val_stats_handler = StatsHandler( - name="evaluator", - output_transform=lambda x: None, # no need to print loss value, so disable per iteration output - ) - val_stats_handler.attach(evaluator) - - # convert the necessary metadata from batch data - SegmentationSaver( - output_dir="tempdir", - output_ext=".nii.gz", - output_postfix="seg", - name="evaluator", - batch_transform=lambda batch: batch["img_meta_dict"], - output_transform=lambda output: predict_segmentation(output[0]), - ).attach(evaluator) - # the model was trained by "unet_training_dict" example - CheckpointLoader(load_path="./runs/net_checkpoint_50.pth", load_dict={"net": net}).attach(evaluator) - - # sliding window inference for one image at every iteration - val_loader = DataLoader( - val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() - ) - state = evaluator.run(val_loader) - print(state) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + print(f"generating synthetic data to {tempdir} (this may take a while)") + for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + + images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) + segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) + val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] + + # define transforms for image and segmentation + val_transforms = Compose( + [ + LoadNiftid(keys=["img", "seg"]), + AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), + ScaleIntensityd(keys="img"), + ToTensord(keys=["img", "seg"]), + ] + ) + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + + device = torch.device("cuda:0") + net = UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ) + net.to(device) + + # define sliding window size and batch size for windows inference + roi_size = (96, 96, 96) + sw_batch_size = 4 + + def _sliding_window_processor(engine, batch): + net.eval() + with torch.no_grad(): + val_images, val_labels = batch["img"].to(device), batch["seg"].to(device) + seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) + return seg_probs, val_labels + + evaluator = Engine(_sliding_window_processor) + + # add evaluation metric to the evaluator engine + MeanDice(sigmoid=True, to_onehot_y=False).attach(evaluator, "Mean_Dice") + + # StatsHandler prints loss at every iteration and print metrics at every epoch, + # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions + val_stats_handler = StatsHandler( + name="evaluator", + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + ) + val_stats_handler.attach(evaluator) + + # convert the necessary metadata from batch data + SegmentationSaver( + output_dir="tempdir", + output_ext=".nii.gz", + output_postfix="seg", + name="evaluator", + batch_transform=lambda batch: batch["img_meta_dict"], + output_transform=lambda output: predict_segmentation(output[0]), + ).attach(evaluator) + # the model was trained by "unet_training_dict" example + CheckpointLoader(load_path="./runs/net_checkpoint_50.pth", load_dict={"net": net}).attach(evaluator) + + # sliding window inference for one image at every iteration + val_loader = DataLoader( + val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() + ) + state = evaluator.run(val_loader) + print(state) if __name__ == "__main__": diff --git a/examples/segmentation_3d_ignite/unet_training_array.py b/examples/segmentation_3d_ignite/unet_training_array.py index 3840736349..0d2a175a0e 100644 --- a/examples/segmentation_3d_ignite/unet_training_array.py +++ b/examples/segmentation_3d_ignite/unet_training_array.py @@ -11,7 +11,6 @@ import logging import os -import shutil import sys import tempfile from glob import glob @@ -41,121 +40,122 @@ def main(): logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris - tempdir = tempfile.mkdtemp() - print(f"generating synthetic data to {tempdir} (this may take a while)") - for i in range(40): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) - - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) - - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) - - images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) - segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) - - # define transforms for image and segmentation - train_imtrans = Compose( - [ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), ToTensor()] - ) - train_segtrans = Compose([AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), ToTensor()]) - val_imtrans = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor()]) - val_segtrans = Compose([AddChannel(), Resize((96, 96, 96)), ToTensor()]) - - # define nifti dataset, data loader - check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) - check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) - im, seg = monai.utils.misc.first(check_loader) - print(im.shape, seg.shape) - - # create a training data loader - train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) - train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available()) - # create a validation data loader - val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) - val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) - - # create UNet, DiceLoss and Adam optimizer - net = monai.networks.nets.UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, - ) - loss = monai.losses.DiceLoss(sigmoid=True) - lr = 1e-3 - opt = torch.optim.Adam(net.parameters(), lr) - device = torch.device("cuda:0") - - # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, - # user can add output_transform to return other values, like: y_pred, y, etc. - trainer = create_supervised_trainer(net, opt, loss, device, False) - - # adding checkpoint handler to save models (network params and optimizer stats) during training - checkpoint_handler = ModelCheckpoint("./runs/", "net", n_saved=10, require_empty=False) - trainer.add_event_handler( - event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt} - ) - - # StatsHandler prints loss at every iteration and print metrics at every epoch, - # we don't set metrics for trainer here, so just print loss, user can also customize print functions - # and can use output_transform to convert engine.state.output if it's not a loss value - train_stats_handler = StatsHandler(name="trainer") - train_stats_handler.attach(trainer) - - # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler - train_tensorboard_stats_handler = TensorBoardStatsHandler() - train_tensorboard_stats_handler.attach(trainer) - - validation_every_n_epochs = 1 - # Set parameters for validation - metric_name = "Mean_Dice" - # add evaluation metric to the evaluator engine - val_metrics = {metric_name: MeanDice(sigmoid=True, to_onehot_y=False)} - - # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, - # user can add output_transform to return other values - evaluator = create_supervised_evaluator(net, val_metrics, device, True) - - @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) - def run_validation(engine): - evaluator.run(val_loader) - - # add early stopping handler to evaluator - early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) - evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) - - # add stats event handler to print validation stats via evaluator - val_stats_handler = StatsHandler( - name="evaluator", - output_transform=lambda x: None, # no need to print loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch, - ) # fetch global epoch number from trainer - val_stats_handler.attach(evaluator) - - # add handler to record metrics to TensorBoard at every validation epoch - val_tensorboard_stats_handler = TensorBoardStatsHandler( - output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch, - ) # fetch global epoch number from trainer - val_tensorboard_stats_handler.attach(evaluator) - - # add handler to draw the first image and the corresponding label and model output in the last batch - # here we draw the 3D output as GIF format along Depth axis, at every validation epoch - val_tensorboard_image_handler = TensorBoardImageHandler( - batch_transform=lambda batch: (batch[0], batch[1]), - output_transform=lambda output: predict_segmentation(output[0]), - global_iter_transform=lambda x: trainer.state.epoch, - ) - evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) - - train_epochs = 30 - state = trainer.run(train_loader, train_epochs) - print(state) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + print(f"generating synthetic data to {tempdir} (this may take a while)") + for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + + images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) + segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) + + # define transforms for image and segmentation + train_imtrans = Compose( + [ScaleIntensity(), AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), ToTensor()] + ) + train_segtrans = Compose([AddChannel(), RandSpatialCrop((96, 96, 96), random_size=False), ToTensor()]) + val_imtrans = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor()]) + val_segtrans = Compose([AddChannel(), Resize((96, 96, 96)), ToTensor()]) + + # define nifti dataset, data loader + check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans) + check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) + im, seg = monai.utils.misc.first(check_loader) + print(im.shape, seg.shape) + + # create a training data loader + train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans) + train_loader = DataLoader( + train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available() + ) + # create a validation data loader + val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans) + val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) + + # create UNet, DiceLoss and Adam optimizer + net = monai.networks.nets.UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ) + loss = monai.losses.DiceLoss(sigmoid=True) + lr = 1e-3 + opt = torch.optim.Adam(net.parameters(), lr) + device = torch.device("cuda:0") + + # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, + # user can add output_transform to return other values, like: y_pred, y, etc. + trainer = create_supervised_trainer(net, opt, loss, device, False) + + # adding checkpoint handler to save models (network params and optimizer stats) during training + checkpoint_handler = ModelCheckpoint("./runs/", "net", n_saved=10, require_empty=False) + trainer.add_event_handler( + event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt} + ) + + # StatsHandler prints loss at every iteration and print metrics at every epoch, + # we don't set metrics for trainer here, so just print loss, user can also customize print functions + # and can use output_transform to convert engine.state.output if it's not a loss value + train_stats_handler = StatsHandler(name="trainer") + train_stats_handler.attach(trainer) + + # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler + train_tensorboard_stats_handler = TensorBoardStatsHandler() + train_tensorboard_stats_handler.attach(trainer) + + validation_every_n_epochs = 1 + # Set parameters for validation + metric_name = "Mean_Dice" + # add evaluation metric to the evaluator engine + val_metrics = {metric_name: MeanDice(sigmoid=True, to_onehot_y=False)} + + # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, + # user can add output_transform to return other values + evaluator = create_supervised_evaluator(net, val_metrics, device, True) + + @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) + def run_validation(engine): + evaluator.run(val_loader) + + # add early stopping handler to evaluator + early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) + evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + + # add stats event handler to print validation stats via evaluator + val_stats_handler = StatsHandler( + name="evaluator", + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch, + ) # fetch global epoch number from trainer + val_stats_handler.attach(evaluator) + + # add handler to record metrics to TensorBoard at every validation epoch + val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch, + ) # fetch global epoch number from trainer + val_tensorboard_stats_handler.attach(evaluator) + + # add handler to draw the first image and the corresponding label and model output in the last batch + # here we draw the 3D output as GIF format along Depth axis, at every validation epoch + val_tensorboard_image_handler = TensorBoardImageHandler( + batch_transform=lambda batch: (batch[0], batch[1]), + output_transform=lambda output: predict_segmentation(output[0]), + global_iter_transform=lambda x: trainer.state.epoch, + ) + evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler) + + train_epochs = 30 + state = trainer.run(train_loader, train_epochs) + print(state) if __name__ == "__main__": diff --git a/examples/segmentation_3d_ignite/unet_training_dict.py b/examples/segmentation_3d_ignite/unet_training_dict.py index 30a36b97f2..4d59c23af2 100644 --- a/examples/segmentation_3d_ignite/unet_training_dict.py +++ b/examples/segmentation_3d_ignite/unet_training_dict.py @@ -11,7 +11,6 @@ import logging import os -import shutil import sys import tempfile from glob import glob @@ -49,153 +48,154 @@ def main(): logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris - tempdir = tempfile.mkdtemp() - print(f"generating synthetic data to {tempdir} (this may take a while)") - for i in range(40): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) - - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) - - images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) - segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) - train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])] - val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])] - - # define transforms for image and segmentation - train_transforms = Compose( - [ - LoadNiftid(keys=["img", "seg"]), - AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), - ScaleIntensityd(keys="img"), - RandCropByPosNegLabeld( - keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 - ), - RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]), - ToTensord(keys=["img", "seg"]), - ] - ) - val_transforms = Compose( - [ - LoadNiftid(keys=["img", "seg"]), - AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), - ScaleIntensityd(keys="img"), - ToTensord(keys=["img", "seg"]), - ] - ) - - # define dataset, data loader - check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) - # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training - check_loader = DataLoader( - check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() - ) - check_data = monai.utils.misc.first(check_loader) - print(check_data["img"].shape, check_data["seg"].shape) - - # create a training data loader - train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) - # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training - train_loader = DataLoader( - train_ds, - batch_size=2, - shuffle=True, - num_workers=4, - collate_fn=list_data_collate, - pin_memory=torch.cuda.is_available(), - ) - # create a validation data loader - val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) - val_loader = DataLoader( - val_ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() - ) - - # create UNet, DiceLoss and Adam optimizer - net = monai.networks.nets.UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, - ) - loss = monai.losses.DiceLoss(sigmoid=True) - lr = 1e-3 - opt = torch.optim.Adam(net.parameters(), lr) - device = torch.device("cuda:0") - - # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, - # user can add output_transform to return other values, like: y_pred, y, etc. - def prepare_batch(batch, device=None, non_blocking=False): - return _prepare_batch((batch["img"], batch["seg"]), device, non_blocking) - - trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) - - # adding checkpoint handler to save models (network params and optimizer stats) during training - checkpoint_handler = ModelCheckpoint("./runs/", "net", n_saved=10, require_empty=False) - trainer.add_event_handler( - event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt} - ) - - # StatsHandler prints loss at every iteration and print metrics at every epoch, - # we don't set metrics for trainer here, so just print loss, user can also customize print functions - # and can use output_transform to convert engine.state.output if it's not loss value - train_stats_handler = StatsHandler(name="trainer") - train_stats_handler.attach(trainer) - - # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler - train_tensorboard_stats_handler = TensorBoardStatsHandler() - train_tensorboard_stats_handler.attach(trainer) - - validation_every_n_iters = 5 - # set parameters for validation - metric_name = "Mean_Dice" - # add evaluation metric to the evaluator engine - val_metrics = {metric_name: MeanDice(sigmoid=True, to_onehot_y=False)} - - # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, - # user can add output_transform to return other values - evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) - - @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) - def run_validation(engine): - evaluator.run(val_loader) - - # add early stopping handler to evaluator - early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) - evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) - - # add stats event handler to print validation stats via evaluator - val_stats_handler = StatsHandler( - name="evaluator", - output_transform=lambda x: None, # no need to print loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.epoch, - ) # fetch global epoch number from trainer - val_stats_handler.attach(evaluator) - - # add handler to record metrics to TensorBoard at every validation epoch - val_tensorboard_stats_handler = TensorBoardStatsHandler( - output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output - global_epoch_transform=lambda x: trainer.state.iteration, - ) # fetch global iteration number from trainer - val_tensorboard_stats_handler.attach(evaluator) - - # add handler to draw the first image and the corresponding label and model output in the last batch - # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations. - val_tensorboard_image_handler = TensorBoardImageHandler( - batch_transform=lambda batch: (batch["img"], batch["seg"]), - output_transform=lambda output: predict_segmentation(output[0]), - global_iter_transform=lambda x: trainer.state.epoch, - ) - evaluator.add_event_handler(event_name=Events.ITERATION_COMPLETED(every=2), handler=val_tensorboard_image_handler) - - train_epochs = 5 - state = trainer.run(train_loader, train_epochs) - print(state) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + print(f"generating synthetic data to {tempdir} (this may take a while)") + for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) + + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + + images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) + segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) + train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])] + val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])] + + # define transforms for image and segmentation + train_transforms = Compose( + [ + LoadNiftid(keys=["img", "seg"]), + AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), + ScaleIntensityd(keys="img"), + RandCropByPosNegLabeld( + keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 + ), + RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]), + ToTensord(keys=["img", "seg"]), + ] + ) + val_transforms = Compose( + [ + LoadNiftid(keys=["img", "seg"]), + AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), + ScaleIntensityd(keys="img"), + ToTensord(keys=["img", "seg"]), + ] + ) + + # define dataset, data loader + check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training + check_loader = DataLoader( + check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() + ) + check_data = monai.utils.misc.first(check_loader) + print(check_data["img"].shape, check_data["seg"].shape) + + # create a training data loader + train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) + # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training + train_loader = DataLoader( + train_ds, + batch_size=2, + shuffle=True, + num_workers=4, + collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available(), + ) + # create a validation data loader + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + val_loader = DataLoader( + val_ds, batch_size=5, num_workers=8, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() + ) + + # create UNet, DiceLoss and Adam optimizer + net = monai.networks.nets.UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ) + loss = monai.losses.DiceLoss(sigmoid=True) + lr = 1e-3 + opt = torch.optim.Adam(net.parameters(), lr) + device = torch.device("cuda:0") + + # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, + # user can add output_transform to return other values, like: y_pred, y, etc. + def prepare_batch(batch, device=None, non_blocking=False): + return _prepare_batch((batch["img"], batch["seg"]), device, non_blocking) + + trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch) + + # adding checkpoint handler to save models (network params and optimizer stats) during training + checkpoint_handler = ModelCheckpoint("./runs/", "net", n_saved=10, require_empty=False) + trainer.add_event_handler( + event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt} + ) + + # StatsHandler prints loss at every iteration and print metrics at every epoch, + # we don't set metrics for trainer here, so just print loss, user can also customize print functions + # and can use output_transform to convert engine.state.output if it's not loss value + train_stats_handler = StatsHandler(name="trainer") + train_stats_handler.attach(trainer) + + # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler + train_tensorboard_stats_handler = TensorBoardStatsHandler() + train_tensorboard_stats_handler.attach(trainer) + + validation_every_n_iters = 5 + # set parameters for validation + metric_name = "Mean_Dice" + # add evaluation metric to the evaluator engine + val_metrics = {metric_name: MeanDice(sigmoid=True, to_onehot_y=False)} + + # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, + # user can add output_transform to return other values + evaluator = create_supervised_evaluator(net, val_metrics, device, True, prepare_batch=prepare_batch) + + @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) + def run_validation(engine): + evaluator.run(val_loader) + + # add early stopping handler to evaluator + early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer) + evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) + + # add stats event handler to print validation stats via evaluator + val_stats_handler = StatsHandler( + name="evaluator", + output_transform=lambda x: None, # no need to print loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.epoch, + ) # fetch global epoch number from trainer + val_stats_handler.attach(evaluator) + + # add handler to record metrics to TensorBoard at every validation epoch + val_tensorboard_stats_handler = TensorBoardStatsHandler( + output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output + global_epoch_transform=lambda x: trainer.state.iteration, + ) # fetch global iteration number from trainer + val_tensorboard_stats_handler.attach(evaluator) + + # add handler to draw the first image and the corresponding label and model output in the last batch + # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations. + val_tensorboard_image_handler = TensorBoardImageHandler( + batch_transform=lambda batch: (batch["img"], batch["seg"]), + output_transform=lambda output: predict_segmentation(output[0]), + global_iter_transform=lambda x: trainer.state.epoch, + ) + evaluator.add_event_handler( + event_name=Events.ITERATION_COMPLETED(every=2), handler=val_tensorboard_image_handler + ) + + train_epochs = 5 + state = trainer.run(train_loader, train_epochs) + print(state) if __name__ == "__main__": diff --git a/examples/workflows/unet_evaluation_dict.py b/examples/workflows/unet_evaluation_dict.py index d9f3237701..d1aa600b09 100644 --- a/examples/workflows/unet_evaluation_dict.py +++ b/examples/workflows/unet_evaluation_dict.py @@ -11,7 +11,6 @@ import logging import os -import shutil import sys import tempfile from glob import glob @@ -43,77 +42,76 @@ def main(): logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris - tempdir = tempfile.mkdtemp() - print(f"generating synthetic data to {tempdir} (this may take a while)") - for i in range(5): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + with tempfile.TemporaryDirectory() as tempdir: + print(f"generating synthetic data to {tempdir} (this may take a while)") + for i in range(5): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) - images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) - segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) - val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)] + images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) + segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) + val_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)] - # define transforms for image and segmentation - val_transforms = Compose( - [ - LoadNiftid(keys=["image", "label"]), - AsChannelFirstd(keys=["image", "label"], channel_dim=-1), - ScaleIntensityd(keys="image"), - ToTensord(keys=["image", "label"]), - ] - ) + # define transforms for image and segmentation + val_transforms = Compose( + [ + LoadNiftid(keys=["image", "label"]), + AsChannelFirstd(keys=["image", "label"], channel_dim=-1), + ScaleIntensityd(keys="image"), + ToTensord(keys=["image", "label"]), + ] + ) - # create a validation data loader - val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) - val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) + # create a validation data loader + val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) + val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - # create UNet, DiceLoss and Adam optimizer - device = torch.device("cuda:0") - net = monai.networks.nets.UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, - ).to(device) + # create UNet, DiceLoss and Adam optimizer + device = torch.device("cuda:0") + net = monai.networks.nets.UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ).to(device) - val_post_transforms = Compose( - [ - Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), - KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), + val_post_transforms = Compose( + [ + Activationsd(keys="pred", sigmoid=True), + AsDiscreted(keys="pred", threshold_values=True), + KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), + ] + ) + val_handlers = [ + StatsHandler(output_transform=lambda x: None), + CheckpointLoader(load_path="./runs/net_key_metric=0.9101.pth", load_dict={"net": net}), + SegmentationSaver( + output_dir="./runs/", + batch_transform=lambda batch: batch["image_meta_dict"], + output_transform=lambda output: output["pred"], + ), ] - ) - val_handlers = [ - StatsHandler(output_transform=lambda x: None), - CheckpointLoader(load_path="./runs/net_key_metric=0.9101.pth", load_dict={"net": net}), - SegmentationSaver( - output_dir="./runs/", - batch_transform=lambda batch: batch["image_meta_dict"], - output_transform=lambda output: output["pred"], - ), - ] - evaluator = SupervisedEvaluator( - device=device, - val_data_loader=val_loader, - network=net, - inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), - post_transform=val_post_transforms, - key_val_metric={ - "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) - }, - additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, - val_handlers=val_handlers, - # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation - amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, - ) - evaluator.run() - shutil.rmtree(tempdir) + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=val_loader, + network=net, + inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), + post_transform=val_post_transforms, + key_val_metric={ + "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) + }, + additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, + val_handlers=val_handlers, + # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation + amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, + ) + evaluator.run() if __name__ == "__main__": diff --git a/examples/workflows/unet_training_dict.py b/examples/workflows/unet_training_dict.py index 5a974a403f..f185ad02ad 100644 --- a/examples/workflows/unet_training_dict.py +++ b/examples/workflows/unet_training_dict.py @@ -11,7 +11,6 @@ import logging import os -import shutil import sys import tempfile from glob import glob @@ -53,127 +52,127 @@ def main(): logging.basicConfig(stream=sys.stdout, level=logging.INFO) # create a temporary directory and 40 random image, mask paris - tempdir = tempfile.mkdtemp() - print(f"generating synthetic data to {tempdir} (this may take a while)") - for i in range(40): - im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) - n = nib.Nifti1Image(im, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) - n = nib.Nifti1Image(seg, np.eye(4)) - nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) - - images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) - segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) - train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])] - val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])] - - # define transforms for image and segmentation - train_transforms = Compose( - [ - LoadNiftid(keys=["image", "label"]), - AsChannelFirstd(keys=["image", "label"], channel_dim=-1), - ScaleIntensityd(keys="image"), - RandCropByPosNegLabeld( - keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 + with tempfile.TemporaryDirectory() as tempdir: + print(f"generating synthetic data to {tempdir} (this may take a while)") + for i in range(40): + im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) + n = nib.Nifti1Image(im, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz")) + n = nib.Nifti1Image(seg, np.eye(4)) + nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) + + images = sorted(glob(os.path.join(tempdir, "img*.nii.gz"))) + segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) + train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])] + val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])] + + # define transforms for image and segmentation + train_transforms = Compose( + [ + LoadNiftid(keys=["image", "label"]), + AsChannelFirstd(keys=["image", "label"], channel_dim=-1), + ScaleIntensityd(keys="image"), + RandCropByPosNegLabeld( + keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 + ), + RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), + ToTensord(keys=["image", "label"]), + ] + ) + val_transforms = Compose( + [ + LoadNiftid(keys=["image", "label"]), + AsChannelFirstd(keys=["image", "label"], channel_dim=-1), + ScaleIntensityd(keys="image"), + ToTensord(keys=["image", "label"]), + ] + ) + + # create a training data loader + train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) + # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training + train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) + # create a validation data loader + val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) + val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) + + # create UNet, DiceLoss and Adam optimizer + device = torch.device("cuda:0") + net = monai.networks.nets.UNet( + dimensions=3, + in_channels=1, + out_channels=1, + channels=(16, 32, 64, 128, 256), + strides=(2, 2, 2, 2), + num_res_units=2, + ).to(device) + loss = monai.losses.DiceLoss(sigmoid=True) + opt = torch.optim.Adam(net.parameters(), 1e-3) + lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) + + val_post_transforms = Compose( + [ + Activationsd(keys="pred", sigmoid=True), + AsDiscreted(keys="pred", threshold_values=True), + KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), + ] + ) + val_handlers = [ + StatsHandler(output_transform=lambda x: None), + TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None), + TensorBoardImageHandler( + log_dir="./runs/", + batch_transform=lambda x: (x["image"], x["label"]), + output_transform=lambda x: x["pred"], ), - RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), - ToTensord(keys=["image", "label"]), + CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True), ] - ) - val_transforms = Compose( - [ - LoadNiftid(keys=["image", "label"]), - AsChannelFirstd(keys=["image", "label"], channel_dim=-1), - ScaleIntensityd(keys="image"), - ToTensord(keys=["image", "label"]), - ] - ) - - # create a training data loader - train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) - # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training - train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) - # create a validation data loader - val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0) - val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - - # create UNet, DiceLoss and Adam optimizer - device = torch.device("cuda:0") - net = monai.networks.nets.UNet( - dimensions=3, - in_channels=1, - out_channels=1, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, - ).to(device) - loss = monai.losses.DiceLoss(sigmoid=True) - opt = torch.optim.Adam(net.parameters(), 1e-3) - lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) - - val_post_transforms = Compose( - [ - Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), - KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), - ] - ) - val_handlers = [ - StatsHandler(output_transform=lambda x: None), - TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None), - TensorBoardImageHandler( - log_dir="./runs/", batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"] - ), - CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True), - ] - - evaluator = SupervisedEvaluator( - device=device, - val_data_loader=val_loader, - network=net, - inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), - post_transform=val_post_transforms, - key_val_metric={ - "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) - }, - additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, - val_handlers=val_handlers, - # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation - amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, - ) - - train_post_transforms = Compose( - [ - Activationsd(keys="pred", sigmoid=True), - AsDiscreted(keys="pred", threshold_values=True), - KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), + + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=val_loader, + network=net, + inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), + post_transform=val_post_transforms, + key_val_metric={ + "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) + }, + additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, + val_handlers=val_handlers, + # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation + amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, + ) + + train_post_transforms = Compose( + [ + Activationsd(keys="pred", sigmoid=True), + AsDiscreted(keys="pred", threshold_values=True), + KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), + ] + ) + train_handlers = [ + LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), + ValidationHandler(validator=evaluator, interval=2, epoch_level=True), + StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), + TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]), + CheckpointSaver(save_dir="./runs/", save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), ] - ) - train_handlers = [ - LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), - ValidationHandler(validator=evaluator, interval=2, epoch_level=True), - StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), - TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]), - CheckpointSaver(save_dir="./runs/", save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), - ] - - trainer = SupervisedTrainer( - device=device, - max_epochs=5, - train_data_loader=train_loader, - network=net, - optimizer=opt, - loss_function=loss, - inferer=SimpleInferer(), - post_transform=train_post_transforms, - key_train_metric={"train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, - train_handlers=train_handlers, - # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training - amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, - ) - trainer.run() - - shutil.rmtree(tempdir) + + trainer = SupervisedTrainer( + device=device, + max_epochs=5, + train_data_loader=train_loader, + network=net, + optimizer=opt, + loss_function=loss, + inferer=SimpleInferer(), + post_transform=train_post_transforms, + key_train_metric={"train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, + train_handlers=train_handlers, + # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training + amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False, + ) + trainer.run() if __name__ == "__main__": diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py index 0935ee5a4f..5bd5059fd5 100644 --- a/tests/test_arraydataset.py +++ b/tests/test_arraydataset.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -59,106 +58,104 @@ class TestArrayDataset(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, img_transform, label_transform, indexes, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)), np.eye(4)) - tempdir = tempfile.mkdtemp() - test_image1 = os.path.join(tempdir, "test_image1.nii.gz") - test_seg1 = os.path.join(tempdir, "test_seg1.nii.gz") - test_image2 = os.path.join(tempdir, "test_image2.nii.gz") - test_seg2 = os.path.join(tempdir, "test_seg2.nii.gz") - nib.save(test_image, test_image1) - nib.save(test_image, test_seg1) - nib.save(test_image, test_image2) - nib.save(test_image, test_seg2) - test_images = [test_image1, test_image2] - test_segs = [test_seg1, test_seg2] - test_labels = [1, 1] - dataset = ArrayDataset(test_images, img_transform, test_segs, label_transform, test_labels, None) - self.assertEqual(len(dataset), 2) - dataset.set_random_state(1234) - data1 = dataset[0] - data2 = dataset[1] - - self.assertTupleEqual(data1[indexes[0]].shape, expected_shape) - self.assertTupleEqual(data1[indexes[1]].shape, expected_shape) - np.testing.assert_allclose(data1[indexes[0]], data1[indexes[1]]) - self.assertTupleEqual(data2[indexes[0]].shape, expected_shape) - self.assertTupleEqual(data2[indexes[1]].shape, expected_shape) - np.testing.assert_allclose(data2[indexes[0]], data2[indexes[0]]) - - dataset = ArrayDataset(test_images, img_transform, test_segs, label_transform, test_labels, None) - dataset.set_random_state(1234) - _ = dataset[0] - data2_new = dataset[1] - np.testing.assert_allclose(data2[indexes[0]], data2_new[indexes[0]], atol=1e-3) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + test_image1 = os.path.join(tempdir, "test_image1.nii.gz") + test_seg1 = os.path.join(tempdir, "test_seg1.nii.gz") + test_image2 = os.path.join(tempdir, "test_image2.nii.gz") + test_seg2 = os.path.join(tempdir, "test_seg2.nii.gz") + nib.save(test_image, test_image1) + nib.save(test_image, test_seg1) + nib.save(test_image, test_image2) + nib.save(test_image, test_seg2) + test_images = [test_image1, test_image2] + test_segs = [test_seg1, test_seg2] + test_labels = [1, 1] + dataset = ArrayDataset(test_images, img_transform, test_segs, label_transform, test_labels, None) + self.assertEqual(len(dataset), 2) + dataset.set_random_state(1234) + data1 = dataset[0] + data2 = dataset[1] + + self.assertTupleEqual(data1[indexes[0]].shape, expected_shape) + self.assertTupleEqual(data1[indexes[1]].shape, expected_shape) + np.testing.assert_allclose(data1[indexes[0]], data1[indexes[1]]) + self.assertTupleEqual(data2[indexes[0]].shape, expected_shape) + self.assertTupleEqual(data2[indexes[1]].shape, expected_shape) + np.testing.assert_allclose(data2[indexes[0]], data2[indexes[0]]) + + dataset = ArrayDataset(test_images, img_transform, test_segs, label_transform, test_labels, None) + dataset.set_random_state(1234) + _ = dataset[0] + data2_new = dataset[1] + np.testing.assert_allclose(data2[indexes[0]], data2_new[indexes[0]], atol=1e-3) @parameterized.expand([TEST_CASE_4]) def test_default_none(self, img_transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)), np.eye(4)) - tempdir = tempfile.mkdtemp() - test_image1 = os.path.join(tempdir, "test_image1.nii.gz") - test_image2 = os.path.join(tempdir, "test_image2.nii.gz") - nib.save(test_image, test_image1) - nib.save(test_image, test_image2) - test_images = [test_image1, test_image2] - dataset = ArrayDataset(test_images, img_transform) - self.assertEqual(len(dataset), 2) - dataset.set_random_state(1234) - data1 = dataset[0] - data2 = dataset[1] - self.assertTupleEqual(data1.shape, expected_shape) - self.assertTupleEqual(data2.shape, expected_shape) - - dataset = ArrayDataset(test_images, img_transform) - dataset.set_random_state(1234) - _ = dataset[0] - data2_new = dataset[1] - np.testing.assert_allclose(data2, data2_new, atol=1e-3) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + test_image1 = os.path.join(tempdir, "test_image1.nii.gz") + test_image2 = os.path.join(tempdir, "test_image2.nii.gz") + nib.save(test_image, test_image1) + nib.save(test_image, test_image2) + test_images = [test_image1, test_image2] + dataset = ArrayDataset(test_images, img_transform) + self.assertEqual(len(dataset), 2) + dataset.set_random_state(1234) + data1 = dataset[0] + data2 = dataset[1] + self.assertTupleEqual(data1.shape, expected_shape) + self.assertTupleEqual(data2.shape, expected_shape) + + dataset = ArrayDataset(test_images, img_transform) + dataset.set_random_state(1234) + _ = dataset[0] + data2_new = dataset[1] + np.testing.assert_allclose(data2, data2_new, atol=1e-3) @parameterized.expand([TEST_CASE_4]) def test_dataloading_img(self, img_transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)), np.eye(4)) - tempdir = tempfile.mkdtemp() - test_image1 = os.path.join(tempdir, "test_image1.nii.gz") - test_image2 = os.path.join(tempdir, "test_image2.nii.gz") - nib.save(test_image, test_image1) - nib.save(test_image, test_image2) - test_images = [test_image1, test_image2] - dataset = ArrayDataset(test_images, img_transform) - self.assertEqual(len(dataset), 2) - dataset.set_random_state(1234) - loader = DataLoader(dataset, batch_size=10, num_workers=1) - imgs = next(iter(loader)) # test batching - np.testing.assert_allclose(imgs.shape, [2] + list(expected_shape)) - - dataset.set_random_state(1234) - new_imgs = next(iter(loader)) # test batching - np.testing.assert_allclose(imgs, new_imgs, atol=1e-3) + with tempfile.TemporaryDirectory() as tempdir: + test_image1 = os.path.join(tempdir, "test_image1.nii.gz") + test_image2 = os.path.join(tempdir, "test_image2.nii.gz") + nib.save(test_image, test_image1) + nib.save(test_image, test_image2) + test_images = [test_image1, test_image2] + dataset = ArrayDataset(test_images, img_transform) + self.assertEqual(len(dataset), 2) + dataset.set_random_state(1234) + loader = DataLoader(dataset, batch_size=10, num_workers=1) + imgs = next(iter(loader)) # test batching + np.testing.assert_allclose(imgs.shape, [2] + list(expected_shape)) + + dataset.set_random_state(1234) + new_imgs = next(iter(loader)) # test batching + np.testing.assert_allclose(imgs, new_imgs, atol=1e-3) @parameterized.expand([TEST_CASE_4]) def test_dataloading_img_label(self, img_transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)), np.eye(4)) - tempdir = tempfile.mkdtemp() - test_image1 = os.path.join(tempdir, "test_image1.nii.gz") - test_image2 = os.path.join(tempdir, "test_image2.nii.gz") - test_label1 = os.path.join(tempdir, "test_label1.nii.gz") - test_label2 = os.path.join(tempdir, "test_label2.nii.gz") - nib.save(test_image, test_image1) - nib.save(test_image, test_image2) - nib.save(test_image, test_label1) - nib.save(test_image, test_label2) - test_images = [test_image1, test_image2] - test_labels = [test_label1, test_label2] - dataset = ArrayDataset(test_images, img_transform, test_labels, img_transform) - self.assertEqual(len(dataset), 2) - dataset.set_random_state(1234) - loader = DataLoader(dataset, batch_size=10, num_workers=1) - data = next(iter(loader)) # test batching - np.testing.assert_allclose(data[0].shape, [2] + list(expected_shape)) - - dataset.set_random_state(1234) - new_data = next(iter(loader)) # test batching - np.testing.assert_allclose(data[0], new_data[0], atol=1e-3) + with tempfile.TemporaryDirectory() as tempdir: + test_image1 = os.path.join(tempdir, "test_image1.nii.gz") + test_image2 = os.path.join(tempdir, "test_image2.nii.gz") + test_label1 = os.path.join(tempdir, "test_label1.nii.gz") + test_label2 = os.path.join(tempdir, "test_label2.nii.gz") + nib.save(test_image, test_image1) + nib.save(test_image, test_image2) + nib.save(test_image, test_label1) + nib.save(test_image, test_label2) + test_images = [test_image1, test_image2] + test_labels = [test_label1, test_label2] + dataset = ArrayDataset(test_images, img_transform, test_labels, img_transform) + self.assertEqual(len(dataset), 2) + dataset.set_random_state(1234) + loader = DataLoader(dataset, batch_size=10, num_workers=1) + data = next(iter(loader)) # test batching + np.testing.assert_allclose(data[0].shape, [2] + list(expected_shape)) + + dataset.set_random_state(1234) + new_data = next(iter(loader)) # test batching + np.testing.assert_allclose(data[0], new_data[0], atol=1e-3) if __name__ == "__main__": diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index 5713a6df14..7450d7dfdc 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -30,29 +29,29 @@ class TestCacheDataset(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) - tempdir = tempfile.mkdtemp() - nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) - test_data = [ - { - "image": os.path.join(tempdir, "test_image1.nii.gz"), - "label": os.path.join(tempdir, "test_label1.nii.gz"), - "extra": os.path.join(tempdir, "test_extra1.nii.gz"), - }, - { - "image": os.path.join(tempdir, "test_image2.nii.gz"), - "label": os.path.join(tempdir, "test_label2.nii.gz"), - "extra": os.path.join(tempdir, "test_extra2.nii.gz"), - }, - ] - dataset = CacheDataset(data=test_data, transform=transform, cache_rate=0.5) - data1 = dataset[0] - data2 = dataset[1] - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) + test_data = [ + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": os.path.join(tempdir, "test_label1.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2.nii.gz"), + "label": os.path.join(tempdir, "test_label2.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2.nii.gz"), + }, + ] + dataset = CacheDataset(data=test_data, transform=transform, cache_rate=0.5) + data1 = dataset[0] + data2 = dataset[1] + if transform is None: self.assertEqual(data1["image"], os.path.join(tempdir, "test_image1.nii.gz")) self.assertEqual(data2["label"], os.path.join(tempdir, "test_label2.nii.gz")) diff --git a/tests/test_cachedataset_parallel.py b/tests/test_cachedataset_parallel.py index 93335aceed..c32ca72516 100644 --- a/tests/test_cachedataset_parallel.py +++ b/tests/test_cachedataset_parallel.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -32,19 +31,19 @@ class TestCacheDatasetParallel(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, num_workers, dataset_size, transform): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) - tempdir = tempfile.mkdtemp() - nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) - test_data = [ - { - "image": os.path.join(tempdir, "test_image1.nii.gz"), - "label": os.path.join(tempdir, "test_label1.nii.gz"), - "extra": os.path.join(tempdir, "test_extra1.nii.gz"), - } - ] * dataset_size - dataset = CacheDataset(data=test_data, transform=transform, cache_rate=1, num_workers=num_workers,) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + test_data = [ + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": os.path.join(tempdir, "test_label1.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1.nii.gz"), + } + ] * dataset_size + dataset = CacheDataset(data=test_data, transform=transform, cache_rate=1, num_workers=num_workers,) + self.assertEqual(len(dataset._cache), dataset.cache_num) for i in range(dataset.cache_num): self.assertIsNotNone(dataset._cache[i]) diff --git a/tests/test_check_md5.py b/tests/test_check_md5.py index e252ae9544..679a299084 100644 --- a/tests/test_check_md5.py +++ b/tests/test_check_md5.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -31,14 +30,12 @@ class TestCheckMD5(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, md5_value, expected_result): test_image = np.ones((64, 64, 3)) - tempdir = tempfile.mkdtemp() - filename = os.path.join(tempdir, "test_file.png") - Image.fromarray(test_image.astype("uint8")).save(filename) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_file.png") + Image.fromarray(test_image.astype("uint8")).save(filename) - result = check_md5(filename, md5_value) - self.assertTrue(result == expected_result) - - shutil.rmtree(tempdir) + result = check_md5(filename, md5_value) + self.assertTrue(result == expected_result) if __name__ == "__main__": diff --git a/tests/test_csv_saver.py b/tests/test_csv_saver.py index 9a92c49d72..d1ff1975ed 100644 --- a/tests/test_csv_saver.py +++ b/tests/test_csv_saver.py @@ -11,7 +11,7 @@ import csv import os -import shutil +import tempfile import unittest import numpy as np @@ -22,25 +22,21 @@ class TestCSVSaver(unittest.TestCase): def test_saved_content(self): - default_dir = os.path.join(".", "tempdir") - shutil.rmtree(default_dir, ignore_errors=True) - - saver = CSVSaver(output_dir=default_dir, filename="predictions.csv") - - meta_data = {"filename_or_obj": ["testfile" + str(i) for i in range(8)]} - saver.save_batch(torch.zeros(8), meta_data) - saver.finalize() - filepath = os.path.join(default_dir, "predictions.csv") - self.assertTrue(os.path.exists(filepath)) - with open(filepath, "r") as f: - reader = csv.reader(f) - i = 0 - for row in reader: - self.assertEqual(row[0], "testfile" + str(i)) - self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0) - i += 1 - self.assertEqual(i, 8) - shutil.rmtree(default_dir) + with tempfile.TemporaryDirectory() as tempdir: + saver = CSVSaver(output_dir=tempdir, filename="predictions.csv") + meta_data = {"filename_or_obj": ["testfile" + str(i) for i in range(8)]} + saver.save_batch(torch.zeros(8), meta_data) + saver.finalize() + filepath = os.path.join(tempdir, "predictions.csv") + self.assertTrue(os.path.exists(filepath)) + with open(filepath, "r") as f: + reader = csv.reader(f) + i = 0 + for row in reader: + self.assertEqual(row[0], "testfile" + str(i)) + self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0) + i += 1 + self.assertEqual(i, 8) if __name__ == "__main__": diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py index 7d0a6013c4..35c15c8050 100644 --- a/tests/test_data_stats.py +++ b/tests/test_data_stats.py @@ -11,7 +11,6 @@ import logging import os -import shutil import tempfile import unittest @@ -117,25 +116,24 @@ def test_value(self, input_param, input_data, expected_print): @parameterized.expand([TEST_CASE_7]) def test_file(self, input_data, expected_print): - tempdir = tempfile.mkdtemp() - filename = os.path.join(tempdir, "test_data_stats.log") - handler = logging.FileHandler(filename, mode="w") - input_param = { - "prefix": "test data", - "data_shape": True, - "value_range": True, - "data_value": True, - "additional_info": lambda x: np.mean(x), - "logger_handler": handler, - } - transform = DataStats(**input_param) - _ = transform(input_data) - handler.stream.close() - transform._logger.removeHandler(handler) - with open(filename, "r") as f: - content = f.read() - self.assertEqual(content, expected_print) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_data_stats.log") + handler = logging.FileHandler(filename, mode="w") + input_param = { + "prefix": "test data", + "data_shape": True, + "value_range": True, + "data_value": True, + "additional_info": lambda x: np.mean(x), + "logger_handler": handler, + } + transform = DataStats(**input_param) + _ = transform(input_data) + handler.stream.close() + transform._logger.removeHandler(handler) + with open(filename, "r") as f: + content = f.read() + self.assertEqual(content, expected_print) if __name__ == "__main__": diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py index 7dc28024aa..79a4ab5ff5 100644 --- a/tests/test_data_statsd.py +++ b/tests/test_data_statsd.py @@ -11,7 +11,6 @@ import logging import os -import shutil import tempfile import unittest @@ -130,26 +129,25 @@ def test_value(self, input_param, input_data, expected_print): @parameterized.expand([TEST_CASE_8]) def test_file(self, input_data, expected_print): - tempdir = tempfile.mkdtemp() - filename = os.path.join(tempdir, "test_stats.log") - handler = logging.FileHandler(filename, mode="w") - input_param = { - "keys": "img", - "prefix": "test data", - "data_shape": True, - "value_range": True, - "data_value": True, - "additional_info": lambda x: np.mean(x), - "logger_handler": handler, - } - transform = DataStatsd(**input_param) - _ = transform(input_data) - handler.stream.close() - transform.printer._logger.removeHandler(handler) - with open(filename, "r") as f: - content = f.read() - self.assertEqual(content, expected_print) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_stats.log") + handler = logging.FileHandler(filename, mode="w") + input_param = { + "keys": "img", + "prefix": "test data", + "data_shape": True, + "value_range": True, + "data_value": True, + "additional_info": lambda x: np.mean(x), + "logger_handler": handler, + } + transform = DataStatsd(**input_param) + _ = transform(input_data) + handler.stream.close() + transform.printer._logger.removeHandler(handler) + with open(filename, "r") as f: + content = f.read() + self.assertEqual(content, expected_print) if __name__ == "__main__": diff --git a/tests/test_dataset.py b/tests/test_dataset.py index bc6f7078c1..93b531bb41 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -28,53 +27,52 @@ class TestDataset(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) def test_shape(self, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) - tempdir = tempfile.mkdtemp() - nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) - test_data = [ - { - "image": os.path.join(tempdir, "test_image1.nii.gz"), - "label": os.path.join(tempdir, "test_label1.nii.gz"), - "extra": os.path.join(tempdir, "test_extra1.nii.gz"), - }, - { - "image": os.path.join(tempdir, "test_image2.nii.gz"), - "label": os.path.join(tempdir, "test_label2.nii.gz"), - "extra": os.path.join(tempdir, "test_extra2.nii.gz"), - }, - ] - test_transform = Compose( - [ - LoadNiftid(keys=["image", "label", "extra"]), - SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) + test_data = [ + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": os.path.join(tempdir, "test_label1.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2.nii.gz"), + "label": os.path.join(tempdir, "test_label2.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2.nii.gz"), + }, ] - ) - dataset = Dataset(data=test_data, transform=test_transform) - data1 = dataset[0] - data2 = dataset[1] + test_transform = Compose( + [ + LoadNiftid(keys=["image", "label", "extra"]), + SimulateDelayd(keys=["image", "label", "extra"], delay_time=[1e-7, 1e-6, 1e-5]), + ] + ) + dataset = Dataset(data=test_data, transform=test_transform) + data1 = dataset[0] + data2 = dataset[1] - self.assertTupleEqual(data1["image"].shape, expected_shape) - self.assertTupleEqual(data1["label"].shape, expected_shape) - self.assertTupleEqual(data1["extra"].shape, expected_shape) - self.assertTupleEqual(data2["image"].shape, expected_shape) - self.assertTupleEqual(data2["label"].shape, expected_shape) - self.assertTupleEqual(data2["extra"].shape, expected_shape) + self.assertTupleEqual(data1["image"].shape, expected_shape) + self.assertTupleEqual(data1["label"].shape, expected_shape) + self.assertTupleEqual(data1["extra"].shape, expected_shape) + self.assertTupleEqual(data2["image"].shape, expected_shape) + self.assertTupleEqual(data2["label"].shape, expected_shape) + self.assertTupleEqual(data2["extra"].shape, expected_shape) - dataset = Dataset(data=test_data, transform=LoadNiftid(keys=["image", "label", "extra"])) - data1_simple = dataset[0] - data2_simple = dataset[1] + dataset = Dataset(data=test_data, transform=LoadNiftid(keys=["image", "label", "extra"])) + data1_simple = dataset[0] + data2_simple = dataset[1] - self.assertTupleEqual(data1_simple["image"].shape, expected_shape) - self.assertTupleEqual(data1_simple["label"].shape, expected_shape) - self.assertTupleEqual(data1_simple["extra"].shape, expected_shape) - self.assertTupleEqual(data2_simple["image"].shape, expected_shape) - self.assertTupleEqual(data2_simple["label"].shape, expected_shape) - self.assertTupleEqual(data2_simple["extra"].shape, expected_shape) - shutil.rmtree(tempdir) + self.assertTupleEqual(data1_simple["image"].shape, expected_shape) + self.assertTupleEqual(data1_simple["label"].shape, expected_shape) + self.assertTupleEqual(data1_simple["extra"].shape, expected_shape) + self.assertTupleEqual(data2_simple["image"].shape, expected_shape) + self.assertTupleEqual(data2_simple["label"].shape, expected_shape) + self.assertTupleEqual(data2_simple["extra"].shape, expected_shape) if __name__ == "__main__": diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py index b759abaa90..5d1d19eb39 100644 --- a/tests/test_handler_checkpoint_loader.py +++ b/tests/test_handler_checkpoint_loader.py @@ -10,7 +10,6 @@ # limitations under the License. import logging -import shutil import sys import tempfile import unittest @@ -34,14 +33,13 @@ def test_one_save_one_load(self): data2["weight"] = torch.tensor([0.2]) net2.load_state_dict(data2) engine = Engine(lambda e, b: None) - tempdir = tempfile.mkdtemp() - CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) - engine.run([0] * 8, max_epochs=5) - path = tempdir + "/net_final_iteration=40.pth" - CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) - engine.run([0] * 8, max_epochs=1) - torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) + engine.run([0] * 8, max_epochs=5) + path = tempdir + "/net_final_iteration=40.pth" + CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) + engine.run([0] * 8, max_epochs=1) + torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1) def test_two_save_one_load(self): logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -55,15 +53,14 @@ def test_two_save_one_load(self): data2["weight"] = torch.tensor([0.2]) net2.load_state_dict(data2) engine = Engine(lambda e, b: None) - tempdir = tempfile.mkdtemp() - save_dict = {"net": net1, "opt": optimizer} - CheckpointSaver(save_dir=tempdir, save_dict=save_dict, save_final=True).attach(engine) - engine.run([0] * 8, max_epochs=5) - path = tempdir + "/checkpoint_final_iteration=40.pth" - CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) - engine.run([0] * 8, max_epochs=1) - torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + save_dict = {"net": net1, "opt": optimizer} + CheckpointSaver(save_dir=tempdir, save_dict=save_dict, save_final=True).attach(engine) + engine.run([0] * 8, max_epochs=5) + path = tempdir + "/checkpoint_final_iteration=40.pth" + CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) + engine.run([0] * 8, max_epochs=1) + torch.testing.assert_allclose(net2.state_dict()["weight"], 0.1) def test_save_single_device_load_multi_devices(self): logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -77,14 +74,13 @@ def test_save_single_device_load_multi_devices(self): net2.load_state_dict(data2) net2 = torch.nn.DataParallel(net2) engine = Engine(lambda e, b: None) - tempdir = tempfile.mkdtemp() - CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) - engine.run([0] * 8, max_epochs=5) - path = tempdir + "/net_final_iteration=40.pth" - CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) - engine.run([0] * 8, max_epochs=1) - torch.testing.assert_allclose(net2.state_dict()["module.weight"], 0.1) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine) + engine.run([0] * 8, max_epochs=5) + path = tempdir + "/net_final_iteration=40.pth" + CheckpointLoader(load_path=path, load_dict={"net": net2}).attach(engine) + engine.run([0] * 8, max_epochs=1) + torch.testing.assert_allclose(net2.state_dict()["module.weight"], 0.1) if __name__ == "__main__": diff --git a/tests/test_handler_classification_saver.py b/tests/test_handler_classification_saver.py index 2b4a5046f5..3b05092adc 100644 --- a/tests/test_handler_classification_saver.py +++ b/tests/test_handler_classification_saver.py @@ -11,7 +11,7 @@ import csv import os -import shutil +import tempfile import unittest import numpy as np @@ -23,32 +23,30 @@ class TestHandlerClassificationSaver(unittest.TestCase): def test_saved_content(self): - default_dir = os.path.join(".", "tempdir") - shutil.rmtree(default_dir, ignore_errors=True) - - # set up engine - def _train_func(engine, batch): - return torch.zeros(8) - - engine = Engine(_train_func) - - # set up testing handler - saver = ClassificationSaver(output_dir=default_dir, filename="predictions.csv") - saver.attach(engine) - - data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8)]}] - engine.run(data, max_epochs=1) - filepath = os.path.join(default_dir, "predictions.csv") - self.assertTrue(os.path.exists(filepath)) - with open(filepath, "r") as f: - reader = csv.reader(f) - i = 0 - for row in reader: - self.assertEqual(row[0], "testfile" + str(i)) - self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0) - i += 1 - self.assertEqual(i, 8) - shutil.rmtree(default_dir) + with tempfile.TemporaryDirectory() as tempdir: + + # set up engine + def _train_func(engine, batch): + return torch.zeros(8) + + engine = Engine(_train_func) + + # set up testing handler + saver = ClassificationSaver(output_dir=tempdir, filename="predictions.csv") + saver.attach(engine) + + data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8)]}] + engine.run(data, max_epochs=1) + filepath = os.path.join(tempdir, "predictions.csv") + self.assertTrue(os.path.exists(filepath)) + with open(filepath, "r") as f: + reader = csv.reader(f) + i = 0 + for row in reader: + self.assertEqual(row[0], "testfile" + str(i)) + self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0) + i += 1 + self.assertEqual(i, 8) if __name__ == "__main__": diff --git a/tests/test_handler_segmentation_saver.py b/tests/test_handler_segmentation_saver.py index 0d0dda65c6..16da580338 100644 --- a/tests/test_handler_segmentation_saver.py +++ b/tests/test_handler_segmentation_saver.py @@ -10,7 +10,7 @@ # limitations under the License. import os -import shutil +import tempfile import unittest import numpy as np @@ -28,54 +28,50 @@ class TestHandlerSegmentationSaver(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) def test_saved_content(self, output_ext): - default_dir = os.path.join(".", "tempdir") - shutil.rmtree(default_dir, ignore_errors=True) + with tempfile.TemporaryDirectory() as tempdir: - # set up engine - def _train_func(engine, batch): - return torch.randint(0, 255, (8, 1, 2, 2)).float() + # set up engine + def _train_func(engine, batch): + return torch.randint(0, 255, (8, 1, 2, 2)).float() - engine = Engine(_train_func) + engine = Engine(_train_func) - # set up testing handler - saver = SegmentationSaver(output_dir=default_dir, output_postfix="seg", output_ext=output_ext, scale=255) - saver.attach(engine) + # set up testing handler + saver = SegmentationSaver(output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255) + saver.attach(engine) - data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8)]}] - engine.run(data, max_epochs=1) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg" + output_ext) - self.assertTrue(os.path.exists(os.path.join(default_dir, filepath))) - shutil.rmtree(default_dir) + data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8)]}] + engine.run(data, max_epochs=1) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg" + output_ext) + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) def test_save_resized_content(self, output_ext): - default_dir = os.path.join(".", "tempdir") - shutil.rmtree(default_dir, ignore_errors=True) - - # set up engine - def _train_func(engine, batch): - return torch.randint(0, 255, (8, 1, 2, 2)).float() - - engine = Engine(_train_func) - - # set up testing handler - saver = SegmentationSaver(output_dir=default_dir, output_postfix="seg", output_ext=output_ext, scale=255) - saver.attach(engine) - - data = [ - { - "filename_or_obj": ["testfile" + str(i) for i in range(8)], - "spatial_shape": [(28, 28)] * 8, - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - } - ] - engine.run(data, max_epochs=1) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg" + output_ext) - self.assertTrue(os.path.exists(os.path.join(default_dir, filepath))) - shutil.rmtree(default_dir) + with tempfile.TemporaryDirectory() as tempdir: + + # set up engine + def _train_func(engine, batch): + return torch.randint(0, 255, (8, 1, 2, 2)).float() + + engine = Engine(_train_func) + + # set up testing handler + saver = SegmentationSaver(output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255) + saver.attach(engine) + + data = [ + { + "filename_or_obj": ["testfile" + str(i) for i in range(8)], + "spatial_shape": [(28, 28)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + } + ] + engine.run(data, max_epochs=1) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg" + output_ext) + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) if __name__ == "__main__": diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index b17bf5828e..dab5a0ea14 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -12,7 +12,6 @@ import logging import os import re -import shutil import tempfile import unittest from io import StringIO @@ -116,32 +115,31 @@ def test_loss_file(self): key_to_handler = "test_logging" key_to_print = "myLoss" - tempdir = tempfile.mkdtemp() - filename = os.path.join(tempdir, "test_loss_stats.log") - handler = logging.FileHandler(filename, mode="w") + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_loss_stats.log") + handler = logging.FileHandler(filename, mode="w") - # set up engine - def _train_func(engine, batch): - return torch.tensor(0.0) + # set up engine + def _train_func(engine, batch): + return torch.tensor(0.0) - engine = Engine(_train_func) + engine = Engine(_train_func) - # set up testing handler - stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=handler) - stats_handler.attach(engine) + # set up testing handler + stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print, logger_handler=handler) + stats_handler.attach(engine) - engine.run(range(3), max_epochs=2) - handler.stream.close() - stats_handler.logger.removeHandler(handler) - with open(filename, "r") as f: - output_str = f.read() - grep = re.compile(f".*{key_to_handler}.*") - has_key_word = re.compile(f".*{key_to_print}.*") - for idx, line in enumerate(output_str.split("\n")): - if grep.match(line): - if idx in [1, 2, 3, 6, 7, 8]: - self.assertTrue(has_key_word.match(line)) - shutil.rmtree(tempdir) + engine.run(range(3), max_epochs=2) + handler.stream.close() + stats_handler.logger.removeHandler(handler) + with open(filename, "r") as f: + output_str = f.read() + grep = re.compile(f".*{key_to_handler}.*") + has_key_word = re.compile(f".*{key_to_print}.*") + for idx, line in enumerate(output_str.split("\n")): + if grep.match(line): + if idx in [1, 2, 3, 6, 7, 8]: + self.assertTrue(has_key_word.match(line)) def test_exception(self): logging.basicConfig(level=logging.INFO) diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py index dc2b84e3d4..e1cda3be65 100644 --- a/tests/test_handler_tb_image.py +++ b/tests/test_handler_tb_image.py @@ -10,8 +10,6 @@ # limitations under the License. import glob -import os -import shutil import tempfile import unittest @@ -28,25 +26,22 @@ class TestHandlerTBImage(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_tb_image_shape(self, shape): - tempdir = tempfile.mkdtemp() - shutil.rmtree(tempdir, ignore_errors=True) + with tempfile.TemporaryDirectory() as tempdir: - # set up engine - def _train_func(engine, batch): - return torch.zeros((1, 1, 10, 10)) + # set up engine + def _train_func(engine, batch): + return torch.zeros((1, 1, 10, 10)) - engine = Engine(_train_func) + engine = Engine(_train_func) - # set up testing handler - stats_handler = TensorBoardImageHandler(log_dir=tempdir) - engine.add_event_handler(Events.ITERATION_COMPLETED, stats_handler) + # set up testing handler + stats_handler = TensorBoardImageHandler(log_dir=tempdir) + engine.add_event_handler(Events.ITERATION_COMPLETED, stats_handler) - data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape))) - engine.run(data, epoch_length=10, max_epochs=1) + data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape))) + engine.run(data, epoch_length=10, max_epochs=1) - self.assertTrue(os.path.exists(tempdir)) - self.assertTrue(len(glob.glob(tempdir)) > 0) - shutil.rmtree(tempdir) + self.assertTrue(len(glob.glob(tempdir)) > 0) if __name__ == "__main__": diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py index de6da0faf8..ab356e74b4 100644 --- a/tests/test_handler_tb_stats.py +++ b/tests/test_handler_tb_stats.py @@ -10,8 +10,6 @@ # limitations under the License. import glob -import os -import shutil import tempfile import unittest @@ -23,57 +21,51 @@ class TestHandlerTBStats(unittest.TestCase): def test_metrics_print(self): - tempdir = tempfile.mkdtemp() - shutil.rmtree(tempdir, ignore_errors=True) + with tempfile.TemporaryDirectory() as tempdir: - # set up engine - def _train_func(engine, batch): - return batch + 1.0 + # set up engine + def _train_func(engine, batch): + return batch + 1.0 - engine = Engine(_train_func) + engine = Engine(_train_func) - # set up dummy metric - @engine.on(Events.EPOCH_COMPLETED) - def _update_metric(engine): - current_metric = engine.state.metrics.get("acc", 0.1) - engine.state.metrics["acc"] = current_metric + 0.1 + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get("acc", 0.1) + engine.state.metrics["acc"] = current_metric + 0.1 - # set up testing handler - stats_handler = TensorBoardStatsHandler(log_dir=tempdir) - stats_handler.attach(engine) - engine.run(range(3), max_epochs=2) - # check logging output - - self.assertTrue(os.path.exists(tempdir)) - shutil.rmtree(tempdir) + # set up testing handler + stats_handler = TensorBoardStatsHandler(log_dir=tempdir) + stats_handler.attach(engine) + engine.run(range(3), max_epochs=2) + # check logging output + self.assertTrue(len(glob.glob(tempdir)) > 0) def test_metrics_writer(self): - tempdir = tempfile.mkdtemp() - shutil.rmtree(tempdir, ignore_errors=True) - - # set up engine - def _train_func(engine, batch): - return batch + 1.0 - - engine = Engine(_train_func) - - # set up dummy metric - @engine.on(Events.EPOCH_COMPLETED) - def _update_metric(engine): - current_metric = engine.state.metrics.get("acc", 0.1) - engine.state.metrics["acc"] = current_metric + 0.1 - - # set up testing handler - writer = SummaryWriter(log_dir=tempdir) - stats_handler = TensorBoardStatsHandler( - writer, output_transform=lambda x: {"loss": x * 2.0}, global_epoch_transform=lambda x: x * 3.0 - ) - stats_handler.attach(engine) - engine.run(range(3), max_epochs=2) - # check logging output - self.assertTrue(os.path.exists(tempdir)) - self.assertTrue(len(glob.glob(tempdir)) > 0) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + + # set up engine + def _train_func(engine, batch): + return batch + 1.0 + + engine = Engine(_train_func) + + # set up dummy metric + @engine.on(Events.EPOCH_COMPLETED) + def _update_metric(engine): + current_metric = engine.state.metrics.get("acc", 0.1) + engine.state.metrics["acc"] = current_metric + 0.1 + + # set up testing handler + writer = SummaryWriter(log_dir=tempdir) + stats_handler = TensorBoardStatsHandler( + writer, output_transform=lambda x: {"loss": x * 2.0}, global_epoch_transform=lambda x: x * 3.0 + ) + stats_handler.attach(engine) + engine.run(range(3), max_epochs=2) + # check logging output + self.assertTrue(len(glob.glob(tempdir)) > 0) if __name__ == "__main__": diff --git a/tests/test_img2tensorboard.py b/tests/test_img2tensorboard.py index decf36c404..99761b4d11 100644 --- a/tests/test_img2tensorboard.py +++ b/tests/test_img2tensorboard.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tempfile import unittest import numpy as np @@ -21,32 +20,31 @@ class TestImg2Tensorboard(unittest.TestCase): def test_write_gray(self): - with tempfile.TemporaryDirectory() as out_dir: - nparr = np.ones(shape=(1, 32, 32, 32), dtype=np.float32) - summary_object_np = make_animated_gif_summary( - tag="test_summary_nparr.png", - image=nparr, - max_out=1, - animation_axes=(3,), - image_axes=(1, 2), - scale_factor=253.0, - ) - assert isinstance( - summary_object_np, tensorboard.compat.proto.summary_pb2.Summary - ), "make_animated_gif_summary must return a tensorboard.summary object from numpy array" - - tensorarr = torch.tensor(nparr) - summary_object_tensor = make_animated_gif_summary( - tag="test_summary_tensorarr.png", - image=tensorarr, - max_out=1, - animation_axes=(3,), - image_axes=(1, 2), - scale_factor=253.0, - ) - assert isinstance( - summary_object_tensor, tensorboard.compat.proto.summary_pb2.Summary - ), "make_animated_gif_summary must return a tensorboard.summary object from tensor input" + nparr = np.ones(shape=(1, 32, 32, 32), dtype=np.float32) + summary_object_np = make_animated_gif_summary( + tag="test_summary_nparr.png", + image=nparr, + max_out=1, + animation_axes=(3,), + image_axes=(1, 2), + scale_factor=253.0, + ) + assert isinstance( + summary_object_np, tensorboard.compat.proto.summary_pb2.Summary + ), "make_animated_gif_summary must return a tensorboard.summary object from numpy array" + + tensorarr = torch.tensor(nparr) + summary_object_tensor = make_animated_gif_summary( + tag="test_summary_tensorarr.png", + image=tensorarr, + max_out=1, + animation_axes=(3,), + image_axes=(1, 2), + scale_factor=253.0, + ) + assert isinstance( + summary_object_tensor, tensorboard.compat.proto.summary_pb2.Summary + ), "make_animated_gif_summary must return a tensorboard.summary object from tensor input" if __name__ == "__main__": diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py index 097b07433a..8c864214c0 100644 --- a/tests/test_integration_sliding_window.py +++ b/tests/test_integration_sliding_window.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -77,14 +76,13 @@ def tearDown(self): os.remove(self.seg_name) def test_training(self): - tempdir = tempfile.mkdtemp() - output_file = run_test( - batch_size=2, img_name=self.img_name, seg_name=self.seg_name, output_dir=tempdir, device=self.device - ) - output_image = nib.load(output_file).get_fdata() - np.testing.assert_allclose(np.sum(output_image), 33621) - np.testing.assert_allclose(output_image.shape, (28, 25, 63, 1)) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + output_file = run_test( + batch_size=2, img_name=self.img_name, seg_name=self.seg_name, output_dir=tempdir, device=self.device + ) + output_image = nib.load(output_file).get_fdata() + np.testing.assert_allclose(np.sum(output_image), 33621) + np.testing.assert_allclose(output_image.shape, (28, 25, 63, 1)) if __name__ == "__main__": diff --git a/tests/test_load_decathalon_datalist.py b/tests/test_load_decathalon_datalist.py index b476f7d68d..82a0f72c1b 100644 --- a/tests/test_load_decathalon_datalist.py +++ b/tests/test_load_decathalon_datalist.py @@ -11,7 +11,6 @@ import json import os -import shutil import tempfile import unittest @@ -20,85 +19,82 @@ class TestLoadDecathalonDatalist(unittest.TestCase): def test_seg_values(self): - tempdir = tempfile.mkdtemp() - test_data = { - "name": "Spleen", - "description": "Spleen Segmentation", - "labels": {"0": "background", "1": "spleen"}, - "training": [ - {"image": "spleen_19.nii.gz", "label": "spleen_19.nii.gz"}, - {"image": "spleen_31.nii.gz", "label": "spleen_31.nii.gz"}, - ], - "test": ["spleen_15.nii.gz", "spleen_23.nii.gz"], - } - json_str = json.dumps(test_data) - file_path = os.path.join(tempdir, "test_data.json") - with open(file_path, "w") as json_file: - json_file.write(json_str) - result = load_decathalon_datalist(file_path, True, "training", tempdir) - self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) - self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + test_data = { + "name": "Spleen", + "description": "Spleen Segmentation", + "labels": {"0": "background", "1": "spleen"}, + "training": [ + {"image": "spleen_19.nii.gz", "label": "spleen_19.nii.gz"}, + {"image": "spleen_31.nii.gz", "label": "spleen_31.nii.gz"}, + ], + "test": ["spleen_15.nii.gz", "spleen_23.nii.gz"], + } + json_str = json.dumps(test_data) + file_path = os.path.join(tempdir, "test_data.json") + with open(file_path, "w") as json_file: + json_file.write(json_str) + result = load_decathalon_datalist(file_path, True, "training", tempdir) + self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) + self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) def test_cls_values(self): - tempdir = tempfile.mkdtemp() - test_data = { - "name": "ChestXRay", - "description": "Chest X-ray classification", - "labels": {"0": "background", "1": "chest"}, - "training": [{"image": "chest_19.nii.gz", "label": 0}, {"image": "chest_31.nii.gz", "label": 1}], - "test": ["chest_15.nii.gz", "chest_23.nii.gz"], - } - json_str = json.dumps(test_data) - file_path = os.path.join(tempdir, "test_data.json") - with open(file_path, "w") as json_file: - json_file.write(json_str) - result = load_decathalon_datalist(file_path, False, "training", tempdir) - self.assertEqual(result[0]["image"], os.path.join(tempdir, "chest_19.nii.gz")) - self.assertEqual(result[0]["label"], 0) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + test_data = { + "name": "ChestXRay", + "description": "Chest X-ray classification", + "labels": {"0": "background", "1": "chest"}, + "training": [{"image": "chest_19.nii.gz", "label": 0}, {"image": "chest_31.nii.gz", "label": 1}], + "test": ["chest_15.nii.gz", "chest_23.nii.gz"], + } + json_str = json.dumps(test_data) + file_path = os.path.join(tempdir, "test_data.json") + with open(file_path, "w") as json_file: + json_file.write(json_str) + result = load_decathalon_datalist(file_path, False, "training", tempdir) + self.assertEqual(result[0]["image"], os.path.join(tempdir, "chest_19.nii.gz")) + self.assertEqual(result[0]["label"], 0) def test_seg_no_basedir(self): - tempdir = tempfile.mkdtemp() - test_data = { - "name": "Spleen", - "description": "Spleen Segmentation", - "labels": {"0": "background", "1": "spleen"}, - "training": [ - { - "image": os.path.join(tempdir, "spleen_19.nii.gz"), - "label": os.path.join(tempdir, "spleen_19.nii.gz"), - }, - { - "image": os.path.join(tempdir, "spleen_31.nii.gz"), - "label": os.path.join(tempdir, "spleen_31.nii.gz"), - }, - ], - "test": [os.path.join(tempdir, "spleen_15.nii.gz"), os.path.join(tempdir, "spleen_23.nii.gz")], - } - json_str = json.dumps(test_data) - file_path = os.path.join(tempdir, "test_data.json") - with open(file_path, "w") as json_file: - json_file.write(json_str) - result = load_decathalon_datalist(file_path, True, "training", None) - self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) - self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) + with tempfile.TemporaryDirectory() as tempdir: + test_data = { + "name": "Spleen", + "description": "Spleen Segmentation", + "labels": {"0": "background", "1": "spleen"}, + "training": [ + { + "image": os.path.join(tempdir, "spleen_19.nii.gz"), + "label": os.path.join(tempdir, "spleen_19.nii.gz"), + }, + { + "image": os.path.join(tempdir, "spleen_31.nii.gz"), + "label": os.path.join(tempdir, "spleen_31.nii.gz"), + }, + ], + "test": [os.path.join(tempdir, "spleen_15.nii.gz"), os.path.join(tempdir, "spleen_23.nii.gz")], + } + json_str = json.dumps(test_data) + file_path = os.path.join(tempdir, "test_data.json") + with open(file_path, "w") as json_file: + json_file.write(json_str) + result = load_decathalon_datalist(file_path, True, "training", None) + self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_19.nii.gz")) + self.assertEqual(result[0]["label"], os.path.join(tempdir, "spleen_19.nii.gz")) def test_seg_no_labels(self): - tempdir = tempfile.mkdtemp() - test_data = { - "name": "Spleen", - "description": "Spleen Segmentation", - "labels": {"0": "background", "1": "spleen"}, - "test": ["spleen_15.nii.gz", "spleen_23.nii.gz"], - } - json_str = json.dumps(test_data) - file_path = os.path.join(tempdir, "test_data.json") - with open(file_path, "w") as json_file: - json_file.write(json_str) - result = load_decathalon_datalist(file_path, True, "test", tempdir) - self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_15.nii.gz")) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + test_data = { + "name": "Spleen", + "description": "Spleen Segmentation", + "labels": {"0": "background", "1": "spleen"}, + "test": ["spleen_15.nii.gz", "spleen_23.nii.gz"], + } + json_str = json.dumps(test_data) + file_path = os.path.join(tempdir, "test_data.json") + with open(file_path, "w") as json_file: + json_file.write(json_str) + result = load_decathalon_datalist(file_path, True, "test", tempdir) + self.assertEqual(result[0]["image"], os.path.join(tempdir, "spleen_15.nii.gz")) if __name__ == "__main__": diff --git a/tests/test_load_nifti.py b/tests/test_load_nifti.py index 4ecc156ce0..a3466d6da6 100644 --- a/tests/test_load_nifti.py +++ b/tests/test_load_nifti.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -41,11 +40,11 @@ class TestLoadNifti(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, filenames, expected_shape): test_image = np.random.randint(0, 2, size=[128, 128, 128]) - tempdir = tempfile.mkdtemp() - for i, name in enumerate(filenames): - filenames[i] = os.path.join(tempdir, name) - nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) - result = LoadNifti(**input_param)(filenames) + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) + result = LoadNifti(**input_param)(filenames) if isinstance(result, tuple): result, header = result @@ -54,7 +53,6 @@ def test_shape(self, input_param, filenames, expected_shape): if input_param["as_closest_canonical"]: np.testing.asesrt_allclose(header["original_affine"], np.eye(4)) self.assertTupleEqual(result.shape, expected_shape) - shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_load_niftid.py b/tests/test_load_niftid.py index 31e78599bf..d46c8a865c 100644 --- a/tests/test_load_niftid.py +++ b/tests/test_load_niftid.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -30,14 +29,14 @@ class TestLoadNiftid(unittest.TestCase): def test_shape(self, input_param, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) test_data = dict() - tempdir = tempfile.mkdtemp() - for key in KEYS: - nib.save(test_image, os.path.join(tempdir, key + ".nii.gz")) - test_data.update({key: os.path.join(tempdir, key + ".nii.gz")}) - result = LoadNiftid(**input_param)(test_data) + with tempfile.TemporaryDirectory() as tempdir: + for key in KEYS: + nib.save(test_image, os.path.join(tempdir, key + ".nii.gz")) + test_data.update({key: os.path.join(tempdir, key + ".nii.gz")}) + result = LoadNiftid(**input_param)(test_data) + for key in KEYS: self.assertTupleEqual(result[key].shape, expected_shape) - shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_load_numpy.py b/tests/test_load_numpy.py index ae6ae910f0..d65087531b 100644 --- a/tests/test_load_numpy.py +++ b/tests/test_load_numpy.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -22,70 +21,60 @@ class TestLoadNumpy(unittest.TestCase): def test_npy(self): test_data = np.random.randint(0, 256, size=[3, 4, 4]) - tempdir = tempfile.mkdtemp() - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npy") + np.save(filepath, test_data) - result = LoadNumpy()(filepath) + result = LoadNumpy()(filepath) self.assertTupleEqual(result[1]["spatial_shape"], test_data.shape) self.assertTupleEqual(result[0].shape, test_data.shape) np.testing.assert_allclose(result[0], test_data) - shutil.rmtree(tempdir) - def test_npz1(self): test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - tempdir = tempfile.mkdtemp() - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data1) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npy") + np.save(filepath, test_data1) - result = LoadNumpy()(filepath) + result = LoadNumpy()(filepath) self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) self.assertTupleEqual(result[0].shape, test_data1.shape) np.testing.assert_allclose(result[0], test_data1) - shutil.rmtree(tempdir) - def test_npz2(self): test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) - tempdir = tempfile.mkdtemp() - filepath = os.path.join(tempdir, "test_data.npz") - np.savez(filepath, test_data1, test_data2) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npz") + np.savez(filepath, test_data1, test_data2) - result = LoadNumpy()(filepath) + result = LoadNumpy()(filepath) self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) - shutil.rmtree(tempdir) - def test_npz3(self): test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) - tempdir = tempfile.mkdtemp() - filepath = os.path.join(tempdir, "test_data.npz") - np.savez(filepath, test1=test_data1, test2=test_data2) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npz") + np.savez(filepath, test1=test_data1, test2=test_data2) - result = LoadNumpy(npz_keys=["test1", "test2"])(filepath) + result = LoadNumpy(npz_keys=["test1", "test2"])(filepath) self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape) self.assertTupleEqual(result[0].shape, (2, 3, 4, 4)) np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2])) - shutil.rmtree(tempdir) - def test_npy_pickle(self): test_data = {"test": np.random.randint(0, 256, size=[3, 4, 4])} - tempdir = tempfile.mkdtemp() - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data, allow_pickle=True) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npy") + np.save(filepath, test_data, allow_pickle=True) - result = LoadNumpy(data_only=True, dtype=None)(filepath).item() + result = LoadNumpy(data_only=True, dtype=None)(filepath).item() self.assertTupleEqual(result["test"].shape, test_data["test"].shape) np.testing.assert_allclose(result["test"], test_data["test"]) - shutil.rmtree(tempdir) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_load_numpyd.py b/tests/test_load_numpyd.py index 666387e9d4..9abe0b0daf 100644 --- a/tests/test_load_numpyd.py +++ b/tests/test_load_numpyd.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -22,70 +21,60 @@ class TestLoadNumpyd(unittest.TestCase): def test_npy(self): test_data = np.random.randint(0, 256, size=[3, 4, 4]) - tempdir = tempfile.mkdtemp() - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npy") + np.save(filepath, test_data) - result = LoadNumpyd(keys="mask")({"mask": filepath}) + result = LoadNumpyd(keys="mask")({"mask": filepath}) self.assertTupleEqual(result["mask_meta_dict"]["spatial_shape"], test_data.shape) self.assertTupleEqual(result["mask"].shape, test_data.shape) np.testing.assert_allclose(result["mask"], test_data) - shutil.rmtree(tempdir) - def test_npz1(self): test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) - tempdir = tempfile.mkdtemp() - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data1) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npy") + np.save(filepath, test_data1) - result = LoadNumpyd(keys="mask")({"mask": filepath}) + result = LoadNumpyd(keys="mask")({"mask": filepath}) self.assertTupleEqual(result["mask_meta_dict"]["spatial_shape"], test_data1.shape) self.assertTupleEqual(result["mask"].shape, test_data1.shape) np.testing.assert_allclose(result["mask"], test_data1) - shutil.rmtree(tempdir) - def test_npz2(self): test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) - tempdir = tempfile.mkdtemp() - filepath = os.path.join(tempdir, "test_data.npz") - np.savez(filepath, test_data1, test_data2) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npz") + np.savez(filepath, test_data1, test_data2) - result = LoadNumpyd(keys="mask")({"mask": filepath}) + result = LoadNumpyd(keys="mask")({"mask": filepath}) self.assertTupleEqual(result["mask_meta_dict"]["spatial_shape"], test_data1.shape) self.assertTupleEqual(result["mask"].shape, (2, 3, 4, 4)) np.testing.assert_allclose(result["mask"], np.stack([test_data1, test_data2])) - shutil.rmtree(tempdir) - def test_npz3(self): test_data1 = np.random.randint(0, 256, size=[3, 4, 4]) test_data2 = np.random.randint(0, 256, size=[3, 4, 4]) - tempdir = tempfile.mkdtemp() - filepath = os.path.join(tempdir, "test_data.npz") - np.savez(filepath, test1=test_data1, test2=test_data2) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npz") + np.savez(filepath, test1=test_data1, test2=test_data2) - result = LoadNumpyd(keys="mask", npz_keys=["test1", "test2"])({"mask": filepath}) + result = LoadNumpyd(keys="mask", npz_keys=["test1", "test2"])({"mask": filepath}) self.assertTupleEqual(result["mask_meta_dict"]["spatial_shape"], test_data1.shape) self.assertTupleEqual(result["mask"].shape, (2, 3, 4, 4)) np.testing.assert_allclose(result["mask"], np.stack([test_data1, test_data2])) - shutil.rmtree(tempdir) - def test_npy_pickle(self): test_data = {"test": np.random.randint(0, 256, size=[3, 4, 4])} - tempdir = tempfile.mkdtemp() - filepath = os.path.join(tempdir, "test_data.npy") - np.save(filepath, test_data, allow_pickle=True) + with tempfile.TemporaryDirectory() as tempdir: + filepath = os.path.join(tempdir, "test_data.npy") + np.save(filepath, test_data, allow_pickle=True) - result = LoadNumpyd(keys="mask", dtype=None)({"mask": filepath})["mask"].item() + result = LoadNumpyd(keys="mask", dtype=None)({"mask": filepath})["mask"].item() self.assertTupleEqual(result["test"].shape, test_data["test"].shape) np.testing.assert_allclose(result["test"], test_data["test"]) - shutil.rmtree(tempdir) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_load_png.py b/tests/test_load_png.py index 2a85638b91..929ee1536d 100644 --- a/tests/test_load_png.py +++ b/tests/test_load_png.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -31,18 +30,17 @@ class TestLoadPNG(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, data_shape, filenames, expected_shape, meta_shape): test_image = np.random.randint(0, 256, size=data_shape) - tempdir = tempfile.mkdtemp() - for i, name in enumerate(filenames): - filenames[i] = os.path.join(tempdir, name) - Image.fromarray(test_image.astype("uint8")).save(filenames[i]) - result = LoadPNG()(filenames) + with tempfile.TemporaryDirectory() as tempdir: + for i, name in enumerate(filenames): + filenames[i] = os.path.join(tempdir, name) + Image.fromarray(test_image.astype("uint8")).save(filenames[i]) + result = LoadPNG()(filenames) self.assertTupleEqual(result[1]["spatial_shape"], meta_shape) self.assertTupleEqual(result[0].shape, expected_shape) if result[0].shape == test_image.shape: np.testing.assert_allclose(result[0], test_image) else: np.testing.assert_allclose(result[0], np.tile(test_image, [result[0].shape[0], 1, 1])) - shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_load_pngd.py b/tests/test_load_pngd.py index 81c3754b7b..6be3197d8f 100644 --- a/tests/test_load_pngd.py +++ b/tests/test_load_pngd.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -29,15 +28,14 @@ class TestLoadPNGd(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) def test_shape(self, input_param, expected_shape): test_image = np.random.randint(0, 256, size=[128, 128, 3]) - tempdir = tempfile.mkdtemp() - test_data = dict() - for key in KEYS: - Image.fromarray(test_image.astype("uint8")).save(os.path.join(tempdir, key + ".png")) - test_data.update({key: os.path.join(tempdir, key + ".png")}) - result = LoadPNGd(**input_param)(test_data) + with tempfile.TemporaryDirectory() as tempdir: + test_data = dict() + for key in KEYS: + Image.fromarray(test_image.astype("uint8")).save(os.path.join(tempdir, key + ".png")) + test_data.update({key: os.path.join(tempdir, key + ".png")}) + result = LoadPNGd(**input_param)(test_data) for key in KEYS: self.assertTupleEqual(result[key].shape, expected_shape) - shutil.rmtree(tempdir) if __name__ == "__main__": diff --git a/tests/test_nifti_dataset.py b/tests/test_nifti_dataset.py index 0523a0d350..459bc94a1d 100644 --- a/tests/test_nifti_dataset.py +++ b/tests/test_nifti_dataset.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -38,85 +37,88 @@ def __call__(self, data): class TestNiftiDataset(unittest.TestCase): def test_dataset(self): - tempdir = tempfile.mkdtemp() - full_names, ref_data = [], [] - for filename in FILENAMES: - test_image = np.random.randint(0, 2, size=(4, 4, 4)) - ref_data.append(test_image) - save_path = os.path.join(tempdir, filename) - full_names.append(save_path) - nib.save(nib.Nifti1Image(test_image, np.eye(4)), save_path) - - # default loading no meta - dataset = NiftiDataset(full_names) - for d, ref in zip(dataset, ref_data): - np.testing.assert_allclose(d, ref, atol=1e-3) - - # loading no meta, int - dataset = NiftiDataset(full_names, dtype=np.float16) - for d, _ in zip(dataset, ref_data): - self.assertEqual(d.dtype, np.float16) - - # loading with meta, no transform - dataset = NiftiDataset(full_names, image_only=False) - for d_tuple, ref in zip(dataset, ref_data): - d, meta = d_tuple - np.testing.assert_allclose(d, ref, atol=1e-3) - np.testing.assert_allclose(meta["original_affine"], np.eye(4)) - - # loading image/label, no meta - dataset = NiftiDataset(full_names, seg_files=full_names, image_only=True) - for d_tuple, ref in zip(dataset, ref_data): - img, seg = d_tuple - np.testing.assert_allclose(img, ref, atol=1e-3) - np.testing.assert_allclose(seg, ref, atol=1e-3) - - # loading image/label, no meta - dataset = NiftiDataset(full_names, transform=lambda x: x + 1, image_only=True) - for d, ref in zip(dataset, ref_data): - np.testing.assert_allclose(d, ref + 1, atol=1e-3) - - # set seg transform, but no seg_files - with self.assertRaises(TypeError): - dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) - _ = dataset[0] - - # set seg transform, but no seg_files - with self.assertRaises(TypeError): - dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) - _ = dataset[0] - - # loading image/label, with meta - dataset = NiftiDataset( - full_names, transform=lambda x: x + 1, seg_files=full_names, seg_transform=lambda x: x + 2, image_only=False - ) - for d_tuple, ref in zip(dataset, ref_data): - img, seg, meta = d_tuple - np.testing.assert_allclose(img, ref + 1, atol=1e-3) - np.testing.assert_allclose(seg, ref + 2, atol=1e-3) - np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) - - # loading image/label, with meta - dataset = NiftiDataset( - full_names, transform=lambda x: x + 1, seg_files=full_names, labels=[1, 2, 3], image_only=False - ) - for idx, (d_tuple, ref) in enumerate(zip(dataset, ref_data)): - img, seg, label, meta = d_tuple - np.testing.assert_allclose(img, ref + 1, atol=1e-3) - np.testing.assert_allclose(seg, ref, atol=1e-3) - np.testing.assert_allclose(idx + 1, label) - np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) - - # loading image/label, with sync. transform - dataset = NiftiDataset( - full_names, transform=RandTest(), seg_files=full_names, seg_transform=RandTest(), image_only=False - ) - for d_tuple, ref in zip(dataset, ref_data): - img, seg, meta = d_tuple - np.testing.assert_allclose(img, seg, atol=1e-3) - self.assertTrue(not np.allclose(img, ref)) - np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) - shutil.rmtree(tempdir) + with tempfile.TemporaryDirectory() as tempdir: + full_names, ref_data = [], [] + for filename in FILENAMES: + test_image = np.random.randint(0, 2, size=(4, 4, 4)) + ref_data.append(test_image) + save_path = os.path.join(tempdir, filename) + full_names.append(save_path) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), save_path) + + # default loading no meta + dataset = NiftiDataset(full_names) + for d, ref in zip(dataset, ref_data): + np.testing.assert_allclose(d, ref, atol=1e-3) + + # loading no meta, int + dataset = NiftiDataset(full_names, dtype=np.float16) + for d, _ in zip(dataset, ref_data): + self.assertEqual(d.dtype, np.float16) + + # loading with meta, no transform + dataset = NiftiDataset(full_names, image_only=False) + for d_tuple, ref in zip(dataset, ref_data): + d, meta = d_tuple + np.testing.assert_allclose(d, ref, atol=1e-3) + np.testing.assert_allclose(meta["original_affine"], np.eye(4)) + + # loading image/label, no meta + dataset = NiftiDataset(full_names, seg_files=full_names, image_only=True) + for d_tuple, ref in zip(dataset, ref_data): + img, seg = d_tuple + np.testing.assert_allclose(img, ref, atol=1e-3) + np.testing.assert_allclose(seg, ref, atol=1e-3) + + # loading image/label, no meta + dataset = NiftiDataset(full_names, transform=lambda x: x + 1, image_only=True) + for d, ref in zip(dataset, ref_data): + np.testing.assert_allclose(d, ref + 1, atol=1e-3) + + # set seg transform, but no seg_files + with self.assertRaises(TypeError): + dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) + _ = dataset[0] + + # set seg transform, but no seg_files + with self.assertRaises(TypeError): + dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True) + _ = dataset[0] + + # loading image/label, with meta + dataset = NiftiDataset( + full_names, + transform=lambda x: x + 1, + seg_files=full_names, + seg_transform=lambda x: x + 2, + image_only=False, + ) + for d_tuple, ref in zip(dataset, ref_data): + img, seg, meta = d_tuple + np.testing.assert_allclose(img, ref + 1, atol=1e-3) + np.testing.assert_allclose(seg, ref + 2, atol=1e-3) + np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) + + # loading image/label, with meta + dataset = NiftiDataset( + full_names, transform=lambda x: x + 1, seg_files=full_names, labels=[1, 2, 3], image_only=False + ) + for idx, (d_tuple, ref) in enumerate(zip(dataset, ref_data)): + img, seg, label, meta = d_tuple + np.testing.assert_allclose(img, ref + 1, atol=1e-3) + np.testing.assert_allclose(seg, ref, atol=1e-3) + np.testing.assert_allclose(idx + 1, label) + np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) + + # loading image/label, with sync. transform + dataset = NiftiDataset( + full_names, transform=RandTest(), seg_files=full_names, seg_transform=RandTest(), image_only=False + ) + for d_tuple, ref in zip(dataset, ref_data): + img, seg, meta = d_tuple + np.testing.assert_allclose(img, seg, atol=1e-3) + self.assertTrue(not np.allclose(img, ref)) + np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3) if __name__ == "__main__": diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py index 86d4db2bb1..9a2f8ac75c 100644 --- a/tests/test_nifti_rw.py +++ b/tests/test_nifti_rw.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -107,75 +106,71 @@ def test_consistency(self): os.remove(test_image) def test_write_2d(self): - out_dir = tempfile.mkdtemp() - image_name = os.path.join(out_dir, "test.nii.gz") - img = np.arange(6).reshape((2, 3)) - write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) - - image_name = os.path.join(out_dir, "test1.nii.gz") - img = np.arange(5).reshape((1, 5)) - write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 1, 3, 5])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1])) - shutil.rmtree(out_dir) + with tempfile.TemporaryDirectory() as out_dir: + image_name = os.path.join(out_dir, "test.nii.gz") + img = np.arange(6).reshape((2, 3)) + write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) + + image_name = os.path.join(out_dir, "test1.nii.gz") + img = np.arange(5).reshape((1, 5)) + write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 1, 3, 5])) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1])) def test_write_3d(self): - out_dir = tempfile.mkdtemp() - image_name = os.path.join(out_dir, "test.nii.gz") - img = np.arange(6).reshape((1, 2, 3)) - write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[0, 1, 2], [3, 4, 5]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) - - image_name = os.path.join(out_dir, "test1.nii.gz") - img = np.arange(5).reshape((1, 1, 5)) - write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) - shutil.rmtree(out_dir) + with tempfile.TemporaryDirectory() as out_dir: + image_name = os.path.join(out_dir, "test.nii.gz") + img = np.arange(6).reshape((1, 2, 3)) + write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[[0, 1, 2], [3, 4, 5]]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) + + image_name = os.path.join(out_dir, "test1.nii.gz") + img = np.arange(5).reshape((1, 1, 5)) + write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5])) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) def test_write_4d(self): - out_dir = tempfile.mkdtemp() - image_name = os.path.join(out_dir, "test.nii.gz") - img = np.arange(6).reshape((1, 1, 3, 2)) - write_nifti(img, image_name, affine=np.diag([1.4, 1]), target_affine=np.diag([1, 1.4, 1])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[[0, 1], [2, 3], [4, 5]]]]) - np.testing.assert_allclose(out.affine, np.diag([1, 1.4, 1, 1])) - - image_name = os.path.join(out_dir, "test1.nii.gz") - img = np.arange(5).reshape((1, 1, 5, 1)) - write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) - shutil.rmtree(out_dir) + with tempfile.TemporaryDirectory() as out_dir: + image_name = os.path.join(out_dir, "test.nii.gz") + img = np.arange(6).reshape((1, 1, 3, 2)) + write_nifti(img, image_name, affine=np.diag([1.4, 1]), target_affine=np.diag([1, 1.4, 1])) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[[[0, 1], [2, 3], [4, 5]]]]) + np.testing.assert_allclose(out.affine, np.diag([1, 1.4, 1, 1])) + + image_name = os.path.join(out_dir, "test1.nii.gz") + img = np.arange(5).reshape((1, 1, 5, 1)) + write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5])) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]]) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) def test_write_5d(self): - out_dir = tempfile.mkdtemp() - image_name = os.path.join(out_dir, "test.nii.gz") - img = np.arange(12).reshape((1, 1, 3, 2, 2)) - write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) - out = nib.load(image_name) - np.testing.assert_allclose( - out.get_fdata(), - np.array([[[[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]], [[8.0, 9.0], [10.0, 11.0]]]]]), - ) - np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) - - image_name = os.path.join(out_dir, "test1.nii.gz") - img = np.arange(10).reshape((1, 1, 5, 1, 2)) - write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5])) - out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), np.array([[[[[0.0, 1.0]], [[4.0, 5.0]], [[8.0, 9.0]]]]])) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) - shutil.rmtree(out_dir) + with tempfile.TemporaryDirectory() as out_dir: + image_name = os.path.join(out_dir, "test.nii.gz") + img = np.arange(12).reshape((1, 1, 3, 2, 2)) + write_nifti(img, image_name, affine=np.diag([1]), target_affine=np.diag([1.4])) + out = nib.load(image_name) + np.testing.assert_allclose( + out.get_fdata(), + np.array([[[[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]], [[8.0, 9.0], [10.0, 11.0]]]]]), + ) + np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) + + image_name = os.path.join(out_dir, "test1.nii.gz") + img = np.arange(10).reshape((1, 1, 5, 1, 2)) + write_nifti(img, image_name, affine=np.diag([1, 1, 1, 3, 3]), target_affine=np.diag([1.4, 2.0, 2, 3, 5])) + out = nib.load(image_name) + np.testing.assert_allclose(out.get_fdata(), np.array([[[[[0.0, 1.0]], [[4.0, 5.0]], [[8.0, 9.0]]]]])) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) if __name__ == "__main__": diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py index d8a560a636..a40650920d 100644 --- a/tests/test_nifti_saver.py +++ b/tests/test_nifti_saver.py @@ -10,7 +10,7 @@ # limitations under the License. import os -import shutil +import tempfile import unittest import numpy as np @@ -21,52 +21,46 @@ class TestNiftiSaver(unittest.TestCase): def test_saved_content(self): - default_dir = os.path.join(".", "tempdir") - shutil.rmtree(default_dir, ignore_errors=True) + with tempfile.TemporaryDirectory() as tempdir: - saver = NiftiSaver(output_dir=default_dir, output_postfix="seg", output_ext=".nii.gz") + saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz") - meta_data = {"filename_or_obj": ["testfile" + str(i) for i in range(8)]} - saver.save_batch(torch.zeros(8, 1, 2, 2), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - self.assertTrue(os.path.exists(os.path.join(default_dir, filepath))) - shutil.rmtree(default_dir) + meta_data = {"filename_or_obj": ["testfile" + str(i) for i in range(8)]} + saver.save_batch(torch.zeros(8, 1, 2, 2), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) def test_saved_resize_content(self): - default_dir = os.path.join(".", "tempdir") - shutil.rmtree(default_dir, ignore_errors=True) + with tempfile.TemporaryDirectory() as tempdir: - saver = NiftiSaver(output_dir=default_dir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) + saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) - meta_data = { - "filename_or_obj": ["testfile" + str(i) for i in range(8)], - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - } - saver.save_batch(torch.randint(0, 255, (8, 8, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - self.assertTrue(os.path.exists(os.path.join(default_dir, filepath))) - shutil.rmtree(default_dir) + meta_data = { + "filename_or_obj": ["testfile" + str(i) for i in range(8)], + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + } + saver.save_batch(torch.randint(0, 255, (8, 8, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) def test_saved_3d_resize_content(self): - default_dir = os.path.join(".", "tempdir") - shutil.rmtree(default_dir, ignore_errors=True) + with tempfile.TemporaryDirectory() as tempdir: - saver = NiftiSaver(output_dir=default_dir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) + saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) - meta_data = { - "filename_or_obj": ["testfile" + str(i) for i in range(8)], - "spatial_shape": [(10, 10, 2)] * 8, - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - } - saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - self.assertTrue(os.path.exists(os.path.join(default_dir, filepath))) - shutil.rmtree(default_dir) + meta_data = { + "filename_or_obj": ["testfile" + str(i) for i in range(8)], + "spatial_shape": [(10, 10, 2)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + } + saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) if __name__ == "__main__": diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py index 3057eff3ee..36c7f9de67 100644 --- a/tests/test_persistentdataset.py +++ b/tests/test_persistentdataset.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -46,34 +45,33 @@ class TestDataset(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_shape(self, transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) - tempdir = tempfile.mkdtemp() - nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) - test_data = [ - { - "image": os.path.join(tempdir, "test_image1.nii.gz"), - "label": os.path.join(tempdir, "test_label1.nii.gz"), - "extra": os.path.join(tempdir, "test_extra1.nii.gz"), - }, - { - "image": os.path.join(tempdir, "test_image2.nii.gz"), - "label": os.path.join(tempdir, "test_label2.nii.gz"), - "extra": os.path.join(tempdir, "test_extra2.nii.gz"), - }, - ] + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) + nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) + test_data = [ + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": os.path.join(tempdir, "test_label1.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2.nii.gz"), + "label": os.path.join(tempdir, "test_label2.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2.nii.gz"), + }, + ] - dataset_precached = PersistentDataset(data=test_data, transform=transform, cache_dir=tempdir) - data1_precached = dataset_precached[0] - data2_precached = dataset_precached[1] + dataset_precached = PersistentDataset(data=test_data, transform=transform, cache_dir=tempdir) + data1_precached = dataset_precached[0] + data2_precached = dataset_precached[1] - dataset_postcached = PersistentDataset(data=test_data, transform=transform, cache_dir=tempdir) - data1_postcached = dataset_postcached[0] - data2_postcached = dataset_postcached[1] - shutil.rmtree(tempdir) + dataset_postcached = PersistentDataset(data=test_data, transform=transform, cache_dir=tempdir) + data1_postcached = dataset_postcached[0] + data2_postcached = dataset_postcached[1] if transform is None: self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) diff --git a/tests/test_plot_2d_or_3d_image.py b/tests/test_plot_2d_or_3d_image.py index e2cd978398..df14f3ed50 100644 --- a/tests/test_plot_2d_or_3d_image.py +++ b/tests/test_plot_2d_or_3d_image.py @@ -10,8 +10,6 @@ # limitations under the License. import glob -import os -import shutil import tempfile import unittest @@ -35,14 +33,11 @@ class TestPlot2dOr3dImage(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_tb_image_shape(self, shape): - tempdir = tempfile.mkdtemp() - shutil.rmtree(tempdir, ignore_errors=True) + with tempfile.TemporaryDirectory() as tempdir: - plot_2d_or_3d_image(torch.zeros(shape), 0, SummaryWriter(log_dir=tempdir)) + plot_2d_or_3d_image(torch.zeros(shape), 0, SummaryWriter(log_dir=tempdir)) - self.assertTrue(os.path.exists(tempdir)) - self.assertTrue(len(glob.glob(tempdir)) > 0) - shutil.rmtree(tempdir, ignore_errors=True) + self.assertTrue(len(glob.glob(tempdir)) > 0) if __name__ == "__main__": diff --git a/tests/test_png_rw.py b/tests/test_png_rw.py index cbde526374..94908ade3e 100644 --- a/tests/test_png_rw.py +++ b/tests/test_png_rw.py @@ -10,7 +10,6 @@ # limitations under the License. import os -import shutil import tempfile import unittest @@ -22,63 +21,57 @@ class TestPngWrite(unittest.TestCase): def test_write_gray(self): - out_dir = tempfile.mkdtemp() - image_name = os.path.join(out_dir, "test.png") - img = np.random.rand(2, 3) - img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) - out = np.asarray(Image.open(image_name)) - np.testing.assert_allclose(out, img_save_val) - shutil.rmtree(out_dir) + with tempfile.TemporaryDirectory() as out_dir: + image_name = os.path.join(out_dir, "test.png") + img = np.random.rand(2, 3) + img_save_val = (255 * img).astype(np.uint8) + write_png(img, image_name, scale=255) + out = np.asarray(Image.open(image_name)) + np.testing.assert_allclose(out, img_save_val) def test_write_gray_1height(self): - out_dir = tempfile.mkdtemp() - image_name = os.path.join(out_dir, "test.png") - img = np.random.rand(1, 3) - img_save_val = (65535 * img).astype(np.uint16) - write_png(img, image_name, scale=65535) - out = np.asarray(Image.open(image_name)) - np.testing.assert_allclose(out, img_save_val) - shutil.rmtree(out_dir) + with tempfile.TemporaryDirectory() as out_dir: + image_name = os.path.join(out_dir, "test.png") + img = np.random.rand(1, 3) + img_save_val = (65535 * img).astype(np.uint16) + write_png(img, image_name, scale=65535) + out = np.asarray(Image.open(image_name)) + np.testing.assert_allclose(out, img_save_val) def test_write_gray_1channel(self): - out_dir = tempfile.mkdtemp() - image_name = os.path.join(out_dir, "test.png") - img = np.random.rand(2, 3, 1) - img_save_val = (255 * img).astype(np.uint8).squeeze(2) - write_png(img, image_name, scale=255) - out = np.asarray(Image.open(image_name)) - np.testing.assert_allclose(out, img_save_val) - shutil.rmtree(out_dir) + with tempfile.TemporaryDirectory() as out_dir: + image_name = os.path.join(out_dir, "test.png") + img = np.random.rand(2, 3, 1) + img_save_val = (255 * img).astype(np.uint8).squeeze(2) + write_png(img, image_name, scale=255) + out = np.asarray(Image.open(image_name)) + np.testing.assert_allclose(out, img_save_val) def test_write_rgb(self): - out_dir = tempfile.mkdtemp() - image_name = os.path.join(out_dir, "test.png") - img = np.random.rand(2, 3, 3) - img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) - out = np.asarray(Image.open(image_name)) - np.testing.assert_allclose(out, img_save_val) - shutil.rmtree(out_dir) + with tempfile.TemporaryDirectory() as out_dir: + image_name = os.path.join(out_dir, "test.png") + img = np.random.rand(2, 3, 3) + img_save_val = (255 * img).astype(np.uint8) + write_png(img, image_name, scale=255) + out = np.asarray(Image.open(image_name)) + np.testing.assert_allclose(out, img_save_val) def test_write_2channels(self): - out_dir = tempfile.mkdtemp() - image_name = os.path.join(out_dir, "test.png") - img = np.random.rand(2, 3, 2) - img_save_val = (255 * img).astype(np.uint8) - write_png(img, image_name, scale=255) - out = np.asarray(Image.open(image_name)) - np.testing.assert_allclose(out, img_save_val) - shutil.rmtree(out_dir) + with tempfile.TemporaryDirectory() as out_dir: + image_name = os.path.join(out_dir, "test.png") + img = np.random.rand(2, 3, 2) + img_save_val = (255 * img).astype(np.uint8) + write_png(img, image_name, scale=255) + out = np.asarray(Image.open(image_name)) + np.testing.assert_allclose(out, img_save_val) def test_write_output_shape(self): - out_dir = tempfile.mkdtemp() - image_name = os.path.join(out_dir, "test.png") - img = np.random.rand(2, 2, 3) - write_png(img, image_name, (4, 4), scale=255) - out = np.asarray(Image.open(image_name)) - np.testing.assert_allclose(out.shape, (4, 4, 3)) - shutil.rmtree(out_dir) + with tempfile.TemporaryDirectory() as out_dir: + image_name = os.path.join(out_dir, "test.png") + img = np.random.rand(2, 2, 3) + write_png(img, image_name, (4, 4), scale=255) + out = np.asarray(Image.open(image_name)) + np.testing.assert_allclose(out.shape, (4, 4, 3)) if __name__ == "__main__": diff --git a/tests/test_png_saver.py b/tests/test_png_saver.py index 1cbaec72a7..9a2e573f45 100644 --- a/tests/test_png_saver.py +++ b/tests/test_png_saver.py @@ -10,7 +10,7 @@ # limitations under the License. import os -import shutil +import tempfile import unittest import torch @@ -20,48 +20,40 @@ class TestPNGSaver(unittest.TestCase): def test_saved_content(self): - default_dir = os.path.join(".", "tempdir") - shutil.rmtree(default_dir, ignore_errors=True) + with tempfile.TemporaryDirectory() as tempdir: - saver = PNGSaver(output_dir=default_dir, output_postfix="seg", output_ext=".png", scale=255) + saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) - meta_data = {"filename_or_obj": ["testfile" + str(i) for i in range(8)]} - saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(default_dir, filepath))) - shutil.rmtree(default_dir) + meta_data = {"filename_or_obj": ["testfile" + str(i) for i in range(8)]} + saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) def test_saved_content_three_channel(self): - default_dir = os.path.join(".", "tempdir") - shutil.rmtree(default_dir, ignore_errors=True) + with tempfile.TemporaryDirectory() as tempdir: - saver = PNGSaver(output_dir=default_dir, output_postfix="seg", output_ext=".png", scale=255) + saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) - meta_data = {"filename_or_obj": ["testfile" + str(i) for i in range(8)]} - saver.save_batch(torch.randint(1, 200, (8, 3, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(default_dir, filepath))) - shutil.rmtree(default_dir) + meta_data = {"filename_or_obj": ["testfile" + str(i) for i in range(8)]} + saver.save_batch(torch.randint(1, 200, (8, 3, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) def test_saved_content_spatial_size(self): - - default_dir = os.path.join(".", "tempdir") - shutil.rmtree(default_dir, ignore_errors=True) - - saver = PNGSaver(output_dir=default_dir, output_postfix="seg", output_ext=".png", scale=255) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) for i in range(8)], - "spatial_shape": [(4, 4) for i in range(8)], - } - saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(default_dir, filepath))) - - shutil.rmtree(default_dir) + with tempfile.TemporaryDirectory() as tempdir: + + saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) + + meta_data = { + "filename_or_obj": ["testfile" + str(i) for i in range(8)], + "spatial_shape": [(4, 4) for i in range(8)], + } + saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) if __name__ == "__main__":