## DDPM

In [23]:
import torch
import torch.nn as nn

class FakeUNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(28 *28, 28*28)
    
    def forward(self, x, t):
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x.view(x.size(0), 1, 28, 28)
    
    

In [25]:
class MyDDPM(nn.Module):
    def __init__(self, network, n_steps=200, min_beta=10**(-4), max_beta=0.02, device=None, image_chw=(1, 28, 28)) -> None:
        super().__init__()
        self.n_steps = n_steps
        self.device = device
        self.image_chw = image_chw
        self.network = network.to(device)
        self.betas = torch.linspace(min_beta, max_beta, n_steps).to(device)
        self.alphas = 1- self.betas
        # self.alpha_bars = 
        # alpha_bars_list = []
        # for i in range(len(self.alphas)):
        #     alpha_bars_list.append(torch.prod(self.alphas[:i+1]))
        # self.alpha_bars = torch.tensor(alpha_bars_list).to(device)
        
        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i+1]) for i in range(len(self.alphas))]).to(device)

    def forward(self, x0, t, eta=None):
        n, c, h, w = x0.shape
        a_bar = self.alpha_bars[t]
        
        if eta is None:
            eta = torch.randn(n, c, h, w).to(self.device) # create noise
        
        noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
        return noisy
    
    def backword(self, x, t):
        return self.network(x, t)
    
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    network = FakeUNet().to(device)
    ddpm = MyDDPM(network, device=device)
    x0 = torch.randn(1, 1, 28, 28).to(device)
    t = torch.randint(0, ddpm.n_steps, (1,)).to(device)
    
    # [1, 1, 28, 28]
    noisy_image = ddpm(x0, t)
    print("forward pass: noisy_image.shape", noisy_image.shape)
    predict_noise = ddpm.backword(noisy_image, t)
    print("backward pass: predict_noise.shape", predict_noise.shape)

forward pass: noisy_image.shape torch.Size([1, 1, 28, 28])
backward pass: predict_noise.shape torch.Size([1, 1, 28, 28])


## UNet

In [None]:
# unet.jpg
def sinusoidal_embedding(n, d):
    embedding = torch.zeros(n, d)
    wk = torch.tensor([1 / 1])