# 1. Import required Libraries

In [84]:
import torch
import torch.nn as nn
import torch.utils as utils
import torch.nn.init as init
from torch.autograd import Variable
import torchvision.utils as v_utils 
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict

# 2. Hyperparameter setting

In [85]:
# num_gpus는 사용할 gpu의 개수 

epoch = 50
batch_size = 512
learning_rate = 0.0002
num_gpus = 1
z_size = 50
middle_size = 200


# 3. Data Setting

In [86]:
mnist_train = dset.MNIST("./data", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)

train_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)


# 4. Generator

In [87]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.layer1 = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(z_size, middle_size)),
            ('bn1', nn.BatchNorm1d(middle_size)), 
            ('act1', nn.ReLU()),
        ]))
        self.layer2 = nn.Sequential(OrderedDict([
            ('fc2', nn.Linear(middle_size, 784)),
            # ('bn2', nn.BatchNorm2d(784)),
            ('tanh', nn.Tanh()),
        ]))
        
    def forward(self, z):
        out = self.layer1(z)
        out = self.layer2(out)
        out = out.view(batch_size // num_gpus, 1, 28, 28)
        
        return out

# 5. Discriminator

In [88]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(784, middle_size)),
            ('bn1', nn.BatchNorm1d(middle_size)),
            ('act1', nn.LeakyReLU()),
        ]))
        self.layer2 = nn.Sequential(OrderedDict([
            ('fc2', nn.Linear(middle_size, 1)),
#             ('bn2', nn.BatchNorm2d(1)),
            ('act2', nn.Sigmoid()),
        ]))
        
    def forward(self, x):
        out = x.view(batch_size // num_gpus, -1)
        out = self.layer1(out)
        out = self.layer2(out)
        
        return out

# 6. Put instances on Muti-gpu

In [89]:
generator = nn.DataParallel(Generator()).cuda()
discriminator = nn.DataParallel(Discriminator()).cuda()

# 7. Check layers

In [90]:
gen_params = generator.state_dict().keys()
dis_params = discriminator.state_dict().keys()

for i in gen_params:
    print(i)

module.layer1.fc1.weight
module.layer1.fc1.bias
module.layer1.bn1.weight
module.layer1.bn1.bias
module.layer1.bn1.running_mean
module.layer1.bn1.running_var
module.layer1.bn1.num_batches_tracked
module.layer2.fc2.weight
module.layer2.fc2.bias


# 8. Set Loss function & Optimizer

In [91]:
loss_func = nn.MSELoss()
gen_optim = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
dis_optim = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

ones_label = Variable(torch.ones(batch_size, 1)).cuda()
zeros_label = Variable(torch.zeros(batch_size, 1)).cuda()

# 9. Restore Model 

In [92]:
try:
    generator, discriminator = torch.load("./model/vanilla_gan.pkl")
    print("\n---------model restored------------\n")
except:
    print("\n------------model not restored---------\n")
    pass


------------model not restored---------



# 10. Train Model

In [95]:
for i in range(epoch):
    for j, (image, label) in enumerate(train_loader):
        image = Variable(image).cuda()
        
        dis_optim.zero_grad()
        
        # 노이즈 z를 생성하고 fake 데이터를 generator를 통해 만듬 
        z = Variable(init.normal(torch.Tensor(batch_size, z_size), mean=0, std=0.1)).cuda()
        gen_fake = generator.forward(z)
        dis_fake = discriminator.forward(gen_fake)
        
        dis_real = discriminator.forward(image)
        # 실제 데이터와 가짜 데이터에 대한 손실값을 더함 
        dis_loss = torch.sum(loss_func(dis_fake, zeros_label)) + torch.sum(loss_func(dis_real, ones_label))
        
        
        gen_optim.zero_grad()
        
        z = Variable(init.normal(torch.Tensor(batch_size, z_size), mean=0, std=0.1)).cuda()
        gen_fake = generator.forward(z)
        dis_fake = discriminator.forward(gen_fake)
        
        gen_loss = torch.sum(loss_func(dis_fake, ones_label))
        gen_loss.backward()
        gen_optim.step()
        
        
        if j % 100 == 0:
            print(gen_loss, dis_loss)
            torch.save([generator, discriminator], './model/vanilla_gan.pkl')
            
            print("{}th iteration gen_loss: {}".format(i, gen_loss.data, dis_loss.data))
            v_utils.save_image(gen_fake.data[0:25], "./result/gen_{}_{}.png".format(i, j), nrow=5)

  


tensor(0.2766, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.5093, device='cuda:0', grad_fn=<AddBackward0>)
0th iteration gen_loss: 0.2766050100326538
tensor(0.2721, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.5084, device='cuda:0', grad_fn=<AddBackward0>)
0th iteration gen_loss: 0.2720579206943512
tensor(0.2719, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.5089, device='cuda:0', grad_fn=<AddBackward0>)
1th iteration gen_loss: 0.271920382976532
tensor(0.2684, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.5120, device='cuda:0', grad_fn=<AddBackward0>)
1th iteration gen_loss: 0.2683781683444977
tensor(0.2686, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.5122, device='cuda:0', grad_fn=<AddBackward0>)
2th iteration gen_loss: 0.2685858905315399
tensor(0.2648, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.5150, device='cuda:0', grad_fn=<AddBackward0>)
2th iteration gen_loss: 0.26479625701904297
tensor(0.2635, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.5151,