# Generative Adversarial Network (GAN)

In [None]:
import os
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.datasets import ImageFolder
import torchvision.transforms as trans
from torchvision.utils import save_image
from torchvision.utils import make_grid

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

# Data directory

In [None]:
data_dir = 'dataset/'
print(os.listdir(data_dir)[:3])

# Defining training data and loader

In [None]:
image_size = 64
batch_size = 128
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
train_data = ImageFolder(
    data_dir,
    transform = trans.Compose(
        [
         trans.Resize(image_size),
         trans.RandomRotation(5),
         trans.RandomHorizontalFlip(0.5),
         trans.CenterCrop(image_size),
         trans.ToTensor(),
         trans.Normalize(mean = mean, std = std)
        ]
    )
)

train_loader = DataLoader(
    train_data,
    batch_size,
    shuffle = True,
    num_workers = 2,
    pin_memory = True
)

# Display images

In [None]:
def denorm(img_tensors):
    return img_tensors * std[0] + mean[0]

def show_images(images, nmax=8):
    fig, plot = plt.subplots(figsize=(8, 8))
    plt.title("Cats images for GAN")
    plt.axis("off")
    plot.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))

def show_batch(dl, nmax=64):
    for images, _ in dl:
        show_images(images, nmax)
        break

show_batch(train_loader)

# Select device

In [None]:
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

print(device)

cpu


# Discriminator description

In [None]:
discriminator = nn.Sequential(
    
    nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2, inplace=True),

    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),

    nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),

    nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),

    nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),

    nn.Flatten(),
    nn.Sigmoid()
)

discriminator = discriminator.to(device)

# Generator description

In [None]:
generator = nn.Sequential(

    nn.ConvTranspose2d(128, 512, kernel_size=4, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),

    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),

    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),

    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(True),

    nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
    nn.Tanh()
)

generator = generator.to(device)

# Train discriminator

In [None]:
def train_disc(real_images, disc):
  disc.zero_grad()

  real_prediction = discriminator(real_images)
  real_target = torch.ones(real_images.size(0), 1, device = device)
  real_loss = F.binary_cross_entropy(real_prediction, real_target)
  real_score = torch.mean(real_prediction).item()

  latent = torch.randn(batch_size, 128, 1, 1, device=device)
  fake_images = generator(latent)

  fake_target = torch.zeros(fake_images.size(0), 1, device = device)
  fake_prediction = discriminator(fake_images)
  fake_loss = F.binary_cross_entropy(fake_prediction, fake_target)
  fake_score = torch.mean(fake_prediction).item()

  loss = real_loss + fake_loss
  loss.backward()
  disc.step()
  return loss.item(), real_score, fake_score

# Train generator

In [None]:
def train_gen(gen):
  gen.zero_grad()

  latent = torch.randn(batch_size, 128, 1, 1, device=device)
  fake_images = generator(latent)

  prediction = discriminator(fake_images)
  target = torch.ones(batch_size, 1, device = device)
  loss = F.binary_cross_entropy(prediction, target)

  loss.backward()
  gen.step()

  return loss.item()

# Save generated images

In [None]:
newfolder = 'generated_images'
os.makedirs(newfolder, exist_ok=True)

def save_generated_images(index, latent_tensors):
  fake_images = generator(latent_tensors)
  fake_file = 'output-{0:0=4d}.png'.format(index)
  save_image(denorm(fake_images), os.path.join(newfolder, fake_file), nrow = 8)

fixed_latent = torch.randn(64, 128, 1, 1, device=device)
save_generated_images(0, fixed_latent)

# Model definition

In [None]:
def model (epochs, learning_rate, start_index = 1):
  torch.cuda.empty_cache()

  real_scores = []
  fake_scores = []
  loss_gen = []
  loss_disc = []
  print(discriminator.parameters())
  print(generator.parameters())
  genopt = torch.optim.Adam(generator.parameters(), lr = learning_rate, betas = (0.5, 0.999))
  discopt = torch.optim.Adam(discriminator.parameters(), lr = learning_rate, betas = (0.5, 0.999))
  for epoch in range(epochs):
    for real_images, _ in tqdm(train_loader):
      real_images = real_images.to(device)
      discloss, real_score, fake_score = train_disc(real_images, discopt)
      genloss = train_gen(genopt)
    loss_gen.append(genloss)
    loss_disc.append(discloss)
    real_scores.append(real_score)
    fake_scores.append(fake_score)

    print("Epoch [{}/{}], genloss: {:.4f}, discloss: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
        epoch+1, epochs, genloss, discloss, real_score, fake_score))
    
    save_generated_images(epoch+start_index, fixed_latent)
    
  return loss_gen, loss_disc, real_scores, fake_scores

learning_rate = 0.0002
epochs = 50

history = model(epochs, learning_rate)

loss_gen, loss_disc, real_scores, fake_scores = history

# Scores graph

In [None]:
plt.plot(real_scores, '-')
plt.plot(fake_scores, '-')
plt.xlabel('iterations')
plt.ylabel('score')
plt.legend(['Real', 'Fake'])
plt.title('Real and fake scores');

# Losses graph

In [None]:
plt.plot(loss_disc, '-')
plt.plot(loss_gen, '-')
plt.xlabel('iterations')
plt.ylabel('loss')
plt.legend(['Discriminator', 'Generator'])
plt.title('Generator and Discriminator loss during training');