In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
%matplotlib inline

In [None]:
from torchvision import datasets, models, transforms, utils

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import scipy.misc as msc
from PIL import Image

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
batch_size = 64
lr = 0.001

D_ent = 100
D_length = 28
D_img = D_length ** 2
D_hidden = 28

In [None]:
trans = transforms.Compose([transforms.ToTensor(), ])

In [None]:
data = datasets.MNIST(root='../data/', train=True, transform=trans)

In [None]:
data_loader = torch.utils.data.DataLoader(dataset=data, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
def mnist():
    data = next(iter(data_loader))[0]
    return data.view(batch_size, D_img)

In [None]:
inputs = mnist().data.resize_(batch_size, 1, D_length, D_length)
out = utils.make_grid(inputs)
msc.toimage(out)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 1)
    
    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)
        #----1----#
        x = self.conv2(x)
        x = self.conv2_drop(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)
        #----2----#
        x = x.view(-1, 320)
        x = self.fc1(x)
        x = F.relu(x)
        #----3----#
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        #----4----#
        res = F.sigmoid(x).view(-1, )
        return res

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(100, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 784)
    
    def forward(self, x):
        x = self.fc1(x)
        x = F.leaky_relu(x, 0.2)
        x = self.fc2(x)
        x = F.leaky_relu(x, 0.2)
        x = self.fc3(x)
        x = F.tanh(x)
        x = x.view(-1, 1, 28, 28)
        return x

In [None]:
D = Discriminator().to(device)
G = Generator().to(device)

In [None]:
criterion = nn.BCELoss()

In [None]:
D_opt = optim.Adam(D.parameters(), lr=lr)
G_opt = optim.Adam(G.parameters(), lr=lr)

In [None]:
ones = torch.ones(batch_size).to(device)
zeros = torch.zeros(batch_size).to(device)

In [None]:
step = 0
max_epoch = 100
dis_step = 3
log_interval = 50

In [None]:
for epoch in range(max_epoch):
    for idx, (images, _) in enumerate(data_loader):
        step += 1
        x = images.to(device)
        x_out = D(x)
        x_loss = criterion(x_out, ones)
        
        z = torch.randn(batch_size, 100).to(device)
        z_out = D(G(z))
        z_loss = criterion(z_out, zeros)
        
        D_loss = x_loss + z_loss
        D.zero_grad()
        D_loss.backward()
        D_opt.step()
        
        if step % dis_step == 0:
            z = torch.randn(batch_size, 100).to(device)
            z_out = D(G(z))
            G_loss = criterion(z_out, ones)
            G.zero_grad()
            G_loss.backward()
            G_opt.step()
        if step % log_interval == 0:
            with torch.no_grad():
                s = G(torch.randn(batch_size, 100).to(device))
                s_out = utils.make_grid(s)
                msc.toimage(s_out)
                print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, 
                        max_epoch, step, D_loss.data.item(), G_loss.data.item()))