<a href="https://colab.research.google.com/github/Qqqsse/Generative-Model-Comparison---HW2/blob/main/Final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
完整生成模型訓練程式碼
包含: VAE, GAN, cGAN, Diffusion Model
MNIST 手寫數字生成
"""

In [None]:
# ==================== 1. VAE 函式庫及定義 ====================
print("=" * 60)
print("載入 VAE 模組...")
print("=" * 60)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import os
from datetime import datetime

# 設定隨機種子
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}\n')

# VAE 模型定義
class ImprovedVAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(ImprovedVAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc_mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim // 2, latent_dim)

        self.fc3 = nn.Linear(latent_dim, hidden_dim // 2)
        self.fc4 = nn.Linear(hidden_dim // 2, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, input_dim)

        self.dropout = nn.Dropout(0.2)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        h1 = self.dropout(h1)
        h2 = F.relu(self.fc2(h1))
        mu = self.fc_mu(h2)
        logvar = self.fc_logvar(h2)
        logvar = torch.clamp(logvar, -10, 10)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        h3 = self.dropout(h3)
        h4 = F.relu(self.fc4(h3))
        return torch.sigmoid(self.fc5(h4))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

# VAE 損失函數
def improved_vae_loss(recon_x, x, mu, logvar, beta=1.0):
    batch_size = x.size(0)
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') / batch_size
    KLD_per_dim = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
    free_bits = 0.5
    KLD_per_dim = torch.clamp(KLD_per_dim, min=free_bits)
    KLD = torch.sum(KLD_per_dim) / batch_size
    return BCE + beta * KLD, BCE, KLD

# VAE 訓練函數
def train_vae(model, train_loader, optimizer, epoch, beta=1.0):
    model.train()
    train_loss = 0
    train_bce = 0
    train_kld = 0
    pbar = tqdm(train_loader, desc=f'VAE Epoch {epoch}')

    for batch_idx, (data, _) in enumerate(pbar):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss, bce, kld = improved_vae_loss(recon_batch, data, mu, logvar, beta)
        loss.backward()
        train_loss += loss.item() * data.size(0)
        train_bce += bce.item() * data.size(0)
        train_kld += kld.item() * data.size(0)
        optimizer.step()
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'bce': f'{bce.item():.4f}', 'kld': f'{kld.item():.4f}'})

    n = len(train_loader.dataset)
    return train_loss / n, train_bce / n, train_kld / n

def test_vae(model, test_loader, beta=1.0):
    model.eval()
    test_loss = 0
    test_bce = 0
    test_kld = 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss, bce, kld = improved_vae_loss(recon_batch, data, mu, logvar, beta)
            test_loss += loss.item() * data.size(0)
            test_bce += bce.item() * data.size(0)
            test_kld += kld.item() * data.size(0)
    n = len(test_loader.dataset)
    return test_loss / n, test_bce / n, test_kld / n

# VAE 視覺化函數
def generate_vae_images(model, num_images=10, latent_dim=20):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_images, latent_dim).to(device)
        samples = model.decode(z)
        samples = samples.view(num_images, 1, 28, 28).cpu()
    return samples

def visualize_generated(images, title='Generated Images', save_path=None):
    fig, axes = plt.subplots(1, 10, figsize=(20, 2))
    for i, ax in enumerate(axes):
        ax.imshow(images[i].squeeze(), cmap='gray', vmin=0, vmax=1)
        ax.axis('off')
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

def plot_vae_training_curves(train_losses, test_losses, train_klds, test_klds, save_path=None):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ax1.plot(train_losses, label='Train Loss', linewidth=2)
    ax1.plot(test_losses, label='Test Loss', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('VAE Total Loss', fontsize=14)
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)

    ax2.plot(train_klds, label='Train KLD', linewidth=2)
    ax2.plot(test_klds, label='Test KLD', linewidth=2)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('KLD', fontsize=12)
    ax2.set_title('KL Divergence', fontsize=14)
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()


In [None]:
# ==================== 2. VAE 執行訓練 ====================
print("\n" + "=" * 60)
print("開始訓練 VAE (Variational Autoencoder)")
print("=" * 60 + "\n")

# 載入資料
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# VAE 訓練參數
latent_dim = 20
hidden_dim = 400
learning_rate = 1e-3
num_epochs_vae = 50
beta_start = 0.5
beta_end = 4.0
warmup_epochs = 20

# 建立模型
vae_model = ImprovedVAE(input_dim=784, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
vae_optimizer = optim.Adam(vae_model.parameters(), lr=learning_rate, weight_decay=1e-5)
vae_scheduler = optim.lr_scheduler.ReduceLROnPlateau(vae_optimizer, mode='min', factor=0.5, patience=5)

print(f'VAE 配置:')
print(f'Latent Dim: {latent_dim}, Hidden Dim: {hidden_dim}')
print(f'Learning Rate: {learning_rate}, Epochs: {num_epochs_vae}')
print(f'β Range: {beta_start} -> {beta_end}\n')

train_losses_vae = []
test_losses_vae = []
train_klds_vae = []
test_klds_vae = []

for epoch in range(1, num_epochs_vae + 1):
    if epoch <= warmup_epochs:
        beta = beta_start + (beta_end - beta_start) * (epoch / warmup_epochs)
    else:
        beta = beta_end

    train_loss, train_bce, train_kld = train_vae(vae_model, train_loader, vae_optimizer, epoch, beta)
    test_loss, test_bce, test_kld = test_vae(vae_model, test_loader, beta)

    train_losses_vae.append(train_loss)
    test_losses_vae.append(test_loss)
    train_klds_vae.append(train_kld)
    test_klds_vae.append(test_kld)

    vae_scheduler.step(test_loss)
    current_lr = vae_optimizer.param_groups[0]['lr']
    print(f'Epoch {epoch}: Loss={test_loss:.4f}, BCE={test_bce:.4f}, KLD={test_kld:.4f}, β={beta:.2f}, LR={current_lr:.6f}')

    if epoch % 10 == 0:
        samples = generate_vae_images(vae_model, num_images=10, latent_dim=latent_dim)
        visualize_generated(samples, title=f'VAE Generated (Epoch {epoch})')

# VAE 最終結果
plot_vae_training_curves(train_losses_vae, test_losses_vae, train_klds_vae, test_klds_vae, 'vae_training_curve.png')
final_vae_samples = generate_vae_images(vae_model, num_images=10, latent_dim=latent_dim)
visualize_generated(final_vae_samples, title='VAE Final Generated Images', save_path='vae_final_results.png')
torch.save(vae_model.state_dict(), 'vae_mnist_improved.pth')
print(f'\nVAE 訓練完成! 最終測試損失: {test_losses_vae[-1]:.4f}\n')

In [None]:
# ==================== 3. GAN 函式庫及定義 ====================
print("=" * 60)
print("載入 GAN 模組...")
print("=" * 60 + "\n")

# GAN 模型定義
class Generator(nn.Module):
    def __init__(self, latent_dim=100, hidden_dim=256):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.BatchNorm1d(hidden_dim * 4),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim * 4, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img

class Discriminator(nn.Module):
    def __init__(self, hidden_dim=256):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, hidden_dim * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# GAN 訓練函數
def train_gan(generator, discriminator, train_loader, optimizer_G, optimizer_D, epoch, device):
    generator.train()
    discriminator.train()
    g_losses = []
    d_losses = []
    adversarial_loss = nn.BCELoss()

    pbar = tqdm(train_loader, desc=f'GAN Epoch {epoch}')
    for batch_idx, (real_imgs, _) in enumerate(pbar):
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        valid = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)
        real_imgs = real_imgs * 2 - 1

        # 訓練 Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        z = torch.randn(batch_size, 100).to(device)
        fake_imgs = generator(z)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # 訓練 Generator
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, 100).to(device)
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())
        pbar.set_postfix({'D_loss': d_loss.item(), 'G_loss': g_loss.item()})

    return np.mean(g_losses), np.mean(d_losses)

# GAN 生成函數
def generate_gan_images(generator, num_images=10, latent_dim=100):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(num_images, latent_dim).to(device)
        samples = generator(z)
        samples = (samples + 1) / 2
        samples = samples.cpu()
    return samples

def plot_gan_training_curve(g_losses, d_losses, save_path=None):
    plt.figure(figsize=(10, 5))
    plt.plot(g_losses, label='Generator Loss', alpha=0.7)
    plt.plot(d_losses, label='Discriminator Loss', alpha=0.7)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('GAN Training Curve')
    plt.legend()
    plt.grid(True)
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# ==================== 4. GAN 執行訓練 ====================
print("=" * 60)
print("開始訓練 GAN (Generative Adversarial Network)")
print("=" * 60 + "\n")

set_seed(42)

# GAN 訓練參數
latent_dim_gan = 100
hidden_dim_gan = 512
G_lr = 2e-4
D_lr = 2 * G_lr
num_epochs_gan = 100

# 建立模型
gan_generator = Generator(latent_dim=latent_dim_gan, hidden_dim=hidden_dim_gan).to(device)
gan_discriminator = Discriminator(hidden_dim=hidden_dim_gan).to(device)
optimizer_G = optim.Adam(gan_generator.parameters(), lr=G_lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(gan_discriminator.parameters(), lr=D_lr, betas=(0.5, 0.999))

print(f'GAN 配置:')
print(f'Latent Dim: {latent_dim_gan}, Hidden Dim: {hidden_dim_gan}')
print(f'Learning Rate: {G_lr}, Epochs: {num_epochs_gan}\n')

gan_g_losses = []
gan_d_losses = []

for epoch in range(1, num_epochs_gan + 1):
    g_loss, d_loss = train_gan(gan_generator, gan_discriminator, train_loader, optimizer_G, optimizer_D, epoch, device)
    gan_g_losses.append(g_loss)
    gan_d_losses.append(d_loss)
    print(f'Epoch {epoch}: G_loss = {g_loss:.4f}, D_loss = {d_loss:.4f}')

    if epoch % 10 == 0:
        samples = generate_gan_images(gan_generator, num_images=10, latent_dim=latent_dim_gan)
        visualize_generated(samples, title=f'GAN Generated (Epoch {epoch})')

# GAN 最終結果
plot_gan_training_curve(gan_g_losses, gan_d_losses, save_path='gan_training_curve.png')
final_gan_samples = generate_gan_images(gan_generator, num_images=10, latent_dim=latent_dim_gan)
visualize_generated(final_gan_samples, title='GAN Final Generated Images', save_path='gan_final_results.png')
torch.save(gan_generator.state_dict(), 'gan_generator.pth')
torch.save(gan_discriminator.state_dict(), 'gan_discriminator.pth')
print(f'\nGAN 訓練完成! 最終 G_loss: {gan_g_losses[-1]:.4f}, D_loss: {gan_d_losses[-1]:.4f}\n')

In [None]:
# ==================== 5. cGAN 函式庫及定義 ====================
print("=" * 60)
print("載入 cGAN 模組...")
print("=" * 60 + "\n")

# cGAN 模型定義
class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim=100, num_classes=10, hidden_dim=256):
        super(ConditionalGenerator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.BatchNorm1d(hidden_dim * 4),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim * 4, 784),
            nn.Tanh()
        )

    def forward(self, z, labels):
        label_input = self.label_emb(labels)
        gen_input = torch.cat([z, label_input], dim=1)
        img = self.model(gen_input)
        img = img.view(img.size(0), 1, 28, 28)
        return img

class ConditionalDiscriminator(nn.Module):
    def __init__(self, num_classes=10, hidden_dim=256):
        super(ConditionalDiscriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(784 + num_classes, hidden_dim * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        img_flat = img.view(img.size(0), -1)
        label_input = self.label_emb(labels)
        d_input = torch.cat([img_flat, label_input], dim=1)
        validity = self.model(d_input)
        return validity

# cGAN 訓練函數
def train_cgan(generator, discriminator, train_loader, optimizer_G, optimizer_D, epoch, device):
    generator.train()
    discriminator.train()
    g_losses = []
    d_losses = []
    adversarial_loss = nn.BCELoss()

    pbar = tqdm(train_loader, desc=f'cGAN Epoch {epoch}')
    for batch_idx, (real_imgs, labels) in enumerate(pbar):
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        labels = labels.to(device)
        valid = torch.ones(batch_size, 1).to(device) * 0.85
        fake = torch.zeros(batch_size, 1).to(device)
        real_imgs = real_imgs * 2 - 1

        # 訓練 Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs, labels), valid)
        z = torch.randn(batch_size, 100).to(device)
        gen_labels = torch.randint(0, 10, (batch_size,)).to(device)
        fake_imgs = generator(z, gen_labels)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach(), gen_labels), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # 訓練 Generator
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, 100).to(device)
        gen_labels = torch.randint(0, 10, (batch_size,)).to(device)
        gen_imgs = generator(z, gen_labels)
        g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), torch.ones(batch_size, 1).to(device))
        g_loss.backward()
        optimizer_G.step()

        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())
        pbar.set_postfix({'D_loss': d_loss.item(), 'G_loss': g_loss.item()})

    return np.mean(g_losses), np.mean(d_losses)

# cGAN 生成函數
def generate_all_digits_cgan(generator, samples_per_digit=10, latent_dim=100):
    generator.eval()
    all_samples = []
    with torch.no_grad():
        for digit in range(10):
            z = torch.randn(samples_per_digit, latent_dim).to(device)
            labels = torch.full((samples_per_digit,), digit).to(device)
            samples = generator(z, labels)
            samples = (samples + 1) / 2
            all_samples.append(samples.cpu())
    return torch.cat(all_samples, dim=0)

def visualize_cgan_grid(images, title='cGAN Generated Digits (0-9)', save_path=None):
    fig, axes = plt.subplots(10, 10, figsize=(15, 15))
    for i in range(10):
        for j in range(10):
            idx = i * 10 + j
            axes[i, j].imshow(images[idx].squeeze(), cmap='gray')
            axes[i, j].axis('off')
            if j == 0:
                axes[i, j].set_ylabel(f'{i}', fontsize=12, rotation=0, labelpad=20)
    plt.suptitle(title, fontsize=18, y=0.995)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

def plot_cgan_training_curve(g_losses, d_losses, save_path=None):
    plt.figure(figsize=(10, 5))
    plt.plot(g_losses, label='Generator Loss', alpha=0.7)
    plt.plot(d_losses, label='Discriminator Loss', alpha=0.7)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('cGAN Training Curve')
    plt.legend()
    plt.grid(True)
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# ==================== 6. cGAN 執行訓練 ====================
print("=" * 60)
print("開始訓練 cGAN (Conditional GAN)")
print("=" * 60 + "\n")

set_seed(42)

# cGAN 訓練參數
latent_dim_cgan = 100
num_classes = 10
hidden_dim_cgan = 512
G_lr_cgan = 2e-4
D_lr_cgan = 5e-5
num_epochs_cgan = 50

# 建立模型
cgan_generator = ConditionalGenerator(latent_dim=latent_dim_cgan, num_classes=num_classes, hidden_dim=hidden_dim_cgan).to(device)
cgan_discriminator = ConditionalDiscriminator(num_classes=num_classes, hidden_dim=hidden_dim_cgan).to(device)
optimizer_G_cgan = optim.Adam(cgan_generator.parameters(), lr=G_lr_cgan, betas=(0.5, 0.999))
optimizer_D_cgan = optim.Adam(cgan_discriminator.parameters(), lr=D_lr_cgan, betas=(0.5, 0.999))

print(f'cGAN 配置:')
print(f'Latent Dim: {latent_dim_cgan}, Classes: {num_classes}, Hidden Dim: {hidden_dim_cgan}')
print(f'Learning Rate: {G_lr_cgan}, Epochs: {num_epochs_cgan}\n')

cgan_g_losses = []
cgan_d_losses = []

for epoch in range(1, num_epochs_cgan + 1):
    g_loss, d_loss = train_cgan(cgan_generator, cgan_discriminator, train_loader, optimizer_G_cgan, optimizer_D_cgan, epoch, device)
    cgan_g_losses.append(g_loss)
    cgan_d_losses.append(d_loss)
    print(f'Epoch {epoch}: G_loss = {g_loss:.4f}, D_loss = {d_loss:.4f}')

    if epoch % 10 == 0:
        samples = generate_all_digits_cgan(cgan_generator, samples_per_digit=10, latent_dim=latent_dim_cgan)
        visualize_cgan_grid(samples, title=f'cGAN Generated (Epoch {epoch})')

# cGAN 最終結果
plot_cgan_training_curve(cgan_g_losses, cgan_d_losses, save_path='cgan_training_curve.png')
final_cgan_samples = generate_all_digits_cgan(cgan_generator, samples_per_digit=10, latent_dim=latent_dim_cgan)
visualize_cgan_grid(final_cgan_samples, title='cGAN Final Generated Digits (0-9)', save_path='cgan_final_results.png')
torch.save(cgan_generator.state_dict(), 'cgan_generator.pth')
torch.save(cgan_discriminator.state_dict(), 'cgan_discriminator.pth')
print(f'\ncGAN 訓練完成! 最終 G_loss: {cgan_g_losses[-1]:.4f}, D_loss: {cgan_d_losses[-1]:.4f}\n')

In [None]:
# ==================== 7. Diffusion Model 函式庫及定義 ====================
print("=" * 60)
print("載入 Diffusion Model 模組...")
print("=" * 60 + "\n")

import math
from torch.optim.lr_scheduler import OneCycleLR

# EMA 實現
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
    def __init__(self, model, decay, device="cpu"):
        def ema_avg(avg_model_param, model_param, num_averaged):
            return decay * avg_model_param + (1 - decay) * model_param
        super().__init__(model, device, ema_avg, use_buffers=True)

# ShuffleNet V2 組件
class ChannelShuffle(nn.Module):
    def __init__(self, groups):
        super().__init__()
        self.groups = groups

    def forward(self, x):
        n, c, h, w = x.shape
        x = x.view(n, self.groups, c // self.groups, h, w)
        x = x.transpose(1, 2).contiguous().view(n, -1, h, w)
        return x

class ConvBnSiLu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.module = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.SiLU(inplace=True)
        )

    def forward(self, x):
        return self.module(x)

class ResidualBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels//2, in_channels//2, 3, 1, 1, groups=in_channels//2),
            nn.BatchNorm2d(in_channels//2),
            ConvBnSiLu(in_channels//2, out_channels//2, 1, 1, 0)
        )
        self.branch2 = nn.Sequential(
            ConvBnSiLu(in_channels//2, in_channels//2, 1, 1, 0),
            nn.Conv2d(in_channels//2, in_channels//2, 3, 1, 1, groups=in_channels//2),
            nn.BatchNorm2d(in_channels//2),
            ConvBnSiLu(in_channels//2, out_channels//2, 1, 1, 0)
        )
        self.channel_shuffle = ChannelShuffle(2)

    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        x = torch.cat([self.branch1(x1), self.branch2(x2)], dim=1)
        x = self.channel_shuffle(x)
        return x

class ResidualDownsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, 2, 1, groups=in_channels),
            nn.BatchNorm2d(in_channels),
            ConvBnSiLu(in_channels, out_channels//2, 1, 1, 0)
        )
        self.branch2 = nn.Sequential(
            ConvBnSiLu(in_channels, out_channels//2, 1, 1, 0),
            nn.Conv2d(out_channels//2, out_channels//2, 3, 2, 1, groups=out_channels//2),
            nn.BatchNorm2d(out_channels//2),
            ConvBnSiLu(out_channels//2, out_channels//2, 1, 1, 0)
        )
        self.channel_shuffle = ChannelShuffle(2)

    def forward(self, x):
        x = torch.cat([self.branch1(x), self.branch2(x)], dim=1)
        x = self.channel_shuffle(x)
        return x

class TimeMLP(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, out_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, out_dim)
        )
        self.act = nn.SiLU()

    def forward(self, x, t):
        t_emb = self.mlp(t).unsqueeze(-1).unsqueeze(-1)
        x = x + t_emb
        return self.act(x)

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_embedding_dim):
        super().__init__()
        self.conv0 = nn.Sequential(
            *[ResidualBottleneck(in_channels, in_channels) for i in range(3)],
            ResidualBottleneck(in_channels, out_channels//2)
        )
        self.time_mlp = TimeMLP(time_embedding_dim, out_channels, out_channels//2)
        self.conv1 = ResidualDownsample(out_channels//2, out_channels)

    def forward(self, x, t=None):
        x_shortcut = self.conv0(x)
        if t is not None:
            x = self.time_mlp(x_shortcut, t)
        x = self.conv1(x)
        return [x, x_shortcut]

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_embedding_dim):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv0 = nn.Sequential(
            *[ResidualBottleneck(in_channels, in_channels) for i in range(3)],
            ResidualBottleneck(in_channels, in_channels//2)
        )
        self.time_mlp = TimeMLP(time_embedding_dim, in_channels, in_channels//2)
        self.conv1 = ResidualBottleneck(in_channels//2, out_channels//2)

    def forward(self, x, x_shortcut, t=None):
        x = self.upsample(x)
        x = torch.cat([x, x_shortcut], dim=1)
        x = self.conv0(x)
        if t is not None:
            x = self.time_mlp(x, t)
        x = self.conv1(x)
        return x

class Unet(nn.Module):
    def __init__(self, timesteps, time_embedding_dim, in_channels=1, out_channels=1,
                 base_dim=32, dim_mults=[2, 4]):
        super().__init__()
        assert isinstance(dim_mults, (list, tuple))
        assert base_dim % 2 == 0

        channels = self._cal_channels(base_dim, dim_mults)

        self.init_conv = ConvBnSiLu(in_channels, base_dim, 3, 1, 1)
        self.time_embedding = nn.Embedding(timesteps, time_embedding_dim)

        self.encoder_blocks = nn.ModuleList(
            [EncoderBlock(c[0], c[1], time_embedding_dim) for c in channels]
        )
        self.decoder_blocks = nn.ModuleList(
            [DecoderBlock(c[1], c[0], time_embedding_dim) for c in channels[::-1]]
        )

        self.mid_block = nn.Sequential(
            *[ResidualBottleneck(channels[-1][1], channels[-1][1]) for i in range(2)],
            ResidualBottleneck(channels[-1][1], channels[-1][1]//2)
        )

        self.final_conv = nn.Conv2d(channels[0][0]//2, out_channels, kernel_size=1)

    def forward(self, x, t=None):
        x = self.init_conv(x)
        if t is not None:
            t = self.time_embedding(t)

        encoder_shortcuts = []
        for encoder_block in self.encoder_blocks:
            x, x_shortcut = encoder_block(x, t)
            encoder_shortcuts.append(x_shortcut)

        x = self.mid_block(x)
        encoder_shortcuts.reverse()

        for decoder_block, shortcut in zip(self.decoder_blocks, encoder_shortcuts):
            x = decoder_block(x, shortcut, t)

        x = self.final_conv(x)
        return x

    def _cal_channels(self, base_dim, dim_mults):
        dims = [base_dim * x for x in dim_mults]
        dims.insert(0, base_dim)
        channels = []
        for i in range(len(dims) - 1):
            channels.append((dims[i], dims[i+1]))
        return channels

# Diffusion Model
class MNISTDiffusion(nn.Module):
    def __init__(self, image_size, in_channels, time_embedding_dim=256,
                 timesteps=1000, base_dim=32, dim_mults=[2, 4]):
        super().__init__()
        self.timesteps = timesteps
        self.in_channels = in_channels
        self.image_size = image_size

        betas = self._cosine_variance_schedule(timesteps)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=-1)

        self.register_buffer("betas", betas)
        self.register_buffer("alphas", alphas)
        self.register_buffer("alphas_cumprod", alphas_cumprod)
        self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1. - alphas_cumprod))

        self.model = Unet(timesteps, time_embedding_dim, in_channels, in_channels,
                         base_dim, dim_mults)

    def forward(self, x, noise):
        t = torch.randint(0, self.timesteps, (x.shape[0],)).to(x.device)
        x_t = self._forward_diffusion(x, t, noise)
        pred_noise = self.model(x_t, t)
        return pred_noise

    @torch.no_grad()
    def sampling(self, n_samples, clipped_reverse_diffusion=True, device="cuda"):
        x_t = torch.randn((n_samples, self.in_channels, self.image_size, self.image_size)).to(device)

        for i in tqdm(range(self.timesteps - 1, -1, -1), desc="Diffusion Sampling"):
            noise = torch.randn_like(x_t).to(device)
            t = torch.tensor([i for _ in range(n_samples)]).to(device)

            if clipped_reverse_diffusion:
                x_t = self._reverse_diffusion_with_clip(x_t, t, noise)
            else:
                x_t = self._reverse_diffusion(x_t, t, noise)

        x_t = (x_t + 1.) / 2.
        return x_t

    def _cosine_variance_schedule(self, timesteps, epsilon=0.008):
        steps = torch.linspace(0, timesteps, steps=timesteps + 1, dtype=torch.float32)
        f_t = torch.cos(((steps / timesteps + epsilon) / (1.0 + epsilon)) * math.pi * 0.5) ** 2
        betas = torch.clip(1.0 - f_t[1:] / f_t[:timesteps], 0.0, 0.999)
        return betas

    def _forward_diffusion(self, x_0, t, noise):
        assert x_0.shape == noise.shape
        return self.sqrt_alphas_cumprod.gather(-1, t).reshape(x_0.shape[0], 1, 1, 1) * x_0 + \
               self.sqrt_one_minus_alphas_cumprod.gather(-1, t).reshape(x_0.shape[0], 1, 1, 1) * noise

    @torch.no_grad()
    def _reverse_diffusion(self, x_t, t, noise):
        pred = self.model(x_t, t)
        alpha_t = self.alphas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        alpha_t_cumprod = self.alphas_cumprod.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        beta_t = self.betas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)

        mean = (1. / torch.sqrt(alpha_t)) * (x_t - ((1.0 - alpha_t) / sqrt_one_minus_alpha_cumprod_t) * pred)

        if t.min() > 0:
            alpha_t_cumprod_prev = self.alphas_cumprod.gather(-1, t - 1).reshape(x_t.shape[0], 1, 1, 1)
            std = torch.sqrt(beta_t * (1. - alpha_t_cumprod_prev) / (1. - alpha_t_cumprod))
        else:
            std = 0.0

        return mean + std * noise

    @torch.no_grad()
    def _reverse_diffusion_with_clip(self, x_t, t, noise):
        pred = self.model(x_t, t)
        alpha_t = self.alphas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        alpha_t_cumprod = self.alphas_cumprod.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)
        beta_t = self.betas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1)

        x_0_pred = torch.sqrt(1. / alpha_t_cumprod) * x_t - torch.sqrt(1. / alpha_t_cumprod - 1.) * pred
        x_0_pred.clamp_(-1., 1.)

        if t.min() > 0:
            alpha_t_cumprod_prev = self.alphas_cumprod.gather(-1, t - 1).reshape(x_t.shape[0], 1, 1, 1)
            mean = (beta_t * torch.sqrt(alpha_t_cumprod_prev) / (1. - alpha_t_cumprod)) * x_0_pred + \
                   ((1. - alpha_t_cumprod_prev) * torch.sqrt(alpha_t) / (1. - alpha_t_cumprod)) * x_t
            std = torch.sqrt(beta_t * (1. - alpha_t_cumprod_prev) / (1. - alpha_t_cumprod))
        else:
            mean = (beta_t / (1. - alpha_t_cumprod)) * x_0_pred
            std = 0.0

        return mean + std * noise

# Diffusion 訓練函數
def train_diffusion_epoch(model, model_ema, train_loader, optimizer, scheduler,
                         loss_fn, device, ema_update_interval=10):
    model.train()
    total_loss = 0
    global_steps = 0

    pbar = tqdm(train_loader, desc='Diffusion Training')
    for image, _ in pbar:
        noise = torch.randn_like(image).to(device)
        image = image.to(device)

        pred = model(image, noise)
        loss = loss_fn(pred, noise)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

        if global_steps % ema_update_interval == 0:
            model_ema.update_parameters(model)

        global_steps += 1
        total_loss += loss.item()

        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{scheduler.get_last_lr()[0]:.6f}'})

    return total_loss / len(train_loader)

def plot_diffusion_training_curve(losses, save_path='diffusion_training_curve.png'):
    plt.figure(figsize=(10, 5))
    plt.plot(losses, label='MSE Loss', alpha=0.7)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Diffusion Model Training Curve')
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# ==================== 8. Diffusion Model 執行訓練 ====================
print("=" * 60)
print("開始訓練 Diffusion Model")
print("=" * 60 + "\n")

set_seed(42)

# Diffusion 資料載入（需要歸一化到 [-1, 1]）
transform_diffusion = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_dataset_diffusion = datasets.MNIST(root='./data', train=True, download=True, transform=transform_diffusion)
train_loader_diffusion = DataLoader(train_dataset_diffusion, batch_size=128, shuffle=True, num_workers=2)

# Diffusion 訓練參數
num_epochs_diffusion = 200
timesteps = 2000
base_dim_diffusion = 128
lr_diffusion = 0.0008

# 建立模型
diffusion_model = MNISTDiffusion(
    timesteps=timesteps,
    image_size=28,
    in_channels=1,
    base_dim=base_dim_diffusion,
    dim_mults=[2, 4]
).to(device)

total_params = sum(p.numel() for p in diffusion_model.parameters())
print(f'Diffusion 模型參數量: {total_params:,} ({total_params/1e6:.2f}M)')

# EMA 設定
model_ema_steps = 10
model_ema_decay = 0.995
adjust = 1 * batch_size * model_ema_steps / num_epochs_diffusion
alpha = 1.0 - model_ema_decay
alpha = min(1.0, alpha * adjust)
diffusion_model_ema = ExponentialMovingAverage(diffusion_model, device=device, decay=1.0 - alpha)

# 優化器和調度器
optimizer_diffusion = optim.Adam(diffusion_model.parameters(), lr=lr_diffusion)
scheduler_diffusion = OneCycleLR(
    optimizer_diffusion, lr_diffusion,
    total_steps=num_epochs_diffusion * len(train_loader_diffusion),
    pct_start=0.25,
    anneal_strategy='cos'
)
loss_fn_diffusion = nn.MSELoss(reduction='mean')

print(f'Diffusion 配置:')
print(f'Timesteps: {timesteps}, Base Dim: {base_dim_diffusion}')
print(f'Learning Rate: {lr_diffusion}, Epochs: {num_epochs_diffusion}\n')

diffusion_losses = []

for epoch in range(1, num_epochs_diffusion + 1):
    loss = train_diffusion_epoch(
        diffusion_model, diffusion_model_ema, train_loader_diffusion,
        optimizer_diffusion, scheduler_diffusion, loss_fn_diffusion,
        device, ema_update_interval=model_ema_steps
    )
    diffusion_losses.append(loss)
    print(f'Epoch {epoch}/{num_epochs_diffusion}: Loss = {loss:.4f}')

    if epoch % 20 == 0 or epoch == num_epochs_diffusion:
        print(f"\n生成 Epoch {epoch} 的樣本...")
        diffusion_model_ema.eval()
        samples = diffusion_model_ema.module.sampling(
            n_samples=10,
            clipped_reverse_diffusion=True,
            device=device
        )
        visualize_generated(samples, title=f'Diffusion Generated (Epoch {epoch})')

# Diffusion 最終結果
plot_diffusion_training_curve(diffusion_losses, save_path='diffusion_training_curve.png')
diffusion_model_ema.eval()
final_diffusion_samples = diffusion_model_ema.module.sampling(
    n_samples=10,
    clipped_reverse_diffusion=True,
    device=device
)
visualize_generated(final_diffusion_samples, title='Diffusion Final Generated Images', save_path='diffusion_final_results.png')
torch.save({
    'model': diffusion_model.state_dict(),
    'model_ema': diffusion_model_ema.state_dict()
}, 'mnist_diffusion_final.pt')
print(f'\nDiffusion Model 訓練完成! 最終 Loss: {diffusion_losses[-1]:.4f}\n')
