In [1]:
import os
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader 
from torch.autograd import Variable 
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets 

In [3]:
num_eps = 10
bsize = 32 
lrate = 0.001 
lat_dimension = 64 ## 랜덤노이즈 벡터 길이 
image_sz = 64 
chnls = 1
logging_intv = 200 

In [4]:
class GANGenerator(nn.Module):
    def __init__(self):
        super(GANGenerator, self).__init__()
        self.inp_sz = image_sz // 4
        self.lin = nn.Linear(lat_dimension, 128 * self.inp_sz ** 2) ## 첫번째 param - input dim / 두번째 param - output dim
        self.bn1 = nn.BatchNorm2d(128)
        self.up1 = nn.Upsample(scale_factor=2)
        self.conv1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128, 0.8) ## 0.8 -> ? 
        self.rl1 = nn.LeakyReLU(0.2, inplace=True)
        self.up2 = nn.Upsample(scale_factor=2)
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.rl2 = nn.LeakyReLU(0.2, inplace=True)
        self.conv3 = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)
        self.act = nn.Tanh()

    def forward(self, x):

        x = self.lin(x)

        x = x.view(x.shape[0], 128, self.inp_sz, self.inp_sz) ## x.shape[0] - 배치사이즈? 
        x = self.bn1(x)
        x = self.up1(x)

        x = self.conv1(x)
        x = self.bn2(x)
        x = self.rl1(x)

        x = self.up2(x)

        x = self.conv2(x)
        x = self.bn3(x)
        x = self.rl2(x) 

        x = self.conv3(x)
        out = self.act(x)

        return out 


In [8]:
class GANDiscriminator(nn.Module):
    def __init__(self):
        super(GANDiscriminator, self).__init__()

        def disc_module(input_channels, output_channels, bnorm=True):
            mod = [nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1), 
                   nn.LeakyReLU(0.2, inplace=True), 
                   nn.Dropout(0.25)]
            if bnorm:
                mod += [nn.BatchNorm2d(output_channels, 0.8)]
            
            return mod
        
        self.disc_model = nn.Sequential(
            *disc_module(chnls, 16, bnorm=False), ## chnls ???
            *disc_module(16, 32), 
            *disc_module(32, 64),
            *disc_module(64, 128)
        )

        ds_size = image_sz // 2 ** 4
        self.adverse_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 1), 
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.disc_model(x)
        x = x.view(x.shape[0], -1)
        out = self.adverse_layer(x)
        return out

In [9]:
## 모델 인스턴스화
gen = GANGenerator()
disc = GANDiscriminator()

## 손실 함수 정의
adv_loss_func = torch.nn.BCELoss() ##binary cross entropy -> 이진 분류 작업에 맞는 손실 함수

In [10]:
dloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist/", 
        download=True, 
        transform = transforms.Compose(
            [transforms.Resize((image_sz, image_sz)), 
            transforms.ToTensor(), 
            transforms.Normalize([0.5], [0.5])]
        )
    ), 
    batch_size = bsize, 
    shuffle=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/mnist/MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 9465516.79it/s] 


Extracting ./data/mnist/MNIST\raw\train-images-idx3-ubyte.gz to ./data/mnist/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 9651477.48it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST\raw\train-labels-idx1-ubyte.gz
Extracting ./data/mnist/MNIST\raw\train-labels-idx1-ubyte.gz to ./data/mnist/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST\raw\t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 6957280.26it/s]


Extracting ./data/mnist/MNIST\raw\t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting ./data/mnist/MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST\raw




