In [4]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as vutils
import os
from torch.autograd import Variable, grad

# 1. 환경 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

batch_size = 512
print(f"batch_size: {batch_size}")

# 2. 데이터 준비 (Fashion MNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.FashionMNIST(
    root='./Fashion_MNIST_dataset',
    train=True,
    transform=transform,
    download=True
)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)

Using device: cuda
batch_size: 512


In [5]:
# 3. WGAN-GP 아키텍처 정의
class Generator(nn.Module):
    def __init__(self, noise_dim=100, img_dim=784):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(noise_dim, 128),
            nn.ReLU(inplace=True),

            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),

            nn.Linear(1024, img_dim),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.gen(x)

class Critic(nn.Module):  # Discriminator를 Critic으로 변경
    def __init__(self, img_dim=784):
        super(Critic, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 1)  # 마지막에 활성화 함수 제거
        )
        
    def forward(self, x):
        return self.disc(x)

In [6]:
# 4. 모델 초기화 및 설정
noise_dim = 100
img_dim = 28*28  # Fashion MNIST 이미지 크기
features_g = 64
features_d = 64
lr, b1, b2 = 3e-5, 0.0, 0.9
num_epochs = 100
n_critic = 5  # Number of critic iterations per generator iteration
lambda_gp = 10  # Gradient penalty lambda hyperparameter

generator = Generator(noise_dim, img_dim).to(device)
critic = Critic(img_dim).to(device)

# Initialize weights
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        nn.init.normal_(m.weight.detach(), 0.0, 0.02)
    if classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.detach(), 1.0, 0.02)
        nn.init.constant_(m.bias.detach(), 0)

generator.apply(weights_init)
critic.apply(weights_init)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_C = torch.optim.Adam(critic.parameters(), lr=lr, betas=(b1, b2))

# 5. Gradient Penalty 함수 정의
def compute_gradient_penalty(critic, real, fake, device="cpu"):
    batch_size = real.size(0)
    epsilon = torch.rand(batch_size, 1, device=device, requires_grad=True)
    interpolated = epsilon * real + (1 - epsilon) * fake
    interpolated.requires_grad_(True)
    
    critic_interpolated = critic(interpolated)
    
    gradients = grad(
        outputs=critic_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(critic_interpolated),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = lambda_gp * ((gradient_norm - 1) ** 2).mean()
    
    return gradient_penalty

# 5. 학습 루프 및 데이터 수집
# 5.1. 결과 저장을 위한 폴더 생성
results_dir = './results_dir/WGAN_GP'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

G_losses = []
C_losses = []
sample_images = []

# fixed noise 생성
fixed_noise = torch.randn(16, noise_dim).to(device)

# 5.2. 학습 루프
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(train_loader):
        real = real.view(-1, img_dim).to(device)
        batch_size_current = real.size(0)
        
        ## ---------------------
        ##  Train Critic
        ## ---------------------
        for _ in range(n_critic):
            noise = torch.randn(batch_size_current, noise_dim).to(device)
            fake = generator(noise)
            
            critic_real = critic(real)
            critic_fake = critic(fake.detach())
            
            loss_C = -(torch.mean(critic_real) - torch.mean(critic_fake))
            
            # Gradient penalty
            loss_C += compute_gradient_penalty(critic, real, fake.detach(), device=device)
            
            optimizer_C.zero_grad()
            loss_C.backward()
            optimizer_C.step()
        
        ## ---------------------
        ##  Train Generator
        ## ---------------------
        noise = torch.randn(batch_size_current, noise_dim).to(device)
        fake = generator(noise)
        # Generator loss aims to maximize critic's output for fake samples
        loss_G = -torch.mean(critic(fake))
        
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
        
    G_losses.append(loss_G.item())
    C_losses.append(loss_C.item())
    
    if epoch == 0 or (epoch+1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}]  Loss_C: {loss_C.item():.4f}  Loss_G: {loss_G.item():.4f}")

        with torch.no_grad():
            fake_images = generator(fixed_noise).reshape(-1, 1, 28, 28).detach().cpu()
            sample_images.append(fake_images)


Epoch [1/100]  Loss_C: -2.0705  Loss_G: -10.0021
Epoch [10/100]  Loss_C: -2.8403  Loss_G: -3.0662
Epoch [20/100]  Loss_C: -1.8603  Loss_G: -4.3060
Epoch [30/100]  Loss_C: -2.0045  Loss_G: -0.0236
Epoch [40/100]  Loss_C: -1.8784  Loss_G: -0.7376
Epoch [50/100]  Loss_C: -1.5112  Loss_G: -1.4654
Epoch [60/100]  Loss_C: -1.3396  Loss_G: -1.1964
Epoch [70/100]  Loss_C: -1.5365  Loss_G: -1.9615
Epoch [80/100]  Loss_C: -1.5113  Loss_G: -1.5102
Epoch [90/100]  Loss_C: -1.4593  Loss_G: -1.5154
Epoch [100/100]  Loss_C: -1.4294  Loss_G: -3.6116


In [7]:
# 6. 시각화
# 6.1. 손실 곡선 시각화 및 저장
plt.figure(figsize=(10,5))
plt.title("Generator and Critic Loss During Training (WGAN-GP)")
plt.plot(G_losses, label="Generator Loss")
plt.plot(C_losses, label="Critic Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.savefig(os.path.join(results_dir, "loss_curve_wgan_gp.png"))
plt.close()

# 6.2. 생성된 이미지 시각화 및 저장
for idx, images in enumerate(sample_images):
    grid = vutils.make_grid(images, nrow=4, normalize=True)
    plt.figure(figsize=(8,8))
    plt.title(f"Generated Images at Epoch {max(idx*10, 1)}")
    plt.imshow(np.transpose(grid, (1,2,0)))
    plt.axis("off")
    # 이미지 저장
    plt.savefig(os.path.join(results_dir, f"generated_epoch_{max(idx*10, 1)}.png"))
    plt.close()