In [8]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as vutils
import os
from torch.autograd import Variable, grad

# 1. 환경 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

batch_size = 512
print(f"batch_size: {batch_size}")

# 2. 데이터 준비 (Fashion MNIST)
transform = transforms.Compose([
    transforms.Resize(32),  # Resize to 32x32 for DCGAN
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.FashionMNIST(
    root='./Fashion_MNIST_dataset',
    train=True,
    transform=transform,
    download=True
)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)

Using device: cuda
batch_size: 512


In [9]:
# 3. WDCGAN-GP 아키텍처 정의
class Generator(nn.Module):
    def __init__(self, noise_dim=100, channels_img=1, features_g=64):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x noise_dim x 1 x 1
            nn.ConvTranspose2d(noise_dim, features_g * 8, kernel_size=4, stride=1, padding=0, bias=False),  # 4x4
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(features_g * 8, features_g * 4, kernel_size=4, stride=2, padding=1, bias=False),  # 8x8
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(features_g * 4, features_g * 2, kernel_size=4, stride=2, padding=1, bias=False),  # 16x16
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(features_g * 2, features_g, kernel_size=4, stride=2, padding=1, bias=False),  # 32x32
            nn.BatchNorm2d(features_g),
            nn.ReLU(True),

            nn.ConvTranspose2d(features_g, channels_img, kernel_size=3, stride=1, padding=1, bias=False),  # 32x32
            nn.Tanh()
            # Output: (channels_img) x32x32
        )

    def forward(self, x):
        return self.gen(x)

class Critic(nn.Module):
    def __init__(self, channels_img=1, features_d=64):
        super(Critic, self).__init__()
        self.disc = nn.Sequential(
            # Input: (channels_img) x32x32
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1, bias=False),  # 16x16
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features_d, features_d * 2, kernel_size=4, stride=2, padding=1, bias=False),  # 8x8
            nn.BatchNorm2d(features_d * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features_d * 2, features_d * 4, kernel_size=4, stride=2, padding=1, bias=False),  # 4x4
            nn.BatchNorm2d(features_d * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features_d * 4, features_d * 8, kernel_size=4, stride=2, padding=1, bias=False),  # 2x2
            nn.BatchNorm2d(features_d * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(features_d * 8, 1, kernel_size=2, stride=1, padding=0, bias=False)  # 1x1
            # No activation
        )

    def forward(self, x):
        return self.disc(x).view(-1)

In [10]:
# 4. 모델 초기화 및 설정
noise_dim = 100
channels_img = 1  # Grayscale images
features_g = 64
features_d = 64
lr, b1, b2 = 3e-5, 0.0, 0.9
num_epochs = 100
n_critic = 5  # Number of critic iterations per generator iteration
lambda_gp = 10  # Gradient penalty lambda hyperparameter

generator = Generator(noise_dim, channels_img, features_g).to(device)
critic = Critic(channels_img, features_d).to(device)

# Initialize weights
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

generator.apply(weights_init)
critic.apply(weights_init)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_C = torch.optim.Adam(critic.parameters(), lr=lr, betas=(b1, b2))

# 5. Gradient Penalty 함수 정의
def compute_gradient_penalty(critic, real, fake, device="cpu"):
    batch_size = real.size(0)
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device, requires_grad=True)
    interpolated = epsilon * real + (1 - epsilon) * fake
    interpolated.requires_grad_(True)
    
    critic_interpolated = critic(interpolated)
    
    gradients = grad(
        outputs=critic_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(critic_interpolated),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    return gradient_penalty

# 6. 학습 루프 및 데이터 수집
# 6.1. 결과 저장을 위한 폴더 생성
results_dir = './results_dir/WDCGAN-GP'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

G_losses = []
C_losses = []
sample_images = []

# fixed noise 생성
fixed_noise = torch.randn(16, noise_dim, 1, 1).to(device)

# 6.2. 학습 루프
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(train_loader):
        real = real.to(device)
        batch_size_curr = real.size(0)
        
        # ---------------------
        #  Train Critic
        # ---------------------
        for _ in range(n_critic):
            # Sample noise and generate fake images
            noise = torch.randn(batch_size_curr, noise_dim, 1, 1, device=device)
            fake = generator(noise)
            
            # Compute critic scores
            critic_real = critic(real)
            critic_fake = critic(fake.detach())
            
            # Compute gradient penalty
            gp = compute_gradient_penalty(critic, real, fake.detach(), device=device)
            
            # WDCGAN-GP loss for critic
            loss_C = -(torch.mean(critic_real) - torch.mean(critic_fake)) + gp
            
            # Backprop and optimize
            optimizer_C.zero_grad()
            loss_C.backward()
            optimizer_C.step()
        
        # ---------------------
        #  Train Generator
        # ---------------------
        # Generate fake images
        noise = torch.randn(batch_size_curr, noise_dim, 1, 1, device=device)
        fake = generator(noise)
        
        # Compute critic scores for fake images
        critic_fake = critic(fake)
        
        # WDCGAN-GP loss for generator
        loss_G = -torch.mean(critic_fake)
        
        # Backprop and optimize
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
        
    # Append losses for visualization
    G_losses.append(loss_G.item())
    C_losses.append(loss_C.item())
    
    # Print losses and save samples at intervals
    if epoch == 0 or (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}]  Loss_C: {loss_C.item():.4f}  Loss_G: {loss_G.item():.4f}")
        
        with torch.no_grad():
            fake_images = generator(fixed_noise).detach().cpu()
            sample_images.append(fake_images)



Epoch [1/100]  Loss_C: -7.4389  Loss_G: -2.0688
Epoch [10/100]  Loss_C: -197.7918  Loss_G: 111.1652
Epoch [20/100]  Loss_C: -628.8533  Loss_G: 327.8903
Epoch [30/100]  Loss_C: -1277.9586  Loss_G: 668.4132
Epoch [40/100]  Loss_C: -140.6190  Loss_G: 1112.9521
Epoch [50/100]  Loss_C: -3168.7153  Loss_G: 1634.6560
Epoch [60/100]  Loss_C: -4388.1484  Loss_G: 2319.5256
Epoch [70/100]  Loss_C: -6049.3940  Loss_G: 3035.3926
Epoch [80/100]  Loss_C: -7649.5435  Loss_G: 3806.7280
Epoch [90/100]  Loss_C: -9490.2695  Loss_G: 4625.8105
Epoch [100/100]  Loss_C: -10862.4346  Loss_G: 5124.2451


In [11]:
# 7. 시각화
# 7.1. 손실 곡선 시각화 및 저장
plt.figure(figsize=(10,5))
plt.title("Generator and Critic Loss During Training (WDCGAN-GP)")
plt.plot(G_losses, label="Generator Loss")
plt.plot(C_losses, label="Critic Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.savefig(os.path.join(results_dir, "loss_curve_wdcgan_gp.png"))
plt.close()

# 7.2. 생성된 이미지 시각화 및 저장
for idx, images in enumerate(sample_images):
    grid = vutils.make_grid(images, nrow=4, normalize=True)
    plt.figure(figsize=(8,8))
    plt.title(f"Generated Images at Epoch {max(idx*10, 1)}")
    plt.imshow(np.transpose(grid, (1,2,0)))
    plt.axis("off")
    # 이미지 저장
    plt.savefig(os.path.join(results_dir, f"generated_epoch_{max(idx*10, 1)}.png"))
    plt.close()