# Import The Necessary Libraries

In [1]:
# Pytorch
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets

# Images -> GIF
import imageio

# Operations
import numpy as np

# Plotting
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import os


# Setting Up the Model Parameters

In [2]:
BATCH_SIZE = 16
EPOCHS = 200
SAMPLE_SIZE = 64
NZ = 128
k = 1
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ]
)
to_pil_image = transforms.ToPILImage()

In [4]:
train_data = datasets.MNIST(root="../input/data/",
                            train=True,
                            download=True,
                            transform=transform)
train_dataloader = DataLoader(train_data,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=os.cpu_count())

In [5]:
class Generator(nn.Module):
    def __init__(self, nz):
        super().__init__()
        self.nz = nz
        self.layer_stack = nn.Sequential(
            nn.Linear(self.nz, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    def forward(self, x):
        return self.layer_stack(x).view(-1, 1, 28, 28)

In [21]:
class Discriminator(nn.Module):
    def __init__(self, input):
        super().__init__()
        self.input = 784
        self.layer_stack = nn.Sequential(
            nn.Linear(self.input, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.layer_stack(x)

In [22]:
generator = Generator(NZ).to(device)
discriminator = Discriminator(NZ).to(device)

In [23]:
optim_g = optim.Adam(generator.parameters(), 0.0002)
optim_d = optim.Adam(discriminator.parameters(), 0.0002)

In [24]:
loss_fn = nn.BCELoss()

In [25]:
losses_g = []
losses_d = []
images = []

In [26]:
def label_real(size):
    data = torch.ones(size, 1)
    return data.to(device)

In [27]:
def label_fake(size):
    data = torch.zeros(size, 1)
    return data.to(device)

In [28]:
def create_noise(sample_size, nz):
    return torch.randn(sample_size, nz).to(device)

In [29]:
def save_generator_image(image, path):
    save_image(image, path)

In [30]:
def train_discriminator(optimizer, data_real, data_fake):
    b_size = data_real.size(0)
    real_label = label_real(b_size)
    fake_label = label_fake(b_size)

    optimizer.zero_grad()
    real_preds = discriminator(data_real)
    real_loss = loss_fn(real_preds, real_label)

    fake_preds = discriminator(data_fake)
    fake_loss = loss_fn(fake_preds, fake_label)

    real_loss.backward()
    fake_loss.backward()
    optimizer.step()

    return real_loss + fake_loss

In [31]:
def train_generator(optimizer, data_fake):
    b_size = data_fake.size(0)
    real_label = label_real(b_size)

    optimizer.zero_grad()
    preds = discriminator(data_fake)
    loss = loss_fn(preds, real_label)

    loss.backward()
    optimizer.step()
    return loss

In [32]:
noise = create_noise(SAMPLE_SIZE, NZ)

In [33]:
noise[0].shape

torch.Size([128])

In [34]:
generator.train()
discriminator.train()

Discriminator(
  (layer_stack): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=1024, out_features=512, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=512, out_features=256, bias=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Linear(in_features=256, out_features=1, bias=True)
    (7): Sigmoid()
  )
)

In [35]:
for epoch in range(EPOCHS):
    loss_g = 0.0
    loss_d = 0.0
    for bi, data in tqdm(enumerate(train_dataloader)):
        image, _  = data
        image = image.to(device)
        b_size = len(image)
        for step in range(k):
            data_fake = generator(create_noise(b_size, NZ)).detach()
            data_real = image
            loss_d += train_discriminator(optimizer=optim_d, data_fake=data_fake, data_real=data_real)
        data_fake = generator(create_noise(b_size, NZ))
        loss_g += train_generator(optimizer=optim_g, data_fake=data_fake)
    generated_img = generator(noise).cpu().detach()
    generated_img = make_grid(generated_img)
    save_generator_image(generated_img,  f"../outputs/gen_img{epoch}.png")
    images.append(generated_img)
    epoch_loss_g = loss_g / bi
    epoch_loss_d = loss_d / bi
    losses_g.append(epoch_loss_g)
    losses_d.append(epoch_loss_d)
    print(f"Epoch {epoch} of {EPOCHS}")
    print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")

3750it [00:47, 79.69it/s] 


Epoch 0 of 200
Generator loss: 2.61756182, Discriminator loss: 0.95613670


3750it [00:46, 80.94it/s] 


Epoch 1 of 200
Generator loss: 2.61412835, Discriminator loss: 0.67624497


3750it [00:40, 92.38it/s] 


Epoch 2 of 200
Generator loss: 2.36129045, Discriminator loss: 0.69832164


3750it [00:40, 92.96it/s] 


Epoch 3 of 200
Generator loss: 1.78526974, Discriminator loss: 0.84575677


3750it [00:38, 96.97it/s] 


Epoch 4 of 200
Generator loss: 1.68933558, Discriminator loss: 0.88962138


3750it [00:39, 95.79it/s] 


Epoch 5 of 200
Generator loss: 1.62609541, Discriminator loss: 0.89697891


3750it [00:38, 98.08it/s] 


Epoch 6 of 200
Generator loss: 1.42942250, Discriminator loss: 0.97756666


3750it [00:39, 95.26it/s] 


Epoch 7 of 200
Generator loss: 1.36320627, Discriminator loss: 1.01598930


3750it [00:39, 95.36it/s] 


Epoch 8 of 200
Generator loss: 1.37778091, Discriminator loss: 1.00625420


3750it [00:46, 81.29it/s] 


Epoch 9 of 200
Generator loss: 1.27129710, Discriminator loss: 1.05908287


3750it [00:43, 86.04it/s] 


Epoch 10 of 200
Generator loss: 1.24002755, Discriminator loss: 1.07696033


0it [00:00, ?it/s]

In [None]:
imgs = [np.array(to_pil_image(img)) for img in images]
imageio.mimsave('../outputs/generator_images.gif', imgs)

In [None]:
plt.figure()
plt.plot(losses_g, label='Generator loss')
plt.plot(losses_d, label='Discriminator Loss')
plt.legend()
plt.savefig('../outputs/loss.png')