From 6eee290cb9ee7b8f9bcbad74f630ee5067417aea Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 23 May 2022 14:59:36 +0000 Subject: [PATCH 01/18] Update mil pipeline Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 212 +++++++++--------- 1 file changed, 109 insertions(+), 103 deletions(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index c37bd5196d..45da2ee931 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -20,74 +20,23 @@ import torch.multiprocessing as mp from monai.data import Dataset, load_decathlon_datalist -from monai.data.image_reader import WSIReader +from monai.data.wsi_reader import WSIReader from monai.metrics import Cumulative, CumulativeAverage -from monai.transforms import Transform, Compose, LoadImageD, RandFlipd, RandRotate90d, ScaleIntensityRangeD, ToTensord -from monai.apps.pathology.transforms import TileOnGridd +from monai.transforms import ( + Transform, + Compose, + LoadImaged, + RandFlipd, + GridPatchd, + RandRotate90d, + ScaleIntensityRanged, + ToTensord, +) +from monai.transforms.spatial.array import GridPatch +from monai.transforms.spatial.dictionary import GridPatchd from monai.networks.nets import milmodel -def parse_args(): - - parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) example of classification from WSI.") - parser.add_argument( - "--data_root", default="/PandaChallenge2020/train_images/", help="path to root folder of images" - ) - parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file") - - parser.add_argument("--num_classes", default=5, type=int, help="number of output classes") - parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm") - parser.add_argument( - "--tile_count", default=44, type=int, help="number of patches (instances) to extract from WSI image" - ) - parser.add_argument("--tile_size", default=256, type=int, help="size of square patch (instance) in pixels") - - parser.add_argument("--checkpoint", default=None, help="load existing checkpoint") - parser.add_argument( - "--validate", - action="store_true", - help="run only inference on the validation set, must specify the checkpoint argument", - ) - - parser.add_argument("--logdir", default=None, help="path to log directory to store Tensorboard logs") - - parser.add_argument("--epochs", default=50, type=int, help="number of training epochs") - parser.add_argument("--batch_size", default=4, type=int, help="batch size, the number of WSI images per gpu") - parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate") - - parser.add_argument("--weight_decay", default=0, type=float, help="optimizer weight decay") - parser.add_argument("--amp", action="store_true", help="use AMP, recommended") - parser.add_argument( - "--val_every", - default=1, - type=int, - help="run validation after this number of epochs, default 1 to run every epoch", - ) - parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading") - - ###for multigpu - parser.add_argument("--distributed", action="store_true", help="use multigpu training, recommended") - parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training") - parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training") - parser.add_argument( - "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training" - ) - parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") - - parser.add_argument( - "--quick", action="store_true", help="use a small subset of data for debugging" - ) # for debugging - - args = parser.parse_args() - - print("Argument values:") - for k, v in vars(args).items(): - print(k, "=>", v) - print("-----------------") - - return args - - def train_epoch(model, loader, optimizer, scaler, epoch, args): """One train epoch over the dataset""" @@ -278,35 +227,12 @@ def __call__(self, data): return d -def main(): - - args = parse_args() - - if args.dataset_json is None: - # download default json datalist - resource = "https://drive.google.com/uc?id=1L6PtKBlHHyUgTE4rVhRuOLTQKgD4tBRK" - dst = "./datalist_panda_0.json" - if not os.path.exists(dst): - gdown.download(resource, dst, quiet=False) - args.dataset_json = dst - - if args.distributed: - ngpus_per_node = torch.cuda.device_count() - args.optim_lr = ngpus_per_node * args.optim_lr / 2 # heuristic to scale up learning rate in multigpu setup - args.world_size = ngpus_per_node * args.world_size - - print("Multigpu", ngpus_per_node, "rescaled lr", args.optim_lr) - mp.spawn(main_worker, nprocs=ngpus_per_node, args=(args,)) - else: - main_worker(0, args) - - def list_data_collate(batch: collections.abc.Sequence): - ''' - Combine instances from a list of dicts into a single dict, by stacking them along first dim - [{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW} - followed by the default collate which will form a batch BxNx3xHxW - ''' + """ + Combine instances from a list of dicts into a single dict, by stacking them along first dim + [{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW} + followed by the default collate which will form a batch BxNx3xHxW + """ for i, item in enumerate(batch): data = item[0] @@ -352,27 +278,27 @@ def main_worker(gpu, args): train_transform = Compose( [ - LoadImageD(keys=["image"], reader=WSIReader, backend="TiffFile", dtype=np.uint8, level=1, image_only=True), + LoadImaged(keys=["image"], reader=WSIReader, backend="cupy", dtype=np.uint8, level=1, image_only=True), LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes), - TileOnGridd( + GridPatchd( keys=["image"], - tile_count=args.tile_count, - tile_size=args.tile_size, - random_offset=True, - background_val=255, - return_list_of_dicts=True, + patch_size=args.tile_size, + start_pos="random", + max_num_patches=args.tile_count, + overlap=0.0, + pad_opts={"constant_values": 255}, ), RandFlipd(keys=["image"], spatial_axis=0, prob=0.5), RandFlipd(keys=["image"], spatial_axis=1, prob=0.5), RandRotate90d(keys=["image"], prob=0.5), - ScaleIntensityRangeD(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), + ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), ToTensord(keys=["image", "label"]), ] ) valid_transform = Compose( [ - LoadImageD(keys=["image"], reader=WSIReader, backend="TiffFile", dtype=np.uint8, level=1, image_only=True), + LoadImaged(keys=["image"], reader=WSIReader, backend="TiffFile", dtype=np.uint8, level=1, image_only=True), LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes), TileOnGridd( keys=["image"], @@ -382,7 +308,7 @@ def main_worker(gpu, args): background_val=255, return_list_of_dicts=True, ), - ScaleIntensityRangeD(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), + ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), ToTensord(keys=["image", "label"]), ] ) @@ -540,5 +466,85 @@ def main_worker(gpu, args): print("ALL DONE") +def parse_args(): + + parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) example of classification from WSI.") + parser.add_argument( + "--data_root", default="/PandaChallenge2020/train_images/", help="path to root folder of images" + ) + parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file") + + parser.add_argument("--num_classes", default=5, type=int, help="number of output classes") + parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm") + parser.add_argument( + "--tile_count", default=44, type=int, help="number of patches (instances) to extract from WSI image" + ) + parser.add_argument("--tile_size", default=256, type=int, help="size of square patch (instance) in pixels") + + parser.add_argument("--checkpoint", default=None, help="load existing checkpoint") + parser.add_argument( + "--validate", + action="store_true", + help="run only inference on the validation set, must specify the checkpoint argument", + ) + + parser.add_argument("--logdir", default=None, help="path to log directory to store Tensorboard logs") + + parser.add_argument("--epochs", default=50, type=int, help="number of training epochs") + parser.add_argument("--batch_size", default=4, type=int, help="batch size, the number of WSI images per gpu") + parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate") + + parser.add_argument("--weight_decay", default=0, type=float, help="optimizer weight decay") + parser.add_argument("--amp", action="store_true", help="use AMP, recommended") + parser.add_argument( + "--val_every", + default=1, + type=int, + help="run validation after this number of epochs, default 1 to run every epoch", + ) + parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading") + + ###for multigpu + parser.add_argument("--distributed", action="store_true", help="use multigpu training, recommended") + parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training") + parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training") + parser.add_argument( + "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training" + ) + parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") + + parser.add_argument( + "--quick", action="store_true", help="use a small subset of data for debugging" + ) # for debugging + + args = parser.parse_args() + + print("Argument values:") + for k, v in vars(args).items(): + print(k, "=>", v) + print("-----------------") + + return args + + if __name__ == "__main__": - main() + + args = parse_args() + + if args.dataset_json is None: + # download default json datalist + resource = "https://drive.google.com/uc?id=1L6PtKBlHHyUgTE4rVhRuOLTQKgD4tBRK" + dst = "./datalist_panda_0.json" + if not os.path.exists(dst): + gdown.download(resource, dst, quiet=False) + args.dataset_json = dst + + if args.distributed: + ngpus_per_node = torch.cuda.device_count() + args.optim_lr = ngpus_per_node * args.optim_lr / 2 # heuristic to scale up learning rate in multigpu setup + args.world_size = ngpus_per_node * args.world_size + + print("Multigpu", ngpus_per_node, "rescaled lr", args.optim_lr) + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(args,)) + else: + main_worker(0, args) From 53252952c4a9a15abd14b5d534844aa9a6e5f84a Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 23 May 2022 15:01:49 +0000 Subject: [PATCH 02/18] Fix a typo Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index 45da2ee931..1ce03f6493 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -278,7 +278,7 @@ def main_worker(gpu, args): train_transform = Compose( [ - LoadImaged(keys=["image"], reader=WSIReader, backend="cupy", dtype=np.uint8, level=1, image_only=True), + LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=1, image_only=True), LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes), GridPatchd( keys=["image"], From afb66f731d25057a42ebe0f7e711604883c0689f Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 23 May 2022 15:06:44 +0000 Subject: [PATCH 03/18] Tiffile to cucim Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index 1ce03f6493..e01d7e0833 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -298,7 +298,7 @@ def main_worker(gpu, args): valid_transform = Compose( [ - LoadImaged(keys=["image"], reader=WSIReader, backend="TiffFile", dtype=np.uint8, level=1, image_only=True), + LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=1, image_only=True), LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes), TileOnGridd( keys=["image"], From de9f7ebab7330a90a1ff79192b8f6cdf197aa8f9 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 23 May 2022 15:15:05 +0000 Subject: [PATCH 04/18] Update to RandGridPatchd Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index e01d7e0833..33281cc3b5 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -280,12 +280,10 @@ def main_worker(gpu, args): [ LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=1, image_only=True), LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes), - GridPatchd( + RandGridPatchd( keys=["image"], patch_size=args.tile_size, - start_pos="random", max_num_patches=args.tile_count, - overlap=0.0, pad_opts={"constant_values": 255}, ), RandFlipd(keys=["image"], spatial_axis=0, prob=0.5), @@ -300,13 +298,11 @@ def main_worker(gpu, args): [ LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=1, image_only=True), LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes), - TileOnGridd( + GridPatchd( keys=["image"], - tile_count=None, - tile_size=args.tile_size, - random_offset=False, - background_val=255, - return_list_of_dicts=True, + patch_size=args.tile_size, + max_num_patches=args.tile_count, + pad_opts={"constant_values": 255}, ), ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), ToTensord(keys=["image", "label"]), From a54a17d06d1f154df56c62d52c1af60bdf41d5cd Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 23 May 2022 15:15:59 +0000 Subject: [PATCH 05/18] Fix import Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index 33281cc3b5..293c1d32ef 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -27,6 +27,7 @@ Compose, LoadImaged, RandFlipd, + RandGridPatchd, GridPatchd, RandRotate90d, ScaleIntensityRanged, From 554ed71e3e8e8a4cfbfbcb4077f29bc7f111634c Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 23 May 2022 15:17:22 +0000 Subject: [PATCH 06/18] Add sort_key Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index 293c1d32ef..ebd0ee0ee4 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -285,6 +285,7 @@ def main_worker(gpu, args): keys=["image"], patch_size=args.tile_size, max_num_patches=args.tile_count, + sort_key="max", pad_opts={"constant_values": 255}, ), RandFlipd(keys=["image"], spatial_axis=0, prob=0.5), @@ -303,6 +304,7 @@ def main_worker(gpu, args): keys=["image"], patch_size=args.tile_size, max_num_patches=args.tile_count, + sort_key="max", pad_opts={"constant_values": 255}, ), ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), From 0724dfe3b2aa5f283625dd40950c34d8b2f64bb8 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 23 May 2022 18:23:42 +0000 Subject: [PATCH 07/18] sort key max to min Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index ebd0ee0ee4..a790695651 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -285,7 +285,7 @@ def main_worker(gpu, args): keys=["image"], patch_size=args.tile_size, max_num_patches=args.tile_count, - sort_key="max", + sort_key="min", pad_opts={"constant_values": 255}, ), RandFlipd(keys=["image"], spatial_axis=0, prob=0.5), @@ -303,8 +303,7 @@ def main_worker(gpu, args): GridPatchd( keys=["image"], patch_size=args.tile_size, - max_num_patches=args.tile_count, - sort_key="max", + sort_key="min", pad_opts={"constant_values": 255}, ), ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), From 7f34968f8df6903122b9add6f79037a1f18af54f Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 23 May 2022 18:56:03 +0000 Subject: [PATCH 08/18] Fix patch_size Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index a790695651..6be73ee933 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -283,7 +283,7 @@ def main_worker(gpu, args): LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes), RandGridPatchd( keys=["image"], - patch_size=args.tile_size, + patch_size=(args.tile_size, args.tile_size), max_num_patches=args.tile_count, sort_key="min", pad_opts={"constant_values": 255}, @@ -302,7 +302,7 @@ def main_worker(gpu, args): LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes), GridPatchd( keys=["image"], - patch_size=args.tile_size, + patch_size=(args.tile_size, args.tile_size), sort_key="min", pad_opts={"constant_values": 255}, ), From a99fe581557b7efe0e68af12b0dd165d5ee887c2 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Mon, 23 May 2022 19:37:42 +0000 Subject: [PATCH 09/18] Stack patch locations Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index 6be73ee933..953e9a293f 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -238,6 +238,7 @@ def list_data_collate(batch: collections.abc.Sequence): for i, item in enumerate(batch): data = item[0] data["image"] = torch.stack([ix["image"] for ix in item], dim=0) + data["patch"]['location'] = [ix["patch"]['location'] for ix in item] batch[i] = data return default_collate(batch) From 8116255b0adb187897dde98697abaa0c63296cf1 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 27 May 2022 02:50:37 +0000 Subject: [PATCH 10/18] Remove location Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index 953e9a293f..e65c865a24 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -33,8 +33,6 @@ ScaleIntensityRanged, ToTensord, ) -from monai.transforms.spatial.array import GridPatch -from monai.transforms.spatial.dictionary import GridPatchd from monai.networks.nets import milmodel @@ -209,7 +207,7 @@ class LabelEncodeIntegerGraded(Transform): """ - def __init__(self, num_classes, keys=["label"]): + def __init__(self, num_classes, keys=("label",)): super().__init__() self.keys = keys self.num_classes = num_classes @@ -238,7 +236,6 @@ def list_data_collate(batch: collections.abc.Sequence): for i, item in enumerate(batch): data = item[0] data["image"] = torch.stack([ix["image"] for ix in item], dim=0) - data["patch"]['location'] = [ix["patch"]['location'] for ix in item] batch[i] = data return default_collate(batch) @@ -285,7 +282,7 @@ def main_worker(gpu, args): RandGridPatchd( keys=["image"], patch_size=(args.tile_size, args.tile_size), - max_num_patches=args.tile_count, + num_patches=args.tile_count, sort_key="min", pad_opts={"constant_values": 255}, ), From a07e1217c4fa31e406778d93801f571cf6b563b6 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 27 May 2022 12:58:28 +0000 Subject: [PATCH 11/18] Change to fix_num_patches Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index e65c865a24..4cf14c882d 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -282,7 +282,7 @@ def main_worker(gpu, args): RandGridPatchd( keys=["image"], patch_size=(args.tile_size, args.tile_size), - num_patches=args.tile_count, + fix_num_patches=args.tile_count, sort_key="min", pad_opts={"constant_values": 255}, ), From 975edc4fe68393a7cd4df811751eb716dd6f49f8 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 31 May 2022 12:34:54 +0000 Subject: [PATCH 12/18] Update grid patch Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index 4cf14c882d..912177fd6a 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -282,9 +282,9 @@ def main_worker(gpu, args): RandGridPatchd( keys=["image"], patch_size=(args.tile_size, args.tile_size), - fix_num_patches=args.tile_count, - sort_key="min", - pad_opts={"constant_values": 255}, + num_patches=args.tile_count, + sort_fn="min", + constant_values=255, ), RandFlipd(keys=["image"], spatial_axis=0, prob=0.5), RandFlipd(keys=["image"], spatial_axis=1, prob=0.5), @@ -301,8 +301,8 @@ def main_worker(gpu, args): GridPatchd( keys=["image"], patch_size=(args.tile_size, args.tile_size), - sort_key="min", - pad_opts={"constant_values": 255}, + sort_fn="min", + constant_values=255, ), ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), ToTensord(keys=["image", "label"]), From e21b81d3d242c63c948cb1da655fb74e3b898c91 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 2 Jun 2022 00:02:12 +0000 Subject: [PATCH 13/18] Add threshold_filter Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index 912177fd6a..8c664fe764 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -275,6 +275,10 @@ def main_worker(gpu, args): training_list = training_list[:16] validation_list = validation_list[:16] + def threshold_filter(patch): + thresh = 0.999 * 3 * 255 * args.tile_size * args.tile_size + return patch.sum() < thresh + train_transform = Compose( [ LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=1, image_only=True), @@ -284,6 +288,7 @@ def main_worker(gpu, args): patch_size=(args.tile_size, args.tile_size), num_patches=args.tile_count, sort_fn="min", + filter_fn=threshold_filter, constant_values=255, ), RandFlipd(keys=["image"], spatial_axis=0, prob=0.5), @@ -302,6 +307,7 @@ def main_worker(gpu, args): keys=["image"], patch_size=(args.tile_size, args.tile_size), sort_fn="min", + filter_fn=threshold_filter, constant_values=255, ), ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), From 10ef5fa680bdf0b308dceb5b65d7c07cd96e8362 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 2 Jun 2022 00:53:13 +0000 Subject: [PATCH 14/18] Update threshold and add pad_mode=None Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index 8c664fe764..e0ab543e5e 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -35,7 +35,7 @@ ) from monai.networks.nets import milmodel - + def train_epoch(model, loader, optimizer, scaler, epoch, args): """One train epoch over the dataset""" @@ -238,8 +238,7 @@ def list_data_collate(batch: collections.abc.Sequence): data["image"] = torch.stack([ix["image"] for ix in item], dim=0) batch[i] = data return default_collate(batch) - - + def main_worker(gpu, args): args.gpu = gpu @@ -275,10 +274,6 @@ def main_worker(gpu, args): training_list = training_list[:16] validation_list = validation_list[:16] - def threshold_filter(patch): - thresh = 0.999 * 3 * 255 * args.tile_size * args.tile_size - return patch.sum() < thresh - train_transform = Compose( [ LoadImaged(keys=["image"], reader=WSIReader, backend="cucim", dtype=np.uint8, level=1, image_only=True), @@ -288,7 +283,7 @@ def threshold_filter(patch): patch_size=(args.tile_size, args.tile_size), num_patches=args.tile_count, sort_fn="min", - filter_fn=threshold_filter, + pad_mode=None, constant_values=255, ), RandFlipd(keys=["image"], spatial_axis=0, prob=0.5), @@ -307,7 +302,8 @@ def threshold_filter(patch): keys=["image"], patch_size=(args.tile_size, args.tile_size), sort_fn="min", - filter_fn=threshold_filter, + threshold=0.999 * 3 * 255 * args.tile_size * args.tile_size, + pad_mode=None, constant_values=255, ), ScaleIntensityRanged(keys=["image"], a_min=np.float32(255), a_max=np.float32(0)), From ecb1ba11196f4ecfa79cc41cc86a61128b010b39 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jun 2022 00:59:10 +0000 Subject: [PATCH 15/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../panda_mil_train_evaluate_pytorch_gpu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index e0ab543e5e..ebe8112d2c 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -35,7 +35,7 @@ ) from monai.networks.nets import milmodel - + def train_epoch(model, loader, optimizer, scaler, epoch, args): """One train epoch over the dataset""" @@ -238,7 +238,7 @@ def list_data_collate(batch: collections.abc.Sequence): data["image"] = torch.stack([ix["image"] for ix in item], dim=0) batch[i] = data return default_collate(batch) - + def main_worker(gpu, args): args.gpu = gpu From a7749734940980026842f043ef3007006a2bfdb3 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Tue, 7 Jun 2022 03:15:27 +0000 Subject: [PATCH 16/18] Update mil pipeline with new grid patch Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index ebe8112d2c..f8d3727d6c 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -239,6 +239,7 @@ def list_data_collate(batch: collections.abc.Sequence): batch[i] = data return default_collate(batch) + def main_worker(gpu, args): args.gpu = gpu @@ -282,7 +283,8 @@ def main_worker(gpu, args): keys=["image"], patch_size=(args.tile_size, args.tile_size), num_patches=args.tile_count, - sort_fn="min", + filter_high_values=True, + return_location=False, pad_mode=None, constant_values=255, ), @@ -301,8 +303,9 @@ def main_worker(gpu, args): GridPatchd( keys=["image"], patch_size=(args.tile_size, args.tile_size), - sort_fn="min", + filter_high_values=True, threshold=0.999 * 3 * 255 * args.tile_size * args.tile_size, + return_location=False, pad_mode=None, constant_values=255, ), From 34195802d42c255797754099a9104e296e9fa526 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Thu, 9 Jun 2022 00:27:21 +0000 Subject: [PATCH 17/18] Update GridPatch args Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index f8d3727d6c..ebe490265f 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -283,8 +283,7 @@ def main_worker(gpu, args): keys=["image"], patch_size=(args.tile_size, args.tile_size), num_patches=args.tile_count, - filter_high_values=True, - return_location=False, + sort_fn="min", pad_mode=None, constant_values=255, ), @@ -303,9 +302,7 @@ def main_worker(gpu, args): GridPatchd( keys=["image"], patch_size=(args.tile_size, args.tile_size), - filter_high_values=True, threshold=0.999 * 3 * 255 * args.tile_size * args.tile_size, - return_location=False, pad_mode=None, constant_values=255, ), From 1a453054a82040c1c7ffbd19f5ca4a21375853a3 Mon Sep 17 00:00:00 2001 From: Behrooz <3968947+drbeh@users.noreply.github.com> Date: Fri, 8 Jul 2022 15:55:13 +0000 Subject: [PATCH 18/18] Update LabelEncodeIntegerGraded and sort imports Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> --- .../panda_mil_train_evaluate_pytorch_gpu.py | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py index ebe490265f..3d4ea7cb4d 100644 --- a/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py +++ b/pathology/multiple_instance_learning/panda_mil_train_evaluate_pytorch_gpu.py @@ -1,39 +1,36 @@ -import os -import time -import shutil import argparse import collections.abc -import gdown +import os +import shutil +import time +import gdown import numpy as np -from sklearn.metrics import cohen_kappa_score - import torch -import torch.nn as nn -from torch.cuda.amp import GradScaler, autocast - -from torch.utils.tensorboard import SummaryWriter -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.dataloader import default_collate - import torch.distributed as dist import torch.multiprocessing as mp - +import torch.nn as nn +from monai.config import KeysCollection from monai.data import Dataset, load_decathlon_datalist from monai.data.wsi_reader import WSIReader from monai.metrics import Cumulative, CumulativeAverage +from monai.networks.nets import milmodel from monai.transforms import ( - Transform, Compose, + GridPatchd, LoadImaged, + MapTransform, RandFlipd, RandGridPatchd, - GridPatchd, RandRotate90d, ScaleIntensityRanged, ToTensord, ) -from monai.networks.nets import milmodel +from sklearn.metrics import cohen_kappa_score +from torch.cuda.amp import GradScaler, autocast +from torch.utils.data.dataloader import default_collate +from torch.utils.data.distributed import DistributedSampler +from torch.utils.tensorboard import SummaryWriter def train_epoch(model, loader, optimizer, scaler, epoch, args): @@ -194,7 +191,7 @@ def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0): print("Saving checkpoint", filename) -class LabelEncodeIntegerGraded(Transform): +class LabelEncodeIntegerGraded(MapTransform): """ Convert an integer label to encoded array representation of length num_classes, with 1 filled in up to label index, and 0 otherwise. For example for num_classes=5, @@ -202,14 +199,18 @@ class LabelEncodeIntegerGraded(Transform): Args: num_classes: the number of classes to convert to encoded format. - keys: keys of the corresponding items to be transformed - Defaults to ``['label']``. + keys: keys of the corresponding items to be transformed. Defaults to ``'label'``. + allow_missing_keys: don't raise exception if key is missing. """ - def __init__(self, num_classes, keys=("label",)): - super().__init__() - self.keys = keys + def __init__( + self, + num_classes: int, + keys: KeysCollection = "label", + allow_missing_keys: bool = False, + ): + super().__init__(keys, allow_missing_keys) self.num_classes = num_classes def __call__(self, data):