-
Notifications
You must be signed in to change notification settings - Fork 387
/
multitask.py
253 lines (219 loc) · 11.6 KB
/
multitask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this is just a bypass for this module name collision with build-in one
from typing import Any, Dict, Optional, Sequence, Union
from torch import Tensor, nn
from torchmetrics.collections import MetricCollection
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.wrappers.abstract import WrapperMetric
if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MultitaskWrapper.plot"]
class MultitaskWrapper(WrapperMetric):
"""Wrapper class for computing different metrics on different tasks in the context of multitask learning.
In multitask learning the different tasks requires different metrics to be evaluated. This wrapper allows
for easy evaluation in such cases by supporting multiple predictions and targets through a dictionary.
Note that only metrics where the signature of `update` follows the stardard `preds, target` is supported.
Args:
task_metrics:
Dictionary associating each task to a Metric or a MetricCollection. The keys of the dictionary represent the
names of the tasks, and the values represent the metrics to use for each task.
Raises:
TypeError:
If argument `task_metrics` is not an dictionary
TypeError:
If not all values in the `task_metrics` dictionary is instances of `Metric` or `MetricCollection`
Example (with a single metric per class):
>>> import torch
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> from torchmetrics.classification import BinaryAccuracy
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>>
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
... "Classification": BinaryAccuracy(),
... "Regression": MeanSquaredError()
... })
>>> metrics.update(preds, targets)
>>> metrics.compute()
{'Classification': tensor(0.3333), 'Regression': tensor(0.8333)}
Example (with several metrics per task):
>>> import torch
>>> from torchmetrics import MetricCollection
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
>>> from torchmetrics.classification import BinaryAccuracy, BinaryF1Score
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>>
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
... "Classification": MetricCollection(BinaryAccuracy(), BinaryF1Score()),
... "Regression": MetricCollection(MeanSquaredError(), MeanAbsoluteError())
... })
>>> metrics.update(preds, targets)
>>> metrics.compute()
{'Classification': {'BinaryAccuracy': tensor(0.3333), 'BinaryF1Score': tensor(0.)},
'Regression': {'MeanSquaredError': tensor(0.8333), 'MeanAbsoluteError': tensor(0.6667)}}
"""
is_differentiable = False
def __init__(
self,
task_metrics: Dict[str, Union[Metric, MetricCollection]],
) -> None:
self._check_task_metrics_type(task_metrics)
super().__init__()
self.task_metrics = nn.ModuleDict(task_metrics)
@staticmethod
def _check_task_metrics_type(task_metrics: Dict[str, Union[Metric, MetricCollection]]) -> None:
if not isinstance(task_metrics, dict):
raise TypeError(f"Expected argument `task_metrics` to be a dict. Found task_metrics = {task_metrics}")
for metric in task_metrics.values():
if not (isinstance(metric, (Metric, MetricCollection))):
raise TypeError(
"Expected each task's metric to be a Metric or a MetricCollection. "
f"Found a metric of type {type(metric)}"
)
def update(self, task_preds: Dict[str, Tensor], task_targets: Dict[str, Tensor]) -> None:
"""Update each task's metric with its corresponding pred and target.
Args:
task_preds: Dictionary associating each task to a Tensor of pred.
task_targets: Dictionary associating each task to a Tensor of target.
"""
if not self.task_metrics.keys() == task_preds.keys() == task_targets.keys():
raise ValueError(
"Expected arguments `task_preds` and `task_targets` to have the same keys as the wrapped `task_metrics`"
f". Found task_preds.keys() = {task_preds.keys()}, task_targets.keys() = {task_targets.keys()} "
f"and self.task_metrics.keys() = {self.task_metrics.keys()}"
)
for task_name, metric in self.task_metrics.items():
pred = task_preds[task_name]
target = task_targets[task_name]
metric.update(pred, target)
def compute(self) -> Dict[str, Any]:
"""Compute metrics for all tasks."""
return {task_name: metric.compute() for task_name, metric in self.task_metrics.items()}
def forward(self, task_preds: Dict[str, Tensor], task_targets: Dict[str, Tensor]) -> Dict[str, Any]:
"""Call underlying forward methods for all tasks and return the result as a dictionary."""
# This method is overriden because we do not need the complex version defined in Metric, that relies on the
# value of full_state_update, and that also accumulates the results. Here, all computations are handled by the
# underlying metrics, which all have their own value of full_state_update, and which all accumulate the results
# by themselves.
return {
task_name: metric(task_preds[task_name], task_targets[task_name])
for task_name, metric in self.task_metrics.items()
}
def reset(self) -> None:
"""Reset all underlying metrics."""
for metric in self.task_metrics.values():
metric.reset()
super().reset()
def plot(
self, val: Optional[Union[Dict, Sequence[Dict]]] = None, axes: Optional[Sequence[_AX_TYPE]] = None
) -> Sequence[_PLOT_OUT_TYPE]:
"""Plot a single or multiple values from the metric.
All tasks' results are plotted on individual axes.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
axes: Sequence of matplotlib axis objects. If provided, will add the plots to the provided axis objects.
If not provided, will create them.
Returns:
Sequence of tuples with Figure and Axes object for each task.
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> from torchmetrics.classification import BinaryAccuracy
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>>
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
... "Classification": BinaryAccuracy(),
... "Regression": MeanSquaredError()
... })
>>> metrics.update(preds, targets)
>>> value = metrics.compute()
>>> fig_, ax_ = metrics.plot(value)
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> from torchmetrics.classification import BinaryAccuracy
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>>
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
... "Classification": BinaryAccuracy(),
... "Regression": MeanSquaredError()
... })
>>> values = []
>>> for _ in range(10):
... values.append(metrics(preds, targets))
>>> fig_, ax_ = metrics.plot(values)
"""
if axes is not None:
if not isinstance(axes, Sequence):
raise TypeError(f"Expected argument `axes` to be a Sequence. Found type(axes) = {type(axes)}")
if not all(isinstance(ax, _AX_TYPE) for ax in axes):
raise TypeError("Expected each ax in argument `axes` to be a matplotlib axis object")
if len(axes) != len(self.task_metrics):
raise ValueError(
"Expected argument `axes` to be a Sequence of the same length as the number of tasks."
f"Found len(axes) = {len(axes)} and {len(self.task_metrics)} tasks"
)
val = val if val is not None else self.compute()
fig_axs = []
for i, (task_name, task_metric) in enumerate(self.task_metrics.items()):
ax = axes[i] if axes is not None else None
if isinstance(val, Dict):
f, a = task_metric.plot(val[task_name], ax=ax)
elif isinstance(val, Sequence):
f, a = task_metric.plot([v[task_name] for v in val], ax=ax)
else:
raise TypeError(
"Expected argument `val` to be None or of type Dict or Sequence[Dict]. "
f"Found type(val)= {type(val)}"
)
fig_axs.append((f, a))
return fig_axs