In [23]:
from torch import nn
import torch.nn.functional as F
import torch
import math

from einops import rearrange, reduce
from torch.amp import autocast

from tqdm.auto import tqdm



In [19]:
def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

In [31]:
def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
    alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

In [82]:
class GaussianDiffusion1D(nn.Module):
    def __init__(
        self,
        model,
        *,
        seq_length = 256,
        device = 'cpu',
        # dimensionality of the latent space. In this case, it is 256 dimensions in the peptide
        # latent space

        timesteps = 1000,
        # amount of timesteps when training the diffusion model
        # MORE timesteps ==> better quality, slower training
        # FEWER timesteps ==> lower quality, faster training

        sampling_timesteps = 200,
        # must be sampling_timesteps < timesteps, allows faster sampling

        schedule_fn_kwargs = dict(),
        # s for cosine controls curve shape

        # ddim_sampling_eta = 0.,
        # lower ==> deterministic sampling, less diversity
        # higher ==> stochastic sampling, more diversity 

        min_snr_loss_weight = False,
        min_snr_gamma = 5,

        s = 0.008
        # cosine_beta_schedule parameter

    ):
        # Model INIT
        super().__init__()
        self.model = model
        self.channels = self.model.channels
        self.seq_length = seq_length
        self.device = device

        # create betas schedule, derived constants
        betas = cosine_beta_schedule(timesteps, s)
        alphas = 1. - betas #alphas
        alphas_cumprod = torch.cumprod(alphas, dim=0)  # alpha bar @ t
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)  # alpha bar @ t - 1

        # timesteps and ddim sampling (Fast sampling)
        self.num_timesteps = timesteps # training timesteps
        self.sampling_timesteps = sampling_timesteps
        # self.ddim_sampling_eta = ddim_sampling_eta

        ######################
        #  Buffer Registery  #
        ######################

        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))

        register_buffer('betas', betas)
        register_buffer('alphas_cumprod', alphas_cumprod)
        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        register_buffer('posterior_variance', posterior_variance)

        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))


        # SNR CALCULATION
        # Less weight is given to easier (higher snr) images
        snr = alphas_cumprod / (1 - alphas_cumprod)
        maybe_clipped_snr = snr.clone()
        maybe_clipped_snr.clamp_(max = min_snr_gamma)

        # pred_v loss_weight
        loss_weight = maybe_clipped_snr / (snr + 1)
        register_buffer('loss_weight', loss_weight)

        # normalize data, automaticcaly
        self.normalize = normalize_to_neg_one_to_one
        self.unnormalize = unnormalize_to_zero_to_one

    ##################
    # Model Training #
    ##################

    # Takes in a sequence of shape (batch_size, seq_length) and returns the losses
    def forward(self, sequence, *args, **kwargs):
        batch_size, n, device = *sequence.shape, sequence.device
        
        assert n == self.seq_length, f'seq_length must be {self.seq_length}'
        t = torch.randint(0, self.num_timesteps, (batch_size,), device=device).long()
        assert(t.shape == (batch_size, 1)) # each member of the batch has a t

        sequence = self.normalize(sequence)
        # sequence is (batch_size, seq_length) and t is #(batch_size, 1)
        return self.p_losses(sequence, t, *args, **kwargs)

    # computes the loss of the NN when denoising:
    def p_losses(self, x_start, t):
        batch_size, n = x_start.shape
        noise = lambda: torch.randn_like(x_start) 

        # adds gaussian noise to x
        x = self.q_sample(x_start=x_start, t=t, noise=noise)

        # have model predict the original (non-noisy) image via the proxy of veloicty:
        # we need to reshape x to have a channel though for the UNET to work
        model_out = self.model(x.reshape(batch_size, 1, n), t)
        model_out = model_out.reshape(batch_size, n) #undo reshape

        target = self.predict_v(x_start, t, noise)

        # compute batched mean loss
        loss = F.mse_loss(model_out, target, reduction = 'none')
        loss = reduce(loss, 'b ... -> b', 'mean')

        # weight these according to the snr of the training
        loss = loss * extract(self.loss_weight, t, loss.shape) 
        return loss.mean()
    
    # Adds gaussian noise to x_t to make x_{t+1}
    # @autocast('cuda', enabled = False)
    def q_sample(self, x_start, t, noise):
        noise = default(noise, lambda: torch.randn_like(x_start))

        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

    # prediction of velocity given noise
    def predict_v(self, x_start, t, noise):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
        )
    
    # predict x_0 from x_t and v
    def predict_start_from_v(self, x_t, t, v):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
        )
    
    ############
    # SAMPLING #
    ############

    @torch.no_grad()
    def sample(self, batch_size = 16, cond_fn=None, guidance_kwargs=None):
        # no ddim sampling for now
        return self.p_sample_loop((batch_size, self.seq_length), cond_fn=cond_fn, guidance_kwargs=guidance_kwargs)
    
    @torch.no_grad()
    def p_sample_loop(self, shape, cond_fn=None, guidance_kwargs=None):
        batch_size, device = shape[0], self.device

        sequence = torch.randn(shape, device=device)
        for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
            # sample from the diffusion model to improve sequence
            sequence = self.p_sample(sequence, t, cond_fn, guidance_kwargs)

        return self.unnormalize(sequence)
    
    def condition_mean(self, cond_fn, mean,variance, x, t, guidance_kwargs=None):
        """
        Compute the mean for the previous step, given a function cond_fn that
        computes the gradient of a conditional log probability with respect to
        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
        condition on y.
        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
        """
        # this fixes a bug in the official OpenAI implementation:
        # https://github.com/openai/guided-diffusion/issues/51 (see point 1)
        # use the predicted mean for the previous timestep to compute gradient
        gradient = cond_fn(mean, t, **guidance_kwargs)
        new_mean = (
            mean.float() + variance * gradient.float()
        )
        print("gradient: ",(variance * gradient.float()).mean())
        return new_mean
    
    # calculates the mean and variance of the distribution
    def p_mean_variance(self, x, t,):
        preds = self.model_predictions(x, t)
        x_start = preds.pred_x_start
        x_start.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
        return model_mean, posterior_variance, posterior_log_variance, x_start
    
    @torch.no_grad()
    def p_sample(self, x, t: int, cond_fn=None, guidance_kwargs=None):
        batch_size, n, device = *x.shape, x.device
        batched_times = torch.full((batch_size,), t, device=x.device, dtype=torch.long)

        # predicted mean / variance of new distribution
        model_mean, variance, model_log_variance, x_start = self.p_mean_variance(x=x, t=batched_times)
        
        if exists(cond_fn) and exists(guidance_kwargs):
            model_mean = self.condition_mean(cond_fn, model_mean, variance, x, batched_times, guidance_kwargs)
        
        noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
        pred_sequence = model_mean + (0.5 * model_log_variance).exp() * noise
        return pred_sequence, x_start

from guided_diffusion_1d import Unet1D
# Set small values for testing
batch_size = 2
seq_length = 16
channels = 1

# Create a simple model
model = Unet1D(dim = seq_length, channels=channels)

# Create the diffusion model
diffusion = GaussianDiffusion1D(
    model,
    seq_length=seq_length,
    timesteps=10,  # Use a small number of timesteps for quick testing
)
    

In [84]:
print("Sampling from diffusion model...")
samples = diffusion.sample(batch_size=batch_size)

Sampling from diffusion model...


sampling loop time step:   0%|          | 0/10 [00:00<?, ?it/s]


AttributeError: 'GaussianDiffusion1D' object has no attribute 'model_predictions'