<a href="https://colab.research.google.com/github/Fu-Pei-Yin/Deep-Generative-Mode/blob/week2/cGAN%E5%9C%A8%E6%89%8B%E5%AF%AB%E6%95%B8%E5%AD%97%E4%B8%8A%E7%9A%84%E6%87%89%E7%94%A8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import os
import time
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

# Generator Network
class Generator(nn.Module):
    def __init__(self, input_dim=100, output_dim=784, hidden_dim=256, num_classes=10):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes

        self.fc1_noise = nn.Linear(input_dim, hidden_dim)
        self.fc1_label = nn.Linear(num_classes, hidden_dim)

        self.main = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.ReLU(True),

            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.BatchNorm1d(hidden_dim * 4),
            nn.ReLU(True),

            nn.Linear(hidden_dim * 4, output_dim),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Process noise
        noise_out = F.relu(self.fc1_noise(noise))
        # Process labels
        label_out = F.relu(self.fc1_label(labels))
        # Concatenate
        x = torch.cat([noise_out, label_out], dim=1)
        x = self.main(x)
        return x.view(-1, 1, 28, 28)

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=256, num_classes=10):
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes

        self.fc1_image = nn.Linear(input_dim, hidden_dim * 4)
        self.fc1_label = nn.Linear(num_classes, hidden_dim * 4)

        self.main = nn.Sequential(
            nn.Linear(hidden_dim * 8, hidden_dim * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(hidden_dim * 2, 1),
            nn.Sigmoid()
        )

    def forward(self, images, labels):
        # Flatten images
        x = images.view(-1, self.input_dim)
        # Process images
        image_out = F.leaky_relu(self.fc1_image(x), 0.2)
        # Process labels
        label_out = F.leaky_relu(self.fc1_label(labels), 0.2)
        # Concatenate
        x = torch.cat([image_out, label_out], dim=1)
        x = self.main(x)
        return x.squeeze()

# 設置設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 準備固定輸入用於可視化 (生成數字0-9各10張)
def create_fixed_noise_and_labels():
    fixed_noise = torch.randn(100, 100).to(device)
    fixed_labels = torch.zeros(100, 10).to(device)

    # 創建標籤: 數字0-9各10個
    for digit in range(10):
        for i in range(10):
            idx = digit * 10 + i
            fixed_labels[idx, digit] = 1.0

    return fixed_noise, fixed_labels

fixed_noise, fixed_labels = create_fixed_noise_and_labels()

def show_result(num_epoch, show=False, save=False, path='result.png'):
    G.eval()
    with torch.no_grad():
        test_images = G(fixed_noise, fixed_labels)
    G.train()

    fig, axes = plt.subplots(10, 10, figsize=(10, 10))
    for i in range(10):
        for j in range(10):
            idx = i * 10 + j
            axes[i, j].imshow(test_images[idx].cpu().squeeze().numpy(), cmap='gray')
            axes[i, j].axis('off')
            axes[i, j].set_title(f'{i}', fontsize=8)

    plt.suptitle(f'cGAN Generated Images - Epoch {num_epoch}', fontsize=16)
    plt.tight_layout()
    plt.savefig(path, dpi=150, bbox_inches='tight')

    if show:
        plt.show()
    else:
        plt.close()

def show_train_hist(hist, show=False, save=False, path='Train_hist.png'):
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.plot(hist['D_losses'], label='Discriminator Loss')
    plt.plot(hist['G_losses'], label='Generator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Losses')
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(hist['D_real_acc'], label='Real Accuracy')
    plt.plot(hist['D_fake_acc'], label='Fake Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Discriminator Accuracy')
    plt.grid(True)

    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

# 訓練參數
batch_size = 128
lr = 0.0002  # 2e-4 as required
train_epoch = 50
label_smooth = 0.1  # Label smoothing factor

# 資料載入
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 初始化網絡
G = Generator().to(device)
D = Discriminator().to(device)

print(f"Generator parameters: {sum(p.numel() for p in G.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in D.parameters()):,}")

# 損失函數和優化器
criterion = nn.BCELoss()
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# 建立結果資料夾
os.makedirs('MNIST_cGAN_results/Fixed_results', exist_ok=True)

# 訓練歷史記錄
train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []
train_hist['D_real_acc'] = []
train_hist['D_fake_acc'] = []
train_hist['per_epoch_ptimes'] = []
train_hist['total_ptime'] = []

print('Training start!')
start_time = time.time()

for epoch in range(train_epoch):
    D_losses = []
    G_losses = []
    D_real_accuracies = []
    D_fake_accuracies = []

    epoch_start_time = time.time()

    for batch_idx, (real_images, real_labels) in enumerate(train_loader):
        batch_size = real_images.size(0)

        # 準備真實資料
        real_images = real_images.to(device)
        real_labels_onehot = F.one_hot(real_labels, num_classes=10).float().to(device)

        # 使用標籤平滑
        real_targets = torch.full((batch_size,), 1.0 - label_smooth, device=device)
        fake_targets = torch.zeros(batch_size, device=device)

        # 訓練判別器
        D.zero_grad()

        # 真實圖片損失
        D_real_output = D(real_images, real_labels_onehot)
        D_real_loss = criterion(D_real_output, real_targets)
        D_real_accuracy = (D_real_output > 0.5).float().mean()

        # 生成假圖片
        noise = torch.randn(batch_size, 100, device=device)
        fake_labels = torch.randint(0, 10, (batch_size,), device=device)
        fake_labels_onehot = F.one_hot(fake_labels, num_classes=10).float().to(device)

        fake_images = G(noise, fake_labels_onehot)

        # 假圖片損失
        D_fake_output = D(fake_images.detach(), fake_labels_onehot)
        D_fake_loss = criterion(D_fake_output, fake_targets)
        D_fake_accuracy = (D_fake_output < 0.5).float().mean()

        # 總判別器損失
        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_total_loss.backward()
        D_optimizer.step()

        # 訓練生成器
        G.zero_grad()

        D_fake_output = D(fake_images, fake_labels_onehot)
        G_loss = criterion(D_fake_output, real_targets)  # 騙過判別器

        G_loss.backward()
        G_optimizer.step()

        # 記錄損失和準確率
        D_losses.append(D_total_loss.item())
        G_losses.append(G_loss.item())
        D_real_accuracies.append(D_real_accuracy.item())
        D_fake_accuracies.append(D_fake_accuracy.item())

    epoch_end_time = time.time()
    per_epoch_ptime = epoch_end_time - epoch_start_time

    # 計算epoch平均
    avg_D_loss = np.mean(D_losses)
    avg_G_loss = np.mean(G_losses)
    avg_D_real_acc = np.mean(D_real_accuracies)
    avg_D_fake_acc = np.mean(D_fake_accuracies)

    print(f'Epoch [{epoch+1}/{train_epoch}] '
          f'D_loss: {avg_D_loss:.4f} '
          f'G_loss: {avg_G_loss:.4f} '
          f'D_real_acc: {avg_D_real_acc:.4f} '
          f'D_fake_acc: {avg_D_fake_acc:.4f} '
          f'Time: {per_epoch_ptime:.2f}s')

    # 保存結果
    fixed_p = f'MNIST_cGAN_results/Fixed_results/MNIST_cGAN_{epoch+1}.png'
    show_result(epoch+1, save=True, path=fixed_p)

    train_hist['D_losses'].append(avg_D_loss)
    train_hist['G_losses'].append(avg_G_loss)
    train_hist['D_real_acc'].append(avg_D_real_acc)
    train_hist['D_fake_acc'].append(avg_D_fake_acc)
    train_hist['per_epoch_ptimes'].append(per_epoch_ptime)

end_time = time.time()
total_ptime = end_time - start_time
train_hist['total_ptime'].append(total_ptime)

print(f"Total training time: {total_ptime:.2f} seconds")
print("Training finished! Saving models and training history...")

# 保存模型和訓練歷史
torch.save(G.state_dict(), "MNIST_cGAN_results/generator_param.pkl")
torch.save(D.state_dict(), "MNIST_cGAN_results/discriminator_param.pkl")

with open('MNIST_cGAN_results/train_hist.pkl', 'wb') as f:
    pickle.dump(train_hist, f)

# 生成訓練過程圖表和動畫
show_train_hist(train_hist, save=True, path='MNIST_cGAN_results/MNIST_cGAN_train_hist.png')

images = []
for e in range(train_epoch):
    img_name = f'MNIST_cGAN_results/Fixed_results/MNIST_cGAN_{e+1}.png'
    if os.path.exists(img_name):
        images.append(imageio.imread(img_name))
imageio.mimsave('MNIST_cGAN_results/generation_animation.gif', images, fps=5)

print("All results saved successfully!")

Using device: cpu
Generator parameters: 1,623,312
Discriminator parameters: 3,438,593
Training start!
Epoch [1/50] D_loss: 0.5143 G_loss: 1.7872 D_real_acc: 0.7240 D_fake_acc: 0.8593 Time: 104.73s
Epoch [2/50] D_loss: 0.4621 G_loss: 1.8789 D_real_acc: 0.7575 D_fake_acc: 0.9167 Time: 103.88s
Epoch [3/50] D_loss: 0.4217 G_loss: 2.0729 D_real_acc: 0.7925 D_fake_acc: 0.9403 Time: 103.19s
Epoch [4/50] D_loss: 0.3869 G_loss: 2.4186 D_real_acc: 0.8229 D_fake_acc: 0.9623 Time: 103.85s
Epoch [5/50] D_loss: 0.3937 G_loss: 2.4731 D_real_acc: 0.8143 D_fake_acc: 0.9495 Time: 103.23s
Epoch [6/50] D_loss: 0.5028 G_loss: 1.7030 D_real_acc: 0.7027 D_fake_acc: 0.8840 Time: 103.39s
Epoch [7/50] D_loss: 0.5463 G_loss: 1.4454 D_real_acc: 0.6517 D_fake_acc: 0.8486 Time: 103.08s
Epoch [8/50] D_loss: 0.5737 G_loss: 1.2986 D_real_acc: 0.6086 D_fake_acc: 0.8305 Time: 103.93s
Epoch [9/50] D_loss: 0.5844 G_loss: 1.2494 D_real_acc: 0.5884 D_fake_acc: 0.8210 Time: 104.49s
Epoch [10/50] D_loss: 0.5972 G_loss: 1.2016

  images.append(imageio.imread(img_name))


All results saved successfully!
