In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

In [None]:
# MNIST Dataset

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

original_train_dataset = datasets.MNIST(root='./mnist_data/', train=True,
                                        transform=transform, download=True)
original_test_dataset = datasets.MNIST(root='./mnist_data/', train=False,
                                       transform=transform, download=True)

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set Hyper-parameters (change None)
BATCH_SIZE = 100
LEARNING_RATE_D = 0.0002
LEARNING_RATE_G = 0.0002
N_EPOCH = 250

In [None]:
# Define Train loader
train_tensors = original_train_dataset.data.float() / 255
test_tensors = original_test_dataset.data.float() / 255

train_dataset = torch.utils.data.TensorDataset(train_tensors, original_train_dataset.targets)
test_dataset = torch.utils.data.TensorDataset(test_tensors, original_test_dataset.targets)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, dis_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(dis_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

    def forward(self, img):
        x = F.leaky_relu(self.fc1(img), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

class Generator(nn.Module):
    def __init__(self, gen_input_dim, gen_output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(gen_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, gen_output_dim)

    def forward(self, z):
        x = F.leaky_relu(self.fc1(z), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))

RV_dim = 100
pic_dim = original_train_dataset.data.shape[1] * original_train_dataset.data.shape[2] 

D = Discriminator(pic_dim)
G = Generator(gen_input_dim=RV_dim, gen_output_dim=pic_dim)

In [None]:
print(D)
print(G)

Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)
Generator(
  (fc1): Linear(in_features=100, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
)


In [None]:
# Device setting
D = D.to(device)
G = G.to(device)

In [None]:
opt_D = optim.Adam(D.parameters(), lr = LEARNING_RATE_D)
opt_G = optim.Adam(G.parameters(), lr = LEARNING_RATE_G)
# Loss function (use ".to(device)" to use gpu(cuda))
loss_function = nn.BCELoss().to(device) 

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 3) # set default size of plots

In [None]:
for epoch in range(N_EPOCH):
    for i, (img, label) in enumerate(train_loader):

        D.zero_grad()

        real_img = img.view(-1, pic_dim)
        real_img = Variable(real_img.to(device))

        real_labels = torch.ones(img.shape[0], 1)
        real_labels = Variable(real_labels.to(device))

        D_output = D(real_img)
        D_real_loss = loss_function(D_output, real_labels)
        D_real_score = D_output
        
        z = torch.randn(img.shape[0], RV_dim).to(device)
        z = Variable(z)
        fake_img = G(z) 

        fake_labels = torch.zeros(img.shape[0], 1).to(device)
        fake_labels = Variable(fake_labels)

        D_output = D(fake_img)
        D_fake_loss = loss_function(D_output, fake_labels)
        D_fake_score = D_output

        # opt_D.zero_grad()
        
        loss_d = D_real_loss + D_fake_loss
        
        loss_d.backward()
        opt_D.step()

        G.zero_grad()

        z = torch.randn(img.shape[0], RV_dim).to(device)
        z = Variable(z)
        fake_img = G(z)

        Y = torch.ones(img.shape[0], 1).to(device)
        Y = Variable(Y)

        D_output = D(fake_img)
        loss_g = loss_function(D_output, Y)

        # opt_G.zero_grad()
        
        loss_g.backward()
        opt_G.step()

    
    print("epoch: {} \t last batch loss D: {} \t last batch loss G: {}".format(epoch + 1, 
                                                                               loss_d.item(), 
                                                                               loss_g.item()))

    for i in range(3):
        for j in range(10):
            plt.subplot(3, 10, i * 10 + j + 1)
            plt.imshow(fake_img[i * 10 + j].detach().cpu().view(28, 28).numpy())
    plt.show()

