# Free-form injective flow with moons manifold

In [3]:
import torch
from tqdm.auto import trange
from sklearn.datasets import make_moons

from ciflows.loss import volume_change_surrogate

import matplotlib.pyplot as plt

In [2]:
dim = 2
latent_dim = dim  # The code below also works for latent_dim < dim
hidden_dim = 128
n_steps = 10000

noise = 0.1

beta = 100
batch_size = 1024
device = "mps"


class SkipConnection(torch.nn.Module):
    def __init__(self, inner):
        super().__init__()
        self.inner = inner

    def forward(self, x, *args, **kwargs):
        return x + self.inner(x, *args, **kwargs)


# Do not use ReLU below, makes training unstable
encoder = SkipConnection(torch.nn.Sequential(
    torch.nn.Linear(dim, hidden_dim), torch.nn.SiLU(),
    torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.SiLU(),
    torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.SiLU(),
    torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.SiLU(),
    torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.SiLU(),
    torch.nn.Linear(hidden_dim, latent_dim)
).to(device))
decoder = SkipConnection(torch.nn.Sequential(
    torch.nn.Linear(latent_dim, hidden_dim), torch.nn.SiLU(),
    torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.SiLU(),
    torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.SiLU(),
    torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.SiLU(),
    torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.SiLU(),
    torch.nn.Linear(hidden_dim, dim)
).to(device))

latent = torch.distributions.Independent(
    torch.distributions.Normal(
        loc=torch.zeros(latent_dim, device=device),
        scale=torch.ones(latent_dim, device=device),
    ),
    1
)

In [8]:

optim = torch.optim.Adam([*encoder.parameters(), *decoder.parameters()], lr=3e-4)

with trange(n_steps) as pbar:
    for step in pbar:
        optim.zero_grad()
        batch, _ = make_moons(batch_size, noise=noise)
        x = torch.from_numpy(batch).float().to(device)
        surrogate_loss, v, xhat = volume_change_surrogate(
            x, encoder, decoder, hutchinson_samples=1
        )
        loss_reconstruction = ((x - xhat) ** 2).sum(-1).mean(-1)
        loss_nll = -latent.log_prob(v) - surrogate_loss
        loss = beta * loss_reconstruction + loss_nll
        loss.mean().backward()
        optim.step()

        if step % 100 == 0:
            with torch.no_grad():
                batch, _ = make_moons(batch_size, noise=noise)
                x = torch.from_numpy(batch).float().to(device)
                surrogate_loss, v, xhat =  volume_change_surrogate(x, encoder, decoder, hutchinson_samples=1)
                nll_out = -latent.log_prob(latent.sample((batch_size,))) - surrogate_loss
                reconstruction = ((x - xhat) ** 2).sum(-1).mean(-1)
                pbar.set_description(f"Reconstruction: {reconstruction:.1e}, NLL: {nll_out.mean():.2f}")

Reconstruction: 2.7e-01, NLL: -3444.53:  41%|████      | 4107/10000 [01:03<01:32, 63.84it/s]

In [5]:
def sample(n_samples):
    z = torch.randn(n_samples, latent_dim, device=device)
    return decoder(z)

(1024, 2)


In [None]:
test_size = 1_000
test_batch, _ = make_moons(test_size, noise=noise)

plot_kwargs = dict(
    s=4,
    alpha=.7
)

plt.scatter(*test_batch.T, label="True", **plot_kwargs)

with torch.no_grad():
    plt.scatter(*sample(test_size).cpu().T, label="FFF", **plot_kwargs)
plt.legend()