## Building Model
- data
- Foward Pass
  - intialize parameters
  - forward pass
- Backward pass
  - U-Net
  - 

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
import warnings

warnings.filterwarnings("ignore")

if torch.__version__ != "2.0.0":
    raise Exception("Reconnect to the Correct Kernel")
else:
    print("You are connected to Kernel, Torch Version: ", torch.__version__)


IMG_SIZE = 28
BATCH_SIZE = 4


data_transforms = transforms.Compose(
    [
        # transforms.Resize((IMG_SIZE, IMG_SIZE)),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),  # Scales data into [0,1]
        transforms.Lambda(lambda t: (t * 2) - 1),  # Scale between [-1, 1]
    ]
)
reverse_transforms = transforms.Compose(
    [
        transforms.Lambda(lambda t: (t + 1) / 2),
        # transforms.Lambda(lambda t: t.permute(1, 2, 0)),  # CHW to HWC
        transforms.Lambda(lambda t: t * 255.0),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ]
)


data = torchvision.datasets.FashionMNIST(
    root="./data", download=True, transform=data_transforms
)
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [None]:
import torch.nn.functional as F

# * Step 1: create Noise levels, amount of noise to add
T = 25
START = 0.0001
END = 0.02

Betas = torch.linspace(START, END, T)

# * Step 2: create Alphas, amount of original image
Alphas = 1 - Betas

# * Step 3: calculate the cumulative product of alphas
Alphas_cumprod = torch.cumprod(Alphas, dim=0)

# * Step 4: calculate the cumulative product of Alphas but replace the last element with 1
Alphas_cumprod_prev = F.pad(Alphas[:-1], (1, 0), value=1)

# * Step 5: calculate the sqrt of the reciprocal of alphas
Alphas_sqrt_reciprocals = torch.sqrt(1 / Alphas)

# * Step 6: calculate the square root of cumulative product of Alphas
Alphas_sqrt_cumpord = torch.sqrt(torch.cumprod(Alphas, dim=0))

# * Step 7: calculate the square root of one minus cumulative product of Alphas
Alphas_sqrt_one_minus_cumprod = torch.sqrt(1 - torch.cumprod(Alphas, dim=0))

# * step 8: caclulate the postierior variance
postierior_variance = Betas * (1 - Alphas_cumprod_prev) / (1 - Alphas_cumprod)



In [None]:
def get_index_from_list(list, t, image_shape):
    batch_size = t.shape[0] #* get the batch size
    out = list.gather(-1, t.cpu()) #* gather the values from the list & move to cpu
    #* reshape the output to match the shape of the input.
    out = out.reshape(batch_size, *((1,)) * (len(image_shape) - 1)).to(t.device) 
    return out

def Forward_pass(
    image, 
    t, 
    sqrt_alphas_cumpord,
    sqrt_one_minus_alphas_cumpord,
    device = "cpu",
    torch_seed = 42
):  
    torch.manual_seed(torch_seed)
    #* initialize the noise
    noise = torch.randn_like(image)
    
    #* get the sqrt of alphas cumpord at time t
    sqrt_alphas_cumpord_t = get_index_from_list(sqrt_alphas_cumpord, t, image.shape)
    
    #* get the sqrt of one minus alphas cumpord at time t
    sqrt_one_minus_alphas_cumpord_t = get_index_from_list(sqrt_one_minus_alphas_cumpord, t, image.shape)
    
    image_part = sqrt_alphas_cumpord_t.to(device) * image.to(device)
    noise_part = sqrt_one_minus_alphas_cumpord_t.to(device) * noise.to(device)
    
    noisy_image_t = image_part + noise_part
    return noisy_image_t, noise_part
