Skip to content

Commit

Permalink
[Issue #36] Added forward and inverse method with log determinant to …
Browse files Browse the repository at this point in the history
…MultiscaleFlow
  • Loading branch information
VincentStimper committed Jun 2, 2023
1 parent 705edcc commit ef73663
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
49 changes: 49 additions & 0 deletions normflows/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,55 @@ def forward(self, x, y=None):
"""
return -self.log_prob(x, y)

def forward_and_log_det(self, z):
"""Get observed variable x from list of latent variables z
Args:
z: List of latent variables
Returns:
Observed variable x, log determinant of Jacobian
"""
log_det = torch.zeros(len(z[0]), dtype=z[0].dtype, device=z[0].device)
for i in range(len(self.q0)):
if i == 0:
z_ = z[0]
else:
z_, log_det_ = self.merges[i - 1]([z_, z[i]])
log_det += log_det_
for flow in self.flows[i]:
z_, log_det_ = flow(z_)
log_det += log_det_
if self.transform is not None:
z_, log_det_ = self.transform(z_)
log_det += log_det_
return z_, log_det

def inverse_and_log_det(self, x):
"""Get latent variable z from observed variable x
Args:
x: Observed variable
Returns:
List of latent variables z, log determinant of Jacobian
"""
log_det = torch.zeros(len(x), dtype=x.dtype, device=x.device)
if self.transform is not None:
x, log_det_ = self.transform.inverse(x)
log_det += log_det_
z = [None] * len(self.q0)
for i in range(len(self.q0) - 1, -1, -1):
for flow in reversed(self.flows[i]):
x, log_det_ = flow.inverse(x)
log_det += log_det_
if i == 0:
z[i] = x
else:
[x, z[i]], log_det_ = self.merges[i - 1].inverse(x)
log_det += log_det_
return z, log_det

def sample(self, num_samples=1, y=None, temperature=None):
"""Samples from flow-based approximate distribution
Expand Down
6 changes: 6 additions & 0 deletions normflows/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ def test_multiscale_flow(self):
fwd = model.forward(x, y)
fwd_kld = model.forward_kld(x, y)
assert_close(torch.mean(fwd), fwd_kld)
z, log_det = model.inverse_and_log_det(x)
x_, log_det_ = model.forward_and_log_det(z)
assert len(z) == L
assert x_.shape == (batch_size,) + (input_shape)
assert_close(x_, x)
assert_close(log_det, -log_det_)


def test_normalizing_flow_vae(self):
Expand Down

0 comments on commit ef73663

Please sign in to comment.