-
Notifications
You must be signed in to change notification settings - Fork 5
/
BayesLinear
184 lines (139 loc) · 7.2 KB
/
BayesLinear
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import torch
import torch.nn as nn
import math
import numpy as np
from torch.distributions.normal import Normal
from numpy.random import logistic
import torch.nn.functional as F
from numpy import triu, ones
## For numerical stability
EPS_FACTOR = 1e-10
#torch.manual_seed(0)
#np.random.seed(0)
class BayesLinear(nn.Module):
def __init__(self, h1, h2, bypass = True):
super().__init__()
self.bypass = bypass
self.input_size = h1
self.output_size = h2
self.total_dim = h1*h2 ## If we see it as independent Gaussians then this is the full dimension of that distribution
## The init for the weight and bias parameters w.r.t the Normal dist
mu_init_var = math.sqrt(2) * math.sqrt(2 / (h1 + h2))
self.weight_mu = nn.Parameter(torch.Tensor(h1, h2).normal_(0, mu_init_var))
self.weight_rho = nn.Parameter(torch.Tensor(h1, h2).uniform_(-5,-4))
self.bias_mu = nn.Parameter(torch.Tensor(h2).normal_(0, 0.001))
self.bias_rho = nn.Parameter(torch.Tensor(h2).uniform_(-5, -4))
## The init for the layer mean and variance
layer_mean = h2/2 if h2 != 1 else 1
self.layer_mu = nn.Parameter(torch.Tensor(1).normal_(layer_mean, 10)) ## This one can be changed for different init for the layer mean. I had 0.5 as the variance
self.layer_rho = nn.Parameter(torch.Tensor(1).uniform_(1, 10))
## The init if the layer has a bypass parameter
if bypass:
self.bypass_a = nn.Parameter(torch.Tensor(1).normal_(0,1))
## Some prior parameters under here
self.tau = torch.tensor(1.) ## The temperature
self.prior_weight_mu = torch.tensor(0.)
self.prior_weight_sigma = torch.tensor(2.)
self.prior_layer_mu = torch.tensor(h2/2) ## This one can be changed for different prior beliefs on the size of the network
prior_var = (h2-1)/2 if h2 != 1 else 1.0
self.prior_layer_sigma = torch.tensor(prior_var)
self.prior_layer_pi = self.layer_pi(self.prior_layer_mu, self.prior_layer_sigma)
self.prior_bypass_pi = torch.tensor(0.5)
## The upper triangular matrix
self.U = torch.from_numpy(triu(ones((h2, h2)))).float()
def truncated_normal(self, x, mu, sigma):
mu_n = mu.clone().detach()
sigma_n = sigma.clone().detach()
norm = Normal(mu_n, sigma_n)
left_val = norm.cdf(1)
right_val = norm.cdf(self.output_size)
def log_normal(x, mu, sigma):
return -1 / 2 * torch.log(2*math.pi*sigma**2 + EPS_FACTOR) - 1/ (2 * sigma**2 + EPS_FACTOR) * (x-mu)**2
p1 = log_normal(x, mu, sigma)
p2 = torch.log(right_val - left_val) if right_val != left_val else torch.tensor(0.)
return torch.exp(p1-p2)
## The parameter pi for the layer bypass gumbal later. bypass_a should be inserted here
def bypass_pi(self, x):
return 1 / (1 + torch.exp(-x))
## Converts rho to sigma parameter
def sigma(self, x):
return torch.log1p(torch.exp(x))
## This one gives me the probability for the layer vector, later I also have to sample from Gumbal using this as pi
def layer_pi(self, mu, sigma):
input_x = torch.arange(1,self.output_size + 1).float()
p1 = self.truncated_normal(input_x, mu, sigma)
p2 = torch.sum(p1)
return p1 / p2
## s_p here is the sampled simplex vector
def calc_layer_VI(self, s_p, post_lay_pi):
## Calc the prior term here first
p1 = torch.log(torch.arange(1,self.output_size).float() + EPS_FACTOR).sum()
p2 = (self.output_size - 1)*torch.log(self.tau)
p3 = (torch.log(self.prior_layer_pi + EPS_FACTOR) - (self.tau + 1)*torch.log(s_p + EPS_FACTOR)).sum()
p4 = self.output_size * torch.log((self.prior_layer_pi * torch.pow(s_p + EPS_FACTOR, -self.tau)).sum())
log_prior = p1 + p2 + p3 - p4
## And now calculate the variational posterior term. But I guess in the exact same way just with different parameters?
p1 = torch.log(torch.arange(1,self.output_size).float() + EPS_FACTOR).sum()
p2 = (self.output_size - 1)*torch.log(self.tau)
p3 = (torch.log(post_lay_pi + EPS_FACTOR) - (self.tau + 1)*torch.log(s_p + EPS_FACTOR)).sum()
p4 = self.output_size * torch.log((post_lay_pi * torch.pow(s_p + EPS_FACTOR, -self.tau)).sum() + EPS_FACTOR)
log_post = p1 + p2 + p3 - p4
return log_post - log_prior
def calc_bypass_VI(self, s_p, post_bypass_pi):
## I am a little bit unsure if they have written it correctly in the article, should not 1-pi enter somewhere?
## First calculate the prior term
p1 = torch.log(self.tau)
p2 = torch.log(self.prior_bypass_pi + EPS_FACTOR) - (self.tau + 1) *(torch.log(s_p + EPS_FACTOR) + torch.log(1-s_p +EPS_FACTOR))
p3 = 2*torch.log(self.prior_bypass_pi*torch.pow(s_p + EPS_FACTOR,-self.tau) + torch.pow((1-s_p + EPS_FACTOR),-self.tau) +EPS_FACTOR)
log_prior = p1 + p2 - p3
## And now the variational posterior terms here
p1 = torch.log(self.tau)
p2 = torch.log(post_bypass_pi + EPS_FACTOR) - (self.tau + 1) *(torch.log(s_p + EPS_FACTOR) + torch.log(1-s_p + EPS_FACTOR))
p3 = 2*torch.log(post_bypass_pi*torch.pow(s_p + EPS_FACTOR,-self.tau) + torch.pow((1-s_p),-self.tau) + EPS_FACTOR)
log_post = p1 + p2 - p3
return log_post - log_prior
def calc_normal_KL(self, mu, sigma):
p1 = 2*self.total_dim*torch.log(self.prior_weight_sigma + EPS_FACTOR)
p2 = 2*torch.log(sigma + EPS_FACTOR).sum()
p3 = self.total_dim
p4 = (1 / self.prior_weight_sigma * sigma).sum()
p5 = (1 / self.prior_weight_sigma * mu**2).sum()
KL = 1 / 2 * (p1 - p2 - p3 + p4 + p5)
return KL
## Should just return a sample of the concete categorical distribution
def concrete_cat_sample(self, layer_pi):
eps = -torch.log(-torch.log(torch.Tensor(self.output_size).uniform_(0,1) + EPS_FACTOR))
## Calc the sample here
p1 = (torch.log(layer_pi + EPS_FACTOR) + eps) / self.tau
p2 = torch.exp(p1)
p3 = p2.sum()
return p2 / p3
def forward(self, x):
if self.bypass:
bypass_pi = self.bypass_pi(self.bypass_a)
self.by_pi = bypass_pi
random_logistic = torch.tensor(logistic(0,1))
gamma = 1 / (1 + torch.exp((-(torch.log(bypass_pi + EPS_FACTOR)-torch.log(1-bypass_pi + EPS_FACTOR) + random_logistic)/self.tau)))
lay_mu = self.layer_mu
lay_sig = self.sigma(self.layer_rho)
lay_pi = self.layer_pi(lay_mu, lay_sig)
self.lay_pi = lay_pi
s_p = self.concrete_cat_sample(lay_pi)
s_p = torch.unsqueeze(s_p,0)
soft_mask = torch.mm(self.U, s_p.t()).t()
w_mu = self.weight_mu
w_sigma = self.sigma(self.weight_rho)
weight = w_mu + w_sigma * torch.Tensor(w_mu.shape).normal_(0,1)
b_mu = self.bias_mu
b_sigma = self.sigma(self.bias_rho)
bias = b_mu + b_sigma * torch.Tensor(b_mu.shape).normal_(0,1)
## I don't have any bias term for the layer and bypass, they only affect the regular neurons
self.layer_VI = self.calc_layer_VI(s_p, lay_pi)
self.weight_VI = self.calc_normal_KL(w_mu, w_sigma)
self.bias_VI = self.calc_normal_KL(b_mu, b_sigma)
if self.bypass:
self.bypass_VI = self.calc_bypass_VI(gamma, bypass_pi).item()
if self.bypass:
return (1-gamma)*F.linear(x, torch.t(weight), bias) * soft_mask + gamma * x
else:
return F.linear(x, torch.t(weight), bias) * soft_mask