In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import time
import random
import numpy as np

In [None]:
batch_size = 128
epochs = 50
lr = 5e-4
T = 300
img_size = 28
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 固定隨機種子
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_dataset = MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [None]:
beta = torch.linspace(1e-4, 0.02, T).to(device)
alpha = 1.0 - beta
alpha_bar = torch.cumprod(alpha, dim=0)

In [None]:
class ConditionalUNet(nn.Module):
    def __init__(self, in_channels=1, base=64, time_emb_dim=128, num_classes=10):
        super().__init__()
        self.time_emb = nn.Sequential(
            nn.Linear(1, time_emb_dim),
            nn.ReLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )
        self.label_emb = nn.Embedding(num_classes, time_emb_dim)

        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels + time_emb_dim + time_emb_dim, base, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base, base, 3, padding=1),
            nn.ReLU()
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(base, base * 2, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(base * 2, base * 2, 3, padding=1),
            nn.ReLU()
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(base * 2, base * 4, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(base * 4, base * 4, 3, padding=1),
            nn.ReLU()
        )

        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(base * 4, base * 2, 4, stride=2, padding=1),
            nn.ReLU()
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(base * 2, base, 4, stride=2, padding=1),
            nn.ReLU()
        )
        self.dec3 = nn.Sequential(
            nn.Conv2d(base, in_channels, 3, padding=1)
        )

    def forward(self, x, t, y):
        if t.dim() == 1:
            t = t.unsqueeze(-1)
        t = self.time_emb(t).view(t.shape[0], -1, 1, 1)
        t = t.expand(-1, -1, x.shape[2], x.shape[3])

        y = self.label_emb(y).view(y.shape[0], -1, 1, 1)
        y = y.expand(-1, -1, x.shape[2], x.shape[3])

        x = torch.cat([x, t, y], dim=1)

        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)

        d1 = self.dec1(e3)
        d2 = self.dec2(d1 + e2)
        d3 = self.dec3(d2 + e1)

        return d3

In [None]:
model = ConditionalUNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.MSELoss()

In [None]:
def q_sample(x0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x0)
    sqrt_alpha_bar = alpha_bar[t].sqrt().view(-1, 1, 1, 1)
    sqrt_one_minus_alpha_bar = (1 - alpha_bar[t]).sqrt().view(-1, 1, 1, 1)
    return sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise

In [None]:
# 訓練
train_start = time.time()

for epoch in range(epochs):
    total_loss = 0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        x, y = x.to(device), y.to(device)
        t = torch.randint(0, T, (x.size(0),), device=device)
        noise = torch.randn_like(x)
        x_t = q_sample(x, t, noise)

        noise_pred = model(x_t, t.float() / T, y)
        loss = loss_fn(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"📉 Epoch {epoch+1} Loss: {total_loss / len(train_loader):.4f}")

print(f"訓練完成總用時：{time.time() - train_start:.2f} 秒\n")
# 訓練完成後儲存模型
model_path = "conditional_unet_mnist.pth"
torch.save(model.state_dict(), model_path)
print(f"模型已儲存至 {model_path}")

In [None]:
@torch.no_grad()
def p_sample(x_t, t, y):
    if isinstance(t, int):
        t = torch.full((x_t.size(0),), t, device=device, dtype=torch.long)

    beta_t = beta[t].view(-1, 1, 1, 1)
    alpha_t = alpha[t].view(-1, 1, 1, 1)
    alpha_bar_t = alpha_bar[t].view(-1, 1, 1, 1)

    noise_pred = model(x_t, t.float() / T, y)
    mean = (1 / alpha_t.sqrt()) * (x_t - ((1 - alpha_t) / (1 - alpha_bar_t).sqrt()) * noise_pred)
    noise = torch.randn_like(x_t) if (t[0].item() > 0) else torch.zeros_like(x_t)
    return mean + beta_t.sqrt() * noise


In [None]:
@torch.no_grad()
def sample_images(n=10):
    model.eval()
    x_t = torch.randn(n, 1, img_size, img_size).to(device)
    # 隨機生成 0~9 的 label
    y = torch.randint(0, 10, (n,), device=device, dtype=torch.long)

    start_time = time.time()  # 開始計時
    for step in tqdm(reversed(range(T)), desc="Sampling random digits"):
        x_t = p_sample(x_t, step, y)

    end_time = time.time()    # 結束計時
    print(f"⏱️ 生成 {n} 張圖花費時間: {end_time - start_time:.2f} 秒")

    x_t = (x_t - x_t.min()) / (x_t.max() - x_t.min() + 1e-8)

    plt.figure(figsize=(16, 2))
    for i in range(n):
        plt.subplot(1, n, i + 1)
        plt.imshow(x_t[i].cpu().view(img_size, img_size), cmap="gray")
        plt.axis("off")
    plt.suptitle("Generated random digits", fontsize=20)
    plt.show()

    return x_t

sample_images(n=10)
# sample_images(target_digit=3, n=10) #指定數字

推理

In [None]:
# model = ConditionalUNet().to(device)
# model.load_state_dict(torch.load("conditional_unet_mnist.pth", map_location=device))
# model.eval()
# sample_images(n=10)