-
Notifications
You must be signed in to change notification settings - Fork 466
Description
🚀 Feature
First of all, thank you very much for this awesome project. It helps a lot evaluating deep learning models on standard and out-of-the-box metrics for several application domain.
My request(s) will be more focused on classification metrics (also applied to semantic segmentation). Maybe, we will need to split that into several issues. Do not hesitate.
Motivation
I have made some tests to check if all metrics taking the same inputs work as expected. I have multiple scenarios in mind, since semantic segmentation is a pixel-wise classification task :
- Binary classification
- Multi-class classification
- Binary semantic segmentation
- Multi-class semantic segmentation
Doing those tests (mentioned below), I noticed some inconsistency between the different metrics (shape of inputs, handling binary scenario, parameters name, parameters default values, etc.). The idea would be to standardize more the interface of each classification metrics.
Pitch
Let us take the different scenarii in mind:
Binary classification
import torch
from torchmetrics import *
# Inputs
num_classes = 2
logits = torch.randn((10, num_classes))
targets = torch.randint(0, num_classes, (10, ))
probabilities = torch.softmax(logits, dim=1)
# Compute classification metrics
metrics = MetricCollection({
'acc': Accuracy(num_classes=num_classes),
'average_precision': AveragePrecision(num_classes=num_classes),
'auroc': AUROC(num_classes=num_classes),
'binned_average_precision': BinnedAveragePrecision(num_classes=num_classes, thresholds=5),
'binned_precision_recall_curve': BinnedPrecisionRecallCurve(num_classes=num_classes, thresholds=5),
'binned_recall_at_fixed_precision': BinnedRecallAtFixedPrecision(num_classes=num_classes, min_precision=0.5, thresholds=5),
'calibration_error': CalibrationError(),
'cohen_kappa': CohenKappa(num_classes=num_classes),
'confusion_matrix': ConfusionMatrix(num_classes=num_classes),
'f1': F1(num_classes=num_classes),
'f2': FBeta(num_classes=num_classes, beta=2),
'hamming_distance': HammingDistance(),
'hinge': Hinge(),
'iou': IoU(num_classes=num_classes),
# 'kl_divergence': KLDivergence()
'matthews_correlation_coef': MatthewsCorrcoef(num_classes=num_classes),
'precision': Precision(num_classes=num_classes),
'precision_recall_curve': PrecisionRecallCurve(num_classes=num_classes),
'recall': Recall(num_classes=num_classes),
'roc': ROC(num_classes=num_classes),
'specificity': Specificity(num_classes=num_classes),
'stat_scores': StatScores(num_classes=num_classes)
})
metrics.update(probabilities, targets)
metrics.compute()Binary classification is acting as expected, except for the KLDivergence metric which complain with following exception:
RuntimeError: Predictions and targets are expected to have the same shapeI have several questions regarding the binary scenario:
- What is considered the best practice? Having
num_classesset to 2? Having set to 1 but some metrics complain about it (ConfusionMatrix,IoU)? - Is setting
num_classesto 2 is considered asbinaryormulti-classmode? I imagine it depends on the metric. - Some metrics take
num_classesasNonevalue for binary mode, some others don't. It is not really clear, to my mind, how to handle properly binary classification with all those metrics.
Multi-class classification
This mode seems to work as expected, and being consistent since num_classes is clearly defined here. Hence, the same question, is computing the metrics with 2 classes equivalent to binary mode. If not, how to handle properly that mode.
Code that is properly working (expect for the KLDivergence, same issue as mentioned above).
import torch
from torchmetrics import *
# Inputs
num_classes = 5
logits = torch.randn((10, num_classes))
targets = torch.randint(0, num_classes, (10, ))
probabilities = torch.softmax(logits, dim=1)
# Compute classification metrics
metrics = MetricCollection({
'acc': Accuracy(num_classes=num_classes),
'average_precision': AveragePrecision(num_classes=num_classes),
'auroc': AUROC(num_classes=num_classes),
'binned_average_precision': BinnedAveragePrecision(num_classes=num_classes, thresholds=5),
'binned_precision_recall_curve': BinnedPrecisionRecallCurve(num_classes=num_classes, thresholds=5),
'binned_recall_at_fixed_precision': BinnedRecallAtFixedPrecision(num_classes=num_classes, min_precision=0.5, thresholds=5),
'calibration_error': CalibrationError(),
'cohen_kappa': CohenKappa(num_classes=num_classes),
'confusion_matrix': ConfusionMatrix(num_classes=num_classes),
'f1': F1(num_classes=num_classes),
'f2': FBeta(num_classes=num_classes, beta=2),
'hamming_distance': HammingDistance(),
'hinge': Hinge(),
'iou': IoU(num_classes=num_classes),
# 'kl_divergence': KLDivergence(),
'matthews_correlation_coef': MatthewsCorrcoef(num_classes=num_classes),
'precision': Precision(num_classes=num_classes),
'precision_recall_curve': PrecisionRecallCurve(num_classes=num_classes),
'recall': Recall(num_classes=num_classes),
'roc': ROC(num_classes=num_classes),
'specificity': Specificity(num_classes=num_classes),
'stat_scores': StatScores(num_classes=num_classes)
})
metrics.update(probabilities, targets)
metrics.compute()Binary semantic segmentation
Since TorchMetric supports extra dimensions for logits and targets, these classification metrics may also be used for semantic segmentation tasks. But for that type of task, I ended up with some issues.
Let us take an example of 10 images of 32x32 pixels.
import torch
from torchmetrics import *
# Inputs
num_classes = 2
logits = torch.randn((10, num_classes, 32, 32))
targets = torch.randint(0, num_classes, (10, 32, 32))
probabilities = torch.softmax(logits, dim=1)
# Compute classification metrics
metrics = MetricCollection({
'acc': Accuracy(num_classes=num_classes),
'average_precision': AveragePrecision(num_classes=num_classes),
'auroc': AUROC(num_classes=num_classes),
# 'binned_average_precision': BinnedAveragePrecision(num_classes=num_classes, thresholds=5),
# 'binned_precision_recall_curve': BinnedPrecisionRecallCurve(num_classes=num_classes, thresholds=5),
# 'binned_recall_at_fixed_precision': BinnedRecallAtFixedPrecision(num_classes=num_classes, min_precision=0.5, thresholds=5),
'calibration_error': CalibrationError(),
'cohen_kappa': CohenKappa(num_classes=num_classes),
'confusion_matrix': ConfusionMatrix(num_classes=num_classes),
'f1': F1(num_classes=num_classes, mdmc_average='global'),
'f2': FBeta(num_classes=num_classes, beta=2, mdmc_average='global'),
'hamming_distance': HammingDistance(),
# 'hinge': Hinge(),
'iou': IoU(num_classes=num_classes),
# 'kl_divergence': KLDivergence(),
'matthews_correlation_coef': MatthewsCorrcoef(num_classes=num_classes),
'precision': Precision(num_classes=num_classes, mdmc_average='global'),
'precision_recall_curve': PrecisionRecallCurve(num_classes=num_classes),
'recall': Recall(num_classes=num_classes, mdmc_average='global'),
'roc': ROC(num_classes=num_classes),
'specificity': Specificity(num_classes=num_classes, mdmc_average='global'),
'stat_scores': StatScores(num_classes=num_classes, mdmc_reduce='global')
})
metrics.update(probabilities, targets)
metrics.compute()- First of all,
BinnedAveragePrecision,BinnedPrecisionRecallCurveandBinnedRecallAtFixedPrecisionare failing with following exception:
RuntimeError: The size of tensor a (2) must match the size of tensor b (32) at non-singleton dimension 2While it is stated in the documentation (https://torchmetrics.readthedocs.io/en/latest/references/modules.html#binnedrecallatfixedprecision) that forward should accept that type of format (logits = (N, C, ...) and targets = (N, ...)).
-
Accuracyhas a default value formdmc_averageset toglobal, while other metrics (Precision,FBeta,Specificity, etc) have it set toNone-> Need to be consistent on this. Either everything is set toNoneorglobal. I would argue that setting the default toglobalwould be ideal since the user would be able to use those metrics for classification or semantic segmentation tasks seamlessly. -
I notice that
StatScoresusesmdmc_reduceparameter while the other metrics called itmdmc_average. I think, it would be suitable to be consistent over the name of this parameter. The easiest change would be to adopt definitelymdmc_average(since onlyStatScoreswould need to be updated). -
Once again, how should we properly deal with
binarydata. For the example, I use a 2-classes workaround but is there a better of handlingbinarysemantic segmentation tasks?
Multi-class semantic segmentation
The scenario is the same as the previous one while setting num_classes to a higher value than 2.
Conclusion
This issue gathers different topics that may be need to be treated separately but have in common API consistency.
We can start a discussion about it, but my main question, regarding that long text, is how to properly deal with binary classification and semantic segmentation.
- Is using 2-classes a suitable workaround?
- Or should we be compliant with
BCELossWithLogitsinterface for example, meaning having theprobabilitiesandtargetshaving the exact same shape (for instance(10, 32, 32)in my example).
I would like to have your thoughts about that. :)
And sorry for that long issue but I wanted to be as clear as possible, providing meaningful example.
Alternatives
Document more on how each metric should be used in binary tasks.
Additional context
The main idea of all those requests is to ease the use of those metrics and standardize interfaces (shape of inputs, num_classes parameters, etc.). I understand that is a tough topic but it matters.