-
Notifications
You must be signed in to change notification settings - Fork 387
/
pearson.py
146 lines (123 loc) · 5.18 KB
/
pearson.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Tuple
import torch
from torch import Tensor
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.checks import _check_same_shape
def _pearson_corrcoef_update(
preds: Tensor,
target: Tensor,
mean_x: Tensor,
mean_y: Tensor,
var_x: Tensor,
var_y: Tensor,
corr_xy: Tensor,
num_prior: Tensor,
num_outputs: int,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Update and returns variables required to compute Pearson Correlation Coefficient.
Check for same shape of input tensors.
Args:
preds: estimated scores
target: ground truth scores
mean_x: current mean estimate of x tensor
mean_y: current mean estimate of y tensor
var_x: current variance estimate of x tensor
var_y: current variance estimate of y tensor
corr_xy: current covariance estimate between x and y tensor
num_prior: current number of observed observations
num_outputs: Number of outputs in multioutput setting
"""
# Data checking
_check_same_shape(preds, target)
_check_data_shape_to_num_outputs(preds, target, num_outputs)
num_obs = preds.shape[0]
cond = num_prior.mean() > 0 or num_obs == 1
if cond:
mx_new = (num_prior * mean_x + preds.sum(0)) / (num_prior + num_obs)
my_new = (num_prior * mean_y + target.sum(0)) / (num_prior + num_obs)
else:
mx_new = preds.mean(0).to(mean_x.dtype)
my_new = target.mean(0).to(mean_y.dtype)
num_prior += num_obs
if cond:
var_x += ((preds - mx_new) * (preds - mean_x)).sum(0)
var_y += ((target - my_new) * (target - mean_y)).sum(0)
else:
var_x += preds.var(0) * (num_obs - 1)
var_y += target.var(0) * (num_obs - 1)
corr_xy += ((preds - mx_new) * (target - mean_y)).sum(0)
mean_x = mx_new
mean_y = my_new
return mean_x, mean_y, var_x, var_y, corr_xy, num_prior
def _pearson_corrcoef_compute(
var_x: Tensor,
var_y: Tensor,
corr_xy: Tensor,
nb: Tensor,
) -> Tensor:
"""Compute the final pearson correlation based on accumulated statistics.
Args:
var_x: variance estimate of x tensor
var_y: variance estimate of y tensor
corr_xy: covariance estimate between x and y tensor
nb: number of observations
"""
var_x /= nb - 1
var_y /= nb - 1
corr_xy /= nb - 1
# if var_x, var_y is float16 and on cpu, make it bfloat16 as sqrt is not supported for float16
# on cpu, remove this after https://github.com/pytorch/pytorch/issues/54774 is fixed
if var_x.dtype == torch.float16 and var_x.device == torch.device("cpu"):
var_x = var_x.bfloat16()
var_y = var_y.bfloat16()
bound = math.sqrt(torch.finfo(var_x.dtype).eps)
if (var_x < bound).any() or (var_y < bound).any():
rank_zero_warn(
"The variance of predictions or target is close to zero. This can cause instability in Pearson correlation"
"coefficient, leading to wrong results. Consider re-scaling the input if possible or computing using a"
f"larger dtype (currently using {var_x.dtype}).",
UserWarning,
)
corrcoef = (corr_xy / (var_x * var_y).sqrt()).squeeze()
return torch.clamp(corrcoef, -1.0, 1.0)
def pearson_corrcoef(preds: Tensor, target: Tensor) -> Tensor:
"""Compute pearson correlation coefficient.
Args:
preds: estimated scores
target: ground truth scores
Example (single output regression):
>>> from torchmetrics.functional.regression import pearson_corrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> pearson_corrcoef(preds, target)
tensor(0.9849)
Example (multi output regression):
>>> from torchmetrics.functional.regression import pearson_corrcoef
>>> target = torch.tensor([[3, -0.5], [2, 7]])
>>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
>>> pearson_corrcoef(preds, target)
tensor([1., 1.])
"""
d = preds.shape[1] if preds.ndim == 2 else 1
_temp = torch.zeros(d, dtype=preds.dtype, device=preds.device)
mean_x, mean_y, var_x = _temp.clone(), _temp.clone(), _temp.clone()
var_y, corr_xy, nb = _temp.clone(), _temp.clone(), _temp.clone()
_, _, var_x, var_y, corr_xy, nb = _pearson_corrcoef_update(
preds, target, mean_x, mean_y, var_x, var_y, corr_xy, nb, num_outputs=1 if preds.ndim == 1 else preds.shape[-1]
)
return _pearson_corrcoef_compute(var_x, var_y, corr_xy, nb)