-
Notifications
You must be signed in to change notification settings - Fork 0
/
distribs.py
81 lines (63 loc) · 2.11 KB
/
distribs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import tensorflow as tf
import numpy as np
#Categorical Distribution
class CategoricalDistrib():
#--------------------------
# Constructor
#--------------------------
def __init__(self, logits, alpha=1.0):
self.logits = logits
self.logits_over_alpha = logits / alpha
self.alpha = alpha
#--------------------------
# Negative log prob
#--------------------------
def neg_logp(self, x):
one_hot_x = tf.one_hot(x, self.logits_over_alpha.get_shape().as_list()[-1])
return tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.logits_over_alpha, labels=one_hot_x)
#--------------------------
# Entropy
#--------------------------
def entropy(self):
a0 = self.logits_over_alpha - tf.reduce_max(self.logits_over_alpha, axis=-1, keepdims=True)
ea0 = tf.exp(a0)
z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1)
#--------------------------
# Sample (greedy)
#--------------------------
def sample_greedy(self):
u = tf.random_uniform(tf.shape(self.logits_over_alpha))
return tf.argmax(self.logits_over_alpha - tf.log(-tf.log(u)), axis=-1)
#--------------------------
# Sample
#--------------------------
def sample(self):
return tf.multinomial(self.logits_over_alpha, 1)[:, 0]
#Diagonal Gaussian Distribution
class DiagGaussianDistrib():
#--------------------------
# Constructor
#--------------------------
def __init__(self, mean, logstd):
self.mean = mean
self.logstd = logstd
self.std = tf.exp(self.logstd)
#--------------------------
# Negative log prob
#--------------------------
def neg_logp(self, x):
return 0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), axis=-1) \
+ 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[-1]) \
+ tf.reduce_sum(self.logstd, axis=-1)
#--------------------------
# Entropy
#--------------------------
def entropy(self):
return tf.reduce_sum(self.logstd + 0.5 * np.log(2.0 * np.pi * np.e), axis=-1)
#--------------------------
# Sample
#--------------------------
def sample(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))