from typing import Optional
To implement your own custom metric, subclass the base ~torchmetrics.Metric
class and implement the following methods:
__init__()
: Each state variable should be called usingself.add_state(...)
.update()
: Any code needed to update the state given any inputs to the metric.compute()
: Computes a final value from the state of the metric.
We provide the remaining interface, such as reset()
that will make sure to correctly reset all metric states that have been added using add_state
. You should therefore not implement reset()
yourself. Additionally, adding metric states with add_state
will make sure that states are correctly synchronized in distributed settings (DDP). To see how metric states are synchronized across distributed processes, refer to add_state()
docs from the base Metric
class.
Example implementation:
from torchmetrics import Metric
- class MyAccuracy(Metric):
- def __init__(self):
super().__init__() self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
- def update(self, preds: torch.Tensor, target: torch.Tensor):
preds, target = self._input_format(preds, target) assert preds.shape == target.shape
self.correct += torch.sum(preds == target) self.total += target.numel()
- def compute(self):
return self.correct.float() / self.total
Additionally you may want to set the class properties: is_differentiable, higher_is_better and full_state_update. Note that none of them are strictly required for the metric to work.
from torchmetrics import Metric
- class MyMetric(Metric):
# Set to True if the metric is differentiable else set to False is_differentiable: Optional[bool] = None
# Set to True if the metric reaches it optimal value when the metric is maximized. # Set to False if it when the metric is minimized. higher_is_better: Optional[bool] = True
# Set to True if the metric during 'update' requires access to the global metric # state for its calculations. If not, setting this to False indicates that all # batch states are independent and we will optimize the runtime of 'forward' full_state_update: bool = True
This section briefly describes how metrics work internally. We encourage looking at the source code for more info. Internally, TorchMetrics wraps the user defined update()
and compute()
method. We do this to automatically synchronize and reduce metric states across multiple devices. More precisely, calling update()
does the following internally:
- Clears computed cache.
- Calls user-defined
update()
.
Similarly, calling compute()
does the following internally:
- Syncs metric states between processes.
- Reduce gathered metric states.
- Calls the user defined
compute()
method on the gathered metric states. - Cache computed result.
From a user's standpoint this has one important side-effect: computed results are cached. This means that no matter how many times compute
is called after one and another, it will continue to return the same result. The cache is first emptied on the next call to update
.
forward
serves the dual purpose of both returning the metric on the current data and updating the internal metric state for accumulating over multiple batches. The forward()
method achieves this by combining calls to update
, compute
and reset
. Depending on the class property full_state_update
, forward
can behave in two ways:
- If
full_state_update
isTrue
it indicates that the metric duringupdate
requires access to the full metric state and we therefore need to do two calls toupdate
to secure that the metric is calculated correctly- Calls
update()
to update the global metric state (for accumulation over multiple batches) - Caches the global state.
- Calls
reset()
to clear global metric state. - Calls
update()
to update local metric state. - Calls
compute()
to calculate metric for current batch. - Restores the global state.
- Calls
- If
full_state_update
isFalse
(default) the metric state of one batch is completly independent of the state of other batches, which means that we only need to callupdate
once.- Caches the global state.
- Calls
reset
the metric to its default state - Calls
update
to update the state with local batch statistics - Calls
compute
to calculate the metric for the current batch - Reduce the global state and batch state into a single state that becomes the new global state
If implementing your own metric, we recommend trying out the metric with full_state_update
class property set to both True
and False
. If the results are equal, then setting it to False
will usually give the best performance.
torchmetrics.Metric
Wanting to contribute the metric you have implemented? Great, we are always open to adding more metrics to torchmetrics
as long as they serve a general purpose. However, to keep all our metrics consistent we request that the implementation and tests gets formatted in the following way:
- Start by reading our contribution guidelines.
- First implement the functional backend. This takes cares of all the logic that goes into the metric. The code should be put into a single file placed under
torchmetrics/functional/"domain"/"new_metric".py
wheredomain
is the type of metric (classification, regression, nlp etc) andnew_metric
is the name of the metric. In this file, there should be the following three functions:
_new_metric_update(...)
: everything that has to do with type/shape checking and all logic required before distributed syncing need to go here._new_metric_compute(...)
: all remaining logic.new_metric(...)
: essentially wraps the_update
and_compute
private functions into one public function that makes up the functional interface for the metric.Note
The functional accuracy metric is a great example of this division of logic.
- In a corresponding file placed in
torchmetrics/"domain"/"new_metric".py
create the module interface:
- Create a new module metric by subclassing
torchmetrics.Metric
.- In the
__init__
of the module callself.add_state
for as many metric states are needed for the metric to proper accumulate metric statistics.- The module interface should essentially call the private
_new_metric_update(...)
in its update method and similarly the_new_metric_compute(...)
function in itscompute
. No logic should really be implemented in the module interface. We do this to not have duplicate code to maintain.Note
The module Accuracy metric that corresponds to the above functional example showcases these steps.
- Remember to add binding to the different relevant
__init__
files. - Testing is key to keeping
torchmetrics
trustworthy. This is why we have a very rigid testing protocol. This means that we in most cases require the metric to be tested against some other common framework (sklearn
,scipy
etc).
- Create a testing file in
unittests/"domain"/test_"new_metric".py
. Only one file is needed as it is intended to test both the functional and module interface.- In that file, start by defining a number of test inputs that your metric should be evaluated on.
- Create a testclass
class NewMetric(MetricTester)
that inherits fromtests.helpers.testers.MetricTester
. This testclass should essentially implement thetest_"new_metric"_class
andtest_"new_metric"_fn
methods that respectively tests the module interface and the functional interface.- The testclass should be parameterized (using
@pytest.mark.parametrize
) by the different test inputs defined initially. Additionally, thetest_"new_metric"_class
method should also be parameterized with anddp
parameter such that it gets tested in a distributed setting. If your metric has additional parameters, then make sure to also parameterize these such that different combinations of inputs and parameters gets tested.- (optional) If your metric raises any exception, please add tests that showcase this.
Note
The test file for accuracy metric shows how to implement such tests.
If you only can figure out part of the steps, do not fear to send a PR. We will much rather receive working metrics that are not formatted exactly like our codebase, than not receiving any. Formatting can always be applied. We will gladly guide and/or help implement the remaining :]