In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Import

In [None]:
import sys
from logging import INFO, StreamHandler, getLogger

logger = getLogger()
if not logger.hasHandlers():
    logger.addHandler(StreamHandler(sys.stdout))
logger.setLevel(INFO)

In [None]:
import os
import math
import pathlib

import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import Categorical, Normal

from src.frameworks.ddpm import DDPMConfig, DDPM

plt.rcParams["font.family"] = "serif"
plt.style.use("tableau-colorblind10")

# Define constants

In [None]:
ROOT_DIR = pathlib.Path(os.environ["PYTHONPATH"].split(":")[0]).resolve()

# Run an SDE

$$
dx = -\mu x dt + \sigma dW
$$
- 初期分布を混合正規分布に取る (二峰分布)

In [None]:
# Mixed Gaussian
init_means = [4.0, -3.0]
init_stds = [2.0, 1.0]
gauss_weights = [0.6, 0.4]

b = 2_000  # n_batches
n_steps = 1_000
dt = 1.0 / n_steps

# OU process: dx = -mu x dt + sigma dW
mu = 5.0
sigma = np.sqrt(2.0 * mu)

idxs = Categorical(torch.tensor(gauss_weights)).sample((b,))
y0 = torch.normal(
    mean=torch.tensor(init_means)[idxs], std=torch.tensor(init_stds)[idxs]
)
dW = math.sqrt(dt) * torch.randn(size=(b, n_steps))

yt = torch.zeros((b, n_steps + 1))
current = y0
yt[:, 0] = current.detach().clone()

for i in range(n_steps):
    current = current - mu * current * dt + sigma * dW[:, i]
    yt[:, i + 1] = current.detach().clone()

In [None]:
ts = torch.linspace(0, 1, n_steps + 1)

plt.rcParams["font.size"] = 14
fig = plt.plot()
ax = plt.subplot()

interval = 4
for i in range(100):
    ax.plot(ts[::interval], yt[i][::interval], lw=0.5, alpha=0.5)

ax.set_xlabel(r"Time, $t$")
ax.set_ylabel(r"State, $x$")
plt.show()

In [None]:
plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(20, 6))

for t, ax in zip([0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 750, 1000], axes.flatten()):
    data = yt[:, t].numpy().flatten()
    ax.hist(data, range=(-10, 10), bins=100, density=True, alpha=1.0)
    ax.set_xlim(-10, 10)
    ax.set_title(f"Timestep={t:04}/{config.n_timesteps}")
    ax.set_ylabel("PDF")
    ax.axvline(0, color="gray")

plt.tight_layout()
plt.show()

# Make a DDPM instance

- 分散保存型 (Variance-Preserving; VP) の拡散モデルを扱う．
- 順過程は下の様に書ける
$$
\begin{align}
dx_t &= - \mu_t x \; dt + \sqrt{2*\mu_t} \; dW \quad (t \in [0,1]) \\
\end{align}
$$

In [None]:
# OU process: dx = -mu x dt + sigma dW
mu = 5.0
sigma = np.sqrt(2.0 * mu)

# Mixed Gaussian
init_means = [4.0, -3.0]
init_stds = [2.0, 1.0]
gauss_weights = [0.6, 0.4]

config = DDPMConfig(
    start_beta=2 * mu,  # 始点と終点の beta の設定値を一定にすると，上の mu も一定となる
    end_beta=2 * mu,  # mu = beta / 2 (i.e., beta = 2 mu) の関係あり
    n_timesteps=1_000,
    n_channels=1,
    n_spaces=1,
)


class ExactScoreFunc(torch.nn.Module):

    def __init__(
        self,
        mu: float,
        sigma: float,
        n_t: int,
        init_means: list[float],
        init_stds: list[float],
        gauss_weights: list[float],
    ):
        super().__init__()

        times = torch.linspace(0, 1, n_t + 1, dtype=torch.float32)

        # 線形 SDE (OU 過程) の解より得られる分散などの解析形
        self.m_t = torch.exp(-mu * times)[1:]
        self.v_t = (sigma**2 / (2 * mu) * (1.0 - torch.exp(-2 * mu * times)))[1:]
        self.s_t = torch.sqrt(self.v_t)

        assert len(init_means) == len(init_stds) == len(gauss_weights)
        self.init_means = torch.tensor(init_means)
        self.init_stds = torch.tensor(init_stds)
        self.gauss_weights = torch.tensor(gauss_weights)

    def potential(self, yt: torch.Tensor, t_index: int):
        assert yt.ndim == 3  # batch, channel, and space
        m_t = torch.index_select(self.m_t, dim=0, index=t_index)[:, None, None]
        v_t = torch.index_select(self.v_t, dim=0, index=t_index)[:, None, None]
        # add channel and space dims

        # Calculate prob for the mixed Gaussian
        probs = torch.zeros_like(yt)
        for i in range(self.init_means.shape[0]):
            mean = self.init_means[i] * m_t
            var = self.init_stds[i] ** 2 * m_t**2 + v_t
            p = torch.exp(-0.5 * (yt - mean) ** 2 / var) / torch.sqrt(2 * math.pi * var)
            probs += self.gauss_weights[i] * p

        return -torch.log(probs)

    def score(self, yt: torch.Tensor, t_index: int) -> torch.Tensor:
        assert yt.ndim == 3  # batch, channel, and space
        m_t = torch.index_select(self.m_t, dim=0, index=t_index)[:, None, None]
        v_t = torch.index_select(self.v_t, dim=0, index=t_index)[:, None, None]
        # add channel and space dims

        # Calculate prob for the mixed Gaussian
        probs = torch.zeros_like(yt)
        for i in range(self.init_means.shape[0]):
            mean = self.init_means[i] * m_t
            var = self.init_stds[i] ** 2 * m_t**2 + v_t
            p = torch.exp(-0.5 * (yt - mean) ** 2 / var) / torch.sqrt(2 * math.pi * var)
            probs += self.gauss_weights[i] * p

        scores = torch.zeros_like(yt)
        for i in range(self.init_means.shape[0]):
            mean = self.init_means[i] * m_t
            var = self.init_stds[i] ** 2 * m_t**2 + v_t
            p = torch.exp(-0.5 * (yt - mean) ** 2 / var) / torch.sqrt(2 * math.pi * var)
            scores = scores - (self.gauss_weights[i] * p / probs) * (yt - mean) / var

        return scores

    def forward(self, yt: torch.Tensor, t_index: int, **kwargs) -> torch.Tensor:
        assert yt.ndim == 3  # batch, channel, and space
        scores = self.score(yt, t_index)
        s_t = torch.index_select(self.s_t, dim=0, index=t_index)[:, None, None]
        # add channel and space dims
        return (-s_t * scores).to(torch.float32)  # estimate noise


score_network = ExactScoreFunc(
    mu=mu,
    sigma=sigma,
    n_t=config.n_timesteps,
    init_means=init_means,
    init_stds=init_stds,
    gauss_weights=gauss_weights,
)

ddpm = DDPM(config=config, neural_net=score_network)

# Check variance preserving

- 分散保存型であることの確認
$$
{\rm Var}\left[x_t^2\right]= {\rm Var}\left[x_0^2\right]  + \left(1 - e^{-2 \mu t}\right)
$$
- 第二項が拡散モデル内部で用いるノイズの標準偏差を表す
- 初期データ $x_0$ の分散が 1 に規格化されていると仮定すれば，全分散は常に 1 となる

In [None]:
plt.rcParams["font.size"] = 14
plt.plot(ddpm.decays, label=r"${\rm decay} = e^{-\mu t}$")
plt.plot(ddpm.stds, label=r"${\rm std} = \sqrt{1-e^{-2\mu t}}$")
vars = ddpm.decays**2 + ddpm.stds**2
plt.plot(vars, label="total vars")
plt.axhline(0.0, ls="--")
plt.xlabel("Num. of Time Steps")
plt.legend()
plt.show()

# Forward process

In [None]:
b = 10_000
idxs = Categorical(torch.tensor(gauss_weights)).sample((b,))
y0 = torch.normal(
    mean=torch.tensor(init_means)[idxs], std=torch.tensor(init_stds)[idxs]
)[:, None, None]

plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(20, 6))
for t, ax in zip(torch.arange(-1, config.n_timesteps + 1, 100), axes.flatten()):
    if t == -1:
        yt = y0.detach().clone()
    else:
        noise = torch.randn_like(y0)
        yt = ddpm._forward_sample_y(
            y0=y0,
            t_index=torch.full((b,), fill_value=t, dtype=torch.long),
            noise=noise,
        )
    data = yt[:, 0, 0].numpy().flatten()
    ax.hist(data, range=(-10, 10), bins=100, density=True)
    ax.set_xlim(-10, 10)
    ax.set_title(f"t={t+1}\nmean={np.mean(data):.2f},std={np.std(data):.2f}")
    ax.set_ylabel("PDF")
    ax.axvline(0, color="k")
plt.tight_layout()
plt.show()

# Backward process

- 逆過程では，正規分布から始まり，混合正規分布へと発展させる
- 混合正規分布の FPE の解は解析的に書ける
- この結果を利用して，厳密にスコア関数を与えて，逆 SDE を解いている

In [None]:
ddpm.net.eval()
lst_img = ddpm.backward_sample_y(n_batches=10_000, n_return_steps=config.n_timesteps)

In [None]:
plt.rcParams["font.size"] = 14
fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(20, 6))

for t, ax in zip([0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 750, 1000], axes.flatten()):
    data = lst_img[int(t)][:, 0, 0].numpy().flatten()
    ax.hist(data, range=(-10, 10), bins=100, density=True, alpha=0.3)
    ax.set_xlim(-10, 10)
    ax.set_title(f"Timestep={t:04}/{config.n_timesteps}")
    ax.set_ylabel("PDF")
    ax.axvline(0, color="gray")

    for x in init_means:
        ax.axvline(x, ls="--", lw=0.5, alpha=0.5, color="gray")

    yt = torch.linspace(-10, 10, 101)[:, None, None]
    if t < config.n_timesteps:
        # add channel and space dim
        dummy_t = torch.full(size=(101,), fill_value=t).long()
        potential = score_network.potential(yt, dummy_t)
    else:
        potential = 0.5 * yt**2

    ax2 = ax.twinx()
    ax2.plot(yt.squeeze(), potential.squeeze(), ls="-", lw=2.0, color="coral")
    ax2.set_ylim(0, 20)

plt.tight_layout()
plt.show()

- 各パスを調べてみると，単峰の正規分布から二峰の混合正規分布に分かれている
- 逆 SDE を解いているため，時刻は $t=1.0$ から $t=0.0$ に向かっている
- SDE は Euler-Maruyama 法で解いている

In [None]:
ts = sorted(lst_img.keys())
yt = torch.stack([lst_img[t][:100].squeeze() for t in ts], dim=1)

plt.rcParams["font.size"] = 14
fig = plt.plot()
ax = plt.subplot()

interval = 4
for i in range(100):
    ax.plot(ts[::interval] / np.max(ts), yt[i][::interval], lw=0.5, alpha=0.5)

ax.set_xlabel(r"Time, $t$")
ax.set_ylabel(r"State, $x$")