From 2e3c8bac852694af68f42ff283b2d834aa19bdb9 Mon Sep 17 00:00:00 2001 From: Vincent Stimper Date: Mon, 19 Dec 2022 11:08:06 +0100 Subject: [PATCH] Dev vs (#15) * Added function to create periodic features by concatenating sin and cos of features * Removed unused referenced * Added flow test module, test for glow * Fixed mixing warning * Old test module replaced * Test for coupling layers added * Added test for planar flow and nsf wrapper * Fixed multidim planar flow * Fix radial multidim case * Added tests for residual flow * Added link to docu in readme * Compatibility with most recent pytorch version --- README.md | 3 +- normflows/flows/affine/autoregressive_test.py | 57 +--------- normflows/flows/affine/coupling.py | 10 +- normflows/flows/affine/coupling_test.py | 46 ++++++++ normflows/flows/affine/glow.py | 6 +- normflows/flows/affine/glow_test.py | 32 ++++++ normflows/flows/base.py | 4 + normflows/flows/flow_test.py | 42 +++++++ normflows/flows/mixing.py | 4 +- .../flows/neural_spline/autoregressive.py | 4 +- .../neural_spline/autoregressive_test.py | 15 ++- normflows/flows/neural_spline/coupling.py | 1 - .../flows/neural_spline/coupling_test.py | 107 +++--------------- normflows/flows/neural_spline/flow_test.py | 31 ----- normflows/flows/neural_spline/wrapper.py | 4 +- normflows/flows/neural_spline/wrapper_test.py | 44 +++++++ normflows/flows/planar.py | 45 +++----- normflows/flows/planar_test.py | 23 ++++ normflows/flows/radial.py | 5 +- normflows/flows/radial_test.py | 19 ++++ normflows/flows/residual.py | 31 ++++- normflows/flows/residual_test.py | 51 +++++++++ normflows/nets/cnn.py | 1 - normflows/nets/lipschitz.py | 10 +- normflows/nets/made_test.py | 38 +++---- normflows/nets/mlp.py | 1 - normflows/utils/nn.py | 65 ++++++++++- normflows/utils/splines_test.py | 22 ++-- 28 files changed, 449 insertions(+), 272 deletions(-) create mode 100644 normflows/flows/affine/coupling_test.py create mode 100644 normflows/flows/affine/glow_test.py create mode 100644 normflows/flows/flow_test.py delete mode 100644 normflows/flows/neural_spline/flow_test.py create mode 100644 normflows/flows/neural_spline/wrapper_test.py create mode 100644 normflows/flows/planar_test.py create mode 100644 normflows/flows/radial_test.py create mode 100644 normflows/flows/residual_test.py diff --git a/README.md b/README.md index f088373..1cd5dd6 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,8 @@ This is a PyTorch implementation of normalizing flows. Many popular flow architectures are implemented, see the [list below](#implemented-flows). The package can be easily [installed via pip](#installation). -The basic usage is described [here](#usage). There are several sample use cases implemented in the +The basic usage is described [here](#usage), and a [full documentation](https://vincentstimper.github.io/normalizing-flows/). +is available as well. There are several sample use cases implemented in the [`example` folder](https://github.com/VincentStimper/normalizing-flows/tree/master/example), including [Glow](https://github.com/VincentStimper/normalizing-flows/blob/master/example/glow.ipynb), a [VAE](https://github.com/VincentStimper/normalizing-flows/blob/master/example/vae.py), and diff --git a/normflows/flows/affine/autoregressive_test.py b/normflows/flows/affine/autoregressive_test.py index 78435c7..6c54238 100644 --- a/normflows/flows/affine/autoregressive_test.py +++ b/normflows/flows/affine/autoregressive_test.py @@ -2,12 +2,12 @@ import unittest from normflows.flows.affine import autoregressive -from normflows.flows.neural_spline.flow_test import FlowTest +from normflows.flows.flow_test import FlowTest class MaskedAffineAutoregressiveTest(FlowTest): - def test_forward(self): - batch_size = 10 + def test_maf(self): + batch_size = 3 features = 20 inputs = torch.randn(batch_size, features) for use_residual_blocks, random_mask in [ @@ -18,61 +18,14 @@ def test_forward(self): with self.subTest( use_residual_blocks=use_residual_blocks, random_mask=random_mask ): - transform = autoregressive.MaskedAffineAutoregressive( + flow = autoregressive.MaskedAffineAutoregressive( features=features, hidden_features=30, num_blocks=5, use_residual_blocks=use_residual_blocks, random_mask=random_mask, ) - outputs, logabsdet = transform(inputs) - self.assert_tensor_is_good(outputs, [batch_size, features]) - self.assert_tensor_is_good(logabsdet, [batch_size]) - - def test_inverse(self): - batch_size = 10 - features = 20 - inputs = torch.randn(batch_size, features) - for use_residual_blocks, random_mask in [ - (False, False), - (False, True), - (True, False), - ]: - with self.subTest( - use_residual_blocks=use_residual_blocks, random_mask=random_mask - ): - transform = autoregressive.MaskedAffineAutoregressive( - features=features, - hidden_features=30, - num_blocks=5, - use_residual_blocks=use_residual_blocks, - random_mask=random_mask, - ) - outputs, logabsdet = transform.inverse(inputs) - self.assert_tensor_is_good(outputs, [batch_size, features]) - self.assert_tensor_is_good(logabsdet, [batch_size]) - - def test_forward_inverse_are_consistent(self): - batch_size = 10 - features = 20 - inputs = torch.randn(batch_size, features) - self.eps = 1e-6 - for use_residual_blocks, random_mask in [ - (False, False), - (False, True), - (True, False), - ]: - with self.subTest( - use_residual_blocks=use_residual_blocks, random_mask=random_mask - ): - transform = autoregressive.MaskedAffineAutoregressive( - features=features, - hidden_features=30, - num_blocks=5, - use_residual_blocks=use_residual_blocks, - random_mask=random_mask, - ) - self.assert_forward_inverse_are_consistent(transform, inputs) + self.checkForwardInverse(flow, inputs) if __name__ == "__main__": diff --git a/normflows/flows/affine/coupling.py b/normflows/flows/affine/coupling.py index 5eb6700..b596a73 100644 --- a/normflows/flows/affine/coupling.py +++ b/normflows/flows/affine/coupling.py @@ -2,7 +2,7 @@ import torch from torch import nn -from ..base import Flow +from ..base import Flow, zero_log_det_like_z from ..reshape import Split, Merge @@ -140,8 +140,8 @@ def forward(self, z): else: raise NotImplementedError("This scale map is not implemented.") else: - z2 += param - log_det = 0 + z2 = z2 + param + log_det = zero_log_det_like_z(z2) return [z1, z2], log_det def inverse(self, z): @@ -164,8 +164,8 @@ def inverse(self, z): else: raise NotImplementedError("This scale map is not implemented.") else: - z2 -= param - log_det = 0 + z2 = z2 - param + log_det = zero_log_det_like_z(z2) return [z1, z2], log_det diff --git a/normflows/flows/affine/coupling_test.py b/normflows/flows/affine/coupling_test.py new file mode 100644 index 0000000..6584ac2 --- /dev/null +++ b/normflows/flows/affine/coupling_test.py @@ -0,0 +1,46 @@ +import unittest +import torch + +from torch.testing import assert_close +from normflows.flows import MaskedAffineFlow, CCAffineConst +from normflows.nets import MLP +from normflows.flows.flow_test import FlowTest + + +class CouplingTest(FlowTest): + def test_mask_affine(self): + batch_size = 5 + for latent_size in [2, 7]: + with self.subTest(latent_size=latent_size): + b = torch.Tensor([1 if i % 2 == 0 else 0 for i in range(latent_size)]) + s = MLP([latent_size, 2 * latent_size, latent_size], init_zeros=True) + t = MLP([latent_size, 2 * latent_size, latent_size], init_zeros=True) + flow = MaskedAffineFlow(b, t, s) + inputs = torch.randn((batch_size, latent_size)) + self.checkForwardInverse(flow, inputs) + + def test_cc_affine(self): + batch_size = 5 + for shape in [(5,), (2, 3, 4)]: + for num_classes in [2, 5]: + with self.subTest(shape=shape, num_classes=num_classes): + flow = CCAffineConst(shape, num_classes) + x = torch.randn((batch_size,) + shape) + y = torch.rand((batch_size,) + (num_classes,)) + x_, log_det = flow(x, y) + x__, log_det_ = flow(x_, y) + + assert x_.dtype == x.dtype + assert x__.dtype == x.dtype + + assert x_.shape == x.shape + assert x__.shape == x.shape + + assert_close(x__, x) + id_ld = log_det + log_det_ + assert_close(id_ld, torch.zeros_like(id_ld)) + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/normflows/flows/affine/glow.py b/normflows/flows/affine/glow.py index 2e4e41f..3ea3f7d 100644 --- a/normflows/flows/affine/glow.py +++ b/normflows/flows/affine/glow.py @@ -49,11 +49,11 @@ def __init__( kernel_size = (3, 1, 3) num_param = 2 if scale else 1 if "channel" == split_mode: - channels_ = (channels // 2,) + 2 * (hidden_channels,) - channels_ += (num_param * ((channels + 1) // 2),) - elif "channel_inv" == split_mode: channels_ = ((channels + 1) // 2,) + 2 * (hidden_channels,) channels_ += (num_param * (channels // 2),) + elif "channel_inv" == split_mode: + channels_ = (channels // 2,) + 2 * (hidden_channels,) + channels_ += (num_param * ((channels + 1) // 2),) elif "checkerboard" in split_mode: channels_ = (channels,) + 2 * (hidden_channels,) channels_ += (num_param * channels,) diff --git a/normflows/flows/affine/glow_test.py b/normflows/flows/affine/glow_test.py new file mode 100644 index 0000000..79a809b --- /dev/null +++ b/normflows/flows/affine/glow_test.py @@ -0,0 +1,32 @@ +import unittest +import torch + +from normflows.flows import GlowBlock +from normflows.flows.flow_test import FlowTest + + +class GlowTest(FlowTest): + def test_glow(self): + img_size = (4, 4) + hidden_channels = 8 + for batch_size, channels, scale, split_mode, use_lu, net_actnorm in [ + (1, 3, True, "channel", True, False), + (2, 3, True, "channel_inv", True, False), + (1, 4, True, "channel_inv", True, True), + (2, 4, True, "channel", True, False), + (1, 4, False, "channel", False, False), + (1, 4, True, "checkerboard", True, True), + (3, 5, False, "checkerboard", False, True) + ]: + with self.subTest(batch_size=batch_size, channels=channels, + scale=scale, split_mode=split_mode, + use_lu=use_lu, net_actnorm=net_actnorm): + inputs = torch.rand((batch_size, channels) + img_size) + flow = GlowBlock(channels, hidden_channels, + scale=scale, split_mode=split_mode, + use_lu=use_lu, net_actnorm=net_actnorm) + self.checkForwardInverse(flow, inputs) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/normflows/flows/base.py b/normflows/flows/base.py index 775aa33..db84b50 100644 --- a/normflows/flows/base.py +++ b/normflows/flows/base.py @@ -76,3 +76,7 @@ def forward(self, inputs): def inverse(self, inputs): funcs = (flow.inverse for flow in self._flows[::-1]) return self._cascade(inputs, funcs) + + +def zero_log_det_like_z(z): + return torch.zeros(z.shape[0], dtype=z.dtype, device=z.device) \ No newline at end of file diff --git a/normflows/flows/flow_test.py b/normflows/flows/flow_test.py new file mode 100644 index 0000000..d6df59c --- /dev/null +++ b/normflows/flows/flow_test.py @@ -0,0 +1,42 @@ +import unittest +import torch + +from torch.testing import assert_close + + +class FlowTest(unittest.TestCase): + """ + Generic test case for flow modules + """ + def assertClose(self, actual, expected, atol=None, rtol=None): + assert_close(actual, expected, atol=atol, rtol=rtol) + + def checkForward(self, flow, inputs): + # Do forward transform + outputs, log_det = flow(inputs) + # Check type + assert outputs.dtype == inputs.dtype + # Check shape + assert outputs.shape == inputs.shape + # Return results + return outputs, log_det + + def checkInverse(self, flow, inputs): + # Do inverse transform + outputs, log_det = flow.inverse(inputs) + # Check type + assert outputs.dtype == inputs.dtype + # Check shape + assert outputs.shape == inputs.shape + # Return results + return outputs, log_det + + def checkForwardInverse(self, flow, inputs, atol=None, rtol=None): + # Check forward + outputs, log_det = self.checkForward(flow, inputs) + # Check inverse + input_, log_det_ = self.checkInverse(flow, outputs) + # Check identity + self.assertClose(input_, inputs, atol, rtol) + ld_id = log_det + log_det_ + self.assertClose(ld_id, torch.zeros_like(ld_id), atol, rtol) \ No newline at end of file diff --git a/normflows/flows/mixing.py b/normflows/flows/mixing.py index cf81067..60b23f8 100644 --- a/normflows/flows/mixing.py +++ b/normflows/flows/mixing.py @@ -70,7 +70,7 @@ def __init__(self, num_channels, use_lu=False): super().__init__() self.num_channels = num_channels self.use_lu = use_lu - Q = torch.qr(torch.randn(self.num_channels, self.num_channels))[0] + Q, _ = torch.linalg.qr(torch.randn(self.num_channels, self.num_channels)) if use_lu: P, L, U = torch.lu_unpack(*Q.lu()) self.register_buffer("P", P) # remains fixed during optimization @@ -149,7 +149,7 @@ def __init__(self, num_channels, use_lu=True): super().__init__() self.num_channels = num_channels self.use_lu = use_lu - Q = torch.qr(torch.randn(self.num_channels, self.num_channels))[0] + Q, _ = torch.linalg.qr(torch.randn(self.num_channels, self.num_channels)) if use_lu: P, L, U = torch.lu_unpack(*Q.lu()) self.register_buffer("P", P) # remains fixed during optimization diff --git a/normflows/flows/neural_spline/autoregressive.py b/normflows/flows/neural_spline/autoregressive.py index c81a011..1c9cc60 100644 --- a/normflows/flows/neural_spline/autoregressive.py +++ b/normflows/flows/neural_spline/autoregressive.py @@ -11,7 +11,7 @@ from ..affine.autoregressive import Autoregressive from normflows.nets import made as made_module from normflows.utils import splines -from normflows.utils.nn import PeriodicFeatures +from normflows.utils.nn import PeriodicFeaturesElementwise class MaskedPiecewiseRationalQuadraticAutoregressive(Autoregressive): @@ -50,7 +50,7 @@ def __init__( scale_pf = np.pi / tail_bound[ind_circ] else: scale_pf = np.pi / tail_bound - preprocessing = PeriodicFeatures(features, ind_circ, scale_pf) + preprocessing = PeriodicFeaturesElementwise(features, ind_circ, scale_pf) else: preprocessing = None diff --git a/normflows/flows/neural_spline/autoregressive_test.py b/normflows/flows/neural_spline/autoregressive_test.py index 4ceea38..2882cbf 100644 --- a/normflows/flows/neural_spline/autoregressive_test.py +++ b/normflows/flows/neural_spline/autoregressive_test.py @@ -1,23 +1,22 @@ """ Tests for the autoregressive transforms. -Code taken from https://github.com/bayesiains/nsf +Code partially taken from https://github.com/bayesiains/nsf """ import torch import unittest from normflows.flows.neural_spline import autoregressive -from normflows.flows.neural_spline.flow_test import FlowTest +from normflows.flows.flow_test import FlowTest class MaskedPiecewiseRationalQuadraticAutoregressiveFlowTest(FlowTest): - def test_forward_inverse_are_consistent(self): - batch_size = 10 - features = 20 + def test_mprqas(self): + batch_size = 5 + features = 10 inputs = torch.rand(batch_size, features) - self.eps = 1e-3 - transform = autoregressive.MaskedPiecewiseRationalQuadraticAutoregressive( + flow = autoregressive.MaskedPiecewiseRationalQuadraticAutoregressive( num_bins=10, features=features, hidden_features=30, @@ -25,7 +24,7 @@ def test_forward_inverse_are_consistent(self): use_residual_blocks=True, ) - self.assert_forward_inverse_are_consistent(transform, inputs) + self.checkForwardInverse(flow, inputs) if __name__ == "__main__": diff --git a/normflows/flows/neural_spline/coupling.py b/normflows/flows/neural_spline/coupling.py index e30750a..2ef30b9 100644 --- a/normflows/flows/neural_spline/coupling.py +++ b/normflows/flows/neural_spline/coupling.py @@ -7,7 +7,6 @@ import numpy as np import torch from torch import nn -from torch.nn import functional as F from ..base import Flow from ... import utils diff --git a/normflows/flows/neural_spline/coupling_test.py b/normflows/flows/neural_spline/coupling_test.py index e73a79e..cfc11be 100644 --- a/normflows/flows/neural_spline/coupling_test.py +++ b/normflows/flows/neural_spline/coupling_test.py @@ -1,22 +1,19 @@ """ Tests for the coupling Transforms. -Code taken from https://github.com/bayesiains/nsf +Code partially taken from https://github.com/bayesiains/nsf """ -import itertools import torch import unittest -from torch import nn - from normflows import nets as nn_ from normflows.flows.neural_spline import coupling -from normflows.flows.neural_spline.flow_test import FlowTest +from normflows.flows.flow_test import FlowTest from normflows import utils -def create_coupling_transform(cls, shape, **kwargs): +def create_coupling_transform(shape, **kwargs): if len(shape) == 1: def create_net(in_features, out_features): @@ -34,104 +31,32 @@ def create_net(in_channels, out_channels): mask = utils.masks.create_mid_split_binary_mask(shape[0]) - return cls(mask=mask, transform_net_create_fn=create_net, **kwargs), mask + return coupling.PiecewiseRationalQuadraticCoupling(mask=mask, transform_net_create_fn=create_net, **kwargs), mask -batch_size = 10 +batch_size = 5 class PiecewiseCouplingTransformTest(FlowTest): - classes = [coupling.PiecewiseRationalQuadraticCoupling] - shapes = [[20], [2, 4, 4]] - def test_forward(self): - for shape in self.shapes: - for cls in self.classes: - inputs = torch.rand(batch_size, *shape) - transform, mask = create_coupling_transform(cls, shape) - outputs, logabsdet = transform(inputs) - with self.subTest(cls=cls, shape=shape): - self.assert_tensor_is_good(outputs, [batch_size] + shape) - self.assert_tensor_is_good(logabsdet, [batch_size]) - self.assertEqual( - outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...] - ) - - def test_forward_unconstrained(self): - batch_size = 10 - for shape in self.shapes: - for cls in self.classes: - inputs = 3.0 * torch.randn(batch_size, *shape) - transform, mask = create_coupling_transform(cls, shape, tails="linear") - outputs, logabsdet = transform(inputs) - with self.subTest(cls=cls, shape=shape): - self.assert_tensor_is_good(outputs, [batch_size] + shape) - self.assert_tensor_is_good(logabsdet, [batch_size]) - self.assertEqual( - outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...] - ) - - def test_inverse(self): + def test_rqcs(self): for shape in self.shapes: - for cls in self.classes: - inputs = torch.rand(batch_size, *shape) - transform, mask = create_coupling_transform(cls, shape) - outputs, logabsdet = transform(inputs) - with self.subTest(cls=cls, shape=shape): - self.assert_tensor_is_good(outputs, [batch_size] + shape) - self.assert_tensor_is_good(logabsdet, [batch_size]) - self.assertEqual( - outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...] - ) - - def test_inverse_unconstrained(self): - for shape in self.shapes: - for cls in self.classes: - inputs = 3.0 * torch.randn(batch_size, *shape) - transform, mask = create_coupling_transform(cls, shape, tails="linear") - outputs, logabsdet = transform(inputs) - with self.subTest(cls=cls, shape=shape): - self.assert_tensor_is_good(outputs, [batch_size] + shape) - self.assert_tensor_is_good(logabsdet, [batch_size]) - self.assertEqual( - outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...] - ) - - def test_forward_inverse_are_consistent(self): - for shape in self.shapes: - for cls in self.classes: - inputs = torch.rand(batch_size, *shape) - transform, mask = create_coupling_transform(cls, shape) - with self.subTest(cls=cls, shape=shape): - self.eps = 1e-4 - self.assert_forward_inverse_are_consistent(transform, inputs) + for tails in [None, "linear"]: + with self.subTest(shape=shape, tails=tails): + inputs = torch.rand(batch_size, *shape) + flow, _ = create_coupling_transform(shape, tails=tails) + self.checkForwardInverse(flow, inputs) - def test_forward_inverse_are_consistent_unconstrained(self): - self.eps = 1e-5 - for shape in self.shapes: - for cls in self.classes: - inputs = 3.0 * torch.randn(batch_size, *shape) - transform, mask = create_coupling_transform(cls, shape, tails="linear") - with self.subTest(cls=cls, shape=shape): - self.eps = 1e-4 - self.assert_forward_inverse_are_consistent(transform, inputs) - - def test_forward_unconditional(self): + def test_rqcs_unconditional(self): for shape in self.shapes: - for cls in self.classes: + with self.subTest(shape=shape): inputs = torch.rand(batch_size, *shape) img_shape = shape[1:] if len(shape) > 1 else None - transform, mask = create_coupling_transform( - cls, shape, apply_unconditional_transform=True, img_shape=img_shape + flow, _ = create_coupling_transform( + shape, apply_unconditional_transform=True, img_shape=img_shape ) - outputs, logabsdet = transform(inputs) - with self.subTest(cls=cls, shape=shape): - self.assert_tensor_is_good(outputs, [batch_size] + shape) - self.assert_tensor_is_good(logabsdet, [batch_size]) - self.assertNotEqual( - outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...] - ) + self.checkForwardInverse(flow, inputs) if __name__ == "__main__": diff --git a/normflows/flows/neural_spline/flow_test.py b/normflows/flows/neural_spline/flow_test.py deleted file mode 100644 index 2163901..0000000 --- a/normflows/flows/neural_spline/flow_test.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import torchtestcase - -from normflows import flows - - -class FlowTest(torchtestcase.TorchTestCase): - """Base test for NSF flows.""" - - def assert_tensor_is_good(self, tensor, shape=None): - self.assertIsInstance(tensor, torch.Tensor) - self.assertFalse(torch.isnan(tensor).any()) - self.assertFalse(torch.isinf(tensor).any()) - if shape is not None: - self.assertEqual(tensor.shape, torch.Size(shape)) - - def assert_forward_inverse_are_consistent(self, transform, inputs): - inverse = flows.Reverse(transform) - identity = flows.Composite([inverse, transform]) - outputs, logabsdet = identity(inputs) - - self.assert_tensor_is_good(outputs, shape=inputs.shape) - self.assert_tensor_is_good(logabsdet, shape=inputs.shape[:1]) - self.assertEqual(outputs, inputs) - self.assertEqual(logabsdet, torch.zeros(inputs.shape[:1])) - - def assertNotEqual(self, first, second, msg=None): - if (self._eps and (first - second).abs().max().item() < self._eps) or ( - not self._eps and torch.equal(first, second) - ): - self._fail_with_message(msg, "The tensors are _not_ different!") diff --git a/normflows/flows/neural_spline/wrapper.py b/normflows/flows/neural_spline/wrapper.py index 5c6d845..fe39886 100644 --- a/normflows/flows/neural_spline/wrapper.py +++ b/normflows/flows/neural_spline/wrapper.py @@ -7,7 +7,7 @@ from .autoregressive import MaskedPiecewiseRationalQuadraticAutoregressive from ...nets.resnet import ResidualNet from ...utils.masks import create_alternating_binary_mask -from ...utils.nn import PeriodicFeatures +from ...utils.nn import PeriodicFeaturesElementwise from ...utils.splines import DEFAULT_MIN_DERIVATIVE @@ -128,7 +128,7 @@ def __init__( def transform_net_create_fn(in_features, out_features): if len(ind_circ_id) > 0: - pf = PeriodicFeatures(in_features, ind_circ_id, scale_pf) + pf = PeriodicFeaturesElementwise(in_features, ind_circ_id, scale_pf) else: pf = None net = ResidualNet( diff --git a/normflows/flows/neural_spline/wrapper_test.py b/normflows/flows/neural_spline/wrapper_test.py new file mode 100644 index 0000000..274693f --- /dev/null +++ b/normflows/flows/neural_spline/wrapper_test.py @@ -0,0 +1,44 @@ +import unittest +import torch +import numpy as np + +from normflows.flows import CoupledRationalQuadraticSpline, \ + AutoregressiveRationalQuadraticSpline, \ + CircularCoupledRationalQuadraticSpline, \ + CircularAutoregressiveRationalQuadraticSpline +from normflows.flows.flow_test import FlowTest + + +class NsfWrapperTest(FlowTest): + def test_normal_nsf(self): + batch_size = 3 + hidden_units = 128 + hidden_layers = 2 + for latent_size in [2, 5]: + for flow_cls in [CoupledRationalQuadraticSpline, + AutoregressiveRationalQuadraticSpline]: + with self.subTest(latent_size=latent_size, flow_cls=flow_cls): + flow = flow_cls(latent_size, hidden_units, hidden_layers) + inputs = torch.randn((batch_size, latent_size)) + self.checkForwardInverse(flow, inputs) + + def test_circular_nsf(self): + batch_size = 3 + hidden_units = 128 + hidden_layers = 2 + params = [(2, [1], torch.tensor([5., np.pi])), + (5, [0, 3], torch.tensor([np.pi, 5., 4., 6., 3.])), + (2, [1], torch.tensor([5., np.pi]))] + for latent_size, ind_circ, tail_bound in params: + for flow_cls in [CircularCoupledRationalQuadraticSpline, + CircularAutoregressiveRationalQuadraticSpline]: + with self.subTest(latent_size=latent_size, ind_circ=ind_circ, + tail_bound=tail_bound, flow_cls=flow_cls): + flow = flow_cls(latent_size, hidden_units, hidden_layers, + ind_circ, tail_bound=tail_bound) + inputs = 6 * torch.rand((batch_size, latent_size)) - 3 + self.checkForwardInverse(flow, inputs) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/normflows/flows/planar.py b/normflows/flows/planar.py index 884fc54..aac7ccb 100644 --- a/normflows/flows/planar.py +++ b/normflows/flows/planar.py @@ -49,44 +49,33 @@ def __init__(self, shape, act="tanh", u=None, w=None, b=None): raise NotImplementedError("Nonlinearity is not implemented.") def forward(self, z): - lin = torch.sum(self.w * z, list(range(1, self.w.dim()))) + self.b + lin = torch.sum(self.w * z, list(range(1, self.w.dim())), + keepdim=True) + self.b + inner = torch.sum(self.w * self.u) + u = self.u + (torch.log(1 + torch.exp(inner)) - 1 - inner) \ + * self.w / torch.sum(self.w ** 2) # constraint w.T * u > -1 if self.act == "tanh": - inner = torch.sum(self.w * self.u) - u = self.u + ( - torch.log(1 + torch.exp(inner)) - 1 - inner - ) * self.w / torch.sum(self.w**2) h_ = lambda x: 1 / torch.cosh(x) ** 2 elif self.act == "leaky_relu": - inner = torch.sum(self.w * self.u) - u = self.u + ( - torch.log(1 + torch.exp(inner)) - 1 - inner - ) * self.w / torch.sum( - self.w**2 - ) # constraint w.T * u neq -1, use > h_ = lambda x: (x < 0) * (self.h.negative_slope - 1.0) + 1.0 - z_ = z + u * self.h(lin.unsqueeze(1)) - log_det = torch.log(torch.abs(1 + torch.sum(self.w * u) * h_(lin))) + z_ = z + u * self.h(lin) + log_det = torch.log(torch.abs(1 + torch.sum(self.w * u) * h_(lin.reshape(-1)))) return z_, log_det def inverse(self, z): if self.act != "leaky_relu": raise NotImplementedError("This flow has no algebraic inverse.") - lin = torch.sum(self.w * z, list(range(2, self.w.dim())), keepdim=True) + self.b - inner = torch.sum(self.w * self.u) - a = ((lin + self.b) / (1 + inner) < 0) * ( + lin = torch.sum(self.w * z, list(range(1, self.w.dim()))) + self.b + a = (lin < 0) * ( self.h.negative_slope - 1.0 ) + 1.0 # absorb leakyReLU slope into u - u = a * ( - self.u - + (torch.log(1 + torch.exp(inner)) - 1 - inner) - * self.w - / torch.sum(self.w**2) - ) - z_ = z - 1 / (1 + inner) * (lin + u * self.b) - log_det = -torch.log(torch.abs(1 + torch.sum(self.w * u))) - if log_det.dim() == 0: - log_det = log_det.unsqueeze(0) - if log_det.dim() == 1: - log_det = log_det.unsqueeze(1) + inner = torch.sum(self.w * self.u) + u = self.u + (torch.log(1 + torch.exp(inner)) - 1 - inner) \ + * self.w / torch.sum(self.w ** 2) + dims = [-1] + (u.dim() - 1) * [1] + u = a.reshape(*dims) * u + inner_ = torch.sum(self.w * u, list(range(1, self.w.dim()))) + z_ = z - u * (lin / (1 + inner_)).reshape(*dims) + log_det = -torch.log(torch.abs(1 + inner_)) return z_, log_det diff --git a/normflows/flows/planar_test.py b/normflows/flows/planar_test.py new file mode 100644 index 0000000..05dbdc5 --- /dev/null +++ b/normflows/flows/planar_test.py @@ -0,0 +1,23 @@ +import unittest +import torch + +from normflows.flows import Planar +from normflows.flows.flow_test import FlowTest + + +class PlanarTest(FlowTest): + def test_normal_nsf(self): + batch_size = 3 + for latent_size in [(2,), (5,), (2, 3, 4)]: + for act in ["tanh", "leaky_relu"]: + with self.subTest(latent_size=latent_size, act=act): + flow = Planar(latent_size, act=act) + inputs = torch.randn((batch_size, *latent_size)) + if act == "leaky_relu": + self.checkForwardInverse(flow, inputs) + else: + self.checkForward(flow, inputs) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/normflows/flows/radial.py b/normflows/flows/radial.py index ee59885..49cfdf3 100644 --- a/normflows/flows/radial.py +++ b/normflows/flows/radial.py @@ -37,9 +37,10 @@ def __init__(self, shape, z_0=None): def forward(self, z): beta = torch.log(1 + torch.exp(self.beta)) - torch.abs(self.alpha) dz = z - self.z_0 - r = torch.norm(dz, dim=list(range(1, self.z_0.dim()))) + r = torch.linalg.vector_norm(dz, dim=list(range(1, self.z_0.dim())), keepdim=True) h_arr = beta / (torch.abs(self.alpha) + r) h_arr_ = -beta * r / (torch.abs(self.alpha) + r) ** 2 - z_ = z + h_arr.unsqueeze(1) * dz + z_ = z + h_arr * dz log_det = (self.d - 1) * torch.log(1 + h_arr) + torch.log(1 + h_arr + h_arr_) + log_det = log_det.reshape(-1) return z_, log_det diff --git a/normflows/flows/radial_test.py b/normflows/flows/radial_test.py new file mode 100644 index 0000000..c2c9936 --- /dev/null +++ b/normflows/flows/radial_test.py @@ -0,0 +1,19 @@ +import unittest +import torch + +from normflows.flows import Radial +from normflows.flows.flow_test import FlowTest + + +class PlanarTest(FlowTest): + def test_normal_nsf(self): + batch_size = 3 + for latent_size in [(2,), (5, 2), (2, 3, 4)]: + with self.subTest(latent_size=latent_size): + flow = Radial(latent_size) + inputs = torch.randn((batch_size, *latent_size)) + self.checkForward(flow, inputs) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/normflows/flows/residual.py b/normflows/flows/residual.py index 1ebcd99..128d7bb 100644 --- a/normflows/flows/residual.py +++ b/normflows/flows/residual.py @@ -16,16 +16,33 @@ class Residual(Flow): """ def __init__( - self, net, n_exact_terms=2, n_samples=1, reduce_memory=True, reverse=True + self, + net, + reverse=True, + reduce_memory=True, + geom_p=0.5, + lamb=2.0, + n_power_series=None, + exact_trace=False, + brute_force=False, + n_samples=1, + n_exact_terms=2, + n_dist="geometric" ): """Constructor Args: net: Neural network, must be Lipschitz continuous with L < 1 - n_exact_terms: Number of terms always included in the power series - n_samples: Number of samples used to estimate power series - reduce_memory: Flag, if true Neumann series and precomputations, for backward pass in forward pass are done reverse: Flag, if true the map ```f(x) = x + net(x)``` is applied in the inverse pass, otherwise it is done in forward + reduce_memory: Flag, if true Neumann series and precomputations, for backward pass in forward pass are done + geom_p: Parameter of the geometric distribution used for the Neumann series + lamb: Parameter of the geometric distribution used for the Neumann series + n_power_series: Number of terms in the Neumann series + exact_trace: Flag, if true the trace of the Jacobian is computed exactly + brute_force: Flag, if true the Jacobian is computed exactly in 2D + n_samples: Number of samples used to estimate power series + n_exact_terms: Number of terms always included in the power series + n_dist: Distribution used for the power series, either "geometric" or "poisson" """ super().__init__() self.reverse = reverse @@ -35,6 +52,12 @@ def __init__( n_exact_terms=n_exact_terms, neumann_grad=reduce_memory, grad_in_forward=reduce_memory, + exact_trace=exact_trace, + geom_p=geom_p, + lamb=lamb, + n_power_series=n_power_series, + brute_force=brute_force, + n_dist=n_dist, ) def forward(self, z): diff --git a/normflows/flows/residual_test.py b/normflows/flows/residual_test.py new file mode 100644 index 0000000..9ecd829 --- /dev/null +++ b/normflows/flows/residual_test.py @@ -0,0 +1,51 @@ +import unittest +import torch + +from normflows.flows import Residual +from normflows.nets import LipschitzMLP, LipschitzCNN +from normflows.flows.flow_test import FlowTest + + +class ResidualTest(FlowTest): + def test_residual_mlp(self): + batch_size = 3 + hidden_units = 128 + hidden_layers = 2 + for latent_size in [2, 5]: + for reduce_memory in [True, False]: + for exact_trace in [True, False]: + with self.subTest(latent_size=latent_size, reduce_memory=reduce_memory, + exact_trace=exact_trace): + layer = [latent_size] + [hidden_units] * hidden_layers + [latent_size] + net = LipschitzMLP(layer, init_zeros=exact_trace, lipschitz_const=0.9) + flow = Residual(net, reduce_memory=reduce_memory, + exact_trace=exact_trace) + inputs = torch.randn((batch_size, latent_size)) + if exact_trace: + self.checkForwardInverse(flow, inputs, atol=1e-4, rtol=1e-4) + else: + outputs, _ = self.checkForward(flow, inputs) + inputs_, _ = self.checkInverse(flow, outputs) + self.assertClose(inputs, inputs_, atol=1e-4, rtol=1e-4) + + + + def test_residual_cnn(self): + batch_size = 1 + hidden_units = 128 + kernel_size = 1 + img_size = (4, 4) + for latent_size in [3, 4]: + for reduce_memory in [True, False]: + with self.subTest(latent_size=latent_size, reduce_memory=reduce_memory): + channels = [latent_size, hidden_units, latent_size] + net = LipschitzCNN(channels, 2 * [kernel_size]) + flow = Residual(net, reduce_memory=reduce_memory) + inputs = torch.randn((batch_size, latent_size, *img_size)) + outputs, _ = self.checkForward(flow, inputs) + inputs_, _ = self.checkInverse(flow, outputs) + self.assertClose(inputs, inputs_, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/normflows/nets/cnn.py b/normflows/nets/cnn.py index 93b470e..71e99b2 100644 --- a/normflows/nets/cnn.py +++ b/normflows/nets/cnn.py @@ -1,4 +1,3 @@ -import torch from torch import nn from .. import utils diff --git a/normflows/nets/lipschitz.py b/normflows/nets/lipschitz.py index fe15d22..edb5130 100644 --- a/normflows/nets/lipschitz.py +++ b/normflows/nets/lipschitz.py @@ -119,7 +119,7 @@ def __init__( n_iterations=max_lipschitz_iter, atol=lipschitz_tolerance, rtol=lipschitz_tolerance, - zero_init=init_zeros if i == self.n_layers - 1 else False, + zero_init=init_zeros if i == (self.n_layers - 1) else False, ), ] @@ -307,6 +307,7 @@ def __init__( n_iterations=None, atol=None, rtol=None, + zero_init=False, **unused_kwargs ): del unused_kwargs @@ -329,7 +330,7 @@ def __init__( self.bias = nn.Parameter(torch.Tensor(out_channels)) else: self.register_parameter("bias", None) - self.reset_parameters() + self.reset_parameters(zero_init) self.register_buffer("initialized", torch.tensor(0)) self.register_buffer("spatial_dims", torch.tensor([1.0, 1.0])) self.register_buffer("scale", torch.tensor(0.0)) @@ -344,8 +345,11 @@ def compute_domain_codomain(self): domain, codomain = self.domain, self.codomain return domain, codomain - def reset_parameters(self): + def reset_parameters(self, zero_init=False): init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if zero_init: + # normalize cannot handle zero weight in some cases. + self.weight.data.div_(1000) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) diff --git a/normflows/nets/made_test.py b/normflows/nets/made_test.py index 6017f4c..e949678 100644 --- a/normflows/nets/made_test.py +++ b/normflows/nets/made_test.py @@ -1,22 +1,22 @@ """ Tests for MADE. -Code taken from https://github.com/bayesiains/nsf +Code partially taken from https://github.com/bayesiains/nsf """ import torch -import torchtestcase import unittest from normflows.nets import made +from torch.testing import assert_close -class ShapeTest(torchtestcase.TorchTestCase): +class ShapeTest(unittest.TestCase): def test_unconditional(self): - features = 100 - hidden_features = 200 - num_blocks = 5 + features = 10 + hidden_features = 20 + num_blocks = 3 output_multiplier = 3 - batch_size = 16 + batch_size = 4 inputs = torch.randn(batch_size, features) @@ -37,16 +37,16 @@ def test_unconditional(self): random_mask=random_mask, ) outputs = model(inputs) - self.assertEqual(outputs.dim(), 2) - self.assertEqual(outputs.shape[0], batch_size) - self.assertEqual(outputs.shape[1], output_multiplier * features) + assert outputs.dim() == 2 + assert outputs.shape[0] == batch_size + assert outputs.shape[1] == output_multiplier * features -class ConnectivityTest(torchtestcase.TorchTestCase): +class ConnectivityTest(unittest.TestCase): def test_gradients(self): features = 10 - hidden_features = 256 - num_blocks = 20 + hidden_features = 32 + num_blocks = 5 output_multiplier = 3 for use_residual_blocks, random_mask in [ @@ -72,11 +72,11 @@ def test_gradients(self): outputs[0, k].backward() depends = inputs.grad.data[0] != 0.0 dim = k // output_multiplier - self.assertEqual(torch.all(depends[dim:] == 0), 1) + assert torch.all(depends[dim:] == 0) == 1 def test_total_mask_sequential(self): features = 10 - hidden_features = 50 + hidden_features = 32 num_blocks = 5 output_multiplier = 1 @@ -102,11 +102,11 @@ def test_total_mask_sequential(self): total_mask = model.final_layer.mask @ total_mask total_mask = (total_mask > 0).float() reference = torch.tril(torch.ones([features, features]), -1) - self.assertEqual(total_mask, reference) + assert_close(total_mask, reference) def test_total_mask_random(self): features = 10 - hidden_features = 50 + hidden_features = 32 num_blocks = 5 output_multiplier = 1 @@ -120,11 +120,11 @@ def test_total_mask_random(self): ) total_mask = model.initial_layer.mask for block in model.blocks: - self.assertIsInstance(block, made.MaskedFeedforwardBlock) + assert isinstance(block, made.MaskedFeedforwardBlock) total_mask = block.linear.mask @ total_mask total_mask = model.final_layer.mask @ total_mask total_mask = (total_mask > 0).float() - self.assertEqual(torch.triu(total_mask), torch.zeros([features, features])) + assert_close(torch.triu(total_mask), torch.zeros([features, features])) if __name__ == "__main__": diff --git a/normflows/nets/mlp.py b/normflows/nets/mlp.py index f5ad18c..c6446e0 100644 --- a/normflows/nets/mlp.py +++ b/normflows/nets/mlp.py @@ -1,4 +1,3 @@ -import torch from torch import nn from .. import utils diff --git a/normflows/utils/nn.py b/normflows/utils/nn.py index 4a33d0d..65c2dec 100644 --- a/normflows/utils/nn.py +++ b/normflows/utils/nn.py @@ -27,16 +27,16 @@ class ActNorm(nn.Module): """ ActNorm layer with just one forward pass """ - - def __init__(self, shape, logscale_factor=None): + def __init__(self, shape): """Constructor Args: shape: Same as shape in flows.ActNorm logscale_factor: Same as shape in flows.ActNorm + """ super().__init__() - self.actNorm = flows.ActNorm(shape, logscale_factor=logscale_factor) + self.actNorm = flows.ActNorm(shape) def forward(self, input): out, _ = self.actNorm(input) @@ -61,9 +61,14 @@ def forward(self, x): return torch.min(torch.exp(x), one) -class PeriodicFeatures(nn.Module): +class PeriodicFeaturesElementwise(nn.Module): """ - Converts a specified part of the input to periodic features + Converts a specified part of the input to periodic features by + replacing those features f with + w1 * sin(scale * f) + w2 * cos(scale * f). + + Note that this operation is done elementwise and, therefore, + some information about the feature can be lost. """ def __init__(self, ndim, ind, scale=1.0, bias=False, activation=None): @@ -76,7 +81,7 @@ def __init__(self, ndim, ind, scale=1.0, bias=False, activation=None): bias: Flag, whether to add a bias activation: Function or None, activation function to be applied """ - super(PeriodicFeatures, self).__init__() + super(PeriodicFeaturesElementwise, self).__init__() # Set up indices and permutations self.ndim = ndim @@ -125,6 +130,54 @@ def forward(self, inputs): return out[..., self.inv_perm] +class PeriodicFeaturesCat(nn.Module): + """ + Converts a specified part of the input to periodic features by + replacing those features f with [sin(scale * f), cos(scale * f)]. + + Note that this decreases the number of features and their order + is changed. + """ + + def __init__(self, ndim, ind, scale=1.0): + """ + Constructor + :param ndim: Int, number of dimensions + :param ind: Iterable, indices of input elements to convert to + periodic features + :param scale: Scalar or iterable, used to scale inputs before + converting them to periodic features + """ + super(PeriodicFeaturesCat, self).__init__() + + # Set up indices and permutations + self.ndim = ndim + if torch.is_tensor(ind): + self.register_buffer("ind", torch._cast_Long(ind)) + else: + self.register_buffer("ind", torch.tensor(ind, dtype=torch.long)) + + ind_ = [] + for i in range(self.ndim): + if not i in self.ind: + ind_ += [i] + self.register_buffer("ind_", torch.tensor(ind_, dtype=torch.long)) + + if torch.is_tensor(scale): + self.register_buffer("scale", scale) + else: + self.scale = scale + + def forward(self, inputs): + inputs_ = inputs[..., self.ind] + inputs_ = self.scale * inputs_ + inputs_sin = torch.sin(inputs_) + inputs_cos = torch.cos(inputs_) + out = torch.cat((inputs_sin, inputs_cos, + inputs[..., self.ind_]), -1) + return out + + def tile(x, n): x_ = x.reshape(-1) x_ = x_.repeat(n) diff --git a/normflows/utils/splines_test.py b/normflows/utils/splines_test.py index 4b46535..64f023b 100644 --- a/normflows/utils/splines_test.py +++ b/normflows/utils/splines_test.py @@ -1,10 +1,12 @@ +import unittest + import torch -import torchtestcase from normflows.utils import splines +from torch.testing import assert_close -class RationalQuadraticSplineTest(torchtestcase.TorchTestCase): +class RationalQuadraticSplineTest(unittest.TestCase): def test_forward_inverse_are_consistent(self): num_bins = 10 shape = [2, 3, 4] @@ -26,12 +28,12 @@ def call_spline_fn(inputs, inverse=False): outputs, logabsdet = call_spline_fn(inputs, inverse=False) inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True) - self.eps = 1e-4 - self.assertEqual(inputs, inputs_inv) - self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet)) + assert_close(inputs, inputs_inv, atol=3e-5, rtol=3e-5) + assert_close(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet), + atol=2e-4, rtol=2e-4) -class UnconstrainedRationalQuadraticSplineTest(torchtestcase.TorchTestCase): +class UnconstrainedRationalQuadraticSplineTest(unittest.TestCase): def test_forward_inverse_are_consistent(self): num_bins = 10 shape = [2, 3, 4] @@ -49,10 +51,10 @@ def call_spline_fn(inputs, inverse=False): inverse=inverse, ) - inputs = 3 * torch.randn(*shape) # Note inputs are outside [0,1]. + inputs = torch.randn(*shape) # Note inputs are outside [0,1]. outputs, logabsdet = call_spline_fn(inputs, inverse=False) inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True) - self.eps = 1e-4 - self.assertEqual(inputs, inputs_inv) - self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet)) + assert_close(inputs, inputs_inv, atol=3e-5, rtol=3e-5) + assert_close(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet), + atol=2e-4, rtol=2e-4)