-
-
Notifications
You must be signed in to change notification settings - Fork 100
/
core.py
496 lines (412 loc) · 14.9 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
import torch
import torch.nn as nn
import numpy as np
from . import distributions
from . import utils
class NormalizingFlow(nn.Module):
"""
Normalizing Flow model to approximate target distribution
"""
def __init__(self, q0, flows, p=None):
"""Constructor
Args:
q0: Base distribution
flows: List of flows
p: Target distribution
"""
super().__init__()
self.q0 = q0
self.flows = nn.ModuleList(flows)
self.p = p
def forward(self, z):
"""Transforms latent variable z to the flow variable x
Args:
z: Batch in the latent space
Returns:
Batch in the space of the target distribution
"""
for flow in self.flows:
z, _ = flow(z)
return z
def forward_and_log_det(self, z):
"""Transforms latent variable z to the flow variable x and
computes log determinant of the Jacobian
Args:
z: Batch in the latent space
Returns:
Batch in the space of the target distribution,
log determinant of the Jacobian
"""
log_det = torch.zeros(len(z), device=z.device)
for flow in self.flows:
z, log_d = flow(z)
log_det -= log_d
return z, log_det
def inverse(self, x):
"""Transforms flow variable x to the latent variable z
Args:
x: Batch in the space of the target distribution
Returns:
Batch in the latent space
"""
for i in range(len(self.flows) - 1, -1, -1):
x, _ = self.flows[i].inverse(x)
return x
def inverse_and_log_det(self, x):
"""Transforms flow variable x to the latent variable z and
computes log determinant of the Jacobian
Args:
x: Batch in the space of the target distribution
Returns:
Batch in the latent space, log determinant of the
Jacobian
"""
log_det = torch.zeros(len(x), device=x.device)
for i in range(len(self.flows) - 1, -1, -1):
x, log_d = self.flows[i].inverse(x)
log_det += log_d
return x, log_det
def forward_kld(self, x):
"""Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)
Args:
x: Batch sampled from target distribution
Returns:
Estimate of forward KL divergence averaged over batch
"""
log_q = torch.zeros(len(x), device=x.device)
z = x
for i in range(len(self.flows) - 1, -1, -1):
z, log_det = self.flows[i].inverse(z)
log_q += log_det
log_q += self.q0.log_prob(z)
return -torch.mean(log_q)
def reverse_kld(self, num_samples=1, beta=1.0, score_fn=True):
"""Estimates reverse KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)
Args:
num_samples: Number of samples to draw from base distribution
beta: Annealing parameter, see [arXiv 1505.05770](https://arxiv.org/abs/1505.05770)
score_fn: Flag whether to include score function in gradient, see [arXiv 1703.09194](https://arxiv.org/abs/1703.09194)
Returns:
Estimate of the reverse KL divergence averaged over latent samples
"""
z, log_q_ = self.q0(num_samples)
log_q = torch.zeros_like(log_q_)
log_q += log_q_
for flow in self.flows:
z, log_det = flow(z)
log_q -= log_det
if not score_fn:
z_ = z
log_q = torch.zeros(len(z_), device=z_.device)
utils.set_requires_grad(self, False)
for i in range(len(self.flows) - 1, -1, -1):
z_, log_det = self.flows[i].inverse(z_)
log_q += log_det
log_q += self.q0.log_prob(z_)
utils.set_requires_grad(self, True)
log_p = self.p.log_prob(z)
return torch.mean(log_q) - beta * torch.mean(log_p)
def reverse_alpha_div(self, num_samples=1, alpha=1, dreg=False):
"""Alpha divergence when sampling from q
Args:
num_samples: Number of samples to draw
dreg: Flag whether to use Double Reparametrized Gradient estimator, see [arXiv 1810.04152](https://arxiv.org/abs/1810.04152)
Returns:
Alpha divergence
"""
z, log_q = self.q0(num_samples)
for flow in self.flows:
z, log_det = flow(z)
log_q -= log_det
log_p = self.p.log_prob(z)
if dreg:
w_const = torch.exp(log_p - log_q).detach()
z_ = z
log_q = torch.zeros(len(z_), device=z_.device)
utils.set_requires_grad(self, False)
for i in range(len(self.flows) - 1, -1, -1):
z_, log_det = self.flows[i].inverse(z_)
log_q += log_det
log_q += self.q0.log_prob(z_)
utils.set_requires_grad(self, True)
w = torch.exp(log_p - log_q)
w_alpha = w_const**alpha
w_alpha = w_alpha / torch.mean(w_alpha)
weights = (1 - alpha) * w_alpha + alpha * w_alpha**2
loss = -alpha * torch.mean(weights * torch.log(w))
else:
loss = np.sign(alpha - 1) * torch.logsumexp(alpha * (log_p - log_q), 0)
return loss
def sample(self, num_samples=1):
"""Samples from flow-based approximate distribution
Args:
num_samples: Number of samples to draw
Returns:
Samples, log probability
"""
z, log_q = self.q0(num_samples)
for flow in self.flows:
z, log_det = flow(z)
log_q -= log_det
return z, log_q
def log_prob(self, x):
"""Get log probability for batch
Args:
x: Batch
Returns:
log probability
"""
log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
z = x
for i in range(len(self.flows) - 1, -1, -1):
z, log_det = self.flows[i].inverse(z)
log_q += log_det
log_q += self.q0.log_prob(z)
return log_q
def save(self, path):
"""Save state dict of model
Args:
path: Path including filename where to save model
"""
torch.save(self.state_dict(), path)
def load(self, path):
"""Load model from state dict
Args:
path: Path including filename where to load model from
"""
self.load_state_dict(torch.load(path))
class ClassCondFlow(nn.Module):
"""
Class conditional normalizing Flow model
"""
def __init__(self, q0, flows):
"""Constructor
Args:
q0: Base distribution
flows: List of flows
"""
super().__init__()
self.q0 = q0
self.flows = nn.ModuleList(flows)
def forward_kld(self, x, y):
"""Estimates forward KL divergence, see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)
Args:
x: Batch sampled from target distribution
Returns:
Estimate of forward KL divergence averaged over batch
"""
log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
z = x
for i in range(len(self.flows) - 1, -1, -1):
z, log_det = self.flows[i].inverse(z)
log_q += log_det
log_q += self.q0.log_prob(z, y)
return -torch.mean(log_q)
def sample(self, num_samples=1, y=None):
"""Samples from flow-based approximate distribution
Args:
num_samples: Number of samples to draw
y: Classes to sample from, will be sampled uniformly if None
Returns:
Samples, log probability
"""
z, log_q = self.q0(num_samples, y)
for flow in self.flows:
z, log_det = flow(z)
log_q -= log_det
return z, log_q
def log_prob(self, x, y):
"""Get log probability for batch
Args:
x: Batch
y: Classes of x
Returns:
log probability
"""
log_q = torch.zeros(len(x), dtype=x.dtype, device=x.device)
z = x
for i in range(len(self.flows) - 1, -1, -1):
z, log_det = self.flows[i].inverse(z)
log_q += log_det
log_q += self.q0.log_prob(z, y)
return log_q
def save(self, path):
"""Save state dict of model
Args:
param path: Path including filename where to save model
"""
torch.save(self.state_dict(), path)
def load(self, path):
"""Load model from state dict
Args:
path: Path including filename where to load model from
"""
self.load_state_dict(torch.load(path))
class MultiscaleFlow(nn.Module):
"""
Normalizing Flow model with multiscale architecture, see RealNVP or Glow paper
"""
def __init__(self, q0, flows, merges, transform=None, class_cond=True):
"""Constructor
Args:
q0: List of base distribution
flows: List of list of flows for each level
merges: List of merge/split operations (forward pass must do merge)
transform: Initial transformation of inputs
class_cond: Flag, indicated whether model has class conditional
base distributions
"""
super().__init__()
self.q0 = nn.ModuleList(q0)
self.num_levels = len(self.q0)
self.flows = torch.nn.ModuleList([nn.ModuleList(flow) for flow in flows])
self.merges = torch.nn.ModuleList(merges)
self.transform = transform
self.class_cond = class_cond
def forward_kld(self, x, y=None):
"""Estimates forward KL divergence, see see [arXiv 1912.02762](https://arxiv.org/abs/1912.02762)
Args:
x: Batch sampled from target distribution
y: Batch of targets, if applicable
Returns:
Estimate of forward KL divergence averaged over batch
"""
return -torch.mean(self.log_prob(x, y))
def forward(self, x, y=None):
"""Get negative log-likelihood for maximum likelihood training
Args:
x: Batch of data
y: Batch of targets, if applicable
Returns:
Negative log-likelihood of the batch
"""
return -self.log_prob(x, y)
def sample(self, num_samples=1, y=None, temperature=None):
"""Samples from flow-based approximate distribution
Args:
num_samples: Number of samples to draw
y: Classes to sample from, will be sampled uniformly if None
temperature: Temperature parameter for temp annealed sampling
Returns:
Samples, log probability
"""
if temperature is not None:
self.set_temperature(temperature)
for i in range(len(self.q0)):
if self.class_cond:
z_, log_q_ = self.q0[i](num_samples, y)
else:
z_, log_q_ = self.q0[i](num_samples)
if i == 0:
log_q = log_q_
z = z_
else:
log_q += log_q_
z, log_det = self.merges[i - 1]([z, z_])
log_q -= log_det
for flow in self.flows[i]:
z, log_det = flow(z)
log_q -= log_det
if self.transform is not None:
z, log_det = self.transform(z)
log_q -= log_det
if temperature is not None:
self.reset_temperature()
return z, log_q
def log_prob(self, x, y):
"""Get log probability for batch
Args:
x: Batch
y: Classes of x
Returns:
log probability
"""
log_q = 0
z = x
if self.transform is not None:
z, log_det = self.transform.inverse(z)
log_q += log_det
for i in range(len(self.q0) - 1, -1, -1):
for j in range(len(self.flows[i]) - 1, -1, -1):
z, log_det = self.flows[i][j].inverse(z)
log_q += log_det
if i > 0:
[z, z_], log_det = self.merges[i - 1].inverse(z)
log_q += log_det
else:
z_ = z
if self.class_cond:
log_q += self.q0[i].log_prob(z_, y)
else:
log_q += self.q0[i].log_prob(z_)
return log_q
def save(self, path):
"""Save state dict of model
Args:
path: Path including filename where to save model
"""
torch.save(self.state_dict(), path)
def load(self, path):
"""Load model from state dict
Args:
path: Path including filename where to load model from
"""
self.load_state_dict(torch.load(path))
def set_temperature(self, temperature):
"""Set temperature for temperature a annealed sampling
Args:
temperature: Temperature parameter
"""
for q0 in self.q0:
if hasattr(q0, "temperature"):
q0.temperature = temperature
else:
raise NotImplementedError(
"One base function does not "
"support temperature annealed sampling"
)
def reset_temperature(self):
"""
Set temperature values of base distributions back to None
"""
self.set_temperature(None)
class NormalizingFlowVAE(nn.Module):
"""
VAE using normalizing flows to express approximate distribution
"""
def __init__(self, prior, q0=distributions.Dirac(), flows=None, decoder=None):
"""Constructor of normalizing flow model
Args:
prior: Prior distribution of te VAE, i.e. Gaussian
decoder: Optional decoder
flows: Flows to transform output of base encoder
q0: Base Encoder
"""
super().__init__()
self.prior = prior
self.decoder = decoder
self.flows = nn.ModuleList(flows)
self.q0 = q0
def forward(self, x, num_samples=1):
"""Takes data batch, samples num_samples for each data point from base distribution
Args:
x: data batch
num_samples: number of samples to draw for each data point
Returns:
latent variables for each batch and sample, log_q, and log_p
"""
z, log_q = self.q0(x, num_samples=num_samples)
# Flatten batch and sample dim
z = z.view(-1, *z.size()[2:])
log_q = log_q.view(-1, *log_q.size()[2:])
for flow in self.flows:
z, log_det = flow(z)
log_q -= log_det
log_p = self.prior.log_prob(z)
if self.decoder is not None:
log_p += self.decoder.log_prob(x, z)
# Separate batch and sample dimension again
z = z.view(-1, num_samples, *z.size()[1:])
log_q = log_q.view(-1, num_samples, *log_q.size()[1:])
log_p = log_p.view(-1, num_samples, *log_p.size()[1:])
return z, log_q, log_p