### Test wrapper on moons

Test that we can use `flowtorch` with transformations from `nflows`. 

Example is modifies  from [nflows](https://github.com/bayesiains/nflows)

In [None]:
import matplotlib.pyplot as plt
import sklearn.datasets as datasets

import torch
from torch import nn
from torch import optim

from nflows.flows.base import Flow
from nflows.distributions.normal import StandardNormal
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.flows import realnvp, autoregressive

import shapeflow as sf

import flowtorch.distributions as ftdist

from shapeflow import ModuleBijector, WrapInverseModel

In [None]:
# print train data
x, y = datasets.make_moons(128, noise=0.1)
plt.scatter(x[:, 0], x[:, 1]);

In [None]:
# Set up model
num_layers = 5
dims = 2
base_dist_nflows = StandardNormal(shape=[2])
base_dist = torch.distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))

transforms = []
for _ in range(num_layers):
    transforms.append(ReversePermutation(features=2))
    transforms.append(
        MaskedAffineAutoregressiveTransform(features=2, hidden_features=4)
    )
transform = CompositeTransform(transforms)
bijector = WrapInverseModel(model=transform)

flow = ftdist.Flow(bijector=bijector, base_dist=base_dist)
optimizer = optim.Adam(flow.parameters())

Test that the wrapper works

In [None]:
num_iter = 5000
for i in range(num_iter):
    x, y = datasets.make_moons(128, noise=0.1)
    x = torch.tensor(x, dtype=torch.float32)
    optimizer.zero_grad()
    loss = -flow.log_prob(value=x).mean()

    loss.backward()
    optimizer.step()

    if (i + 1) % 1000 == 0:
        xline = torch.linspace(-1.5, 2.5)
        yline = torch.linspace(-0.75, 1.25)
        xgrid, ygrid = torch.meshgrid(xline, yline)
        xyinput = torch.cat([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], dim=1)

        with torch.no_grad():
            zgrid = flow.log_prob(xyinput).exp().reshape(100, 100)
        plt.contourf(xgrid.numpy(), ygrid.numpy(), zgrid.numpy())
        plt.title("iteration {}".format(i + 1))
        plt.show()