In [92]:
import os
import torch
import torchvision 
import torch.nn as nn 
import matplotlib.pyplot as plt
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader


In [93]:
#hyperparameters
z_dim = 64
image_dim = 784
lr = 3e-4
EPOCHS = 50
BATCH_SIZE = 100
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")

In [94]:
trainset = datasets.FashionMNIST(
    root = '../PaperWithCode/',
    download=True,
    train = True,
    transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,),(0.5))
    ])
)

train_loader = DataLoader(
    dataset=trainset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [95]:
class Generator(nn.Module):
    def __init__(self,z_dim,img_dim):
        super().__init__()

        self.gen = nn.Sequential(
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.gen(z)

In [96]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        
        self.dis = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.dis(x)

In [97]:
G = Generator(z_dim,image_dim).to(DEVICE)
D = Discriminator(image_dim).to(DEVICE)
d_optimizer = optim.Adam(D.parameters(), lr = 0.0002)
g_optimizer = optim.Adam(G.parameters(), lr = 0.0002)
criterion = nn.BCELoss() 

In [98]:
import os
from torchvision.utils import save_image

dir_name = "GAN_results"

# Create a directory for saving samples
if not os.path.exists(dir_name):
    os.makedirs(dir_name)

In [99]:
total_step = len(train_loader)
for epoch in range(EPOCHS):
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(BATCH_SIZE, -1).to(DEVICE)  #[100,786]
        
        # '진짜'와 '가짜' 레이블 생성
        real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)  #[100,1]
        fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE) #[100,1]
        
        # 판별자가 진짜 이미지를 진짜로 인식하는 오차를 예산
        outputs = D(images)                                 #[100,1]
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # 무작위 텐서로 가짜 이미지 생성
        z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
        fake_images = G(z)
        
        # 판별자가 가짜 이미지를 가짜로 인식하는 오차를 계산
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # 진짜와 가짜 이미지를 갖고 낸 오차를 더해서 판별자의 오차 계산
        d_loss = d_loss_real + d_loss_fake

        # 역전파 알고리즘으로 판별자 모델의 학습을 진행
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # 생성자가 판별자를 속였는지에 대한 오차를 계산
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)
        
        # 역전파 알고리즘으로 생성자 모델의 학습을 진행
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
    # 학습 진행 알아보기
    print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
          .format(epoch, EPOCHS, d_loss.item(), g_loss.item(), 
                  real_score.mean().item(), fake_score.mean().item()))
    
    samples = fake_images.reshape(BATCH_SIZE, 1, 28, 28)
    save_image(samples, os.path.join(dir_name, 'GAN_fake_samples{}.png'.format(epoch + 1)))

Epoch [0/50], d_loss: 0.0818, g_loss: 4.2445, D(x): 0.97, D(G(z)): 0.04
Epoch [1/50], d_loss: 0.0475, g_loss: 6.2974, D(x): 0.98, D(G(z)): 0.01
Epoch [2/50], d_loss: 0.0468, g_loss: 6.9327, D(x): 1.00, D(G(z)): 0.04
Epoch [3/50], d_loss: 0.2253, g_loss: 6.8094, D(x): 0.96, D(G(z)): 0.08
Epoch [4/50], d_loss: 0.1796, g_loss: 6.6226, D(x): 0.94, D(G(z)): 0.01
Epoch [5/50], d_loss: 0.3658, g_loss: 6.6869, D(x): 0.90, D(G(z)): 0.00
Epoch [6/50], d_loss: 0.2494, g_loss: 4.3512, D(x): 0.97, D(G(z)): 0.13
Epoch [7/50], d_loss: 0.1379, g_loss: 5.0320, D(x): 0.97, D(G(z)): 0.07
Epoch [8/50], d_loss: 0.2992, g_loss: 4.4364, D(x): 0.93, D(G(z)): 0.04
Epoch [9/50], d_loss: 0.2737, g_loss: 4.1077, D(x): 0.91, D(G(z)): 0.06
Epoch [10/50], d_loss: 0.2496, g_loss: 4.0438, D(x): 0.96, D(G(z)): 0.11
Epoch [11/50], d_loss: 0.1858, g_loss: 5.2764, D(x): 0.94, D(G(z)): 0.04
Epoch [12/50], d_loss: 0.2694, g_loss: 4.4974, D(x): 0.92, D(G(z)): 0.06
Epoch [13/50], d_loss: 0.4177, g_loss: 3.6286, D(x): 0.90, D(