In [ ]:
import torch
import torch.nn as nn

In [ ]:
'''
Important modules:
- Dataset
- Gaussian Diffusion
- Tokenizer
- Trainer
'''

In [ ]:
def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
    
    # Directly from the paper although more testing is required.
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.9999)

In [1]:
class GaussianDiffusion(nn.Module):
    def __init__(self):
        # Load required paremeters

        betas = cosine_beta_schedule(timesteps=1000)
        alphas = 1-betas
        alphas_cumprod = torch.cumprod(alphas)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1,0), value=1) # add padding to the front and none at the end (at every index, the prev. value is available at that index)
        
        # Calculate the remaining values needed here to easily calculate the posterior varianes downstream
        
        posterior_variance = betas * (1-alphas_cumprod_prev) / (1-alphas_cumprod)
        
        # Calculate remaining important values including text vals
        
        
    def denoise_fn():
        # TODO
        return pred_noise
        
    def extract(a, t, x_shape):
        b, *_ = t.shape
        out = a.gather(-1, t) # t gives us the time steps that we'd like to get values for (from a)
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))
        
    def q_mean_variance(self, x_start, t):
        mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
        variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
        log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
        return mean, variance, log_variance

    # Just applying the formulas presented in the paper
    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )
    
    def q_sample(self, x_start, t, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )        
    
    
    @torch.inference_mode()
    def p_sample_loop(self, shape, cond = None, cond_scale = 1.):
        device = self.betas.device

        b = shape[0]
        img = torch.randn(shape, device=device)

        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), cond = cond, cond_scale = cond_scale)

        return unnormalize_img(img)
    

    @torch.inference_mode()
    def p_sample(self, x, t, cond = None, cond_scale = 1., clip_denoised = True):
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, clip_denoised = clip_denoised, cond = cond, cond_scale = cond_scale)
        noise = torch.randn_like(x)
        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    
    # Reverse -> recon
    def p_losses(self, x_start, t, cond = None, noise = None, **kwargs):
        b, c, f, h, w, device = *x_start.shape, x_start.device
        noise = default(noise, lambda: torch.randn_like(x_start))

        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

        if is_list_str(cond):
            cond = bert_embed(tokenize(cond), return_cls_repr = self.text_use_bert_cls)
            cond = cond.to(device)

        x_recon = self.denoise_fn(x_noisy, t, cond = cond, **kwargs)

        if self.loss_type == 'l1':
            loss = F.l1_loss(noise, x_recon)
        elif self.loss_type == 'l2':
            loss = F.mse_loss(noise, x_recon)
        else:
            raise NotImplementedError()

        return loss
        



SyntaxError: incomplete input (72868075.py, line 26)

In [0]:
'''
Training (General structure)
'''
# load config parameters here
step = 0
max_steps = config["steps"]
opt = ...
scaler = ...
model = ...
ema = ...

# initialize training loop
def train():
    while step < max_steps:
        # read from a dataloader

        # Use torch lightning to scale and train
        with autocast(enabled=amp):
            data = next(data_loader).cuda()
            loss = model(
                data,
            )
            scaler.scale(loss/gradient_accumulate_scaler).backward()
            
        # Save content at regular intervals
        
        scaler.step(opt)
        scaler.update()
        opt.zero_grad()
        
        step +=1
        
    print("Training completed!")