Skip to content

New Image metrics & wrappers

Compare
Choose a tag to compare
@Borda Borda released this 11 Jan 12:56
· 156 commits to master since this release

TorchMetrics v1.3 is out now! This release introduces seven new metrics in the different subdomains of TorchMetrics, adding some nice features to already established metrics. In this blogpost, we present the new metrics with short code samples.

We are happy to see the continued adoption of TorchMetrics in over 19,000 Github repositories projects, and we are proud to release that we have passed 1,800 GitHub stars.

New metrics

The retrieval domain has received one new metric in this release: RetrievalAUROC. This metric calculates the Area Under the Receiver Operation Curve for document retrieval data. It is similar to the standard AUROC metric from classification but also supports the additional indexes argument that all retrieval metrics support.

from torch import tensor
from torchmetrics.retrieval import RetrievalAUROC
indexes = tensor([0, 0, 0, 1, 1, 1, 1])
preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
target = tensor([False, False, True, False, True, False, True])
r_auroc = RetrievalAUROC()
r_auroc(preds, target, indexes=indexes)
# tensor(0.7500)

The image subdomain is receiving two new metrics in v1.3, which brings the total number image-specific metrics in TorchMetrics to 21! As with other metrics, these two new metrics work by comparing a predicted image tensor to a ground truth image, but they focus on different properties for their metric calculation.

  • The first metrics is SpatialCorrelationCoefficient. As the name indicates this metric focuses on how well the spatial structure of the predicted image correlates with the ground truth image.

    import torch
    torch.manual_seed(42)
    from torchmetrics.image import SpatialCorrelationCoefficient as SCC
    preds = torch.randn([32, 3, 64, 64])
    target = torch.randn([32, 3, 64, 64])
    scc = SCC()
    scc(preds, target)
    # tensor(0.0023)
  • The second metrics is SpatialDistortionIndex compares the spatial structure of the images, and is especially useful for evaluating multi spectral images

    import torch
    from torchmetrics.image import SpatialDistortionIndex
    preds = torch.rand([16, 3, 32, 32])
    target = {
      'ms': torch.rand([16, 3, 16, 16]),
      'pan': torch.rand([16, 3, 32, 32]),
    }
    sdi = SpatialDistortionIndex()
    sdi(preds, target)
    # tensor(0.0090)

A new wrapper metric called FeatureShare has also been added. This can be seen as a specialized version of MetricCollection that can be combined with metrics that use a neural network as part of their metric calculation. For example, FrechetInceptionDistance , InceptionScore, KernelInceptionDistance all, by default, use an inception network for their metric calculations. When these metrics were combined inside a MetricCollection, the underlying neural network was still called three times, which is quite redundant and wastes resources. In principle, it should be possible only to call it once and then propagate the value to all metrics, which is exactly what the FeatureShare wrapper solves.

import torch
from torchmetrics.wrappers import FeatureShare
from torchmetrics import MetricCollection
from torchmetrics.image import FrechetInceptionDistance, KernelInceptionDistance

def fs_wrapper():
    fs = FeatureShare([FrechetInceptionDistance(), KernelInceptionDistance(subset_size=10, subsets=2)])
    fs.update(torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8), real=True)
    fs.update(torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8), real=False)
    fs.compute()

def mc_wrapper():
    mc = MetricCollection([FrechetInceptionDistance(), KernelInceptionDistance(subset_size=10, subsets=2)])
    mc.update(torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8), real=True)
    mc.update(torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8), real=False)
    mc.compute()

# lets compare (using ipython timeit function)
% timeit fs_wrapper()
# 8.38 s ± 564 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
% timeit mc_wrapper()
# 13.8 s ± 232 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

This will most likely be significantly faster than the alternative metric collection, as show in the code example.

Improved features

In v1.2, several new arguments were added to MeanAveragePrecision metric from the detection package. This metric has seen a further small improvement in that the argument extended_summary=True also returns confidence scores. The confidence scores are the score assigned by the model on how confident a given predicted bounding box belongs to a certain class.

from torch import tensor
from torchmetrics.detection import MeanAveragePrecision
# enable extended summary
map_metric = MeanAveragePrecision(extended_summary=True)
preds = [
	{
		"boxes": torch.tensor([[0.5, 0.5, 1, 1]]),
		"scores": torch.tensor([1.0]),
		"labels": torch.tensor([0]),
	}
]
target = [
	{"boxes": torch.tensor([[0, 0, 1, 1]]), "labels": torch.tensor([0])}
]
map_metric.update(preds, target)
result = map_metric.compute()

# new confidence score can be found in the "score" key
confidence_scores = result["scores"]
# in this case confidence_score will have shape (10, 101, 1, 4, 3)
# because
#   * We are by default evaluating for 10 different IoU thresholds
#   * We evaluate the PR-curve based on 101 linearly spaced locations
#   * We only have 1 class (see the labels tensor)
#   * There are 4 area sizes we evaluate on (small, medium, large and all)
#   * By default `max_detection_thresholds=[1,10,100]` meaning we evaluate for 3 values

From v1.3 all retrieval metrics now support an argument called aggregation that determines how the metric should be aggregated over different documents. The supported options are "mean", "median", "max", "min" with the default value being "mean" which is fully backward compatible with earlier versions of TorchMetrics.

from torch import tensor
from torchmetrics.retrieval import RetrievalHitRate
indexes = tensor([0, 0, 0, 1, 1, 1, 1])
preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
target = tensor([True, False, False, False, True, False, True])
hr2 = RetrievalHitRate(aggregation="max")
hr2(preds, target, indexes=indexes)
# tensor(1.000)

Finally, the SacreBLEU metric from the text domain now supports even more tokenizers: "ja-mecab", "ko-mecab", "flores101", "flores200”.

Changes and bugfixes

Users should be aware that from v1.3, TorchMetrics now only supports v1.10 of Pytorch and up (before v1.8). We always try to provide support for Pytorch releases for up to two years.

There have been several bug fixes related to numerical stability in several metrics. For this reason, we always recommend that users use the most recent version of Torchmetrics for the best experience.

Thank you!

As always, we offer a big thank you to all of our community members for their contributions and feedback. Please open an issue in the repo if you have any recommendations for the next metrics we should tackle.

If you want to ask a question or join us in expanding Torchmetrics, please join our discord server, where you can ask questions and get guidance in the #torchmetrics channel.

🔥 Check out the documentation and code! 🚀

[1.3.0] - 2024-01-10

Added

  • Added more tokenizers for SacreBLEU metric (#2068)
  • Added support for logging MultiTaskWrapper directly with lightnings log_dict method (#2213)
  • Added FeatureShare wrapper to share submodules containing feature extractors between metrics (#2120)
  • Added new metrics to image domain:
    • SpatialDistortionIndex (#2260)
    • Added CriticalSuccessIndex (#2257)
    • Spatial Correlation Coefficient (#2248)
  • Added average argument to multiclass versions of PrecisionRecallCurve and ROC (#2084)
  • Added confidence scores when extended_summary=True in MeanAveragePrecision (#2212)
  • Added RetrievalAUROC metric (#2251)
  • Added aggregate argument to retrieval metrics (#2220)
  • Added utility functions in segmentation.utils for future segmentation metrics (#2105)

Changed

  • Changed minimum supported Pytorch version from 1.8 to 1.10 (#2145)
  • Changed x-/y-axis order for PrecisionRecallCurve to be consistent with scikit-learn (#2183)

Deprecated

  • Deprecated metric._update_called (#2141)
  • Deprecated specicity_at_sensitivity in favour of specificity_at_sensitivity (#2199)

Fixed

  • Fixed support for half precision + CPU in metrics requiring topk operator (#2252)
  • Fixed warning incorrectly being raised in Running metrics (#2256)
  • Fixed integration with custom feature extractor in FID metric (#2277)

Full Changelog: v1.2.0...v1.3.0

Key Contributors

@Borda, @HoseinAkbarzadeh, @matsumotosan, @miskfi, @oguz-hanoglu, @SkafteNicki, @stancld, @ywchan2005

New Contributors

If we forgot someone due to not matching commit email with GitHub account, let us know :]