<a href="https://colab.research.google.com/github/Fu-Pei-Yin/Deep-Generative-Mode/blob/main/GAN%E5%9C%A8%E6%89%8B%E5%AF%AB%E6%95%B8%E5%AD%97%E4%B8%8A%E7%9A%84%E6%87%89%E7%94%A8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import os
import random
from tqdm import tqdm

# ---------------------
# 設定
# ---------------------
seed = 42
random.seed(seed)
torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

batch_size = 128
gan_lr = 2e-4
gan_epochs = 300
nz = 100         # noise dim for GAN
img_size = 28
img_channels = 1

save_dir = "outputs_gan"
os.makedirs(save_dir, exist_ok=True)

# ---------------------
# Data
# ---------------------
transform = transforms.Compose([
    transforms.ToTensor(),  # [0,1]
    transforms.Normalize((0.5,), (0.5,))  # scale to [-1,1]
])
train_ds = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)

# ---------------------
# Utility
# ---------------------
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        nn.init.normal_(m.weight, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

# ---------------------
# GAN 模型定義
# ---------------------
class Generator(nn.Module):
    def __init__(self, nz=100):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(nz, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, img_size*img_size),
            nn.Tanh()  # 輸出範圍 [-1,1]
        )

    def forward(self, z):
        out = self.model(z)
        return out.view(-1, 1, img_size, img_size)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(img_size*img_size, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity.view(-1)

# ---------------------
# 初始化模型與優化器
# ---------------------
G = Generator(nz=nz).to(device)
D = Discriminator().to(device)
G.apply(weights_init)
D.apply(weights_init)

bceloss = nn.BCELoss()
optimG = optim.Adam(G.parameters(), lr=gan_lr, betas=(0.5, 0.999))
optimD = optim.Adam(D.parameters(), lr=gan_lr, betas=(0.5, 0.999))

# ---------------------
# 訓練迴圈
# ---------------------
print("Start GAN training...")
for epoch in range(1, gan_epochs+1):
    pbar = tqdm(train_loader, desc=f"GAN Epoch {epoch}/{gan_epochs}")
    for imgs, _ in pbar:
        imgs = imgs.to(device)
        bs = imgs.size(0)

        # Train Discriminator
        D.zero_grad()
        real_labels = torch.ones(bs, device=device)
        fake_labels = torch.zeros(bs, device=device)

        out_real = D(imgs)
        loss_real = bceloss(out_real, real_labels)

        z = torch.randn(bs, nz, device=device)
        fake_imgs = G(z).detach()
        out_fake = D(fake_imgs)
        loss_fake = bceloss(out_fake, fake_labels)

        d_loss = (loss_real + loss_fake) / 2
        d_loss.backward()
        optimD.step()

        # Train Generator
        G.zero_grad()
        z = torch.randn(bs, nz, device=device)
        gen_imgs = G(z)
        out = D(gen_imgs)
        g_loss = bceloss(out, real_labels)
        g_loss.backward()
        optimG.step()

        pbar.set_postfix(d_loss=d_loss.item(), g_loss=g_loss.item())

    # 每個 epoch 儲存隨機生成的 64 張影像
    with torch.no_grad():
        z = torch.randn(64, nz, device=device)
        samples = G(z)
        samples = (samples + 1) / 2  # 轉回 [0,1]
        save_image(make_grid(samples, nrow=8), os.path.join(save_dir, f"gan_samples_epoch{epoch}.png"))

print("GAN training finished.")

# ---------------------
# 產生 10 張隨機生成影像並顯示
# ---------------------
import matplotlib.pyplot as plt

G.eval()
with torch.no_grad():
    z = torch.randn(10, nz, device=device)
    gen_imgs = G(z).cpu()
    gen_imgs = (gen_imgs + 1) / 2

fig, axs = plt.subplots(1, 10, figsize=(15, 2))
for i in range(10):
    axs[i].imshow(gen_imgs[i].squeeze(), cmap='gray')
    axs[i].axis('off')
plt.suptitle("GAN Generated MNIST Digits")
plt.show()

Device: cpu


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 486kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.49MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.83MB/s]


Start GAN training...


GAN Epoch 1/300: 100%|██████████| 468/468 [00:45<00:00, 10.26it/s, d_loss=0.239, g_loss=1.93]
GAN Epoch 2/300: 100%|██████████| 468/468 [00:45<00:00, 10.34it/s, d_loss=0.374, g_loss=2.07]
GAN Epoch 3/300:  10%|▉         | 46/468 [00:04<00:36, 11.66it/s, d_loss=0.258, g_loss=2.62]