-
-
Notifications
You must be signed in to change notification settings - Fork 99
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added function to create periodic features by concatenating sin and cos of features * Removed unused referenced * Added flow test module, test for glow * Fixed mixing warning * Old test module replaced * Test for coupling layers added * Added test for planar flow and nsf wrapper * Fixed multidim planar flow * Fix radial multidim case * Added tests for residual flow * Added link to docu in readme * Compatibility with most recent pytorch version
- Loading branch information
1 parent
bbcd6c1
commit 2e3c8ba
Showing
28 changed files
with
449 additions
and
272 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import unittest | ||
import torch | ||
|
||
from torch.testing import assert_close | ||
from normflows.flows import MaskedAffineFlow, CCAffineConst | ||
from normflows.nets import MLP | ||
from normflows.flows.flow_test import FlowTest | ||
|
||
|
||
class CouplingTest(FlowTest): | ||
def test_mask_affine(self): | ||
batch_size = 5 | ||
for latent_size in [2, 7]: | ||
with self.subTest(latent_size=latent_size): | ||
b = torch.Tensor([1 if i % 2 == 0 else 0 for i in range(latent_size)]) | ||
s = MLP([latent_size, 2 * latent_size, latent_size], init_zeros=True) | ||
t = MLP([latent_size, 2 * latent_size, latent_size], init_zeros=True) | ||
flow = MaskedAffineFlow(b, t, s) | ||
inputs = torch.randn((batch_size, latent_size)) | ||
self.checkForwardInverse(flow, inputs) | ||
|
||
def test_cc_affine(self): | ||
batch_size = 5 | ||
for shape in [(5,), (2, 3, 4)]: | ||
for num_classes in [2, 5]: | ||
with self.subTest(shape=shape, num_classes=num_classes): | ||
flow = CCAffineConst(shape, num_classes) | ||
x = torch.randn((batch_size,) + shape) | ||
y = torch.rand((batch_size,) + (num_classes,)) | ||
x_, log_det = flow(x, y) | ||
x__, log_det_ = flow(x_, y) | ||
|
||
assert x_.dtype == x.dtype | ||
assert x__.dtype == x.dtype | ||
|
||
assert x_.shape == x.shape | ||
assert x__.shape == x.shape | ||
|
||
assert_close(x__, x) | ||
id_ld = log_det + log_det_ | ||
assert_close(id_ld, torch.zeros_like(id_ld)) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import unittest | ||
import torch | ||
|
||
from normflows.flows import GlowBlock | ||
from normflows.flows.flow_test import FlowTest | ||
|
||
|
||
class GlowTest(FlowTest): | ||
def test_glow(self): | ||
img_size = (4, 4) | ||
hidden_channels = 8 | ||
for batch_size, channels, scale, split_mode, use_lu, net_actnorm in [ | ||
(1, 3, True, "channel", True, False), | ||
(2, 3, True, "channel_inv", True, False), | ||
(1, 4, True, "channel_inv", True, True), | ||
(2, 4, True, "channel", True, False), | ||
(1, 4, False, "channel", False, False), | ||
(1, 4, True, "checkerboard", True, True), | ||
(3, 5, False, "checkerboard", False, True) | ||
]: | ||
with self.subTest(batch_size=batch_size, channels=channels, | ||
scale=scale, split_mode=split_mode, | ||
use_lu=use_lu, net_actnorm=net_actnorm): | ||
inputs = torch.rand((batch_size, channels) + img_size) | ||
flow = GlowBlock(channels, hidden_channels, | ||
scale=scale, split_mode=split_mode, | ||
use_lu=use_lu, net_actnorm=net_actnorm) | ||
self.checkForwardInverse(flow, inputs) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import unittest | ||
import torch | ||
|
||
from torch.testing import assert_close | ||
|
||
|
||
class FlowTest(unittest.TestCase): | ||
""" | ||
Generic test case for flow modules | ||
""" | ||
def assertClose(self, actual, expected, atol=None, rtol=None): | ||
assert_close(actual, expected, atol=atol, rtol=rtol) | ||
|
||
def checkForward(self, flow, inputs): | ||
# Do forward transform | ||
outputs, log_det = flow(inputs) | ||
# Check type | ||
assert outputs.dtype == inputs.dtype | ||
# Check shape | ||
assert outputs.shape == inputs.shape | ||
# Return results | ||
return outputs, log_det | ||
|
||
def checkInverse(self, flow, inputs): | ||
# Do inverse transform | ||
outputs, log_det = flow.inverse(inputs) | ||
# Check type | ||
assert outputs.dtype == inputs.dtype | ||
# Check shape | ||
assert outputs.shape == inputs.shape | ||
# Return results | ||
return outputs, log_det | ||
|
||
def checkForwardInverse(self, flow, inputs, atol=None, rtol=None): | ||
# Check forward | ||
outputs, log_det = self.checkForward(flow, inputs) | ||
# Check inverse | ||
input_, log_det_ = self.checkInverse(flow, outputs) | ||
# Check identity | ||
self.assertClose(input_, inputs, atol, rtol) | ||
ld_id = log_det + log_det_ | ||
self.assertClose(ld_id, torch.zeros_like(ld_id), atol, rtol) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.