Skip to content

Commit 3249a74

Browse files
Bordawyli
andauthored
rename n_classes (#322)
Co-authored-by: Wenqi Li <wenqil@nvidia.com>
1 parent a75a1d6 commit 3249a74

File tree

24 files changed

+55
-55
lines changed

24 files changed

+55
-55
lines changed

2d_classification/mednist_tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@
375375
" [LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])\n",
376376
"\n",
377377
"y_pred_trans = Compose([EnsureType(), Activations(softmax=True)])\n",
378-
"y_trans = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=num_class)])"
378+
"y_trans = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=num_class)])"
379379
]
380380
},
381381
{

3d_classification/ignite/densenet_training_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def prepare_batch(batch, device=None, non_blocking=False):
126126
# add evaluation metric to the evaluator engine
127127
val_metrics = {metric_name: ROCAUC()}
128128

129-
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=2)])
129+
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])
130130
post_pred = Compose([EnsureType(), Activations(softmax=True)])
131131
# Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
132132
# user can add output_transform to return other values

3d_classification/torch/densenet_training_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def main():
8181
]
8282
)
8383
post_pred = Compose([EnsureType(), Activations(softmax=True)])
84-
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=2)])
84+
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])
8585

8686
# Define dataset, data loader
8787
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)

3d_segmentation/challenge_baseline/run_net.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ def get_xforms(mode="train", keys=("image", "label")):
8181
def get_net():
8282
"""returns a unet model instance."""
8383

84-
n_classes = 2
84+
num_classes = 2
8585
net = monai.networks.nets.BasicUNet(
8686
dimensions=3,
8787
in_channels=1,
88-
out_channels=n_classes,
88+
out_channels=num_classes,
8989
features=(32, 32, 64, 128, 256, 32),
9090
dropout=0.1,
9191
)
@@ -172,7 +172,7 @@ def train(data_folder=".", model_folder="runs"):
172172

173173
# create evaluator (to be used to measure model quality during training
174174
val_post_transform = monai.transforms.Compose(
175-
[EnsureTyped(keys=("pred", "label")), AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=True, n_classes=2)]
175+
[EnsureTyped(keys=("pred", "label")), AsDiscreted(keys=("pred", "label"), argmax=(True, False), to_onehot=True, num_classes=2)]
176176
)
177177
val_handlers = [
178178
ProgressBar(),

3d_segmentation/spleen_segmentation_3d.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,8 @@
471471
"best_metric_epoch = -1\n",
472472
"epoch_loss_values = []\n",
473473
"metric_values = []\n",
474-
"post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, n_classes=2)])\n",
475-
"post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=2)])\n",
474+
"post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)])\n",
475+
"post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])\n",
476476
"\n",
477477
"for epoch in range(max_epochs):\n",
478478
" print(\"-\" * 10)\n",
@@ -720,8 +720,8 @@
720720
" nearest_interp=False,\n",
721721
" to_tensor=True,\n",
722722
" ),\n",
723-
" AsDiscreted(keys=\"pred\", argmax=True, to_onehot=True, n_classes=2),\n",
724-
" AsDiscreted(keys=\"label\", to_onehot=True, n_classes=2),\n",
723+
" AsDiscreted(keys=\"pred\", argmax=True, to_onehot=True, num_classes=2),\n",
724+
" AsDiscreted(keys=\"label\", to_onehot=True, num_classes=2),\n",
725725
"])"
726726
]
727727
},

3d_segmentation/spleen_segmentation_3d_lightning.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,8 @@
241241
" norm=Norm.BATCH,\n",
242242
" )\n",
243243
" self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)\n",
244-
" self.post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, n_classes=2)])\n",
245-
" self.post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=2)])\n",
244+
" self.post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)])\n",
245+
" self.post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])\n",
246246
" self.dice_metric = DiceMetric(include_background=False, reduction=\"mean\", get_not_nans=False)\n",
247247
" self.best_val_dice = 0\n",
248248
" self.best_val_epoch = 0\n",

3d_segmentation/unetr_btcv_segmentation_3d.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,8 +681,8 @@
681681
"\n",
682682
"max_iterations = 25000\n",
683683
"eval_num = 500\n",
684-
"post_label = AsDiscrete(to_onehot=True, n_classes=14)\n",
685-
"post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=14)\n",
684+
"post_label = AsDiscrete(to_onehot=True, num_classes=14)\n",
685+
"post_pred = AsDiscrete(argmax=True, to_onehot=True, num_classes=14)\n",
686686
"dice_metric = DiceMetric(include_background=True, reduction=\"mean\", get_not_nans=False)\n",
687687
"global_step = 0\n",
688688
"dice_val_best = 0.0\n",

3d_segmentation/unetr_btcv_segmentation_3d_lightning.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,8 @@
415415
" ).to(device)\n",
416416
"\n",
417417
" self.loss_function = DiceCELoss(to_onehot_y=True, softmax=True)\n",
418-
" self.post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=14)\n",
419-
" self.post_label = AsDiscrete(to_onehot=True, n_classes=14)\n",
418+
" self.post_pred = AsDiscrete(argmax=True, to_onehot=True, num_classes=14)\n",
419+
" self.post_label = AsDiscrete(to_onehot=True, num_classes=14)\n",
420420
" self.dice_metric = DiceMetric(\n",
421421
" include_background=False, reduction=\"mean\", get_not_nans=False\n",
422422
" )\n",

acceleration/automatic_mixed_precision.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,8 @@
352352
" optimizer = torch.optim.Adam(model.parameters(), 1e-4)\n",
353353
" scaler = torch.cuda.amp.GradScaler() if amp else None\n",
354354
"\n",
355-
" post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, n_classes=2)])\n",
356-
" post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=2)])\n",
355+
" post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)])\n",
356+
" post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])\n",
357357
"\n",
358358
" dice_metric = DiceMetric(include_background=False, reduction=\"mean\", get_not_nans=False)\n",
359359
"\n",

acceleration/dataset_type_performance.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@
209209
" loss_function = DiceLoss(to_onehot_y=True, softmax=True)\n",
210210
" optimizer = torch.optim.Adam(model.parameters(), 1e-4)\n",
211211
"\n",
212-
" post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, n_classes=2)])\n",
213-
" post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=2)])\n",
212+
" post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)])\n",
213+
" post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])\n",
214214
"\n",
215215
" dice_metric = DiceMetric(include_background=True, reduction=\"mean\", get_not_nans=False)\n",
216216
"\n",

0 commit comments

Comments
 (0)