Here i replicate the paper of 'Generative Adversarial Nets' from 2014, to generate handwritten numbers using the MNIST dataset.

In [1]:
""" 
Simple GAN using fully connected layers
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard

# Building the discriminator nn
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1), # the output is a single value (real or fake)
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

# Building the generator nn
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256), # z_dim is the dimention of the "latent noise"
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)


device ="mps" # "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64 # z_dim is the dimention of the "latent noise"
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device) # fixed noise that will be use to watch the progress of the generator in a batch in Tensorboard.
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

dataset = datasets.MNIST(root="dataset_GAN/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss() # loss = - Wn[yn * log(xn) + (1 - yn)*log(1-xn)]

writer_fake = SummaryWriter(log_dir= f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(log_dir= f"runs/GAN_MNIST/real")
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader): # here we do not use the labels of the MNIST dataset, GAN are unsupervised
        
        real = real.view(-1, 784).to(device) #flatten the img
        batch_size = real.shape[0]
        
        #### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device) # generating the random noise
        fake = gen(noise) # generating fake images
    
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real)) # D(real)
        # if yn = 1 (torch.ones_like) --> loss = - 1*[1 * log(disc(real))]
        
        disc_fake = disc(fake).view(-1) 
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) # min log(D(G(z)))
        # if yn = 0 (torch.zeros_like) --> loss = - 1*[(1 - 0) * log(1 - disc(fake))]
        
        lossD = (lossD_real + lossD_fake) / 2
        
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()


        #### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1) 
        lossG = criterion(output, torch.ones_like(output)) # max log(D(G(z))
        
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28) # we use our gen in a fixed_noise batch, to see the progress after each epoch in tensorboard.
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

Epoch [0/50] Batch 0/1875                       Loss D: 0.7407, loss G: 0.7217
Epoch [1/50] Batch 0/1875                       Loss D: 0.8602, loss G: 0.6461
Epoch [2/50] Batch 0/1875                       Loss D: 0.4275, loss G: 1.4299
Epoch [3/50] Batch 0/1875                       Loss D: 0.2308, loss G: 1.9767
Epoch [4/50] Batch 0/1875                       Loss D: 0.6046, loss G: 0.8696
Epoch [5/50] Batch 0/1875                       Loss D: 0.6486, loss G: 0.9889
Epoch [6/50] Batch 0/1875                       Loss D: 0.8064, loss G: 0.9886
Epoch [7/50] Batch 0/1875                       Loss D: 1.0574, loss G: 0.5108
Epoch [8/50] Batch 0/1875                       Loss D: 0.6067, loss G: 0.9279
Epoch [9/50] Batch 0/1875                       Loss D: 0.4648, loss G: 1.3461
Epoch [10/50] Batch 0/1875                       Loss D: 0.4855, loss G: 1.7373
Epoch [11/50] Batch 0/1875                       Loss D: 0.5832, loss G: 1.3158
Epoch [12/50] Batch 0/1875                       L

### Trying to improve the GAN using a better normalization with batchnorm and adding more layers to our nets.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard

# Building the discriminator nn
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.01),
            nn.Linear(1024, 1), # the output is a single value (real or fake)
            nn.Sigmoid()
        )
      
    def forward(self, x):
        return self.disc(x)

# Building the generator nn
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256), # z_dim is the dimention of the "latent noise"
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, 1024),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 550),
            nn.BatchNorm1d(550),
            nn.LeakyReLU(0.01),
            nn.Linear(550, img_dim),
            nn.BatchNorm1d(img_dim),
            nn.LeakyReLU(0.01),
            nn.Tanh()  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)


device ="mps" # "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

dataset = datasets.MNIST(root="dataset_GAN/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss() # loss = - Wn[yn * log(xn) + (1 - yn)*log(1-xn)]

writer_fake = SummaryWriter(log_dir= f"runs/GAN_MNIST/fake_2")
writer_real = SummaryWriter(log_dir= f"runs/GAN_MNIST/real_2")
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader): # here we do not use the labels of the MNIST dataset, GAN are unsupervised
        
        real = real.view(-1, 784).to(device) #flatten the img
        batch_size = real.shape[0]
        
        #### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device) # generating the random noise
        fake = gen(noise) # generating fake images
    
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real)) # D(real)
        # if yn = 1 (torch.ones_like) --> loss = - 1*[1 * log(disc(real))]
        
        disc_fake = disc(fake).view(-1) 
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) # min log(D(G(z)))
        # if yn = 0 (torch.zeros_like) --> loss = - 1*[(1 - 0) * log(1 - disc(fake))]
        
        lossD = (lossD_real + lossD_fake) / 2
        
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()


        #### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1) 
        lossG = criterion(output, torch.ones_like(output)) # max log(D(G(z))
        
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

### Now we will try changing the architecture to a CNN