In [None]:
import math
import torch
import torch.nn as nn
import torch.optim as optim


# ----------------------------
# 1) 一个 2D RealNVP-style coupling layer (single split)
#    Forward: y1 = z1
#             y2 = z2 * exp(s(z1)) + t(z1)
#    Inverse: z1 = y1
#             z2 = (y2 - t(y1)) * exp(-s(y1))
#    log|det J_forward|  = s(z1)
#    log|det J_inverse|  = -s(y1)
# ----------------------------
class Coupling2D(nn.Module):
    def __init__(self, hidden=64):
        super().__init__()
        self.s_net = nn.Sequential(
            nn.Linear(1, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden),
            nn.Tanh(),
            nn.Linear(hidden, 1),
        )
        self.t_net = nn.Sequential(
            nn.Linear(1, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden),
            nn.Tanh(),
            nn.Linear(hidden, 1),
        )

    def s(self, x1):
        return 0.8 * torch.tanh(self.s_net(x1))

    def t(self, x1):
        return self.t_net(x1)

    def forward(self, z):
        z1 = z[:, 0:1]
        z2 = z[:, 1:2]
        s = self.s(z1)
        t = self.t(z1)
        y1 = z1
        y2 = z2 * torch.exp(s) + t
        y = torch.cat([y1, y2], dim=1)
        logdet = s.squeeze(1) 
        return y, logdet

    def inverse(self, y):
        y1 = y[:, 0:1]
        y2 = y[:, 1:2]
        s = self.s(y1)
        t = self.t(y1)
        z1 = y1
        z2 = (y2 - t) * torch.exp(-s)
        z = torch.cat([z1, z2], dim=1)
        logdet_inv = (-s).squeeze(1) 
        return z, logdet_inv


def log_normal_diag(z, sigma=1.0):
    # z: (batch, d)
    d = z.shape[1]
    var = sigma ** 2
    return -0.5 * d * math.log(2 * math.pi * var) - 0.5 * (z**2).sum(dim=1) / var

torch.manual_seed(0)
device = "cpu"

true_flow = Coupling2D(hidden=64).to(device)

with torch.no_grad():
    for p in true_flow.parameters():
        p.add_(0.5 * torch.randn_like(p))

n = 5000
z0 = torch.randn(n, 2, device=device)
with torch.no_grad():
    y_data, _ = true_flow.forward(z0)
y_data = y_data.detach()


model = Coupling2D(hidden=64).to(device)
opt = optim.Adam(model.parameters(), lr=1e-3)

sigma = 1.0  

def batch_nll(y_batch):
    z, logdet_inv = model.inverse(y_batch)
    logpz = log_normal_diag(z, sigma=sigma)
    logpy = logpz + logdet_inv
    return -logpy.mean()  # 

batch_size = 512
steps = 1500

for step in range(steps):
    idx = torch.randint(0, n, (batch_size,), device=device)
    yb = y_data[idx]
    nll = batch_nll(yb)

    opt.zero_grad()
    nll.backward()
    opt.step()

    if step % 200 == 0:
        with torch.no_grad():
            nll_full = batch_nll(y_data).item()
        print(f"step {step:4d} | NLL(full) = {nll_full:.4f}")

print("done.")

step    0 | NLL(full) = 7.7374
step  200 | NLL(full) = 3.6060
step  400 | NLL(full) = 3.5657
step  600 | NLL(full) = 3.5555
step  800 | NLL(full) = 3.5494
step 1000 | NLL(full) = 3.5452
step 1200 | NLL(full) = 3.5417
step 1400 | NLL(full) = 3.5389
done.
