<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Normalizing_Flows.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class PlanarFlow(nn.Module):
    def __init__(self, input_dim):
        super(PlanarFlow, self).__init__()
        self.u = nn.Parameter(torch.randn(input_dim))
        self.w = nn.Parameter(torch.randn(input_dim))
        self.b = nn.Parameter(torch.randn(1))

    def forward(self, x):
        linear = x @ self.w + self.b  # Shape: [batch_size]
        psi = (1 - torch.tanh(linear) ** 2).unsqueeze(1) * self.w  # Shape: [batch_size, input_dim]
        return x + self.u * torch.tanh(linear).unsqueeze(1), torch.log(torch.abs(1 + psi @ self.u))

class NormalizingFlow(nn.Module):
    def __init__(self, input_dim, n_flows):
        super(NormalizingFlow, self).__init__()
        self.flows = nn.ModuleList([PlanarFlow(input_dim) for _ in range(n_flows)])

    def forward(self, x):
        log_det_jacobians = 0
        for flow in self.flows:
            x, log_det_jacobian = flow(x)
            log_det_jacobians += log_det_jacobian
        return x, log_det_jacobians

# Example usage
model = NormalizingFlow(input_dim=2, n_flows=5)
optimizer = optim.Adam(model.parameters(), lr=0.001)

for _ in range(100):
    x = torch.randn(32, 2)
    z, log_det_jacobians = model(x)
    loss = -torch.mean(-0.5 * torch.sum(z**2, dim=1) + log_det_jacobians)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f'Loss: {loss.item():.4f}')