In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image

In [None]:
class CustomGANDataset(Dataset):
    def __init__(self, 
        transform = transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=(0.5),
                                                        std=(0.5,))
                                ])):

        self.root_dir = os.getcwd()
        self.data_dir = os.path.join(self.root_dir, 'mnist', 'data')
        self.transform = transform

        self.imgs = [os.path.join(self.data_dir, pth) for pth in os.listdir(self.data_dir)]
        # self.imgs = self.imgs[:500]

    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, index):
        img_pth = self.imgs[index]

        img = np.load(img_pth)
        img = self.transform(img)

        return img

In [None]:
class Generator(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Generator, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Discriminator, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)

        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)

        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [None]:
# GPU 잡기
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Dataset, DataLoader 정의
# batch_size 설정 & epoch 정의
dataset = CustomGANDataset()
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
epochs = 200

In [None]:
# model hyperparameter-Generator
g_input_dim = 100
g_hidden_dim = 256
g_output_dim = 784 # 28*28

In [None]:
# model hyperparameter-Discriminator
d_input_dim = g_output_dim
d_hidden_dim = 1024

In [None]:
# build model
G = Generator(g_input_dim, g_hidden_dim, g_output_dim)
G = G.to(device=device)
D = Discriminator(d_input_dim, d_hidden_dim)
D = D.to(device=device)

In [None]:
# 목적함수 설정
criterion = nn.BCELoss()
criterion = criterion.to(device=device)

In [None]:
# Optimizer 설정
optimizer_g = optim.Adam(G.parameters(), lr=0.0002)
optimizer_d = optim.Adam(D.parameters(), lr=0.0002)

In [None]:
# Save directory 설정
root_dir = os.getcwd()
data_dir = os.path.join(root_dir, 'mnist')
gan_dir = os.path.join(data_dir, 'GAN')
os.makedirs(gan_dir, exist_ok=True)

In [None]:
# Training
total_step = len(train_loader)
cnt = 0
g_losses, d_losses = [], []

for epoch in range(epochs):
    with tqdm(total=len(train_loader)) as tbar:
        for idx, imgs in enumerate(train_loader):
            N, _, _, _ = imgs.size()

            imgs = imgs.to(device=device)
            imgs = imgs.view(N, -1)

            label_real = torch.ones(N, 1).to(device=device)
            label_fake = torch.zeros(N, 1).to(device=device)
            
            # training discriminator
            optimizer_d.zero_grad()
            output_real = D(imgs)

            d_loss_real = criterion(output_real, label_real)
            score_real = output_real

            z = torch.randn(N, g_input_dim).to(device=device)
            fake_imgs = G(z) # Images from Generator
            output_fake = D(fake_imgs)
            d_loss_fake = criterion(output_fake, label_fake)
            score_fake = output_fake

            # backpropagation
            d_loss_total = d_loss_fake+d_loss_real  
            d_loss_total.backward()
            optimizer_d.step()
            
            if cnt % 50 == 0:
                d_losses.append(d_loss_total.item())
            
            # training generator
            optimizer_g.zero_grad()
            z = torch.randn(N, g_input_dim).to(device=device)
            fake_imgs = G(z)

            output_fake = D(fake_imgs)

            g_loss = criterion(output_fake, label_real)

            g_loss.backward()
            optimizer_g.step()

            if cnt % 50 == 0:
                g_losses.append(g_loss.item())

            tbar.set_description('Epoch [{:3d}/{:3d}], Step [{:2d}/{:2d}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                  .format(epoch, epochs, idx + 1, total_step, d_loss_total.item(), g_loss.item(),
                          score_real.mean().item(), score_fake.mean().item()))
            cnt += 1
            
            if cnt % 100 == 0:
                torch.save(G.state_dict(), os.path.join(gan_dir, f'generator_{str(cnt)}.pth'))
                torch.save(D.state_dict(), os.path.join(gan_dir, f'discriminator_{str(cnt)}.pth'))
                
                plt.plot(range(len(d_losses)), d_losses, label='Discriminator loss', c='red')
                plt.plot(range(len(g_losses)), g_losses, label='Discriminator loss', c='blue')
                plt.legend()
                plt.savefig(os.path.join(data_dir, f'loss plot.png'))
                plt.close()

            tbar.update(1)
        
        imgs = imgs.reshape(N, 1, 28, 28)
        save_image(imgs, os.path.join(gan_dir, f"imgs_{str(epoch)}.png"))