# Generative Adversarial Network (生成對抗網路)

生成對抗網絡（Generative Adversarial Networks, GANs）是一類由 Ian Goodfellow 於 2014 年提出的深度學習模型。GANs 包含兩個神經網絡：**生成器（Generator）**和**判別器（Discriminator）**，它們在博弈論框架下相互競爭。以下是其運作機制的詳細說明：

---

### 1. **生成器網絡**：
- **生成器的任務**：生成逼真的合成數據樣本，使其與真實數據難以區分。
- **工作原理**：將隨機噪聲向量（潛在空間）映射到目標數據分佈。
- **目標**：生成器希望生成的樣本能夠欺騙判別器，使其無法區分真假。

---

### 2. **判別器網絡**：
- **判別器的任務**：評估數據的真實性。
- **工作原理**：接收數據樣本，嘗試區分真實數據（來自實際數據集）與假數據（由生成器生成）。
- **目標**：提供反饋幫助生成器改進其生成的數據質量。

---

### 3. **訓練過程**：
- **同時訓練**：生成器與判別器是同時訓練的。
- **生成器目標**：最小化判別器識別假樣本的能力（通過優化損失函數）。
- **判別器目標**：最大化判別真實數據與假數據的準確性（通過優化損失函數）。
- **對抗過程**：訓練過程持續，直到生成器能生成足夠逼真的數據，使判別器無法可靠區分真假數據。

---

### 核心特點：
- **零和博弈框架**：生成器的收益即為判別器的損失，反之亦然。
- **無需標註數據**：GANs 依賴於無監督學習，無需標籤數據。
- **靈活性高**：GANs 已廣泛應用於圖像生成、風格轉換、超分辨率以及視頻生成等領域。

---

### 挑戰：
1. **訓練不穩定性**：
   - 對抗訓練可能導致不收斂或模式崩塌（生成器只生成有限類型的輸出）。
2. **對超參數的敏感性**：
   - 適當的超參數調整是取得良好性能的關鍵。

## Pseudo-Code for GAN
```
import torch
import torch.nn as nn
import torch.optim as optim

# Define the Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 28*28),
            nn.Tanh()
        )
    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, img):
        return self.model(img.view(img.size(0), -1))

# Initialize models, optimizers, and loss
G = Generator()
D = Discriminator()
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=0.0002)
optimizer_D = optim.Adam(D.parameters(), lr=0.0002)

# Training Loop
for epoch in range(epochs):
    for real_imgs in dataloader:  # Batch of real images
        # Train Discriminator
        z = torch.randn(batch_size, 100)  # Random noise
        fake_imgs = G(z)  # Generate fake images
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        
        # Real loss
        optimizer_D.zero_grad()
        real_loss = criterion(D(real_imgs), real_labels)
        fake_loss = criterion(D(fake_imgs.detach()), fake_labels)
        D_loss = real_loss + fake_loss
        D_loss.backward()
        optimizer_D.step()
        
        # Train Generator
        optimizer_G.zero_grad()
        G_loss = criterion(D(fake_imgs), real_labels)  # Fool discriminator
        G_loss.backward()
        optimizer_G.step()
```

## Sample: DCGAN (Deep Convolutional Generative Adversarial Network) MNIST 
### 說明:
這段程式碼展示了如何使用 深度卷積生成對抗網絡 (DCGAN) 訓練模型來生成手寫數字。程式碼中的主要步驟包括資料載入、生成器與判別器的定義、訓練過程，以及結果的可視化與儲存。

In [1]:
import os, torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision import utils as vutils

### 1. 環境設置 (Parameter setting)
- 裝置檢測：
  - 程式會檢查系統是否支援 GPU（例如 CUDA 或 Apple 的 MPS）。若無 GPU，則退回使用 CPU。

In [2]:
DATASET_PATH = "../data/mnist"
BATCH_SIZE = 64
device = None
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available.")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS backend is available.")
else:
    device = torch.device("cpu")
    print("MPS backend is not available.")

CUDA is available.


### 2. MNIST 資料載入 (Load data and transform to Dataset )
- 使用 torchvision.datasets.MNIST 載入 MNIST 資料集。
- 資料轉換：
  - 將圖像縮放至 $28\times 28$
  - 將像素值歸一化至 [−1,1]，以符合生成器輸出的範圍（Tanh 激活函數的輸出）。

In [3]:
transform = transforms.Compose([
    transforms.Resize((28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
# Load MNIST dataset
dataset = MNIST(DATASET_PATH, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

print(dataset.data.shape)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:13<00:00, 757kB/s] 


Extracting ../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz to ../data/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 141kB/s]


Extracting ../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz to ../data/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:05<00:00, 307kB/s]


Extracting ../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz to ../data/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.54MB/s]

Extracting ../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../data/mnist\MNIST\raw

torch.Size([60000, 28, 28])





### 3. 基本參數設定及神經網路權重初始化

In [4]:
nz = 100 # Size of z latent vector (i.e. size of generator input)
ngf = 64  # Size of feature maps in generator
ndf = 64  # Size of feature maps in discriminator

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

### Initialize the Networks
- Generator: A neural network that takes random noise ($z$) as input and outputs an image of size $28\times28$.
- Discriminator: A neural network that takes an image as input and outputs a probability value ($P(real)$) indicating whether the image is real (from the dataset) or fake (from the generator).

### 4. 生成器 (Generator)
- 功能：將隨機噪聲向量 $z$ 轉換為 $28\times 28$ 的手寫數字圖像。
- 架構：
  - 使用反卷積 (`ConvTranspose2d`) 進行上採樣。
  - 隱藏層使用 `ReLU` 激活函數，輸出層使用 `Tanh` 激活函數。
  - 批標準化 (`BatchNorm2d`) 提供穩定的訓練。

### Train the Generator
- Generate Fake Images:
  - Use the generator to create fake images: $x_{fake} = G(z)$.
- Fool the Discriminator:
  - Pass the fake images to the discriminator: $D(x_{fake})$.
  - Compute the loss: $L_G = -\log(D(x_{fake}))$.
    - Note: Here, the generator wants $D(x_{fake})$ to be close to 1 (to fool the discriminator into thinking the images are real).
- Update the Generator:
  - Update the generator's weights to minimize $L_G$.

In [5]:
use_bias = False
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input is Z (latent vector)
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=use_bias),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # State size: (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=use_bias),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # State size: (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=use_bias),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # State size: (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 3, bias=use_bias),  # Adjust padding to 3
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # State size: (ngf) x 28 x 28
            nn.ConvTranspose2d(ngf, 1, kernel_size=1, stride=1, padding=0, bias=use_bias),
            nn.Tanh()
            # Output size: 1 x 28 x 28
        )

    def forward(self, input):
        return self.main(input)

netG = Generator().to(device)
netG.apply(weights_init)

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(3, 3), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (13): Tanh()
  )
)

### 5. 判別器 (Discriminator)
- 功能：將輸入圖像（真實或生成）分類為真實 ( 1 ) 或假 ( 0 )。
- 架構：
  - 使用卷積層 (`Conv2d`) 進行下採樣。
  - 批標準化 (`BatchNorm2d`) 與 LeakyReLU 激活函數穩定訓練並避免梯度稀疏。
  - 最後一層使用 `Sigmoid` 激活函數進行`二分類`。

#### Train the Discriminator
- Input Real Images:
  - Feed a batch of real images ( $X_{real}$) from the dataset to the discriminator.
  - Compute the discriminator's output: $D(x_{ real })$, where $D$ is the discriminator.
  - Calculate the loss: $L_{real} = - \log (D(X_{real}))$.
- Input Fake Images:
  - Generate a batch of fake images ($x_{fake} = G(z)$) using the generator.
  - Feed $x_{fake}$ to the discriminator.
  - Compute the discriminator's output: $D(x_{fake})$.
  - Calculate the loss: $L_{fake} = -\log ( 1- D(x_{fake}))$. 
- Update the Discriminator:
  - Combine the losses: $L_D = L_{real} + L_{fake}$.
  - Update the discriminator's weights to minimize $L_D$​.

In [6]:
# using LeakyReLU activation function to avoid sparse gradients
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(1, ndf, 4, 2, 1, bias=use_bias),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=use_bias),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=use_bias),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, 1, 4, 2, 1, bias=use_bias),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1, 1).squeeze(1)

netD = Discriminator().to(device)
netD.apply(weights_init)

Discriminator(
  (main): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): Sigmoid()
  )
)

### 6. 損失函數與優化器 (Define Lossy function (BinaryCrossentropy))
- 損失函數：使用二元交叉熵損失 (`BCELoss`)。
- 優化器：Adam，學習率 0.0002，$\beta=(0.5,0.999)$。

In [7]:
# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
optimizer_D = torch.optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_G = torch.optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

### Training Model ( Epoch: 25 )
- The GAN training alternates between updating the discriminator and the generator.
- Alternate between training the discriminator and generator for many iterations.
- Over time, the generator improves, creating images that become increasingly realistic, while the discriminator becomes better at distinguishing real from fake images.

In [None]:
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
real_label = 1.0
fake_label = 0.0
niter = 25 

for epoch in range( niter ):
    for i, data in enumerate(dataloader, 0):
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizer_D.step()

        netG.zero_grad()
        label.fill_(real_label)
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizer_G.step()

        if i % 50 == 0:
            print(f"[{epoch}/{niter}][{i}/{len(dataloader)}] Loss_D: {errD.item()} Loss_G: {errG.item()} D(x): {D_x} D(G(z)): {D_G_z1}/{D_G_z2}")

        if i % 100 == 0:
            vutils.save_image(real_cpu, f"real_samples.png", normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(), f"fake_samples_epoch_{epoch}.png", normalize=True)
    torch.save(netG.state_dict(), f"netG_epoch_{epoch}.pth")
    torch.save(netD.state_dict(), f"netD_epoch_{epoch}.pth")

[0/25][0/938] Loss_D: 1.5097945928573608 Loss_G: 1.1543402671813965 D(x): 0.4472287893295288 D(G(z)): 0.45976877212524414/0.3312985301017761
[0/25][50/938] Loss_D: 0.34454309940338135 Loss_G: 5.165733337402344 D(x): 0.9224661588668823 D(G(z)): 0.2222352921962738/0.00794902816414833
[0/25][100/938] Loss_D: 0.6081126928329468 Loss_G: 3.144291400909424 D(x): 0.8790104389190674 D(G(z)): 0.35810741782188416/0.051959265023469925
[0/25][150/938] Loss_D: 0.5404613018035889 Loss_G: 2.705753803253174 D(x): 0.7721208333969116 D(G(z)): 0.20771917700767517/0.09180039167404175
[0/25][200/938] Loss_D: 0.4969502091407776 Loss_G: 2.1167471408843994 D(x): 0.737260103225708 D(G(z)): 0.15248596668243408/0.13881339132785797
[0/25][250/938] Loss_D: 0.43428170680999756 Loss_G: 3.2315146923065186 D(x): 0.9031277894973755 D(G(z)): 0.26793742179870605/0.04626741260290146
[0/25][300/938] Loss_D: 0.4345180094242096 Loss_G: 2.32780122756958 D(x): 0.7674660682678223 D(G(z)): 0.1359778642654419/0.13087432086467743
[

: 

### Plot results

In [None]:
import matplotlib.pyplot as plt

batch_size = 25
latent_dim = 100
fixed_noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
fake_images = netG(fixed_noise).cpu().detach().numpy()
fake_images = fake_images.reshape(-1, 28, 28)
R, C = 5, 5
for i in range(batch_size):
    plt.subplot(R, C, i+1)
    plt.imshow(fake_images[i], cmap='gray')
    plt.axis('off')
plt.show()

### Generate GIF file

In [None]:
import imageio.v2
import glob

anim_file = 'dcgan.gif'
with imageio.v2.get_writer(anim_file, mode='I') as writer:
    filenames = glob.glob('fake_samples*.png')
    filenames = sorted(filenames)
    for filename in filenames:
        image = imageio.v2.imread(filename)
        writer.append_data(image)
