In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os

# 설정값들
learning_rate = 0.0002
total_epoch = 200
batch_size = 100
n_input = 28 * 28
n_noise = 128
n_hidden = 256
n_label = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 데이터셋 로드
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])
train_dataset = datasets.MNIST(root='./mnist/data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

# Generator 클래스
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(n_noise + n_label, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_input),
            nn.Sigmoid()
        )

    def forward(self, z, labels):
        x = torch.cat([z, labels], dim=1)
        return self.net(x)

# Discriminator 클래스
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(n_input + n_label, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        x = torch.cat([x, labels], dim=1)
        return self.net(x)

# 모델 초기화
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 손실함수 및 옵티마이저
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)

# 결과 저장 디렉토리 생성
os.makedirs("samples_ex", exist_ok=True)

# 학습 루프
for epoch in range(total_epoch):
    for real_imgs, labels in train_loader:
        real_imgs = real_imgs.to(device)
        labels = torch.nn.functional.one_hot(labels, n_label).float().to(device)

        batch_size = real_imgs.size(0)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Discriminator 학습
        z = torch.randn(batch_size, n_noise).to(device)
        fake_imgs = generator(z, labels)

        d_real = discriminator(real_imgs, labels)
        d_fake = discriminator(fake_imgs.detach(), labels)

        d_loss = criterion(d_real, real_labels) + criterion(d_fake, fake_labels)

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # Generator 학습
        z = torch.randn(batch_size, n_noise).to(device)
        fake_imgs = generator(z, labels)
        d_fake = discriminator(fake_imgs, labels)

        g_loss = criterion(d_fake, real_labels)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

    print(f"Epoch [{epoch+1:04}/{total_epoch}] D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

    if epoch == 0 or (epoch + 1) % 10 == 0:
        with torch.no_grad():
            z = torch.randn(10, n_noise).to(device)
            labels = torch.eye(10).to(device)  # 각 숫자 하나씩
            samples = generator(z, labels).cpu().view(-1, 28, 28)

            fig, ax = plt.subplots(1, 10, figsize=(10, 1))
            for i in range(10):
                ax[i].imshow(samples[i], cmap='gray')
                ax[i].axis('off')
            plt.savefig(f'samples_ex/{str(epoch).zfill(3)}.png', bbox_inches='tight')
            plt.close(fig)
