In [1]:
import torch

a = torch.tensor(0.7) 
b = torch.tensor(2.0)

base = torch.distributions.Normal(0.0, 1.0)

def forward(z):
    # x = f(z)
    return a * z + b

def inverse(x):
    # z = f^{-1}(x)
    return (x - b) / a

def log_prob_x(x):
    z = inverse(x)
    return base.log_prob(z) - torch.log(torch.abs(a))

def reconstruct_x(x):
    # x -> z -> x_hat
    z = inverse(x)
    x_hat = forward(z)
    return x_hat

z = base.sample((5,))
x = forward(z)

x_hat = reconstruct_x(x)

err = (x_hat - x).abs()

print("z:", z)
print("x:", x)
print("x_hat:", x_hat)
print("abs error:", err)
print("max abs error:", err.max().item())

print("log p(x):", log_prob_x(x))


z: tensor([-0.3040,  0.2222,  0.8658,  0.2976, -1.6826])
x: tensor([1.7872, 2.1556, 2.6060, 2.2083, 0.8222])
x_hat: tensor([1.7872, 2.1556, 2.6060, 2.2083, 0.8222])
abs error: tensor([0., 0., 0., 0., 0.])
max abs error: 0.0
log p(x): tensor([-0.6085, -0.5870, -0.9370, -0.6065, -1.9778])


In [2]:
import math
import torch
import torch.nn as nn

# ----------------------------
# 1D Affine Flow: x = a z + b, with a = exp(s) > 0
# ----------------------------
class AffineFlow1D(nn.Module):
    def __init__(self):
        super().__init__()
        self.s = nn.Parameter(torch.tensor(0.0))  # log a
        self.b = nn.Parameter(torch.tensor(0.0))  # shift

        self.base = torch.distributions.Normal(loc=0.0, scale=1.0)

    @property
    def a(self):
        return torch.exp(self.s)

    def sample(self, n: int, device=None):
        if device is None:
            device = self.s.device
        z = self.base.sample((n,)).to(device)
        x = self.a * z + self.b
        return x

    def log_prob(self, x: torch.Tensor):
        # inverse: z = (x - b) / a
        z = (x - self.b) / self.a
        # log p(x) = log p(z) - log|a|
        return self.base.log_prob(z) - self.s  # 1D, so no sum needed

def main():
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ----------------------------
    # Target data (what we want the flow to fit)
    # ----------------------------
    true_mu = 2.0
    true_sigma = 0.7
    target = torch.distributions.Normal(true_mu, true_sigma)

    n_data = 5000
    x_data = target.sample((n_data,)).to(device)

    # ----------------------------
    # Flow model
    # ----------------------------
    flow = AffineFlow1D().to(device)
    opt = torch.optim.Adam(flow.parameters(), lr=5e-2)

    # ----------------------------
    # Train by maximum likelihood: maximize log_prob(data)
    # ----------------------------
    for step in range(1, 501):
        opt.zero_grad()
        nll = -flow.log_prob(x_data).mean()  # negative log-likelihood
        nll.backward()
        opt.step()

        if step % 100 == 0:
            with torch.no_grad():
                est_mu = flow.b.item()
                est_sigma = flow.a.item()  # since z~N(0,1), x~N(b, a^2)
            print(f"step {step:4d} | NLL {nll.item():.4f} | a={flow.a.item():.4f} b={flow.b.item():.4f} "
                  f"| est_mu={est_mu:.4f} est_sigma={est_sigma:.4f}")

    # ----------------------------
    # Quick sanity check: sample and compare moments
    # ----------------------------
    with torch.no_grad():
        xs = flow.sample(200000, device=device)
        print("\nTarget:   mu=%.3f sigma=%.3f" % (true_mu, true_sigma))
        print("Learned:  mu=%.3f sigma=%.3f" % (xs.mean().item(), xs.std(unbiased=False).item()))
        print("DONE")
        
if __name__ == "__main__":
    main()



step  100 | NLL 1.0543 | a=0.7008 b=1.9704 | est_mu=1.9704 est_sigma=0.7008
step  200 | NLL 1.0536 | a=0.6940 b=1.9948 | est_mu=1.9948 est_sigma=0.6940
step  300 | NLL 1.0536 | a=0.6940 b=1.9947 | est_mu=1.9947 est_sigma=0.6940
step  400 | NLL 1.0536 | a=0.6940 b=1.9947 | est_mu=1.9947 est_sigma=0.6940
step  500 | NLL 1.0536 | a=0.6940 b=1.9947 | est_mu=1.9947 est_sigma=0.6940

Target:   mu=2.000 sigma=0.700
Learned:  mu=1.994 sigma=0.695
DONE
