Skip to content
23 changes: 22 additions & 1 deletion monai/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class Cumulative(ABC):
cum.add(a, b)
cum.add(c, d)
cum.aggregate()
result = cum.get_buffer()
result = cum.get_buffer() # optional
cum.reset()

"""
Expand Down Expand Up @@ -197,6 +197,27 @@ class CumulativeIterationMetric(Cumulative, IterationMetric):
Typically, it computes some intermediate results for every iteration, cumulates in buffers,
then syncs across all the distributed ranks and aggregates for the final result when epoch completed.

For example, `MeanDice` inherits this class and the usage:

.. code-block:: python

dice_metric = DiceMetric(include_background=True, reduction="mean")

for val_data in val_loader:
val_outputs = model(val_data["img"])
val_outputs = [postprocessing_transform(i) for i in decollate_batch(val_outputs)]
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_data["seg"])

# aggregate the final mean dice result
metric = dice_metric.aggregate().item()

# reset the status for next computation round
dice_metric.reset()

And to load `predictions` and `labels` from files, then compute metrics with multi-processing, please refer to:
https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py.

"""

def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): # type: ignore
Expand Down