In [1]:
'''
Here we go through the Imagen modules

Important functions:
- GaussianDiffusionContinuousTimes
- t5_encode_text
- p_mean_variance
- p_sample
- p_sample_loop
- noise_scheduler
    - 


==> Reference: Code extracted from the Imagen implementation by Phil Wang.
'''

SyntaxError: invalid syntax (3784184826.py, line 4)

In [None]:
'''
Distribution explanation here: https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material
'''

import torch.nn as nn

class GaussianDiffusionContinuousTimes(nn.Module):
    def __init__(self, *, noise_schedule, timesteps=1000):
        super().__init__()

        # Initialize the noise schedule based on the provided argument
        if noise_schedule == "linear":
            self.log_snr = beta_linear_log_snr  # Use linear noise schedule
        elif noise_schedule == "cosine":
            self.log_snr = alpha_cosine_log_snr  # Use cosine noise schedule
        else:
            raise ValueError(f'invalid noise schedule {noise_schedule}')  # Raise error for invalid schedule

        self.num_timesteps = timesteps  # Set the number of timesteps

    # Return a tensor of noise levels for each sample in the batch
    def get_times(self, batch_size, noise_level, *, device):
        return torch.full((batch_size,), noise_level, device=device, dtype=torch.float32)  # Return tensor of noise levels

    # Sample random times from a uniform distribution
    def sample_random_times(self, batch_size, *, device):
        return torch.zeros((batch_size,), device=device).float().uniform_(0, 1)  # Sample random times from a uniform distribution

    # Get the condition value based on the times
    def get_condition(self, times):
        return maybe(self.log_snr)(times)  # Get condition value based on times

    # Generate a tensor of sampling timesteps
    def get_sampling_timesteps(self, batch, *, device):
        times = torch.linspace(1., 0., self.num_timesteps + 1, device=device)  # Generate sampling timesteps
        times = repeat(times, 't -> b t', b=batch)  # Repeat timesteps for each sample in the batch
        times = torch.stack((times[:, :-1], times[:, 1:]), dim=0)  # Stack timesteps for consecutive samples
        times = times.unbind(dim=-1)  # Unbind timesteps along the last dimension
        return times  # Return sampling timesteps

    # Calculate the posterior distribution parameters
    def q_posterior(self, x_start, x_t, t, *, t_next=None):
        t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min=0.))  # Calculate next timestep

        log_snr = self.log_snr(t)  # Get log signal-to-noise ratio for current timestep
        log_snr_next = self.log_snr(t_next)  # Get log signal-to-noise ratio for next timestep
        log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))  # Pad dimensions for broadcasting

        alpha, sigma = log_snr_to_alpha_sigma(log_snr)  # Convert log SNR to alpha and sigma
        alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)  # Convert log SNR of next timestep to alpha and sigma

        c = -expm1(log_snr - log_snr_next)  # Calculate c coefficient
        posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start)  # Calculate posterior mean

        posterior_variance = (sigma_next ** 2) * c  # Calculate posterior variance
        posterior_log_variance_clipped = log(posterior_variance, eps=1e-20)  # Clip and log the posterior variance
        return posterior_mean, posterior_variance, posterior_log_variance_clipped  # Return posterior parameters

    # Sample from the diffusion process
    def q_sample(self, x_start, t, noise=None):
        dtype = x_start.dtype  # Get the data type of x_start

        if isinstance(t, float):  # Check if t is a float
            batch = x_start.shape[0]  # Get the batch size
            t = torch.full((batch,), t, device=x_start.device, dtype=dtype)  # Create a tensor of t values

        noise = default(noise, lambda: torch.randn_like(x_start))  # Generate noise if not provided
        log_snr = self.log_snr(t).type(dtype)  # Get log SNR and cast to the data type of x_start
        log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)  # Pad dimensions for broadcasting
        alpha, sigma =  log_snr_to_alpha_sigma(log_snr_padded_dim)  # Convert log SNR to alpha and sigma

        return alpha * x_start + sigma * noise, log_snr, alpha, sigma  # Return sampled value, log SNR, alpha, and sigma

    # Sample from the diffusion process from a specified start time to an end time
    def q_sample_from_to(self, x_from, from_t, to_t, noise=None):
        shape, device, dtype = x_from.shape, x_from.device, x_from.dtype  # Get shape, device, and data type of x_from
        batch = shape[0]  # Get the batch size

        if isinstance(from_t, float):  # Check if from_t is a float
            from_t = torch.full((batch,), from_t, device=device, dtype=dtype)  # Create a tensor of from_t values

        if isinstance(to_t, float):  # Check if to_t is a float
            to_t = torch.full((batch,), to_t, device=device, dtype=dtype)  # Create a tensor of to_t values

        noise = default(noise, lambda: torch.randn_like(x_from))  # Generate noise if not provided

        log_snr = self.log_snr(from_t)  # Get log SNR for from_t
        log_snr_padded_dim = right_pad_dims_to(x_from, log_snr)  # Pad dimensions for broadcasting
        alpha, sigma =  log_snr_to_alpha_sigma(log_snr_padded_dim)  # Convert log SNR to alpha and sigma

        log_snr_to = self.log_snr(to_t)  # Get log SNR for to_t
        log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to)  # Pad dimensions for broadcasting
        alpha_to, sigma_to =  log_snr_to_alpha_sigma(log_snr_padded_dim_to)  # Convert log SNR of to_t to alpha and sigma

        return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha  # Return sampled value

    # Predict the start from a given velocity
    def predict_start_from_v(self, x_t, t, v):
        log_snr = self.log_snr(t)  # Get log SNR for t
        log_snr = right_pad_dims_to(x_t, log_snr)  # Pad dimensions for broadcasting
        alpha, sigma = log_snr_to_alpha_sigma(log_snr)  # Convert log SNR to alpha and sigma
        return alpha * x_t - sigma * v  # Return predicted start

    # Predict the start from a given noise
    def predict_start_from_noise(self, x_t, t, noise):
        log_snr = self.log_snr(t)  # Get log SNR for t
        log_snr = right_pad_dims_to(x_t, log_snr)  # Pad dimensions for broadcasting
        alpha, sigma = log_snr_to_alpha_sigma(log_snr)  # Convert log SNR to alpha and sigma
        return (x_t - sigma * noise) / alpha.clamp(min=1e-8)  # Return predicted start
