/
distributions.py
77 lines (66 loc) · 2.19 KB
/
distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from torch.distributions import Distribution, Normal
import rlkit.torch.pytorch_util as ptu
class TanhNormal(Distribution):
"""
Represent distribution of X where
X ~ tanh(Z)
Z ~ N(mean, std)
Note: this is not very numerically stable.
"""
def __init__(self, normal_mean, normal_std, epsilon=1e-6):
"""
:param normal_mean: Mean of the normal distribution
:param normal_std: Std of the normal distribution
:param epsilon: Numerical stability epsilon when computing log-prob.
"""
self.normal_mean = normal_mean
self.normal_std = normal_std
self.normal = Normal(normal_mean, normal_std)
self.epsilon = epsilon
def sample_n(self, n, return_pre_tanh_value=False):
z = self.normal.sample_n(n)
if return_pre_tanh_value:
return torch.tanh(z), z
else:
return torch.tanh(z)
def log_prob(self, value, pre_tanh_value=None):
"""
:param value: some value, x
:param pre_tanh_value: arctanh(x)
:return:
"""
if pre_tanh_value is None:
pre_tanh_value = torch.log(
(1+value) / (1-value)
) / 2
return self.normal.log_prob(pre_tanh_value) - torch.log(
1 - value * value + self.epsilon
)
def sample(self, return_pretanh_value=False):
"""
Gradients will and should *not* pass through this operation.
See https://github.com/pytorch/pytorch/issues/4620 for discussion.
"""
z = self.normal.sample().detach()
if return_pretanh_value:
return torch.tanh(z), z
else:
return torch.tanh(z)
def rsample(self, return_pretanh_value=False):
"""
Sampling in the reparameterization case.
"""
z = (
self.normal_mean +
self.normal_std *
Normal(
ptu.zeros(self.normal_mean.size()),
ptu.ones(self.normal_std.size())
).sample()
)
z.requires_grad_()
if return_pretanh_value:
return torch.tanh(z), z
else:
return torch.tanh(z)