Skip to content

Commit

Permalink
New metric: Normalized Mutual Information Score (#2029)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people committed Aug 31, 2023
1 parent 39d90d0 commit b8e1b23
Show file tree
Hide file tree
Showing 12 changed files with 440 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `RandScore` metric to cluster package ([#2025](https://github.com/Lightning-AI/torchmetrics/pull/2025)


- Added `NormalizedMutualInfoScore` metric to cluster package ([#2029](https://github.com/Lightning-AI/torchmetrics/pull/2029)


### Changed

-
Expand Down
21 changes: 21 additions & 0 deletions docs/source/clustering/normalized_mutual_info_score.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Normalized Mutual Information Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg
:tags: Clustering

.. include:: ../links.rst

###################################
Normalized Mutual Information Score
###################################

Module Interface
________________

.. autoclass:: torchmetrics.clustering.NormalizedMutualInfoScore
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.clustering.normalized_mutual_info_score
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,6 @@
.. _DIOU: https://arxiv.org/abs/1911.08287v1
.. _GIOU: https://arxiv.org/abs/1902.09630
.. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information
.. _Normalized Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html
.. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools
.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075
2 changes: 2 additions & 0 deletions src/torchmetrics/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.clustering.mutual_info_score import MutualInfoScore
from torchmetrics.clustering.normalized_mutual_info_score import NormalizedMutualInfoScore
from torchmetrics.clustering.rand_score import RandScore

__all__ = [
"MutualInfoScore",
"NormalizedMutualInfoScore",
"RandScore",
]
4 changes: 2 additions & 2 deletions src/torchmetrics/clustering/mutual_info_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ class MutualInfoScore(Metric):
"""

is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = True
higher_is_better: Optional[bool] = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
preds: List[Tensor]
target: List[Tensor]
Expand Down
126 changes: 126 additions & 0 deletions src/torchmetrics/clustering/normalized_mutual_info_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Literal, Optional, Sequence, Union

from torch import Tensor

from torchmetrics.clustering.mutual_info_score import MutualInfoScore
from torchmetrics.functional.clustering.normalized_mutual_info_score import (
_validate_average_method_arg,
normalized_mutual_info_score,
)
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["NormalizedMutualInfoScore.plot"]


class NormalizedMutualInfoScore(MutualInfoScore):
r"""Compute `Normalized Mutual Information Score`_.
.. math::
NMI(U,V) = \frac{MI(U,V)}{M_p(U,V)}
Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, :math:`M_p(U,V)` is the
generalized mean of order :math:`p` of :math:`U` and :math:`V`, and :math:`MI(U,V)` is the mutual information score
between clusters :math:`U` and :math:`V`. The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields
the same mutual information score.
This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not
be available in practice since clustering in generally is used for unsupervised learning.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels
- ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels
As output of ``forward`` and ``compute`` the metric returns the following output:
- ``nmi_score`` (:class:`~torch.Tensor`): A tensor with the Normalized Mutual Information Score
Args:
average_method: Method used to calculate generalized mean for normalization
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> from torchmetrics.clustering import NormalizedMutualInfoScore
>>> preds = torch.tensor([2, 1, 0, 1, 0])
>>> target = torch.tensor([0, 2, 1, 1, 0])
>>> nmi_score = NormalizedMutualInfoScore("arithmetic")
>>> nmi_score(preds, target)
tensor(0.4744)
"""

is_differentiable: bool = True
higher_is_better: Optional[bool] = None
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 0.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(
self, average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic", **kwargs: Any
) -> None:
super().__init__(**kwargs)
_validate_average_method_arg(average_method)
self.average_method = average_method

def compute(self) -> Tensor:
"""Compute normalized mutual information over state."""
return normalized_mutual_info_score(dim_zero_cat(self.preds), dim_zero_cat(self.target), self.average_method)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import NormalizedMutualInfoScore
>>> metric = NormalizedMutualInfoScore()
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import NormalizedMutualInfoScore
>>> metric = NormalizedMutualInfoScore()
>>> for _ in range(10):
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
"""
return self._plot(val, ax)
7 changes: 6 additions & 1 deletion src/torchmetrics/functional/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
from torchmetrics.functional.clustering.normalized_mutual_info_score import normalized_mutual_info_score
from torchmetrics.functional.clustering.rand_score import rand_score

__all__ = ["mutual_info_score", "rand_score"]
__all__ = [
"mutual_info_score",
"normalized_mutual_info_score",
"rand_score",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal

import torch
from torch import Tensor

from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
from torchmetrics.functional.clustering.utils import calculate_entropy, calculate_generalized_mean, check_cluster_labels


def _validate_average_method_arg(
average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic"
) -> None:
if average_method not in ("min", "geometric", "arithmetic", "max"):
raise ValueError(
"Expected argument `average_method` to be one of `min`, `geometric`, `arithmetic`, `max`,"
f"but got {average_method}"
)


def normalized_mutual_info_score(
preds: Tensor, target: Tensor, average_method: Literal["min", "geometric", "arithmetic", "max"] = "arithmetic"
) -> Tensor:
"""Compute normalized mutual information between two clusterings.
Args:
preds: predicted cluster labels
target: ground truth cluster labels
average_method: normalizer computation method
Returns:
Scalar tensor with normalized mutual info score between 0.0 and 1.0
Example:
>>> from torchmetrics.functional.clustering import normalized_mutual_info_score
>>> target = torch.tensor([0, 3, 2, 2, 1])
>>> preds = torch.tensor([1, 3, 2, 0, 1])
>>> normalized_mutual_info_score(preds, target, "arithmetic")
tensor(0.7919)
"""
check_cluster_labels(preds, target)
_validate_average_method_arg(average_method)
mutual_info = mutual_info_score(preds, target)
if torch.allclose(mutual_info, torch.tensor(0.0), atol=torch.finfo().eps):
return mutual_info

normalizer = calculate_generalized_mean(
torch.stack([calculate_entropy(preds), calculate_entropy(target)]), average_method
)

return mutual_info / normalizer
77 changes: 75 additions & 2 deletions src/torchmetrics/functional/clustering/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,87 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Union

import torch
from torch import Tensor
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape


def calculate_entropy(x: Tensor) -> Tensor:
"""Calculate entropy for a tensor of labels.
Final calculation of entropy is performed in log form to account for roundoff error.
Args:
x: labels
Returns:
entropy: entropy of tensor
Example:
>>> from torchmetrics.functional.clustering.utils import calculate_entropy
>>> labels = torch.tensor([1, 3, 2, 2, 1])
>>> calculate_entropy(labels)
tensor(1.0549)
"""
if len(x) == 0:
return tensor(1.0, device=x.device)

p = torch.bincount(torch.unique(x, return_inverse=True)[1])
p = p[p > 0]

if p.size() == 1:
return tensor(0.0, device=x.device)

n = p.sum()
return -torch.sum((p / n) * (torch.log(p) - torch.log(n)))


def calculate_generalized_mean(x: Tensor, p: Union[int, Literal["min", "geometric", "arithmetic", "max"]]) -> Tensor:
"""Return generalized (power) mean of a tensor.
Args:
x: tensor
p: power
Returns:
generalized_mean: generalized mean
Example (p="min"):
>>> from torchmetrics.functional.clustering.utils import calculate_generalized_mean
>>> x = torch.tensor([1, 3, 2, 2, 1])
>>> calculate_generalized_mean(x, "min")
tensor(1)
Example (p="geometric"):
>>> from torchmetrics.functional.clustering.utils import calculate_generalized_mean
>>> x = torch.tensor([1, 3, 2, 2, 1])
>>> calculate_generalized_mean(x, "geometric")
tensor(1.6438)
"""
if torch.is_complex(x) or torch.any(x <= 0.0):
raise ValueError("`x` must contain positive real numbers")

if isinstance(p, str):
if p == "min":
return x.min()
if p == "geometric":
return torch.exp(torch.mean(x.log()))
if p == "arithmetic":
return x.mean()
if p == "max":
return x.max()

raise ValueError("'method' must be 'min', 'geometric', 'arirthmetic', or 'max'")

return torch.mean(torch.pow(x, p)) ** (1.0 / p)


def calculate_contingency_matrix(
preds: Tensor, target: Tensor, eps: Optional[float] = None, sparse: bool = False
) -> Tensor:
Expand Down
Loading

0 comments on commit b8e1b23

Please sign in to comment.