<a href="https://colab.research.google.com/github/Krisss993/GAN/blob/main/06_13_GAN_03.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Loss function:
Earth movel distance - The effort needed to make both distributions equal
- Critics(discriminator) values not restricted to be between 0 and 1
- Even for very different distributions, gradients are significant and high enough to drive the process in the right way

In [1]:
import torch
import torchvision
import os
import PIL
import pdb
import numpy as np
import matplotlib.pyplot as plt

from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from tqdm.auto import tqdm
from PIL import Image

In [2]:
# OPTIONAL
!pip install wandb -qqq
import wandb
wandb.login(key='')

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m289.6/289.6 kB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25h

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
def show(tensor, num=25, wandbactivation=0, name=''):
  data = tensor.detach().cpu()
  grid = make_grid(data[:num], nrow=5).permute(1,2,0)

  # optional
  # wandb - online activation
  if (wandbactivation==1 and wandbact==1):
      wandb.log({name:wandb.Image(grid.numpy().clip(0,1))})


  # cliping pixels to range(0,1)
  plt.imshow(grid.clip(0,1))
  plt.show()


In [5]:
## hyperparameters and general parameters

n_epochs = 1000
batch_size = 128
lr = 1e-4

# z_dim - input noise latent vector dim
z_dim = 200
device = 'cuda'

cur_step = 0

# 5 cycles training of the critic, then 1 of the generator
# generally critic needs more training than the generator
crit_cycles = 5

gen_losses = []
crit_losses = []
show_step  = 35
save_step = 35

# optional, tracking stats online
wandbact = 1


In [6]:
# optional wandb
%%capture
# experiment_name = wandb.util.generate_id()
experiment_name = 'MY EXP'
myrun = wandb.init(
    project='wgan',
    name=experiment_name,
    group=experiment_name,
    config={'optimizer':'adam',
            'model':'wgan gp',
            'epoch':'1000',
            'batch_size':128
            }
)
config = wandb.config

In [7]:
# optional wandb
print(experiment_name)

MY EXP


In [45]:
# generator model

class Generator(nn.Module):

  # d_dim - internal dimension for the output of the convolutional layers
  def __init__(self, z_dim=64, d_dim=16):
    super(Generator, self).__init__()
    self.z_dim = z_dim

    self.gen = nn.Sequential(

        # ConvTranspose2d: in_channels, out_channels, kernel_size, stride=1, padding=0
        # n - width or height
        # (n - 1) * stride - 2 * padding + kernel_size

        # generator starts 1x1 pixel and z_dim number of channels and it gives it dimensionality of latent space
        # starting with 200 channels bringing up channels to 512 and increasing size of the image

        # 1st block
        nn.ConvTranspose2d(in_channels=z_dim, out_channels=d_dim*32, kernel_size=4, stride=1, padding=0),
        # normalizing values for improving stability, num_features = out from conv layer
        nn.BatchNorm2d(num_features=d_dim*32),
        # applying nonlinearity
        nn.ReLU(True),


        # 2nd block
        nn.ConvTranspose2d(in_channels=d_dim*32, out_channels=d_dim*16, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(d_dim*16),
        nn.ReLU(inplace=True),

        # 3rd block
        nn.ConvTranspose2d(in_channels=d_dim*16, out_channels=d_dim*8, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(d_dim*8),
        nn.ReLU(inplace=True),

        # 4th block
        nn.ConvTranspose2d(in_channels=d_dim*8, out_channels=d_dim*4, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(d_dim*4),
        nn.ReLU(inplace=True),

        # 5th block
        nn.ConvTranspose2d(in_channels=d_dim*4, out_channels=d_dim*2, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(d_dim*2),
        nn.ReLU(inplace=True),

        # 6th block
        # in output layer we stay with ONLY 3 channels
        nn.ConvTranspose2d(in_channels=d_dim*2, out_channels=d_dim*1, kernel_size=4, stride=2, padding=1),
        # no batch norm

        # result from -1 to 1
        nn.Tanh()

    )


  def forward(self, noise):
    # generator recives noise as input
    x = noise.view(len(noise), self.z_dim, 1, 1) # 128(batch) x 200(z_dim) x 1(height) x 1(width)

    return self.gen(x)


def gen_noise(num, z_dim, device='cuda'):
  return torch.randn(num, z_dim, device=device) # 128 x 200


# n - width or height
# nn.Conv2d: (n + 2 * pad - ks) // stride + 1
# nn.ConvTranspose2d: (n - 1) * stride - 2 * padding + kernel_size

(n - 1) * stride - 2 * padding + kernel_size
- 1st step:
- (1 - 1) * 1 - 2 * 0 + 4 = 4x4 image, channels: 200 to 512
- 2nd step
- (4 - 1) * 2 - 2 * 1 + 4 = 8x8 image, channels: 512 to 256
- 3rd step
- (8 - 1) * 2 - 2 * 1 + 4 = 16x16 image, channels: 256 to 128
- 4th step
- (16 - 1) * 2 - 2 * 1 + 4 = 32x32 image, channels: 128 to 64
- 5th step
- (32 - 1) * 2 - 2 * 1 + 4 = 64x64 image, channels: 64 to 32
- 6th step
- (64 - 1) * 2 - 2 * 1 + 4 = 128x128 image, channels: 32 to 3

In [10]:
# critic model
# Conv2d: in_channels, out_channels, kernel_size, stride=1, padding=0
# (n + 2 * padding - kernel_size) // stride + 1

class Critic(nn.Module):
  def __init__(self, d_dim=16):
    super(Critic, self).__init__()

    self.crit = nn.Sequential(
        # 1st block
        nn.Conv2d(in_channels=3, out_channels=d_dim, kernel_size=4, stride=2, padding=1),

        # instead of batchnorm2d, normalizing according to the values of the whole instance insted of values of the batch
        nn.InstanceNorm2d(d_dim), # works the best

        # leaky keeps information, negative values have little slope(small negative numbers), but theyre not converted to 0
        nn.LeakyReLU(0.2),

        # 2nd block
        nn.Conv2d(in_channels=d_dim, out_channels=d_dim*2, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(d_dim*2),
        nn.LeakyReLU(0.2),

        # 3rd block
        nn.Conv2d(in_channels=d_dim*2, out_channels=d_dim*4, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(d_dim*4),
        nn.LeakyReLU(0.2),

        # 4th block
        nn.Conv2d(in_channels=d_dim*4, out_channels=d_dim*8, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(d_dim*8),
        nn.LeakyReLU(0.2),

        # 5th block
        nn.Conv2d(in_channels=d_dim*8, out_channels=d_dim*16, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(d_dim*16),
        nn.LeakyReLU(0.2),

        # 6th block
        # we return 1 value, either its fake or real
        # stride has to be 1 and padding 0
        nn.Conv2d(in_channels=d_dim*16, out_channels=1, kernel_size=4, stride=1, padding=0),

    )


  def forward(self, image):
    # image: 128(batch) x 3(channels) x 128(height) x 128(width)
    crit_pred = self.crit(image) # 128 x 1 x 1 x 1
    return crit_pred.view(len(crit_pred), -1) # 128(batch values) x 1(fake or real)


(n + 2 * padding - kernel_size) // stride + 1
- 1st step
- (128 + 2 * 1 - 4) //2 + 1 = 64x64 image, channels: 3 to 16
- 2nd step
- (64 + 2 * 1 - 4) //2 + 1 = 32x32 image, channels: 16 to 32
- 3rd step
- (32 + 2 * 1 - 4) //2 + 1 = 16x16 image, channels: 32 to 64
- 4th step
- (16 + 2 * 1 - 4) //2 + 1 = 8x8 image, channels: 64 to 128
- 5th step
- (8 + 2 * 1 - 4) //2 + 1 = 4x4 image, channels: 128 to 256
- 6th step
- (4 + 2 * 0 - 4) //1 + 1 = 1x1 image, channels: 256 to 1

Alternative way to initialize parameters


In [11]:
def init_weights(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
    torch.nn.init.normal(m.weight, 0.0, 0.02)
    torch.nn.init.constant(m.bias, 0)

  if isinstance(m, nn.BatchNorm2d):
    torch.nn.init.normal(m.weight, 0.0, 0.02)
    torch.nn.init.constant(m.bias, 0)

# Initializations - NOT HERE
# gen = gen.apply(init_weights)
# crit = crit.apply(init_weights)

In [28]:
# loading dataset
import gdown, zipfile

# url = ''
path = '/content/drive/MyDrive'
download_path = f'{path}/img_align_celeba.zip'

if not os.path.exists(path):
  os.makedirs(path)

# gdown.download(url, download_path, quiet=False)

with zipfile.ZipFile(download_path, 'r') as ziphandler:
  ziphandler.extractall('.')

In [19]:
!wget https://drive.google.com/file/d/0B7EVK8r0v71pZjFTYXZWM3FlRnM/view?resourcekey=0-dYn9z10tMJOBAkviAcfdyQ
!unzip -q img_align_celeba.zip

--2024-06-17 10:52:45--  https://drive.google.com/file/d/0B7EVK8r0v71pZjFTYXZWM3FlRnM/view?resourcekey=0-dYn9z10tMJOBAkviAcfdyQ
Resolving drive.google.com (drive.google.com)... 108.177.121.138, 108.177.121.100, 108.177.121.101, ...
Connecting to drive.google.com (drive.google.com)|108.177.121.138|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/html]
Saving to: ‘view?resourcekey=0-dYn9z10tMJOBAkviAcfdyQ’

          view?reso     [<=>                 ]       0  --.-KB/s               view?resourcekey=0-     [ <=>                ]  88.03K  --.-KB/s    in 0.003s  

2024-06-17 10:52:45 (26.0 MB/s) - ‘view?resourcekey=0-dYn9z10tMJOBAkviAcfdyQ’ saved [90138]

unzip:  cannot find or open img_align_celeba.zip, img_align_celeba.zip.zip or img_align_celeba.zip.ZIP.


In [43]:
# class Dataset(Dataset):
class Dataset():
  def __init__(self, path, size=128, lim=10000):
    self.sizes = [size, size]

    # paths to the images
    items, labels = [], []

    for data in os.listdir(path)[:lim]:
      # path: './data/celeba/img_align_celeba'
      # data: '123213.img'
      item = os.path.join(path, data)
      items.append(item)
      labels.append(data)
    self.items=items
    self.labels=labels

  def __len__(self):
    return len(self.items)

  def __getitem__(self, idx):
    # open image idx
    data = PIL.Image.open(self.items[idx]).convert('RGB') # size of the image: fe. 1278, 121
    data = np.asarray(torchvision.transforms.Resize(self.sizes)(data)) # 128 x 128 x 3
    data = np.transpose(data, (2,0,1)).astype(np.float32, copy=False) # 3 x 128 x 128
    # from np to tensor for training, div = standarizing
    data = torch.from_numpy(data).div(255) # leaving values from 0 to 1
    return data, self.labels[idx]


In [44]:
# Dataset
data_path = './img_align_celeba'
ds = Dataset(data_path, size=128, lim=10000)

# DataLoader
dataloader = DataLoader(dataset=ds, batch_size=128, shuffle=True)

# Models
gen = Generator(z_dim=z_dim).to(device)
crit = Critic().to(device)

# Optimizers
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.9)) # betas - internal calculations, works well with this architecture
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(0.5, 0.9))

# Initializations - OPTIONAL
gen = gen.apply(init_weights)
crit = crit.apply(init_weights)

# wandb - OPTIONAL
if (wandbact==1):
  wandb.watch(gen, log_freq=100)
  wandb.watch(crit, log_freq=100)

x, y = next(iter(dataloader))
show(x)

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [38]:
# gradient penalty calculation

def get_gp(real, fake, crit, alpha, gamma=10): # alpha does random interpolations, gamma stands for intensity of gp regularizations
  mix_images = real * alpha + fake * (1-alpha) # 128(batch) x 3 x 128 x 128, linear interpolation
  mix_scores = crit(mix_images) # predictions: 128(batch) x 1

  # we want to penalize gradients that are to large
  # computing and returning the sum of the gradients of the outputs with respect to the inputs
  gradient = torch.autograd.grad(
      inputs = mix_images,
      outputs = mix_scores,

      # puting ones to take into account all the grades and outputs
      grad_outputs = torch.ones_like(mix_scores),
      retain_graph=True,
      create_graph=True,

  )[0] # return first batch 128(bs) x 3 x 128 x 128

  gradient = gradient.view(len(gradient), -1) # 128 x 49512(128x128x3)
  gradient_norm = gradient.norm(2, dim=1) # L2 norm
  gp = gamma * ((gradient_norm - 1)**2).mean()

  return gp

In [41]:
# saving and loading checkpoints

root_path='./data/'

def save_checkpoint(name):
  torch.save({
      'epoch':epoch,
      'model_state_dict':gen.state_dict(),
      'optimizer_state_dict':gen_opt.state_dict(),
  }, f'{root_path}G-{name}.pkl')

  torch.save({
      'epoch':epoch,
      'model_state_dict':crit_cycles.state_dict(),
      'optimizer_state_dict':crit_opt.state_dict(),
  }, f'{root_path}C-{name}.pkl')

  print('Saved checkpoint')


def load_checkpoint(name):
  # generator
  # loading file
  checkpoint = torch.load(f'{root_path}G-{name}.pkl')
  # loading values to the model
  gen.load_state_dict(checkpoint['model_state_dict'])
  gen_opt.load_state_dict(checkpoint['optimizer_state_dict'])

  # critic
  checkpoint = torch.load(f'{root_path}C-{name}.pkl')
  crit.load_state_dict(checkpoint['model_state_dict'])
  crit_opt.load_state_dict(checkpoint['optimizer_state_dict'])

  print('Loaded checkpoint')

In [42]:
epoch=1
save_checkpoint('test')
load_checkpoint('test')

NameError: name 'gen' is not defined

In [49]:
# Training loop

for epoch in range(n_epochs):
  for real, _ in tqdm(dataloader):
    cur_bs = len(real) # 128
    real = real.to(device)


    ## Critic
    mean_crit_loss = 0
    for _ in range(crit_cycles):

      # zeroing gradient of the optimizer
      crit_opt.zero_grad()

      noise = gen_noise(cur_bs, z_dim)
      fake = gen(noise)
      # detaching for not affecting the parameters of the generator
      crit_fake_pred = crit(fake.detach())
      crit_real_pred = crit(real)

      # alpha vector (numbers size of the batch)
      alpha = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True) # 128 x 1 x 1 x 1

      # calculating gradient penalty
      gp = get_gp(real, fake.detach(), crit, alpha)

      # calculating loss
      crit_loss = crit_fake_pred.mean() - crit_real_pred.mean() + gp

      #.item - taking only the number from the tensor
      mean_crit_loss += crit_loss.item() / crit_cycles

      # optimizer backpropagation
      crit_loss.backward(retain_graph=True)
      crit_opt.step()

    # list of losses values
    crit_losses += [mean_crit_loss]


    ## Generator

    # zeroing gradient of the optimizer
    gen_opt.zero_grad()

    # creating noise 128 x 200
    noise = gen.noise(cur_bs, z_dim)

    # passing noise through generator
    fake = gen(noise)

    # passing them through critic
    crit_fake_pred = crit(fake)

    # negative of the pred of the critic
    gen_loss = -crit_fake_pred.mean()

    # backpropagation
    gen_loss.backward()

    # updating the parameters of the generator
    gen_opt.step()

    gen_losses+=[gen_loss.item()]

    ## Statistics

    if (wandb==1):
      wandb.log(
          {'Epoch':epoch,
           'Step':cur_step,
           'Critic loss':mean_crit_loss,
           'Gen loss':gen_loss,
           }
      )

    if cur_step % save_step == 0 and cur_step > 0:
      print('Saving checkpoint:', cur_step, save_step)
      # best to save the files with the different names fe. nr of epoch
      save_checkpoint('latest')

    if (cur_step % show_step == 0 and cur_step > 0):
      show(fake, wandbactivation=1, name='fake')
      show(real, wandbactivation=1, name='real')

      gen_mean = sum(gen_losses[-show_step:]) / show_step
      crit_mean = sum(crit_losses[-show_step:]) / show_step
      print(f'Epoch: {epoch}, step: {cur_step}, Generator loss: {gen_loss}, Critic loss: {crit_loss}')

      plt.plot(range(len(gen_losses)),
               torch.Tensor(gen_losses),
               label='Generator loss')

      plt.plot(range(len(crit_losses)),
               torch.Tensor(crit_losses),
               label='Critic loss')

      plt.ylim(-200,200)
      plt.legend()
      plt.show()
    cur_step += 1

  0%|          | 0/79 [00:00<?, ?it/s]

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx