In [None]:
cd drive/My Drive/Colab Notebooks/Machine Learning/Plots/GAN_Training

/content/drive/My Drive/Colab Notebooks/Machine Learning/Plots/GAN_Training


In [None]:
# Based on: https://debuggercafe.com/generating-mnist-digit-images-using-vanilla-gan-with-pytorch/

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets
import torch.nn.functional as F
import imageio
import numpy as np
import matplotlib
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from tqdm import tqdm
matplotlib.style.use('ggplot')

# learning parameters
batch_size = 512
epochs = 300
sample_size = 64 # fixed sample size
nz = 32 # latent vector size (was 128)
k = 4 # number of steps to apply to the discriminator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,)),
])
to_pil_image = transforms.ToPILImage()

train_data = datasets.MNIST(
    root='../../data',
    train=True,
    download=True,
    transform=transform
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

class Generator(nn.Module):
    def __init__(self, nz):
        super(Generator, self).__init__()
        self.nz = nz
        self.main = nn.Sequential(
            nn.Linear(self.nz, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh(),
        )
    def forward(self, x):
        return self.main(x).view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_input = 784
        self.main = nn.Sequential(
            nn.Linear(self.n_input, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.main(x)
convolutional = 0


# # Convolutional Configuration
# class Generator(nn.Module):
#     def __init__(self, nz):
#         super(Generator, self).__init__()
#         self.nz = nz
#         self.generator1 = nn.Sequential(
#             nn.Linear(self.nz, 64),
#             nn.LeakyReLU(0.2),
#             nn.Linear(64, 512),
#             nn.LeakyReLU(0.2),
#             nn.Linear(512, 1152),
#             nn.BatchNorm1d(1152),
#         )
#         self.generator2 = nn.Sequential(
#             nn.ConvTranspose2d(128,64, kernel_size=3, stride=2, bias=False),
#             nn.BatchNorm2d(64),
#             nn.LeakyReLU(0.2),
#             nn.ConvTranspose2d(64,32, kernel_size=4, stride=2, padding=1, bias=False),
#             nn.BatchNorm2d(32),
#             nn.LeakyReLU(0.2),
#             nn.ConvTranspose2d(32,1, kernel_size=4, stride=2, padding=1, bias=False),
#             nn.Tanh()
#         )
#     def forward(self, x):
#       x = self.generator1(x)
#       x = x.view(-1,128,3,3)
#       x = self.generator2(x)
#       return x

# class Discriminator(nn.Module):
#     def __init__(self):
#         super(Discriminator, self).__init__()
#         self.n_input = 784
#         self.discrim1 = nn.Sequential(
#           nn.Conv2d(1,32, kernel_size=3, stride=2, bias=False),
#           nn.BatchNorm2d(32),
#           nn.LeakyReLU(0.2),
#         )
#         self.discrim2 = nn.Sequential(
#           nn.Conv2d(32,64, kernel_size=3, stride=3, bias=False),
#           nn.BatchNorm2d(64),
#           nn.LeakyReLU(0.2),
#         )
#         self.discrim3 = nn.Sequential(
#           nn.Conv2d(64,128, kernel_size=3, stride=3, bias=False),
#           nn.BatchNorm2d(128),
#           nn.LeakyReLU(0.2),
#         )
#         self.discrim4 = nn.Sequential(
#           nn.Linear(1152,512),
#           nn.LeakyReLU(0.2),
#           nn.Linear(512,64),
#           nn.LeakyReLU(0.2),
#           nn.Linear(64,16),
#           nn.LeakyReLU(0.2),
#           nn.Linear(16,1),
#           nn.Sigmoid(),
#         )
#     def forward(self, x):
#       x = F.pad(x, (15, 14, 15, 14),)
#       x = self.discrim1(x)
#       x = self.discrim2(x)
#       x = self.discrim3(x)
#       x = x.view(-1,1152)
#       x = self.discrim4(x)
#       return x
# convolutional = 1


################################################
generator = Generator(nz).to(device)
discriminator = Discriminator().to(device)
generator.load_state_dict(torch.load('models/generator.pth', map_location=device))
discriminator.load_state_dict(torch.load('models/discriminator.pth', map_location=device))
print('##### GENERATOR #####')
print(generator)
print('######################')
print('\n##### DISCRIMINATOR #####')
print(discriminator)
print('######################')

# optimizers
optim_g = optim.Adam(generator.parameters(), lr=0.00002)
optim_d = optim.Adam(discriminator.parameters(), lr=0.00002)

# loss function
criterion = nn.BCELoss()

losses_g = [] # to store generator loss after each epoch
losses_d = [] # to store discriminator loss after each epoch
images = [] # to store images generatd by the generator

# to create real labels (1s)
def label_real(size):
    data = torch.ones(size, 1)
    return data.to(device)
# to create fake labels (0s)
def label_fake(size):
    data = torch.zeros(size, 1)
    return data.to(device)

# function to create the noise vector
def create_noise(sample_size, nz):
    return torch.randn(sample_size, nz).to(device)

# to save the images generated by the generator
def save_generator_image(image, path):
    save_image(image, path)    

# function to train the generator network
def train_generator(optimizer, data_fake):
    b_size = data_fake.size(0)

    real_label = label_real(b_size)

    optimizer.zero_grad()
    output = discriminator(data_fake)
    loss = criterion(output, real_label)

    loss.backward()
    optimizer.step()
    return loss

# function to train the discriminator network
def train_discriminator(optimizer, data_real, data_fake):
    b_size = data_real.size(0)
    real_label = label_real(b_size)
    fake_label = label_fake(b_size)

    optimizer.zero_grad()

    output_real = discriminator(data_real)
    loss_real = criterion(output_real, real_label)

    output_fake = discriminator(data_fake)
    loss_fake = criterion(output_fake, fake_label)

    loss_real.backward()
    loss_fake.backward()
    optimizer.step()

    return loss_real + loss_fake

# Training Loop
noise = create_noise(sample_size, nz)
generator.train()
discriminator.train()

print("On epoch: ")
for epoch in range(epochs):
  print(f"{epoch+1}")
  loss_g = 0.0
  loss_d = 0.0
  for bi, data in enumerate(train_loader):
      image, _ = data
      image = image.to(device)
      b_size = len(image)
      # run the discriminator for k number of steps
      for step in range(k):
          data_fake = generator(create_noise(b_size, nz)).detach()
          data_real = image
          # train the discriminator network
          loss_d += train_discriminator(optim_d, data_real, data_fake)
      data_fake = generator(create_noise(b_size, nz))
      # train the generator network
      loss_g += train_generator(optim_g, data_fake)
  # create the final fake image for the epoch
  generated_img = generator(noise).cpu().detach()
  # make the images as grid
  generated_img = make_grid(generated_img)
  # save the generated torch tensor models to disk
  if convolutional:
    save_generator_image(generated_img, f'output_conv/gen_img{epoch}.png')
  else:
    save_generator_image(generated_img, f'output/gen_img{epoch}.png')
  images.append(generated_img)
  epoch_loss_g = loss_g / bi # total generator loss for the epoch
  epoch_loss_d = loss_d / bi # total discriminator loss for the epoch
  losses_g.append(epoch_loss_g)
  losses_d.append(epoch_loss_d)
  

print('DONE TRAINING')
if convolutional:
  torch.save(generator.state_dict(), 'models/generator_conv.pth')
  torch.save(discriminator.state_dict(), 'models/discriminator_conv.pth')
  torch.save(generator, 'models/generator_model_conv.pth')
  torch.save(discriminator, 'models/discriminator_model_conv.pth')
else:
  torch.save(generator.state_dict(), 'models/generator.pth')
  torch.save(discriminator.state_dict(), 'models/discriminator.pth')
  torch.save(generator, 'models/generator_model.pth')
  torch.save(discriminator, 'models/discriminator_model.pth')

# save the generated images as GIF file
imgs = [np.array(to_pil_image(img)) for img in images]
if convolutional:
  imageio.mimsave('output_conv/generator_images.gif', imgs)
else:
  imageio.mimsave('output/generator_images.gif', imgs)

# plot and save the generator and discriminator loss
plt.figure()
plt.plot(losses_g, label='Generator loss')
plt.plot(losses_d, label='Discriminator Loss')
plt.legend()
if convolutional:
  plt.savefig('output_conv/loss.png')
else:
  plt.savefig('output/loss.png')