<a href="https://colab.research.google.com/github/Velociraptorvelraptor/fake-food-generation-with-GAN-PyTorch-Lightning/blob/main/fake_food_GAN_PTL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pytorch-lightning kaggle torchvision -q

In [None]:
!mv /content/kaggle.json /root/.kaggle

In [None]:
!kaggle datasets download -d trolukovich/food11-image-dataset

Downloading food11-image-dataset.zip to /content
 99% 1.07G/1.08G [00:13<00:00, 147MB/s]
100% 1.08G/1.08G [00:13<00:00, 86.4MB/s]


In [None]:
src_path = '/content/drive/MyDrive/Colab Notebooks/GAN-fake-food'

In [None]:
!mv  '/content/food11-image-dataset.zip' '/content/drive/MyDrive/Colab Notebooks/GAN-fake-food'

In [None]:
!mkdir '/content/drive/MyDrive/Colab Notebooks/GAN-fake-food/data'

In [None]:
training_path = src_path + '/data/training'

In [None]:
import os

import torch
import torch.functional as F
from torch import nn, optim

from torchvision import transforms as T
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader


In [None]:
normalization_mtx = [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]

In [None]:
transformer = T.Compose([
    T.Resize(64),
    T.CenterCrop(64),
    T.ToTensor(),
    T.Normalize(*normalization_mtx)
])

In [None]:
food_train_dataset = ImageFolder(training_path, transform=transformer)

In [None]:
batch_size = 128

In [None]:
train_dl = DataLoader(food_train_dataset, batch_size=batch_size, num_workers=4, pin_memory=True, shuffle=True)



In [None]:
def denormalize(input_image_tensors):
  input_image_tensors *= normalization_mtx[1][0]
  input_image_tensors += normalization_mtx[0][0]  
  return input_image_tensors

In [None]:
def save_samples(idx, sample_images):
  """
  Takes epoch number and generated images and save last 64 (half of a batch size)
  images in a grid 8x8.
  idx: epoch number
  sample_images: image returned by GAN model at the end of each epoch
  """
  fake_fname = f'fake-img-{idx}.png'
  save_image(denormalize(sample_images[-64:]), os.path.join(".", fake_fname), nrow=8)

In [None]:
class FoodDiscriminator(nn.Module):
  """
  Discriminator class which takes output from the generator model of size (3, 64, 64)
  and generates output of either 0 (fake) or 1 (real).
  """
  def __init__(self, input_size):
    super().__init__()

    self.input_size = input_size
    self.channel = 3
    self.kernel_size = 4
    self.stride = 2
    self.padding = 1
    self.bias = False
    self.negative_slope = 0.2

    # (3, 64, 64)
    self.conv1 = nn.Conv2d(self.channel, 
                           128, 
                           self.kernel_size, 
                           self.stride, 
                           self.padding, 
                           bias=self.bias)
    self.bn1 = nn.BatchNorm2d(128)
    self.relu = nn.LeakyReLU(self.negative_slope, inplace=True)

    # (64, 32, 32)
    self.conv2 = nn.Conv2d(128,
                           256, 
                           self.kernel_size, 
                           self.stride, 
                           self.padding,
                           bias=self.bias)
    self.bn2 = nn.BatchNorm2d(256)

    self.conv3 = nn.Conv2d(256,
                           512, 
                           self.kernel_size, 
                           self.stride, 
                           self.padding,
                           bias=self.bias)
    self.bn3 = nn.BatchNorm2d(512)

    self.conv4 = nn.Conv2d(512,
                           1024, 
                           self.kernel_size, 
                           self.stride, 
                           self.padding,
                           bias=self.bias)
    self.bn4 = nn.BatchNorm2d(1024)

    self.fc = nn.Sequential(nn.Linear(16384, 1),
                            nn.Sigmoid())

  def forward(self, input_img):
    x = self.conv1(input_img)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.bn2(x)
    x = self.relu(x)
    x = self.conv3(x)
    x = self.bn3(x)
    x = self.relu(x)
    x = self.conv4(x)
    x = self.bn4(x)
    x = self.relu(x)
    x = x.view(-1, 1024 * 4 * 4)
    x = self.fc(x)
    return x

In [None]:
class FoodGenerator(nn.Module):
  """
  Generator class creates fake images by incorporating feedback from the discriminator.
  latent_size: compressed low-dim represenation of the input (images)
  """
  def __init__(self, latent_size=256):
    super().__init__()
    self.latent_size = latent_size
    self.kernel_size = 4
    self.stride = 2
    self.padding = 1
    self.bias = False

    self.model = nn.Sequential(
        # input size: (latent_size, 1, 1)
        nn.ConvTranspose2d(latent_size, 512, self.kernel_size, stride=1, padding=0, bias=self.bias),
        nn.BatchNorm2d(512),
        nn.ReLU(True),

        # input size: (512, 4, 4)
        nn.ConvTranspose2d(512, 256, self.kernel_size, self.stride, self.padding, bias=self.bias),
        nn.BatchNorm2d(256),
        nn.ReLU(True),

        # input size: (256, 8, 8)
        nn.ConvTranspose2d(256, 128, self.kernel_size, self.stride, self.padding, bias=self.bias),
        nn.BatchNorm2d(128),
        nn.ReLU(True),

        # input size: (128, 16, 16)
        nn.ConvTranspose2d(128, 64, self.kernel_size, self.stride, self.padding, bias=self.bias),
        nn.BatchNorm2d(64),
        nn.ReLU(True),

        # input size: (64, 32, 32)
        nn.ConvTranspose2d(64, 3, self.kernel_size, self.stride, self.padding, bias=self.bias),
        nn.Tanh()
    )

  def forward(self, input_img):
    return self.model(input_img)

In [None]:
class FoodGAN():
  def __init__(self, latent_size=256, learninig_rate=0.0002, bias1=0.5, bias2=0.999, batch_size=128):
    super().__init__()

    self.save_hyperparameters()
    self.generator = FoodGenerator()
    self.discriminator = FoodDiscriminator(input_size=64)

    self.batch_size = batch_size
    self.latent_size = latent_size
    self.validation = torch.randn(self.batch_size, self.latent_size, 1, 1)

    def adversial_loss(self, preds, targets):
      return F.binary_cross_entropy(preds, targets)

    def configure_optimizers(self):
      learning_rate = self.hparams.learninig_rate
      bias1 = self.hparams.bias1
      bias2 = self.hparams.bias2

      opt_g = optim.Adam(self.generator.parameters(), lr=learninig_rate, betas=(bias1, bias2))
      opt_d = optim.Adam(self.discriminator.parameters(), lr=learninig_rate, betas=(bias1, bias2))     

      return [opt_g, opt_d], []

    def forward(self, z):
      return self.generator(z)

    def train(self, batch, optimizer_idx):
      real_img, _ = batch
      if optimizer_idx == 0:
        fake_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1)
        fake_random_noise = fake_random_noise.type_as(real_img)
        fake_img = self(fake_random_noise)
        preds = self.discriminator(fake_img)
        targets = torch.ones(self.batch_size, 1)
        targets = targets.type_as(real_img)

        loss = self.adversial_loss(preds, targets)
        self.log('generator_loss', loss, prog_bar=True)

        tqdm_dict = {'g_loss', loss}
        output = OrderDict({
            'loss': loss, 
            'progress_bar': tqdm_dict, 
            'log': tqdm_dict
        })
        return output