-
Notifications
You must be signed in to change notification settings - Fork 53
/
Gaussian.py
67 lines (58 loc) · 2.72 KB
/
Gaussian.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from torch.distributions import Normal as Gaussian_Torch
from .distribution_utils import DistributionClass
from ..utils import *
class Gaussian(DistributionClass):
"""
Gaussian distribution class.
Distributional Parameters
-------------------------
loc: torch.Tensor
Mean of the distribution (often referred to as mu).
scale: torch.Tensor
Standard deviation of the distribution (often referred to as sigma).
Source
-------------------------
https://pytorch.org/docs/stable/distributions.html#normal
Parameters
-------------------------
stabilization: str
Stabilization method for the Gradient and Hessian. Options are "None", "MAD", "L2".
response_fn: str
Response function for transforming the distributional parameters to the correct support. Options are
"exp" (exponential) or "softplus" (softplus).
loss_fn: str
Loss function. Options are "nll" (negative log-likelihood) or "crps" (continuous ranked probability score).
Note that if "crps" is used, the Hessian is set to 1, as the current CRPS version is not twice differentiable.
Hence, using the CRPS disregards any variation in the curvature of the loss function.
"""
def __init__(self,
stabilization: str = "None",
response_fn: str = "exp",
loss_fn: str = "nll"
):
# Input Checks
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.")
if loss_fn not in ["nll", "crps"]:
raise ValueError("Invalid loss function. Please choose from 'nll' or 'crps'.")
# Specify Response Functions
response_functions = {"exp": exp_fn, "softplus": softplus_fn}
if response_fn in response_functions:
response_fn = response_functions[response_fn]
else:
raise ValueError(
"Invalid response function. Please choose from 'exp' or 'softplus'.")
# Set the parameters specific to the distribution
distribution = Gaussian_Torch
param_dict = {"loc": identity_fn, "scale": response_fn}
torch.distributions.Distribution.set_default_validate_args(False)
# Specify Distribution Class
super().__init__(distribution=distribution,
univariate=True,
discrete=False,
n_dist_param=len(param_dict),
stabilization=stabilization,
param_dict=param_dict,
distribution_arg_names=list(param_dict.keys()),
loss_fn=loss_fn
)