In [16]:
from typing import Callable, List, Tuple

import lightning
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions.distribution import Distribution

# coupling layer
based on

https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/09-normalizing-flows.html (VPN)

https://sebastiancallh.github.io/post/affine-normalizing-flows/

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial11/NF_image_modeling.html

In [14]:
class AffineCouplingLayer(nn.Module):
    def __init__(
        self,
        theta: nn.Module,
        split: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]],
    ):
        super().__init__()
        self.theta = theta
        self.split = split

    def f(self, x: torch.Tensor) -> torch.Tensor:
        """f : x -> z. The inverse of g."""
        x2, x1 = self.split(x)
        t, s = self.theta(x1)
        z1, z2 = x1, x2 * torch.exp(s) + t
        log_det = s.sum(-1)
        return torch.cat((z1, z2), dim=-1), log_det

    def g(self, z: torch.Tensor) -> torch.Tensor:
        """g : z -> x. The inverse of f."""
        z1, z2 = self.split(z)
        t, s = self.theta(z1)
        x1, x2 = z1, (z2 - t) * torch.exp(-s)
        return torch.cat((x2, x1), dim=-1)

In [15]:
class NormalizingFlow(nn.Module):
    def __init__(self, latent: Distribution, flows: List[nn.Module]):
        super().__init__()
        self.latent = latent
        self.flows = flows

    def latent_log_prob(self, z: torch.Tensor) -> torch.Tensor:
        return self.latent.log_prob(z)

    def latent_sample(self, num_samples: int = 1) -> torch.Tensor:
        return self.latent.sample((num_samples,))

    def sample(self, num_samples: int = 1) -> torch.Tensor:
        """Sample a new observation x by sampling z from
        the latent distribution and pass through g."""
        return self.g(self.latent_sample(num_samples))

    def f(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Maps observation x to latent variable z.
        Additionally, computes the log determinant
        of the Jacobian for this transformation.
        Inveres of g."""
        z, sum_log_abs_det = x, torch.ones(x.size(0)).to(x.device)
        for flow in self.flows:
            z, log_abs_det = flow.f(z)
            sum_log_abs_det += log_abs_det

        return z, sum_log_abs_det

    def g(self, z: torch.Tensor) -> torch.Tensor:
        """Maps latent variable z to observation x.
        Inverse of f."""
        with torch.no_grad():
            x = z
            for flow in reversed(self.flows):
                x = flow.g(x)

            return x

    def g_steps(self, z: torch.Tensor) -> List[torch.Tensor]:
        """Maps latent variable z to observation x
        and stores intermediate results."""
        xs = [z]
        for flow in reversed(self.flows):
            xs.append(flow.g(xs[-1]))

        return xs

    def log_prob(self, x: torch.Tensor) -> torch.Tensor:
        """Computes log p(x) using the change of variable formula."""
        z, log_abs_det = self.f(x)
        return self.latent_log_prob(z) + log_abs_det

    def __len__(self) -> int:
        return len(self.flows)

In [18]:
class ThetaNetwork(nn.Module):
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        num_hidden: int,
        hidden_dim: int,
        num_params: int,
    ):
        super().__init__()
        self.input = nn.Linear(in_dim, hidden_dim)
        self.hidden = nn.ModuleList(
            [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_hidden)]
        )

        self.num_params = num_params
        self.out_dim = out_dim
        self.dims = nn.Linear(hidden_dim, out_dim * num_params)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.leaky_relu(self.input(x))
        for h in self.hidden:
            x = F.leaky_relu(h(x))

        batch_params = self.dims(x).reshape(x.size(0), self.out_dim, -1)
        params = batch_params.chunk(self.num_params, dim=-1)
        return [p.squeeze(-1) for p in params]

In [19]:
def SplitFunc(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    pass

In [20]:
NormalizingFlow(
    latent=torch.distributions.Normal(loc=0.0, scale=1.),
    flows=[
        AffineCouplingLayer(
            theta=ThetaNetwork(
                in_dim = 42,
                out_dim = 42,
                num_hidden = 42,
                hidden_dim = 42,
                num_params = 42
            ),
            split=SplitFunc
        ),
        AffineCouplingLayer(
            theta=ThetaNetwork(
                in_dim = 42,
                out_dim = 42,
                num_hidden = 42,
                hidden_dim = 42,
                num_params = 42
            ),
            split=SplitFunc
        ),
        AffineCouplingLayer(
            theta=ThetaNetwork(
                in_dim = 42,
                out_dim = 42,
                num_hidden = 42,
                hidden_dim = 42,
                num_params = 42
            ),
            split=SplitFunc
        )
    ]
)

NormalizingFlow()