From 0353617446eefe6bfb8e334345ab350339e36d01 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 11:01:56 -0400 Subject: [PATCH 01/27] removed metric Co-authored-by: Teddy Koker --- pytorch_lightning/metrics/__init__.py | 66 - pytorch_lightning/metrics/classification.py | 866 --------- pytorch_lightning/metrics/converters.py | 410 ----- .../metrics/functional/__init__.py | 34 - .../metrics/functional/classification.py | 1056 ----------- pytorch_lightning/metrics/functional/nlp.py | 103 -- .../metrics/functional/reduction.py | 65 - .../metrics/functional/regression.py | 325 ---- .../metrics/functional/self_supervised.py | 46 - pytorch_lightning/metrics/metric.py | 262 --- pytorch_lightning/metrics/nlp.py | 60 - pytorch_lightning/metrics/regression.py | 361 ---- pytorch_lightning/metrics/self_supervised.py | 85 - pytorch_lightning/metrics/sklearns.py | 1590 ----------------- tests/metrics/__init__.py | 0 tests/metrics/functional/__init__.py | 0 .../metrics/functional/test_classification.py | 485 ----- tests/metrics/functional/test_nlp.py | 66 - tests/metrics/functional/test_reduction.py | 30 - tests/metrics/functional/test_regression.py | 175 -- .../functional/test_self_supervised.py | 35 - tests/metrics/test_aggregation.py | 297 --- tests/metrics/test_classification.py | 237 --- tests/metrics/test_converters.py | 265 --- tests/metrics/test_metrics.py | 323 ---- tests/metrics/test_nlp.py | 29 - tests/metrics/test_regression.py | 69 - tests/metrics/test_sklearn.py | 178 -- 28 files changed, 7518 deletions(-) delete mode 100644 pytorch_lightning/metrics/__init__.py delete mode 100644 pytorch_lightning/metrics/classification.py delete mode 100644 pytorch_lightning/metrics/converters.py delete mode 100644 pytorch_lightning/metrics/functional/__init__.py delete mode 100644 pytorch_lightning/metrics/functional/classification.py delete mode 100644 pytorch_lightning/metrics/functional/nlp.py delete mode 100644 pytorch_lightning/metrics/functional/reduction.py delete mode 100644 pytorch_lightning/metrics/functional/regression.py delete mode 100644 pytorch_lightning/metrics/functional/self_supervised.py delete mode 100644 pytorch_lightning/metrics/metric.py delete mode 100644 pytorch_lightning/metrics/nlp.py delete mode 100644 pytorch_lightning/metrics/regression.py delete mode 100644 pytorch_lightning/metrics/self_supervised.py delete mode 100644 pytorch_lightning/metrics/sklearns.py delete mode 100644 tests/metrics/__init__.py delete mode 100644 tests/metrics/functional/__init__.py delete mode 100644 tests/metrics/functional/test_classification.py delete mode 100644 tests/metrics/functional/test_nlp.py delete mode 100644 tests/metrics/functional/test_reduction.py delete mode 100644 tests/metrics/functional/test_regression.py delete mode 100644 tests/metrics/functional/test_self_supervised.py delete mode 100644 tests/metrics/test_aggregation.py delete mode 100644 tests/metrics/test_classification.py delete mode 100644 tests/metrics/test_converters.py delete mode 100644 tests/metrics/test_metrics.py delete mode 100644 tests/metrics/test_nlp.py delete mode 100644 tests/metrics/test_regression.py delete mode 100644 tests/metrics/test_sklearn.py diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py deleted file mode 100644 index d8cad98998bc3..0000000000000 --- a/pytorch_lightning/metrics/__init__.py +++ /dev/null @@ -1,66 +0,0 @@ -from pytorch_lightning.metrics.classification import ( - Accuracy, - AveragePrecision, - ConfusionMatrix, - F1, - FBeta, - Recall, - ROC, - AUROC, - DiceCoefficient, - MulticlassPrecisionRecallCurve, - MulticlassROC, - Precision, - PrecisionRecallCurve, - IoU, -) -from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric -from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric -from pytorch_lightning.metrics.nlp import BLEUScore -from pytorch_lightning.metrics.self_supervised import EmbeddingSimilarity -from pytorch_lightning.metrics.regression import ( - MAE, - MSE, - PSNR, - RMSE, - RMSLE, - SSIM -) -from pytorch_lightning.metrics.sklearns import ( - AUC, - SklearnMetric, -) - -__classification_metrics = [ - "AUC", - "AUROC", - "Accuracy", - "AveragePrecision", - "ConfusionMatrix", - "DiceCoefficient", - "F1", - "FBeta", - "MulticlassPrecisionRecallCurve", - "MulticlassROC", - "Precision", - "PrecisionRecallCurve", - "ROC", - "Recall", - "IoU", -] -__regression_metrics = [ - "MAE", - "MSE", - "PSNR", - "RMSE", - "RMSLE", - "SSIM" -] -__sequence_metrics = ["BLEUScore"] -__selfsuper_metrics = ["EmbeddingSimilarity"] - -__all__ = __regression_metrics \ - + __classification_metrics \ - + __selfsuper_metrics \ - + __sequence_metrics \ - + ["SklearnMetric"] diff --git a/pytorch_lightning/metrics/classification.py b/pytorch_lightning/metrics/classification.py deleted file mode 100644 index d44d93f01a8ac..0000000000000 --- a/pytorch_lightning/metrics/classification.py +++ /dev/null @@ -1,866 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Optional, Sequence, Tuple - -import torch - -from pytorch_lightning.metrics.functional.classification import ( - accuracy, - auroc, - average_precision, - confusion_matrix, - _confmat_normalize, - dice_score, - f1_score, - fbeta_score, - iou, - multiclass_precision_recall_curve, - multiclass_roc, - precision_recall_curve, - roc, - precision_recall -) -from pytorch_lightning.metrics.functional.reduction import class_reduce -from pytorch_lightning.metrics.metric import TensorMetric - - -class Accuracy(TensorMetric): - """ - Computes the accuracy classification score - - Example: - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 2, 2]) - >>> metric = Accuracy() - >>> metric(pred, target) - tensor(0.7500) - - """ - - def __init__( - self, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', - reduce_group: Any = None, - ): - """ - Args: - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - reduce_group: the process group to reduce metric results from DDP - """ - super().__init__(name="accuracy", reduce_group=reduce_group) - self.num_classes = num_classes - assert class_reduction in ('micro', 'macro', 'weighted', 'none') - self.class_reduction = class_reduction - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: ground truth labels - - Return: - A Tensor with the classification score. - """ - return accuracy(pred=pred, target=target, - num_classes=self.num_classes, - class_reduction='none', - return_state=True) - - @staticmethod - def compute(self, data: Any, output: Any): - tps, sups = output['tps'], output['sups'] - return class_reduce(tps, sups, sups, class_reduction=self.class_reduction) - - -class ConfusionMatrix(TensorMetric): - """ - Computes the confusion matrix C where each entry C_{i,j} is the number of observations - in group i that were predicted in group j. - - Example: - - >>> pred = torch.tensor([0, 1, 2, 2]) - >>> target = torch.tensor([0, 1, 2, 2]) - >>> metric = ConfusionMatrix() - >>> metric(pred, target) - tensor([[1., 0., 0.], - [0., 1., 0.], - [0., 0., 2.]]) - - """ - - def __init__( - self, - num_classes: Optional[int] = None, - normalize: bool = False, - reduce_group: Any = None, - ): - """ - Args: - num_classes: number of classes - normalize: whether to compute a normalized confusion matrix - reduce_group: the process group to reduce metric results from DDP - """ - super().__init__( - name="confusion_matrix", - reduce_group=reduce_group, - ) - self.normalize = normalize - self.num_classes = num_classes - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: ground truth labels - - Return: - A Tensor with the confusion matrix. - """ - return confusion_matrix(pred=pred, target=target, - normalize=False, # we normalize after ddp sync - num_classes=self.num_classes) - - @staticmethod - def compute(self, data: Any, output: Any): - """ Confusion matrix normalization needs to happen after ddp sync """ - confmat = output - if self.normalize: - confmat = _confmat_normalize(confmat) - return confmat - - -class PrecisionRecallCurve(TensorMetric): - """ - Computes the precision recall curve - - Example: - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 2, 2]) - >>> metric = PrecisionRecallCurve() - >>> prec, recall, thr = metric(pred, target) - >>> prec - tensor([0.3333, 0.0000, 0.0000, 1.0000]) - >>> recall - tensor([1., 0., 0., 0.]) - >>> thr - tensor([1., 2., 3.]) - - """ - - def __init__( - self, - pos_label: int = 1, - reduce_group: Any = None, - ): - """ - Args: - pos_label: positive label indicator - reduce_group: the process group to reduce metric results from DDP - """ - super().__init__( - name="precision_recall_curve", - reduce_group=reduce_group, - ) - - self.pos_label = pos_label - - def forward( - self, - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Actual metric computation - - Args: - pred: predicted labels - target: groundtruth labels - sample_weight: the weights per sample - - Return: - - precision values - - recall values - - threshold values - """ - return precision_recall_curve(pred=pred, target=target, - sample_weight=sample_weight, pos_label=self.pos_label) - - -class Precision(TensorMetric): - """ - Computes the precision score - - Example: - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 2, 2]) - >>> metric = Precision(num_classes=4, class_reduction='macro') - >>> metric(pred, target) - tensor(0.7500) - - """ - - def __init__( - self, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', - reduce_group: Any = None, - ): - """ - Args: - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - reduce_group: the process group to reduce metric results from DDP - """ - super().__init__( - name="precision", - reduce_group=reduce_group, - ) - self.num_classes = num_classes - assert class_reduction in ('micro', 'macro', 'weighted', 'none') - self.class_reduction = class_reduction - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: ground truth labels - - Return: - A Tensor with the classification score. - """ - return precision_recall(pred=pred, target=target, - num_classes=self.num_classes, - class_reduction='none', - return_state=True) - - @staticmethod - def compute(self, data: Any, output: Any): - tps, fps, sups = output['tps'], output['fps'], output['sups'] - return class_reduce(tps, tps + fps, sups, class_reduction=self.class_reduction) - - -class Recall(TensorMetric): - """ - Computes the recall score - - Example: - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 2, 2]) - >>> metric = Recall() - >>> metric(pred, target) - tensor(0.7500) - - """ - - def __init__( - self, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', - reduce_group: Any = None, - ): - """ - Args: - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - reduce_group: the process group to reduce metric results from DDP - """ - super().__init__( - name="recall", - reduce_group=reduce_group, - ) - - self.num_classes = num_classes - assert class_reduction in ('micro', 'macro', 'weighted', 'none') - self.class_reduction = class_reduction - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: ground truth labels - - Return: - A Tensor with the classification score. - """ - return precision_recall(pred=pred, target=target, - num_classes=self.num_classes, - class_reduction='none', - return_state=True) - - @staticmethod - def compute(self, data: Any, output: Any): - tps, fns, sups = output['tps'], output['fns'], output['sups'] - return class_reduce(tps, tps + fns, sups, class_reduction=self.class_reduction) - - -class AveragePrecision(TensorMetric): - """ - Computes the average precision score - - Example: - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 2, 2]) - >>> metric = AveragePrecision() - >>> metric(pred, target) - tensor(0.3333) - - """ - - def __init__( - self, - pos_label: int = 1, - reduce_group: Any = None, - ): - """ - Args: - pos_label: positive label indicator - reduce_group: the process group to reduce metric results from DDP - """ - super().__init__( - name="AP", - reduce_group=reduce_group, - ) - - self.pos_label = pos_label - - def forward( - self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None - ) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: groundtruth labels - sample_weight: the weights per sample - - Return: - torch.Tensor: classification score - """ - return average_precision(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label) - - -class AUROC(TensorMetric): - """ - Computes the area under curve (AUC) of the receiver operator characteristic (ROC) - - Example: - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 0]) - >>> metric = AUROC() - >>> metric(pred, target) - tensor(0.5000) - - """ - - def __init__( - self, - pos_label: int = 1, - reduce_group: Any = None, - ): - """ - Args: - pos_label: positive label indicator - reduce_group: the process group to reduce metric results from DDP - """ - super().__init__( - name="auroc", - reduce_group=reduce_group, - ) - - self.pos_label = pos_label - - def forward( - self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None - ) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: groundtruth labels - sample_weight: the weights per sample - - Return: - torch.Tensor: classification score - """ - return auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label) - - -class FBeta(TensorMetric): - """ - Computes the FBeta Score, which is the weighted harmonic mean of precision and recall. - It ranges between 1 and 0, where 1 is perfect and the worst value is 0. - - Example: - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 2, 2]) - >>> metric = FBeta(0.25, class_reduction='macro') - >>> metric(pred, target) - tensor(0.7361) - """ - - def __init__( - self, - beta: float, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', - reduce_group: Any = None, - ): - """ - Args: - beta: determines the weight of recall in the combined score. - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - reduce_group: the process group to reduce metric results from DDP - """ - super().__init__( - name="fbeta", - reduce_group=reduce_group, - ) - - self.beta = beta - self.num_classes = num_classes - assert class_reduction in ('micro', 'macro', 'weighted', 'none') - self.class_reduction = class_reduction - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: groundtruth labels - - Return: - torch.Tensor: classification score - """ - return precision_recall(pred=pred, target=target, - num_classes=self.num_classes, - class_reduction='none', - return_state=True) - - @staticmethod - def compute(self, data: Any, output: Any): - """ tps, fps, fns, sups needs to be synced before we do any calculations """ - tps, fps, fns, sups = output['tps'], output['fps'], output['fns'], output['sups'] - - intermidiate_reduction = 'none' if self.class_reduction != "micro" else 'micro' - precision = class_reduce(tps, tps + fps, sups, class_reduction=intermidiate_reduction) - recall = class_reduce(tps, tps + fns, sups, class_reduction=intermidiate_reduction) - - num = (1 + self.beta ** 2) * precision * recall - denom = ((self.beta ** 2) * precision + recall) - if intermidiate_reduction == 'micro': - return torch.sum(num) / torch.sum(denom) - return class_reduce(num, denom, sups, class_reduction=self.class_reduction) - - -class F1(FBeta): - """ - Computes the F1 score, which is the harmonic mean of the precision and recall. - It ranges between 1 and 0, where 1 is perfect and the worst value is 0. - - Example: - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 2, 2]) - >>> metric = F1(class_reduction='macro') - >>> metric(pred, target) - tensor(0.6667) - """ - - def __init__( - self, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', - reduce_group: Any = None, - ): - """ - Args: - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - reduce_group: the process group to reduce metric results from DDP - """ - super().__init__(beta=1.0, - num_classes=num_classes, - class_reduction=class_reduction, - reduce_group=reduce_group) - self.name = "f1" - - -class ROC(TensorMetric): - """ - Computes the Receiver Operator Characteristic (ROC) - - Example: - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 2, 2]) - >>> metric = ROC() - >>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE - (tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), - tensor([0., 0., 0., 1., 1.]), - tensor([4., 3., 2., 1., 0.])) - - """ - - def __init__( - self, - pos_label: int = 1, - reduce_group: Any = None, - ): - """ - Args: - pos_label: positive label indicator - reduce_group: the process group to reduce metric results from DDP - """ - super().__init__( - name="roc", - reduce_group=reduce_group, - ) - - self.pos_label = pos_label - - def forward( - self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Actual metric computation - - Args: - pred: predicted labels - target: groundtruth labels - sample_weight: the weights per sample - - Return: - - false positive rate - - true positive rate - - thresholds - """ - return roc(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label) - - -class MulticlassROC(TensorMetric): - """ - Computes the multiclass ROC - - Example: - - >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - ... [0.05, 0.85, 0.05, 0.05], - ... [0.05, 0.05, 0.85, 0.05], - ... [0.05, 0.05, 0.05, 0.85]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> metric = MulticlassROC() - >>> classes_roc = metric(pred, target) - >>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE - ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), - (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), - (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), - (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500]))) - """ - - def __init__( - self, - num_classes: Optional[int] = None, - reduce_group: Any = None, - ): - """ - Args: - num_classes: number of classes - reduce_group: the process group to reduce metric results from DDP - """ - super().__init__( - name="multiclass_roc", - reduce_group=reduce_group, - ) - - self.num_classes = num_classes - - def forward( - self, - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: - """ - Actual metric computation - - Args: - pred: predicted probability for each label - target: groundtruth labels - sample_weight: Weights for each sample defining the sample's impact on the score - - Return: - tuple: A tuple consisting of one tuple per class, holding false positive rate, true positive rate and thresholds - - """ - return multiclass_roc(pred=pred, target=target, sample_weight=sample_weight, num_classes=self.num_classes) - - def aggregate(self, *tensors: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: - """Aggregates results by stacking them instead of concatenating before averaging. - - Returns: - the aggregated results - """ - - return tuple([tuple([torch.stack(tmps).mean(0) for tmps in zip(*_tensors)]) for _tensors in zip(*tensors)]) - - -class MulticlassPrecisionRecallCurve(TensorMetric): - """Computes the multiclass PR Curve - - Example: - - >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - ... [0.05, 0.85, 0.05, 0.05], - ... [0.05, 0.05, 0.85, 0.05], - ... [0.05, 0.05, 0.05, 0.85]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> metric = MulticlassPrecisionRecallCurve() - >>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE - ((tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])), - (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])), - (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])), - (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500]))) - """ - - def __init__( - self, - num_classes: Optional[int] = None, - reduce_group: Any = None, - ): - """ - Args: - num_classes: number of classes - reduce_group: the process group to reduce metric results from DDP - - """ - super().__init__( - name="multiclass_precision_recall_curve", - reduce_group=reduce_group, - ) - - self.num_classes = num_classes - - def forward( - self, - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Actual metric computation - - Args: - pred: predicted probability for each label - target: groundtruth labels - sample_weight: Weights for each sample defining the sample's impact on the score - - Return: - tuple: A tuple consisting of one tuple per class, holding precision, recall and thresholds - - """ - return multiclass_precision_recall_curve( - pred=pred, target=target, sample_weight=sample_weight, num_classes=self.num_classes - ) - - def aggregate(self, *tensors: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: - """Aggregates results by stacking them instead of concatenating before averaging. - - Returns: - the aggregated results - """ - - return tuple([tuple([torch.stack(tmps).mean(0) for tmps in zip(*_tensors)]) for _tensors in zip(*tensors)]) - - -class DiceCoefficient(TensorMetric): - """ - Computes the dice coefficient - - Example: - - >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - ... [0.05, 0.85, 0.05, 0.05], - ... [0.05, 0.05, 0.85, 0.05], - ... [0.05, 0.05, 0.05, 0.85]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> metric = DiceCoefficient() - >>> metric(pred, target) - tensor(0.3333) - """ - - def __init__( - self, - include_background: bool = False, - nan_score: float = 0.0, - no_fg_score: float = 0.0, - reduction: str = "elementwise_mean", - reduce_group: Any = None, - ): - """ - Args: - include_background: whether to also compute dice for the background - nan_score: score to return, if a NaN occurs during computation (denom zero) - no_fg_score: score to return, if no foreground pixel was found in target - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - reduce_group: the process group to reduce metric results from DDP - """ - super().__init__( - name="dice", - reduce_group=reduce_group, - ) - - self.include_background = include_background - self.nan_score = nan_score - self.no_fg_score = no_fg_score - self.reduction = reduction - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted probability for each label - target: groundtruth labels - - Return: - torch.Tensor: the calculated dice coefficient - """ - return dice_score( - pred=pred, - target=target, - bg=self.include_background, - nan_score=self.nan_score, - no_fg_score=self.no_fg_score, - reduction=self.reduction, - ) - - -class IoU(TensorMetric): - """ - Computes the intersection over union. - - Example: - - >>> pred = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0], - ... [0, 0, 1, 1, 1, 0, 0, 0], - ... [0, 0, 0, 0, 0, 0, 0, 0]]) - >>> target = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0], - ... [0, 0, 0, 1, 1, 1, 0, 0], - ... [0, 0, 0, 0, 0, 0, 0, 0]]) - >>> metric = IoU() - >>> metric(pred, target) - tensor(0.7045) - - """ - - def __init__( - self, - ignore_index: Optional[int] = None, - absent_score: float = 0.0, - num_classes: Optional[int] = None, - reduction: str = "elementwise_mean", - ): - """ - Args: - ignore_index: optional int specifying a target class to ignore. If given, this class index does not - contribute to the returned score, regardless of reduction method. Has no effect if given an int that is - not in the range [0, num_classes-1], where num_classes is either given or derived from pred and target. - By default, no index is ignored, and all classes are used. - absent_score: score to use for an individual class, if no instances of the class index were present in - `y_pred` AND no instances of the class index were present in `y_true`. For example, if we have 3 - classes, [0, 0] for `y_pred`, and [0, 2] for `y_true`, then class 1 would be assigned the - `absent_score`. Default is 0.0. - num_classes: Optionally specify the number of classes - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - """ - super().__init__(name="iou") - self.ignore_index = ignore_index - self.absent_score = absent_score - self.num_classes = num_classes - self.reduction = reduction - - def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, sample_weight: Optional[torch.Tensor] = None): - """ - Actual metric calculation. - """ - return iou( - pred=y_pred, - target=y_true, - ignore_index=self.ignore_index, - absent_score=self.absent_score, - num_classes=self.num_classes, - reduction=self.reduction, - ) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py deleted file mode 100644 index f62a4709810aa..0000000000000 --- a/pytorch_lightning/metrics/converters.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file provides functions and decorators for automated input and output -conversion to/from :class:`numpy.ndarray` and :class:`torch.Tensor` as well as utilities to -sync tensors between different processes in a DDP scenario, when needed. -""" - -from functools import reduce -import numbers -from typing import Any, Callable, Optional, Union - -import numpy as np -import torch -from torch.utils.data._utils.collate import np_str_obj_array_pattern - -from pytorch_lightning.utilities.apply_func import apply_to_collection - -if torch.distributed.is_available(): - from torch.distributed import ReduceOp -else: - class ReduceOp: - SUM = None - - -def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable: - """ - Decorator function to apply a function to all inputs of a function. - - Args: - func_to_apply: the function to apply to the inputs - *dec_args: positional arguments for the function to be applied - **dec_kwargs: keyword arguments for the function to be applied - - Return: - the decorated function - """ - - def decorator_fn(func_to_decorate): - # actual function applying the give function to inputs - def new_func(*args, **kwargs): - args = func_to_apply(args, *dec_args, **dec_kwargs) - kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs) - return func_to_decorate(*args, **kwargs) - - return new_func - - return decorator_fn - - -def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable: - """ - Decorator function to apply a function to all outputs of a function. - - Args: - func_to_apply: the function to apply to the outputs - *dec_args: positional arguments for the function to be applied - **dec_kwargs: keyword arguments for the function to be applied - - Return: - the decorated function - """ - - def decorator_fn(function_to_decorate): - # actual function applying the give function to outputs - def new_func(*args, **kwargs): - result = function_to_decorate(*args, **kwargs) - return func_to_apply(result, *dec_args, **dec_kwargs) - - return new_func - - return decorator_fn - - -def convert_to_tensor(data: Any, dtype=None, device=None) -> Any: - """ - Maps all kind of collections and numbers to tensors. - - Args: - data: the data to convert to tensor - dtype: data type to convert to - device: device to cast to - - Return: - the converted data - """ - if isinstance(data, numbers.Number): - return torch.tensor([data], dtype=dtype, device=device) - # is not array of object - elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None: - return torch.from_numpy(data).to(device=device, dtype=dtype) - elif isinstance(data, torch.Tensor): - return data.to(device=device, dtype=dtype) - - raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!") - - -def convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray: - """Convert all tensors and numpy arrays to numpy arrays. - - Args: - data: the tensor or array to convert to numpy - - Return: - the resulting numpy array - """ - if isinstance(data, torch.Tensor): - return data.cpu().detach().numpy() - elif isinstance(data, numbers.Number): - return np.array([data]) - elif isinstance(data, np.ndarray): - return data - - raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__) - - -def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable: - """ - Decorator converting all inputs of a function to numpy - - Args: - func_to_decorate: the function whose inputs shall be converted - - Return: - Callable: the decorated function - """ - return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)( - func_to_decorate - ) - - -def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable: - """ - Decorator converting all outputs of a function to tensors - - Args: - func_to_decorate: the function whose outputs shall be converted - - Return: - Callable: the decorated function - """ - return _apply_to_outputs(convert_to_tensor)(func_to_decorate) - - -def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable: - """ - Decorator handling the argument conversion for metrics working on numpy. - All inputs of the decorated function will be converted to numpy and all - outputs will be converted to tensors. - - Args: - func_to_decorate: the function whose inputs and outputs shall be converted - - Return: - the decorated function - """ - # applies collection conversion from tensor to numpy to all inputs - # we need to include numpy arrays here, since otherwise they will also be treated as sequences - func_convert_inputs = _numpy_metric_input_conversion(func_to_decorate) - # converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric) - func_convert_in_out = _tensor_metric_output_conversion(func_convert_inputs) - return func_convert_in_out - - -def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable: - """ - Decorator converting all inputs of a function to tensors - - Args: - func_to_decorate: the function whose inputs shall be converted - - Return: - Callable: the decorated function - """ - return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)( - func_to_decorate - ) - - -def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable: - """ - Decorator converting all numpy arrays and numbers occuring in the outputs of a function to tensors - - Args: - func_to_decorate: the function whose outputs shall be converted - - Return: - Callable: the decorated function - """ - return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)( - func_to_decorate - ) - - -def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable: - """ - Decorator Handling the argument conversion for metrics working on tensors. - All inputs and outputs of the decorated function will be converted to tensors - - Args: - func_to_decorate: the function whose inputs and outputs shall be converted - - Return: - the decorated function - """ - # converts all inputs to tensor if possible - # we need to include tensors here, since otherwise they will also be treated as sequences - func_convert_inputs = _tensor_metric_input_conversion(func_to_decorate) - # convert all outputs to tensor if possible - return _tensor_metric_output_conversion(func_convert_inputs) - - -def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable: - """ - Decorator Handling the argument conversion for metrics working on tensors. - All inputs of the decorated function and all numpy arrays and numbers in - it's outputs will be converted to tensors - - Args: - func_to_decorate: the function whose inputs and outputs shall be converted - - Return: - the decorated function - """ - # converts all inputs to tensor if possible - # we need to include tensors here, since otherwise they will also be treated as sequences - func_convert_inputs = _tensor_metric_input_conversion(func_to_decorate) - # convert all outputs to tensor if possible - return _tensor_collection_metric_output_conversion(func_convert_inputs) - - -def sync_ddp_if_available( - result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None -) -> torch.Tensor: - """ - Function to reduce the tensors from several ddp processes to one master process - - Args: - result: the value to sync and reduce (typically tensor or number) - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to sum. - Can also be a string of 'avg', 'mean' to calculate the mean during reduction. - - Return: - reduced value - """ - - if torch.distributed.is_available() and torch.distributed.is_initialized(): - divide_by_world_size = False - - if group is None: - group = torch.distributed.group.WORLD - - if reduce_op is None: - reduce_op = torch.distributed.ReduceOp.SUM - elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): - reduce_op = torch.distributed.ReduceOp.SUM - divide_by_world_size = True - - # sync all processes before reduction - torch.distributed.barrier(group=group) - torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False) - - if divide_by_world_size: - result = result / torch.distributed.get_world_size(group) - - return result - - -def at_least_1d(tensor: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: - """Makes sure the tensor is at least of 1d shape - - Args: - tensor: the tensor or array to check the shape for - - Returns: - the optionally reshaped tensor - """ - if tensor.shape == (): - tensor = tensor.reshape(1, ) - return tensor - - -def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None): - """ - Function to gather all tensors from several ddp processes onto a list that - is broadcasted to all processes - - Args: - result: the value to sync - group: the process group to gather results from. Defaults to all processes (world) - - Return: - gathered_result: list with size equal to the process group where - gathered_result[i] corresponds to result tensor from process i - - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - if group is None: - group = torch.distributed.group.WORLD - - world_size = torch.distributed.get_world_size(group) - - gathered_result = [torch.zeros_like(result) for _ in range(world_size)] - - # sync and broadcast all - torch.distributed.barrier(group=group) - torch.distributed.all_gather(gathered_result, result, group) - - result = gathered_result - - return result - - -def sync_ddp(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable: - """ - This decorator syncs a functions outputs across different processes for DDP. - - Args: - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to sum - - Return: - the decorated function - - """ - - def decorator_fn(func_to_decorate): - return _apply_to_outputs( - apply_to_collection, torch.Tensor, sync_ddp_if_available, group=group, reduce_op=reduce_op - )(func_to_decorate) - - return decorator_fn - - -def numpy_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable: - """ - This decorator shall be used on all function metrics working on numpy arrays. - It handles the argument conversion and DDP reduction for metrics working on numpy. - All inputs of the decorated function will be converted to numpy and all - outputs will be converted to tensors. - In DDP Training all output tensors will be reduced according to the given rules. - - Args: - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to sum - - Return: - the decorated function - """ - - def decorator_fn(func_to_decorate): - return sync_ddp(group=group, reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate)) - - return decorator_fn - - -def tensor_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable: - """ - This decorator shall be used on all function metrics working on tensors. - It handles the argument conversion and DDP reduction for metrics working on tensors. - All inputs and outputs of the decorated function will be converted to tensors. - In DDP Training all output tensors will be reduced according to the given rules. - - Args: - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to sum - - Return: - the decorated function - """ - - def decorator_fn(func_to_decorate): - return sync_ddp(group=group, reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate)) - - return decorator_fn - - -def tensor_collection_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable: - """ - This decorator shall be used on all function metrics working on tensors and returning collections - that cannot be converted to tensors. - It handles the argument conversion and DDP reduction for metrics working on tensors. - All inputs and outputs of the decorated function will be converted to tensors. - In DDP Training all output tensors will be reduced according to the given rules. - - Args: - group: the process group to gather results from. Defaults to all processes (world) - reduce_op: the reduction operation. Defaults to sum - - Return: - the decorated function - """ - - def decorator_fn(func_to_decorate): - return sync_ddp(group=group, reduce_op=reduce_op)(_tensor_collection_metric_conversion(func_to_decorate)) - - return decorator_fn diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py deleted file mode 100644 index 02928c803f19d..0000000000000 --- a/pytorch_lightning/metrics/functional/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -from pytorch_lightning.metrics.functional.classification import ( - accuracy, - auc, - auroc, - average_precision, - confusion_matrix, - dice_score, - f1_score, - fbeta_score, - multiclass_precision_recall_curve, - multiclass_roc, - precision, - precision_recall, - precision_recall_curve, - recall, - roc, - stat_scores, - stat_scores_multiple_classes, - to_categorical, - to_onehot, - iou, -) -from pytorch_lightning.metrics.functional.nlp import bleu_score -from pytorch_lightning.metrics.functional.regression import ( - mae, - mse, - psnr, - rmse, - rmsle, - ssim -) -from pytorch_lightning.metrics.functional.self_supervised import ( - embedding_similarity -) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py deleted file mode 100644 index 20510c4c088f8..0000000000000 --- a/pytorch_lightning/metrics/functional/classification.py +++ /dev/null @@ -1,1056 +0,0 @@ -from functools import wraps -from typing import Callable, Optional, Sequence, Tuple - -import torch -from torch.nn import functional as F - -from pytorch_lightning.metrics.functional.reduction import class_reduce, reduce -from pytorch_lightning.utilities import FLOAT16_EPSILON, rank_zero_warn - - -def to_onehot( - tensor: torch.Tensor, - num_classes: Optional[int] = None, -) -> torch.Tensor: - """ - Converts a dense label tensor to one-hot format - - Args: - tensor: dense label tensor, with shape [N, d1, d2, ...] - num_classes: number of classes C - - Output: - A sparse label tensor with shape [N, C, d1, d2, ...] - - Example: - - >>> x = torch.tensor([1, 2, 3]) - >>> to_onehot(x) - tensor([[0, 1, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1]]) - - """ - if num_classes is None: - num_classes = int(tensor.max().detach().item() + 1) - dtype, device, shape = tensor.dtype, tensor.device, tensor.shape - tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], - dtype=dtype, device=device) - index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) - return tensor_onehot.scatter_(1, index, 1.0) - - -def to_categorical( - tensor: torch.Tensor, - argmax_dim: int = 1 -) -> torch.Tensor: - """ - Converts a tensor of probabilities to a dense label tensor - - Args: - tensor: probabilities to get the categorical label [N, d1, d2, ...] - argmax_dim: dimension to apply - - Return: - A tensor with categorical labels [N, d2, ...] - - Example: - - >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) - >>> to_categorical(x) - tensor([1, 0]) - - """ - return torch.argmax(tensor, dim=argmax_dim) - - -def get_num_classes( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, -) -> int: - """ - Calculates the number of classes for a given prediction and target tensor. - - Args: - pred: predicted values - target: true labels - num_classes: number of classes if known - - Return: - An integer that represents the number of classes. - """ - num_target_classes = int(target.max().detach().item() + 1) - num_pred_classes = int(pred.max().detach().item() + 1) - num_all_classes = max(num_target_classes, num_pred_classes) - - if num_classes is None: - num_classes = num_all_classes - elif num_classes != num_all_classes: - rank_zero_warn(f'You have set {num_classes} number of classes which is' - f' different from predicted ({num_pred_classes}) and' - f' target ({num_target_classes}) number of classes', - RuntimeWarning) - return num_classes - - -def stat_scores( - pred: torch.Tensor, - target: torch.Tensor, - class_index: int, argmax_dim: int = 1, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Calculates the number of true positive, false positive, true negative - and false negative for a specific class - - Args: - pred: prediction tensor - target: target tensor - class_index: class to calculate over - argmax_dim: if pred is a tensor of probabilities, this indicates the - axis the argmax transformation will be applied over - - Return: - True Positive, False Positive, True Negative, False Negative, Support - - Example: - - >>> x = torch.tensor([1, 2, 3]) - >>> y = torch.tensor([0, 2, 3]) - >>> tp, fp, tn, fn, sup = stat_scores(x, y, class_index=1) - >>> tp, fp, tn, fn, sup - (tensor(0), tensor(1), tensor(2), tensor(0), tensor(0)) - - """ - if pred.ndim == target.ndim + 1: - pred = to_categorical(pred, argmax_dim=argmax_dim) - - tp = ((pred == class_index) * (target == class_index)).to(torch.long).sum() - fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum() - tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum() - fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum() - sup = (target == class_index).to(torch.long).sum() - - return tp, fp, tn, fn, sup - - -def stat_scores_multiple_classes( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - argmax_dim: int = 1, - reduction: str = 'none', -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Calculates the number of true positive, false positive, true negative - and false negative for each class - - Args: - pred: prediction tensor - target: target tensor - num_classes: number of classes if known - argmax_dim: if pred is a tensor of probabilities, this indicates the - axis the argmax transformation will be applied over - reduction: a method to reduce metric score over labels (default: none) - Available reduction methods: - - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements - - Return: - True Positive, False Positive, True Negative, False Negative, Support - - Example: - - >>> x = torch.tensor([1, 2, 3]) - >>> y = torch.tensor([0, 2, 3]) - >>> tps, fps, tns, fns, sups = stat_scores_multiple_classes(x, y) - >>> tps - tensor([0., 0., 1., 1.]) - >>> fps - tensor([0., 1., 0., 0.]) - >>> tns - tensor([2., 2., 2., 2.]) - >>> fns - tensor([1., 0., 0., 0.]) - >>> sups - tensor([1., 0., 1., 1.]) - - """ - if pred.ndim == target.ndim + 1: - pred = to_categorical(pred, argmax_dim=argmax_dim) - - num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) - - if pred.dtype != torch.bool: - pred = pred.clamp_max(max=num_classes) - if target.dtype != torch.bool: - target = target.clamp_max(max=num_classes) - - possible_reductions = ('none', 'sum', 'elementwise_mean') - if reduction not in possible_reductions: - raise ValueError("reduction type %s not supported" % reduction) - - if reduction == 'none': - pred = pred.view((-1, )).long() - target = target.view((-1, )).long() - - tps = torch.zeros((num_classes + 1,), device=pred.device) - fps = torch.zeros((num_classes + 1,), device=pred.device) - tns = torch.zeros((num_classes + 1,), device=pred.device) - fns = torch.zeros((num_classes + 1,), device=pred.device) - sups = torch.zeros((num_classes + 1,), device=pred.device) - - match_true = (pred == target).float() - match_false = 1 - match_true - - tps.scatter_add_(0, pred, match_true) - fps.scatter_add_(0, pred, match_false) - fns.scatter_add_(0, target, match_false) - tns = pred.size(0) - (tps + fps + fns) - sups.scatter_add_(0, target, torch.ones_like(match_true)) - - tps = tps[:num_classes] - fps = fps[:num_classes] - tns = tns[:num_classes] - fns = fns[:num_classes] - sups = sups[:num_classes] - - elif reduction == 'sum' or reduction == 'elementwise_mean': - count_match_true = (pred == target).sum().float() - oob_tp, oob_fp, oob_tn, oob_fn, oob_sup = stat_scores(pred, target, num_classes, argmax_dim) - - tps = count_match_true - oob_tp - fps = pred.nelement() - count_match_true - oob_fp - fns = pred.nelement() - count_match_true - oob_fn - tns = pred.nelement() * (num_classes + 1) - (tps + fps + fns + oob_tn) - sups = pred.nelement() - oob_sup.float() - - if reduction == 'elementwise_mean': - tps /= num_classes - fps /= num_classes - fns /= num_classes - tns /= num_classes - sups /= num_classes - - return tps.float(), fps.float(), tns.float(), fns.float(), sups.float() - - -def accuracy( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', - return_state: bool = False -) -> torch.Tensor: - """ - Computes the accuracy classification score - - Args: - pred: predicted labels - target: ground truth labels - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - return_state: returns a internal state that can be ddp reduced - before doing the final calculation - Return: - A Tensor with the accuracy score. - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> accuracy(x, y) - tensor(0.7500) - - """ - tps, fps, tns, fns, sups = stat_scores_multiple_classes( - pred=pred, target=target, num_classes=num_classes) - if return_state: - return {'tps': tps, 'sups': sups} - return class_reduce(tps, sups, sups, class_reduction=class_reduction) - - -def _confmat_normalize(cm): - """ Normalization function for confusion matrix """ - cm = cm / cm.sum(-1, keepdim=True) - nan_elements = cm[torch.isnan(cm)].nelement() - if nan_elements != 0: - cm[torch.isnan(cm)] = 0 - rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.') - return cm - - -def confusion_matrix( - pred: torch.Tensor, - target: torch.Tensor, - normalize: bool = False, - num_classes: Optional[int] = None -) -> torch.Tensor: - """ - Computes the confusion matrix C where each entry C_{i,j} is the number of observations - in group i that were predicted in group j. - - Args: - pred: estimated targets - target: ground truth labels - normalize: normalizes confusion matrix - num_classes: number of classes - - Return: - Tensor, confusion matrix C [num_classes, num_classes ] - - Example: - - >>> x = torch.tensor([1, 2, 3]) - >>> y = torch.tensor([0, 2, 3]) - >>> confusion_matrix(x, y) - tensor([[0., 1., 0., 0.], - [0., 0., 0., 0.], - [0., 0., 1., 0.], - [0., 0., 0., 1.]]) - """ - num_classes = get_num_classes(pred, target, num_classes) - - unique_labels = (target.view(-1) * num_classes + pred.view(-1)).to(torch.int) - - bins = torch.bincount(unique_labels, minlength=num_classes ** 2) - cm = bins.reshape(num_classes, num_classes).squeeze().float() - - if normalize: - cm = _confmat_normalize(cm) - - return cm - - -def precision_recall( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', - return_support: bool = False, - return_state: bool = False -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Computes precision and recall for different thresholds - - Args: - pred: estimated probabilities - target: ground-truth labels - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - return_support: returns the support for each class, need for fbeta/f1 calculations - return_state: returns a internal state that can be ddp reduced - before doing the final calculation - - Return: - Tensor with precision and recall - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 2, 2, 2]) - >>> precision_recall(x, y, class_reduction='macro') - (tensor(0.5000), tensor(0.3333)) - - """ - tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) - - precision = class_reduce(tps, tps + fps, sups, class_reduction=class_reduction) - recall = class_reduce(tps, tps + fns, sups, class_reduction=class_reduction) - if return_state: - return {'tps': tps, 'fps': fps, 'fns': fns, 'sups': sups} - if return_support: - return precision, recall, sups - return precision, recall - - -def precision( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', -) -> torch.Tensor: - """ - Computes precision score. - - Args: - pred: estimated probabilities - target: ground-truth labels - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - Return: - Tensor with precision. - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> precision(x, y) - tensor(0.7500) - - """ - return precision_recall(pred=pred, target=target, - num_classes=num_classes, class_reduction=class_reduction)[0] - - -def recall( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', -) -> torch.Tensor: - """ - Computes recall score. - - Args: - pred: estimated probabilities - target: ground-truth labels - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - Return: - Tensor with recall. - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> recall(x, y) - tensor(0.7500) - """ - return precision_recall(pred=pred, target=target, - num_classes=num_classes, class_reduction=class_reduction)[1] - - -def fbeta_score( - pred: torch.Tensor, - target: torch.Tensor, - beta: float, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', -) -> torch.Tensor: - """ - Computes the F-beta score which is a weighted harmonic mean of precision and recall. - It ranges between 1 and 0, where 1 is perfect and the worst value is 0. - - Args: - pred: estimated probabilities - target: ground-truth labels - beta: weights recall when combining the score. - beta < 1: more weight to precision. - beta > 1 more weight to recall - beta = 0: only precision - beta -> inf: only recall - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - Return: - Tensor with the value of F-score. It is a value between 0-1. - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> fbeta_score(x, y, 0.2) - tensor(0.7500) - """ - # We need to differentiate at which point to do class reduction - intermidiate_reduction = 'none' if class_reduction != "micro" else 'micro' - - prec, rec, sups = precision_recall(pred=pred, target=target, - num_classes=num_classes, - class_reduction=intermidiate_reduction, - return_support=True) - num = (1 + beta ** 2) * prec * rec - denom = ((beta ** 2) * prec + rec) - if intermidiate_reduction == 'micro': - return torch.sum(num) / torch.sum(denom) - return class_reduce(num, denom, sups, class_reduction=class_reduction) - - -def f1_score( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', -) -> torch.Tensor: - """ - Computes the F1-score (a.k.a F-measure), which is the harmonic mean of the precision and recall. - It ranges between 1 and 0, where 1 is perfect and the worst value is 0. - - Args: - pred: estimated probabilities - target: ground-truth labels - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - Return: - Tensor containing F1-score - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> f1_score(x, y) - tensor(0.7500) - """ - return fbeta_score(pred=pred, target=target, beta=1., - num_classes=num_classes, class_reduction=class_reduction) - - -def _binary_clf_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py - """ - if sample_weight is not None and not isinstance(sample_weight, torch.Tensor): - sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float) - - # remove class dimension if necessary - if pred.ndim > target.ndim: - pred = pred[:, 0] - desc_score_indices = torch.argsort(pred, descending=True) - - pred = pred[desc_score_indices] - target = target[desc_score_indices] - - if sample_weight is not None: - weight = sample_weight[desc_score_indices] - else: - weight = 1. - - # pred typically has many tied values. Here we extract - # the indices associated with the distinct values. We also - # concatenate a value for the end of the curve. - distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0] - threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) - - target = (target == pos_label).to(torch.long) - tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] - - if sample_weight is not None: - # express fps as a cumsum to ensure fps is increasing even in - # the presence of floating point errors - fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] - else: - fps = 1 + threshold_idxs - tps - - return fps, tps, pred[threshold_idxs] - - -def roc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - - Return: - false-positive rate (fpr), true-positive rate (tpr), thresholds - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 1, 1]) - >>> fpr, tpr, thresholds = roc(x, y) - >>> fpr - tensor([0., 0., 0., 0., 1.]) - >>> tpr - tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) - >>> thresholds - tensor([4, 3, 2, 1, 0]) - - """ - fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) - - # Add an extra threshold position - # to make sure that the curve starts at (0, 0) - tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) - fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) - thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) - - if fps[-1] <= 0: - raise ValueError("No negative samples in targets, false positive value should be meaningless") - - fpr = fps / fps[-1] - - if tps[-1] <= 0: - raise ValueError("No positive samples in targets, true positive value should be meaningless") - - tpr = tps / tps[-1] - - return fpr, tpr, thresholds - - -def multiclass_roc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - num_classes: Optional[int] = None, -) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: - """ - Computes the Receiver Operating Characteristic (ROC) for multiclass predictors. - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - num_classes: number of classes (default: None, computes automatically from data) - - Return: - returns roc for each class. - Number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds - - Example: - - >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - ... [0.05, 0.85, 0.05, 0.05], - ... [0.05, 0.05, 0.85, 0.05], - ... [0.05, 0.05, 0.05, 0.85]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE - ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), - (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), - (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), - (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500]))) - """ - num_classes = get_num_classes(pred, target, num_classes) - - class_roc_vals = [] - for c in range(num_classes): - pred_c = pred[:, c] - - class_roc_vals.append(roc(pred=pred_c, target=target, - sample_weight=sample_weight, pos_label=c)) - - return tuple(class_roc_vals) - - -def precision_recall_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Computes precision-recall pairs for different thresholds. - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - - Return: - precision, recall, thresholds - - Example: - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 0]) - >>> precision, recall, thresholds = precision_recall_curve(pred, target) - >>> precision - tensor([0.6667, 0.5000, 0.0000, 1.0000]) - >>> recall - tensor([1.0000, 0.5000, 0.0000, 0.0000]) - >>> thresholds - tensor([1, 2, 3]) - - """ - fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) - - precision = tps / (tps + fps) - recall = tps / tps[-1] - - # stop when full recall attained - # and reverse the outputs so recall is decreasing - last_ind = torch.where(tps == tps[-1])[0][0] - sl = slice(0, last_ind.item() + 1) - - # need to call reversed explicitly, since including that to slice would - # introduce negative strides that are not yet supported in pytorch - precision = torch.cat([reversed(precision[sl]), - torch.ones(1, dtype=precision.dtype, - device=precision.device)]) - - recall = torch.cat([reversed(recall[sl]), - torch.zeros(1, dtype=recall.dtype, - device=recall.device)]) - - thresholds = torch.tensor(reversed(thresholds[sl])) - - return precision, recall, thresholds - - -def multiclass_precision_recall_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - num_classes: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Computes precision-recall pairs for different thresholds given a multiclass scores. - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weight - num_classes: number of classes - - Return: - number of classes, precision, recall, thresholds - - Example: - - >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - ... [0.05, 0.85, 0.05, 0.05], - ... [0.05, 0.05, 0.85, 0.05], - ... [0.05, 0.05, 0.05, 0.85]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> nb_classes, precision, recall, thresholds = multiclass_precision_recall_curve(pred, target) - >>> nb_classes - (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])) - >>> precision - (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])) - >>> recall - (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])) - >>> thresholds # doctest: +NORMALIZE_WHITESPACE - (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])) - """ - num_classes = get_num_classes(pred, target, num_classes) - - class_pr_vals = [] - for c in range(num_classes): - pred_c = pred[:, c] - - class_pr_vals.append(precision_recall_curve( - pred=pred_c, - target=target, - sample_weight=sample_weight, pos_label=c)) - - return tuple(class_pr_vals) - - -def auc( - x: torch.Tensor, - y: torch.Tensor, - reorder: bool = True -) -> torch.Tensor: - """ - Computes Area Under the Curve (AUC) using the trapezoidal rule - - Args: - x: x-coordinates - y: y-coordinates - reorder: reorder coordinates, so they are increasing - - Return: - Tensor containing AUC score (float) - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> auc(x, y) - tensor(4.) - """ - direction = 1. - - if reorder: - # can't use lexsort here since it is not implemented for torch - order = torch.argsort(x) - x, y = x[order], y[order] - else: - dx = x[1:] - x[:-1] - if (dx < 0).any(): - if (dx, 0).all(): - direction = -1. - else: - raise ValueError("Reordering is not turned on, and " - "the x array is not increasing: %s" % x) - - return direction * torch.trapz(y, x) - - -def auc_decorator(reorder: bool = True) -> Callable: - def wrapper(func_to_decorate: Callable) -> Callable: - @wraps(func_to_decorate) - def new_func(*args, **kwargs) -> torch.Tensor: - x, y = func_to_decorate(*args, **kwargs)[:2] - - return auc(x, y, reorder=reorder) - - return new_func - - return wrapper - - -def multiclass_auc_decorator(reorder: bool = True) -> Callable: - def wrapper(func_to_decorate: Callable) -> Callable: - def new_func(*args, **kwargs) -> torch.Tensor: - results = [] - for class_result in func_to_decorate(*args, **kwargs): - x, y = class_result[:2] - results.append(auc(x, y, reorder=reorder)) - - return torch.cat(results) - - return new_func - - return wrapper - - -def auroc( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., -) -> torch.Tensor: - """ - Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - - Return: - Tensor containing ROCAUC score - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 1, 0]) - >>> auroc(x, y) - tensor(0.5000) - """ - if any(target > 1): - raise ValueError('AUROC metric is meant for binary classification, but' - ' target tensor contains value different from 0 and 1.' - ' Multiclass is currently not supported.') - - @auc_decorator(reorder=True) - def _auroc(pred, target, sample_weight, pos_label): - return roc(pred, target, sample_weight, pos_label) - - return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) - - -def average_precision( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., -) -> torch.Tensor: - """ - Compute average precision from prediction scores - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - - Return: - Tensor containing average precision score - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> average_precision(x, y) - tensor(0.3333) - """ - precision, recall, _ = precision_recall_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) - # Return the step function integral - # The following works because the last entry of precision is - # guaranteed to be 1, as returned by precision_recall_curve - return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) - - -def dice_score( - pred: torch.Tensor, - target: torch.Tensor, - bg: bool = False, - nan_score: float = 0.0, - no_fg_score: float = 0.0, - reduction: str = 'elementwise_mean', -) -> torch.Tensor: - """ - Compute dice score from prediction scores - - Args: - pred: estimated probabilities - target: ground-truth labels - bg: whether to also compute dice for the background - nan_score: score to return, if a NaN occurs during computation - no_fg_score: score to return, if no foreground pixel was found in target - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - Return: - Tensor containing dice score - - Example: - - >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - ... [0.05, 0.85, 0.05, 0.05], - ... [0.05, 0.05, 0.85, 0.05], - ... [0.05, 0.05, 0.05, 0.85]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> dice_score(pred, target) - tensor(0.3333) - - """ - num_classes = pred.shape[1] - bg = (1 - int(bool(bg))) - scores = torch.zeros(num_classes - bg, device=pred.device, dtype=torch.float32) - for i in range(bg, num_classes): - if not (target == i).any(): - # no foreground class - scores[i - bg] += no_fg_score - continue - - tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i) - denom = (2 * tp + fp + fn).to(torch.float) - # nan result - score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else nan_score - - scores[i - bg] += score_cls - return reduce(scores, reduction=reduction) - - -def iou( - pred: torch.Tensor, - target: torch.Tensor, - ignore_index: Optional[int] = None, - absent_score: float = 0.0, - num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', -) -> torch.Tensor: - """ - Intersection over union, or Jaccard index calculation. - - Args: - pred: Tensor containing predictions - target: Tensor containing targets - ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that is not in the - range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no - index is ignored, and all classes are used. - absent_score: score to use for an individual class, if no instances of the class index were present in - `pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes, - [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`. Default is - 0.0. - num_classes: Optionally specify the number of classes - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - Return: - IoU score : Tensor containing single value if reduction is - 'elementwise_mean', or number of classes if reduction is 'none' - - Example: - - >>> target = torch.randint(0, 1, (10, 25, 25)) - >>> pred = torch.tensor(target) - >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] - >>> iou(pred, target) - tensor(0.4914) - - """ - num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) - - tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes) - - scores = torch.zeros(num_classes, device=pred.device, dtype=torch.float32) - - for class_idx in range(num_classes): - if class_idx == ignore_index: - continue - - tp = tps[class_idx] - fp = fps[class_idx] - fn = fns[class_idx] - sup = sups[class_idx] - - # If this class is absent in the target (no support) AND absent in the pred (no true or false - # positives), then use the absent_score for this class. - if sup + tp + fp == 0: - scores[class_idx] = absent_score - continue - - denom = tp + fp + fn - # Note that we do not need to worry about division-by-zero here since we know (sup + tp + fp != 0) from above, - # which means ((tp+fn) + tp + fp != 0), which means (2tp + fp + fn != 0). Since all vars are non-negative, we - # can conclude (tp + fp + fn > 0), meaning the denominator is non-zero for each class. - score = tp.to(torch.float) / denom - scores[class_idx] = score - - # Remove the ignored class index from the scores. - if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes: - scores = torch.cat([ - scores[:ignore_index], - scores[ignore_index + 1:], - ]) - - return reduce(scores, reduction=reduction) diff --git a/pytorch_lightning/metrics/functional/nlp.py b/pytorch_lightning/metrics/functional/nlp.py deleted file mode 100644 index 85c33642704cd..0000000000000 --- a/pytorch_lightning/metrics/functional/nlp.py +++ /dev/null @@ -1,103 +0,0 @@ -# referenced from -# Library Name: torchtext -# Authors: torchtext authors and @sluks -# Date: 2020-07-18 -# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from collections import Counter -from typing import List, Sequence - -import torch - - -def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: - """ - Counting how many times each word appears in a given text with ngram - - Args: - ngram_input_list: A list of translated text or reference texts - n_gram: gram value ranged 1 to 4 - - Return: - ngram_counter: a collections.Counter object of ngram - """ - - ngram_counter = Counter() - - for i in range(1, n_gram + 1): - for j in range(len(ngram_input_list) - i + 1): - ngram_key = tuple(ngram_input_list[j:(i + j)]) - ngram_counter[ngram_key] += 1 - - return ngram_counter - - -def bleu_score( - translate_corpus: Sequence[str], - reference_corpus: Sequence[str], - n_gram: int = 4, - smooth: bool = False -) -> torch.Tensor: - """ - Calculate BLEU score of machine translated text with one or more references - - Args: - translate_corpus: An iterable of machine translated corpus - reference_corpus: An iterable of iterables of reference corpus - n_gram: Gram value ranged from 1 to 4 (Default 4) - smooth: Whether or not to apply smoothing – Lin et al. 2004 - - Return: - Tensor with BLEU Score - - Example: - - >>> translate_corpus = ['the cat is on the mat'.split()] - >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] - >>> bleu_score(translate_corpus, reference_corpus) - tensor(0.7598) - - """ - - assert len(translate_corpus) == len(reference_corpus) - numerator = torch.zeros(n_gram) - denominator = torch.zeros(n_gram) - precision_scores = torch.zeros(n_gram) - c = 0.0 - r = 0.0 - - for (translation, references) in zip(translate_corpus, reference_corpus): - c += len(translation) - ref_len_list = [len(ref) for ref in references] - ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] - r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] - translation_counter = _count_ngram(translation, n_gram) - reference_counter = Counter() - - for ref in references: - reference_counter |= _count_ngram(ref, n_gram) - - ngram_counter_clip = translation_counter & reference_counter - - for counter_clip in ngram_counter_clip: - numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] - - for counter in translation_counter: - denominator[len(counter) - 1] += translation_counter[counter] - - trans_len = torch.tensor(c) - ref_len = torch.tensor(r) - - if min(numerator) == 0.0: - return torch.tensor(0.0) - - if smooth: - precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram)) - else: - precision_scores = numerator / denominator - - log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores) - geometric_mean = torch.exp(torch.sum(log_precision_scores)) - brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len)) - bleu = brevity_penalty * geometric_mean - - return bleu diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py deleted file mode 100644 index d0618abd65b96..0000000000000 --- a/pytorch_lightning/metrics/functional/reduction.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch - - -def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: - """ - Reduces a given tensor by a given reduction method - - Args: - to_reduce : the tensor, which shall be reduced - reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum') - - Return: - reduced Tensor - - Raise: - ValueError if an invalid reduction parameter was given - """ - if reduction == 'elementwise_mean': - return torch.mean(to_reduce) - if reduction == 'none': - return to_reduce - if reduction == 'sum': - return torch.sum(to_reduce) - raise ValueError('Reduction parameter unknown.') - - -def class_reduce(num: torch.Tensor, - denom: torch.Tensor, - weights: torch.Tensor, - class_reduction: str = 'none') -> torch.Tensor: - """ - Function used to reduce classification metrics of the form `num / denom * weights`. - For example for calculating standard accuracy the num would be number of - true positives per class, denom would be the support per class, and weights - would be a tensor of 1s - - Args: - num: numerator tensor - decom: denominator tensor - weights: weights for each class - class_reduction: reduction method for multiclass problems - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - - """ - valid_reduction = ('micro', 'macro', 'weighted', 'none') - if class_reduction == 'micro': - return torch.sum(num) / torch.sum(denom) - - # For the rest we need to take care of instances where the denom can be 0 - # for some classes which will produce nans for that class - fraction = num / denom - fraction[fraction != fraction] = 0 - if class_reduction == 'macro': - return torch.mean(fraction) - elif class_reduction == 'weighted': - return torch.sum(fraction * (weights / torch.sum(weights))) - elif class_reduction == 'none': - return fraction - - raise ValueError(f'Reduction parameter {class_reduction} unknown.' - f' Choose between one of these: {valid_reduction}') diff --git a/pytorch_lightning/metrics/functional/regression.py b/pytorch_lightning/metrics/functional/regression.py deleted file mode 100644 index 63d4615cae1a1..0000000000000 --- a/pytorch_lightning/metrics/functional/regression.py +++ /dev/null @@ -1,325 +0,0 @@ -from typing import Sequence - -import torch -from torch.nn import functional as F - -from pytorch_lightning.metrics.functional.reduction import reduce - - -def mse( - pred: torch.Tensor, - target: torch.Tensor, - reduction: str = 'elementwise_mean', - return_state: bool = False -) -> torch.Tensor: - """ - Computes mean squared error - - Args: - pred: estimated labels - target: ground truth labels - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - return_state: returns a internal state that can be ddp reduced - before doing the final calculation - - Return: - Tensor with MSE - - Example: - - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mse(x, y) - tensor(0.2500) - - """ - mse = F.mse_loss(pred, target, reduction='none') - if return_state: - return {'squared_error': mse.sum(), 'n_observations': torch.tensor(mse.numel())} - mse = reduce(mse, reduction=reduction) - return mse - - -def rmse( - pred: torch.Tensor, - target: torch.Tensor, - reduction: str = 'elementwise_mean', - return_state: bool = False -) -> torch.Tensor: - """ - Computes root mean squared error - - Args: - pred: estimated labels - target: ground truth labels - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - return_state: returns a internal state that can be ddp reduced - before doing the final calculation - - Return: - Tensor with RMSE - - - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> rmse(x, y) - tensor(0.5000) - - """ - mean_squared_error = mse(pred, target, reduction=reduction) - if return_state: - return {'squared_error': mean_squared_error.sum(), - 'n_observations': torch.tensor(mean_squared_error.numel())} - return torch.sqrt(mean_squared_error) - - -def mae( - pred: torch.Tensor, - target: torch.Tensor, - reduction: str = 'elementwise_mean', - return_state: bool = False -) -> torch.Tensor: - """ - Computes mean absolute error - - Args: - pred: estimated labels - target: ground truth labels - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - return_state: returns a internal state that can be ddp reduced - before doing the final calculation - - Return: - Tensor with MAE - - Example: - - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> mae(x, y) - tensor(0.2500) - - """ - mae = F.l1_loss(pred, target, reduction='none') - if return_state: - return {'absolute_error': mae.sum(), 'n_observations': torch.tensor(mae.numel())} - mae = reduce(mae, reduction=reduction) - return mae - - -def rmsle( - pred: torch.Tensor, - target: torch.Tensor, - reduction: str = 'elementwise_mean' -) -> torch.Tensor: - """ - Computes root mean squared log error - - Args: - pred: estimated labels - target: ground truth labels - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - Return: - Tensor with RMSLE - - Example: - - >>> x = torch.tensor([0., 1, 2, 3]) - >>> y = torch.tensor([0., 1, 2, 2]) - >>> rmsle(x, y) - tensor(0.1438) - - """ - rmsle = rmse(torch.log(pred + 1), torch.log(target + 1), reduction=reduction) - return rmsle - - -def psnr( - pred: torch.Tensor, - target: torch.Tensor, - data_range: float = None, - base: float = 10.0, - reduction: str = 'elementwise_mean', - return_state: bool = False -) -> torch.Tensor: - """ - Computes the peak signal-to-noise ratio - - Args: - pred: estimated signal - target: groun truth signal - data_range: the range of the data. If None, it is determined from the data (max - min) - base: a base of a logarithm to use (default: 10) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - return_state: returns a internal state that can be ddp reduced - before doing the final calculation - - Return: - Tensor with PSNR score - - Example: - - >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) - >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) - >>> psnr(pred, target) - tensor(2.5527) - - """ - if data_range is None: - data_range = target.max() - target.min() - else: - data_range = torch.tensor(float(data_range)) - - if return_state: - return {'data_range': data_range, - 'sum_squared_error': F.mse_loss(pred, target, reduction='none').sum(), - 'n_obs': torch.tensor(target.numel())} - - mse_score = mse(pred.view(-1), target.view(-1), reduction=reduction) - psnr_base_e = 2 * torch.log(data_range) - torch.log(mse_score) - psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) - return psnr - - -def _gaussian_kernel(channel, kernel_size, sigma, device): - def _gaussian(kernel_size, sigma, device): - gauss = torch.arange( - start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, - step=1, - dtype=torch.float32, - device=device - ) - gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2))) - return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) - - gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], device) - gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], device) - kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) - - return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) - - -def ssim( - pred: torch.Tensor, - target: torch.Tensor, - kernel_size: Sequence[int] = (11, 11), - sigma: Sequence[float] = (1.5, 1.5), - reduction: str = "elementwise_mean", - data_range: float = None, - k1: float = 0.01, - k2: float = 0.03 -) -> torch.Tensor: - """ - Computes Structual Similarity Index Measure - - Args: - pred: estimated image - target: ground truth image - kernel_size: size of the gaussian kernel (default: (11, 11)) - sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - data_range: Range of the image. If ``None``, it is determined from the image (max - min) - k1: Parameter of SSIM. Default: 0.01 - k2: Parameter of SSIM. Default: 0.03 - - Return: - Tensor with SSIM score - - Example: - - >>> pred = torch.rand([16, 1, 16, 16]) - >>> target = pred * 0.75 - >>> ssim(pred, target) - tensor(0.9219) - - """ - if pred.dtype != target.dtype: - raise TypeError( - "Expected `pred` and `target` to have the same data type." - f" Got pred: {pred.dtype} and target: {target.dtype}." - ) - - if pred.shape != target.shape: - raise ValueError( - "Expected `pred` and `target` to have the same shape." - f" Got pred: {pred.shape} and target: {target.shape}." - ) - - if len(pred.shape) != 4 or len(target.shape) != 4: - raise ValueError( - "Expected `pred` and `target` to have BxCxHxW shape." - f" Got pred: {pred.shape} and target: {target.shape}." - ) - - if len(kernel_size) != 2 or len(sigma) != 2: - raise ValueError( - "Expected `kernel_size` and `sigma` to have the length of two." - f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}." - ) - - if any(x % 2 == 0 or x <= 0 for x in kernel_size): - raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.") - - if any(y <= 0 for y in sigma): - raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.") - - if data_range is None: - data_range = max(pred.max() - pred.min(), target.max() - target.min()) - - C1 = pow(k1 * data_range, 2) - C2 = pow(k2 * data_range, 2) - device = pred.device - - channel = pred.size(1) - kernel = _gaussian_kernel(channel, kernel_size, sigma, device) - - # Concatenate - # pred for mu_pred - # target for mu_target - # pred * pred for sigma_pred - # target * target for sigma_target - # pred * target for sigma_pred_target - input_list = torch.cat([pred, target, pred * pred, target * target, pred * target]) # (5 * B, C, H, W) - outputs = F.conv2d(input_list, kernel, groups=channel) - output_list = [outputs[x * pred.size(0): (x + 1) * pred.size(0)] for x in range(len(outputs))] - - mu_pred_sq = output_list[0].pow(2) - mu_target_sq = output_list[1].pow(2) - mu_pred_target = output_list[0] * output_list[1] - - sigma_pred_sq = output_list[2] - mu_pred_sq - sigma_target_sq = output_list[3] - mu_target_sq - sigma_pred_target = output_list[4] - mu_pred_target - - UPPER = 2 * sigma_pred_target + C2 - LOWER = sigma_pred_sq + sigma_target_sq + C2 - - ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / ((mu_pred_sq + mu_target_sq + C1) * LOWER) - - return reduce(ssim_idx, reduction) diff --git a/pytorch_lightning/metrics/functional/self_supervised.py b/pytorch_lightning/metrics/functional/self_supervised.py deleted file mode 100644 index c8c7e83166723..0000000000000 --- a/pytorch_lightning/metrics/functional/self_supervised.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch - - -def embedding_similarity( - batch: torch.Tensor, - similarity: str = 'cosine', - reduction: str = 'none', - zero_diagonal: bool = True -) -> torch.Tensor: - """ - Computes representation similarity - - Example: - - >>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]]) - >>> embedding_similarity(embeddings) - tensor([[0.0000, 1.0000, 0.9759], - [1.0000, 0.0000, 0.9759], - [0.9759, 0.9759, 0.0000]]) - - Args: - batch: (batch, dim) - similarity: 'dot' or 'cosine' - reduction: 'none', 'sum', 'mean' (all along dim -1) - zero_diagonal: if True, the diagonals are set to zero - - Return: - A square matrix (batch, batch) with the similarity scores between all elements - If sum or mean are used, then returns (b, 1) with the reduced value for each row - """ - if similarity == 'cosine': - norm = torch.norm(batch, p=2, dim=1) - batch = batch / norm.unsqueeze(1) - - sqr_mtx = batch.mm(batch.transpose(1, 0)) - - if zero_diagonal: - sqr_mtx = sqr_mtx.fill_diagonal_(0) - - if reduction == 'mean': - sqr_mtx = sqr_mtx.mean(dim=-1) - - if reduction == 'sum': - sqr_mtx = sqr_mtx.sum(dim=-1) - - return sqr_mtx diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py deleted file mode 100644 index e97f054f05f89..0000000000000 --- a/pytorch_lightning/metrics/metric.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional, Sequence -import numbers - -import torch -from torch import nn -import numpy as np - -from pytorch_lightning.metrics.converters import ( - at_least_1d, - gather_all_tensors_if_available, - convert_to_tensor, - convert_to_numpy, -) -from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin - - -class Metric(DeviceDtypeModuleMixin, nn.Module, ABC): - """ - Abstract base class for metric implementation. - - Should be used to implement metrics that - - 1. Return multiple Outputs - 2. Handle their own DDP sync - - Metric hooks that can be implemented are - - * input_convert: pre-forward hook that takes care of input conversion - * output_convert: post-forward hook that takes care of output convertion - * ddp_reduce: implementation of ddp sync + aggregation, default is ddp_sync + aggregate - * compute: post-ddp sync for additional metric computations - - ``ddp_reduce`` by default calls the following methods, which can also be overwritten if necessary. - - * ddp_sync: implements how values should be synced across ddp-processes. Defaults to gather all. - * aggregate: implement how values should be aggregated (defaults to mean). - - Call order - - input_convert -> forward -> output_convert -> ddp_reduce (per default being ddp_sync -> aggregate) -> compute - - """ - - def __init__(self, name: str, reduce_group: Optional[Any] = None): - """ - Args: - name: the metric's name - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - - """ - super().__init__() - self.name = name - self._dtype = torch.get_default_dtype() - self._device = torch.device("cpu") - - self.reduce_group = reduce_group - - # Buffer for holding aggregated state after each batch - self._step_vals = [] - - # Register hooks - self.register_forward_pre_hook(self.input_convert) - self.register_forward_hook(self.output_convert) - self.register_forward_hook(self.ddp_reduce) - self.register_forward_hook(self.compute) - - @staticmethod - def input_convert(self, data: Any): - """ - Implement how the inputs should be casted before calling forward - - Args: - data: input to forward method - - Returns: - casted data - """ - return data - - @abstractmethod - def forward(self, *args, **kwargs): - """ - Implements the actual metric computation. - - Returns: - metric value or metric state - - """ - raise NotImplementedError - - @staticmethod - def output_convert(self, data: Any, output: Any): - """ - Implement how outputs from forward should be casted - - Args: - data: input to forward method - output: output from forward method - - Returns: - casted outputs - """ - return apply_to_collection(output, (torch.Tensor, np.ndarray), at_least_1d) - - def ddp_sync(self, tensor: Any): - """ - Implement how the outputs from forward should be synced - (per default just gathers all of them and adds them to self._step_vals) - - Args: - tensor: tensor to sync - - Returns: - synced output - - """ - gathered_tensors = apply_to_collection(tensor, torch.Tensor, gather_all_tensors_if_available, self.reduce_group) - return gathered_tensors - - @staticmethod - def ddp_reduce(self, data: Any, output: Any): - """ - Implement how the outputs from forward should be synced and reduced across nodes - - Args: - data: input to forward method - output: output from the `output_convert` hook - - Returns: - synced output - - """ - synced = self.ddp_sync(output) - agg_val = self.aggregate(synced) - self._step_vals.append(agg_val) - return agg_val - - def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: - """ - Implement aggregation of values on the same device - - Args: - tensors: the values to be aggregated - - Returns: - aggregated values - - """ - # single tensor - if len(tensors) == 1: - tensors = tensors[0] - if isinstance(tensors, Mapping): - return {k: _stack_and_agg(tensors[k]) for k in tensors.keys()} - if isinstance(tensors, list): - return _stack_and_agg(tensors) - if isinstance(tensors, tuple): - return tensors - if isinstance(tensors, torch.Tensor): - return _stack_and_agg(tensors) - - # multiple tensors (from aggregation over batches) - if isinstance(tensors[0], Mapping): - return {k: torch.stack([tensor[k] for tensor in tensors]).sum(0) for k in tensors[0].keys()} - if isinstance(tensors[0], Sequence): - return tuple([torch.stack(tmp).sum(0) for tmp in zip(*tensors)]) - if isinstance(tensors[0], torch.Tensor): - return torch.stack(tensors).sum(0) - - raise TypeError("unknown metric value format to aggregate") - - @staticmethod - def compute(self, data: Any, output: Any): - """ - Implement additionally metric computations to be done after the aggregation - - Args: - data: input to forward method - output: output from the `aggregate` hook - - Returns: - final metric value - - """ - return output - - @property - def aggregated(self) -> torch.Tensor: - aggr = self.aggregate(*self._step_vals if len(self._step_vals) > 1 else self._step_vals) - self.reset() - return self.compute(self, None, aggr) - - def reset(self): - self._step_vals = [] - - -def _stack_and_agg(tensors): - """ Utility function for stacking and aggregating tensors """ - if isinstance(tensors, list): - return torch.sum(torch.stack([t for t in tensors]), 0) - return tensors.squeeze() if tensors.numel() == 1 else tensors - - -class TensorMetric(Metric): - """ - Base class for metric implementation operating directly on tensors. - All inputs and outputs will be casted to tensors if necessary. - Already handles DDP sync and input/output conversions. - """ - - @staticmethod - def input_convert(self, data: Any): - data = apply_to_collection( - data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device - ) - return super(TensorMetric, self).input_convert(self, data) - - @staticmethod - def output_convert(self, data: Any, output: Any): - - output = apply_to_collection( - output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device - ) - return super(TensorMetric, self).output_convert(self, data, output) - - -class NumpyMetric(Metric): - """ - Base class for metric implementation operating on numpy arrays. - All inputs will be casted to numpy if necessary and all outputs will - be casted to tensors if necessary. - Already handles DDP sync and input/output conversions. - """ - - @staticmethod - def input_convert(self, data: Any): - data = apply_to_collection(data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) - return super(NumpyMetric, self).input_convert(self, data) - - @staticmethod - def output_convert(self, data: Any, output: Any): - output = apply_to_collection( - output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device - ) - - return super(NumpyMetric, self).output_convert(self, data, output) diff --git a/pytorch_lightning/metrics/nlp.py b/pytorch_lightning/metrics/nlp.py deleted file mode 100644 index 38b3632f5490d..0000000000000 --- a/pytorch_lightning/metrics/nlp.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from pytorch_lightning.metrics.functional.nlp import bleu_score -from pytorch_lightning.metrics.metric import Metric - - -class BLEUScore(Metric): - """ - Calculate BLEU score of machine translated text with one or more references. - - Example: - - >>> translate_corpus = ['the cat is on the mat'.split()] - >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] - >>> metric = BLEUScore() - >>> metric(translate_corpus, reference_corpus) - tensor(0.7598) - """ - - def __init__(self, n_gram: int = 4, smooth: bool = False): - """ - Args: - n_gram: Gram value ranged from 1 to 4 (Default 4) - smooth: Whether or not to apply smoothing – Lin et al. 2004 - """ - super().__init__(name="bleu") - self.n_gram = n_gram - self.smooth = smooth - - def forward(self, translate_corpus: list, reference_corpus: list) -> torch.Tensor: - """ - Actual metric computation - - Args: - translate_corpus: An iterable of machine translated corpus - reference_corpus: An iterable of iterables of reference corpus - - Return: - torch.Tensor: BLEU Score - """ - return bleu_score( - translate_corpus=translate_corpus, - reference_corpus=reference_corpus, - n_gram=self.n_gram, - smooth=self.smooth, - ).to(self.device, self.dtype) diff --git a/pytorch_lightning/metrics/regression.py b/pytorch_lightning/metrics/regression.py deleted file mode 100644 index a152d684a43a8..0000000000000 --- a/pytorch_lightning/metrics/regression.py +++ /dev/null @@ -1,361 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Sequence, Any - -import torch - -from pytorch_lightning.metrics.functional.regression import ( - mae, - mse, - psnr, - rmse, - rmsle, - ssim -) -from pytorch_lightning.metrics.metric import Metric - - -class MSE(Metric): - """ - Computes the mean squared loss. - - Example: - - >>> pred = torch.tensor([0., 1, 2, 3]) - >>> target = torch.tensor([0., 1, 2, 2]) - >>> metric = MSE() - >>> metric(pred, target) - tensor(0.2500) - - """ - - def __init__( - self, - reduction: str = 'elementwise_mean', - ): - """ - Args: - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - """ - super().__init__(name='mse') - self.reduction = reduction - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: ground truth labels - - Return: - A Tensor with the mse loss. - """ - return mse(pred, target, return_state=True) - - @staticmethod - def compute(self, data: Any, output: Any): - sse, n = output['squared_error'], output['n_observations'] - return sse / n - - -class RMSE(Metric): - """ - Computes the root mean squared loss. - - Example: - - >>> pred = torch.tensor([0., 1, 2, 3]) - >>> target = torch.tensor([0., 1, 2, 2]) - >>> metric = RMSE() - >>> metric(pred, target) - tensor(0.5000) - - """ - - def __init__( - self, - reduction: str = 'elementwise_mean', - ): - """ - Args: - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - """ - super().__init__(name='rmse') - self.reduction = reduction - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: ground truth labels - - Return: - A Tensor with the rmse loss. - """ - return rmse(pred, target, reduction='none', return_state=True) - - @staticmethod - def compute(self, data: Any, output: Any): - """ Squaring needs to happend after ddp sync """ - sse, n = output['squared_error'], output['n_observations'] - return torch.sqrt(sse / n) - - -class MAE(Metric): - """ - Computes the mean absolute loss or L1-loss. - - Example: - - >>> pred = torch.tensor([0., 1, 2, 3]) - >>> target = torch.tensor([0., 1, 2, 2]) - >>> metric = MAE() - >>> metric(pred, target) - tensor(0.2500) - - """ - - def __init__( - self, - reduction: str = 'elementwise_mean', - ): - """ - Args: - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - """ - super().__init__(name='mae') - self.reduction = reduction - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: ground truth labels - - Return: - A Tensor with the mae loss. - """ - return mae(pred, target, return_state=True) - - @staticmethod - def compute(self, data: Any, output: Any): - sae, n = output['absolute_error'], output['n_observations'] - return sae / n - - -class RMSLE(Metric): - """ - Computes the root mean squared log loss. - - Example: - - >>> pred = torch.tensor([0., 1, 2, 3]) - >>> target = torch.tensor([0., 1, 2, 2]) - >>> metric = RMSLE() - >>> metric(pred, target) - tensor(0.1438) - - """ - - def __init__( - self, - reduction: str = 'elementwise_mean', - ): - """ - Args: - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - """ - super().__init__(name='rmsle') - self.reduction = reduction - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: ground truth labels - - Return: - A Tensor with the rmsle loss. - """ - return mse(torch.log(pred + 1), torch.log(target + 1), - self.reduction, return_state=True) - - @staticmethod - def compute(self, data: Any, output: Any): - """ Squaring needs to happend after ddp sync """ - sse, n = output['squared_error'], output['n_observations'] - return torch.sqrt(sse / n) - - -class PSNR(Metric): - """ - Computes the peak signal-to-noise ratio - - Example: - - >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) - >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) - >>> metric = PSNR() - >>> metric(pred, target) - tensor(2.5527) - - """ - - def __init__( - self, - data_range: float = None, - base: int = 10, - reduction: str = 'elementwise_mean' - ): - """ - Args: - data_range: the range of the data. If None, it is determined from the data (max - min) - base: a base of a logarithm to use (default: 10) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - """ - super().__init__(name='psnr') - self.data_range = data_range - self.base = float(base) - self.reduction = reduction - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: predicted labels - target: ground truth labels - - Return: - A Tensor with psnr score. - """ - return psnr(pred, target, self.data_range, self.base, self.reduction, return_state=True) - - def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: - """ Special aggregation function as the data range needs to be correctly synced """ - if len(tensors) == 1: - tensors = tensors[0] - output = {'data_range': torch.stack([t for t in tensors['data_range']]).max()} - output.update({k: torch.stack([t for t in tensors[k]]).sum(0) for k in tensors.keys() if k != 'data_range'}) - return output - - output = {'data_range': torch.stack([tensor['data_range'] for tensor in tensors]).max()} - output.update({k: torch.stack([tensor[k] for tensor in tensors]).sum(0) for k in tensors[0].keys() if k != 'data_range'}) - return output - - @staticmethod - def compute(self, data: Any, output: Any): - """ - Compute final value based on the synced data_range, sum of squared errors - and number of samples. - - Args: - data: input to forward method - output: output from the `aggregate` hook - - Returns: - final metric value - - """ - sse, n, data_range = output['sum_squared_error'], output['n_obs'], output['data_range'] - psnr_base_e = 2 * torch.log(data_range) - torch.log(sse / n) - psnr = psnr_base_e * (10 / torch.log(torch.tensor(self.base))) - return psnr - - -class SSIM(Metric): - """ - Computes Structual Similarity Index Measure - - Example: - - >>> pred = torch.rand([16, 1, 16, 16]) - >>> target = pred * 0.75 - >>> metric = SSIM() - >>> metric(pred, target) - tensor(0.9219) - - """ - - def __init__( - self, - kernel_size: Sequence[int] = (11, 11), - sigma: Sequence[float] = (1.5, 1.5), - reduction: str = "elementwise_mean", - data_range: float = None, - k1: float = 0.01, - k2: float = 0.03 - ): - """ - Args: - kernel_size: Size of the gaussian kernel (default: (11, 11)) - sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - data_range: Range of the image. If ``None``, it is determined from the image (max - min) - k1: Parameter of SSIM. Default: 0.01 - k2: Parameter of SSIM. Default: 0.03 - """ - super().__init__(name="ssim") - self.kernel_size = kernel_size - self.sigma = sigma - self.reduction = reduction - self.data_range = data_range - self.k1 = k1 - self.k2 = k2 - - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - pred: Estimated image - target: Ground truth image - - Return: - A Tensor with SSIM score. - """ - return ssim(pred, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2) diff --git a/pytorch_lightning/metrics/self_supervised.py b/pytorch_lightning/metrics/self_supervised.py deleted file mode 100644 index 9e57c15026fcc..0000000000000 --- a/pytorch_lightning/metrics/self_supervised.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -import torch - -from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity -from pytorch_lightning.metrics.metric import TensorMetric -from pytorch_lightning.utilities import rank_zero_warn - - -class EmbeddingSimilarity(TensorMetric): - """ - Computes similarity between embeddings - - Example: - >>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]]) - >>> embedding_similarity(embeddings) - tensor([[0.0000, 1.0000, 0.9759], - [1.0000, 0.0000, 0.9759], - [0.9759, 0.9759, 0.0000]]) - - """ - def __init__( - self, - similarity: str = 'cosine', - zero_diagonal: bool = True, - reduction: str = 'mean', - reduce_group: Any = None - ): - """ - Args: - similarity: 'dot' or 'cosine' - reduction: 'none', 'sum', 'mean' (all along dim -1) - zero_diagonal: if True, the diagonals are set to zero - reduce_group: the process group to reduce metric results from DDP - - """ - super().__init__(name='embedding_similarity', - reduce_group=reduce_group) - assert similarity in ('dot', 'cosine') - self.similarity = similarity - isinstance(zero_diagonal, bool) - self.zero_diagonal = zero_diagonal - assert reduction in ('none', 'sum', 'mean') - self.reduction = reduction - - rank_zero_warn('Please note that Metric `EmbeddingSimilarity` does not support aggregation.') - - def forward(self, batch: torch.Tensor) -> torch.Tensor: - """ - Actual metric computation - - Args: - batch: tensor containing embeddings with shape (batch_size, dim) - - Return: - A square matrix (batch, batch) with the similarity scores between all elements - If sum or mean are used, then returns (b, 1) with the reduced value for each row - """ - return embedding_similarity(batch, - similarity=self.similarity, - zero_diagonal=self.zero_diagonal, - reduction=self.reduction) - - @staticmethod - def ddp_reduce(self, data: Any, output: Any): - """ reduction for this metric does not make sense """ - return output - - @property - def aggregated(self): - raise ValueError('Metric `EmbeddingSimilarity` does not support aggregation.') diff --git a/pytorch_lightning/metrics/sklearns.py b/pytorch_lightning/metrics/sklearns.py deleted file mode 100644 index d19bd05e634d1..0000000000000 --- a/pytorch_lightning/metrics/sklearns.py +++ /dev/null @@ -1,1590 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, List, Optional, Sequence, Tuple, Union - -import numpy as np -import torch - -from pytorch_lightning import _logger as lightning_logger -from pytorch_lightning.metrics.metric import NumpyMetric - -if torch.distributed.is_available(): - from torch.distributed import group -else: - class group: - WORLD = None - - -class SklearnMetric(NumpyMetric): - """ - Bridge between PyTorch Lightning and scikit-learn metrics - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Note: - The order of targets and predictions may be different from the order typically used in PyTorch - """ - - def __init__( - self, - metric_name: str, - reduce_group: Any = group.WORLD, - **kwargs, - ): - """ - Args: - metric_name: the metric name to import and compute from scikit-learn.metrics - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - **kwargs: additonal keyword arguments (will be forwarded to metric call) - """ - super().__init__( - name=metric_name, - reduce_group=reduce_group, - ) - - self.metric_kwargs = kwargs - lightning_logger.debug( - f"Metric {self.__class__.__name__} is using Sklearn as backend, meaning that" - " every metric call will cause a GPU synchronization, which may slow down your code" - ) - - @property - def metric_fn(self): - import sklearn.metrics - - return getattr(sklearn.metrics, self.name) - - def forward(self, *args, **kwargs) -> Union[np.ndarray, int, float]: - """ - Carries the actual metric computation - - Args: - *args: Positional arguments forwarded to metric call (should be already converted to numpy) - **kwargs: keyword arguments forwarded to metric call (should be already converted to numpy) - - Return: - the metric value (will be converted to tensor by baseclass) - - """ - return self.metric_fn(*args, **kwargs, **self.metric_kwargs) - - -class Accuracy(SklearnMetric): - """ - Calculates the Accuracy Score - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([0, 1, 2, 3]) - >>> y_true = torch.tensor([0, 1, 2, 2]) - >>> metric = Accuracy() - >>> metric(y_pred, y_true) - tensor(0.7500) - - """ - - def __init__( - self, - normalize: bool = True, - reduce_group: Any = group.WORLD, - ): - """ - Args: - normalize: If ``False``, return the number of correctly classified samples. - Otherwise, return the fraction of correctly classified samples. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__(metric_name="accuracy_score", reduce_group=reduce_group, normalize=normalize) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> float: - """ - Computes the accuracy - - Args: - y_pred: the array containing the predictions (already in categorical form) - y_true: the array containing the targets (in categorical form) - sample_weight: Sample weights. - - Return: - Accuracy Score - """ - return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) - - -class AUC(SklearnMetric): - """ - Calculates the Area Under the Curve using the trapoezoidal rule - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([0, 1, 2, 3]) - >>> y_true = torch.tensor([0, 1, 2, 2]) - >>> metric = AUC() - >>> metric(y_pred, y_true) - tensor(4.) - """ - - def __init__( - self, - reduce_group: Any = group.WORLD, - ): - """ - Args: - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - - super().__init__(metric_name="auc", reduce_group=reduce_group) - - def forward(self, x: np.ndarray, y: np.ndarray) -> float: - """ - Computes the AUC - - Args: - x: x coordinates. - y: y coordinates. - - Return: - AUC calculated with trapezoidal rule - - """ - return super().forward(x=x, y=y) - - -class AveragePrecision(SklearnMetric): - """ - Calculates the average precision (AP) score. - - """ - - def __init__( - self, - average: Optional[str] = "macro", - reduce_group: Any = group.WORLD, - ): - """ - Args: - average: If None, the scores for each class are returned. Otherwise, this determines the type of - averaging performed on the data: - - * If 'micro': Calculate metrics globally by considering each element of the label indicator - matrix as a label. - * If 'macro': Calculate metrics for each label, and find their unweighted mean. - This does not take label imbalance into account. - * If 'weighted': Calculate metrics for each label, and find their average, weighted by - support (the number of true instances for each label). - * If 'samples': Calculate metrics for each instance, and find their average. - - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("average_precision_score", reduce_group=reduce_group, average=average) - - def forward( - self, - y_score: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> float: - """ - Args: - y_score: Target scores, can either be probability estimates of the positive class, - confidence values, or binary decisions. - y_true: True binary labels in binary label indicators. - sample_weight: Sample weights. - - Return: - average precision score - """ - return super().forward(y_score=y_score, y_true=y_true, sample_weight=sample_weight) - - -class BalancedAccuracy(SklearnMetric): - """Compute the balanced accuracy score - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([0, 0, 0, 1]) - >>> y_true = torch.tensor([0, 0, 1, 1]) - >>> metric = BalancedAccuracy() - >>> metric(y_pred, y_true) - tensor(0.7500) - - """ - - def __init__( - self, - adjusted: bool = False, - reduce_group: Any = group.WORLD, - ): - """ - Args: - adjusted: If ``True``, the result sis adjusted for chance, such that random performance - corresponds to 0 and perfect performance corresponds to 1 - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("balanced_accuracy_score", reduce_group=reduce_group, adjusted=adjusted) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> float: - """ - Args: - y_pred: the array containing the predictions (already in categorical form) - y_true: the array containing the targets (in categorical form) - sample_weight: Sample weights. - - Return: - balanced accuracy score - - """ - return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) - - -class CohenKappaScore(SklearnMetric): - """ - Calculates Cohens kappa: a statitic that measures inter-annotator agreement - - Example: - - >>> y_pred = torch.tensor([1, 2, 0, 2]) - >>> y_true = torch.tensor([2, 2, 2, 1]) - >>> metric = CohenKappaScore() - >>> metric(y_pred, y_true) - tensor(-0.3333) - - """ - - def __init__( - self, - labels: Optional[Sequence] = None, - weights: Optional[str] = None, - reduce_group: Any = group.WORLD, - ): - """ - Args: - labels: List of labels to index the matrix. This may be used to reorder - or select a subset of labels. - If none is given, those that appear at least once - in ``y1`` or ``y2`` are used in sorted order. - weights: string indicating weightning type used in scoring. None - means no weighting, string ``linear`` means linear weighted - and ``quadratic`` means quadratic weighted - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("cohen_kappa_score", reduce_group=reduce_group, labels=labels, weights=weights) - - def forward( - self, - y1: np.ndarray, - y2: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> float: - """ - Args: - y_1: Labels assigned by first annotator - y_2: Labels assigned by second annotator - sample_weight: Sample weights. - - Return: - Cohens kappa score - """ - return super().forward(y1=y1, y2=y2, sample_weight=sample_weight) - - -class ConfusionMatrix(SklearnMetric): - """ - Compute confusion matrix to evaluate the accuracy of a classification - By definition a confusion matrix :math:`C` is such that :math:`C_{i, j}` - is equal to the number of observations known to be in group :math:`i` but - predicted to be in group :math:`j`. - - Example: - - >>> y_pred = torch.tensor([0, 1, 2, 1]) - >>> y_true = torch.tensor([0, 1, 2, 2]) - >>> metric = ConfusionMatrix() - >>> metric(y_pred, y_true) - tensor([[1., 0., 0.], - [0., 1., 0.], - [0., 1., 1.]]) - - """ - - def __init__( - self, - labels: Optional[Sequence] = None, - reduce_group: Any = group.WORLD, - ): - """ - Args: - labels: List of labels to index the matrix. This may be used to reorder - or select a subset of labels. - If none is given, those that appear at least once - in ``y_true`` or ``y_pred`` are used in sorted order. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("confusion_matrix", reduce_group=reduce_group, labels=labels) - - def forward(self, y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: - """ - Args: - y_pred: Estimated targets as returned by a classifier. - y_true: Ground truth (correct) target values. - - Return: - Confusion matrix (array of shape [num_classes, num_classes]) - - """ - return super().forward(y_pred=y_pred, y_true=y_true) - - def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: - return torch.stack(tensors).mean(0) - - -class DCG(SklearnMetric): - """Compute discounted cumulative gain - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_score = torch.tensor([[.1, .2, .3, 4, 70]]) - >>> y_true = torch.tensor([[10, 0, 0, 1, 5]]) - >>> metric = DCG() - >>> metric(y_score, y_true) - tensor(9.4995) - """ - - def __init__( - self, - k: Optional[int] = None, - log_base: float = 2, - ignore_ties: bool = False, - reduce_group: Any = group.WORLD, - ): - """ - Args: - k: only consider the hightest k score in the ranking - log_base: base of the logarithm used for the discount - ignore_ties: If ``True``, assume there are no ties in y_score for efficiency gains - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("dcg_score", reduce_group=reduce_group, k=k, log_base=log_base, ignore_ties=ignore_ties) - - def forward( - self, - y_score: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> float: - """ - Args: - y_score: target scores, either probability estimates, confidence values - or or non-thresholded measure of decisions - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - DCG score - - """ - return super().forward(y_true=y_true, y_score=y_score, sample_weight=sample_weight) - - -class F1(SklearnMetric): - r""" - Compute the F1 score, also known as balanced F-score or F-measure - The F1 score can be interpreted as a weighted average of the precision and - recall, where an F1 score reaches its best value at 1 and worst score at 0. - The relative contribution of precision and recall to the F1 score are - equal. The formula for the F1 score is: - - .. math:: - - F_1 = 2 \cdot \frac{precision \cdot recall}{precision + recall} - - In the multi-class and multi-label case, this is the weighted average of - the F1 score of each class. - - Example: - - >>> y_pred = torch.tensor([0, 1, 2, 3]) - >>> y_true = torch.tensor([0, 1, 2, 2]) - >>> metric = F1() - >>> metric(y_pred, y_true) - tensor(0.6667) - - References - - [1] `Wikipedia entry for the F1-score - `_ - """ - - def __init__( - self, - labels: Optional[Sequence] = None, - pos_label: Union[str, int] = 1, - average: Optional[str] = "macro", - reduce_group: Any = group.WORLD, - ): - """ - Args: - labels: Integer array of labels. - pos_label: The class to report if ``average='binary'``. - average: This parameter is required for multiclass/multilabel targets. - If ``None``, the scores for each class are returned. Otherwise, this - determines the type of averaging performed on the data: - - * ``'binary'``: - Only report results for the class specified by ``pos_label``. - This is applicable only if targets (``y_{true,pred}``) are binary. - * ``'micro'``: - Calculate metrics globally by counting the total true positives, - false negatives and false positives. - * ``'macro'``: - Calculate metrics for each label, and find their unweighted - mean. This does not take label imbalance into account. - * ``'weighted'``: - Calculate metrics for each label, and find their average, weighted - by support (the number of true instances for each label). This - alters 'macro' to account for label imbalance; it can result in an - F-score that is not between precision and recall. - * ``'samples'``: - Calculate metrics for each instance, and find their average (only - meaningful for multilabel classification where this differs from - :func:`accuracy_score`). - - Note that if ``pos_label`` is given in binary classification with - `average != 'binary'`, only that positive class is reported. This - behavior is deprecated and will change in version 0.18. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("f1_score", reduce_group=reduce_group, labels=labels, pos_label=pos_label, average=average) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> Union[np.ndarray, float]: - """ - Args: - y_pred : Estimated targets as returned by a classifier. - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - F1 score of the positive class in binary classification or weighted - average of the F1 scores of each class for the multiclass task. - - """ - return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) - - -class FBeta(SklearnMetric): - """ - Compute the F-beta score. The `beta` parameter determines the weight of precision in the combined - score. ``beta < 1`` lends more weight to precision, while ``beta > 1`` - favors recall (``beta -> 0`` considers only precision, ``beta -> inf`` - only recall). - - Example: - - >>> y_pred = torch.tensor([0, 1, 2, 3]) - >>> y_true = torch.tensor([0, 1, 2, 2]) - >>> metric = FBeta(beta=0.25) - >>> metric(y_pred, y_true) - tensor(0.7361) - - References: - - [1] R. Baeza-Yates and B. Ribeiro-Neto (2011). - Modern Information Retrieval. Addison Wesley, pp. 327-328. - - [2] `Wikipedia entry for the F1-score - `_ - """ - - def __init__( - self, - beta: float, - labels: Optional[Sequence] = None, - pos_label: Union[str, int] = 1, - average: Optional[str] = "macro", - reduce_group: Any = group.WORLD, - ): - """ - Args: - beta: Weight of precision in harmonic mean. - labels: Integer array of labels. - pos_label: The class to report if ``average='binary'``. - average: This parameter is required for multiclass/multilabel targets. - If ``None``, the scores for each class are returned. Otherwise, this - determines the type of averaging performed on the data: - - * ``'binary'``: - Only report results for the class specified by ``pos_label``. - This is applicable only if targets (``y_{true,pred}``) are binary. - * ``'micro'``: - Calculate metrics globally by counting the total true positives, - false negatives and false positives. - * ``'macro'``: - Calculate metrics for each label, and find their unweighted - mean. This does not take label imbalance into account. - * ``'weighted'``: - Calculate metrics for each label, and find their average, weighted - by support (the number of true instances for each label). This - alters 'macro' to account for label imbalance; it can result in an - F-score that is not between precision and recall. - * ``'samples'``: - Calculate metrics for each instance, and find their average (only - meaningful for multilabel classification where this differs from - :func:`accuracy_score`). - - Note that if ``pos_label`` is given in binary classification with - `average != 'binary'`, only that positive class is reported. This - behavior is deprecated and will change in version 0.18. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__( - "fbeta_score", reduce_group=reduce_group, beta=beta, labels=labels, pos_label=pos_label, average=average - ) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> Union[np.ndarray, float]: - """ - Args: - y_pred : Estimated targets as returned by a classifier. - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - FBeta score of the positive class in binary classification or weighted - average of the FBeta scores of each class for the multiclass task. - - """ - return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) - - -class Hamming(SklearnMetric): - """ - Computes the average hamming loss - - Example: - - >>> y_pred = torch.tensor([0, 1, 2, 3]) - >>> y_true = torch.tensor([1, 1, 2, 3]) - >>> metric = Hamming() - >>> metric(y_pred, y_true) - tensor(0.2500) - - """ - - def __init__( - self, - reduce_group: Any = group.WORLD, - ): - """ - Args: - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - - """ - super().__init__("hamming_loss", reduce_group=reduce_group) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> Union[np.ndarray, float]: - """ - Args: - y_pred : Estimated targets as returned by a classifier. - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - Average hamming loss - - """ - return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) - - -class Hinge(SklearnMetric): - """ - Computes the average hinge loss - - Example: - - >>> pred_decision = torch.tensor([-2.17, -0.97, -0.19, -0.43]) - >>> y_true = torch.tensor([1, 1, 0, 0]) - >>> metric = Hinge() - >>> metric(pred_decision, y_true) - tensor(1.6300) - - """ - - def __init__( - self, - labels: Optional[Sequence] = None, - reduce_group: Any = group.WORLD, - ): - """ - Args: - labels: Integer array of labels. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("hinge_loss", reduce_group=reduce_group, labels=labels) - - def forward( - self, - pred_decision: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> float: - """ - Args: - pred_decision : Predicted decisions - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - Average hinge loss - - """ - return super().forward(pred_decision=pred_decision, y_true=y_true, sample_weight=sample_weight) - - -class Jaccard(SklearnMetric): - """ - Calculates jaccard similarity coefficient score - - Example: - - >>> y_pred = torch.tensor([1, 1, 1]) - >>> y_true = torch.tensor([0, 1, 1]) - >>> metric = Jaccard() - >>> metric(y_pred, y_true) - tensor(0.3333) - - """ - - def __init__( - self, - labels: Optional[Sequence] = None, - pos_label: Union[str, int] = 1, - average: Optional[str] = "macro", - reduce_group: Any = group.WORLD, - ): - """ - Args: - labels: Integer array of labels. - pos_label: The class to report if ``average='binary'``. - average: This parameter is required for multiclass/multilabel targets. - If ``None``, the scores for each class are returned. Otherwise, this - determines the type of averaging performed on the data: - - * ``'binary'``: - Only report results for the class specified by ``pos_label``. - This is applicable only if targets (``y_{true,pred}``) are binary. - * ``'micro'``: - Calculate metrics globally by counting the total true positives, - false negatives and false positives. - * ``'macro'``: - Calculate metrics for each label, and find their unweighted - mean. This does not take label imbalance into account. - * ``'weighted'``: - Calculate metrics for each label, and find their average, weighted - by support (the number of true instances for each label). This - alters 'macro' to account for label imbalance; it can result in an - F-score that is not between precision and recall. - * ``'samples'``: - Calculate metrics for each instance, and find their average (only - meaningful for multilabel classification where this differs from - :func:`accuracy_score`). - - Note that if ``pos_label`` is given in binary classification with - `average != 'binary'`, only that positive class is reported. This - behavior is deprecated and will change in version 0.18. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__( - "jaccard_score", reduce_group=reduce_group, labels=labels, pos_label=pos_label, average=average - ) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> Union[np.ndarray, float]: - """ - Args: - y_pred : Estimated targets as returned by a classifier. - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - Jaccard similarity score - - """ - return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) - - -class Precision(SklearnMetric): - """ - Compute the precision - The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of - true positives and ``fp`` the number of false positives. The precision is - intuitively the ability of the classifier not to label as positive a sample - that is negative. - The best value is 1 and the worst value is 0. - - Example: - - >>> y_pred = torch.tensor([0, 1, 2, 3]) - >>> y_true = torch.tensor([0, 1, 2, 2]) - >>> metric = Precision() - >>> metric(y_pred, y_true) - tensor(0.7500) - - """ - - def __init__( - self, - labels: Optional[Sequence] = None, - pos_label: Union[str, int] = 1, - average: Optional[str] = "macro", - reduce_group: Any = group.WORLD, - ): - """ - Args: - labels: Integer array of labels. - pos_label: The class to report if ``average='binary'``. - average: This parameter is required for multiclass/multilabel targets. - If ``None``, the scores for each class are returned. Otherwise, this - determines the type of averaging performed on the data: - - * ``'binary'``: - Only report results for the class specified by ``pos_label``. - This is applicable only if targets (``y_{true,pred}``) are binary. - * ``'micro'``: - Calculate metrics globally by counting the total true positives, - false negatives and false positives. - * ``'macro'``: - Calculate metrics for each label, and find their unweighted - mean. This does not take label imbalance into account. - * ``'weighted'``: - Calculate metrics for each label, and find their average, weighted - by support (the number of true instances for each label). This - alters 'macro' to account for label imbalance; it can result in an - F-score that is not between precision and recall. - * ``'samples'``: - Calculate metrics for each instance, and find their average (only - meaningful for multilabel classification where this differs from - :func:`accuracy_score`). - - Note that if ``pos_label`` is given in binary classification with - `average != 'binary'`, only that positive class is reported. This - behavior is deprecated and will change in version 0.18. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__( - "precision_score", reduce_group=reduce_group, labels=labels, pos_label=pos_label, average=average - ) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> Union[np.ndarray, float]: - """ - Args: - y_pred : Estimated targets as returned by a classifier. - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - Precision of the positive class in binary classification or weighted - average of the precision of each class for the multiclass task. - - """ - return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) - - -class Recall(SklearnMetric): - """ - Compute the recall - The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of - true positives and ``fn`` the number of false negatives. The recall is - intuitively the ability of the classifier to find all the positive samples. - The best value is 1 and the worst value is 0. - - Example: - - >>> y_pred = torch.tensor([0, 1, 2, 3]) - >>> y_true = torch.tensor([0, 1, 2, 2]) - >>> metric = Recall() - >>> metric(y_pred, y_true) - tensor(0.6250) - - """ - - def __init__( - self, - labels: Optional[Sequence] = None, - pos_label: Union[str, int] = 1, - average: Optional[str] = "macro", - reduce_group: Any = group.WORLD, - ): - """ - Args: - labels: Integer array of labels. - pos_label: The class to report if ``average='binary'``. - average: This parameter is required for multiclass/multilabel targets. - If ``None``, the scores for each class are returned. Otherwise, this - determines the type of averaging performed on the data: - - * ``'binary'``: - Only report results for the class specified by ``pos_label``. - This is applicable only if targets (``y_{true,pred}``) are binary. - * ``'micro'``: - Calculate metrics globally by counting the total true positives, - false negatives and false positives. - * ``'macro'``: - Calculate metrics for each label, and find their unweighted - mean. This does not take label imbalance into account. - * ``'weighted'``: - Calculate metrics for each label, and find their average, weighted - by support (the number of true instances for each label). This - alters 'macro' to account for label imbalance; it can result in an - F-score that is not between precision and recall. - * ``'samples'``: - Calculate metrics for each instance, and find their average (only - meaningful for multilabel classification where this differs from - :func:`accuracy_score`). - - Note that if ``pos_label`` is given in binary classification with - `average != 'binary'`, only that positive class is reported. This - behavior is deprecated and will change in version 0.18. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("recall_score", reduce_group=reduce_group, labels=labels, pos_label=pos_label, average=average) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> Union[np.ndarray, float]: - """ - Args: - y_pred : Estimated targets as returned by a classifier. - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - Recall of the positive class in binary classification or weighted - average of the recall of each class for the multiclass task. - - """ - return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) - - -class PrecisionRecallCurve(SklearnMetric): - """ - Compute precision-recall pairs for different probability thresholds - - Note: - This implementation is restricted to the binary classification task. - - The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of - true positives and ``fp`` the number of false positives. The precision is - intuitively the ability of the classifier not to label as positive a sample - that is negative. - The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of - true positives and ``fn`` the number of false negatives. The recall is - intuitively the ability of the classifier to find all the positive samples. - The last precision and recall values are 1. and 0. respectively and do not - have a corresponding threshold. This ensures that the graph starts on the - x axis. - """ - - def __init__( - self, - pos_label: Union[str, int] = 1, - reduce_group: Any = group.WORLD, - ): - """ - Args: - pos_label: The class to report if ``average='binary'``. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("precision_recall_curve", reduce_group=reduce_group, pos_label=pos_label) - - def forward( - self, - probas_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> Union[np.ndarray, float]: - """ - Args: - probas_pred : Estimated probabilities or decision function. - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Returns: - precision: - Precision values such that element i is the precision of - predictions with score >= thresholds[i] and the last element is 1. - recall: - Decreasing recall values such that element i is the recall of - predictions with score >= thresholds[i] and the last element is 0. - thresholds: - Increasing thresholds on the decision function used to compute - precision and recall. - - """ - # only return x and y here, since for now we cannot auto-convert elements of multiple length. - # Will be fixed in native implementation - return np.array(super().forward(probas_pred=probas_pred, y_true=y_true, sample_weight=sample_weight)[:2]) - - def aggregate(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Aggregates results by stacking them instead of concatenating before averaging. - - Returns: - the aggregated results - """ - return tuple([torch.stack(tmp).mean(0) for tmp in zip(*tensors)]) - - -class ROC(SklearnMetric): - """ - Compute Receiver operating characteristic (ROC) - - Note: - this implementation is restricted to the binary classification task. - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([0, 1, 2, 3]) - >>> y_true = torch.tensor([0, 1, 2, 2]) - >>> metric = ROC() - >>> fps, tps = metric(y_pred, y_true) - >>> fps - tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]) - >>> tps - tensor([0., 0., 0., 1., 1.]) - - References: - - [1] `Wikipedia entry for the Receiver operating characteristic - `_ - - """ - - def __init__( - self, - pos_label: Union[str, int] = 1, - reduce_group: Any = group.WORLD, - ): - """ - Args: - pos_labels: The class to report if ``average='binary'``. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("roc_curve", reduce_group=reduce_group, pos_label=pos_label) - - def forward( - self, - y_score: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> Union[np.ndarray, float]: - """ - Args: - y_score : Target scores, can either be probability estimates of the positive - class or confidence values. - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Returns: - fpr: - Increasing false positive rates such that element i is the false - positive rate of predictions with score >= thresholds[i]. - tpr: - Increasing true positive rates such that element i is the true - positive rate of predictions with score >= thresholds[i]. - thresholds: - Decreasing thresholds on the decision function used to compute - fpr and tpr. `thresholds[0]` represents no instances being predicted - and is arbitrarily set to `max(y_score) + 1`. - - """ - return np.array(super().forward(y_score=y_score, y_true=y_true, sample_weight=sample_weight)[:2]) - - def aggregate(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Aggregates results by stacking them instead of concatenating before averaging. - - Returns: - the aggregated results - """ - - return tuple([torch.stack(tmp).mean(0) for tmp in zip(*tensors)]) - - -class AUROC(SklearnMetric): - """ - Compute Area Under the Curve (AUC) from prediction scores - - Note: - this implementation is restricted to the binary classification task - or multilabel classification task in label indicator format. - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - """ - - def __init__( - self, - average: Optional[str] = "macro", - reduce_group: Any = group.WORLD, - ): - """ - Args: - average: If None, the scores for each class are returned. Otherwise, this determines the type of - averaging performed on the data: - - * If 'micro': Calculate metrics globally by considering each element of the label indicator - matrix as a label. - * If 'macro': Calculate metrics for each label, and find their unweighted mean. - This does not take label imbalance into account. - * If 'weighted': Calculate metrics for each label, and find their average, weighted by - support (the number of true instances for each label). - * If 'samples': Calculate metrics for each instance, and find their average. - - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("roc_auc_score", reduce_group=reduce_group, average=average) - - def forward( - self, - y_score: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ) -> float: - """ - Args: - y_score: Target scores, can either be probability estimates of the positive class, - confidence values, or binary decisions. - y_true: True binary labels in binary label indicators. - sample_weight: Sample weights. - - Return: - Area Under Receiver Operating Characteristic Curve - """ - return super().forward(y_score=y_score, y_true=y_true, sample_weight=sample_weight) - - -class ExplainedVariance(SklearnMetric): - """ - Calculates explained variance score - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]) - >>> y_true = torch.tensor([3, -0.5, 2, 7]) - >>> metric = ExplainedVariance() - >>> metric(y_pred, y_true) - tensor(0.9572) - """ - - def __init__( - self, - multioutput: Optional[Union[str, List[float]]] = "variance_weighted", - reduce_group: Any = group.WORLD, - ): - """ - Args: - multioutput: either one of the strings [‘raw_values’, ‘uniform_average’, 'variance_weighted'] - or an array with shape (n_outputs,) that defines how multiple - output values should be aggregated. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("explained_variance_score", reduce_group=reduce_group, multioutput=multioutput) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ): - """ - Args: - y_pred: Estimated target values - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - Explained variance score - - """ - return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) - - -class MeanAbsoluteError(SklearnMetric): - """ - Compute absolute error regression loss - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]) - >>> y_true = torch.tensor([3, -0.5, 2, 7]) - >>> metric = MeanAbsoluteError() - >>> metric(y_pred, y_true) - tensor(0.5000) - - """ - - def __init__( - self, - multioutput: Optional[Union[str, List[float]]] = "uniform_average", - reduce_group: Any = group.WORLD, - ): - """ - Args: - multioutput: either one of the strings [‘raw_values’, ‘uniform_average’] - or an array with shape (n_outputs,) that defines how multiple - output values should be aggregated. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("mean_absolute_error", reduce_group=reduce_group, multioutput=multioutput) - - def forward(self, y_pred: np.ndarray, y_true: np.ndarray, sample_weight: Optional[np.ndarray] = None): - """ - Args: - y_pred: Estimated target values - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - Mean absolute error - - """ - return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) - - -class MeanSquaredError(SklearnMetric): - """ - Compute mean squared error loss - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]) - >>> y_true = torch.tensor([3, -0.5, 2, 7]) - >>> metric = MeanSquaredError() - >>> metric(y_pred, y_true) - tensor(0.3750) - >>> metric = MeanSquaredError(squared=True) - >>> metric(y_pred, y_true) - tensor(0.6124) - - """ - - def __init__( - self, - multioutput: Optional[Union[str, List[float]]] = "uniform_average", - squared: bool = False, - reduce_group: Any = group.WORLD, - ): - """ - Args: - multioutput: either one of the strings [‘raw_values’, ‘uniform_average’] - or an array with shape (n_outputs,) that defines how multiple - output values should be aggregated. - squared: if ``True`` returns the mse value else the rmse value - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("mean_squared_error", reduce_group=reduce_group, multioutput=multioutput) - self.squared = squared - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ): - """ - Args: - y_pred: Estimated target values - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - Mean squared error - - """ - mse = super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) - if self.squared: - mse = np.sqrt(mse) - return mse - - -class MeanSquaredLogError(SklearnMetric): - """ - Calculates the mean squared log error - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([2.5, 5, 4, 8]) - >>> y_true = torch.tensor([3, 5, 2.5, 7]) - >>> metric = MeanSquaredLogError() - >>> metric(y_pred, y_true) - tensor(0.0397) - """ - - def __init__( - self, - multioutput: Optional[Union[str, List[float]]] = "uniform_average", - reduce_group: Any = group.WORLD, - ): - """ - Args: - multioutput: either one of the strings [‘raw_values’, ‘uniform_average’] - or an array with shape (n_outputs,) that defines how multiple - output values should be aggregated. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("mean_squared_log_error", reduce_group=reduce_group, multioutput=multioutput) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ): - """ - Args: - y_pred: Estimated target values - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - Mean squared log error - - """ - return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) - - -class MedianAbsoluteError(SklearnMetric): - """ - Calculates the median absolute error - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]) - >>> y_true = torch.tensor([3, -0.5, 2, 7]) - >>> metric = MedianAbsoluteError() - >>> metric(y_pred, y_true) - tensor(0.5000) - """ - - def __init__( - self, - multioutput: Optional[Union[str, List[float]]] = "uniform_average", - reduce_group: Any = group.WORLD, - ): - """ - Args: - multioutput: either one of the strings [‘raw_values’, ‘uniform_average’] - or an array with shape (n_outputs,) that defines how multiple - output values should be aggregated. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("median_absolute_error", reduce_group=reduce_group, multioutput=multioutput) - - def forward(self, y_pred: np.ndarray, y_true: np.ndarray): - """ - Args: - y_pred: Estimated target values - y_true: Ground truth (correct) target values. - - Return: - Median absolute error - - """ - return super().forward(y_true=y_true, y_pred=y_pred) - - -class R2Score(SklearnMetric): - """ - Calculates the R^2 score also known as coefficient of determination - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]) - >>> y_true = torch.tensor([3, -0.5, 2, 7]) - >>> metric = R2Score() - >>> metric(y_pred, y_true) - tensor(0.9486) - """ - - def __init__( - self, - multioutput: Optional[Union[str, List[float]]] = "uniform_average", - reduce_group: Any = group.WORLD, - ): - """ - Args: - multioutput: either one of the strings [‘raw_values’, ‘uniform_average’, 'variance_weighted'] - or an array with shape (n_outputs,) that defines how multiple - output values should be aggregated. - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("r2_score", reduce_group=reduce_group, multioutput=multioutput) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ): - """ - Args: - y_pred: Estimated target values - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - R^2 score - - """ - return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) - - -class MeanPoissonDeviance(SklearnMetric): - """ - Calculates the mean poisson deviance regression loss - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([2, 0.5, 1, 4]) - >>> y_true = torch.tensor([0.5, 0.5, 2., 2.]) - >>> metric = MeanPoissonDeviance() - >>> metric(y_pred, y_true) - tensor(0.9034) - """ - - def __init__( - self, - reduce_group: Any = group.WORLD, - ): - """ - Args: - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("mean_poisson_deviance", reduce_group=reduce_group) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ): - """ - Args: - y_pred: Estimated target values - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - Mean possion deviance - - """ - return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) - - -class MeanGammaDeviance(SklearnMetric): - """ - Calculates the mean gamma deviance regression loss - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([0.5, 0.5, 2., 2.]) - >>> y_true = torch.tensor([2, 0.5, 1, 4]) - >>> metric = MeanGammaDeviance() - >>> metric(y_pred, y_true) - tensor(1.0569) - """ - - def __init__( - self, - reduce_group: Any = group.WORLD, - ): - """ - Args: - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("mean_gamma_deviance", reduce_group=reduce_group) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ): - """ - Args: - y_pred: Estimated target values - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - Mean gamma deviance - - """ - return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) - - -class MeanTweedieDeviance(SklearnMetric): - """ - Calculates the mean tweedie deviance regression loss - - Warning: - Every metric call will cause a GPU synchronization, which may slow down your code - - Example: - - >>> y_pred = torch.tensor([2, 0.5, 1, 4]) - >>> y_true = torch.tensor([0.5, 0.5, 2., 2.]) - >>> metric = MeanTweedieDeviance() - >>> metric(y_pred, y_true) - tensor(1.8125) - """ - - def __init__( - self, - power: float = 0, - reduce_group: Any = group.WORLD, - ): - """ - Args: - power: tweedie power parameter: - - * power < 0: Extreme stable distribution. Requires: y_pred > 0. - * power = 0 : Normal distribution, output corresponds to mean_squared_error. - y_true and y_pred can be any real numbers. - * power = 1 : Poisson distribution. Requires: y_true >= 0 and y_pred > 0. - * 1 < power < 2 : Compound Poisson distribution. Requires: y_true >= 0 and y_pred > 0. - * power = 2 : Gamma distribution. Requires: y_true > 0 and y_pred > 0. - * power = 3 : Inverse Gaussian distribution. Requires: y_true > 0 and y_pred > 0. - * otherwise : Positive stable distribution. Requires: y_true > 0 and y_pred > 0. - - reduce_group: the process group for DDP reduces (only needed for DDP training). - Defaults to all processes (world) - """ - super().__init__("mean_tweedie_deviance", reduce_group=reduce_group, power=power) - - def forward( - self, - y_pred: np.ndarray, - y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None, - ): - """ - Args: - y_pred: Estimated target values - y_true: Ground truth (correct) target values. - sample_weight: Sample weights. - - Return: - Mean tweedie deviance - - """ - return super().forward(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight) diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/metrics/functional/__init__.py b/tests/metrics/functional/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py deleted file mode 100644 index 9afdf84fa8770..0000000000000 --- a/tests/metrics/functional/test_classification.py +++ /dev/null @@ -1,485 +0,0 @@ -from functools import partial - -import pytest -import torch -from sklearn.metrics import ( - accuracy_score as sk_accuracy, - jaccard_score as sk_jaccard_score, - precision_score as sk_precision, - recall_score as sk_recall, - f1_score as sk_f1_score, - fbeta_score as sk_fbeta_score, - confusion_matrix as sk_confusion_matrix, - roc_curve as sk_roc_curve, - roc_auc_score as sk_roc_auc_score, - precision_recall_curve as sk_precision_recall_curve -) - -from pytorch_lightning import seed_everything -from pytorch_lightning.metrics.functional.classification import ( - to_onehot, - to_categorical, - get_num_classes, - stat_scores, - stat_scores_multiple_classes, - accuracy, - confusion_matrix, - precision, - recall, - fbeta_score, - f1_score, - _binary_clf_curve, - dice_score, - average_precision, - auroc, - precision_recall_curve, - roc, - auc, - iou, -) - - -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [ - pytest.param(sk_accuracy, accuracy, False, id='accuracy'), - pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'), - pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'), - pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'), - pytest.param(partial(sk_f1_score, average='micro'), f1_score, False, id='f1_score'), - pytest.param(partial(sk_fbeta_score, average='micro', beta=2), - partial(fbeta_score, beta=2), False, id='fbeta_score'), - pytest.param(sk_confusion_matrix, confusion_matrix, False, id='confusion_matrix'), - pytest.param(sk_roc_curve, roc, True, id='roc'), - pytest.param(sk_precision_recall_curve, precision_recall_curve, True, id='precision_recall_curve'), - pytest.param(sk_roc_auc_score, auroc, True, id='auroc') -]) -def test_against_sklearn(sklearn_metric, torch_metric, only_binary): - """Compare PL metrics to sklearn version. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # for metrics with only_binary=False, we try out different combinations of number - # of labels in pred and target (also test binary) - # for metrics with only_binary=True, target is always binary and pred will be - # (unnormalized) class probabilities - class_comb = [(5, 2)] if only_binary else [(10, 10), (5, 10), (10, 5), (2, 2)] - for n_cls_pred, n_cls_target in class_comb: - pred = torch.randint(n_cls_pred, (300,), device=device) - target = torch.randint(n_cls_target, (300,), device=device) - - sk_score = sklearn_metric(target.cpu().detach().numpy(), - pred.cpu().detach().numpy()) - pl_score = torch_metric(pred, target) - - # if multi output - if isinstance(sk_score, tuple): - sk_score = [torch.tensor(sk_s.copy(), dtype=torch.float, device=device) for sk_s in sk_score] - for sk_s, pl_s in zip(sk_score, pl_score): - assert torch.allclose(sk_s, pl_s.float()) - else: - sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) - assert torch.allclose(sk_score, pl_score) - - -@pytest.mark.parametrize('class_reduction', ['micro', 'macro', 'weighted']) -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ - pytest.param(sk_precision, precision, id='precision'), - pytest.param(sk_recall, recall, id='recall'), - pytest.param(sk_f1_score, f1_score, id='f1_score'), - pytest.param(partial(sk_fbeta_score, beta=2), partial(fbeta_score, beta=2), id='fbeta_score') -]) -def test_different_reduction_against_sklearn(class_reduction, sklearn_metric, torch_metric): - """ Test metrics where the class_reduction parameter have a correponding - value in sklearn """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - pred = torch.randint(10, (300,), device=device) - target = torch.randint(10, (300,), device=device) - sk_score = sklearn_metric(target.cpu().detach().numpy(), - pred.cpu().detach().numpy(), - average=class_reduction) - sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) - pl_score = torch_metric(pred, target, class_reduction=class_reduction) - assert torch.allclose(sk_score, pl_score) - - -def test_onehot(): - test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) - expected = torch.stack([ - torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), - torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) - ]) - - assert test_tensor.shape == (2, 5) - assert expected.shape == (2, 10, 5) - - onehot_classes = to_onehot(test_tensor, num_classes=10) - onehot_no_classes = to_onehot(test_tensor) - - assert torch.allclose(onehot_classes, onehot_no_classes) - - assert onehot_classes.shape == expected.shape - assert onehot_no_classes.shape == expected.shape - - assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes) - assert torch.allclose(expected.to(onehot_classes), onehot_classes) - - -def test_to_categorical(): - test_tensor = torch.stack([ - torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), - torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) - ]).to(torch.float) - - expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) - assert expected.shape == (2, 5) - assert test_tensor.shape == (2, 10, 5) - - result = to_categorical(test_tensor) - - assert result.shape == expected.shape - assert torch.allclose(result, expected.to(result.dtype)) - - -@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [ - pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10), - pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), - pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), -]) -def test_get_num_classes(pred, target, num_classes, expected_num_classes): - assert get_num_classes(pred, target, num_classes) == expected_num_classes - - -@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', - 'expected_tn', 'expected_fn', 'expected_support'], [ - pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2) -]) -def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): - tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4) - - assert tp.item() == expected_tp - assert fp.item() == expected_fp - assert tn.item() == expected_tn - assert fn.item() == expected_fn - assert sup.item() == expected_support - - -@pytest.mark.parametrize(['pred', 'target', 'reduction', 'expected_tp', 'expected_fp', - 'expected_tn', 'expected_fn', 'expected_support'], [ - pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 'none', - [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'none', - [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'sum', - torch.tensor(2), torch.tensor(2), torch.tensor(14), torch.tensor(2), torch.tensor(4)), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean', - torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8)) -]) -def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): - tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction) - - assert torch.allclose(torch.tensor(expected_tp).to(tp), tp) - assert torch.allclose(torch.tensor(expected_fp).to(fp), fp) - assert torch.allclose(torch.tensor(expected_tn).to(tn), tn) - assert torch.allclose(torch.tensor(expected_fn).to(fn), fn) - assert torch.allclose(torch.tensor(expected_support).to(sup), sup) - - -def test_multilabel_accuracy(): - # Dense label indicator matrix format - y1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) - y2 = torch.tensor([[0, 0, 1], [1, 0, 1]]) - - assert torch.allclose(accuracy(y1, y2, class_reduction='none'), torch.tensor([2 / 3, 1.])) - assert torch.allclose(accuracy(y1, y1, class_reduction='none'), torch.tensor([1., 1.])) - assert torch.allclose(accuracy(y2, y2, class_reduction='none'), torch.tensor([1., 1.])) - assert torch.allclose(accuracy(y2, torch.logical_not(y2), class_reduction='none'), torch.tensor([0., 0.])) - assert torch.allclose(accuracy(y1, torch.logical_not(y1), class_reduction='none'), torch.tensor([0., 0.])) - - # num_classes does not match extracted number from input we expect a warning - with pytest.warns(RuntimeWarning, - match=r'You have set .* number of classes which is' - r' different from predicted (.*) and' - r' target (.*) number of classes'): - _ = accuracy(y2, torch.zeros_like(y2), num_classes=3) - - -def test_accuracy(): - pred = torch.tensor([0, 1, 2, 3]) - target = torch.tensor([0, 1, 2, 2]) - acc = accuracy(pred, target) - - assert acc.item() == 0.75 - - pred = torch.tensor([0, 1, 2, 2]) - target = torch.tensor([0, 1, 1, 3]) - acc = accuracy(pred, target) - - assert acc.item() == 0.50 - - -def test_confusion_matrix(): - target = (torch.arange(120) % 3).view(-1, 1) - pred = target.clone() - cm = confusion_matrix(pred, target, normalize=True) - - assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])) - - pred = torch.zeros_like(pred) - cm = confusion_matrix(pred, target, normalize=True) - assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]])) - - target = torch.LongTensor([0, 0, 0, 0, 0]) - pred = target.clone() - cm = confusion_matrix(pred, target, normalize=False, num_classes=3) - assert torch.allclose(cm, torch.tensor([[5., 0., 0.], [0., 0., 0.], [0., 0., 0.]])) - - # Example taken from https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html - target = torch.LongTensor([0] * 13 + [1] * 16 + [2] * 9) - pred = torch.LongTensor([0] * 13 + [1] * 10 + [2] * 15) - cm = confusion_matrix(pred, target, normalize=False, num_classes=3) - assert torch.allclose(cm, torch.tensor([[13., 0., 0.], [0., 10., 6.], [0., 0., 9.]])) - to_compare = cm / torch.tensor([[13.], [16.], [9.]]) - - cm = confusion_matrix(pred, target, normalize=True, num_classes=3) - assert torch.allclose(cm, to_compare) - - -@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [ - pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]), - pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]) -]) -def test_precision_recall(pred, target, expected_prec, expected_rec): - prec = precision(pred, target, class_reduction='none') - rec = recall(pred, target, class_reduction='none') - - assert torch.allclose(torch.tensor(expected_prec).to(prec), prec) - assert torch.allclose(torch.tensor(expected_rec).to(rec), rec) - - -@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [ - pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 0.5, [0.5, 0.5]), - pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 1, [0.5, 0.5]), - pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]), -]) -def test_fbeta_score(pred, target, beta, exp_score): - score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, class_reduction='none') - assert torch.allclose(score, torch.tensor(exp_score)) - - score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, class_reduction='none') - assert torch.allclose(score, torch.tensor(exp_score)) - - -@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [ - pytest.param([0., 0., 0., 0.], [1., 1., 1., 1.], [0.0, 0.0]), - pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]), - pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]), -]) -def test_f1_score(pred, target, exp_score): - score = f1_score(torch.tensor(pred), torch.tensor(target), class_reduction='none') - assert torch.allclose(score, torch.tensor(exp_score)) - - score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), class_reduction='none') - assert torch.allclose(score, torch.tensor(exp_score)) - - -@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [ - pytest.param(1, 1., 42), - pytest.param(None, 1., 42), -]) -def test_binary_clf_curve(sample_weight, pos_label, exp_shape): - # TODO: move back the pred and target to test func arguments - # if you fix the array inside the function, you'd also have fix the shape, - # because when the array changes, you also have to fix the shape - seed_everything(0) - pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100 - target = torch.tensor([0, 1] * 50, dtype=torch.int) - if sample_weight is not None: - sample_weight = torch.ones_like(pred) * sample_weight - - fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label) - - assert isinstance(tps, torch.Tensor) - assert isinstance(fps, torch.Tensor) - assert isinstance(thresh, torch.Tensor) - assert tps.shape == (exp_shape,) - assert fps.shape == (exp_shape,) - assert thresh.shape == (exp_shape,) - - -@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [ - pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4]) -]) -def test_pr_curve(pred, target, expected_p, expected_r, expected_t): - p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target)) - assert p.size() == r.size() - assert p.size(0) == t.size(0) + 1 - - assert torch.allclose(p, torch.tensor(expected_p).to(p)) - assert torch.allclose(r, torch.tensor(expected_r).to(r)) - assert torch.allclose(t, torch.tensor(expected_t).to(t)) - - -@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [ - pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), - pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), - pytest.param([1, 1], [1, 0], [0, 1], [0, 1]), - pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), - pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]), -]) -def test_roc_curve(pred, target, expected_tpr, expected_fpr): - fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target)) - - assert fpr.shape == tpr.shape - assert fpr.size(0) == thresh.size(0) - assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr)) - assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr)) - - -@pytest.mark.parametrize(['pred', 'target', 'expected'], [ - pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.), - pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.), - pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5), - pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.), - pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5), -]) -def test_auroc(pred, target, expected): - score = auroc(torch.tensor(pred), torch.tensor(target)).item() - assert score == expected - - -@pytest.mark.parametrize(['x', 'y', 'expected'], [ - pytest.param([0, 1], [0, 1], 0.5), - pytest.param([1, 0], [0, 1], 0.5), - pytest.param([1, 0, 0], [0, 1, 1], 0.5), - pytest.param([0, 1], [1, 1], 1), - pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5), -]) -def test_auc(x, y, expected): - # Test Area Under Curve (AUC) computation - assert auc(torch.tensor(x), torch.tensor(y)) == expected - - -@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [ - # Check the average_precision_score of a constant predictor is - # the TPR - # Generate a dataset with 25% of positives - # And a constant score - # The precision is then the fraction of positive whatever the recall - # is, as there is only one threshold: - pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25), - # With threshold 0.8 : 1 TP and 2 TN and one FN - pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75), -]) -def test_average_precision(scores, target, expected_score): - assert average_precision(scores, target) == expected_score - - -@pytest.mark.parametrize(['pred', 'target', 'expected'], [ - pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.), - pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.), - pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3), - pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.), -]) -def test_dice_score(pred, target, expected): - score = dice_score(torch.tensor(pred), torch.tensor(target)) - assert score == expected - - -@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [ - pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])), - pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])), - pytest.param(False, 'none', 0, torch.Tensor([1, 1])), - pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])), - pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])), - pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])), -]) -def test_iou(half_ones, reduction, ignore_index, expected): - pred = (torch.arange(120) % 3).view(-1, 1) - target = (torch.arange(120) % 3).view(-1, 1) - if half_ones: - pred[:60] = 1 - iou_val = iou( - pred=pred, - target=target, - ignore_index=ignore_index, - reduction=reduction, - ) - assert torch.allclose(iou_val, expected, atol=1e-9) - - -@pytest.mark.parametrize('metric', [auroc]) -def test_error_on_multiclass_input(metric): - """ check that these metrics raise an error if they are used for multiclass problems """ - pred = torch.randint(0, 10, (100, )) - target = torch.randint(0, 10, (100, )) - with pytest.raises(ValueError, match="AUROC metric is meant for binary classification"): - _ = metric(pred, target) - - -# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see -# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our -# `absent_score`. -@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [ - # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid - # scores the function can return ([0., 1.] range, inclusive). - # 2 classes, class 0 is correct everywhere, class 1 is absent. - pytest.param([0], [0], None, -1., 2, [1., -1.]), - pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]), - # absent_score not applied if only class 0 is present and it's the only class. - pytest.param([0], [0], None, -1., 1, [1.]), - # 2 classes, class 1 is correct everywhere, class 0 is absent. - pytest.param([1], [1], None, -1., 2, [-1., 1.]), - pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]), - # When 0 index ignored, class 0 does not get a score (not even the absent_score). - pytest.param([1], [1], 0, -1., 2, [1.0]), - # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. - pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]), - pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]), - # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. - pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]), - pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]), - # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class - # 2 is absent. - pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]), - # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class - # 2 is absent. - pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]), - # Sanity checks with absent_score of 1.0. - pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]), - pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]), -]) -def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): - iou_val = iou( - pred=torch.tensor(pred), - target=torch.tensor(target), - ignore_index=ignore_index, - absent_score=absent_score, - num_classes=num_classes, - reduction='none', - ) - assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) - - -# example data taken from -# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py -@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [ - # Ignoring an index outside of [0, num_classes-1] should have no effect. - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]), - # Ignoring a valid index drops only that index from the result. - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]), - # When reducing to mean or sum, the ignored index does not contribute to the output. - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]), - pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]), -]) -def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): - iou_val = iou( - pred=torch.tensor(pred), - target=torch.tensor(target), - ignore_index=ignore_index, - num_classes=num_classes, - reduction=reduction, - ) - assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) diff --git a/tests/metrics/functional/test_nlp.py b/tests/metrics/functional/test_nlp.py deleted file mode 100644 index 2f1647270ee64..0000000000000 --- a/tests/metrics/functional/test_nlp.py +++ /dev/null @@ -1,66 +0,0 @@ -import pytest -import torch -from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu - -from pytorch_lightning.metrics.functional.nlp import bleu_score - -# example taken from -# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu -HYPOTHESIS1 = tuple( - "It is a guide to action which ensures that the military always obeys the commands of the party".split() -) -REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split()) -REFERENCE2 = tuple( - "It is a guiding principle which makes the military forces always being under the command of the Party".split() -) -REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split()) - - -# example taken from -# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu -HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split() -HYP2 = "he read the book because he was interested in world history".split() - -REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split() -REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split() -REF1C = "It is the practical guide for the army always to heed the directions of the party".split() -REF2A = "he was interested in world history because he read the book".split() - -LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]] -HYPOTHESES = [HYP1, HYP2] - -# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction -smooth_func = SmoothingFunction().method2 - - -@pytest.mark.parametrize( - ["weights", "n_gram", "smooth_func", "smooth"], - [ - pytest.param([1], 1, None, False), - pytest.param([0.5, 0.5], 2, smooth_func, True), - pytest.param([0.333333, 0.333333, 0.333333], 3, None, False), - pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True), - ], -) -def test_bleu_score(weights, n_gram, smooth_func, smooth): - nltk_output = sentence_bleu( - [REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func - ) - pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth) - assert torch.allclose(pl_output, torch.tensor(nltk_output)) - - nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func) - pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth) - assert torch.allclose(pl_output, torch.tensor(nltk_output)) - - -def test_bleu_empty(): - hyp = [[]] - ref = [[[]]] - assert bleu_score(hyp, ref) == torch.tensor(0.0) - - -def test_no_4_gram(): - hyps = [["My", "full", "pytorch-lightning"]] - refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]] - assert bleu_score(hyps, refs) == torch.tensor(0.0) diff --git a/tests/metrics/functional/test_reduction.py b/tests/metrics/functional/test_reduction.py deleted file mode 100644 index aec54c1806715..0000000000000 --- a/tests/metrics/functional/test_reduction.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest -import torch - -from pytorch_lightning.metrics.functional.reduction import reduce, class_reduce - - -def test_reduce(): - start_tensor = torch.rand(50, 40, 30) - - assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor)) - assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor)) - assert torch.allclose(reduce(start_tensor, 'none'), start_tensor) - - with pytest.raises(ValueError): - reduce(start_tensor, 'error_reduction') - - -def test_class_reduce(): - num = torch.randint(1, 10, (100,)).float() - denom = torch.randint(10, 20, (100,)).float() - weights = torch.randint(1, 100, (100,)).float() - - assert torch.allclose(class_reduce(num, denom, weights, 'micro'), - torch.sum(num) / torch.sum(denom)) - assert torch.allclose(class_reduce(num, denom, weights, 'macro'), - torch.mean(num / denom)) - assert torch.allclose(class_reduce(num, denom, weights, 'weighted'), - torch.sum(num / denom * (weights / torch.sum(weights)))) - assert torch.allclose(class_reduce(num, denom, weights, 'none'), - num / denom) diff --git a/tests/metrics/functional/test_regression.py b/tests/metrics/functional/test_regression.py deleted file mode 100644 index 49a79f9424f13..0000000000000 --- a/tests/metrics/functional/test_regression.py +++ /dev/null @@ -1,175 +0,0 @@ -import numpy as np -import pytest -import torch -from functools import partial -from math import sqrt -from skimage.metrics import ( - peak_signal_noise_ratio as ski_psnr, - structural_similarity as ski_ssim -) -from sklearn.metrics import ( - mean_absolute_error as mae_sk, - mean_squared_error as mse_sk, - mean_squared_log_error as msle_sk -) - -from pytorch_lightning.metrics.functional import ( - mae, - mse, - psnr, - rmse, - rmsle, - ssim -) - - -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ - pytest.param(mae_sk, mae, id='mean_absolute_error'), - pytest.param(mse_sk, mse, id='mean_squared_error'), - pytest.param(partial(mse_sk, squared=False), rmse, id='root_mean_squared_error'), - pytest.param(lambda x, y: sqrt(msle_sk(x, y)), rmsle, id='root_mean_squared_log_error') -]) -def test_against_sklearn(sklearn_metric, torch_metric): - """Compare PL metrics to sklearn version.""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # iterate over different label counts in predictions and target - pred = torch.rand(300, device=device) - target = torch.rand(300, device=device) - - sk_score = sklearn_metric(target.cpu().detach().numpy(), - pred.cpu().detach().numpy()) - sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) - pl_score = torch_metric(pred, target) - assert torch.allclose(sk_score, pl_score) - - -@pytest.mark.parametrize(['pred', 'target', 'expected'], [ - pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25), - pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 3.0), -]) -def test_mse(pred, target, expected): - score = mse(torch.tensor(pred), torch.tensor(target)) - assert score.item() == expected - - -@pytest.mark.parametrize(['pred', 'target', 'expected'], [ - pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0), - pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.5), - pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.7321), -]) -def test_rmse(pred, target, expected): - score = rmse(torch.tensor(pred), torch.tensor(target)) - assert torch.allclose(score, torch.tensor(expected), atol=1e-3) - - -@pytest.mark.parametrize(['pred', 'target', 'expected'], [ - pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0), - pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25), - pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.5), -]) -def test_mae(pred, target, expected): - score = mae(torch.tensor(pred), torch.tensor(target)) - assert score.item() == expected - - -@pytest.mark.parametrize(['pred', 'target', 'expected'], [ - pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0), - pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.1438), - pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.5330), -]) -def test_rmsle(pred, target, expected): - score = rmsle(torch.tensor(pred), torch.tensor(target)) - assert torch.allclose(score, torch.tensor(expected), atol=1e-3) - - -@pytest.mark.parametrize(['pred', 'target'], [ - pytest.param([0., 1., 2., 3.], [0., 1., 2., 3.]), - pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]), - pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]), -]) -def test_psnr_with_skimage(pred, target): - score = psnr(pred=torch.tensor(pred), - target=torch.tensor(target), data_range=3) - sk_score = ski_psnr(np.array(pred), np.array(target), data_range=3) - assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float), atol=1e-3) - - -@pytest.mark.parametrize(['pred', 'target'], [ - pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]), - pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]), -]) -def test_psnr_base_e_wider_range(pred, target): - score = psnr(pred=torch.tensor(pred), - target=torch.tensor(target), - data_range=4, - base=2.718281828459045) - sk_score = ski_psnr(np.array(pred), np.array(target), data_range=4) * np.log(10) - assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float32), atol=1e-3) - - -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ - pytest.param(ski_psnr, psnr, id='peak_signal_noise_ratio') -]) -def test_psnr_against_sklearn(sklearn_metric, torch_metric): - """Compare PL metrics to sklearn version.""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]: - pred = torch.randint(n_cls_pred, (500,), device=device, dtype=torch.float) - target = torch.randint(n_cls_target, (500,), device=device, dtype=torch.float) - - sk_score = sklearn_metric(target.cpu().detach().numpy(), - pred.cpu().detach().numpy(), - data_range=n_cls_target) - sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) - pl_score = torch_metric(pred, target, data_range=n_cls_target) - assert torch.allclose(sk_score, pl_score) - - -@pytest.mark.parametrize(['size', 'channel', 'coef', 'multichannel'], [ - pytest.param(16, 1, 0.9, False), - pytest.param(32, 3, 0.8, True), - pytest.param(48, 4, 0.7, True), - pytest.param(64, 5, 0.6, True) -]) -def test_ssim(size, channel, coef, multichannel): - device = "cuda" if torch.cuda.is_available() else "cpu" - pred = torch.rand(size, channel, size, size, device=device) - target = pred * coef - ssim_idx = ssim(pred, target, data_range=1.0) - np_pred = pred.permute(0, 2, 3, 1).cpu().numpy() - if multichannel is False: - np_pred = np_pred[:, :, :, 0] - np_target = np.multiply(np_pred, coef) - sk_ssim_idx = ski_ssim( - np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True, data_range=1.0 - ) - assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-4) - - ssim_idx = ssim(pred, pred) - assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device)) - - -@pytest.mark.parametrize(['pred', 'target', 'kernel', 'sigma'], [ - pytest.param([1, 1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # shape - pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma) - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input - pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input -]) -def test_ssim_invalid_inputs(pred, target, kernel, sigma): - pred_t = torch.rand(pred) - target_t = torch.rand(target, dtype=torch.float64) - with pytest.raises(TypeError): - ssim(pred_t, target_t) - - pred = torch.rand(pred) - target = torch.rand(target) - with pytest.raises(ValueError): - ssim(pred, target, kernel, sigma) diff --git a/tests/metrics/functional/test_self_supervised.py b/tests/metrics/functional/test_self_supervised.py deleted file mode 100644 index 1ef3b43f77b62..0000000000000 --- a/tests/metrics/functional/test_self_supervised.py +++ /dev/null @@ -1,35 +0,0 @@ -import pytest -import torch -from sklearn.metrics import pairwise - -from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity - - -@pytest.mark.parametrize('similarity', ['cosine', 'dot']) -@pytest.mark.parametrize('reduction', ['none', 'mean', 'sum']) -def test_against_sklearn(similarity, reduction): - """Compare PL metrics to sklearn version.""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - batch = torch.randn(5, 10, device=device) # 100 samples in 10 dimensions - - pl_dist = embedding_similarity(batch, similarity=similarity, - reduction=reduction, zero_diagonal=False) - - def sklearn_embedding_distance(batch, similarity, reduction): - - metric_func = {'cosine': pairwise.cosine_similarity, - 'dot': pairwise.linear_kernel}[similarity] - - dist = metric_func(batch, batch) - if reduction == 'mean': - return dist.mean(axis=-1) - if reduction == 'sum': - return dist.sum(axis=-1) - return dist - - sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(), - similarity=similarity, reduction=reduction) - sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device) - - assert torch.allclose(sk_dist, pl_dist) diff --git a/tests/metrics/test_aggregation.py b/tests/metrics/test_aggregation.py deleted file mode 100644 index 73c1e05118554..0000000000000 --- a/tests/metrics/test_aggregation.py +++ /dev/null @@ -1,297 +0,0 @@ -import pytest -import sys -from collections import namedtuple -from functools import partial -import math - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import numpy as np - -from tests.base import EvalModelTemplate -from pytorch_lightning import Trainer -import tests.base.develop_utils as tutils -from pytorch_lightning.metrics import ( - Accuracy, - ConfusionMatrix, - PrecisionRecallCurve, - Precision, - Recall, - AveragePrecision, - AUROC, - FBeta, - F1, - ROC, - MulticlassROC, - MulticlassPrecisionRecallCurve, - DiceCoefficient, - IoU, - MAE, - MSE, - RMSE, - RMSLE, - PSNR, - SSIM, -) - -from sklearn.metrics import ( - accuracy_score, - confusion_matrix, - precision_recall_curve, - precision_score, - recall_score, - average_precision_score, - roc_auc_score, - fbeta_score, - f1_score, - roc_curve, - jaccard_score, - mean_squared_error, - mean_absolute_error, - mean_squared_log_error -) - -from skimage.metrics import ( - peak_signal_noise_ratio, - structural_similarity -) - -# example structure -TestCase = namedtuple('example', ['name', 'lightning_metric', 'comparing_metric', 'test_input']) - -# setup some standard testcases -NB_SAMPLES = 200 -multiclass_example = [(torch.randint(10, (NB_SAMPLES,)), torch.randint(10, (NB_SAMPLES,)))] -binary_example = [(torch.randint(2, (NB_SAMPLES,)), torch.randint(2, (NB_SAMPLES,)))] -multiclass_and_binary_example = [*multiclass_example, *binary_example] -binary_example_logits = (torch.randint(2, (NB_SAMPLES,)), torch.randint(5, (NB_SAMPLES,))) -multiclass_example_probs = (torch.randint(10, (NB_SAMPLES,)), torch.randn((NB_SAMPLES, 10)).softmax(-1)) -regression_example = [(torch.rand((NB_SAMPLES,)), torch.rand((NB_SAMPLES,)))] - - -# construct additional test functions -def root_mean_squared_error(x, y): - return math.sqrt(mean_squared_error(x, y)) - - -def root_mean_squared_log_error(x, y): - return math.sqrt(mean_squared_log_error(x, y)) - - -# Define testcases -# TODO: update remaining metrics and uncomment the corresponding test cases -TESTS = [ - TestCase('accuracy', - Accuracy, - accuracy_score, - multiclass_and_binary_example), - TestCase('confusion matrix without normalize', - ConfusionMatrix, - confusion_matrix, - multiclass_and_binary_example), - TestCase('confusion matrix with normalize', - partial(ConfusionMatrix, normalize=True), - partial(confusion_matrix, normalize='true'), - multiclass_and_binary_example), - # TestCase('precision recall curve', - # PrecisionRecallCurve, - # precision_recall_curve, - # binary_example), - TestCase('precision', - Precision, - partial(precision_score, average='micro'), - multiclass_and_binary_example), - TestCase('recall', - Recall, - partial(recall_score, average='micro'), - multiclass_and_binary_example), - # TestCase('average_precision', - # AveragePrecision, - # average_precision_score, - # binary_example), - # TestCase('auroc', - # AUROC, - # roc_auc_score, - # binary_example), - TestCase('f beta', - partial(FBeta, beta=2), - partial(fbeta_score, average='micro', beta=2), - multiclass_and_binary_example), - TestCase('f1', - F1, - partial(f1_score, average='micro'), - multiclass_and_binary_example), - # TestCase('roc', - # ROC, - # roc_curve, - # binary_example), - # TestCase('multiclass roc', - # MulticlassROC, - # multiclass_roc, - # binary_example), - # TestCase('multiclass precision recall curve', - # MulticlassPrecisionRecallCurve, - # multiclass_precision_recall_curve, - # binary_example), - # TestCase('dice coefficient', - # DiceCoefficient, - # partial(f1_score, average='micro'), - # multiclass_and_binary_example), - # TestCase('intersection over union', - # IoU, - # partial(jaccard_score, average='macro'), - # binary_example), - TestCase('mean squared error', - MSE, - mean_squared_error, - regression_example), - TestCase('root mean squared error', - RMSE, - root_mean_squared_error, - regression_example), - TestCase('mean absolute error', - MAE, - mean_absolute_error, - regression_example), - TestCase('root mean squared log error', - RMSLE, - root_mean_squared_log_error, - regression_example), - TestCase('peak signal-to-noise ratio', - partial(PSNR, data_range=10), - partial(peak_signal_noise_ratio, data_range=10), - regression_example), - # TestCase('structual similarity index measure', - # SSIM, - # structural_similarity, - # regression_example) -] - - -# Utility test functions -def _idsfn(test): - """ Return id for current example being tested """ - return test.name - - -def _setup_ddp(rank, worldsize): - """ setup ddp enviroment for testing """ - import os - os.environ['MASTER_ADDR'] = 'localhost' - # initialize the process group - dist.init_process_group("gloo", rank=rank, world_size=worldsize) - - -def comparing_fn(lightning_val, comparing_val, rtol=1e-03, atol=1e-08): - """ function for comparing output, both multi and single output""" - # multi output - if isinstance(comparing_val, tuple): - for l_score, c_score in zip(lightning_val, comparing_val): - assert np.allclose(l_score.numpy(), c_score, rtol, atol) - else: # single output - assert np.allclose(lightning_val.numpy(), comparing_val, rtol, atol) - - -# ===== Tests start here ===== -def _test_ddp_single_batch(rank, worldsize, lightning_metric, comparing_metric, test_inputs): - """ ddp testing function, divide test_inputs equally between all processes """ - _setup_ddp(rank, worldsize) - - # Setup metric for ddp - lightning_metric = lightning_metric() - for test_input in test_inputs: - # rank 0 receives sample 0,2,4,... - # rank 1 receives sample 1,3,5,... - lightning_val = lightning_metric(*[ti[rank::2] for ti in test_input]) - - comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)]) - - comparing_fn(lightning_val, comparing_val) - - -@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") -@pytest.mark.parametrize("test", TESTS, ids=_idsfn) -def test_ddp(test): - """Make sure that metrics are correctly sync and reduced in DDP mode""" - tutils.reset_seed() - tutils.set_random_master_port() - - worldsize = 2 - mp.spawn(_test_ddp_single_batch, - args=(worldsize, - test.lightning_metric, - test.comparing_metric, - test.test_input), - nprocs=worldsize) - - -@pytest.mark.parametrize("test", TESTS, ids=_idsfn) -def test_multi_batch(test): - """ test that aggregation works for multiple batches """ - lightning_metric = test.lightning_metric() - comparing_metric = test.comparing_metric - - for test_input in test.test_input: - for i in range(2): # for lightning device in 2 artificially batches - # first batch consist of samples 0,2,4,... - # second batch consist of samples 1,3,5,... - _ = lightning_metric(*[ti[i::2] for ti in test_input]) - lightning_val = lightning_metric.aggregated - comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)]) - - comparing_fn(lightning_val, comparing_val) - - -@pytest.mark.parametrize("test", TESTS, ids=_idsfn) -def test_multi_batch_unequal_sizes(test): - """ test that aggregation works for multiple batches with uneven sizes """ - lightning_metric = test.lightning_metric() - comparing_metric = test.comparing_metric - - for test_input in test.test_input: - for i in range(2): # for lightning device in 2 artificially batches - if i == 0: # allocate 3/4 of data to the first batch - _ = lightning_metric(*[ti[:int(3 / 4 * len(ti))] for ti in test_input]) - else: - _ = lightning_metric(*[ti[int(3 / 4 * len(ti)):] for ti in test_input]) - lightning_val = lightning_metric.aggregated - comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)]) - - comparing_fn(lightning_val, comparing_val) - - -def _test_ddp_multi_batch(rank, worldsize, lightning_metric, comparing_metric, test_inputs): - """ ddp testing function, test that metric works with aggregation over multiple - devices and multiple batches """ - _setup_ddp(rank, worldsize) - - # Setup metric for ddp - lightning_metric = lightning_metric() - for test_input in test_inputs: - for i in range(2): # artificially divide samples between batches and processes - # rank 0, batch 0 consist of samples 0,4,8,... - # rank 0, batch 1 consist of samples 1,5,9,... - # rank 1, batch 0 consist of samples 2,6,10,... - # rank 1, batch 1 consist of samples 3,7,11,... - _ = lightning_metric(*[ti[i + worldsize * rank::4] for ti in test_input]) - lightning_val = lightning_metric.aggregated - comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)]) - - comparing_fn(lightning_val, comparing_val) - - -@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") -@pytest.mark.parametrize("test", TESTS, ids=_idsfn) -def test_ddp_multi_batch(test): - """ test that aggregation works fine with in DDP mode and multiple batches """ - tutils.reset_seed() - tutils.set_random_master_port() - - worldsize = 2 - mp.spawn(_test_ddp_multi_batch, - args=(worldsize, - test.lightning_metric, - test.comparing_metric, - test.test_input), - nprocs=worldsize) diff --git a/tests/metrics/test_classification.py b/tests/metrics/test_classification.py deleted file mode 100644 index 0914134adc29e..0000000000000 --- a/tests/metrics/test_classification.py +++ /dev/null @@ -1,237 +0,0 @@ -# NOTE: This file only tests if modules with arguments are running fine. -# The actual metric implementation is tested in functional/test_classification.py -# Especially reduction and reducing across processes won't be tested here! - -import pytest -import torch - -from pytorch_lightning.metrics.classification import ( - Accuracy, - ConfusionMatrix, - PrecisionRecallCurve, - Precision, - Recall, - AveragePrecision, - AUROC, - FBeta, - F1, - ROC, - MulticlassROC, - MulticlassPrecisionRecallCurve, - DiceCoefficient, - IoU, -) - - -@pytest.fixture -def random(): - torch.manual_seed(0) - - -@pytest.mark.parametrize('num_classes', [1, None]) -def test_accuracy(num_classes): - acc = Accuracy(num_classes=num_classes) - assert acc.name == 'accuracy' - - result = acc(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]), - target=torch.tensor([[0, 0, 1], [1, 0, 1]])) - assert isinstance(result, torch.Tensor) - - -@pytest.mark.parametrize(['normalize', 'num_classes'], [ - pytest.param(False, None), - pytest.param(True, None), - pytest.param(False, 3) -]) -def test_confusion_matrix(normalize, num_classes): - conf_matrix = ConfusionMatrix(normalize=normalize, num_classes=num_classes) - assert conf_matrix.name == 'confusion_matrix' - - target = (torch.arange(120) % 3).view(-1, 1) - pred = target.clone() - cm = conf_matrix(pred, target) - assert isinstance(cm, torch.Tensor) - - -@pytest.mark.parametrize(['normalize', 'num_classes'], [ - pytest.param(True, 3) -]) -def test_confusion_matrix_norm(normalize, num_classes): - """ test that user is warned if confusion matrix contains nans that are changed to zeros""" - conf_matrix = ConfusionMatrix(normalize=normalize, num_classes=num_classes) - assert conf_matrix.name == 'confusion_matrix' - - with pytest.warns(UserWarning, match='6 nan values found in confusion matrix have been replaced with zeros.'): - target = torch.LongTensor([0] * 5) - pred = target.clone() - cm = conf_matrix(pred, target) - assert isinstance(cm, torch.Tensor) - - -@pytest.mark.parametrize('pos_label', [1, 2.]) -def test_precision_recall(pos_label): - pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1]) - - pr_curve = PrecisionRecallCurve(pos_label=pos_label) - assert pr_curve.name == 'precision_recall_curve' - - pr = pr_curve(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4]) - - assert isinstance(pr, tuple) - assert len(pr) == 3 - for tmp in pr: - assert isinstance(tmp, torch.Tensor) - - -@pytest.mark.parametrize('num_classes', [1, None]) -def test_precision(num_classes): - precision = Precision(num_classes=num_classes) - assert precision.name == 'precision' - - pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1]) - prec = precision(pred=pred, target=target) - assert isinstance(prec, torch.Tensor) - - -@pytest.mark.parametrize('num_classes', [1, None]) -def test_recall(num_classes): - recall = Recall(num_classes=num_classes) - assert recall.name == 'recall' - - pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1]) - rec = recall(pred=pred, target=target) - assert isinstance(rec, torch.Tensor) - - -@pytest.mark.parametrize('pos_label', [1, 2]) -def test_average_precision(pos_label): - avg_prec = AveragePrecision(pos_label=pos_label) - assert avg_prec.name == 'AP' - - pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 0, 1]) - ap = avg_prec(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4]) - assert isinstance(ap, torch.Tensor) - - -@pytest.mark.parametrize('pos_label', [0, 1]) -def test_auroc(pos_label): - auroc = AUROC(pos_label=pos_label) - assert auroc.name == 'auroc' - - pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 1, 0, 1]) - area = auroc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4]) - assert isinstance(area, torch.Tensor) - - -@pytest.mark.parametrize(['beta', 'num_classes'], [ - pytest.param(0., 1), - pytest.param(0.5, 1), - pytest.param(1., 1), - pytest.param(2., 1), - pytest.param(0., None), - pytest.param(0.5, None), - pytest.param(1., None), - pytest.param(2., None) -]) -def test_fbeta(beta, num_classes): - fbeta = FBeta(beta=beta, num_classes=num_classes) - assert fbeta.name == 'fbeta' - - score = fbeta(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]), - target=torch.tensor([[0, 0, 1], [1, 0, 1]])) - assert isinstance(score, torch.Tensor) - - -@pytest.mark.parametrize('num_classes', [1, None]) -def test_f1(num_classes): - f1 = F1(num_classes=num_classes) - assert f1.name == 'f1' - - score = f1(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]), - target=torch.tensor([[0, 0, 1], [1, 0, 1]])) - assert isinstance(score, torch.Tensor) - - -@pytest.mark.parametrize('pos_label', [1, 2]) -def test_roc(pos_label): - roc = ROC(pos_label=pos_label) - assert roc.name == 'roc' - - pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 4, 3]) - res = roc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4]) - - assert isinstance(res, tuple) - assert len(res) == 3 - for tmp in res: - assert isinstance(tmp, torch.Tensor) - - -@pytest.mark.parametrize('num_classes', [4, None]) -def test_multiclass_roc(num_classes): - pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - [0.05, 0.85, 0.05, 0.05], - [0.05, 0.05, 0.85, 0.05], - [0.05, 0.05, 0.05, 0.85]]) - target = torch.tensor([0, 1, 3, 2]) - - multi_roc = MulticlassROC(num_classes=num_classes) - assert multi_roc.name == 'multiclass_roc' - - res = multi_roc(pred, target) - assert isinstance(res, tuple) - - if num_classes is not None: - assert len(res) == num_classes - - for tmp in res: - assert isinstance(tmp, tuple) - assert len(tmp) == 3 - - for _tmp in tmp: - assert isinstance(_tmp, torch.Tensor) - - -@pytest.mark.parametrize('num_classes', [4, None]) -def test_multiclass_pr(num_classes): - pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - [0.05, 0.85, 0.05, 0.05], - [0.05, 0.05, 0.85, 0.05], - [0.05, 0.05, 0.05, 0.85]]) - target = torch.tensor([0, 1, 3, 2]) - - multi_pr = MulticlassPrecisionRecallCurve(num_classes=num_classes) - assert multi_pr.name == 'multiclass_precision_recall_curve' - - pr = multi_pr(pred, target) - assert isinstance(pr, tuple) - - if num_classes is not None: - assert len(pr) == num_classes - - for tmp in pr: - assert isinstance(tmp, tuple) - assert len(tmp) == 3 - - for _tmp in tmp: - assert isinstance(_tmp, torch.Tensor) - - -@pytest.mark.parametrize('include_background', [True, False]) -def test_dice_coefficient(include_background): - dice_coeff = DiceCoefficient(include_background=include_background) - assert dice_coeff.name == 'dice' - - dice = dice_coeff(torch.randint(0, 1, (10, 25, 25)), - torch.randint(0, 1, (10, 25, 25))) - assert isinstance(dice, torch.Tensor) - - -@pytest.mark.parametrize('ignore_index', [0, 1, None]) -def test_iou(ignore_index): - iou = IoU(ignore_index=ignore_index) - assert iou.name == 'iou' - - score = iou(torch.randint(0, 1, (10, 25, 25)), - torch.randint(0, 1, (10, 25, 25))) - - assert isinstance(score, torch.Tensor) diff --git a/tests/metrics/test_converters.py b/tests/metrics/test_converters.py deleted file mode 100644 index c64353ed17bf1..0000000000000 --- a/tests/metrics/test_converters.py +++ /dev/null @@ -1,265 +0,0 @@ -import sys - -import numpy as np -import pytest -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -import tests.base.develop_utils as tutils -from pytorch_lightning.metrics.converters import ( - _apply_to_inputs, - _apply_to_outputs, - convert_to_tensor, - convert_to_numpy, - _numpy_metric_conversion, - _tensor_metric_conversion, - sync_ddp_if_available, - gather_all_tensors_if_available, - tensor_metric, - numpy_metric -) - - -def test_apply_to_inputs(): - def apply_fn(inputs, factor): - if isinstance(inputs, (float, int)): - return inputs * factor - elif isinstance(inputs, dict): - return {k: apply_fn(v, factor) for k, v in inputs.items()} - elif isinstance(inputs, (tuple, list)): - return [apply_fn(x, factor) for x in inputs] - - @_apply_to_inputs(apply_fn, factor=2.) - def test_fn(*args, **kwargs): - return args, kwargs - - for args in [[], [1., 2.]]: - for kwargs in [{}, {'a': 1., 'b': 2.}]: - result_args, result_kwargs = test_fn(*args, **kwargs) - assert isinstance(result_args, (list, tuple)) - assert isinstance(result_kwargs, dict) - assert len(result_args) == len(args) - assert len(result_kwargs) == len(kwargs) - assert all([k in result_kwargs for k in kwargs.keys()]) - for arg, result_arg in zip(args, result_args): - assert arg * 2. == result_arg - - for key in kwargs.keys(): - arg = kwargs[key] - result_arg = result_kwargs[key] - assert arg * 2. == result_arg - - -def test_apply_to_outputs(): - def apply_fn(inputs, additional_str): - return str(inputs) + additional_str - - @_apply_to_outputs(apply_fn, additional_str='_str') - def test_fn(*args, **kwargs): - return 'dummy' - - assert test_fn() == 'dummy_str' - - -def test_convert_to_tensor(): - for test_item in [1., np.array([1.])]: - result_tensor = convert_to_tensor(test_item) - assert isinstance(result_tensor, torch.Tensor) - assert result_tensor.item() == 1. - - -def test_convert_to_numpy(): - for test_item in [1., torch.tensor([1.])]: - result = convert_to_numpy(test_item) - assert isinstance(result, np.ndarray) - assert result.item() == 1. - - -def test_numpy_metric_conversion(): - @_numpy_metric_conversion - def numpy_test_metric(*args, **kwargs): - for arg in args: - assert isinstance(arg, np.ndarray) - - for v in kwargs.values(): - assert isinstance(v, np.ndarray) - - return 5. - - result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.) - assert isinstance(result, torch.Tensor) - assert result.item() == 5. - - -def test_tensor_metric_conversion(): - @_tensor_metric_conversion - def tensor_test_metric(*args, **kwargs): - for arg in args: - assert isinstance(arg, torch.Tensor) - - for v in kwargs.values(): - assert isinstance(v, torch.Tensor) - - return 5. - - result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.) - assert isinstance(result, torch.Tensor) - assert result.item() == 5. - - -def _setup_ddp(rank, worldsize): - import os - - os.environ['MASTER_ADDR'] = 'localhost' - - # initialize the process group - dist.init_process_group("gloo", rank=rank, world_size=worldsize) - - -def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False): - _setup_ddp(rank, worldsize) - if add_offset: - tensor = torch.tensor([float(rank)]) - else: - tensor = torch.tensor([1.], ) - if reduction_mean: - reduced_tensor = sync_ddp_if_available(tensor, reduce_op='avg') - - manual_reduction = sum([i for i in range(dist.get_world_size())]) / dist.get_world_size() - assert reduced_tensor.item() == manual_reduction - else: - reduced_tensor = sync_ddp_if_available(tensor) - - assert reduced_tensor.item() == dist.get_world_size(), \ - 'Sync-Reduce does not work properly with DDP and Tensors' - - -def _ddp_test_gather_all_tensors(rank, worldsize): - _setup_ddp(rank, worldsize) - - tensor = torch.tensor([rank]) - gather_tensors = gather_all_tensors_if_available(tensor) - mannual_tensors = [torch.tensor([i]) for i in range(worldsize)] - - for t1, t2 in zip(gather_tensors, mannual_tensors): - assert(t1.equal(t2)) - - -@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows") -def test_sync_reduce_ddp(): - """Make sure sync-reduce works with DDP""" - tutils.reset_seed() - tutils.set_random_master_port() - - worldsize = 2 - mp.spawn(_ddp_test_fn, args=(worldsize, False), nprocs=worldsize) - - -@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows") -def test_sync_reduce_ddp_mean(): - """Make sure sync-reduce works with DDP""" - tutils.reset_seed() - tutils.set_random_master_port() - - worldsize = 2 - mp.spawn(_ddp_test_fn, args=(worldsize, True, True), nprocs=worldsize) - - -def test_sync_reduce_simple(): - """Make sure sync-reduce works without DDP""" - tensor = torch.tensor([1.], device='cpu') - - reduced_tensor = sync_ddp_if_available(tensor) - - assert torch.allclose(tensor, reduced_tensor), \ - 'Sync-Reduce does not work properly without DDP and Tensors' - - -@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows") -def test_gather_all_tensors_ddp(): - """Make sure gather_all_tensors works with DDP""" - tutils.reset_seed() - tutils.set_random_master_port() - - worldsize = 2 - mp.spawn(_ddp_test_gather_all_tensors, args=(worldsize, ), nprocs=worldsize) - - -def _test_tensor_metric(is_ddp: bool): - @tensor_metric() - def tensor_test_metric(*args, **kwargs): - for arg in args: - assert isinstance(arg, torch.Tensor) - - for v in kwargs.values(): - assert isinstance(v, torch.Tensor) - - return 5. - - if is_ddp: - factor = dist.get_world_size() - else: - factor = 1. - - result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.) - assert isinstance(result, torch.Tensor) - assert result.item() == 5. * factor - - -def _ddp_test_tensor_metric(rank, worldsize): - _setup_ddp(rank, worldsize) - _test_tensor_metric(True) - - -@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows") -def test_tensor_metric_ddp(): - tutils.reset_seed() - tutils.set_random_master_port() - - world_size = 2 - mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size) - # dist.destroy_process_group() - - -def test_tensor_metric_simple(): - _test_tensor_metric(False) - - -def _test_numpy_metric(is_ddp: bool): - @numpy_metric() - def numpy_test_metric(*args, **kwargs): - for arg in args: - assert isinstance(arg, np.ndarray) - - for v in kwargs.values(): - assert isinstance(v, np.ndarray) - - return 5. - - if is_ddp: - factor = dist.get_world_size() - else: - factor = 1. - - result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.) - assert isinstance(result, torch.Tensor) - assert result.item() == 5. * factor - - -def _ddp_test_numpy_metric(rank, worldsize): - _setup_ddp(rank, worldsize) - _test_numpy_metric(True) - - -@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows") -def test_numpy_metric_ddp(): - tutils.reset_seed() - tutils.set_random_master_port() - world_size = 2 - mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size) - # dist.destroy_process_group() - - -def test_numpy_metric_simple(): - _test_numpy_metric(False) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py deleted file mode 100644 index f3395a551ef62..0000000000000 --- a/tests/metrics/test_metrics.py +++ /dev/null @@ -1,323 +0,0 @@ -import os -from typing import Any -import numpy as np -import pytest -import torch - -import tests.base.develop_utils as tutils -from tests.base import EvalModelTemplate -from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric -from pytorch_lightning import Trainer - - -class DummyTensorMetric(TensorMetric): - def __init__(self): - super().__init__("dummy") - - def forward(self, input1, input2): - assert isinstance(input1, torch.Tensor) - assert isinstance(input2, torch.Tensor) - return torch.tensor([1.0]) - - -class DummyNumpyMetric(NumpyMetric): - def __init__(self): - super().__init__("dummy") - - def forward(self, input1, input2): - assert isinstance(input1, np.ndarray) - assert isinstance(input2, np.ndarray) - return 1.0 - - -class DummyTensorCollectionMetric(TensorMetric): - def __init__(self): - super().__init__("dummy") - - def forward(self, input1, input2): - assert isinstance(input1, torch.Tensor) - assert isinstance(input2, torch.Tensor) - return 1.0, 2.0, 3.0, 4.0 - - -@pytest.mark.parametrize("metric", [DummyTensorCollectionMetric()]) -def test_collection_metric(metric: Metric): - """ Test that metric.device, metric.dtype works for metric collection """ - input1, input2 = torch.tensor([1.0]), torch.tensor([2.0]) - - def change_and_check_device_dtype(device, dtype): - metric.to(device=device, dtype=dtype) - - metric_val = metric(input1, input2) - assert not isinstance(metric_val, torch.Tensor) - - if device is not None: - assert metric.device in [device, torch.device(device)] - - if dtype is not None: - assert metric.dtype == dtype - - devices = [None, "cpu"] - if torch.cuda.is_available(): - devices += ["cuda:0"] - - for device in devices: - for dtype in [None, torch.float32, torch.float64]: - change_and_check_device_dtype(device=device, dtype=dtype) - - if torch.cuda.is_available(): - metric.cuda(0) - assert metric.device == torch.device("cuda", index=0) - - metric.cpu() - assert metric.device == torch.device("cpu") - - metric.type(torch.int8) - assert metric.dtype == torch.int8 - - metric.float() - assert metric.dtype == torch.float32 - - metric.double() - assert metric.dtype == torch.float64 - assert all(out.dtype == torch.float64 for out in metric(input1, input2)) - - if torch.cuda.is_available(): - metric.cuda() - metric.half() - assert metric.dtype == torch.float16 - - -@pytest.mark.parametrize( - "metric", - [ - DummyTensorMetric(), - DummyNumpyMetric(), - ], -) -def test_metric(metric: Metric): - """ Test that metric.device, metric.dtype works for single metric""" - input1, input2 = torch.tensor([1.0]), torch.tensor([2.0]) - - def change_and_check_device_dtype(device, dtype): - metric.to(device=device, dtype=dtype) - - metric_val = metric(input1, input2) - assert isinstance(metric_val, torch.Tensor) - - if device is not None: - assert metric.device in [device, torch.device(device)] - assert metric_val.device in [device, torch.device(device)] - - if dtype is not None: - assert metric.dtype == dtype - assert metric_val.dtype == dtype - - devices = [None, "cpu"] - if torch.cuda.is_available(): - devices += ["cuda:0"] - - for device in devices: - for dtype in [None, torch.float32, torch.float64]: - change_and_check_device_dtype(device=device, dtype=dtype) - - if torch.cuda.is_available(): - metric.cuda(0) - assert metric.device == torch.device("cuda", index=0) - assert metric(input1, input2).device == torch.device("cuda", index=0) - - metric.cpu() - assert metric.device == torch.device("cpu") - assert metric(input1, input2).device == torch.device("cpu") - - metric.float() - assert metric.dtype == torch.float32 - assert metric(input1, input2).dtype == torch.float32 - - metric.double() - assert metric.dtype == torch.float64 - assert metric(input1, input2).dtype == torch.float64 - - if torch.cuda.is_available(): - metric.cuda() - metric.half() - assert metric.dtype == torch.float16 - assert metric(input1, input2).dtype == torch.float16 - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.parametrize("metric", [DummyTensorMetric, DummyNumpyMetric]) -def test_model_pickable(tmpdir, metric: Metric): - """Make sure that metrics are pickable by including into a model and running in multi-gpu mode""" - tutils.set_random_master_port() - - trainer_options = dict( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=10, - gpus=[0, 1], - distributed_backend="ddp_spawn", - ) - - model = EvalModelTemplate() - model.metric = metric() - model.training_step = model.training_step__using_metrics - - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - - # correct result and ok accuracy - assert result == 1, "ddp model failed to complete" - - -@pytest.mark.parametrize("metric", [DummyTensorMetric(), DummyNumpyMetric()]) -def test_saving_pickable(tmpdir, metric: Metric): - """ Make sure that metrics are pickable by saving and loading them using torch """ - x, y = torch.randn(10,), torch.randn( - 10, - ) - results_before_save = metric(x, y) - - # save metric - save_path = os.path.join(tmpdir, "save_test.ckpt") - torch.save(metric, save_path) - - # load metric - new_metric = torch.load(save_path) - results_after_load = new_metric(x, y) - - # Check metric value is the same - assert results_before_save == results_after_load - - -def test_correct_call_order(): - """ Check that hooks are called in the expected order """ - - class DummyMetric(Metric): - def __init__(self): - super().__init__("dummy") - self.call_history = ["init"] - - @staticmethod - def input_convert(self, data: Any): - self.call_history.append("input_convert") - return super(DummyMetric, self).input_convert(self, data) - - def forward(self, tensor1, tensor2): - self.call_history.append("forward") - return tensor1 - tensor2 - - @staticmethod - def output_convert(self, data: Any, output: Any): - self.call_history.append("output_convert") - return super(DummyMetric, self).output_convert(self, data, output) - - def ddp_sync(self, tensor: Any): - self.call_history.append("ddp_sync") - return super().ddp_sync(tensor) - - @staticmethod - def ddp_reduce(self, data: Any, output: Any): - self.call_history.append("ddp_reduce") - return super(DummyMetric, self).ddp_reduce(self, data, output) - - def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: - self.call_history.append("aggregate") - return super().aggregate(*tensors) - - def reset(self): - self.call_history.append("reset") - return super().reset() - - @property - def aggregated(self) -> torch.Tensor: - self.call_history.append("aggregated") - return super().aggregated - - @staticmethod - def compute(self, data: Any, output: Any): - self.call_history.append("compute") - return super(DummyMetric, self).compute(self, data, output) - - metric = DummyMetric() - assert metric.call_history == ["init"] - result = metric(torch.tensor([2.0]), torch.tensor([1.0])) - assert torch.allclose(result, torch.tensor(1.0)) - assert metric.call_history == [ - "init", - "input_convert", - "forward", - "output_convert", - "ddp_reduce", - "ddp_sync", - "aggregate", - "compute" - ] - aggr = metric.aggregated - assert metric.call_history == [ - "init", - "input_convert", - "forward", - "output_convert", - "ddp_reduce", - "ddp_sync", - "aggregate", - "compute", - "aggregated", - "aggregate", - "reset", - "compute" - ] - assert torch.allclose(aggr, result) - _ = metric(torch.tensor(2.0), torch.tensor(1.0)) - assert metric.call_history == [ - "init", - "input_convert", - "forward", - "output_convert", - "ddp_reduce", - "ddp_sync", - "aggregate", - "compute", - "aggregated", - "aggregate", - "reset", - "compute", - "input_convert", - "forward", - "output_convert", - "ddp_reduce", - "ddp_sync", - "aggregate", - "compute" - ] - - metric = DummyMetric() - _ = metric(torch.tensor([2.0]), torch.tensor([1.0])) - _ = metric(torch.tensor([3.0]), torch.tensor([0.0])) - - aggregated = metric.aggregated - - assert torch.allclose(aggregated, torch.tensor(4.0)) - - assert metric.call_history == [ - "init", - "input_convert", - "forward", - "output_convert", - "ddp_reduce", - "ddp_sync", - "aggregate", - "compute", - "input_convert", - "forward", - "output_convert", - "ddp_reduce", - "ddp_sync", - "aggregate", - "compute", - "aggregated", - "aggregate", - "reset", - "compute", - ] diff --git a/tests/metrics/test_nlp.py b/tests/metrics/test_nlp.py deleted file mode 100644 index e58b1f33988f5..0000000000000 --- a/tests/metrics/test_nlp.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest -import torch - -from pytorch_lightning.metrics.nlp import BLEUScore - -# example taken from -# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu -HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split() -HYP2 = "he read the book because he was interested in world history".split() - -REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split() -REF1B = "It is a guiding principle which makes the military forces always being under the command of the Party".split() -REF1C = "It is the practical guide for the army always to heed the directions of the party".split() -REF2A = "he was interested in world history because he read the book".split() - -LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]] -HYPOTHESES = [HYP1, HYP2] - - -@pytest.mark.parametrize( - ["n_gram", "smooth"], - [pytest.param(1, True), pytest.param(2, False), pytest.param(3, True), pytest.param(4, False),], -) -def test_bleu(smooth, n_gram): - bleu = BLEUScore(n_gram=n_gram, smooth=smooth) - assert bleu.name == "bleu" - - pl_output = bleu(HYPOTHESES, LIST_OF_REFERENCES) - assert isinstance(pl_output, torch.Tensor) diff --git a/tests/metrics/test_regression.py b/tests/metrics/test_regression.py deleted file mode 100644 index 36c408e93c469..0000000000000 --- a/tests/metrics/test_regression.py +++ /dev/null @@ -1,69 +0,0 @@ -# NOTE: This file only tests if modules with arguments are running fine. -# The actual metric implementation is tested in functional/test_regression.py -# Especially reduction and reducing across processes won't be tested here! - -import torch - -from pytorch_lightning.metrics.regression import ( - MAE, MSE, RMSE, RMSLE, PSNR, SSIM -) - - -def test_mse(): - mse = MSE() - assert mse.name == 'mse' - - pred = torch.tensor([0., 1, 2, 3]) - target = torch.tensor([0., 1, 2, 2]) - score = mse(pred, target) - assert isinstance(score, torch.Tensor) - - -def test_rmse(): - rmse = RMSE() - assert rmse.name == 'rmse' - - pred = torch.tensor([0., 1, 2, 3]) - target = torch.tensor([0., 1, 2, 2]) - score = rmse(pred, target) - assert isinstance(score, torch.Tensor) - - -def test_mae(): - mae = MAE() - assert mae.name == 'mae' - - pred = torch.tensor([0., 1, 2, 3]) - target = torch.tensor([0., 1, 2, 2]) - score = mae(pred, target) - assert isinstance(score, torch.Tensor) - - -def test_rmsle(): - rmsle = RMSLE() - assert rmsle.name == 'rmsle' - - pred = torch.tensor([0., 1, 2, 3]) - target = torch.tensor([0., 1, 2, 2]) - score = rmsle(pred, target) - assert isinstance(score, torch.Tensor) - - -def test_psnr(): - psnr = PSNR() - assert psnr.name == 'psnr' - - pred = torch.tensor([0., 1, 2, 3]) - target = torch.tensor([0., 1, 2, 2]) - score = psnr(pred, target) - assert isinstance(score, torch.Tensor) - - -def test_ssim(): - ssim = SSIM() - assert ssim.name == 'ssim' - - pred = torch.rand([16, 1, 16, 16]) - target = pred * 0.75 - score = ssim(pred, target) - assert isinstance(score, torch.Tensor) diff --git a/tests/metrics/test_sklearn.py b/tests/metrics/test_sklearn.py deleted file mode 100644 index 019048056016c..0000000000000 --- a/tests/metrics/test_sklearn.py +++ /dev/null @@ -1,178 +0,0 @@ -import numbers -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import ( - accuracy_score as sk_accuracy, - precision_score as sk_precision, - recall_score as sk_recall, - f1_score as sk_f1_score, - fbeta_score as sk_fbeta_score, - confusion_matrix as sk_confusion_matrix, - average_precision_score as sk_average_precision, - auc as sk_auc, - precision_recall_curve as sk_precision_recall_curve, - roc_curve as sk_roc_curve, - roc_auc_score as sk_roc_auc_score, - balanced_accuracy_score as sk_balanced_accuracy_score, - dcg_score as sk_dcg_score, - mean_absolute_error as sk_mean_absolute_error, - mean_squared_error as sk_mean_squared_error, - mean_squared_log_error as sk_mean_squared_log_error, - median_absolute_error as sk_median_absolute_error, - r2_score as sk_r2_score, - mean_poisson_deviance as sk_mean_poisson_deviance, - mean_gamma_deviance as sk_mean_gamma_deviance, - mean_tweedie_deviance as sk_mean_tweedie_deviance, - explained_variance_score as sk_explained_variance_score, - cohen_kappa_score as sk_cohen_kappa_score, - hamming_loss as sk_hamming_loss, - hinge_loss as sk_hinge_loss, - jaccard_score as sk_jaccard_score -) - -from pytorch_lightning.metrics.converters import convert_to_numpy -from pytorch_lightning.metrics.sklearns import ( - Accuracy, - AUC, - AveragePrecision, - BalancedAccuracy, - ConfusionMatrix, - CohenKappaScore, - DCG, - F1, - FBeta, - Hamming, - Hinge, - Jaccard, - Precision, - Recall, - PrecisionRecallCurve, - ROC, - AUROC, - MeanAbsoluteError, - MeanSquaredError, - MeanSquaredLogError, - MedianAbsoluteError, - R2Score, - MeanPoissonDeviance, - MeanGammaDeviance, - MeanTweedieDeviance, - ExplainedVariance, -) -from pytorch_lightning.utilities.apply_func import apply_to_collection - - -def _xy_only(func): - def new_func(*args, **kwargs): - return np.array(func(*args, **kwargs)[:2]) - return new_func - - -@pytest.mark.parametrize(['metric_class', 'sklearn_func', 'inputs'], [ - pytest.param(Accuracy(), sk_accuracy, - {'y_pred': torch.randint(10, size=(128,)), - 'y_true': torch.randint(10, size=(128,))}, - id='Accuracy'), - pytest.param(AUC(), sk_auc, - {'x': torch.arange(10, dtype=torch.float) / 10, - 'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.5, 0.6, 0.7])}, - id='AUC'), - pytest.param(AveragePrecision(), sk_average_precision, - {'y_score': torch.randint(2, size=(128,)), - 'y_true': torch.randint(2, size=(128,))}, - id='AveragePrecision'), - pytest.param(ConfusionMatrix(), sk_confusion_matrix, - {'y_pred': torch.randint(10, size=(128,)), - 'y_true': torch.randint(10, size=(128,))}, - id='ConfusionMatrix'), - pytest.param(F1(average='macro'), partial(sk_f1_score, average='macro'), - {'y_pred': torch.randint(10, size=(128,)), - 'y_true': torch.randint(10, size=(128,))}, - id='F1'), - pytest.param(FBeta(beta=0.5, average='macro'), partial(sk_fbeta_score, beta=0.5, average='macro'), - {'y_pred': torch.randint(10, size=(128,)), - 'y_true': torch.randint(10, size=(128,))}, - id='FBeta'), - pytest.param(Precision(average='macro'), partial(sk_precision, average='macro'), - {'y_pred': torch.randint(10, size=(128,)), - 'y_true': torch.randint(10, size=(128,))}, - id='Precision'), - pytest.param(Recall(average='macro'), partial(sk_recall, average='macro'), - {'y_pred': torch.randint(10, size=(128,)), - 'y_true': torch.randint(10, size=(128,))}, - id='Recall'), - pytest.param(PrecisionRecallCurve(), _xy_only(sk_precision_recall_curve), - {'probas_pred': torch.rand(size=(128,)), - 'y_true': torch.randint(2, size=(128,))}, - id='PrecisionRecallCurve'), - pytest.param(ROC(), _xy_only(sk_roc_curve), - {'y_score': torch.rand(size=(128,)), - 'y_true': torch.randint(2, size=(128,))}, - id='ROC'), - pytest.param(AUROC(), sk_roc_auc_score, - {'y_score': torch.rand(size=(128,)), - 'y_true': torch.randint(2, size=(128,))}, - id='AUROC'), - pytest.param(BalancedAccuracy(), sk_balanced_accuracy_score, - {'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))}, - id='BalancedAccuracy'), - pytest.param(DCG(), sk_dcg_score, - {'y_score': torch.rand(size=(128, 3)), 'y_true': torch.randint(3, size=(128, 3))}, - id='DCG'), - pytest.param(ExplainedVariance(), sk_explained_variance_score, - {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, - id='ExplainedVariance'), - pytest.param(MeanAbsoluteError(), sk_mean_absolute_error, - {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, - id='MeanAbsolutError'), - pytest.param(MeanSquaredError(), sk_mean_squared_error, - {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, - id='MeanSquaredError'), - pytest.param(MeanSquaredLogError(), sk_mean_squared_log_error, - {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, - id='MeanSquaredLogError'), - pytest.param(MedianAbsoluteError(), sk_median_absolute_error, - {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, - id='MedianAbsoluteError'), - pytest.param(R2Score(), sk_r2_score, - {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, - id='R2Score'), - pytest.param(MeanPoissonDeviance(), sk_mean_poisson_deviance, - {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, - id='MeanPoissonDeviance'), - pytest.param(MeanGammaDeviance(), sk_mean_gamma_deviance, - {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, - id='MeanGammaDeviance'), - pytest.param(MeanTweedieDeviance(), sk_mean_tweedie_deviance, - {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, - id='MeanTweedieDeviance'), - pytest.param(CohenKappaScore(), sk_cohen_kappa_score, - {'y1': torch.randint(3, size=(128,)), 'y2': torch.randint(3, size=(128,))}, - id='CohenKappaScore'), - pytest.param(Hamming(), sk_hamming_loss, - {'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))}, - id='Hamming'), - pytest.param(Hinge(), sk_hinge_loss, - {'pred_decision': torch.randn(size=(128,)), 'y_true': torch.randint(2, size=(128,))}, - id='Hinge'), - pytest.param(Jaccard(average='macro'), partial(sk_jaccard_score, average='macro'), - {'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))}, - id='Jaccard') -]) -def test_sklearn_metric(metric_class, sklearn_func, inputs): - numpy_inputs = apply_to_collection(inputs, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) - - sklearn_result = sklearn_func(**numpy_inputs) - lightning_result = metric_class(**inputs) - - sklearn_result = apply_to_collection( - sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy) - - lightning_result = np.array(apply_to_collection( - lightning_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)) - - assert np.allclose(sklearn_result, lightning_result, atol=1e-5) - assert isinstance(lightning_result, type(sklearn_result)) From 8cf72ec2e14a3e28ae52fff20d1a0adc157883c5 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 11:26:16 -0400 Subject: [PATCH 02/27] added new metrics Co-authored-by: Teddy Koker teddy.koker@gmail.com --- docs/source/metrics.rst | 637 ------------------ pytorch_lightning/core/step_result.py | 3 +- pytorch_lightning/metrics/__init__.py | 3 + .../metrics/classification/__init__.py | 1 + .../metrics/classification/accuracy.py | 57 ++ pytorch_lightning/metrics/metric.py | 129 ++++ .../metrics/regression/__init__.py | 0 pytorch_lightning/utilities/distributed.py | 44 ++ tests/metrics/__init__.py | 5 + tests/metrics/classification/__init__.py | 0 tests/metrics/classification/test_accuracy.py | 130 ++++ tests/metrics/test_ddp.py | 45 ++ tests/metrics/test_metric.py | 63 ++ tests/metrics/utils.py | 57 ++ 14 files changed, 535 insertions(+), 639 deletions(-) create mode 100644 pytorch_lightning/metrics/__init__.py create mode 100644 pytorch_lightning/metrics/classification/__init__.py create mode 100644 pytorch_lightning/metrics/classification/accuracy.py create mode 100644 pytorch_lightning/metrics/metric.py create mode 100644 pytorch_lightning/metrics/regression/__init__.py create mode 100644 tests/metrics/__init__.py create mode 100644 tests/metrics/classification/__init__.py create mode 100644 tests/metrics/classification/test_accuracy.py create mode 100644 tests/metrics/test_ddp.py create mode 100644 tests/metrics/test_metric.py create mode 100644 tests/metrics/utils.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index f6a4039fa9c7f..e69de29bb2d1d 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -1,637 +0,0 @@ -.. testsetup:: * - - import torch - from torch.nn import Module - from pytorch_lightning.core.lightning import LightningModule - from pytorch_lightning.metrics import TensorMetric, NumpyMetric - -.. _metrics: - -Metrics -======= -This is a general package for PyTorch Metrics. These can also be used with regular non-lightning PyTorch code. -Metrics are used to monitor model performance. - -In this package, we provide two major pieces of functionality. - -1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic. -2. A collection of ready to use popular metrics. There are two types of metrics: Class metrics and Functional metrics. -3. An interface to call `sklearns metrics `_ - -Example:: - - from pytorch_lightning.metrics.functional import accuracy - - pred = torch.tensor([0, 1, 2, 3]) - target = torch.tensor([0, 1, 2, 2]) - - # calculates accuracy across all GPUs and all Nodes used in training - accuracy(pred, target) - -.. warning:: - The metrics package is still in development! If we're missing a metric or you find a mistake, please send a PR! - to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug. - ----------------- - -Implement a metric ------------------- -You can implement metrics as either a PyTorch metric or a Numpy metric (It is recommended to use PyTorch metrics when possible, -since Numpy metrics slow down training). - -Use :class:`TensorMetric` to implement native PyTorch metrics. This class -handles automated DDP syncing and converts all inputs and outputs to tensors. - -Use :class:`NumpyMetric` to implement numpy metrics. This class -handles automated DDP syncing and converts all inputs and outputs to tensors. - -.. warning:: - Numpy metrics might slow down your training substantially, - since every metric computation requires a GPU sync to convert tensors to numpy. - ----------------- - -TensorMetric -^^^^^^^^^^^^ -Here's an example showing how to implement a TensorMetric - -.. testcode:: - - class RMSE(TensorMetric): - def forward(self, x, y): - return torch.sqrt(torch.mean(torch.pow(x-y, 2.0))) - -.. autoclass:: pytorch_lightning.metrics.metric.TensorMetric - :noindex: - ----------------- - -NumpyMetric -^^^^^^^^^^^ -Here's an example showing how to implement a NumpyMetric - -.. testcode:: - - class RMSE(NumpyMetric): - def forward(self, x, y): - return np.sqrt(np.mean(np.power(x-y, 2.0))) - - -.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric - :noindex: - ----------------- - -Class Metrics -------------- -Class metrics can be instantiated as part of a module definition (even with just -plain PyTorch). - -.. testcode:: - - from pytorch_lightning.metrics import Accuracy - - # Plain PyTorch - class MyModule(Module): - def __init__(self): - super().__init__() - self.metric = Accuracy() - - def forward(self, x, y): - y_hat = ... - acc = self.metric(y_hat, y) - - # PyTorch Lightning - class MyModule(LightningModule): - def __init__(self): - super().__init__() - self.metric = Accuracy() - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = ... - acc = self.metric(y_hat, y) - -These metrics even work when using distributed training: - -.. code-block:: python - - model = MyModule() - trainer = Trainer(gpus=8, num_nodes=2) - - # any metric automatically reduces across GPUs (even the ones you implement using Lightning) - trainer.fit(model) - -Accuracy -^^^^^^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.Accuracy - :noindex: - -AveragePrecision -^^^^^^^^^^^^^^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision - :noindex: - -AUROC -^^^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.AUROC - :noindex: - -BLEUScore -^^^^^^^^^ - -.. autoclass:: pytorch_lightning.metrics.nlp.BLEUScore - :noindex: - -ConfusionMatrix -^^^^^^^^^^^^^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix - :noindex: - -DiceCoefficient -^^^^^^^^^^^^^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.DiceCoefficient - :noindex: - -EmbeddingSimilarity -^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: pytorch_lightning.metrics.self_supervised.EmbeddingSimilarity - :noindex: - -F1 -^^ - -.. autoclass:: pytorch_lightning.metrics.classification.F1 - :noindex: - -FBeta -^^^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.FBeta - :noindex: - -PrecisionRecallCurve -^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve - :noindex: - -Precision -^^^^^^^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.Precision - :noindex: - -Recall -^^^^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.Recall - :noindex: - -ROC -^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.ROC - :noindex: - -MAE -^^^ - -.. autoclass:: pytorch_lightning.metrics.regression.MAE - :noindex: - -MSE -^^^ - -.. autoclass:: pytorch_lightning.metrics.regression.MSE - :noindex: - -MulticlassROC -^^^^^^^^^^^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.MulticlassROC - :noindex: - -MulticlassPrecisionRecallCurve -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.MulticlassPrecisionRecallCurve - :noindex: - -IoU -^^^ - -.. autoclass:: pytorch_lightning.metrics.classification.IoU - :noindex: - -RMSE -^^^^ - -.. autoclass:: pytorch_lightning.metrics.regression.RMSE - :noindex: - -RMSLE -^^^^^ - -.. autoclass:: pytorch_lightning.metrics.regression.RMSLE - :noindex: - -SSIM -^^^^ - -.. autoclass:: pytorch_lightning.metrics.regression.SSIM - :noindex: - ----------------- - -Functional Metrics ------------------- -Functional metrics can be called anywhere (even used with just plain PyTorch). - -.. code-block:: python - - from pytorch_lightning.metrics.functional import accuracy - - pred = torch.tensor([0, 1, 2, 3]) - target = torch.tensor([0, 1, 2, 2]) - - # calculates accuracy across all GPUs and all Nodes used in training - accuracy(pred, target) - -These metrics even work when using distributed training: - -.. code-block:: python - - class MyModule(...): - def forward(self, x, y): - return accuracy(x, y) - - model = MyModule() - trainer = Trainer(gpus=8, num_nodes=2) - - # any metric automatically reduces across GPUs (even the ones you implement using Lightning) - trainer.fit(model) - - -accuracy (F) -^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.accuracy - :noindex: - -auc (F) -^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.auc - :noindex: - -auroc (F) -^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.auroc - :noindex: - -average_precision (F) -^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.average_precision - :noindex: - -bleu_score (F) -^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.bleu_score - :noindex: - -confusion_matrix (F) -^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix - :noindex: - -dice_score (F) -^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.dice_score - :noindex: - -embedding_similarity (F) -^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.embedding_similarity - :noindex: - -f1_score (F) -^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.f1_score - :noindex: - -fbeta_score (F) -^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.fbeta_score - :noindex: - -multiclass_precision_recall_curve (F) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.multiclass_precision_recall_curve - :noindex: - -multiclass_roc (F) -^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.multiclass_roc - :noindex: - -precision (F) -^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.precision - :noindex: - -precision_recall (F) -^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.precision_recall - :noindex: - -precision_recall_curve (F) -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve - :noindex: - -recall (F) -^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.recall - :noindex: - -roc (F) -^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.roc - :noindex: - -stat_scores (F) -^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.stat_scores - :noindex: - -iou (F) -^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.iou - :noindex: - -mse (F) -^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.mse - :noindex: - -rmse (F) -^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.rmse - :noindex: - -mae (F) -^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.mae - :noindex: - -rmsle (F) -^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.rmsle - :noindex: - -psnr (F) -^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.psnr - :noindex: - -ssim (F) -^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.ssim - :noindex: - -stat_scores_multiple_classes (F) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes - :noindex: - ----------------- - -Metric pre-processing ---------------------- - -to_categorical (F) -^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.to_categorical - :noindex: - -to_onehot (F) -^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.functional.to_onehot - :noindex: - ----------------- - -Sklearn interface ------------------ - -Lightning supports `sklearns metrics module `_ -as a backend for calculating metrics. Sklearns metrics are well tested and robust, -but requires conversion between pytorch and numpy thus may slow down your computations. - -To use the sklearn backend of metrics simply import as - -.. code-block:: python - - import pytorch_lightning.metrics.sklearns import plm - metric = plm.Accuracy(normalize=True) - val = metric(pred, target) - -Each converted sklearn metric comes has the same interface as its -original counterpart (e.g. accuracy takes the additional `normalize` keyword). -Like the native Lightning metrics, these converted sklearn metrics also come -with built-in distributed (ddp) support. - -SklearnMetric (sk) -^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.SklearnMetric - :noindex: - -Accuracy (sk) -^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.Accuracy - :noindex: - -AUC (sk) -^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.AUC - :noindex: - -AveragePrecision (sk) -^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.AveragePrecision - :noindex: - -BalancedAccuracy (sk) -^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.BalancedAccuracy - :noindex: - -CohenKappaScore (sk) -^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.CohenKappaScore - :noindex: - -ConfusionMatrix (sk) -^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.ConfusionMatrix - :noindex: - -DCG (sk) -^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.DCG - :noindex: - -F1 (sk) -^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.F1 - :noindex: - -FBeta (sk) -^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.FBeta - :noindex: - -Hamming (sk) -^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.Hamming - :noindex: - -Hinge (sk) -^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.Hinge - :noindex: - -Jaccard (sk) -^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.Jaccard - :noindex: - -Precision (sk) -^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.Precision - :noindex: - -Recall (sk) -^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.Recall - :noindex: - -PrecisionRecallCurve (sk) -^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.PrecisionRecallCurve - :noindex: - -ROC (sk) -^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.ROC - :noindex: - -AUROC (sk) -^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.AUROC - :noindex: - -ExplainedVariance (sk) -^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.ExplainedVariance - :noindex: - -MeanAbsoluteError (sk) -^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.MeanAbsoluteError - :noindex: - -MeanSquaredError (sk) -^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.MeanSquaredError - :noindex: - -MeanSquaredLogError (sk) -^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.MeanSquaredLogError - :noindex: - -MedianAbsoluteError (sk) -^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.MedianAbsoluteError - :noindex: - -R2Score (sk) -^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.R2Score - :noindex: - -MeanPoissonDeviance (sk) -^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.MeanPoissonDeviance - :noindex: - -MeanGammaDeviance (sk) -^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.MeanGammaDeviance - :noindex: - -MeanTweedieDeviance (sk) -^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: pytorch_lightning.metrics.sklearns.MeanTweedieDeviance - :noindex: diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 756fffc60a0b1..50fa6266c7964 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -20,8 +20,7 @@ from torch import Tensor import os -from pytorch_lightning.metrics.converters import sync_ddp_if_available -from typing import Iterable +from pytorch_lightning.utilities.distributed import sync_ddp_if_available class Result(Dict): diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py new file mode 100644 index 0000000000000..5c3b31f543e56 --- /dev/null +++ b/pytorch_lightning/metrics/__init__.py @@ -0,0 +1,3 @@ +from pytorch_lightning.metrics.metric import Metric + +from pytorch_lightning.metrics.classification.accuracy import Accuracy diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py new file mode 100644 index 0000000000000..45e66603b2465 --- /dev/null +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -0,0 +1 @@ +from pytorch_lightning.metrics.classification.accuracy import Accuracy diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py new file mode 100644 index 0000000000000..bc2acf5c3d23d --- /dev/null +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -0,0 +1,57 @@ +import math +import functools +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Union +from collections.abc import Mapping, Sequence +from collections import namedtuple + +import torch +from torch import nn +from pytorch_lightning.metrics.metric import Metric + + +class Accuracy(Metric): + """ + Computes accuracy. Works with binary, multiclass, and multilabel data. + + preds and targets must be of shape (N, ...) and (N, ...) or (N, num_classes, ...) and (N, ...) + + If preds and targets are the same shape: + If preds are integer values, we perform accuracy with those values + If preds are floating point we threshold at `threshold` + + """ + def __init__(self, threshold=0.5, **kwargs): + super().__init__(**kwargs) + + # change to dist_reduce_fx + self.add_state("correct", torch.tensor(0), reduction=sum) + self.add_state("total", torch.tensor(0), reduction=sum) + + self.threshold = threshold + + def _input_format(self, preds, target): + if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): + raise ValueError( + "preds and target must have same number of dimensions, or one additional dimension for preds" + ) + + if len(preds.shape) == len(target.shape) + 1: + # multi class probabilites + preds = torch.argmax(preds, dim=1) + + if len(preds.shape) == len(target.shape) and preds.dtype == torch.float: + # binary or multilabel probablities + preds = (preds >= self.threshold).long() + + return preds, target + + def update(self, preds, target): + preds, target = self._input_format(preds, target) + assert preds.shape == target.shape + + self.correct += torch.sum(preds == target) + self.total += target.numel() + + def compute(self): + return self.correct.float() / self.total diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py new file mode 100644 index 0000000000000..cc908553f3994 --- /dev/null +++ b/pytorch_lightning/metrics/metric.py @@ -0,0 +1,129 @@ +import functools +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Union +from collections.abc import Mapping, Sequence +from collections import namedtuple +from copy import deepcopy + +import torch +from torch import nn + +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +class Metric(nn.Module, ABC): + + def __init__( + self, + compute_on_step: bool = True, + ddp_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__() + + self.ddp_sync_on_step = ddp_sync_on_step + self.compute_on_step = compute_on_step + self.process_group = process_group + self._to_sync = True + + self.compute = self.wrap_sync(self.compute) + + # initialize state + self._reductions = {} + self._defaults = {} + + def add_state(self, name, default, reduction=None): + if reduction is None: + # TODO: implement default reduction + raise NotImplementedError("default reduction not implemented") + + setattr(self, name, default) + self._defaults[name] = deepcopy(default) + self._reductions[name] = reduction + + def forward(self, *args, **kwargs): + # add current step + self.update(*args, **kwargs) + + if self.compute_on_step: + self._to_sync = self.ddp_sync_on_step + + # save context before switch + self._cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} + + # call reset, update, compute, on single batch + self.reset() + self.update(*args, **kwargs) + result = self.compute() + + # restore context + for attr, val in self._cache.items(): + setattr(self, attr, val) + self._to_sync = True + + return result + + def sync(self): + input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} + output_dict = apply_to_collection( + input_dict, + torch.Tensor, + gather_all_tensors_if_available, + group=self.process_group, + ) + + for attr, reduction_fn in self._reductions.items(): + # agregate lists of tensors + reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] + setattr(self, attr, reduced) + + def wrap_sync(self, func): + @functools.wraps(func) + def wrapped_func(*args, **kwargs): + if self._to_sync and torch.distributed.is_available() and torch.distributed.is_initialized(): + self.sync() + return func(*args, **kwargs) + + return wrapped_func + + @abstractmethod + def update(self): + pass + + @abstractmethod + def compute(self): # pylint: disable=E0202 + pass + + def reset(self): + for attr, default in self._defaults.items(): + setattr(self, attr, deepcopy(default)) + + +def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None): + """ + Function to gather all tensors from several ddp processes onto a list that + is broadcasted to all processes + + Args: + result: the value to sync + group: the process group to gather results from. Defaults to all processes (world) + + Return: + gathered_result: list with size equal to the process group where + gathered_result[i] corresponds to result tensor from process i + + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if group is None: + group = torch.distributed.group.WORLD + + world_size = torch.distributed.get_world_size(group) + + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + + # sync and broadcast all + torch.distributed.barrier(group=group) + torch.distributed.all_gather(gathered_result, result, group) + + result = gathered_result + return result diff --git a/pytorch_lightning/metrics/regression/__init__.py b/pytorch_lightning/metrics/regression/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 792f54397fd13..be421ad3cabb2 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -16,7 +16,15 @@ import warnings from functools import wraps +import torch from pytorch_lightning import _logger as log +from typing import Union, Optional, Any + +if torch.distributed.is_available(): + from torch.distributed import ReduceOp +else: + class ReduceOp: + SUM = None def rank_zero_only(fn): @@ -63,3 +71,39 @@ def find_free_network_port() -> int: port = s.getsockname()[1] s.close() return port + + +def sync_ddp_if_available( + result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None +) -> torch.Tensor: + """ + Function to reduce the tensors from several ddp processes to one master process + Args: + result: the value to sync and reduce (typically tensor or number) + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum. + Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + Return: + reduced value + """ + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + divide_by_world_size = False + + if group is None: + group = torch.distributed.group.WORLD + + if reduce_op is None: + reduce_op = torch.distributed.ReduceOp.SUM + elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): + reduce_op = torch.distributed.ReduceOp.SUM + divide_by_world_size = True + + # sync all processes before reduction + torch.distributed.barrier(group=group) + torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False) + + if divide_by_world_size: + result = result / torch.distributed.get_world_size(group) + + return result diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000000000..dff7b6497b136 --- /dev/null +++ b/tests/metrics/__init__.py @@ -0,0 +1,5 @@ +import os + +from tests.metrics.utils import compute_batch, setup_ddp +from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE +from tests.metrics.test_metric import Dummy diff --git a/tests/metrics/classification/__init__.py b/tests/metrics/classification/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py new file mode 100644 index 0000000000000..4e45b7a730ea1 --- /dev/null +++ b/tests/metrics/classification/test_accuracy.py @@ -0,0 +1,130 @@ +import os +import pytest +import torch +import os +import numpy as np +from collections import namedtuple + +from pytorch_lightning.metrics.classification.accuracy import Accuracy +from sklearn.metrics import accuracy_score + +from tests.metrics.utils import compute_batch, setup_ddp +from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE + +torch.manual_seed(42) + +# global vars +num_classes = 5 +threshold = 0.5 +extra_dim = 3 + +def test_accuracy_invalid_shape(): + with pytest.raises(ValueError): + acc = Accuracy() + acc.update(preds=torch.rand(1), target=torch.rand(1, 2, 3)) + +Input = namedtuple('Input', ["preds", "target"]) + +_binary_prob_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) +) + +def _binary_prob_sk_metric(preds, target): + sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8) + sk_target = target.view(-1).numpy() + + return accuracy_score(y_true=sk_target, y_pred=sk_preds) + +_binary_inputs = Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,)) +) + +def _binary_sk_metric(preds, target): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return accuracy_score(y_true=sk_target, y_pred=sk_preds) + +_multilabel_prob_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes)) +) + +def _multilabel_prob_sk_metric(preds, target): + sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8) + sk_target = target.view(-1).numpy() + + return accuracy_score(y_true=sk_target, y_pred=sk_preds) + +_multilabel_inputs = Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes)) +) + +def _multilabel_sk_metric(preds, target): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return accuracy_score(y_true=sk_target, y_pred=sk_preds) + +_multiclass_prob_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes), + target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE)) +) + +def _multiclass_prob_sk_metric(preds, target): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return accuracy_score(y_true=sk_target, y_pred=sk_preds) + +_multiclass_inputs = Input( + preds=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE)) +) + +def _multiclass_sk_metric(preds, target): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return accuracy_score(y_true=sk_target, y_pred=sk_preds) + +_multidim_multiclass_prob_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes, extra_dim), + target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE, extra_dim)) +) + +def _multidim_multiclass_prob_sk_metric(preds, target): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return accuracy_score(y_true=sk_target, y_pred=sk_preds) + +_multidim_multiclass_inputs = Input( + preds=torch.randint(high=num_classes, size=(NUM_BATCHES, extra_dim, BATCH_SIZE)), + target=torch.randint(high=num_classes, size=(NUM_BATCHES, extra_dim, BATCH_SIZE)) +) + +def _multidim_multiclass_sk_metric(preds, target): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return accuracy_score(y_true=sk_target, y_pred=sk_preds) + + +@pytest.mark.parametrize("ddp", [True, False]) +@pytest.mark.parametrize("ddp_sync_on_step", [True, False]) +@pytest.mark.parametrize("preds, target, sk_metric", [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric), + (_binary_inputs.preds, _binary_inputs.target, _binary_sk_metric), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _multilabel_prob_sk_metric), + (_multilabel_inputs.preds, _multilabel_inputs.target, _multilabel_sk_metric), + (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _multiclass_prob_sk_metric), + (_multiclass_inputs.preds, _multiclass_inputs.target, _multiclass_sk_metric), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, _multidim_multiclass_prob_sk_metric), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, _multidim_multiclass_sk_metric) +]) +def test_accuracy(ddp, ddp_sync_on_step, preds, target, sk_metric): + compute_batch(preds, target, Accuracy, sk_metric, ddp_sync_on_step, ddp, metric_args={"threshold": threshold}) diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py new file mode 100644 index 0000000000000..8d95e1f25db41 --- /dev/null +++ b/tests/metrics/test_ddp.py @@ -0,0 +1,45 @@ +import pytest +import torch +import os + +from tests.metrics.test_metric import Dummy +from tests.metrics.utils import setup_ddp + +torch.manual_seed(42) + + +def _test_ddp_sum(rank, worldsize): + setup_ddp(rank, worldsize) + dummy = Dummy() + dummy._reductions = {"foo": sum} + dummy.foo = torch.tensor(1) + + dummy.sync() + assert dummy.foo == worldsize + + +def _test_ddp_cat(rank, worldsize): + setup_ddp(rank, worldsize) + dummy = Dummy() + dummy._reductions = {"foo": torch.cat} + dummy.foo = torch.tensor([1]) + dummy.sync() + assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) + + +def _test_ddp_sum_cat(rank, worldsize): + setup_ddp(rank, worldsize) + dummy = Dummy() + dummy._reductions = {"foo": torch.cat, "bar": sum} + dummy.foo = torch.tensor([1]) + dummy.bar = torch.tensor(1) + dummy.sync() + assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) + assert dummy.bar == worldsize + + +@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat]) +def test_ddp(process): + torch.multiprocessing.spawn(process, args=(2,), nprocs=2) + + diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py new file mode 100644 index 0000000000000..a02ead1d640a4 --- /dev/null +++ b/tests/metrics/test_metric.py @@ -0,0 +1,63 @@ +import pytest +import torch +from pytorch_lightning.metrics.metric import Metric +import os + +torch.manual_seed(42) + + +class Dummy(Metric): + name = "Dummy" + def __init__(self): + super().__init__() + self.add_state("x", 0, reduction=False) + + def update(self): + pass + + def compute(self): + pass + + +def test_inherit(): + a = Dummy() + + +def test_reset(): + class A(Dummy): + pass + + a = A() + assert a.x == 0 + a.x = 5 + a.reset() + assert a.x == 0 + + +def test_update(): + class A(Dummy): + def update(self, x): + self.x += x + + a = A() + assert a.x == 0 + a.update(1) + assert a.x == 1 + a.update(2) + assert a.x == 3 + + +def test_compute(): + class A(Dummy): + def update(self, x): + self.x += x + + def compute(self): + return self.x + + a = A() + assert a.x == a.compute() + a.update(1) + assert a.x == a.compute() + a.update(2) + assert a.x == a.compute() diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py new file mode 100644 index 0000000000000..19b351cb1f786 --- /dev/null +++ b/tests/metrics/utils.py @@ -0,0 +1,57 @@ +import torch +import numpy as np +import os + +NUM_PROCESSES = 2 +NUM_BATCHES = 10 +BATCH_SIZE = 16 + +def setup_ddp(rank, world_size): + os.environ["MASTER_ADDR"] = 'localhost' + os.environ['MASTER_PORT'] = '8088' + torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) + + +def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_step, worldsize=1, metric_args={}): + + metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args) + + # Only use ddp if world size + if worldsize > 1: + setup_ddp(rank, worldsize) + + for i in range(rank, NUM_BATCHES, worldsize): + batch_result = metric(preds[i], target[i]) + + if metric.ddp_sync_on_step: + if rank == 0: + ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)]) + ddp_target = torch.stack([target[i + r] for r in range(worldsize)]) + sk_batch_result = sk_metric(ddp_preds, ddp_target) + assert np.allclose(batch_result.numpy(), sk_batch_result) + else: + sk_batch_result = sk_metric(preds[i], target[i]) + assert np.allclose(batch_result.numpy(), sk_batch_result) + + result = metric.compute() + if rank == 0: + assert isinstance(result, torch.Tensor) + + total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)]) + total_target = torch.stack([target[i] for i in range(NUM_BATCHES)]) + sk_result = sk_metric(total_preds, total_target) + + assert np.allclose(result.numpy(), sk_result) + else: + assert True + #assert result is None + +def compute_batch(preds, target, metric_class, sk_metric, ddp_sync_on_step, ddp=False, metric_args={}): + if ddp: + torch.multiprocessing.spawn( + _compute_batch, args=(preds, target, metric_class, sk_metric, ddp_sync_on_step, NUM_PROCESSES, metric_args), + nprocs=NUM_PROCESSES + ) + else: + # first args: rank, last args: world size + _compute_batch(0, preds, target, metric_class, sk_metric, ddp_sync_on_step, 1, metric_args) From 31f41aa54718256fbdc43799bd335b4c35d1318b Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 11:30:46 -0400 Subject: [PATCH 03/27] pep8 Co-authored-by: Teddy Koker teddy.koker@gmail.com --- .../metrics/classification/accuracy.py | 2 +- tests/metrics/classification/test_accuracy.py | 31 +++++++++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index bc2acf5c3d23d..2acc333f18685 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -25,7 +25,7 @@ def __init__(self, threshold=0.5, **kwargs): super().__init__(**kwargs) # change to dist_reduce_fx - self.add_state("correct", torch.tensor(0), reduction=sum) + self.add_state("correct", torch.tensor(0), reduction=sum) self.add_state("total", torch.tensor(0), reduction=sum) self.threshold = threshold diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 4e45b7a730ea1..8aec57cb2e56e 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -18,95 +18,112 @@ threshold = 0.5 extra_dim = 3 +Input = namedtuple('Input', ["preds", "target"]) + + def test_accuracy_invalid_shape(): with pytest.raises(ValueError): acc = Accuracy() acc.update(preds=torch.rand(1), target=torch.rand(1, 2, 3)) -Input = namedtuple('Input', ["preds", "target"]) _binary_prob_inputs = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) ) + def _binary_prob_sk_metric(preds, target): sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8) sk_target = target.view(-1).numpy() return accuracy_score(y_true=sk_target, y_pred=sk_preds) + _binary_inputs = Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,)) ) + def _binary_sk_metric(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() return accuracy_score(y_true=sk_target, y_pred=sk_preds) + _multilabel_prob_inputs = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes)) ) + def _multilabel_prob_sk_metric(preds, target): sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8) sk_target = target.view(-1).numpy() return accuracy_score(y_true=sk_target, y_pred=sk_preds) + _multilabel_inputs = Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes)) ) + def _multilabel_sk_metric(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() return accuracy_score(y_true=sk_target, y_pred=sk_preds) + _multiclass_prob_inputs = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes), target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE)) ) + def _multiclass_prob_sk_metric(preds, target): sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() sk_target = target.view(-1).numpy() return accuracy_score(y_true=sk_target, y_pred=sk_preds) + _multiclass_inputs = Input( preds=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE)) ) + def _multiclass_sk_metric(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() return accuracy_score(y_true=sk_target, y_pred=sk_preds) + _multidim_multiclass_prob_inputs = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes, extra_dim), target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE, extra_dim)) ) + def _multidim_multiclass_prob_sk_metric(preds, target): sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() sk_target = target.view(-1).numpy() return accuracy_score(y_true=sk_target, y_pred=sk_preds) + _multidim_multiclass_inputs = Input( preds=torch.randint(high=num_classes, size=(NUM_BATCHES, extra_dim, BATCH_SIZE)), target=torch.randint(high=num_classes, size=(NUM_BATCHES, extra_dim, BATCH_SIZE)) ) + def _multidim_multiclass_sk_metric(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() @@ -123,8 +140,16 @@ def _multidim_multiclass_sk_metric(preds, target): (_multilabel_inputs.preds, _multilabel_inputs.target, _multilabel_sk_metric), (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _multiclass_prob_sk_metric), (_multiclass_inputs.preds, _multiclass_inputs.target, _multiclass_sk_metric), - (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, _multidim_multiclass_prob_sk_metric), - (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, _multidim_multiclass_sk_metric) + ( + _multidim_multiclass_prob_inputs.preds, + _multidim_multiclass_prob_inputs.target, + _multidim_multiclass_prob_sk_metric + ), + ( + _multidim_multiclass_inputs.preds, + _multidim_multiclass_inputs.target, + _multidim_multiclass_sk_metric + ) ]) def test_accuracy(ddp, ddp_sync_on_step, preds, target, sk_metric): compute_batch(preds, target, Accuracy, sk_metric, ddp_sync_on_step, ddp, metric_args={"threshold": threshold}) From 5519c30242b1e89bca2c885f651eb0b27c391fb8 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 11:34:32 -0400 Subject: [PATCH 04/27] pep8 Co-authored-by: Teddy Koker --- tests/metrics/classification/test_accuracy.py | 2 +- tests/metrics/test_ddp.py | 2 -- tests/metrics/test_metric.py | 1 + tests/metrics/utils.py | 4 +++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 8aec57cb2e56e..5367a3dcc4362 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -50,7 +50,7 @@ def _binary_sk_metric(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() - return accuracy_score(y_true=sk_target, y_pred=sk_preds) + return accuracy_score(y_true=sk_target, y_pred=sk_preds) _multilabel_prob_inputs = Input( diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py index 8d95e1f25db41..049ae8e00b570 100644 --- a/tests/metrics/test_ddp.py +++ b/tests/metrics/test_ddp.py @@ -41,5 +41,3 @@ def _test_ddp_sum_cat(rank, worldsize): @pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat]) def test_ddp(process): torch.multiprocessing.spawn(process, args=(2,), nprocs=2) - - diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index a02ead1d640a4..f08ffdb0ce0ca 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -8,6 +8,7 @@ class Dummy(Metric): name = "Dummy" + def __init__(self): super().__init__() self.add_state("x", 0, reduction=False) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 19b351cb1f786..d40ebc99ef2f9 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -6,6 +6,7 @@ NUM_BATCHES = 10 BATCH_SIZE = 16 + def setup_ddp(rank, world_size): os.environ["MASTER_ADDR"] = 'localhost' os.environ['MASTER_PORT'] = '8088' @@ -44,7 +45,8 @@ def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_ste assert np.allclose(result.numpy(), sk_result) else: assert True - #assert result is None + # assert result is None + def compute_batch(preds, target, metric_class, sk_metric, ddp_sync_on_step, ddp=False, metric_args={}): if ddp: From a068236e143cb836ce1d08117ab8e2b36a7afbc9 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 11:45:55 -0400 Subject: [PATCH 05/27] docs Co-authored-by: Teddy Koker --- pytorch_lightning/utilities/distributed.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index be421ad3cabb2..1f69d6e0e4946 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -78,11 +78,13 @@ def sync_ddp_if_available( ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process + Args: result: the value to sync and reduce (typically tensor or number) group: the process group to gather results from. Defaults to all processes (world) reduce_op: the reduction operation. Defaults to sum. Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + Return: reduced value """ From 2175a86064cd28dca914c192108ac7c8879ce3e6 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 11:51:27 -0400 Subject: [PATCH 06/27] docs Co-authored-by: Teddy Koker --- docs/source/metrics.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index e69de29bb2d1d..fa0d45863689e 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -0,0 +1,2 @@ +Metrics +------- \ No newline at end of file From 4a9b646f286aeee541057a5dd8487287ec7e2488 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 12:08:27 -0400 Subject: [PATCH 07/27] win ddp tests skip Co-authored-by: Teddy Koker --- tests/metrics/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index d40ebc99ef2f9..b958ef6081cf9 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -1,6 +1,7 @@ import torch import numpy as np import os +import sys NUM_PROCESSES = 2 NUM_BATCHES = 10 @@ -50,6 +51,9 @@ def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_ste def compute_batch(preds, target, metric_class, sk_metric, ddp_sync_on_step, ddp=False, metric_args={}): if ddp: + if sys.platform == "win32": + pytest.skip("DDP not supported on windows") + torch.multiprocessing.spawn( _compute_batch, args=(preds, target, metric_class, sk_metric, ddp_sync_on_step, NUM_PROCESSES, metric_args), nprocs=NUM_PROCESSES From e56d7e2f8edb3a66a38f7dfc4d3398a8b6dd7fc7 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 12:23:34 -0400 Subject: [PATCH 08/27] win ddp tests skip Co-authored-by: Teddy Koker --- tests/metrics/test_ddp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py index 049ae8e00b570..6950cc12b4c9c 100644 --- a/tests/metrics/test_ddp.py +++ b/tests/metrics/test_ddp.py @@ -38,6 +38,7 @@ def _test_ddp_sum_cat(rank, worldsize): assert dummy.bar == worldsize +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat]) def test_ddp(process): torch.multiprocessing.spawn(process, args=(2,), nprocs=2) From 9ccfb4a9edf270713dd37e9bbc303dcc4ee54a0c Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 12:28:01 -0400 Subject: [PATCH 09/27] win ddp tests skip Co-authored-by: Teddy Koker --- tests/metrics/test_ddp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py index 6950cc12b4c9c..5097186977527 100644 --- a/tests/metrics/test_ddp.py +++ b/tests/metrics/test_ddp.py @@ -1,6 +1,7 @@ import pytest import torch import os +import sys from tests.metrics.test_metric import Dummy from tests.metrics.utils import setup_ddp From 95c9e0419cddace8788ed35de68c65ab174e03e0 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 13:02:49 -0400 Subject: [PATCH 10/27] win ddp tests skip Co-authored-by: Teddy Koker --- tests/metrics/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index b958ef6081cf9..9b4b96647cc94 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -2,6 +2,7 @@ import numpy as np import os import sys +import pytest NUM_PROCESSES = 2 NUM_BATCHES = 10 From a3af97df1ee8361506a63a60c5453d93d6be2729 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 14:33:25 -0400 Subject: [PATCH 11/27] reset in compute, cache compute Co-authored-by: Teddy Koker --- pytorch_lightning/metrics/metric.py | 29 ++++++++++++++++++++++++----- tests/metrics/test_metric.py | 14 +++++++++++--- tests/metrics/utils.py | 15 ++++++--------- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index cc908553f3994..48e886f6cace8 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -5,6 +5,7 @@ from collections import namedtuple from copy import deepcopy +import os import torch from torch import nn @@ -26,7 +27,9 @@ def __init__( self.process_group = process_group self._to_sync = True - self.compute = self.wrap_sync(self.compute) + self.update = self.wrap_update(self.update) + self.compute = self.wrap_compute(self.compute) + self._computed = None # initialize state self._reductions = {} @@ -60,6 +63,7 @@ def forward(self, *args, **kwargs): for attr, val in self._cache.items(): setattr(self, attr, val) self._to_sync = True + self._computed = None return result @@ -76,18 +80,33 @@ def sync(self): # agregate lists of tensors reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] setattr(self, attr, reduced) + + def wrap_update(self, update): + @functools.wraps(update) + def wrapped_func(*args, **kwargs): + self._computed = None + return update(*args, **kwargs) + return wrapped_func - def wrap_sync(self, func): - @functools.wraps(func) + def wrap_compute(self, compute): + @functools.wraps(compute) def wrapped_func(*args, **kwargs): + # return cached value + if self._computed is not None: + return self._computed + if self._to_sync and torch.distributed.is_available() and torch.distributed.is_initialized(): self.sync() - return func(*args, **kwargs) + self._computed = compute(*args, **kwargs) + self.reset() + + return self._computed + return wrapped_func @abstractmethod - def update(self): + def update(self) -> None: # pylint: disable=E0202 pass @abstractmethod diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index f08ffdb0ce0ca..6d358c812a775 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -42,10 +42,13 @@ def update(self, x): a = A() assert a.x == 0 + assert a._computed is None a.update(1) + assert a._computed is None assert a.x == 1 a.update(2) assert a.x == 3 + assert a._computed is None def test_compute(): @@ -57,8 +60,13 @@ def compute(self): return self.x a = A() - assert a.x == a.compute() + assert 0 == a.compute() + assert 0 == a.x a.update(1) - assert a.x == a.compute() + assert a._computed is None + assert a.compute() == 1 + assert a._computed == 1 a.update(2) - assert a.x == a.compute() + assert a._computed is None + assert a.compute() == 2 + assert a._computed == 2 diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 9b4b96647cc94..88997006b0014 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -36,18 +36,15 @@ def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_ste sk_batch_result = sk_metric(preds[i], target[i]) assert np.allclose(batch_result.numpy(), sk_batch_result) + # check on all batches on all ranks result = metric.compute() - if rank == 0: - assert isinstance(result, torch.Tensor) + assert isinstance(result, torch.Tensor) - total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)]) - total_target = torch.stack([target[i] for i in range(NUM_BATCHES)]) - sk_result = sk_metric(total_preds, total_target) + total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)]) + total_target = torch.stack([target[i] for i in range(NUM_BATCHES)]) + sk_result = sk_metric(total_preds, total_target) - assert np.allclose(result.numpy(), sk_result) - else: - assert True - # assert result is None + assert np.allclose(result.numpy(), sk_result) def compute_batch(preds, target, metric_class, sk_metric, ddp_sync_on_step, ddp=False, metric_args={}): From 426403e1e47a9d1cb0725554104b1400aca74fd2 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 16:25:34 -0400 Subject: [PATCH 12/27] reduce_ops handling Co-authored-by: Teddy Koker --- .../metrics/classification/accuracy.py | 5 +- pytorch_lightning/metrics/metric.py | 65 ++++++++----------- pytorch_lightning/metrics/utils.py | 37 +++++++++++ tests/metrics/test_ddp.py | 8 +-- tests/metrics/test_metric.py | 38 ++++++++++- 5 files changed, 107 insertions(+), 46 deletions(-) create mode 100644 pytorch_lightning/metrics/utils.py diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 2acc333f18685..25eab952bddf0 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -24,9 +24,8 @@ class Accuracy(Metric): def __init__(self, threshold=0.5, **kwargs): super().__init__(**kwargs) - # change to dist_reduce_fx - self.add_state("correct", torch.tensor(0), reduction=sum) - self.add_state("total", torch.tensor(0), reduction=sum) + self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") self.threshold = threshold diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 48e886f6cace8..c64fec40bd1dd 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -10,6 +10,7 @@ from torch import nn from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.metrics.utils import _flatten, gather_all_tensors_if_available class Metric(nn.Module, ABC): @@ -35,14 +36,26 @@ def __init__( self._reductions = {} self._defaults = {} - def add_state(self, name, default, reduction=None): - if reduction is None: - # TODO: implement default reduction - raise NotImplementedError("default reduction not implemented") + def add_state(self, name, default, dist_reduce_fx: Optional[Union[str, Callable]] = None): + if not isinstance(default, torch.Tensor) or (isinstance(default, list) and len(default) != 0): + raise ValueError( + "state variable must be a tensor or any empty list (where you can append tensors)" + ) + + if dist_reduce_fx == "sum": + dist_reduce_fx = lambda x: torch.sum(x, dim=0) + elif dist_reduce_fx == "mean": + dist_reduce_fx = lambda x: torch.mean(x, dim=0) + elif dist_reduce_fx == "cat": + dist_reduce_fx = lambda x: torch.cat(x, dim=0) + elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable): + raise ValueError( + "`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]" + ) setattr(self, name, default) self._defaults[name] = deepcopy(default) - self._reductions[name] = reduction + self._reductions[name] = dist_reduce_fx def forward(self, *args, **kwargs): # add current step @@ -77,10 +90,16 @@ def sync(self): ) for attr, reduction_fn in self._reductions.items(): - # agregate lists of tensors + # pre-processing ops (stack or flatten for inputs) + if isinstance(output_dict[attr][0], torch.Tensor): + output_dict[attr] = torch.stack(output_dict[attr]) + elif isinstance(output_dict[attr][0], list): + output_dict[attr] = _flatten(output_dict[attr]) + + assert isinstance(reduction_fn, (Callable, None)) reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] setattr(self, attr, reduced) - + def wrap_update(self, update): @functools.wraps(update) def wrapped_func(*args, **kwargs): @@ -102,7 +121,7 @@ def wrapped_func(*args, **kwargs): self.reset() return self._computed - + return wrapped_func @abstractmethod @@ -116,33 +135,3 @@ def compute(self): # pylint: disable=E0202 def reset(self): for attr, default in self._defaults.items(): setattr(self, attr, deepcopy(default)) - - -def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None): - """ - Function to gather all tensors from several ddp processes onto a list that - is broadcasted to all processes - - Args: - result: the value to sync - group: the process group to gather results from. Defaults to all processes (world) - - Return: - gathered_result: list with size equal to the process group where - gathered_result[i] corresponds to result tensor from process i - - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - if group is None: - group = torch.distributed.group.WORLD - - world_size = torch.distributed.get_world_size(group) - - gathered_result = [torch.zeros_like(result) for _ in range(world_size)] - - # sync and broadcast all - torch.distributed.barrier(group=group) - torch.distributed.all_gather(gathered_result, result, group) - - result = gathered_result - return result diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py new file mode 100644 index 0000000000000..c330174a7abff --- /dev/null +++ b/pytorch_lightning/metrics/utils.py @@ -0,0 +1,37 @@ +import torch + +from typing import Any, Callable, Optional, Union + + +def _flatten(x): + return [item for sublist in x for item in sublist] + + +def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None): + """ + Function to gather all tensors from several ddp processes onto a list that + is broadcasted to all processes + + Args: + result: the value to sync + group: the process group to gather results from. Defaults to all processes (world) + + Return: + gathered_result: list with size equal to the process group where + gathered_result[i] corresponds to result tensor from process i + + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if group is None: + group = torch.distributed.group.WORLD + + world_size = torch.distributed.get_world_size(group) + + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + + # sync and broadcast all + torch.distributed.barrier(group=group) + torch.distributed.all_gather(gathered_result, result, group) + + result = gathered_result + return result diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py index 5097186977527..c52ef40fb10d4 100644 --- a/tests/metrics/test_ddp.py +++ b/tests/metrics/test_ddp.py @@ -12,7 +12,7 @@ def _test_ddp_sum(rank, worldsize): setup_ddp(rank, worldsize) dummy = Dummy() - dummy._reductions = {"foo": sum} + dummy._reductions = {"foo": torch.sum} dummy.foo = torch.tensor(1) dummy.sync() @@ -23,7 +23,7 @@ def _test_ddp_cat(rank, worldsize): setup_ddp(rank, worldsize) dummy = Dummy() dummy._reductions = {"foo": torch.cat} - dummy.foo = torch.tensor([1]) + dummy.foo = [torch.tensor([1])] dummy.sync() assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) @@ -31,8 +31,8 @@ def _test_ddp_cat(rank, worldsize): def _test_ddp_sum_cat(rank, worldsize): setup_ddp(rank, worldsize) dummy = Dummy() - dummy._reductions = {"foo": torch.cat, "bar": sum} - dummy.foo = torch.tensor([1]) + dummy._reductions = {"foo": torch.cat, "bar": torch.sum} + dummy.foo = [torch.tensor([1])] dummy.bar = torch.tensor(1) dummy.sync() assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 6d358c812a775..4553e8bbdb6ce 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -2,6 +2,7 @@ import torch from pytorch_lightning.metrics.metric import Metric import os +import numpy as np torch.manual_seed(42) @@ -11,7 +12,7 @@ class Dummy(Metric): def __init__(self): super().__init__() - self.add_state("x", 0, reduction=False) + self.add_state("x", torch.tensor(0), dist_reduce_fx=None) def update(self): pass @@ -24,6 +25,37 @@ def test_inherit(): a = Dummy() +def test_add_state(): + a = Dummy() + + a.add_state("a", torch.tensor(0), "sum") + assert a._reductions["a"](torch.tensor([1, 1])) == 2 + + a.add_state("b", torch.tensor(0), "mean") + assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5) + + a.add_state("c", torch.tensor(0), "cat") + assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2,) + + with pytest.raises(ValueError): + a.add_state("d1", torch.tensor(0), 'xyz') + + with pytest.raises(ValueError): + a.add_state("d2", torch.tensor(0), 42) + + with pytest.raises(ValueError): + a.add_state("d3", [torch.tensor(0)], 'sum') + + with pytest.raises(ValueError): + a.add_state("d4", 42, 'sum') + + def custom_fx(x): + return -1 + + a.add_state("e", torch.tensor(0), custom_fx) + assert a._reductions["e"](torch.tensor([1, 1])) == -1 + + def test_reset(): class A(Dummy): pass @@ -70,3 +102,7 @@ def compute(self): assert a._computed is None assert a.compute() == 2 assert a._computed == 2 + + # called without update, should return cached value + a._computed = 5 + assert a.compute() == 5 From ecabb8d503e14f53632d7bc34fd975a49a4d5d24 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 17:57:02 -0400 Subject: [PATCH 13/27] sync -> sync_dist, type annotations Co-authored-by: Teddy Koker --- pytorch_lightning/metrics/classification/accuracy.py | 6 +++--- pytorch_lightning/metrics/metric.py | 4 ++-- tests/metrics/test_ddp.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 25eab952bddf0..b6b5d6cfb549f 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -21,7 +21,7 @@ class Accuracy(Metric): If preds are floating point we threshold at `threshold` """ - def __init__(self, threshold=0.5, **kwargs): + def __init__(self, threshold: float = 0.5, **kwargs): super().__init__(**kwargs) self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") @@ -29,7 +29,7 @@ def __init__(self, threshold=0.5, **kwargs): self.threshold = threshold - def _input_format(self, preds, target): + def _input_format(self, preds: torch.Tensor, target: torch.Tensor): if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): raise ValueError( "preds and target must have same number of dimensions, or one additional dimension for preds" @@ -45,7 +45,7 @@ def _input_format(self, preds, target): return preds, target - def update(self, preds, target): + def update(self, preds: torch.Tensor, target: torch.Tensor): preds, target = self._input_format(preds, target) assert preds.shape == target.shape diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index c64fec40bd1dd..c1f67705cc475 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -80,7 +80,7 @@ def forward(self, *args, **kwargs): return result - def sync(self): + def sync_dist(self): input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} output_dict = apply_to_collection( input_dict, @@ -115,7 +115,7 @@ def wrapped_func(*args, **kwargs): return self._computed if self._to_sync and torch.distributed.is_available() and torch.distributed.is_initialized(): - self.sync() + self.sync_dist() self._computed = compute(*args, **kwargs) self.reset() diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py index c52ef40fb10d4..f8c0cbb8cd3bc 100644 --- a/tests/metrics/test_ddp.py +++ b/tests/metrics/test_ddp.py @@ -15,7 +15,7 @@ def _test_ddp_sum(rank, worldsize): dummy._reductions = {"foo": torch.sum} dummy.foo = torch.tensor(1) - dummy.sync() + dummy.sync_dist() assert dummy.foo == worldsize @@ -24,7 +24,7 @@ def _test_ddp_cat(rank, worldsize): dummy = Dummy() dummy._reductions = {"foo": torch.cat} dummy.foo = [torch.tensor([1])] - dummy.sync() + dummy.sync_dist() assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) @@ -34,7 +34,7 @@ def _test_ddp_sum_cat(rank, worldsize): dummy._reductions = {"foo": torch.cat, "bar": torch.sum} dummy.foo = [torch.tensor([1])] dummy.bar = torch.tensor(1) - dummy.sync() + dummy.sync_dist() assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) assert dummy.bar == worldsize From 9d62235af34ef64f01a21cd85af680082c17033c Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 5 Oct 2020 19:02:28 -0400 Subject: [PATCH 14/27] wip docs Co-authored-by: Teddy Koker --- docs/source/metrics.rst | 103 +++++++++++++++++- .../metrics/classification/accuracy.py | 14 ++- pytorch_lightning/metrics/metric.py | 47 ++++++-- tests/metrics/test_ddp.py | 6 +- 4 files changed, 156 insertions(+), 14 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index fa0d45863689e..88db918d85a19 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -1,2 +1,103 @@ +.. testsetup:: * + + import torch + from torch.nn import Module + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.metrics import Metric + +.. _metrics: + Metrics -------- \ No newline at end of file +======= + +This metrics API is independent of PytTorch Lightning. + +- mention when reset and compute are called, what forward does, how is it different from update + +.. code-block:: python + + from pytorch_lightning import metrics + + train_accuracy = metrics.Accuracy() + valid_accuracy = metrics.Accuracy(compute_on_step) + + for epoch in range(epochs): + for x, y in train_data: + y_hat = model(x) + + # training step accuracy + batch_acc = train_accuracy(y_hat, y) + + for x, y in valid_data: + y_hat = model(x) + valid_accuracy(y_hat, y) + + # total accuracy over all training batches + total_train_accuracy = train_accuracy.compute() + + # total accuracy over all validation batches + total_valid_accuracy = train_accuracy.compute() + + +These metrics work with DDP in PyTorch and PyTorch Lightning by default. +Lihgtning calls .compute() for you at epoch end on its own. + +Lightning code snippet to using the metrics API. + +.. code-block:: python + + import pytorch_lightning as pl + + +Implementing a Metric +--------------------- + +To implement a metric, subclass the ``Metric`` class and implement the following methods: + + - ``__init__()``: Each state variable should be called using ``self.add_state(...)``. + - ``update()``: Any code needed to update the state given any inputs to the metric. + - ``compute()``: Computes a final value from the state of the metric. + +All you need to do is call add_state correctly to implement a custom metric with DDP. +``reset()`` is called on its own on variables added using ``add_state()``. + +Example implementation: + +.. code-block:: python + + from pytorch_lightning.metrics import Metric + + class MyAccuracy(Metric): + def __init__(self, ddp_sync_on_step=False): + super().__init__(ddp_sync_on_step=ddp_sync_on_step) + + self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + preds, target = self._input_format(preds, target) + assert preds.shape == target.shape + + self.correct += torch.sum(preds == target) + self.total += target.numel() + + def compute(self): + return self.correct.float() / self.total + +Metric +^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.Metric + :noindex: + +Classification Metrics +---------------------- + +Accuracy +^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.Accuracy + :noindex: + +Regression Metrics +------------------ diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index b6b5d6cfb549f..9b6d54e0a81d8 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -21,8 +21,18 @@ class Accuracy(Metric): If preds are floating point we threshold at `threshold` """ - def __init__(self, threshold: float = 0.5, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + threshold: float = 0.5, + compute_on_step: bool = True, + ddp_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + ddp_sync_on_step=ddp_sync_on_step, + process_group=process_group, + ) self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index c1f67705cc475..a37f5253c07f5 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -14,7 +14,9 @@ class Metric(nn.Module, ABC): + """ + """ def __init__( self, compute_on_step: bool = True, @@ -28,15 +30,26 @@ def __init__( self.process_group = process_group self._to_sync = True - self.update = self.wrap_update(self.update) - self.compute = self.wrap_compute(self.compute) + self.update = self._wrap_update(self.update) + self.compute = self._wrap_compute(self.compute) self._computed = None # initialize state self._reductions = {} self._defaults = {} - def add_state(self, name, default, dist_reduce_fx: Optional[Union[str, Callable]] = None): + def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None): + """ + Adds metric state variable. Only used by subclasses. + + Args: + name: The name of the state variable. The variable will then be accessible at ``self.name``. + default: Default value of the state; can either be a tensor or an empty list. The state will be + reset to this value when ``self.reset()`` is called. + dist_reduce_fx (Optional): Function to reduce state accross mutliple GPUs. If value is ``"sum"``, + ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, and ``torch.cat`` respectively, + each with argument ``dim=0``. + """ if not isinstance(default, torch.Tensor) or (isinstance(default, list) and len(default) != 0): raise ValueError( "state variable must be a tensor or any empty list (where you can append tensors)" @@ -58,6 +71,8 @@ def add_state(self, name, default, dist_reduce_fx: Optional[Union[str, Callable] self._reductions[name] = dist_reduce_fx def forward(self, *args, **kwargs): + """ + """ # add current step self.update(*args, **kwargs) @@ -80,7 +95,11 @@ def forward(self, *args, **kwargs): return result - def sync_dist(self): + def _sync_dist(self): + """ + Method to synchronize metric state variables across different processes + in distributed training. + """ input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} output_dict = apply_to_collection( input_dict, @@ -100,22 +119,24 @@ def sync_dist(self): reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] setattr(self, attr, reduced) - def wrap_update(self, update): + def _wrap_update(self, update): @functools.wraps(update) def wrapped_func(*args, **kwargs): self._computed = None return update(*args, **kwargs) return wrapped_func - def wrap_compute(self, compute): + def _wrap_compute(self, compute): @functools.wraps(compute) def wrapped_func(*args, **kwargs): # return cached value if self._computed is not None: return self._computed - if self._to_sync and torch.distributed.is_available() and torch.distributed.is_initialized(): - self.sync_dist() + if self._to_sync \ + and torch.distributed.is_available() \ + and torch.distributed.is_initialized(): + self._sync_dist() self._computed = compute(*args, **kwargs) self.reset() @@ -126,12 +147,22 @@ def wrapped_func(*args, **kwargs): @abstractmethod def update(self) -> None: # pylint: disable=E0202 + """ + Override this method to update the state variables of your metric class. + """ pass @abstractmethod def compute(self): # pylint: disable=E0202 + """ + Override this method to compute the final metric value from state variables + synchronized across the distributed backend. + """ pass def reset(self): + """ + This method automatically resets the metric state variables to their default value. + """ for attr, default in self._defaults.items(): setattr(self, attr, deepcopy(default)) diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py index f8c0cbb8cd3bc..7af4abc087b0c 100644 --- a/tests/metrics/test_ddp.py +++ b/tests/metrics/test_ddp.py @@ -15,7 +15,7 @@ def _test_ddp_sum(rank, worldsize): dummy._reductions = {"foo": torch.sum} dummy.foo = torch.tensor(1) - dummy.sync_dist() + dummy._sync_dist() assert dummy.foo == worldsize @@ -24,7 +24,7 @@ def _test_ddp_cat(rank, worldsize): dummy = Dummy() dummy._reductions = {"foo": torch.cat} dummy.foo = [torch.tensor([1])] - dummy.sync_dist() + dummy._sync_dist() assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) @@ -34,7 +34,7 @@ def _test_ddp_sum_cat(rank, worldsize): dummy._reductions = {"foo": torch.cat, "bar": torch.sum} dummy.foo = [torch.tensor([1])] dummy.bar = torch.tensor(1) - dummy.sync_dist() + dummy._sync_dist() assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) assert dummy.bar == worldsize From 0947c57ef698d25373c3e7774727e3d3e34b1bca Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Mon, 5 Oct 2020 20:07:57 -0400 Subject: [PATCH 15/27] mean squared error --- pytorch_lightning/metrics/__init__.py | 1 + .../metrics/regression/__init__.py | 1 + .../metrics/regression/mean_squared_error.py | 37 ++++++++++++++ tests/metrics/regression/__init__.py | 0 .../regression/test_mean_squared_error.py | 50 +++++++++++++++++++ 5 files changed, 89 insertions(+) create mode 100644 pytorch_lightning/metrics/regression/mean_squared_error.py create mode 100644 tests/metrics/regression/__init__.py create mode 100644 tests/metrics/regression/test_mean_squared_error.py diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 5c3b31f543e56..5fc9141020f2f 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -1,3 +1,4 @@ from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.classification.accuracy import Accuracy +from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError diff --git a/pytorch_lightning/metrics/regression/__init__.py b/pytorch_lightning/metrics/regression/__init__.py index e69de29bb2d1d..ea5feeeeb30ca 100644 --- a/pytorch_lightning/metrics/regression/__init__.py +++ b/pytorch_lightning/metrics/regression/__init__.py @@ -0,0 +1 @@ +from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py new file mode 100644 index 0000000000000..4267f18da1e3a --- /dev/null +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -0,0 +1,37 @@ +import torch +from typing import Any, Callable, Optional, Union + +from pytorch_lightning.metrics.metric import Metric + + +class MeanSquaredError(Metric): + """ + Computes mean squared error. If ``num_targets`` is ``1``, ``compute()`` will return a single float, + otherwise it will return a tensor of length ``num_targets``. + """ + + def __init__( + self, + compute_on_step: bool = True, + ddp_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + ddp_sync_on_step=ddp_sync_on_step, + process_group=process_group, + ) + self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + assert preds.shape == target.shape + squared_error = torch.pow(preds - target, 2) + + self.sum_squared_error += torch.sum(squared_error) + + self.total += target.numel() + + def compute(self): + return self.sum_squared_error / self.total + diff --git a/tests/metrics/regression/__init__.py b/tests/metrics/regression/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/metrics/regression/test_mean_squared_error.py b/tests/metrics/regression/test_mean_squared_error.py new file mode 100644 index 0000000000000..47c5d4c34bbfa --- /dev/null +++ b/tests/metrics/regression/test_mean_squared_error.py @@ -0,0 +1,50 @@ +import torch +import pytest +from collections import namedtuple + +from pytorch_lightning.metrics.regression import MeanSquaredError +from sklearn.metrics import mean_squared_error + +from tests.metrics.utils import compute_batch, setup_ddp +from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE + +num_targets = 5 + +Input = namedtuple('Input', ["preds", "target"]) + +_single_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), +) + + +def _single_target_sk_metric(preds, target): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return mean_squared_error(sk_preds, sk_target) + + +_multi_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), +) + + +def _multi_target_sk_metric(preds, target): + sk_preds = preds.view(-1, num_targets).numpy() + sk_target = target.view(-1, num_targets).numpy() + return mean_squared_error(sk_preds, sk_target) + + +@pytest.mark.parametrize("ddp", [True, False]) +@pytest.mark.parametrize("ddp_sync_on_step", [True, False]) +@pytest.mark.parametrize( + "preds, target, sk_metric", + [ + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), + ], +) +def test_mean_squared_error_single(ddp, ddp_sync_on_step, preds, target, sk_metric): + compute_batch(preds, target, MeanSquaredError, sk_metric, ddp_sync_on_step, ddp) + From 90eccc74c785767425267498dc9eb37a916aa102 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Mon, 5 Oct 2020 20:10:15 -0400 Subject: [PATCH 16/27] docstring --- pytorch_lightning/metrics/regression/mean_squared_error.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index 4267f18da1e3a..2a00c6751638a 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -6,10 +6,8 @@ class MeanSquaredError(Metric): """ - Computes mean squared error. If ``num_targets`` is ``1``, ``compute()`` will return a single float, - otherwise it will return a tensor of length ``num_targets``. + Computes mean squared error. """ - def __init__( self, compute_on_step: bool = True, From b7fce22b90813ab142e14a3943b9dd4e2553f557 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Mon, 5 Oct 2020 23:49:24 -0400 Subject: [PATCH 17/27] added mean ___ error metrics --- docs/source/metrics.rst | 20 ++++++++ pytorch_lightning/metrics/__init__.py | 2 +- .../metrics/regression/__init__.py | 2 +- .../metrics/regression/mean_squared_error.py | 35 ------------- .../regression/test_mean_squared_error.py | 50 ------------------- 5 files changed, 22 insertions(+), 87 deletions(-) delete mode 100644 pytorch_lightning/metrics/regression/mean_squared_error.py delete mode 100644 tests/metrics/regression/test_mean_squared_error.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 88db918d85a19..2d77ac9678ffe 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -101,3 +101,23 @@ Accuracy Regression Metrics ------------------ + +MeanSquaredError +^^^^^^^^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError + :noindex: + + +MeanAbsoluteError +^^^^^^^^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError + :noindex: + + +MeanSquaredLogError +^^^^^^^^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError + :noindex: diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 5fc9141020f2f..615a08e27a8e8 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -1,4 +1,4 @@ from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.classification.accuracy import Accuracy -from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError +from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError diff --git a/pytorch_lightning/metrics/regression/__init__.py b/pytorch_lightning/metrics/regression/__init__.py index ea5feeeeb30ca..5d6cd4b156068 100644 --- a/pytorch_lightning/metrics/regression/__init__.py +++ b/pytorch_lightning/metrics/regression/__init__.py @@ -1 +1 @@ -from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError +from pytorch_lightning.metrics.regression.mean_error import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py deleted file mode 100644 index 2a00c6751638a..0000000000000 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -from typing import Any, Callable, Optional, Union - -from pytorch_lightning.metrics.metric import Metric - - -class MeanSquaredError(Metric): - """ - Computes mean squared error. - """ - def __init__( - self, - compute_on_step: bool = True, - ddp_sync_on_step: bool = False, - process_group: Optional[Any] = None, - ): - super().__init__( - compute_on_step=compute_on_step, - ddp_sync_on_step=ddp_sync_on_step, - process_group=process_group, - ) - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - assert preds.shape == target.shape - squared_error = torch.pow(preds - target, 2) - - self.sum_squared_error += torch.sum(squared_error) - - self.total += target.numel() - - def compute(self): - return self.sum_squared_error / self.total - diff --git a/tests/metrics/regression/test_mean_squared_error.py b/tests/metrics/regression/test_mean_squared_error.py deleted file mode 100644 index 47c5d4c34bbfa..0000000000000 --- a/tests/metrics/regression/test_mean_squared_error.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -import pytest -from collections import namedtuple - -from pytorch_lightning.metrics.regression import MeanSquaredError -from sklearn.metrics import mean_squared_error - -from tests.metrics.utils import compute_batch, setup_ddp -from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE - -num_targets = 5 - -Input = namedtuple('Input', ["preds", "target"]) - -_single_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.rand(NUM_BATCHES, BATCH_SIZE), -) - - -def _single_target_sk_metric(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return mean_squared_error(sk_preds, sk_target) - - -_multi_target_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), -) - - -def _multi_target_sk_metric(preds, target): - sk_preds = preds.view(-1, num_targets).numpy() - sk_target = target.view(-1, num_targets).numpy() - return mean_squared_error(sk_preds, sk_target) - - -@pytest.mark.parametrize("ddp", [True, False]) -@pytest.mark.parametrize("ddp_sync_on_step", [True, False]) -@pytest.mark.parametrize( - "preds, target, sk_metric", - [ - (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), - (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), - ], -) -def test_mean_squared_error_single(ddp, ddp_sync_on_step, preds, target, sk_metric): - compute_batch(preds, target, MeanSquaredError, sk_metric, ddp_sync_on_step, ddp) - From f1c89ad7d3f32da2747d4d19c735085817d0e079 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Mon, 5 Oct 2020 23:54:00 -0400 Subject: [PATCH 18/27] added mean ___ error metrics --- .../metrics/regression/mean_error.py | 120 ++++++++++++++++++ tests/metrics/regression/test_mean_error.py | 57 +++++++++ 2 files changed, 177 insertions(+) create mode 100644 pytorch_lightning/metrics/regression/mean_error.py create mode 100644 tests/metrics/regression/test_mean_error.py diff --git a/pytorch_lightning/metrics/regression/mean_error.py b/pytorch_lightning/metrics/regression/mean_error.py new file mode 100644 index 0000000000000..428f4c4346528 --- /dev/null +++ b/pytorch_lightning/metrics/regression/mean_error.py @@ -0,0 +1,120 @@ +import torch +from typing import Any, Callable, Optional, Union + +from pytorch_lightning.metrics.metric import Metric + + +class MeanSquaredError(Metric): + """ + Computes mean squared error. + + Example: + + >>> from pytorch_lightning.metrics import MeanSquaredError + >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) + >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) + >>> mean_squared_error = MeanSquaredError() + >>> mean_squared_error(preds, target) + tensor(0.8750) + + """ + + def __init__( + self, + compute_on_step: bool = True, + ddp_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + ddp_sync_on_step=ddp_sync_on_step, + process_group=process_group, + ) + self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + assert preds.shape == target.shape + squared_error = torch.pow(preds - target, 2) + self.sum_squared_error += torch.sum(squared_error) + self.total += target.numel() + + def compute(self): + return self.sum_squared_error / self.total + + +class MeanAbsoluteError(Metric): + """ + Computes mean absolute error. + + Example: + + >>> from pytorch_lightning.metrics import MeanAbsoluteError + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> mean_absolute_error = MeanAbsoluteError() + >>> mean_absolute_error(preds, target) + tensor(0.5) + """ + + def __init__( + self, + compute_on_step: bool = True, + ddp_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + ddp_sync_on_step=ddp_sync_on_step, + process_group=process_group, + ) + self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + assert preds.shape == target.shape + abs_error = torch.abs(preds - target) + self.sum_abs_error += torch.sum(abs_error) + self.total += target.numel() + + def compute(self): + return self.sum_abs_error / self.total + + +class MeanSquaredLogError(Metric): + """ + Computes mean squared logarithmic error. + + Example: + + >>> from pytorch_lightning.metrics import MeanSquaredLogError + >>> target = torch.tensor([2.5, 5, 4, 8]) + >>> preds = torch.tensor([3, 5, 2.5, 7]) + >>> mean_squared_log_error = MeanSquaredLogError() + >>> mean_squared_log_error(preds, target) + tensor(0.0397) + + """ + + def __init__( + self, + compute_on_step: bool = True, + ddp_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + ddp_sync_on_step=ddp_sync_on_step, + process_group=process_group, + ) + self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + assert preds.shape == target.shape + squared_log_error = torch.pow(torch.log1p(preds) - torch.log1p(target), 2) + self.sum_squared_log_error += torch.sum(squared_log_error) + self.total += target.numel() + + def compute(self): + return self.sum_squared_log_error / self.total diff --git a/tests/metrics/regression/test_mean_error.py b/tests/metrics/regression/test_mean_error.py new file mode 100644 index 0000000000000..757c6edd56e3c --- /dev/null +++ b/tests/metrics/regression/test_mean_error.py @@ -0,0 +1,57 @@ +import torch +import pytest +from collections import namedtuple +from functools import partial + +from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError +from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_squared_log_error + +from tests.metrics.utils import compute_batch, setup_ddp +from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE + +num_targets = 5 + +Input = namedtuple('Input', ["preds", "target"]) + +_single_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), +) + +_multi_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), +) + + +def _single_target_sk_metric(preds, target, sk_fn=mean_squared_error): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return sk_fn(sk_preds, sk_target) + + +def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error): + sk_preds = preds.view(-1, num_targets).numpy() + sk_target = target.view(-1, num_targets).numpy() + return sk_fn(sk_preds, sk_target) + + +@pytest.mark.parametrize("ddp", [True, False]) +@pytest.mark.parametrize("ddp_sync_on_step", [True, False]) +@pytest.mark.parametrize( + "preds, target, sk_metric", + [ + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), + ], +) +@pytest.mark.parametrize( + "metric_class, sk_fn", + [ + (MeanSquaredError, mean_squared_error), + (MeanAbsoluteError, mean_absolute_error), + (MeanSquaredLogError, mean_squared_log_error), + ], +) +def test_mean_error(ddp, ddp_sync_on_step, preds, target, sk_metric, metric_class, sk_fn): + compute_batch(preds, target, metric_class, partial(sk_metric, sk_fn=sk_fn), ddp_sync_on_step, ddp) From 6efc50a8eb61141dc959ea4c645b669a20f3df19 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Tue, 6 Oct 2020 10:37:35 -0400 Subject: [PATCH 19/27] seperated files --- .../metrics/regression/__init__.py | 5 +- .../metrics/regression/mean_absolute_error.py | 45 +++++++ .../metrics/regression/mean_error.py | 120 ------------------ .../metrics/regression/mean_squared_error.py | 43 +++++++ .../regression/mean_squared_log_error.py | 43 +++++++ 5 files changed, 135 insertions(+), 121 deletions(-) create mode 100644 pytorch_lightning/metrics/regression/mean_absolute_error.py delete mode 100644 pytorch_lightning/metrics/regression/mean_error.py create mode 100644 pytorch_lightning/metrics/regression/mean_squared_error.py create mode 100644 pytorch_lightning/metrics/regression/mean_squared_log_error.py diff --git a/pytorch_lightning/metrics/regression/__init__.py b/pytorch_lightning/metrics/regression/__init__.py index 5d6cd4b156068..3b57f4da8ae21 100644 --- a/pytorch_lightning/metrics/regression/__init__.py +++ b/pytorch_lightning/metrics/regression/__init__.py @@ -1 +1,4 @@ -from pytorch_lightning.metrics.regression.mean_error import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError +from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError +from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError +from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError + diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py new file mode 100644 index 0000000000000..ffe81fdc15924 --- /dev/null +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -0,0 +1,45 @@ +import torch +from typing import Any, Callable, Optional, Union + +from pytorch_lightning.metrics.metric import Metric + + + +class MeanAbsoluteError(Metric): + """ + Computes mean absolute error. + + Example: + + >>> from pytorch_lightning.metrics import MeanAbsoluteError + >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) + >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) + >>> mean_absolute_error = MeanAbsoluteError() + >>> mean_absolute_error(preds, target) + tensor(0.5) + """ + + def __init__( + self, + compute_on_step: bool = True, + ddp_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + ddp_sync_on_step=ddp_sync_on_step, + process_group=process_group, + ) + self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + assert preds.shape == target.shape + abs_error = torch.abs(preds - target) + self.sum_abs_error += torch.sum(abs_error) + self.total += target.numel() + + def compute(self): + return self.sum_abs_error / self.total + + diff --git a/pytorch_lightning/metrics/regression/mean_error.py b/pytorch_lightning/metrics/regression/mean_error.py deleted file mode 100644 index 428f4c4346528..0000000000000 --- a/pytorch_lightning/metrics/regression/mean_error.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -from typing import Any, Callable, Optional, Union - -from pytorch_lightning.metrics.metric import Metric - - -class MeanSquaredError(Metric): - """ - Computes mean squared error. - - Example: - - >>> from pytorch_lightning.metrics import MeanSquaredError - >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) - >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) - >>> mean_squared_error = MeanSquaredError() - >>> mean_squared_error(preds, target) - tensor(0.8750) - - """ - - def __init__( - self, - compute_on_step: bool = True, - ddp_sync_on_step: bool = False, - process_group: Optional[Any] = None, - ): - super().__init__( - compute_on_step=compute_on_step, - ddp_sync_on_step=ddp_sync_on_step, - process_group=process_group, - ) - self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - assert preds.shape == target.shape - squared_error = torch.pow(preds - target, 2) - self.sum_squared_error += torch.sum(squared_error) - self.total += target.numel() - - def compute(self): - return self.sum_squared_error / self.total - - -class MeanAbsoluteError(Metric): - """ - Computes mean absolute error. - - Example: - - >>> from pytorch_lightning.metrics import MeanAbsoluteError - >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) - >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) - >>> mean_absolute_error = MeanAbsoluteError() - >>> mean_absolute_error(preds, target) - tensor(0.5) - """ - - def __init__( - self, - compute_on_step: bool = True, - ddp_sync_on_step: bool = False, - process_group: Optional[Any] = None, - ): - super().__init__( - compute_on_step=compute_on_step, - ddp_sync_on_step=ddp_sync_on_step, - process_group=process_group, - ) - self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - assert preds.shape == target.shape - abs_error = torch.abs(preds - target) - self.sum_abs_error += torch.sum(abs_error) - self.total += target.numel() - - def compute(self): - return self.sum_abs_error / self.total - - -class MeanSquaredLogError(Metric): - """ - Computes mean squared logarithmic error. - - Example: - - >>> from pytorch_lightning.metrics import MeanSquaredLogError - >>> target = torch.tensor([2.5, 5, 4, 8]) - >>> preds = torch.tensor([3, 5, 2.5, 7]) - >>> mean_squared_log_error = MeanSquaredLogError() - >>> mean_squared_log_error(preds, target) - tensor(0.0397) - - """ - - def __init__( - self, - compute_on_step: bool = True, - ddp_sync_on_step: bool = False, - process_group: Optional[Any] = None, - ): - super().__init__( - compute_on_step=compute_on_step, - ddp_sync_on_step=ddp_sync_on_step, - process_group=process_group, - ) - self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor): - assert preds.shape == target.shape - squared_log_error = torch.pow(torch.log1p(preds) - torch.log1p(target), 2) - self.sum_squared_log_error += torch.sum(squared_log_error) - self.total += target.numel() - - def compute(self): - return self.sum_squared_log_error / self.total diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py new file mode 100644 index 0000000000000..74a5255e7f7b3 --- /dev/null +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -0,0 +1,43 @@ +import torch +from typing import Any, Callable, Optional, Union + +from pytorch_lightning.metrics.metric import Metric + + +class MeanSquaredError(Metric): + """ + Computes mean squared error. + + Example: + + >>> from pytorch_lightning.metrics import MeanSquaredError + >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) + >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) + >>> mean_squared_error = MeanSquaredError() + >>> mean_squared_error(preds, target) + tensor(0.8750) + + """ + + def __init__( + self, + compute_on_step: bool = True, + ddp_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + ddp_sync_on_step=ddp_sync_on_step, + process_group=process_group, + ) + self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + assert preds.shape == target.shape + squared_error = torch.pow(preds - target, 2) + self.sum_squared_error += torch.sum(squared_error) + self.total += target.numel() + + def compute(self): + return self.sum_squared_error / self.total diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py new file mode 100644 index 0000000000000..d93c5ad3401c5 --- /dev/null +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -0,0 +1,43 @@ +import torch +from typing import Any, Callable, Optional, Union + +from pytorch_lightning.metrics.metric import Metric + + +class MeanSquaredLogError(Metric): + """ + Computes mean squared logarithmic error. + + Example: + + >>> from pytorch_lightning.metrics import MeanSquaredLogError + >>> target = torch.tensor([2.5, 5, 4, 8]) + >>> preds = torch.tensor([3, 5, 2.5, 7]) + >>> mean_squared_log_error = MeanSquaredLogError() + >>> mean_squared_log_error(preds, target) + tensor(0.0397) + + """ + + def __init__( + self, + compute_on_step: bool = True, + ddp_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + ddp_sync_on_step=ddp_sync_on_step, + process_group=process_group, + ) + self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + assert preds.shape == target.shape + squared_log_error = torch.pow(torch.log1p(preds) - torch.log1p(target), 2) + self.sum_squared_log_error += torch.sum(squared_log_error) + self.total += target.numel() + + def compute(self): + return self.sum_squared_log_error / self.total From f14a4b3c36399d2e56178467542da5d28f3495ea Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Tue, 6 Oct 2020 10:58:33 -0400 Subject: [PATCH 20/27] accuracy doctest --- pytorch_lightning/metrics/classification/accuracy.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 9b6d54e0a81d8..ec3fe0a6cb843 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -20,7 +20,17 @@ class Accuracy(Metric): If preds are integer values, we perform accuracy with those values If preds are floating point we threshold at `threshold` + Example: + + >>> from pytorch_lightning.metrics import Accuracy + >>> target = torch.tensor([0, 1, 2, 3]) + >>> preds = torch.tensor([0, 2, 1, 3]) + >>> accuracy = Accuracy() + >>> accuracy(preds, target) + tensor(0.5) + """ + def __init__( self, threshold: float = 0.5, From 4f3a958307c8a1ba9e0771d7daef028513c5a245 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Tue, 6 Oct 2020 12:09:23 -0400 Subject: [PATCH 21/27] gpu fix --- pytorch_lightning/metrics/metric.py | 15 ++++++++++++--- tests/metrics/test_metric.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index a37f5253c07f5..6814a2b82b88f 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -12,8 +12,10 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.metrics.utils import _flatten, gather_all_tensors_if_available +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin -class Metric(nn.Module, ABC): + +class Metric(DeviceDtypeModuleMixin, nn.Module, ABC): """ """ @@ -66,7 +68,10 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call "`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]" ) - setattr(self, name, default) + if isinstance(default, torch.Tensor): + self.register_buffer(name, default) + else: + setattr(self, name, default) self._defaults[name] = deepcopy(default) self._reductions[name] = dist_reduce_fx @@ -165,4 +170,8 @@ def reset(self): This method automatically resets the metric state variables to their default value. """ for attr, default in self._defaults.items(): - setattr(self, attr, deepcopy(default)) + current_val = getattr(self, attr) + if isinstance(current_val, torch.Tensor): + setattr(self, attr, deepcopy(default).to(current_val.device)) + else: + setattr(self, attr, deepcopy(default)) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 4553e8bbdb6ce..62b9384b219d0 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -62,7 +62,7 @@ class A(Dummy): a = A() assert a.x == 0 - a.x = 5 + a.x = torch.tensor(5) a.reset() assert a.x == 0 From 49a0c7b00359cf29a409bdaf0f69bafae1aebd3e Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Tue, 6 Oct 2020 12:11:58 -0400 Subject: [PATCH 22/27] remove unnecessary mixin --- pytorch_lightning/metrics/metric.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 6814a2b82b88f..4cf6c6505d283 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -12,10 +12,8 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.metrics.utils import _flatten, gather_all_tensors_if_available -from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin - -class Metric(DeviceDtypeModuleMixin, nn.Module, ABC): +class Metric(nn.Module, ABC): """ """ From acf86082097d774468040452cf117161176b4a9a Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 6 Oct 2020 14:11:28 -0400 Subject: [PATCH 23/27] metric and accuracy docstring Co-authored-by: Teddy Koker --- docs/source/metrics.rst | 57 ++++++++++++------- .../metrics/classification/accuracy.py | 35 ++++++++++-- pytorch_lightning/metrics/metric.py | 42 ++++++++++++-- .../metrics/regression/mean_absolute_error.py | 14 ++++- .../metrics/regression/mean_squared_error.py | 12 ++++ .../regression/mean_squared_log_error.py | 12 ++++ 6 files changed, 141 insertions(+), 31 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 2d77ac9678ffe..64e4b78d9bfbf 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -10,16 +10,44 @@ Metrics ======= -This metrics API is independent of PytTorch Lightning. +Using a metric with with PyTorch Lightning: +# TODO 1: write an intro for metrics, and lead the user into the lightning example -- mention when reset and compute are called, what forward does, how is it different from update +# expand a bit on this +These metrics work with DDP in PyTorch and PyTorch Lightning by default. + +.. note:: + + For v0.10.0 the user is expected to call ``.compute()`` on the metric at the end each epoch. + This has been shown in the example below. For v1.0 release after this, we will integrate metrics + with logging and ``.compute()`` will be called automatically by PyTorch Lightning. + +.. code-block:: python + + def __init__(self): + ... + self.accuracy = pl.metrics.Accuracy() + + def training_step(self, batch, batch_idx): + logits = self(x) + ... + # log step metric + self.log('train_acc_step', self.accuracy(logits, y)) + ... + + def training_epoch_end(self, outs): + # log epoch metric + self.log('train_acc_epoch', self.accuracy.compute()) + + +This metrics API is independent of PyTorch Lightning. If you please, they can be used with plain PyTorch like so: .. code-block:: python from pytorch_lightning import metrics train_accuracy = metrics.Accuracy() - valid_accuracy = metrics.Accuracy(compute_on_step) + valid_accuracy = metrics.Accuracy(compute_on_step=False) for epoch in range(epochs): for x, y in train_data: @@ -38,25 +66,16 @@ This metrics API is independent of PytTorch Lightning. # total accuracy over all validation batches total_valid_accuracy = train_accuracy.compute() - -These metrics work with DDP in PyTorch and PyTorch Lightning by default. -Lihgtning calls .compute() for you at epoch end on its own. - -Lightning code snippet to using the metrics API. - -.. code-block:: python - - import pytorch_lightning as pl - - Implementing a Metric --------------------- +# TODO 3: finalize this!, explain reduction in detail + To implement a metric, subclass the ``Metric`` class and implement the following methods: - - ``__init__()``: Each state variable should be called using ``self.add_state(...)``. - - ``update()``: Any code needed to update the state given any inputs to the metric. - - ``compute()``: Computes a final value from the state of the metric. +- ``__init__()``: Each state variable should be called using ``self.add_state(...)``. +- ``update()``: Any code needed to update the state given any inputs to the metric. +- ``compute()``: Computes a final value from the state of the metric. All you need to do is call add_state correctly to implement a custom metric with DDP. ``reset()`` is called on its own on variables added using ``add_state()``. @@ -110,14 +129,14 @@ MeanSquaredError MeanAbsoluteError -^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^ .. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError :noindex: MeanSquaredLogError -^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^ .. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError :noindex: diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index ec3fe0a6cb843..6e9bfd1b46191 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -13,12 +13,29 @@ class Accuracy(Metric): """ Computes accuracy. Works with binary, multiclass, and multilabel data. + Accepts logits from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. - preds and targets must be of shape (N, ...) and (N, ...) or (N, num_classes, ...) and (N, ...) + Forward accepts - If preds and targets are the same shape: - If preds are integer values, we perform accuracy with those values - If preds are floating point we threshold at `threshold` + - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes + - ``target`` (long tensor): ``(N, ...)`` + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. + This is the case for binary and multi-label logits. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + threshold: + Threshold value for binary or multi-label logits. default: 0.5 + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + ddp_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) Example: @@ -66,6 +83,13 @@ def _input_format(self, preds: torch.Tensor, target: torch.Tensor): return preds, target def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ preds, target = self._input_format(preds, target) assert preds.shape == target.shape @@ -73,4 +97,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor): self.total += target.numel() def compute(self): + """ + Computes accuracy over state. + """ return self.correct.float() / self.total diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 4cf6c6505d283..15787028f7c6e 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -15,7 +15,32 @@ class Metric(nn.Module, ABC): """ - + Base class for all metrics present in the Metrics API. + + Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to + handle distributed synchronization and per step metric computation. + + Override ``update()`` and ``compute()`` functions to implement your own metric. Use + ``add_state()`` to register metric state variables which keep track of state on each + call of ``update()`` and are synchronized across processes when ``compute()`` is called. + + Note: + Metric state variables can either be ``torch.Tensors`` or an empty list which can we used + to store `torch.Tensors``. + + Note: + Different metrics only override ``update()`` and not ``forward()``. A call to ``update()`` + is valid, but it won't return the metric value at the current step. A call to ``forward()`` + calls ``update()`` behind the scenes and also return the metric value at the current step. + + Args: + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + ddp_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) """ def __init__( self, @@ -44,11 +69,18 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call Args: name: The name of the state variable. The variable will then be accessible at ``self.name``. - default: Default value of the state; can either be a tensor or an empty list. The state will be + default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be reset to this value when ``self.reset()`` is called. dist_reduce_fx (Optional): Function to reduce state accross mutliple GPUs. If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, and ``torch.cat`` respectively, each with argument ``dim=0``. + + Note: + Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes. + It will be stacked ``torch.Tensor`` across the process dimension if the metric state was a ``torch.Tensor``. + + For the list metric state, passing None to ``dist_reduce_fx`` will return a combined list ``torch.Tensor`` + elements from across all processes. """ if not isinstance(default, torch.Tensor) or (isinstance(default, list) and len(default) != 0): raise ValueError( @@ -70,11 +102,13 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call self.register_buffer(name, default) else: setattr(self, name, default) + self._defaults[name] = deepcopy(default) self._reductions[name] = dist_reduce_fx def forward(self, *args, **kwargs): """ + Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. """ # add current step self.update(*args, **kwargs) @@ -99,10 +133,6 @@ def forward(self, *args, **kwargs): return result def _sync_dist(self): - """ - Method to synchronize metric state variables across different processes - in distributed training. - """ input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} output_dict = apply_to_collection( input_dict, diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index ffe81fdc15924..75936efb7bc25 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -30,16 +30,26 @@ def __init__( ddp_sync_on_step=ddp_sync_on_step, process_group=process_group, ) + self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ assert preds.shape == target.shape abs_error = torch.abs(preds - target) + self.sum_abs_error += torch.sum(abs_error) self.total += target.numel() def compute(self): + """ + Computes mean absolute error over state. + """ return self.sum_abs_error / self.total - - diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index 74a5255e7f7b3..b8c94b56f1071 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -30,14 +30,26 @@ def __init__( ddp_sync_on_step=ddp_sync_on_step, process_group=process_group, ) + self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ assert preds.shape == target.shape squared_error = torch.pow(preds - target, 2) + self.sum_squared_error += torch.sum(squared_error) self.total += target.numel() def compute(self): + """ + Computes mean squared error over state. + """ return self.sum_squared_error / self.total diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index d93c5ad3401c5..ce371f06b4a5d 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -30,14 +30,26 @@ def __init__( ddp_sync_on_step=ddp_sync_on_step, process_group=process_group, ) + self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ assert preds.shape == target.shape squared_log_error = torch.pow(torch.log1p(preds) - torch.log1p(target), 2) + self.sum_squared_log_error += torch.sum(squared_log_error) self.total += target.numel() def compute(self): + """ + Compute mean squared logarithmic error over state. + """ return self.sum_squared_log_error / self.total From 5679052b884ed30bceaebeecc2b32f932bab38c0 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 6 Oct 2020 14:57:18 -0400 Subject: [PATCH 24/27] metric docs Co-authored-by: Teddy Koker --- docs/source/metrics.rst | 30 +++++++++++++++++++---------- pytorch_lightning/metrics/metric.py | 23 ++++++++++++++++------ 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 64e4b78d9bfbf..3a2684f656354 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -10,16 +10,25 @@ Metrics ======= -Using a metric with with PyTorch Lightning: -# TODO 1: write an intro for metrics, and lead the user into the lightning example +``pytorch_lightning.metrics`` is a Metrics API created for easy metric development and usage in +PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of +common metric implementations. -# expand a bit on this -These metrics work with DDP in PyTorch and PyTorch Lightning by default. +The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to the user. The metric base class inherits +``nn.Module`` which allows us to call ``metric(...)`` directly. The ``forward()`` method of the base ``Metric`` class +serves the dual purpose of calling ``update()`` on its input and simultanously returning the value of the metric over the +provided input. + +These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in +distributed mode, the internal state of each metric is synced and reduced across each process, so that the +logic present in ``.compute()`` is applied to state information from all processes. + +The example below shows how to use a metric in your ``LightningModule``: .. note:: For v0.10.0 the user is expected to call ``.compute()`` on the metric at the end each epoch. - This has been shown in the example below. For v1.0 release after this, we will integrate metrics + This has been shown in the example below. For v1.0 release, we will integrate metrics with logging and ``.compute()`` will be called automatically by PyTorch Lightning. .. code-block:: python @@ -40,7 +49,7 @@ These metrics work with DDP in PyTorch and PyTorch Lightning by default. self.log('train_acc_epoch', self.accuracy.compute()) -This metrics API is independent of PyTorch Lightning. If you please, they can be used with plain PyTorch like so: +This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example: .. code-block:: python @@ -69,16 +78,17 @@ This metrics API is independent of PyTorch Lightning. If you please, they can be Implementing a Metric --------------------- -# TODO 3: finalize this!, explain reduction in detail - -To implement a metric, subclass the ``Metric`` class and implement the following methods: +To implement your custom metric, subclass the base ``Metric`` class and implement the following methods: - ``__init__()``: Each state variable should be called using ``self.add_state(...)``. - ``update()``: Any code needed to update the state given any inputs to the metric. - ``compute()``: Computes a final value from the state of the metric. All you need to do is call add_state correctly to implement a custom metric with DDP. -``reset()`` is called on its own on variables added using ``add_state()``. +``reset()`` is called on metric state variables added using ``add_state()``. + +To see how metric states are synchronized across distributed processes, refer to ``add_state()`` docs +from the base ``Metric`` class. Example implementation: diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 15787028f7c6e..506f5a7f06af9 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -31,7 +31,7 @@ class Metric(nn.Module, ABC): Note: Different metrics only override ``update()`` and not ``forward()``. A call to ``update()`` is valid, but it won't return the metric value at the current step. A call to ``forward()`` - calls ``update()`` behind the scenes and also return the metric value at the current step. + automatically calls ``update()`` and also return the metric value at the current step. Args: compute_on_step: @@ -71,16 +71,27 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call name: The name of the state variable. The variable will then be accessible at ``self.name``. default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be reset to this value when ``self.reset()`` is called. - dist_reduce_fx (Optional): Function to reduce state accross mutliple GPUs. If value is ``"sum"``, + dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, and ``torch.cat`` respectively, - each with argument ``dim=0``. + each with argument ``dim=0``. The user can also pass a custom function in this parameter. Note: Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes. - It will be stacked ``torch.Tensor`` across the process dimension if the metric state was a ``torch.Tensor``. + However, there won't be any reduction function applied to the synchronized metric state. + + The metric states would be synced as follows + + - If the metric state is ``torch.Tensor``, the synced value will be a stacked ``torch.Tensor`` across + the process dimension if the metric state was a ``torch.Tensor``. The original ``torch.Tensor`` metric + state retains dimension and hence the synchronized output will be of shape ``(num_process, ...)``. + + - If the metric state is a ``list``, the synced value will be a ``list`` containing the + combined elements from all processes. + + Note: + When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow + the format discussed in the above note. - For the list metric state, passing None to ``dist_reduce_fx`` will return a combined list ``torch.Tensor`` - elements from across all processes. """ if not isinstance(default, torch.Tensor) or (isinstance(default, list) and len(default) != 0): raise ValueError( From e8cc4029272c775efc2004a73ae32d419655379d Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 6 Oct 2020 15:19:40 -0400 Subject: [PATCH 25/27] pep8, changelog Co-authored-by: Teddy Koker --- CHANGELOG.md | 3 +++ .../metrics/classification/accuracy.py | 2 +- pytorch_lightning/metrics/metric.py | 22 +++++++++++-------- .../metrics/regression/__init__.py | 1 - .../metrics/regression/mean_absolute_error.py | 3 +-- pytorch_lightning/metrics/utils.py | 12 ++++++++++ 6 files changed, 30 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 433c754777b95..0db2548e2727f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added new Metrics API. ([#3868](https://github.com/PyTorchLightning/pytorch-lightning/pull/3868)) + - Enable PyTorch 1.7 compatibility ([#3541](https://github.com/PyTorchLightning/pytorch-lightning/pull/3541)) - Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528)) @@ -63,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed +- Remove old Metrics API. ([#3868](https://github.com/PyTorchLightning/pytorch-lightning/pull/3868)) ### Fixed diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 6e9bfd1b46191..50751aa73be51 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -44,7 +44,7 @@ class Accuracy(Metric): >>> preds = torch.tensor([0, 2, 1, 3]) >>> accuracy = Accuracy() >>> accuracy(preds, target) - tensor(0.5) + tensor(0.5000) """ diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 506f5a7f06af9..bb58d11e6d04b 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -11,6 +11,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.metrics.utils import _flatten, gather_all_tensors_if_available +from pytorch_lightning.metrics.utils import dim_zero_cat, dim_zero_mean, dim_zero_sum class Metric(nn.Module, ABC): @@ -71,9 +72,10 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call name: The name of the state variable. The variable will then be accessible at ``self.name``. default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be reset to this value when ``self.reset()`` is called. - dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. If value is ``"sum"``, - ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, and ``torch.cat`` respectively, - each with argument ``dim=0``. The user can also pass a custom function in this parameter. + dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. + If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, + and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom + function in this parameter. Note: Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes. @@ -99,11 +101,11 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call ) if dist_reduce_fx == "sum": - dist_reduce_fx = lambda x: torch.sum(x, dim=0) + dist_reduce_fx = dim_zero_sum elif dist_reduce_fx == "mean": - dist_reduce_fx = lambda x: torch.mean(x, dim=0) + dist_reduce_fx = dim_zero_mean elif dist_reduce_fx == "cat": - dist_reduce_fx = lambda x: torch.cat(x, dim=0) + dist_reduce_fx = dim_zero_cat elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable): raise ValueError( "`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]" @@ -177,9 +179,11 @@ def wrapped_func(*args, **kwargs): if self._computed is not None: return self._computed - if self._to_sync \ - and torch.distributed.is_available() \ - and torch.distributed.is_initialized(): + if ( + self._to_sync + and torch.distributed.is_available() + and torch.distributed.is_initialized() + ): self._sync_dist() self._computed = compute(*args, **kwargs) diff --git a/pytorch_lightning/metrics/regression/__init__.py b/pytorch_lightning/metrics/regression/__init__.py index 3b57f4da8ae21..c5f235aeff12b 100644 --- a/pytorch_lightning/metrics/regression/__init__.py +++ b/pytorch_lightning/metrics/regression/__init__.py @@ -1,4 +1,3 @@ from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError - diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index 75936efb7bc25..b8e00ede728e1 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -4,7 +4,6 @@ from pytorch_lightning.metrics.metric import Metric - class MeanAbsoluteError(Metric): """ Computes mean absolute error. @@ -16,7 +15,7 @@ class MeanAbsoluteError(Metric): >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) >>> mean_absolute_error = MeanAbsoluteError() >>> mean_absolute_error(preds, target) - tensor(0.5) + tensor(0.5000) """ def __init__( diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index c330174a7abff..17f8bb65172ee 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -3,6 +3,18 @@ from typing import Any, Callable, Optional, Union +def dim_zero_cat(x): + return torch.cat(x, dim=0) + + +def dim_zero_sum(x): + return torch.sum(x, dim=0) + + +def dim_zero_mean(x): + return torch.mean(x, dim=0) + + def _flatten(x): return [item for sublist in x for item in sublist] From 41ddda8cae74d38d7ce3bf0f6390829e435871c8 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 6 Oct 2020 15:45:56 -0400 Subject: [PATCH 26/27] refactor dist utils, pep8 --- pytorch_lightning/metrics/metric.py | 8 +++--- pytorch_lightning/metrics/utils.py | 30 ---------------------- pytorch_lightning/utilities/distributed.py | 30 ++++++++++++++++++++++ 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index bb58d11e6d04b..fe78869911fda 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -10,8 +10,8 @@ from torch import nn from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.metrics.utils import _flatten, gather_all_tensors_if_available -from pytorch_lightning.metrics.utils import dim_zero_cat, dim_zero_mean, dim_zero_sum +from pytorch_lightning.utilities.distributed import gather_all_tensors_if_available +from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum class Metric(nn.Module, ABC): @@ -181,8 +181,8 @@ def wrapped_func(*args, **kwargs): if ( self._to_sync - and torch.distributed.is_available() - and torch.distributed.is_initialized() + and torch.distributed.is_available() # noqa: W503 + and torch.distributed.is_initialized() # noqa: W503 ): self._sync_dist() diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 17f8bb65172ee..850b3858b0848 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -17,33 +17,3 @@ def dim_zero_mean(x): def _flatten(x): return [item for sublist in x for item in sublist] - - -def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None): - """ - Function to gather all tensors from several ddp processes onto a list that - is broadcasted to all processes - - Args: - result: the value to sync - group: the process group to gather results from. Defaults to all processes (world) - - Return: - gathered_result: list with size equal to the process group where - gathered_result[i] corresponds to result tensor from process i - - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - if group is None: - group = torch.distributed.group.WORLD - - world_size = torch.distributed.get_world_size(group) - - gathered_result = [torch.zeros_like(result) for _ in range(world_size)] - - # sync and broadcast all - torch.distributed.barrier(group=group) - torch.distributed.all_gather(gathered_result, result, group) - - result = gathered_result - return result diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 1f69d6e0e4946..a29fd3e5a1059 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -73,6 +73,36 @@ def find_free_network_port() -> int: return port +def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None): + """ + Function to gather all tensors from several ddp processes onto a list that + is broadcasted to all processes + + Args: + result: the value to sync + group: the process group to gather results from. Defaults to all processes (world) + + Return: + gathered_result: list with size equal to the process group where + gathered_result[i] corresponds to result tensor from process i + + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if group is None: + group = torch.distributed.group.WORLD + + world_size = torch.distributed.get_world_size(group) + + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + + # sync and broadcast all + torch.distributed.barrier(group=group) + torch.distributed.all_gather(gathered_result, result, group) + + result = gathered_result + return result + + def sync_ddp_if_available( result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> torch.Tensor: From 1069381020b39f01b36128d13c060523e34c64d1 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 6 Oct 2020 15:46:40 -0400 Subject: [PATCH 27/27] refactor dist utils, pep8 --- pytorch_lightning/core/step_result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 50fa6266c7964..b6d46c691ce11 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -14,7 +14,7 @@ import numbers from copy import copy -from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any, List, Tuple +from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any, List, Tuple, Iterable import torch from torch import Tensor