In [5]:
# Import all the necessary libraries
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
import matplotlib.pyplot as plt
from tqdm import notebook

In [1]:
# Define our simple vanilla generator
class Generator(nn.Module):
    """
    Architecture
    ------------
    Latent Input: latent_shape
    Flattened
    Linear MLP(256, 512, 1024, prod(img_shape))

    Leaky Relu activation after every layer except last. (Important!)
    Tanh activation after last layer to normalize
    """
    def __init__(self, latent_shape, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(np.prod(latent_shape), 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, np.prod(img_shape)),
            nn.Tanh()
        )
    def forward(self, x):
        batch_size = x.shape[0]
        # reshape into a image
        return self.mlp(x).reshape(batch_size, 1, *self.img_shape)

KeyboardInterrupt: 

In [7]:
# Define our simple vanilla discriminator
class Discriminator(nn.Module):
    """
    Architecture
    ------------
    Input Image: img_shape
    Flattened
    Linear MLP(128, 512, 256, 1)
    Relu activation after every layer except last.
    Sigmoid activation after last layer to normalize in range 0 to 1
    """
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(np.prod(img_shape), 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.mlp(x)

In [None]:
# load our data
latent_shape = (28, 28)
img_shape = (28, 28)
batch_size = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])
train_dataset = torchvision.datasets.MNIST(root="./data", train = True, download=True, transform=transform)der(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # for gpu usage if possible

generator = Generator(latent_shape, img_shape)
discriminator = Discriminator(img_shape)

gen_optim = torch.optim.Adam(generator.parameters(), lr=2e-4)
disc_optim = torch.optim.Adam(discriminator.parameters(), lr=2e-4)

# use gpu if possible
generator = generator.to(device)
discriminator = discriminator.to(device)

In [None]:
def train(generator, discriminator, generator_optim: torch.optim, discriminator_optim: torch.optim, epochs=10):
    adversarial_loss = torch.nn.BCELoss()
    
    for epoch in range(1, epochs+1):
        print("Epoch {}".format(epoch))
        avg_g_loss = 0
        avg_d_loss = 0
        pbar = notebook.tqdm(train_dataloader, total=len(train_dataloader))
        i = 0
        for data in pbar:
            i += 1
            real_images = data[0].to(device)
            ### Train Generator ###
            generator_optim.zero_grad()
            
            latent_input = torch.randn((batch_size, 1, *latent_shape)).to(device)
            fake_images = generator(latent_input)

            fake_res = discriminator(fake_images)
            
            generator_loss = adversarial_loss(fake_res, torch.ones_like(fake_res))
            generator_loss.backward()
            generator_optim.step()
            
            ### Train Discriminator ###
            discriminator_optim.zero_grad()
            
            real_res = discriminator(real_images)

            fake_res = discriminator(fake_images.detach())

            discriminator_real_loss = adversarial_loss(real_res, torch.ones_like(real_res))
            discriminator_fake_loss = adversarial_loss(fake_res, torch.zeros_like(real_res))
            discriminator_loss = (discriminator_real_loss + discriminator_fake_loss) / 2
            discriminator_loss.backward()
            discriminator_optim.step()
            

            avg_g_loss += generator_loss.item()
            avg_d_loss += discriminator_loss.item()
            pbar.set_postfix({"G_loss": generator_loss.item(), "D_loss": discriminator_loss.item()})

In [None]:
# train our generator and discriminator
# Note: don't always expect loss to go down simultaneously for both models. They are competing against each other! So sometimes one model 
# may perform better than the other
train(generator=generator, discriminator=discriminator, generator_optim=gen_optim, discriminator_optim=disc_optim)

In [None]:
# test it out!
latent_input = torch.randn((batch_size, 1, *latent_shape))
test = generator(latent_input.to(device))
plt.imshow(test[0].reshape(28, 28).cpu().detach().numpy())