In [1]:
import random
import imageio
import numpy as np
from argparse import ArgumentParser

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import einops # git visualization
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

from torchvision.transforms import Compose, ToTensor, Lambda  # preprocessing
from torchvision.datasets.mnist import MNIST, FashionMNIST # Dataset




In [None]:
fashion = True
train_flag = True


## Training

In [None]:
def training_loop(ddpm, loader, n_epochs, optim, device, display=False, store_path = "ddpm_model.path"):
    mse = nn.MSELoss()
    best_loss = float('inf')
    n_steps = ddpm.n_steps
    
    # 循环
    for epoch in tqdm(range(n_epochs), desc="Training progress", color = "#00fff00"):
        epoch_loss = 0
        for step, batch in enumerate(tqdm(loader, leave=False, desc = f"Epoch {epoch+1}/{n_epochs}", colour="#00ff00")):
            # loading data 
            x0 = batch[0].to(device)
            n = len(x0)
            
            # forward pass
            noisy_imgs = ddpm(x0, t, eta)  # [128, 1, 28, 28]

            # backward pass
            eta_theta=ddpm.backward(noisy_imgs, t) # predict noise
            
            loss = mse(eta_theta, eta)
            optim.zero_grad()
            loss.backward()
            optim.step()
            
            epoch_loss += loss.item() * len(x0) / len(loader.dataset)
            
        if display:
            show_images(generate_new_images(ddpm, device=device), f"Images generated at epoch {epoch+1}")  
        
        log_string = f"Loss at epoch{epoch+1}:{epoch_loss:.3f}"
        
        # save
        if best_loss > epoch_loss:
            best_loss = epoch_loss
            torch.save(ddpm.state_dict(), store_path)
            log_string += " --> best model"
        print(log_string)

In [None]:

store_path = "ddpm_fashion.pt" if fashion else "ddpm_mnist.pt"
if train_flag:
    training_loop(ddpm, loader, n_epochs, optim, device)
    

In [None]:
# testing and generation
best_model = MyDDPM(MyUNet(), n_step = n_steps, device = device)
best_model.load_state_dict(torch.load(store_path, map_location = device))
best_model.eval()
print("Module liaded")

In [None]:
generated = generated_new_images(
    best_model,
    n_samples = 100,
    device = device,
    gif_name = "fashion.gif" if fashion else "mnist.gif
)

show_images(generated, "Final result")