In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as utils
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.functional as F

is_cuda = torch.cuda.is_available()
device = torch.device('cuda' if is_cuda else 'cpu')

# 0~1로 standardize -> G에서 tanh 써야함
standardizator = transforms.Compose([
transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

# MNIST dataset
train_data = dsets.MNIST(root='data/', train=True, transform=standardizator, download=True)
test_data  = dsets.MNIST(root='data/', train=False, transform=standardizator, download=True)


batch_size = 100
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)
test_data_loader  = torch.utils.data.DataLoader(test_data, batch_size, shuffle=True)

In [18]:
example_mini_batch_img, example_mini_batch_label  = next(iter(train_data_loader))

In [2]:
# input notse generator for Generator
def input_noise_generator(batch_size, dim):
    return torch.randn(batch_size, dim, device=device)

In [5]:
# define generator
class Generator(nn.Module):
    def __init__(self, dim_input, dim_hidden_1, dim_hidden_2, dim_output):
        super(Generator, self).__init__()
        self.layer_1 = nn.Linear(dim_input, dim_hidden_1)
        self.layer_2 = nn.Linear(dim_hidden_1, dim_hidden_2)
        self.layer_3 = nn.Linear(dim_hidden_2, dim_output)
    
    def forward(self, x):
        x = F.relu(self.layer_1(x))
        x = F.dropout(x, p=0.1)
        x = F.relu(self.layer_2(x))
        x = F.dropout(x, p=0.1)
        x = F.tanh(self.layer_3(x))
        return x
        

In [None]:
class Discriminator(nn.Module):
    def __init__(self, dim_input, dim_hidden_1, dim_hidden_2, dim_output):
        super(Discriminator, self).__init__()
        self.layer_1 = nn.Linear(dim_input, dim_hidden_1)
        self.layer_2 = nn.Linear(dim_hidden_1, dim_hidden_2)
        self.layer_3 = nn.Linear(dim_hidden_2, dim_output)
        
    def forward(self, x):
        x = F.relu(self.layer_1(x))
        x = F.dropout(x, p=0.1)
        x = F.relu(self.layer_2(x))
        x = F.dropout(x, p=0.1)
        x = F.relu(self.layer_3(x))
        return x
    