In [61]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

In [63]:
import math

states = 10




def _cosine_variance_schedule(timesteps, epsilon=0.003, power=2.0):
    steps = torch.linspace(0, timesteps, steps=timesteps + 1, dtype=torch.float32)
    f_t = (
        torch.cos(((steps / timesteps + epsilon) / (1.0 + epsilon)) * math.pi * 0.5)
        ** power
    )
    # betas = torch.clip(1.0 - f_t[1:] / f_t[:timesteps], 0.0, 0.999)
    betas = torch.clip(1.0 - f_t[1:] / f_t[:timesteps], 0.0, 0.999)

    return betas

betas = _cosine_variance_schedule(states, power=2.0)
alphas = 1.0 - betas
alphas_bar = alphas.cumprod(dim=-1)
sqrt_alpha_bar = alphas_bar.sqrt()
sqrt_1_sub_alpha_b = (1.0 - alphas_bar).sqrt()

def forward_diffusion(
        clean_data: torch.Tensor, noise : torch.Tensor, target: torch.Tensor=None, keep_intermediate: bool = False) -> torch.Tensor:
    if keep_intermediate:
        images = [clean_data]

        for t in range(states-1):
            image_scale = (1-betas[t]).sqrt()
            noise_scale = betas[t].sqrt()
            noised = image_scale * images[-1] + noise_scale * torch.randn_like(
                clean_data
            )
            # noised = torch.clip(noised, min=-1, max=1)
            images.append(noised)

        # concatenate each step into one image for for each sample
        return torch.cat(images, dim=2)

    else:
        image_scale = sqrt_alpha_bar.gather(0, target).reshape(
            clean_data.shape[0], 1, 1, 1
        )
        noise_scale = sqrt_1_sub_alpha_b.gather(0, target).reshape(
            clean_data.shape[0], 1, 1, 1
        )
        noised = image_scale * clean_data + noise_scale * noise
        # noised = torch.clip(noised, min=-1, max=1)
        return noised


In [87]:
data = torch.rand(8*100).round().reshape(-1, 8, 1)

model = nn.Sequential(nn.Linear(2, 50), nn.GELU(), nn.Linear(50, 1))

optim = torch.optim.Adam(model.parameters())

for i, flip in enumerate(data):
    print(i, flip)
    noise = torch.randn(1)
    t = torch.randint(0, states-1, size=(8,))
    print(t)
    noised = forward_diffusion(clean_data=flip, noise=noise, target=t, keep_intermediate=False)
    print("noised:", noised)
    pred = model(torch.cat((noised, t)))
    loss = torch.mean(noised-pred)**2


0 tensor([[1.],
        [0.],
        [1.],
        [0.],
        [0.],
        [0.],
        [0.],
        [1.]])
tensor([3, 5, 4, 6, 5, 5, 1, 7])
noised: tensor([[[[ 0.7960],
          [-0.0114],
          [ 0.7960],
          [-0.0114],
          [-0.0114],
          [-0.0114],
          [-0.0114],
          [ 0.7960]]],


        [[[ 0.5706],
          [-0.0156],
          [ 0.5706],
          [-0.0156],
          [-0.0156],
          [-0.0156],
          [-0.0156],
          [ 0.5706]]],


        [[[ 0.6918],
          [-0.0137],
          [ 0.6918],
          [-0.0137],
          [-0.0137],
          [-0.0137],
          [-0.0137],
          [ 0.6918]]],


        [[[ 0.4355],
          [-0.0172],
          [ 0.4355],
          [-0.0172],
          [-0.0172],
          [-0.0172],
          [-0.0172],
          [ 0.4355]]],


        [[[ 0.5706],
          [-0.0156],
          [ 0.5706],
          [-0.0156],
          [-0.0156],
          [-0.0156],
          [-0.0156],
         

RuntimeError: Tensors must have same number of dimensions: got 4 and 1