In [28]:
import os
import torch
import torchvision 
import torch.nn as nn 
import matplotlib.pyplot as plt
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader


In [50]:
#hyperparameters
z_dim = 64
image_dim = 784
lr = 3e-4
EPOCHS = 50
BATCH_SIZE = 100
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")

In [60]:
trainset = datasets.FashionMNIST(
    root = '../PaperWithCode/',
    download=True,
    train = True,
    transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,),(0.5))
    ])
)

train_loader = DataLoader(
    dataset=trainset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../PaperWithCode/FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:20<00:00, 1273338.61it/s]


Extracting ../PaperWithCode/FashionMNIST\raw\train-images-idx3-ubyte.gz to ../PaperWithCode/FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../PaperWithCode/FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 107674.41it/s]


Extracting ../PaperWithCode/FashionMNIST\raw\train-labels-idx1-ubyte.gz to ../PaperWithCode/FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../PaperWithCode/FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:06<00:00, 717240.42it/s] 


Extracting ../PaperWithCode/FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ../PaperWithCode/FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../PaperWithCode/FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5149601.00it/s]

Extracting ../PaperWithCode/FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ../PaperWithCode/FashionMNIST\raw






In [61]:
class Generator(nn.Module):
    def __init__(self,z_dim,img_dim):
        super().__init__()

        self.gen = nn.Sequential(
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.gen(z)

In [75]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        
        self.dis = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.dis(x)

In [76]:
G = Generator(z_dim,image_dim).to(DEVICE)
D = Discriminator(image_dim).to(DEVICE)
d_optimizer = optim.Adam(D.parameters(), lr = 0.0002)
g_optimizer = optim.Adam(G.parameters(), lr = 0.0002)
criterion = nn.BCELoss() 

In [77]:
import os
from torchvision.utils import save_image

dir_name = "GAN_results"

# Create a directory for saving samples
if not os.path.exists(dir_name):
    os.makedirs(dir_name)

In [89]:
total_step = len(train_loader)
for epoch in range(EPOCHS):
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(BATCH_SIZE, -1).to(DEVICE)  #[100,786]
        
        # '진짜'와 '가짜' 레이블 생성
        real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)  #[100,1]
        fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE) #[100,1]
        
        # 판별자가 진짜 이미지를 진짜로 인식하는 오차를 예산
        outputs = D(images)                                 #[100,1]
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # 무작위 텐서로 가짜 이미지 생성
        z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
        fake_images = G(z)
        
        # 판별자가 가짜 이미지를 가짜로 인식하는 오차를 계산
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # 진짜와 가짜 이미지를 갖고 낸 오차를 더해서 판별자의 오차 계산
        d_loss = d_loss_real + d_loss_fake

        # 역전파 알고리즘으로 판별자 모델의 학습을 진행
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # 생성자가 판별자를 속였는지에 대한 오차를 계산
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)
        
        # 역전파 알고리즘으로 생성자 모델의 학습을 진행
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
    # 학습 진행 알아보기
    print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
          .format(epoch, EPOCHS, d_loss.item(), g_loss.item(), 
                  real_score.mean().item(), fake_score.mean().item()))

Epoch [0/50], d_loss: 0.0406, g_loss: 4.7338, D(x): 0.99, D(G(z)): 0.03
Epoch [1/50], d_loss: 0.0461, g_loss: 4.7364, D(x): 0.99, D(G(z)): 0.02
Epoch [2/50], d_loss: 0.1423, g_loss: 10.3365, D(x): 0.96, D(G(z)): 0.00
Epoch [3/50], d_loss: 0.0673, g_loss: 6.5878, D(x): 0.97, D(G(z)): 0.02
Epoch [4/50], d_loss: 0.0757, g_loss: 7.7468, D(x): 0.98, D(G(z)): 0.03
Epoch [5/50], d_loss: 0.0236, g_loss: 8.3082, D(x): 0.99, D(G(z)): 0.01
Epoch [6/50], d_loss: 0.2538, g_loss: 6.8520, D(x): 0.92, D(G(z)): 0.00
Epoch [7/50], d_loss: 0.1772, g_loss: 4.4231, D(x): 0.97, D(G(z)): 0.08
Epoch [8/50], d_loss: 0.1949, g_loss: 5.7602, D(x): 0.94, D(G(z)): 0.01
Epoch [9/50], d_loss: 0.1588, g_loss: 5.0980, D(x): 0.94, D(G(z)): 0.04
Epoch [10/50], d_loss: 0.1571, g_loss: 5.1245, D(x): 0.94, D(G(z)): 0.03
Epoch [11/50], d_loss: 0.3053, g_loss: 3.9137, D(x): 0.94, D(G(z)): 0.11
Epoch [12/50], d_loss: 0.1925, g_loss: 5.3882, D(x): 0.96, D(G(z)): 0.04
Epoch [13/50], d_loss: 0.3306, g_loss: 5.4865, D(x): 0.87, D