In [1]:
import torch
import matplotlib.pyplot as plt
import math
from typing import Tuple

# Beta schedule

In [2]:
def betas_for_alpha_bar(num_diffusion_time_steps: int,
                        max_beta: float = 0.999,
                        alpha_transform_type: str = "cosine"):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_time_steps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    """
    if alpha_transform_type == "cosine":

        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")

    betas = []
    for i in range(num_diffusion_time_steps):
        t1 = i / num_diffusion_time_steps
        t2 = (i + 1) / num_diffusion_time_steps
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
    return torch.tensor(betas, dtype=torch.float32)


def make_beta_schedule(schedule: str,
                       num_steps: int,
                       beta_start: float = 1e-4,
                       beta_end: float = 2e-2) -> torch.Tensor:
    """Make a scheduled sequence of betas.

    Args:
        schedule (str):         The schedule type.
        num_steps (int):        The number of time steps.
        beta_start (float):     The start value of the linear schedule.
        beta_end (float):       The end value of the linear schedule.
    """
    if schedule == "linear":
        return torch.linspace(beta_start, beta_end, num_steps, dtype=torch.float32)
    elif schedule == "scaled_linear":
        # this schedule is very specific to the latent diffusion model.
        return torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_steps, dtype=torch.float32) ** 2
    elif schedule == "squaredcos_cap_v2":
        # Glide cosine schedule
        return betas_for_alpha_bar(num_steps)
    elif schedule == "sigmoid":
        # GeoDiff sigmoid schedule
        betas = torch.linspace(-6, 6, num_steps, dtype=torch.float32)
        return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
    else:
        raise ValueError(f"schedule '{schedule}' is unknown.")

In [3]:
BETA_SCHEDULE = "linear"

In [4]:
betas = make_beta_schedule(schedule=BETA_SCHEDULE,
                           num_steps=1000,
                           beta_start=1e-4,
                           beta_end=0.02)
plt.plot(betas, label="beta")
plt.legend()
plt.show()

## Alphas

$$\alpha_t = 1 - \beta_t$$
$$\bar{\alpha}_t = \prod_{s=0}^{t}\alpha_s$$
$$\vec{x}_t = \sqrt{\bar{\alpha}_t} \cdot \vec{x}_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \vec{\epsilon}_0$$

In [5]:
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = alphas_cumprod ** 0.5
sqrt_betas_cumprod = (1 - alphas_cumprod) ** 0.5

plt.plot(sqrt_alphas_cumprod, label="sqrt_alphas_cumprod (coeff for x_0)")
plt.plot(sqrt_betas_cumprod, label="sqrt_betas_cumprod (coeff for noise)")
plt.legend()
plt.plot()

## Add noise

$$\vec{x}_t = \sqrt{\bar{\alpha}_t} \cdot \vec{x}_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \vec{\epsilon}_0$$
$$\vec{x}_t \sim q(\vec{x}_t|\vec{x}_0) = N(\vec{x}_t; \sqrt{\bar{\alpha}_t} \vec{x}_0, (1 - \bar{\alpha}) \vec{I}) $$

In [6]:
def add_noise(x: torch.Tensor,
              noise: torch.Tensor,
              timesteps: torch.Tensor) -> torch.Tensor:
    """
    Add noise to the input tensor.
    
    Args:
        x (torch.Tensor):               The original samples.  Shape: (batch_size, ...).
        noise (torch.Tensor):           The noise tensor to be added.  Shape should be the same as x.
        timesteps (torch.Tensor):       Time steps for each batch.  Shape: (batch_size,).
    
    Returns:
        torch.Tensor:                   The noisy samples.
    """
    assert x.size() == noise.size(), \
        f"The size of x ({x.size()}) and noise ({noise.size()}) should be the same."
    
    timesteps = timesteps.long()
    
    assert timesteps.dim() == 1, \
        f"The timesteps should be a 1D tensor, but got {timesteps.dim()}D."
    assert x.size(0) == timesteps.size(0), \
        f"The batch size of x ({x.size(0)}) and timesteps ({timesteps.size(0)}) should be the same."
    
    assert 0 <= timesteps.min() <= timesteps.max() < 1000, \
        (f"The timesteps should be in the range of [0, {999}], "
         f"but got {timesteps.min()} to {timesteps.max()}.")
    
    sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[timesteps].flatten()
    sqrt_betas_cumprod_t = sqrt_betas_cumprod[timesteps].flatten()
    
    while sqrt_alphas_cumprod_t.dim() < x.dim():
        sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.unsqueeze(-1)
        sqrt_betas_cumprod_t = sqrt_betas_cumprod_t.unsqueeze(-1)
    
    noisy_samples = sqrt_alphas_cumprod_t * x + sqrt_betas_cumprod_t * noise
    
    return noisy_samples

In [7]:
x_0 = torch.randn(3, 10, 16)
noise = torch.randn(3, 10, 16)
timesteps = torch.randint(0, 1000, (3,))
x_t = add_noise(x_0, noise, timesteps)
assert x_0.shape == x_t.shape
print("Time steps:", timesteps)

## Posterior mean and variance

$$q(\vec{x}_{t - 1} | \vec{x}_t, \vec{x}_0) = N(\vec{x}; \tilde{\mu}_t (\vec{x}_t, \vec{x}_0), \tilde{\beta}_t \vec{I})$$

where

$$ \tilde{\mu}_t (\vec{x}_t, \vec{x}_0) = \frac{\sqrt{\bar{\alpha}_{t-1}} \  \beta_t}{1-\bar{\alpha}_t} \vec{x}_0 \  + \ \frac{\sqrt{\alpha_t} \ (1 - \bar{\alpha}_{t-1})}{1-\bar{\alpha}_t} \vec{x}_t$$

and

$$ \tilde{\beta}_t = \frac{1 - \bar{\alpha}_{t - 1}}{1 - \bar{\alpha}_t} \beta_t$$

In [8]:
alphas_cumprod_prev = torch.cat([torch.Tensor([1.0]), alphas_cumprod[:-1]])
assert alphas_cumprod_prev.shape == alphas_cumprod.shape
print(alphas_cumprod[0:8])
print(alphas_cumprod_prev[0:8])

In [9]:
posterior_mean_coeff1 = betas * (alphas_cumprod_prev ** 0.5) / (1.0 - alphas_cumprod)
posterior_mean_coeff2 = (1.0 - alphas_cumprod_prev) * (alphas ** 0.5) / (1 - alphas_cumprod)

In [10]:
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)

In [11]:
_, (plt1, plt2) = plt.subplots(1, 2, figsize=(12, 4))
plt1.plot(posterior_variance, label="posterior_variance")
plt1.legend()
plt2.plot(posterior_mean_coeff1, label="posterior_mean_coeff1")
plt2.plot(posterior_mean_coeff2, label="posterior_mean_coeff2")
plt2.legend()
plt.show()

In [12]:
posterior_log_variance_clipped = torch.log(torch.cat([posterior_variance[1:2], posterior_variance[1:]])) 
plt.plot(posterior_log_variance_clipped, label="posterior_log_variance_clipped")
plt.legend()
plt.show()

In [13]:
def extract_into_tensor(arr: torch.Tensor,
                        timesteps: torch.Tensor,
                        broadcast_shape: torch.Size) -> torch.Tensor:
    """Extract values from a 1-D numpy array for a batch of indices.

    Args:
        arr:                            the 1-D numpy array.
        timesteps:                      a tensor of indices into the array to extract.
        broadcast_shape:                a larger shape of K dimensions with the batch dimension equal to
                                        the length of timesteps.

    Returns:
        a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    res = arr.to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res + torch.zeros(broadcast_shape, device=timesteps.device)

def q_posterior_mean_variance(samples: torch.Tensor,
                              noisy_samples: torch.Tensor,
                              timesteps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute the mean and variance of the diffusion posterior:
        q(x_{t-1} | x_t, x_0)

    Refer to the equations (6) and (7) of the DDPM paper https://arxiv.org/abs/2006.11239 for more details.

    Args:
        samples (torch.Tensor):         The original samples.  Shape: (batch_size, ...).
        noisy_samples (torch.Tensor):   The noisy samples returned by add_noise.  Shape: (batch_size, ...).
        timesteps (torch.Tensor):       Time steps for each batch.  Shape: (batch_size,).

    Returns:    Tuple of two tensors:
        torch.Tensor:                   The mean of the posterior.
        torch.Tensor:                   The clipped log variance of the posterior.
    """
    posterior_mean = (extract_into_tensor(posterior_mean_coeff1,
                                          timesteps,
                                          samples.shape) * samples
                    + extract_into_tensor(posterior_mean_coeff2,
                                          timesteps,
                                          samples.shape) * noisy_samples)

    posterior_log_variance = extract_into_tensor(posterior_log_variance_clipped,
                                                 timesteps,
                                                 samples.shape)

    return posterior_mean, posterior_log_variance

In [14]:
x_tm1_mean, x_tm1_logvar = q_posterior_mean_variance(x_0, x_t, timesteps)

In [15]:
print((x_tm1_mean - x_t).mean())

In [16]:
alphas_cumprod_prev[0]

In the DDIM paper equation (16), the variance is calculated as
$$\sigma^2_t = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \left(1 - \frac{\bar{\alpha}_t}{\bar{\alpha}_{t-1}} \right). $$
Since the latter term $$1 - \frac{\bar{\alpha}_t}{\bar{\alpha}_{t-1}} = \beta_t,$$
we have
$$\sigma^2_t = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \beta_i = \tilde{\beta}_t.$$

In [17]:
var2 = (1 - alphas_cumprod_prev) / (1 - alphas_cumprod) * (1 - alphas_cumprod / alphas_cumprod_prev)
assert (var2 - posterior_variance).max() < 1e-6

## DDIM Denoise

In [18]:
# Inference timesteps:
def update_time_steps(num_training_steps: int,
                      num_inference_steps: int,
                      timestep_spacing: str = "leading",
                      steps_offset: int = 0) -> torch.Tensor:
    """Update the discrete time steps used for the diffusion chain (to be run before inference).

    Args:
        num_training_steps (int):       The number of training steps.
        num_inference_steps (int):      The number of inference steps.
        timestep_spacing (str):         The spacing of the time steps. Could be one of "linspace", "leading", or
                                        "trailing". Defaults to "leading".
        steps_offset (int):             The offset of the inference steps. Defaults to 0.
    Returns:
        time_steps (torch.Tensor):      The time steps.
    """
    assert num_inference_steps <= num_training_steps, \
        (f"The number of inference steps ({num_inference_steps}) should be less than "
         f"the number of training steps ({num_training_steps}).")

    if timestep_spacing == "linspace":
        time_steps = torch.linspace(0, num_training_steps - 1, num_inference_steps).flip(0).round().long()
    elif timestep_spacing == "leading":
        step_ratio = num_training_steps // num_inference_steps
        # creates integer time steps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
        time_steps = (torch.arange(0, num_inference_steps) * step_ratio).flip(0).round().long()
        time_steps += steps_offset
    elif timestep_spacing == "trailing":
        step_ratio = num_training_steps / num_inference_steps
        # creates integer time steps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
        time_steps = (torch.arange(num_training_steps, 0, -step_ratio)).round().long()
        time_steps -= 1
    else:
        raise ValueError(
            f"{timestep_spacing} is not supported. Please make sure to "
            f"choose one of 'leading' or 'trailing' or 'linspace'."
        )

    assert time_steps.shape == (num_inference_steps,)

    return time_steps

inference_timesteps = update_time_steps(1000, 10, "leading")
print(inference_timesteps)

In [19]:
current_timestep = torch.Tensor([300, 400, 200]).long()
prev_timestep = torch.Tensor([200, 300, 100]).long()
sqrt_betas_cumprod_t = sqrt_betas_cumprod[current_timestep]
sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[current_timestep]
alphas_cumprod_prev_t = alphas_cumprod[prev_timestep]

## Calculate variance

In [20]:
min_log_variance = posterior_log_variance_clipped
max_log_variance = torch.log(betas)
assert torch.all(min_log_variance <= max_log_variance)

log_var = torch.randint(0, 1000, (1000,)) / 2000.0 - 1.0
log_var = 0.5 * (log_var + 1) * (max_log_variance - min_log_variance) + min_log_variance
variance = torch.exp(log_var)

plt.plot(min_log_variance, label='min_log_variance')
plt.plot(max_log_variance, label='max_log_variance')
plt.plot(log_var, label='log_var')
plt.legend()
plt.show()

## Predict $x_0$ by DDIM Paper Equation (12)

$$\hat{x}_0 = \frac{\vec{x}_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_\theta^{(t)}(\vec{x}_t) }{\sqrt{\bar{\alpha}_t}} $$

In [21]:
model_output = noise
sqrt_betas_cumprod_t = sqrt_betas_cumprod_t.view(3, *([1] * (x_t.dim() - 1)))
sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.view(3, *([1] * (x_t.dim() - 1)))
pred_original_sample = (x_t - sqrt_betas_cumprod_t * model_output) / sqrt_alphas_cumprod_t
assert pred_original_sample.shape == x_t.shape

## Predict direction pointing to $x_t$

In [27]:
eta = 0.1
sigma = eta * (variance[current_timestep] ** 0.5)
alphas_cumprod_prev_t = alphas_cumprod_prev_t.view(3,  *([1] * (x_t.dim() - 1)))
sigma = sigma.view(3,  *([1] * (x_t.dim() - 1)))
pred_sample_direction = ((1 - alphas_cumprod_prev_t - sigma**2) ** 0.5) * model_output

## Calculate the previous sample

In [31]:
prev_sample = (alphas_cumprod_prev_t ** 0.5) * pred_original_sample + pred_sample_direction