-
Notifications
You must be signed in to change notification settings - Fork 0
/
experiment_utils.py
57 lines (46 loc) · 1.66 KB
/
experiment_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
def dist_to_dirac(samples, theta_true, metrics=["mse", "mmd"], scaled=False):
dict = {metric: [] for metric in metrics}
if theta_true.ndim > 1:
theta_true = theta_true[0]
for j in range(len(theta_true)):
samples_coordj = samples[:, j]
if "mse" in metrics:
dict["mse"].append((samples_coordj - theta_true[j]).square().mean())
if "mmd" in metrics:
sd = torch.sqrt(samples_coordj.var())
if scaled:
mmd = (
samples_coordj.var()
+ (samples_coordj.mean() - theta_true[j]).square()
) / sd
else:
mmd = (
samples_coordj.var()
+ (samples_coordj.mean() - theta_true[j]).square()
)
dict["mmd"].append(mmd)
for metric in metrics:
dict[metric] = torch.stack(dict[metric]).mean()
return dict
def _matrix_pow(matrix: torch.Tensor, p: float) -> torch.Tensor:
r"""
Power of a matrix using Eigen Decomposition.
Args:
matrix: matrix
p: power
Returns:
Power of a matrix
"""
L, V = torch.linalg.eig(matrix)
L = L.real
V = V.real
return V @ torch.diag_embed(L.pow(p)) @ torch.linalg.inv(V)
def gaussien_wasserstein(ref_mu, ref_cov, X2):
mean2 = torch.mean(X2, dim=1)
sqrtcov1 = _matrix_pow(ref_cov, 0.5)
cov2 = torch.func.vmap(lambda x: torch.cov(x.mT))(X2)
covterm = torch.func.vmap(torch.trace)(
ref_cov + cov2 - 2 * _matrix_pow(sqrtcov1 @ cov2 @ sqrtcov1, 0.5)
)
return (1 * torch.linalg.norm(ref_mu - mean2, dim=-1) ** 2 + 1 * covterm) ** 0.5