# ðŸŒ€ RealNVP

# ðŸŒ€ RealNVP (PyTorch)

åœ¨è¿™ä¸ª Notebook ä¸­ï¼Œæˆ‘ä»¬å°†ä½¿ç”¨ **PyTorch** ä»Žé›¶å¼€å§‹è®­ç»ƒä¸€ä¸ª **RealNVP**æ¨¡åž‹ï¼Œç”¨äºŽå­¦ä¹ ä¸€ä¸ªäºŒç»´ toy datasetï¼ˆmake_moonsï¼‰çš„æ¦‚çŽ‡åˆ†å¸ƒã€‚


In [None]:
# %%
%load_ext autoreload
%autoreload 2

In [None]:
# %%
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from sklearn import datasets

## 0. Parameters <a name="parameters"></a>

In [None]:
# %%
COUPLING_DIM = 256
COUPLING_LAYERS = 2
INPUT_DIM = 2
REGULARIZATION = 0.01
BATCH_SIZE = 256
EPOCHS = 300
LR = 1e-4

## 1. Prepare the data<a name="parameters"></a>

In [None]:
# %%
# Load toy dataset (make_moons)
data = datasets.make_moons(30000, noise=0.05)[0].astype("float32")

In [None]:
# %%
# Normalize data (å¯¹åº” Keras Normalization layer)
mean = data.mean(axis=0, keepdims=True)
std = data.std(axis=0, keepdims=True)
normalized_data = (data - mean) / std

plt.scatter(
    normalized_data[:, 0], normalized_data[:, 1], c="green", s=1
)
plt.title("Normalized data")
plt.show()

In [None]:
# %%
dataset = TensorDataset(torch.from_numpy(normalized_data))
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

## 2. Build the RealNVP network <a name="build"></a>

In [None]:
# %%
class CouplingNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, reg):
        super().__init__()

        self.net_s = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Tanh(),
        )

        self.net_t = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
        )

        # L2 regularizationï¼ˆweight decayï¼‰
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

    def forward(self, x):
        s = self.net_s(x)
        t = self.net_t(x)
        return s, t

In [None]:
# %%
class RealNVP(nn.Module):
    def __init__(self, input_dim, coupling_layers, coupling_dim):
        super().__init__()

        self.coupling_layers = coupling_layers

        # Base distribution N(0, I)
        self.base_dist = torch.distributions.MultivariateNormal(
            torch.zeros(input_dim),
            torch.eye(input_dim),
        )

        # Alternating masks
        self.masks = torch.tensor(
            [[0.0, 1.0], [1.0, 0.0]] * (coupling_layers // 2)
        )

        self.couplings = nn.ModuleList(
            [
                CouplingNet(input_dim, coupling_dim, REGULARIZATION)
                for _ in range(coupling_layers)
            ]
        )

    def forward(self, x, training=True):
        log_det = torch.zeros(x.size(0), device=x.device)
        direction = -1 if training else 1

        layers = range(self.coupling_layers)[::direction]

        for i in layers:
            mask = self.masks[i].to(x.device)
            x_masked = x * mask
            s, t = self.couplings[i](x_masked)

            s = s * (1 - mask)
            t = t * (1 - mask)

            gate = (direction - 1) / 2
            x = (
                (1 - mask)
                * (x * torch.exp(direction * s) + direction * t * torch.exp(gate * s))
                + x_masked
            )

            log_det += gate * torch.sum(s, dim=1)

        return x, log_det

    def log_loss(self, x):
        z, log_det = self.forward(x, training=True)
        log_prob = self.base_dist.log_prob(z)
        return -torch.mean(log_prob + log_det)

In [None]:
# %%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = RealNVP(
    input_dim=INPUT_DIM,
    coupling_layers=COUPLING_LAYERS,
    coupling_dim=COUPLING_DIM,
).to(device)

optimizer = torch.optim.Adam(
    model.parameters(), lr=LR, weight_decay=REGULARIZATION
)

## 3. Train the RealNVP network <a name="train"></a>

In [None]:
# %%
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0

    for (x,) in loader:
        x = x.to(device)

        optimizer.zero_grad()
        loss = model.log_loss(x)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    if epoch % 10 == 0:
        print(f"Epoch {epoch:03d} | Loss: {total_loss / len(loader):.4f}")

## 4. Generate images <a name="generate"></a>

In [None]:
# %%
@torch.no_grad()
def generate_samples(model, num_samples):
    model.eval()

    # Data â†’ latent
    z, _ = model(
        torch.from_numpy(normalized_data).to(device), training=True
    )

    # Sample from base distribution
    samples = model.base_dist.sample((num_samples,)).to(device)

    # Latent â†’ data
    x, _ = model(samples, training=False)

    return (
        x.cpu().numpy(),
        z.cpu().numpy(),
        samples.cpu().numpy(),
    )

In [None]:
# %%
def display(x, z, samples):
    fig, axes = plt.subplots(2, 2, figsize=(8, 5))

    axes[0, 0].scatter(
        normalized_data[:, 0], normalized_data[:, 1], s=1, c="r"
    )
    axes[0, 0].set_title("Data space X")

    axes[0, 1].scatter(z[:, 0], z[:, 1], s=1, c="r")
    axes[0, 1].set_title("f(X)")

    axes[1, 0].scatter(samples[:, 0], samples[:, 1], s=1, c="g")
    axes[1, 0].set_title("Latent space Z")

    axes[1, 1].scatter(x[:, 0], x[:, 1], s=1, c="g")
    axes[1, 1].set_title("g(Z)")

    for ax in axes.flat:
        ax.set_xlim([-2, 2])
        ax.set_ylim([-2, 2])

    plt.tight_layout()
    plt.show()

In [None]:
# %%
x, z, samples = generate_samples(model, num_samples=3000)
display(x, z, samples)