TorchMetrics is a collection of 100+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:
- A standardized interface to increase reproducibility
- Reduces Boilerplate
- Distributed-training compatible
- Rigorously tested
- Automatic accumulation over batches
- Automatic synchronization between multiple devices
You can use TorchMetrics in any PyTorch model, or within PyTorch Lightning to enjoy additional features:
- This means that your data will always be placed on the same device as your metrics.
- Native support for logging metrics in Lightning to reduce even more boilerplate.
You can install TorchMetrics using pip or conda:
# Python Package Index (PyPI)
pip install torchmetrics
# Conda
conda install -c conda-forge torchmetrics
Eventually if there is a missing PyTorch wheel for your OS or Python version you can simply compile PyTorch from source:
# Optional if you do not need compile GPU support
export USE_CUDA=0 # just to keep it simple
# you can install the latest state from master
pip install git+https://github.com/pytorch/pytorch.git
# OR set a particular PyTorch release
pip install git+https://github.com/pytorch/pytorch.git@<release-tag>
# and finalize with installing TorchMetrics
pip install torchmetrics
Similar to torch.nn, most metrics have both a class-based and a functional version. The functional versions implement the basic operations required for computing each metric. They are simple python functions that as input take torch.tensors and return the corresponding metric as a torch.tensor
. The code-snippet below shows a simple example for calculating the accuracy using the functional interface:
import torch # import our library import torchmetrics
# simulate a classification problem preds = torch.randn(10, 5).softmax(dim=-1) target = torch.randint(5, (10,))
acc = torchmetrics.functional.accuracy(preds, target, task="multiclass", num_classes=5)
Nearly all functional metrics have a corresponding class-based metric that calls it a functional counterpart underneath. The class-based metrics are characterized by having one or more internal metrics states (similar to the parameters of the PyTorch module) that allow them to offer additional functionalities:
- Accumulation of multiple batches
- Automatic synchronization between multiple devices
- Metric arithmetic
The code below shows how to use the class-based interface:
import torch # import our library import torchmetrics
# initialize metric metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5)
n_batches = 10 for i in range(n_batches): # simulate a classification problem preds = torch.randn(10, 5).softmax(dim=-1) target = torch.randint(5, (10,)) # metric on current batch acc = metric(preds, target) print(f"Accuracy on batch {i}: {acc}")
# metric on all batches using custom accumulation acc = metric.compute() print(f"Accuracy on all data: {acc}")
# Reseting internal state such that metric ready for new data metric.reset()
Accuracy on batch ...
Implementing your own metric is as easy as subclassing a torch.nn.Module
. Simply, subclass ~torchmetrics.Metric
and do the following:
- Implement
__init__
where you callself.add_state
for every internal state that is needed for the metrics computations - Implement
update
method, where all logic that is necessary for updating metric states go - Implement
compute
method, where the final metric computations happens
For practical examples and more info about implementing a metric, please see this page <implement>
.
TorchMetrics provides a Devcontainer configuration for Visual Studio Code to use a Docker container as a pre-configured development environment. This avoids struggles setting up a development environment and makes them reproducible and consistent. Please follow the installation instructions and make yourself familiar with the container tutorials if you want to use them. In order to use GPUs, you can enable them within the .devcontainer/devcontainer.json
file.