<a href="https://colab.research.google.com/github/Fu-Pei-Yin/Deep-Generative-Mode/blob/main/Diffusion_Model%E5%9C%A8%E6%89%8B%E5%AF%AB%E6%95%B8%E5%AD%97%E4%B8%8A%E7%9A%84%E6%87%89%E7%94%A8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import os
from tqdm import tqdm

# ---------------------
# 設定
# ---------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

batch_size = 128
ddpm_timesteps = 1000  # 訓練可先用 200，加速收斂；正式可用 1000
ddpm_lr = 2e-4
ddpm_epochs = 200
save_dir = "ddpm_outputs"
os.makedirs(save_dir, exist_ok=True)

# ---------------------
# MNIST 資料
# ---------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # scale to [-1,1]
])
train_ds = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)

# ---------------------
# Beta schedule
# ---------------------
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=2e-2):
    return torch.linspace(beta_start, beta_end, timesteps)

betas = linear_beta_schedule(ddpm_timesteps).to(device)
alphas = 1.0 - betas
alpha_cumprod = torch.cumprod(alphas, dim=0)

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    t = t.long()
    sqrt_ac = alpha_cumprod[t].sqrt().view(-1,1,1,1)
    sqrt_one_minus = (1 - alpha_cumprod[t]).sqrt().view(-1,1,1,1)
    return sqrt_ac * x_start + sqrt_one_minus * noise

# ---------------------
# 簡單卷積去噪器 (Denoiser)
# ---------------------
class SmallDenoiser(nn.Module):
    def __init__(self, time_embed_dim=128):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_embed_dim),
            nn.ReLU(),
            nn.Linear(time_embed_dim, time_embed_dim)
        )
        self.temb_proj = nn.Linear(time_embed_dim, 64)

        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 32, 3, padding=1)
        self.conv5 = nn.Conv2d(32, 1, 3, padding=1)
        self.act = nn.ReLU()

    def forward(self, x, t):
        t = t.float().unsqueeze(-1) / float(ddpm_timesteps)
        temb = self.time_mlp(t)
        temb2 = self.temb_proj(temb).unsqueeze(-1).unsqueeze(-1)

        h1 = self.act(self.conv1(x))
        h2 = self.act(self.conv2(h1))
        h3 = self.act(self.conv3(h2) + temb2)
        h4 = self.act(self.conv4(h3) + h1)
        out = self.conv5(h4)
        return out

denoiser = SmallDenoiser().to(device)
opt_ddpm = optim.Adam(denoiser.parameters(), lr=ddpm_lr)

def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.normal_(m.weight, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
denoiser.apply(weights_init)

# ---------------------
# 訓練 DDPM
# ---------------------
print("Start DDPM training...")
for epoch in range(1, ddpm_epochs+1):
    pbar = tqdm(train_loader, desc=f"DDPM Epoch {epoch}/{ddpm_epochs}")
    for imgs, _ in pbar:
        imgs = imgs.to(device)
        bs = imgs.size(0)
        t = torch.randint(0, ddpm_timesteps, (bs,), device=device)
        noise = torch.randn_like(imgs)
        x_t = q_sample(imgs, t, noise=noise)

        opt_ddpm.zero_grad()
        pred_noise = denoiser(x_t, t)
        loss = F.mse_loss(pred_noise, noise)
        loss.backward()
        opt_ddpm.step()
        pbar.set_postfix(ddpm_loss=loss.item())

    # 每個 epoch 儲存樣本（簡單逆向採樣）
    with torch.no_grad():
        n = 16
        x = torch.randn(n, 1, 28, 28, device=device)
        for time_step in reversed(range(ddpm_timesteps)):
            t = torch.full((n,), time_step, device=device, dtype=torch.long)
            beta_t = betas[time_step]
            alpha_t = alphas[time_step]
            alpha_cumprod_t = alpha_cumprod[time_step]
            pred_noise = denoiser(x, t)
            x_prev_mean = (1.0 / alpha_t.sqrt()) * (x - (beta_t / (1 - alpha_cumprod_t)).sqrt() * pred_noise)
            if time_step > 0:
                noise = torch.randn_like(x)
                sigma_t = beta_t.sqrt()
                x = x_prev_mean + sigma_t * noise
            else:
                x = x_prev_mean
        x = (x + 1) / 2
        save_image(make_grid(x, nrow=4), os.path.join(save_dir, f"ddpm_samples_epoch{epoch}.png"))

print("DDPM training finished.")

# ---------------------
# 最終取樣 (10 張圖)
# ---------------------
with torch.no_grad():
    n = 10
    x = torch.randn(n, 1, 28, 28, device=device)
    for time_step in reversed(range(ddpm_timesteps)):
        t = torch.full((n,), time_step, device=device, dtype=torch.long)
        beta_t = betas[time_step]
        alpha_t = alphas[time_step]
        alpha_cumprod_t = alpha_cumprod[time_step]
        pred_noise = denoiser(x, t)
        x_prev_mean = (1.0 / alpha_t.sqrt()) * (x - (beta_t / (1 - alpha_cumprod_t)).sqrt() * pred_noise)
        if time_step > 0:
            noise = torch.randn_like(x)
            sigma_t = beta_t.sqrt()
            x = x_prev_mean + sigma_t * noise
        else:
            x = x_prev_mean
    x = (x + 1) / 2  # scale back to [0,1]
    ddpm_final_path = os.path.join(save_dir, "ddpm_final_samples.png")
    save_image(make_grid(x, nrow=5), ddpm_final_path)
print("Saved final ddpm samples ->", ddpm_final_path)

# ---------------------
# 四種模型結果對照圖 (示範)
# 假設其他模型輸出路徑如下：
vae_path = "vae_outputs/vae_samples.png"       # VAE 隨機生成 10 張
gan_path = "gan_outputs/gan_samples.png"       # GAN 隨機生成 10 張
cgan_path = "cgan_outputs/cgan_samples.png"    # cGAN 生成 0-9 各 10 張 (10x10)
# ddpm_final_path 已有

from PIL import Image
import matplotlib.pyplot as plt

def load_image(path):
    return Image.open(path).convert("L")  # 灰階

# 載入圖片
vae_img = load_image(vae_path)
gan_img = load_image(gan_path)
cgan_img = load_image(cgan_path)
ddpm_img = load_image(ddpm_final_path)

# 建立對照圖 (2x2)
fig, axs = plt.subplots(2, 2, figsize=(10,10))
axs[0,0].imshow(vae_img, cmap='gray')
axs[0,0].set_title("VAE")
axs[0,1].imshow(gan_img, cmap='gray')
axs[0,1].set_title("GAN")
axs[1,0].imshow(cgan_img, cmap='gray')
axs[1,0].set_title("cGAN")
axs[1,1].imshow(ddpm_img, cmap='gray')
axs[1,1].set_title("Diffusion (DDPM)")

for ax in axs.flatten():
    ax.axis('off')

plt.tight_layout()
compare_path = os.path.join(save_dir, "compare_all_models.png")
plt.savefig(compare_path)
plt.show()
print("Saved comparison image ->", compare_path)


Device: cpu


100%|██████████| 9.91M/9.91M [00:00<00:00, 42.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.05MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.51MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.72MB/s]


Start DDPM training...


DDPM Epoch 1/200:   9%|▊         | 40/468 [00:38<06:59,  1.02it/s, ddpm_loss=0.868]