Skip to content

Commit

Permalink
Simplify SDE schedulers code
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanCrabbe committed Dec 21, 2023
1 parent 794dcc4 commit e2eda22
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 158 deletions.
2 changes: 2 additions & 0 deletions src/fdiff/models/score_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def training_step(
on_epoch=True,
on_step=True,
)
assert isinstance(loss, torch.Tensor)
return loss

def validation_step(
Expand Down Expand Up @@ -133,6 +134,7 @@ def set_loss_fn(self) -> tuple[Callable, Callable]:
# Depending on the scheduler, get the right loss function

if isinstance(self.noise_scheduler, DDPMScheduler):
assert hasattr(self.noise_scheduler, "config")
scheduler_config = self.noise_scheduler.config
self.max_time = scheduler_config.num_train_timesteps

Expand Down
187 changes: 70 additions & 117 deletions src/fdiff/schedulers/sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@
import abc
import math
from collections import namedtuple
from typing import Optional

import numpy as np
import torch
from torch import device

SamplingOutput = namedtuple("SamplingOutput", ["prev_sample"])


class SDE(abc.ABC):
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""

def __init__(self, fourier_noise_scaling: bool = False):
def __init__(self, fourier_noise_scaling: bool = False, eps: float = 1e-5):
"""Construct an SDE.
Args:
N: number of discretization time steps.
"""
super().__init__()
self.noise_scaling = fourier_noise_scaling
self.eps = eps
self.G: Optional[torch.Tensor] = None

@property
@abc.abstractmethod
def T(self) -> float:
"""End time of the SDE."""
return 1.0

@abc.abstractmethod
def marginal_prob(
Expand All @@ -33,30 +34,57 @@ def marginal_prob(
"""Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""

@abc.abstractmethod
def prior_sampling(self, shape: tuple[int, ...]) -> torch.Tensor:
"""Generate one sample from the prior distribution, $p_T(x)$."""
def step(
self, model_output: torch.Tensor, timestep: float, sample: torch.Tensor
) -> SamplingOutput:
...

def initialize(self, max_len: int, device: str | device) -> None:
"""Finish the initialization of the scheduler by setting G (scaling diagonal) and the device.
def set_noise_scaling(self, max_len: int) -> None:
"""Finish the initialization of the scheduler by setting G (scaling diagonal)
Args:
max_len (_type_): _description_
device (_type_): _description_
max_len (int): number of time steps of the time series
"""
if not self.noise_scaling:
# We will get the identity by putting G in the diagonal
G = torch.ones(max_len, device=device)
else:
G = 1 / (math.sqrt(2 * max_len)) * torch.ones(max_len, device=device)

G = torch.ones(max_len)
if self.noise_scaling:
G = 1 / (math.sqrt(2)) * G
# Double the variance for the first component
G[0] *= math.sqrt(2)
# Double the variance for the middle component if max_len is even
if max_len % 2 == 0:
G[max_len // 2] *= math.sqrt(2)

self.G = G # Tensor of size (max_len)
self.G_matrix = torch.diag(G) # Tensor of size (max_len, max_len)
assert G.shape[0] == max_len

# Set the device
self.device = device
def set_timesteps(self, num_diffusion_steps: int) -> None:
self.timesteps = torch.linspace(1.0, self.eps, num_diffusion_steps)
self.step_size = self.timesteps[0] - self.timesteps[1]

def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
x0 = original_samples
mean, _ = self.marginal_prob(x0, timesteps)

# Note that the std is not used here because the noise has been scaled prior to calling the function
sample = mean + noise
return sample

def prior_sampling(self, shape: tuple[int, ...]) -> torch.Tensor:
# Reshape the G matrix to be (1, max_len, max_len)
scaling_matrix = self.G_matrix.view(
-1, self.G_matrix.shape[0], self.G_matrix.shape[1]
)

z = torch.randn(*shape)
# Return G@z where z \sim N(0,I)
return torch.matmul(scaling_matrix, z)


class VEScheduler(SDE):
Expand All @@ -73,72 +101,33 @@ def __init__(
sigma_max: largest sigma.
N: number of discretization steps
"""
super().__init__(fourier_noise_scaling)
super().__init__(fourier_noise_scaling=fourier_noise_scaling, eps=eps)
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.eps = eps

self.device = None
self.G = None

@property
def T(self) -> float:
return 1.0

def marginal_prob(
self, x: torch.Tensor, t: torch.Tensor
) -> tuple[
torch.Tensor, torch.Tensor
]: # perturbation kernel P(X(t)|X(0)) parameters
if self.G is None:
self.initialize(x.shape[1], x.device)
self.set_noise_scaling(x.shape[1])
assert self.G is not None

sigma_min = torch.tensor(self.sigma_min).type_as(t)
sigma_max = torch.tensor(self.sigma_max).type_as(t)
std = (sigma_min * (sigma_max / sigma_min) ** t).view(-1, 1) * self.G
std = (sigma_min * (sigma_max / sigma_min) ** t).view(-1, 1) * self.G.to(
x.device
)
mean = x
return mean, std

def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
x0 = original_samples
mean, _ = self.marginal_prob(x0, timesteps)

# Note that the std is not used here because the noise has been scaled prior to calling the function
sample = mean + noise
return sample

def get_sigma(
self, timestep: torch.Tensor | float | np.ndarray
) -> torch.Tensor | float | np.ndarray:
return torch.tensor(
self.sigma_min * (self.sigma_max / self.sigma_min) ** timestep,
device=self.device,
)

def set_timesteps(self, num_diffusion_steps: int) -> None:
self.timesteps = torch.linspace(
1.0, self.eps, num_diffusion_steps, device=self.device
)
self.step_size = self.timesteps[0] - self.timesteps[1]

def prior_sampling(self, shape: tuple[int, ...]) -> torch.Tensor:
# Reshape the G matrix to be (1, max_len, max_len)
scaling_matrix = self.G_matrix.view(
-1, self.G_matrix.shape[0], self.G_matrix.shape[1]
)
scaling_matrix = self.sigma_max * scaling_matrix

z = torch.randn(*shape, device=self.device)
# Return G@z where z \sim N(0,I)
return torch.matmul(scaling_matrix, z)
# In the case of VESDE, the prior is scaled by the maximum noise std
return self.sigma_max * super().prior_sampling(shape)

def step(
self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor
self, model_output: torch.Tensor, timestep: float, sample: torch.Tensor
) -> SamplingOutput:
"""Single denoising step, used for sampling.
Expand All @@ -157,9 +146,9 @@ def step(
* (self.sigma_max / self.sigma_min) ** (timestep)
)

diffusion = torch.diag_embed(sqrt_derivative * self.G)
diffusion = torch.diag_embed(sqrt_derivative * self.G).to(device=sample.device)

# Compute drift for the reverse
# Compute drift for the reverse: f(x,t) - G(x,t)G(x,t)^{T}*score
drift = -(
torch.matmul(diffusion * diffusion, model_output)
) # Notice that the drift of the forward is 0
Expand All @@ -169,7 +158,7 @@ def step(
assert self.step_size > 0
x = (
sample
- drift * self.step_size
- drift * self.step_size # - sign because of reverse time
+ torch.sqrt(self.step_size) * torch.matmul(diffusion, z)
)
output = SamplingOutput(prev_sample=x)
Expand All @@ -191,25 +180,17 @@ def __init__(
N: number of discretization steps
G: tensor of size max_len
"""
super().__init__(fourier_noise_scaling)
super().__init__(fourier_noise_scaling=fourier_noise_scaling, eps=eps)
self.beta_0 = beta_min
self.beta_1 = beta_max
self.eps = eps

# To be initialized later
self.device = None
self.G = None

@property
def T(self) -> float:
return 1.0

def marginal_prob(
self, x: torch.Tensor, t: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# first check if G has been init.
if self.G is None:
self.initialize(x.shape[1], x.device)
self.set_noise_scaling(x.shape[1])
assert self.G is not None

# Compute -1/2*\int_0^t \beta(s) ds
log_mean_coeff = (
Expand All @@ -220,48 +201,19 @@ def marginal_prob(
torch.exp(log_mean_coeff[(...,) + (None,) * len(x.shape[1:])]) * x
) # mean: (batch_size, max_len, n_channels)

std = (
torch.sqrt((1.0 - torch.exp(2.0 * log_mean_coeff.view(-1, 1)))) * self.G
std = torch.sqrt(
(1.0 - torch.exp(2.0 * log_mean_coeff.view(-1, 1)))
) * self.G.to(
x.device
) # std: (batch_size, max_len)

return mean, std

def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
x0 = original_samples
mean, _ = self.marginal_prob(x0, timesteps)

# Note that the std is not used here because the noise has been scaled prior to calling the function
sample = mean + noise
return sample

def get_beta(self, timestep: torch.Tensor | float | np.ndarray) -> torch.Tensor:
return torch.tensor(
self.beta_0 + timestep * (self.beta_1 - self.beta_0), device=self.device
)

def set_timesteps(self, num_diffusion_steps: int) -> None:
self.timesteps = torch.linspace(
1.0, self.eps, num_diffusion_steps, device=self.device
)
self.step_size = self.timesteps[0] - self.timesteps[1]

def prior_sampling(self, shape: tuple | list | torch.Size) -> torch.Tensor:
# Reshape the G matrix to be (1, max_len, max_len)
scaling_matrix = self.G_matrix.view(
-1, self.G_matrix.shape[0], self.G_matrix.shape[1]
)
z = torch.randn(*shape, device=self.device)

# Return G@z where z \sim N(0,I)
return torch.matmul(scaling_matrix, z)
def get_beta(self, timestep: float) -> float:
return self.beta_0 + timestep * (self.beta_1 - self.beta_0)

def step(
self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor
self, model_output: torch.Tensor, timestep: float, sample: torch.Tensor
) -> SamplingOutput:
"""Single denoising step, used for sampling.
Expand All @@ -274,10 +226,11 @@ def step(
SamplingOutput: _description_
"""
beta = self.get_beta(timestep)
diffusion = torch.diag_embed(torch.sqrt(beta).view(-1, 1) * self.G)
assert self.G is not None
diffusion = torch.diag_embed(math.sqrt(beta) * self.G).to(device=sample.device)

# Compute drift
drift = -0.5 * beta.view(-1, 1, 1) * sample - (
drift = -0.5 * beta * sample - (
torch.matmul(diffusion * diffusion, model_output)
)

Expand Down
Loading

0 comments on commit e2eda22

Please sign in to comment.