# 1. Import required Libraries

In [1]:
import torch
import torch.nn as nn
import torch.utils as utils
import torch.nn.init as init
from torch.autograd import Variable
import torchvision.utils as v_utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict

# 2. Hyperparameters

In [2]:
epoch = 50
batch_size = 512
learning_rate = 0.0002
num_gpus = 1

# 3. Data Setting

In [3]:
mnist_train = dset.MNIST("./data", train=True, 
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.5,), (0.5,)),
                        ]),
                        target_transform=None,
                        download=True)

train_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)


# 4. Generator

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.layer1 = nn.Sequential(
                    nn.Linear(100, 7*7*256),
                    nn.BatchNorm1d(7*7*256),
                    nn.ReLU(),
        )
        self.layer2 = nn.Sequential(OrderedDict([
            ('conv1', nn.ConvTranspose2d(256, 128, 3, 2, 1, 1)),
            ('bn1', nn.BatchNorm2d(128)),
            ('relu1', nn.LeakyReLU()),
            ('conv2', nn.ConvTranspose2d(128, 64, 3, 1, 1)),
            ('bn2', nn.BatchNorm2d(64)),
            ('relu2', nn.LeakyReLU()),
        ]))
        self.layer3 = nn.Sequential(OrderedDict([
            ('conv3', nn.ConvTranspose2d(64, 16, 3, 1, 1)),
            ('bn3', nn.BatchNorm2d(16)),
            ('relu', nn.LeakyReLU()),
            ('conv4', nn.ConvTranspose2d(16, 1, 3, 2, 1, 1)),
            ('relu4', nn.Tanh())
        ]))
  
    def forward(self, z):
        out = self.layer1(z)
        out = out.view(batch_size // num_gpus, 256, 7, 7)
        out = self.layer2(out)
        out = self.layer3(out)
        return out
    

# 5. Discriminator

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(1, 8, 3, padding=1)),
            #('bn1', nn.BatchNorm2d(8)),
            ('relu1', nn.LeakyReLU()),
            ('conv2', nn.Conv2d(8, 16, 3, padding=1)),
            ('bn2', nn.BatchNorm2d(16)),
            ('relu2', nn.LeakyReLU()),
            ('max1', nn.MaxPool2d(2, 2))
        ]))
        self.layer2 = nn.Sequential(OrderedDict([
            ('conv3', nn.Conv2d(16, 32, 3, padding=1)),
            ('bn3', nn.BatchNorm2d(32)),
            ('relu3', nn.LeakyReLU()),
            ('max2', nn.MaxPool2d(2, 2)),
            ('conv4', nn.Conv2d(32, 64, 3, padding=1)),
            ('bn4', nn.BatchNorm2d(64)),
            ('relu4', nn.LeakyReLU())
        ]))
        self.fc = nn.Sequential(
            nn.Linear(64*7*7, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(batch_size // num_gpus, -1)
        out = self.fc(out)
        return out

# 6. Put instances on Multi-gpu

In [6]:
generator = nn.DataParallel(Generator()).cuda()
discriminator = nn.DataParallel(Discriminator()).cuda()


# 7. Check layers

In [7]:
gen_params = generator.state_dict().keys()
dis_params = discriminator.state_dict().keys()

for i in gen_params:
    print(i)

module.layer1.0.weight
module.layer1.0.bias
module.layer1.1.weight
module.layer1.1.bias
module.layer1.1.running_mean
module.layer1.1.running_var
module.layer1.1.num_batches_tracked
module.layer2.conv1.weight
module.layer2.conv1.bias
module.layer2.bn1.weight
module.layer2.bn1.bias
module.layer2.bn1.running_mean
module.layer2.bn1.running_var
module.layer2.bn1.num_batches_tracked
module.layer2.conv2.weight
module.layer2.conv2.bias
module.layer2.bn2.weight
module.layer2.bn2.bias
module.layer2.bn2.running_mean
module.layer2.bn2.running_var
module.layer2.bn2.num_batches_tracked
module.layer3.conv3.weight
module.layer3.conv3.bias
module.layer3.bn3.weight
module.layer3.bn3.bias
module.layer3.bn3.running_mean
module.layer3.bn3.running_var
module.layer3.bn3.num_batches_tracked
module.layer3.conv4.weight
module.layer3.conv4.bias


# 8. Set Loss function & Optimizer

In [8]:
loss_func = nn.MSELoss()
gen_optim = torch.optim.Adam(generator.parameters(), lr= 5*learning_rate, betas=(0.5, 0.999))
dis_optim = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

ones_label = Variable(torch.ones(batch_size, 1)).cuda()
zeros_label = Variable(torch.zeros(batch_size, 1)).cuda()

def image_check(gen_fake):
    img = gen_fake.data.numpy()
    for i in range(10):
        plt.imshow(img[i][0], cmap='gray')
        plt.show()

# 9. Restore Model

In [9]:
try:
    generator, discriminator, = torch.load("./model/dcgan.pkl")
    print("\n-----------model restored----------------\n")
except:
    print("\n-----------model not restored--------------\n")
    pass



-----------model not restored--------------



# 10. Train Model

In [12]:
for i in range(epoch):
    for j, (image, label) in enumerate(train_loader):
        image = Variable(image).cuda()
        
        gen_optim.zero_grad()
        
        z = Variable(init.normal(torch.Tensor(batch_size, 100), mean=0, std=0.1)).cuda()
        gen_fake = generator.forward(z)
        dis_fake = discriminator.forward(gen_fake)
        
        gen_loss = torch.sum(loss_func(dis_fake, ones_label))
        gen_loss.backward()
        gen_optim.step()
        
        dis_optim.zero_grad()
        
        z = Variable(init.normal(torch.Tensor(batch_size, 100), mean=0, std=0.1)).cuda()
        gen_fake = generator.forward(z)
        dis_fake = discriminator.forward(gen_fake)
        
        dis_real = discriminator.forward(image)
        dis_loss = torch.sum(loss_func(dis_fake, zeros_label)) + torch.sum(loss_func(dis_real, ones_label))
        dis_loss.backward()
        dis_optim.step()
        
        if j % 50 == 0:
            print(gen_loss, dis_loss)
            torch.save([generator, discriminator], "./model/dcgan.pkl")
            
            print("{}th iteration gen_loss: {} dis_loss: {}".format(i, gen_loss.data, dis_loss.data))
            v_utils.save_image(gen_fake.data[0:25], "./result/gen_{}_{}.png".format(i, j), nrow=5)

  import sys


tensor(0.1965, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.6035, device='cuda:0', grad_fn=<AddBackward0>)
0th iteration gen_loss: 0.19651813805103302 dis_loss: 0.6035259366035461


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


tensor(0.7669, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.0745, device='cuda:0', grad_fn=<AddBackward0>)
0th iteration gen_loss: 0.766917884349823 dis_loss: 0.07447990030050278
tensor(0.5292, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.1864, device='cuda:0', grad_fn=<AddBackward0>)
0th iteration gen_loss: 0.5292307734489441 dis_loss: 0.1863524168729782
tensor(0.5764, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.1831, device='cuda:0', grad_fn=<AddBackward0>)
1th iteration gen_loss: 0.5764449834823608 dis_loss: 0.1830945760011673
tensor(0.5387, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.1800, device='cuda:0', grad_fn=<AddBackward0>)
1th iteration gen_loss: 0.5387470126152039 dis_loss: 0.18000206351280212
tensor(0.2675, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.3146, device='cuda:0', grad_fn=<AddBackward0>)
1th iteration gen_loss: 0.2674868404865265 dis_loss: 0.31460481882095337
tensor(0.6083, device='cuda:0', grad_fn=<SumBackward0>) tensor(0.2719, device=