# **Generic notebok used on colab for training models**

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=64, 
shuffle=True)

class ImprovedUNet(nn.Module):
    def __init__(self, in_ch=1, base_ch=64):
        super().__init__()
        # Encoder
        self.enc1 = nn.Sequential(nn.Conv2d(in_ch, base_ch, 3, 1, 1), nn.ReLU())
        self.enc2 = nn.Sequential(nn.Conv2d(base_ch, base_ch*2, 3, 2, 1), nn.ReLU())
        self.enc3 = nn.Sequential(nn.Conv2d(base_ch*2, base_ch*2, 3, 1, 1), nn.ReLU())
        # Decoder
        self.dec1 = nn.Sequential(nn.ConvTranspose2d(base_ch*2, base_ch*2, 4, 2, 1), nn.ReLU())
        self.dec2 = nn.Sequential(nn.Conv2d(base_ch*2, base_ch, 3, 1, 1), nn.ReLU())
        self.out = nn.Conv2d(base_ch, in_ch, 3, 1, 1)

    def forward(self, x, t):
        h = self.enc1(x)
        h = self.enc2(h)
        h = self.enc3(h)
        h = self.dec1(h)
        h = self.dec2(h)
        return self.out(h)
    
model = ImprovedUNet().to(device)

num_timesteps = 100
beta_start, beta_end = 1e-4, 0.01
betas = np.linspace(beta_start, beta_end, num_timesteps, dtype=np.float32)
alphas = 1.0 - betas
alpha_bars = np.cumprod(alphas)
alpha_bars = torch.tensor(alpha_bars, device=device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
mse = nn.MSELoss()


epochs = 100
for epoch in range(epochs):
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    for x, _ in pbar:
        x = x.to(device)
        b, c, h, w = x.shape

        t = torch.randint(0, num_timesteps, (b,), device=device).long()
        eps = torch.randn_like(x)

        alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1)
        x_t = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * eps

        eps_pred = model(x_t, t)
        loss = mse(eps_pred, eps)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pbar.set_postfix({"loss": loss.item()})

torch.save(model.state_dict(), "ddpm_mnist_100_100e_imp_low_beta.pt")
print("Training done, weights saved as ddpm_mnist_100_100e_imp_low_beta.pt")

Epoch 1:  98%|█████████▊| 921/938 [03:45<00:04,  4.09it/s, loss=0.118]


KeyboardInterrupt: 

In [None]:
!ls

data				     ddpm_mnist_100_100e_imp_low_beta.zip
ddpm_mnist_100_100e_imp_low_beta.pt  sample_data
