In [None]:
import os
import tqdm
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision import datasets
from torchvision.utils import make_grid

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.layers(x)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.layers(x)
        return x.view(-1)

In [None]:
cuda_id = 0
device_name = "cuda:{}".format(cuda_id) if torch.cuda.is_available() else "cpu"
device = torch.device(device_name)

In [None]:
epochs = 250
batch_size = 128
lr = 0.0002
latent_dim = 100

In [None]:
dataloader = data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize(0.5, 0.5)
                   ])),
    batch_size=batch_size, shuffle=True)
generator = Generator().to(device)
discriminator = Discriminator().to(device)
optimizer_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()

In [None]:
generator

In [None]:
discriminator

In [None]:
error_d_all = []
error_g_all = []
for epoch in tqdm.tqdm(range(epochs)):
    img = None
    error_d_avg = 0
    error_g_avg = 0
    for i, (img, _) in enumerate(dataloader):
        img = img.view(-1, 784).to(device)
        n_samples = img.shape[0]
        true_labels = torch.ones(n_samples).to(device)
        fake_labels = torch.zeros(n_samples).to(device)
        true_output = discriminator(img)
        error_d_real = criterion(true_output, true_labels)
        noise = torch.randn(n_samples, latent_dim).to(device)
        fake_img = generator(noise).detach()
        fake_output = discriminator(fake_img)
        error_d_fake = criterion(fake_output, fake_labels)
        error_d = error_d_real + error_d_fake
        optimizer_d.zero_grad()
        error_d.backward()
        optimizer_d.step()
        noise = torch.randn(n_samples, latent_dim).to(device)
        fake_img = generator(noise)
        fake_output = discriminator(fake_img)
        error_g = criterion(fake_output, true_labels)
        optimizer_g.zero_grad()
        error_g.backward()
        optimizer_g.step()
        error_d_avg += error_d.item()
        error_g_avg += error_g.item()
    print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
          .format(epoch + 1, epochs, error_d_avg / len(dataloader), error_g_avg / len(dataloader)))
    error_d_all.append(error_d_avg / len(dataloader))
    error_g_all.append(error_g_avg / len(dataloader))

    if not os.path.exists('img'):
        os.mkdir('img')
    if not os.path.exists('checkpoints'):
        os.mkdir('checkpoints')
    if epoch == 0 or (epoch + 1) % 20 == 0:
        noise = torch.randn(img.size(0), latent_dim).to(device)
        with torch.no_grad():
            fake = generator(noise)
        img, fake = img.view(-1, 1, 28, 28), fake.view(-1, 1, 28, 28)
        img_grid_real = make_grid(img[:64], normalize=True)
        img_grid_fake = make_grid(fake[:64], normalize=True)
        plt.subplot(1,2,1)
        plt.imshow(np.transpose(img_grid_real.cpu(), (1, 2, 0)))
        plt.title('real images')
        plt.axis('off')
        plt.subplot(1,2,2)
        plt.imshow(np.transpose(img_grid_fake.cpu(), (1, 2, 0)))
        plt.title('fake images')
        plt.axis('off')
        plt.show()
plt.plot([i for i in range(epochs)], error_d_all, label='error_d')
plt.plot([i for i in range(epochs)], error_g_all, label='error_g')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('loss')
plt.show()
torch.save(generator.state_dict(), 'checkpoints/generator.pkl')

In [None]:
model = Generator().to(device)
model.load_state_dict(torch.load('checkpoints/generator.pkl'))
model.eval()
noise0 = torch.zeros(1, latent_dim).to(device)
noise1 = torch.ones(1, latent_dim).to(device)
noises = [noise0]
for i in range(1, 8):
    alpha = i / 8
    noise = alpha * noise1 + (1 - alpha) * noise0
    noises.append(noise)
noises.append(noise1)
for noise in noises:
    with torch.no_grad():
        fake = model(noise)
    fake = fake.view(-1, 1, 28, 28)
    plt.imshow(np.transpose(fake.cpu().numpy()[0], (1, 2, 0)))
    plt.axis('off')
    plt.show()
