In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as vutils
import os

# 1. 환경 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

batch_size = 64
print(f"batch_size: {batch_size}")

# 2. 데이터 준비 (Fashion MNIST)
transform = transforms.Compose([
    transforms.Resize(32),  # Resize to 32x32 for better convolution
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.FashionMNIST(
    root='./Fashion_MNIST_dataset',
    train=True,
    transform=transform,
    download=True
)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)

In [None]:
# 3. DCGAN 아키텍처 정의
class Generator(nn.Module):
    def __init__(self, noise_dim=100, channels_img=1, features_g=64):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x noise_dim x 1 x 1
            nn.ConvTranspose2d(noise_dim, features_g * 4, kernel_size=3, stride=2, padding=0, bias=False),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(True),
            # State: (features_g*4) x 3 x 3

            nn.ConvTranspose2d(features_g * 4, features_g * 2, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(True),
            # State: (features_g*2) x 6 x 6

            nn.ConvTranspose2d(features_g * 2, features_g, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(features_g),
            nn.ReLU(True),
            # State: (features_g) x 14 x 14

            nn.ConvTranspose2d(features_g, channels_img, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
            # Output: (channels_img) x 28 x 28
        )
        
    def forward(self, x):
        return self.gen(x)

class Discriminator(nn.Module):
    def __init__(self, channels_img=1, features_d=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # Input: (channels_img) x 28 x 28
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (features_d) x 14 x 14

            nn.Conv2d(features_d, features_d * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(features_d * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (features_d*2) x 7 x 7

            nn.Conv2d(features_d * 2, features_d * 4, kernel_size=3, stride=2, padding=0, bias=False),
            nn.BatchNorm2d(features_d * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (features_d*4) x 3 x 3

            nn.Conv2d(features_d * 4, 1, kernel_size=3, stride=1, padding=0, bias=False),
            nn.Sigmoid()
            # Output: 1 x 1 x 1
        )
        
    def forward(self, x):
        return self.disc(x).view(-1, 1)


In [None]:
# 4. 모델 초기화 및 설정
noise_dim = 100
channels_img = 1  # Grayscale images
features_g = 64
features_d = 64
lr, b1, b2 = 3e-5, 0.5, 0.999
num_epochs = 100

generator = Generator(noise_dim, channels_img, features_g).to(device)
discriminator = Discriminator(channels_img, features_d).to(device)

# Initialize weights
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

generator.apply(weights_init)
discriminator.apply(weights_init)

criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# 5. 학습 루프 및 데이터 수집
# 5.1. 결과 저장을 위한 폴더 생성
results_dir = './results_dcgan'
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

G_losses = []
D_losses = []
sample_images = []

# fixed noise 생성
fixed_noise = torch.randn(16, noise_dim, 1, 1).to(device)

# 5.2. 학습 루프
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(train_loader):
        real = real.to(device)
        batch_size_curr = real.size(0)
        
        # Labels
        labels_real = torch.ones(batch_size_curr, 1).to(device)
        labels_fake = torch.zeros(batch_size_curr, 1).to(device)
        
        ## 판별자 학습 ##
        # Real images
        outputs = discriminator(real)
        loss_D_real = criterion(outputs, labels_real)
        
        # Fake images
        noise = torch.randn(batch_size_curr, noise_dim, 1, 1).to(device)
        fake = generator(noise)
        outputs = discriminator(fake.detach())
        loss_D_fake = criterion(outputs, labels_fake)
        
        loss_D = loss_D_real + loss_D_fake
        
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()
        
        ## 생성자 학습 ##
        outputs = discriminator(fake)
        loss_G = criterion(outputs, labels_real)
        
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()
        
    G_losses.append(loss_G.item())
    D_losses.append(loss_D.item())
    
    if epoch == 0 or (epoch+1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}]  Loss_D: {loss_D.item():.4f}  Loss_G: {loss_G.item():.4f}")
        
        with torch.no_grad():
            fake_images = generator(fixed_noise).detach().cpu()
            sample_images.append(fake_images)

In [None]:
# 6. 시각화
# 6.1. 손실 곡선 시각화 및 저장
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator Loss")
plt.plot(D_losses, label="Discriminator Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.savefig(os.path.join(results_dir, "loss_curve_dcgan.png"))
plt.close()

# 6.2. 생성된 이미지 시각화 및 저장
for idx, images in enumerate(sample_images):
    grid = vutils.make_grid(images, nrow=4, normalize=True)
    plt.figure(figsize=(8,8))
    plt.title(f"Generated Images at Epoch {max(idx*10, 1)}")
    plt.imshow(np.transpose(grid, (1,2,0)))
    plt.axis("off")
    # 이미지 저장
    plt.savefig(os.path.join(results_dir, f"generated_epoch_{max(idx*10, 1)}.png"))
    plt.close()