-
Notifications
You must be signed in to change notification settings - Fork 257
/
Copy pathtrain_variational_autoencoder_pytorch.py
205 lines (173 loc) · 7.16 KB
/
train_variational_autoencoder_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""Fit a variational autoencoder to MNIST.
Notes:
- run https://github.com/altosaar/proximity_vi/blob/master/get_binary_mnist.py to download binary MNIST file
- batch size is the innermost dimension, then the sample dimension, then latent dimension
"""
import torch
import torch.utils
import torch.utils.data
from torch import nn
import nomen
import yaml
import numpy as np
import logging
import pathlib
import h5py
config = """
latent_size: 128
data_size: 784
learning_rate: 0.001
batch_size: 128
test_batch_size: 512
max_iterations: 100000
log_interval: 5000
n_samples: 128
use_gpu: true
train_dir: $TMPDIR
"""
class Model(nn.Module):
"""Bernoulli model parameterized by a generative network with Gaussian latents for MNIST."""
def __init__(self, latent_size, data_size, batch_size, device):
super().__init__()
self.p_z = torch.distributions.Normal(
torch.zeros(latent_size, device=device),
torch.ones(latent_size, device=device))
self.log_p_x = BernoulliLogProb()
self.generative_network = NeuralNetwork(input_size=latent_size,
output_size=data_size,
hidden_size=latent_size * 2)
def forward(self, z, x):
"""Return log probability of model."""
log_p_z = self.p_z.log_prob(z).sum(-1)
logits = self.generative_network(z)
# unsqueeze sample dimension
logits, x = torch.broadcast_tensors(logits, x.unsqueeze(1))
log_p_x = self.log_p_x(logits, x).sum(-1)
return log_p_z + log_p_x
class Variational(nn.Module):
"""Approximate posterior parameterized by an inference network."""
def __init__(self, latent_size, data_size):
super().__init__()
self.inference_network = NeuralNetwork(input_size=data_size,
output_size=latent_size * 2,
hidden_size=latent_size*2)
self.log_q_z = NormalLogProb()
self.softplus = nn.Softplus()
def forward(self, x, n_samples=1):
"""Return sample of latent variable and log prob."""
loc, scale_arg = torch.chunk(self.inference_network(x).unsqueeze(1), chunks=2, dim=-1)
scale = self.softplus(scale_arg)
eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device)
z = loc + scale * eps # reparameterization
log_q_z = self.log_q_z(loc, scale, z).sum(-1)
return z, log_q_z
class NeuralNetwork(nn.Module):
def __init__(self, input_size, output_size, hidden_size):
super().__init__()
modules = [nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size)]
self.net = nn.Sequential(*modules)
def forward(self, input):
return self.net(input)
class NormalLogProb(nn.Module):
def __init__(self):
super().__init__()
def forward(self, loc, scale, z):
var = torch.pow(scale, 2)
return -0.5 * torch.log(2 * np.pi * var) - torch.pow(z - loc, 2) / (2 * var)
class BernoulliLogProb(nn.Module):
def __init__(self):
super().__init__()
self.bce_with_logits = nn.BCEWithLogitsLoss(reduction='none')
def forward(self, logits, target):
# bernoulli log prob is equivalent to negative binary cross entropy
return -self.bce_with_logits(logits, target)
def cycle(iterable):
while True:
for x in iterable:
yield x
def load_binary_mnist(cfg, **kwcfg):
f = h5py.File(pathlib.os.path.join(pathlib.os.environ['DAT'], 'binarized_mnist.hdf5'), 'r')
x_train = f['train'][::]
x_val = f['valid'][::]
x_test = f['test'][::]
train = torch.utils.data.TensorDataset(torch.from_numpy(x_train))
train_loader = torch.utils.data.DataLoader(train, batch_size=cfg.batch_size, shuffle=True)
validation = torch.utils.data.TensorDataset(torch.from_numpy(x_val))
val_loader = torch.utils.data.DataLoader(validation, batch_size=cfg.test_batch_size, shuffle=False)
test = torch.utils.data.TensorDataset(torch.from_numpy(x_test))
test_loader = torch.utils.data.DataLoader(test, batch_size=cfg.test_batch_size, shuffle=False)
return train_loader, val_loader, test_loader
def evaluate(n_samples, model, variational, eval_data):
model.eval()
total_log_p_x = 0.0
total_elbo = 0.0
for batch in eval_data:
x = batch[0].to(next(model.parameters()).device)
z, log_q_z = variational(x, n_samples)
log_p_x_and_z = model(z, x)
# importance sampling of approximate marginal likelihood with q(z)
# as the proposal, and logsumexp in the sample dimension
elbo = log_p_x_and_z - log_q_z
log_p_x = torch.logsumexp(elbo, dim=1) - np.log(n_samples)
# average over sample dimension, sum over minibatch
total_elbo += elbo.cpu().numpy().mean(1).sum()
# sum over minibatch
total_log_p_x += log_p_x.cpu().numpy().sum()
n_data = len(eval_data.dataset)
return total_elbo / n_data, total_log_p_x / n_data
if __name__ == '__main__':
dictionary = yaml.load(config)
cfg = nomen.Config(dictionary)
device = torch.device("cuda:0" if cfg.use_gpu else "cpu")
model = Model(latent_size=cfg.latent_size,
data_size=cfg.data_size,
batch_size=cfg.batch_size,
device=device)
variational = Variational(latent_size=cfg.latent_size,
data_size=cfg.data_size)
model.to(device)
variational.to(device)
optimizer = torch.optim.RMSprop(list(model.parameters()) +
list(variational.parameters()),
lr=cfg.learning_rate,
centered=True)
kwargs = {'num_workers': 0, 'pin_memory': False} if cfg.use_gpu else {}
train_data, valid_data, test_data = load_binary_mnist(cfg, **kwargs)
best_valid_elbo = -np.inf
num_no_improvement = 0
for step, batch in enumerate(cycle(train_data)):
x = batch[0].to(device)
model.zero_grad()
variational.zero_grad()
z, log_q_z = variational(x)
log_p_x_and_z = model(z, x)
# average over sample dimension
elbo = (log_p_x_and_z - log_q_z).mean(1)
# sum over batch dimension
loss = -elbo.sum(0)
loss.backward()
optimizer.step()
if step % cfg.log_interval == 0:
print(f'step:\t{step}\ttrain elbo: {elbo.detach().cpu().numpy().mean():.2f}')
with torch.no_grad():
valid_elbo, valid_log_p_x = evaluate(cfg.n_samples, model, variational, valid_data)
print(f'step:\t{step}\t\tvalid elbo: {valid_elbo:.2f}\tvalid log p(x): {valid_log_p_x:.2f}')
if valid_elbo > best_valid_elbo:
best_valid_elbo = valid_elbo
states = {'model': model.state_dict(),
'variational': variational.state_dict()}
torch.save(states, cfg.train_dir / 'best_state_dict')
else:
num_no_improvement += 1
if num_no_improvement > 5:
checkpoint = torch.load(cfg.train_dir / 'best_state_dict')
model.load_state_dict(checkpoint['model'])
variational.load_state_dict(checkpoint['variational'])
with torch.no_grad():
test_elbo, test_log_p_x = evaluate(cfg.n_samples, model, variational, test_data)
print(f'step:\t{step}\t\ttest elbo: {test_elbo:.2f}\ttest log p(x): {test_log_p_x:.2f}')
break