In [36]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as vutils
import os

# 1. 환경 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

batch_size = 512
print(f"batch_size: {batch_size}")

# 2. 데이터 준비 (Fashion MNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.FashionMNIST(
    root='./Fashion_MNIST_dataset',
    train=True,
    transform=transform,
    download=True
)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)


Using device: cuda
batch_size: 512


In [37]:
# 3. GAN 아키텍처 정의
class Generator(nn.Module):
    def __init__(self, noise_dim=100, img_dim=784):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(noise_dim, 128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(128, 256),
            nn.BatchNorm1d(256, momentum=0.8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512, momentum=0.8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024, momentum=0.8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(1024, img_dim),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.gen(x)

class Discriminator(nn.Module):
    def __init__(self, img_dim=784):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.disc(x)

In [38]:
# 4. 모델 초기화 및 설정
noise_dim = 100
img_dim = 28*28  # Fashion MNIST 이미지 크기
lr, b1, b2 = 3e-5, 0.5, 0.999
num_epochs = 100

generator = Generator(noise_dim, img_dim).to(device)
discriminator = Discriminator(img_dim).to(device)

criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# 5. 학습 루프 및 데이터 수집
# 5.1. 결과 저장을 위한 폴더 생성
results_dir = './results_dir/GAN'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

G_losses = []
D_losses = []
sample_images = []

# fixed noise 생성
fixed_noise = torch.randn(16, noise_dim).to(device)

# 5.2. 학습 루프
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(train_loader):
        real = real.view(-1, img_dim).to(device)
        batch_size = real.size(0)
        
        labels_real = torch.ones(batch_size, 1).to(device)
        labels_fake = torch.zeros(batch_size, 1).to(device)
        
        ## 판별자 학습 ##
        outputs = discriminator(real)
        loss_D_real = criterion(outputs, labels_real)
        
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake = generator(noise)
        outputs = discriminator(fake.detach())
        loss_D_fake = criterion(outputs, labels_fake)
        
        loss_D = loss_D_real + loss_D_fake
        
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        
        ## 생성자 학습 ##
        outputs = discriminator(fake)
        loss_G = criterion(outputs, labels_real)
        
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
        
    G_losses.append(loss_G.item())
    D_losses.append(loss_D.item())
    
    if epoch == 0 or (epoch+1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}]  Loss_D: {loss_D.item():.4f}  Loss_G: {loss_G.item():.4f}")

        with torch.no_grad():
            fake_images = generator(fixed_noise).reshape(-1, 1, 28, 28)
            sample_images.append(fake_images.cpu())


Epoch [1/100]  Loss_D: 1.3083  Loss_G: 0.5455
Epoch [10/100]  Loss_D: 1.3193  Loss_G: 0.6860
Epoch [20/100]  Loss_D: 1.3310  Loss_G: 0.7211
Epoch [30/100]  Loss_D: 1.3385  Loss_G: 0.7340
Epoch [40/100]  Loss_D: 1.3199  Loss_G: 0.7974
Epoch [50/100]  Loss_D: 1.3492  Loss_G: 0.7456
Epoch [60/100]  Loss_D: 1.3145  Loss_G: 0.7838
Epoch [70/100]  Loss_D: 1.3060  Loss_G: 0.7731
Epoch [80/100]  Loss_D: 1.3235  Loss_G: 0.7843
Epoch [90/100]  Loss_D: 1.3224  Loss_G: 0.7782
Epoch [100/100]  Loss_D: 1.3239  Loss_G: 0.7891


In [39]:
# 6. 시각화
# 6.1. 손실 곡선 시각화 및 저장
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator Loss")
plt.plot(D_losses, label="Discriminator Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.savefig(os.path.join(results_dir, "loss_curve.png"))
plt.close()

# 6.2. 생성된 이미지 시각화 및 저장
for idx, images in enumerate(sample_images):
    grid = vutils.make_grid(images, nrow=4, normalize=True)
    plt.figure(figsize=(8,8))
    plt.title(f"Generated Images at Epoch {max(idx*10, 1)}")
    plt.imshow(np.transpose(grid, (1,2,0)))
    plt.axis("off")
    # 이미지 저장
    plt.savefig(os.path.join(results_dir, f"generated_epoch_{max(idx*10, 1)}.png"))
    plt.close()