From e2eda22676a623782bfc4d88b9fda9e287f0ccc7 Mon Sep 17 00:00:00 2001 From: JonathanCrabbe Date: Thu, 21 Dec 2023 12:25:03 +0000 Subject: [PATCH] Simplify SDE schedulers code --- src/fdiff/models/score_models.py | 2 + src/fdiff/schedulers/sde.py | 187 ++++++++++++------------------- tests/test_schedulers.py | 96 +++++++++------- 3 files changed, 127 insertions(+), 158 deletions(-) diff --git a/src/fdiff/models/score_models.py b/src/fdiff/models/score_models.py index a835a43..b190c89 100644 --- a/src/fdiff/models/score_models.py +++ b/src/fdiff/models/score_models.py @@ -105,6 +105,7 @@ def training_step( on_epoch=True, on_step=True, ) + assert isinstance(loss, torch.Tensor) return loss def validation_step( @@ -133,6 +134,7 @@ def set_loss_fn(self) -> tuple[Callable, Callable]: # Depending on the scheduler, get the right loss function if isinstance(self.noise_scheduler, DDPMScheduler): + assert hasattr(self.noise_scheduler, "config") scheduler_config = self.noise_scheduler.config self.max_time = scheduler_config.num_train_timesteps diff --git a/src/fdiff/schedulers/sde.py b/src/fdiff/schedulers/sde.py index 3d14bfc..3a1bdf1 100644 --- a/src/fdiff/schedulers/sde.py +++ b/src/fdiff/schedulers/sde.py @@ -2,10 +2,9 @@ import abc import math from collections import namedtuple +from typing import Optional -import numpy as np import torch -from torch import device SamplingOutput = namedtuple("SamplingOutput", ["prev_sample"]) @@ -13,18 +12,20 @@ class SDE(abc.ABC): """SDE abstract class. Functions are designed for a mini-batch of inputs.""" - def __init__(self, fourier_noise_scaling: bool = False): + def __init__(self, fourier_noise_scaling: bool = False, eps: float = 1e-5): """Construct an SDE. Args: N: number of discretization time steps. """ super().__init__() self.noise_scaling = fourier_noise_scaling + self.eps = eps + self.G: Optional[torch.Tensor] = None @property - @abc.abstractmethod def T(self) -> float: """End time of the SDE.""" + return 1.0 @abc.abstractmethod def marginal_prob( @@ -33,30 +34,57 @@ def marginal_prob( """Parameters to determine the marginal distribution of the SDE, $p_t(x)$.""" @abc.abstractmethod - def prior_sampling(self, shape: tuple[int, ...]) -> torch.Tensor: - """Generate one sample from the prior distribution, $p_T(x)$.""" + def step( + self, model_output: torch.Tensor, timestep: float, sample: torch.Tensor + ) -> SamplingOutput: + ... - def initialize(self, max_len: int, device: str | device) -> None: - """Finish the initialization of the scheduler by setting G (scaling diagonal) and the device. + def set_noise_scaling(self, max_len: int) -> None: + """Finish the initialization of the scheduler by setting G (scaling diagonal) Args: - max_len (_type_): _description_ - device (_type_): _description_ + max_len (int): number of time steps of the time series """ - if not self.noise_scaling: - # We will get the identity by putting G in the diagonal - G = torch.ones(max_len, device=device) - else: - G = 1 / (math.sqrt(2 * max_len)) * torch.ones(max_len, device=device) + + G = torch.ones(max_len) + if self.noise_scaling: + G = 1 / (math.sqrt(2)) * G # Double the variance for the first component G[0] *= math.sqrt(2) + # Double the variance for the middle component if max_len is even + if max_len % 2 == 0: + G[max_len // 2] *= math.sqrt(2) self.G = G # Tensor of size (max_len) self.G_matrix = torch.diag(G) # Tensor of size (max_len, max_len) assert G.shape[0] == max_len - # Set the device - self.device = device + def set_timesteps(self, num_diffusion_steps: int) -> None: + self.timesteps = torch.linspace(1.0, self.eps, num_diffusion_steps) + self.step_size = self.timesteps[0] - self.timesteps[1] + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + x0 = original_samples + mean, _ = self.marginal_prob(x0, timesteps) + + # Note that the std is not used here because the noise has been scaled prior to calling the function + sample = mean + noise + return sample + + def prior_sampling(self, shape: tuple[int, ...]) -> torch.Tensor: + # Reshape the G matrix to be (1, max_len, max_len) + scaling_matrix = self.G_matrix.view( + -1, self.G_matrix.shape[0], self.G_matrix.shape[1] + ) + + z = torch.randn(*shape) + # Return G@z where z \sim N(0,I) + return torch.matmul(scaling_matrix, z) class VEScheduler(SDE): @@ -73,17 +101,9 @@ def __init__( sigma_max: largest sigma. N: number of discretization steps """ - super().__init__(fourier_noise_scaling) + super().__init__(fourier_noise_scaling=fourier_noise_scaling, eps=eps) self.sigma_min = sigma_min self.sigma_max = sigma_max - self.eps = eps - - self.device = None - self.G = None - - @property - def T(self) -> float: - return 1.0 def marginal_prob( self, x: torch.Tensor, t: torch.Tensor @@ -91,54 +111,23 @@ def marginal_prob( torch.Tensor, torch.Tensor ]: # perturbation kernel P(X(t)|X(0)) parameters if self.G is None: - self.initialize(x.shape[1], x.device) + self.set_noise_scaling(x.shape[1]) + assert self.G is not None sigma_min = torch.tensor(self.sigma_min).type_as(t) sigma_max = torch.tensor(self.sigma_max).type_as(t) - std = (sigma_min * (sigma_max / sigma_min) ** t).view(-1, 1) * self.G + std = (sigma_min * (sigma_max / sigma_min) ** t).view(-1, 1) * self.G.to( + x.device + ) mean = x return mean, std - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.Tensor, - ) -> torch.Tensor: - x0 = original_samples - mean, _ = self.marginal_prob(x0, timesteps) - - # Note that the std is not used here because the noise has been scaled prior to calling the function - sample = mean + noise - return sample - - def get_sigma( - self, timestep: torch.Tensor | float | np.ndarray - ) -> torch.Tensor | float | np.ndarray: - return torch.tensor( - self.sigma_min * (self.sigma_max / self.sigma_min) ** timestep, - device=self.device, - ) - - def set_timesteps(self, num_diffusion_steps: int) -> None: - self.timesteps = torch.linspace( - 1.0, self.eps, num_diffusion_steps, device=self.device - ) - self.step_size = self.timesteps[0] - self.timesteps[1] - def prior_sampling(self, shape: tuple[int, ...]) -> torch.Tensor: - # Reshape the G matrix to be (1, max_len, max_len) - scaling_matrix = self.G_matrix.view( - -1, self.G_matrix.shape[0], self.G_matrix.shape[1] - ) - scaling_matrix = self.sigma_max * scaling_matrix - - z = torch.randn(*shape, device=self.device) - # Return G@z where z \sim N(0,I) - return torch.matmul(scaling_matrix, z) + # In the case of VESDE, the prior is scaled by the maximum noise std + return self.sigma_max * super().prior_sampling(shape) def step( - self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor + self, model_output: torch.Tensor, timestep: float, sample: torch.Tensor ) -> SamplingOutput: """Single denoising step, used for sampling. @@ -157,9 +146,9 @@ def step( * (self.sigma_max / self.sigma_min) ** (timestep) ) - diffusion = torch.diag_embed(sqrt_derivative * self.G) + diffusion = torch.diag_embed(sqrt_derivative * self.G).to(device=sample.device) - # Compute drift for the reverse + # Compute drift for the reverse: f(x,t) - G(x,t)G(x,t)^{T}*score drift = -( torch.matmul(diffusion * diffusion, model_output) ) # Notice that the drift of the forward is 0 @@ -169,7 +158,7 @@ def step( assert self.step_size > 0 x = ( sample - - drift * self.step_size + - drift * self.step_size # - sign because of reverse time + torch.sqrt(self.step_size) * torch.matmul(diffusion, z) ) output = SamplingOutput(prev_sample=x) @@ -191,25 +180,17 @@ def __init__( N: number of discretization steps G: tensor of size max_len """ - super().__init__(fourier_noise_scaling) + super().__init__(fourier_noise_scaling=fourier_noise_scaling, eps=eps) self.beta_0 = beta_min self.beta_1 = beta_max - self.eps = eps - - # To be initialized later - self.device = None - self.G = None - - @property - def T(self) -> float: - return 1.0 def marginal_prob( self, x: torch.Tensor, t: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: # first check if G has been init. if self.G is None: - self.initialize(x.shape[1], x.device) + self.set_noise_scaling(x.shape[1]) + assert self.G is not None # Compute -1/2*\int_0^t \beta(s) ds log_mean_coeff = ( @@ -220,48 +201,19 @@ def marginal_prob( torch.exp(log_mean_coeff[(...,) + (None,) * len(x.shape[1:])]) * x ) # mean: (batch_size, max_len, n_channels) - std = ( - torch.sqrt((1.0 - torch.exp(2.0 * log_mean_coeff.view(-1, 1)))) * self.G + std = torch.sqrt( + (1.0 - torch.exp(2.0 * log_mean_coeff.view(-1, 1))) + ) * self.G.to( + x.device ) # std: (batch_size, max_len) return mean, std - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.Tensor, - ) -> torch.Tensor: - x0 = original_samples - mean, _ = self.marginal_prob(x0, timesteps) - - # Note that the std is not used here because the noise has been scaled prior to calling the function - sample = mean + noise - return sample - - def get_beta(self, timestep: torch.Tensor | float | np.ndarray) -> torch.Tensor: - return torch.tensor( - self.beta_0 + timestep * (self.beta_1 - self.beta_0), device=self.device - ) - - def set_timesteps(self, num_diffusion_steps: int) -> None: - self.timesteps = torch.linspace( - 1.0, self.eps, num_diffusion_steps, device=self.device - ) - self.step_size = self.timesteps[0] - self.timesteps[1] - - def prior_sampling(self, shape: tuple | list | torch.Size) -> torch.Tensor: - # Reshape the G matrix to be (1, max_len, max_len) - scaling_matrix = self.G_matrix.view( - -1, self.G_matrix.shape[0], self.G_matrix.shape[1] - ) - z = torch.randn(*shape, device=self.device) - - # Return G@z where z \sim N(0,I) - return torch.matmul(scaling_matrix, z) + def get_beta(self, timestep: float) -> float: + return self.beta_0 + timestep * (self.beta_1 - self.beta_0) def step( - self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor + self, model_output: torch.Tensor, timestep: float, sample: torch.Tensor ) -> SamplingOutput: """Single denoising step, used for sampling. @@ -274,10 +226,11 @@ def step( SamplingOutput: _description_ """ beta = self.get_beta(timestep) - diffusion = torch.diag_embed(torch.sqrt(beta).view(-1, 1) * self.G) + assert self.G is not None + diffusion = torch.diag_embed(math.sqrt(beta) * self.G).to(device=sample.device) # Compute drift - drift = -0.5 * beta.view(-1, 1, 1) * sample - ( + drift = -0.5 * beta * sample - ( torch.matmul(diffusion * diffusion, model_output) ) diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py index e517c9e..3f58d78 100644 --- a/tests/test_schedulers.py +++ b/tests/test_schedulers.py @@ -1,11 +1,12 @@ from copy import deepcopy +import pytest import pytorch_lightning as pl import torch from fdiff.models.score_models import ScoreModule from fdiff.sampling.sampler import DiffusionSampler -from fdiff.schedulers.sde import VPScheduler +from fdiff.schedulers.sde import SDE, VEScheduler, VPScheduler from fdiff.utils.dataclasses import DiffusableBatch from .test_datamodules import DummyDatamodule @@ -19,22 +20,21 @@ low = 0 high = 10 num_samples = 48 -beta_0 = 0.01 -beta_1 = 20 +beta_min = 0.01 +beta_max = 20 batch_size = 50 -def test_noise_adder() -> None: - """Test the noise adder.""" - # Set the parameters - beta_min = 0.01 - beta_max = 1 - max_len = 20 - G = torch.ones(max_len) - G[0] *= 2 - +@pytest.mark.parametrize( + "scheduler_type", + [ + VEScheduler, + VPScheduler, + ], +) +def test_forward(scheduler_type: SDE) -> None: # Create the SDE - scheduler = VPScheduler(beta_min=beta_min, beta_max=beta_max) + scheduler: SDE = scheduler_type() # Create a dummy time series x = torch.randn(size=(batch_size, max_len, n_channels), device="cpu") @@ -44,42 +44,39 @@ def test_noise_adder() -> None: assert x_noisy.shape == x.shape - beta = scheduler.get_beta(timestep=timesteps) - # Check that each element of beta is between beta_min and beta_max - assert torch.all(beta >= beta_min) - assert torch.all(beta <= beta_max) +@pytest.mark.parametrize( + "scheduler_type", + [ + VEScheduler, + VPScheduler, + ], +) +def test_backward(scheduler_type: SDE) -> None: + t = 0.5 + + scheduler: SDE = scheduler_type() + scheduler.set_noise_scaling(max_len=max_len) scheduler.set_timesteps(num_diffusion_steps=1000) + noise = torch.randn(size=(batch_size, max_len, n_channels), device="cpu") model_output = torch.randn(size=(batch_size, max_len, n_channels), device="cpu") - timesteps = torch.ones(size=(batch_size,), device="cpu") * 0.5 - scheduler_output = scheduler.step(model_output, timestep=timesteps, sample=x_noisy) - assert scheduler_output.prev_sample.shape == x_noisy.shape + scheduler_output = scheduler.step(model_output, timestep=t, sample=noise) + assert scheduler_output.prev_sample.shape == noise.shape -def instantiate_score_model() -> ScoreModule: - noise_scheduler = VPScheduler( - beta_min=beta_0, beta_max=beta_1, fourier_noise_scaling=False - ) - score_model = ScoreModule( - n_channels=n_channels, - max_len=max_len, - noise_scheduler=noise_scheduler, - d_model=d_model, - n_head=n_head, - num_layers=num_layers, - num_training_steps=10, - ) - return score_model - - -def instantiate_trainer() -> pl.Trainer: - return pl.Trainer(max_epochs=1, accelerator="cpu") - -def test_score_module_with_vpsde() -> None: +@pytest.mark.parametrize( + "scheduler_type", + [ + VEScheduler, + VPScheduler, + ], +) +def test_training(scheduler_type: SDE) -> None: torch.manual_seed(42) - score_model = instantiate_score_model() + noise_scheduler = scheduler_type() + score_model = instantiate_score_model(noise_scheduler) # Check that the forward call produces tensor of the right shape X = torch.randn((batch_size, max_len, n_channels)) @@ -118,3 +115,20 @@ def test_score_module_with_vpsde() -> None: # Check the shape of the samples assert samples.shape == (num_samples, max_len, n_channels) + + +def instantiate_score_model(scheduler: SDE) -> ScoreModule: + score_model = ScoreModule( + n_channels=n_channels, + max_len=max_len, + noise_scheduler=scheduler, + d_model=d_model, + n_head=n_head, + num_layers=num_layers, + num_training_steps=10, + ) + return score_model + + +def instantiate_trainer() -> pl.Trainer: + return pl.Trainer(max_epochs=1, accelerator="cpu")