In [None]:
import math
import torch
import torch.nn as nn
import matplotlib.pyplot as plt


# ----------------------------
# Utilities
# ----------------------------
def set_seed(seed: int = 0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ----------------------------
# Target distribution: 2D mixture of Gaussians (4 modes)
# ----------------------------
def sample_target(n: int, device):
    """
    生成 2D 四峰高斯混合数据，方便肉眼看多峰结构。
    """
    means = torch.tensor(
        [
            [2.0, 2.0],
            [2.0, -2.0],
            [-2.0, 2.0],
            [-2.0, -2.0],
        ],
        device=device,
    )
    # 每个分量方差相同
    std = 0.35

    # 随机选分量
    k = torch.randint(0, 4, (n,), device=device)
    eps = torch.randn(n, 2, device=device) * std
    x = means[k] + eps
    return x


# ----------------------------
# Small MLP to produce s(.) or t(.)
# Input: scalar (conditioned dim), Output: scalar (for transformed dim)
# ----------------------------
class ScalarMLP(nn.Module):
    def __init__(self, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden),
            nn.Tanh(),
            nn.Linear(hidden, 1),
        )

    def forward(self, x):  # x: (N, 1)
        return self.net(x)


# ----------------------------
# 2D RealNVP-style coupling layer:
# x1 = z1
# x2 = z2 * exp(s(z1)) + t(z1)
# Invertible because exp(s)>0
# ----------------------------
class CouplingLayer2D(nn.Module):
    def __init__(self, hidden=64, clamp_s=5.0, swap=False):
        """
        swap=False: condition on dim0, transform dim1
        swap=True : condition on dim1, transform dim0 (通过交换维度实现)
        """
        super().__init__()
        self.s_net = ScalarMLP(hidden)
        self.t_net = ScalarMLP(hidden)
        self.clamp_s = clamp_s
        self.swap = swap

    def _st(self, cond):  # cond: (N,) -> (N,)
        cond = cond.unsqueeze(1)  # (N,1)
        s = self.s_net(cond).squeeze(1)
        t = self.t_net(cond).squeeze(1)
        # 防止 exp(s) 爆炸，训练更稳定
        s = torch.clamp(s, -self.clamp_s, self.clamp_s)
        return s, t

    def forward(self, z):
        """
        z -> x, return (x, logdet)
        logdet = sum log|det dx/dz| per sample
        """
        if self.swap:
            z = z[:, [1, 0]]  # swap dims

        z1, z2 = z[:, 0], z[:, 1]
        s, t = self._st(z1)

        x1 = z1
        x2 = z2 * torch.exp(s) + t
        x = torch.stack([x1, x2], dim=1)

        if self.swap:
            x = x[:, [1, 0]]  # swap back

        logdet = s  # (N,) because Jacobian triangular with diag = [1, exp(s)]
        return x, logdet

    def inverse(self, x):
        """
        x -> z, return (z, logdet_inv)
        logdet_inv = sum log|det dz/dx| per sample
        """
        if self.swap:
            x = x[:, [1, 0]]

        x1, x2 = x[:, 0], x[:, 1]
        s, t = self._st(x1)

        z1 = x1
        z2 = (x2 - t) * torch.exp(-s)
        z = torch.stack([z1, z2], dim=1)

        if self.swap:
            z = z[:, [1, 0]]

        logdet_inv = -s
        return z, logdet_inv


# ----------------------------
# Flow model: stack coupling layers
# ----------------------------
class RealNVP2D(nn.Module):
    def __init__(self, num_layers=6, hidden=64):
        super().__init__()
        layers = []
        for k in range(num_layers):
            layers.append(CouplingLayer2D(hidden=hidden, swap=(k % 2 == 1)))
        self.layers = nn.ModuleList(layers)

        # base distribution: standard normal in R^2
        self.base = torch.distributions.Normal(0.0, 1.0)

    def log_prob(self, x):
        """
        log p(x) = log p(z) + log|det dz/dx|
        where z = f^{-1}(x)
        """
        z = x
        logdet_sum = torch.zeros(x.shape[0], device=x.device)
        for layer in reversed(self.layers):
            z, logdet_inv = layer.inverse(z)
            logdet_sum += logdet_inv

        # base log prob: sum over dims
        logp_z = self.base.log_prob(z).sum(dim=1)
        return logp_z + logdet_sum

    def sample(self, n, device):
        z = self.base.sample((n, 2)).to(device)
        x = z
        # forward through layers
        for layer in self.layers:
            x, _ = layer.forward(x)
        return x

    def check_invertibility(self, device):
        with torch.no_grad():
            z0 = torch.randn(256, 2, device=device)
            x, _ = self.layers[0].forward(z0)  # check first layer quickly
            z1, _ = self.layers[0].inverse(x)
            err = (z1 - z0).abs().max().item()
        return err


# ----------------------------
# Training + visualization
# ----------------------------
def main():
    set_seed(0)
    device = get_device()
    print("device:", device)

    # Model
    flow = RealNVP2D(num_layers=6, hidden=64).to(device)

    # Quick invertibility check (single layer)
    err = flow.check_invertibility(device)
    print("invertibility check (layer 0) max |z_rec - z| =", err)

    # Data
    n_data = 20000
    x_data = sample_target(n_data, device=device)

    # Optimizer
    opt = torch.optim.Adam(flow.parameters(), lr=1e-3)

    # Training
    steps = 2000
    batch_size = 512

    for step in range(1, steps + 1):
        idx = torch.randint(0, n_data, (batch_size,), device=device)
        xb = x_data[idx]

        opt.zero_grad()
        nll = -flow.log_prob(xb).mean()
        nll.backward()
        opt.step()

        if step % 200 == 0:
            print(f"step {step:4d} | NLL {nll.item():.4f}")

    # Sample from trained flow
    with torch.no_grad():
        x_model = flow.sample(5000, device=device).cpu()
        x_true = sample_target(5000, device=device).cpu()

    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    axes[0].scatter(x_true[:, 0], x_true[:, 1], s=5)
    axes[0].set_title("Target samples")
    axes[0].set_aspect("equal", "box")

    axes[1].scatter(x_model[:, 0], x_model[:, 1], s=5)
    axes[1].set_title("Flow samples")
    axes[1].set_aspect("equal", "box")

    plt.tight_layout()
    plt.show()

    print("DONE")


if __name__ == "__main__":
    main()
