diff --git a/modules/inverse_transforms_and_test_time_augmentations.ipynb b/modules/inverse_transforms_and_test_time_augmentations.ipynb index e446323c12..80e3c8cf81 100644 --- a/modules/inverse_transforms_and_test_time_augmentations.ipynb +++ b/modules/inverse_transforms_and_test_time_augmentations.ipynb @@ -123,7 +123,6 @@ " Dataset,\n", " pad_list_data_collate,\n", " TestTimeAugmentation,\n", - " decollate_batch,\n", ")\n", "from monai.inferers import sliding_window_inference\n", "from monai.losses import DiceLoss\n", @@ -228,8 +227,7 @@ " def __call__(self, data):\n", " d = dict(data)\n", " im = d[self.label_key]\n", - " _im = im.detach().cpu().numpy()\n", - " q = np.sum((_im > 0).reshape(-1, _im.shape[-1]), axis=0)\n", + " q = np.sum((im.array > 0).reshape(-1, im.array.shape[-1]), axis=0)\n", " _slice = np.where(q == np.max(q))[0][0]\n", " for key in self.keys:\n", " d[key] = d[key][..., _slice]\n", @@ -247,7 +245,7 @@ " fname = os.path.basename(\n", " data[key + \"_meta_dict\"][\"filename_or_obj\"])\n", " path = os.path.join(self.path, key, fname)\n", - " nib.save(nib.Nifti1Image(data[key].detach().cpu().numpy(), np.eye(4)), path)\n", + " nib.save(nib.Nifti1Image(data[key].array, np.eye(4)), path)\n", " d[key] = path\n", " return d\n", "\n", @@ -443,7 +441,7 @@ "def infer_seg(images, model, roi_size=(96, 96), sw_batch_size=4):\n", " val_outputs = sliding_window_inference(\n", " images, roi_size, sw_batch_size, model)\n", - " return torch.stack([post_trans(i) for i in decollate_batch(val_outputs)])\n", + " return pad_list_data_collate([post_trans(i) for i in val_outputs])\n", "\n", "\n", "# Create network, loss fn., etc.\n",