<a href="https://colab.research.google.com/github/JohnathonGil/Generative-Artificial-Networks-Study/blob/main/Advanced_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Advanced GAN
# https://medium.com/@ideami

# Importing the Libraries
import torch, torchvision, os, PIL, pdb
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from tqdm.auto import tqdm
import numpy as np
from PIL import Image
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.pyplot as plt

def show(tensor, num = 25, wandbactive = 0, name = ''):
  data = tensor.detach().cpu()
  grid = make_grid(data[:num], nrow=5).permute(1,2,0)

  # Optional
  if (wandbactive == 1):
    wandb.log({name:wandb.Image(grid.numpy().clip(0,1))})

  plt.imshow(grid.clip(0,1))
  plt.show()

# Hyperparameters and general parameters

n_epochs = 10000
batch_size = 128
lr = 1e-4
z_dim = 200
device = 'cuda' #GPU

cur_step = 0
crit_cycles = 5
gen_losses = []
crit_losses = []
show_step = 35
save_step = 35

wandbact = 1 # True, track statistics of the learning processes through weights and biases


In [None]:
### Optional
!pip install wandb -qqq
import wandb
wandb.login(key = 'aeca1fda4422b59db5390af38009e15f0524d5f3')


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.4/195.4 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.5/258.5 kB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25h

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
%%capture
experiment_name = wandb.util.generate_id()

myrun=wandb.init(
    project="Wasserstein GAN",
    group = experiment_name,
    config={
        "optimizer":"adam",
        "model":"wgan gp",
        "epoch":"1000",
        "batch_size":128
    }
)

config = wandb.config

In [None]:
print(experiment_name)

mohrho7l


In [None]:
# Generator Model

class Generator(nn.Module):
  def __init__(self, z_dim=64, d_dim=16):
    super(Generator, self).__init__()
    self.z_dim = z_dim

    self.gen = nn.Sequential(
        ## Calculating new Width and Height: (n-1) * stride - 2 * padding + ks
        ## n = width or height
        ## ks = kernal size
        ## We will begin with 1*1 image with z_dim number of channels (200)
        nn.ConvTranspose2d(z_dim, d_dim * 32, 4, 1, 0), # 4*4 (ch:200 -> 512)
        nn.BatchNorm2d(d_dim*32),
        nn.ReLU(True),

        nn.ConvTranspose2d(d_dim*32, d_dim*16, 4, 2, 1), # 8*8 (ch:512 -> 256)
        nn.BatchNorm2d(d_dim*16),
        nn.ReLU(True),

        nn.ConvTranspose2d(d_dim*16, d_dim*8, 4, 2, 1), # 16*16 (ch:256 -> 128)
        nn.BatchNorm2d(d_dim*8),
        nn.ReLU(True),

        nn.ConvTranspose2d(d_dim*8, d_dim*4, 4, 2, 1), # 32*32 (ch:128 -> 64)
        nn.BatchNorm2d(d_dim*4),
        nn.ReLU(True),

        nn.ConvTranspose2d(d_dim*4, d_dim*2, 4, 2, 1), # 64*64 (ch:64 -> 32)
        nn.BatchNorm2d(d_dim*2),
        nn.ReLU(True),

        nn.ConvTranspose2d(d_dim*2, 3, 4, 2, 1), # 128*128 (ch:32 -> 3)
        nn.Tanh() # Produces result in the range from -1 to 1
    )

  def forward(self, noise):
    x = noise.view(len(noise), self.z_dim, 1, 1) # 128 x 200 x 1 x 1
    return self.gen(x)

def gen_noise(num, z_dim, device='cuda'):
  return torch.randn(num, z_dim, device = device) # 128 x 200


In [None]:
# Critic model

class Critic(nn.Module):
  def __init__(self, d_dim=16):
    super(Critic, self).__init__()

    self.crit = nn.Sequential(
        # Conv2D in_channels, out_channels, kernel_size, stride=1, padding=0
        # New width and height: # (n+2*pad-ks)/stride + 1
        nn.Conv2d(3, d_dim, 4, 2, 1), # (ch:3 -> 16)
        nn.InstanceNorm2d(d_dim),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim, d_dim*2, 4, 2, 1),  # 32x32 (ch:16 -> 32)
        nn.InstanceNorm2d(d_dim*2),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*2, d_dim*4, 4, 2, 1),  # 16x16 (ch:32 -> 64)
        nn.InstanceNorm2d(d_dim*4),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*4, d_dim*8, 4, 2, 1),  # 8x8 (ch:64 -> 128)
        nn.InstanceNorm2d(d_dim*8),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*8, d_dim*16, 4, 2, 1), # 4x4 (ch:128 -> 256)
        nn.InstanceNorm2d(d_dim*16),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*16, 1, 4, 1, 0),  # (ch:256 -> 1)
    )

  def forward(self, image):
    # image 128 * 3 * 128 * 128
    crit_pred = self.crit(image) # 128 x 1 x 1 x 1
    return crit_pred.view(len(crit_pred), -1) # 128 x 1

In [None]:
### Optional, Initialize your weights in a different way

def init_weights(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.2)
    torch.nn.init.constant_(m.bias, 0)

  if isinstance(m, nn.BatchNorm2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.2)
    torch.nn.init.constant_(m.bias, 0)

# gen=gen.apply(init_weights)
# crit=crit.apply(init_weights)

In [None]:
# Load Dataset

import gdown, zipfile

url = 'https://dl.dropboxusercontent.com/scl/fi/vltmt8hlgdf9mv9kn7d0b/img_align_celeba.zip?rlkey=tacwpkr8d9bjpctdftjg3b00a'
path = 'data/celeba'
download_path=f'{path}/img_align_celeba.zip'

if not os.path.exists(path):
  os.makedirs(path)

gdown.download(url, download_path, quiet=False)

with zipfile.ZipFile(download_path, 'r') as ziphandler:
  ziphandler.extractall(path)

Downloading...
From: https://dl.dropboxusercontent.com/scl/fi/vltmt8hlgdf9mv9kn7d0b/img_align_celeba.zip?rlkey=tacwpkr8d9bjpctdftjg3b00a
To: /content/data/celeba/img_align_celeba.zip
 13%|█▎        | 195M/1.44G [00:01<00:07, 157MB/s]

KeyboardInterrupt: 

In [None]:
# Dataset, DataLoader, Declare Genarator and Critic, Test Dataset

class Dataset(Dataset):
  def __init__(self, path, size=128, lim=10000):
    self.sizes=[size,size]
    items, labels = [],[]

    for data in os.listdir(path)[:lim]:
      #path: './data/celeba/img_align_celeb'
      #data: '114568.jpg'
      item = os.path.join(path,data)
      items.append(item)
      labels.append(data)

    self.items = items
    self.labels = labels

  def __len__(self):
    return len(self.items)

  def __getitem__(self,idx):
    data = PIL.Image.open(self.items[idx]).convert('RGB') # image size = (width, height)
    data = np.asarray(torchvision.transforms.Resize(self.sizes)(data)) # Our case: 128 x 128 x 3
    data = np.transpose(data, (2,0,1)).astype(np.float32, copy=False) # Our case: 3 x 128 x 128
    data = torch.from_numpy(data).div(255) # from 0 to 1
    return data, self.labels[idx]

# Instantiate the Dataset

data_path = './data/celeba/img_align_celeba'
ds = Dataset(data_path, size=128, lim=10000)

dataloader = DataLoader(ds, batch_size = batch_size, shuffle=True)

# Models
gen = Generator(z_dim).to(device)
crit = Critic().to(device)

# Optimizer

gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5,0.9))
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(0.5,0.9))

# Initializations
# gen=gen.apply(init_weights)
# crit=crit.apply(init_weights)

# Wandb Optional
if (wandbact==1):
  wandb.watch(gen, log_freq=100)
  wandb.watch(crit, log_freq=100)

x,y=next(iter(dataloader))
show(x)

In [None]:
# Gradient Penalty Calculation

def get_gp(real, fake, crit, alpha, gamma=10):
  mix_images = real * alpha + fake * (1-alpha) # 128 x 3 128 x 128
  mix_scores = crit(mix_images) # 128 x 1

  gradient = torch.autograd.grad(
      inputs = mix_images,
      outputs = mix_scores,
      grad_outputs = torch.ones_like(mix_scores),
      retain_graph = True,
      create_graph = True,
  )[0] # 128 x 3 128 x 128

  gradient = gradient.view(len(gradient), -1) # 128 x 49152
  gradient_norm = gradient.norm(2, dim=1)
  gp = gamma * ((gradient_norm - 1)**2).mean()

  return gp

In [None]:
# Save and Load Checkpoints

root_path='./data/'

def save_checkpoint(name):

  torch.save({
         'epoch': epoch,
         'model_state_dict': gen.state_dict(),
         'optimizer_state_dict': gen_opt.state_dict(),
  }, f"{root_path}G-{name}.pkl")

  torch.save({
         'epoch': epoch,
         'model_state_dict': crit.state_dict(),
         'optimizer_state_dict': crit_opt.state_dict(),
  }, f"{root_path}C-{name}.pkl")

  print("Saved checkpoint")

def load_checkpoint(name):

  checkpoint = torch.load(f"{root_path}G-{name}.pkl")
  gen.load_state_dict(checkpoint['model_state_dict'])
  gen_opt.load_state_dict(checkpoint['optimizer_state_dict'])

  checkpoint = torch.load(f"{root_path}C-{name}.pkl")
  crit.load_state_dict(checkpoint['model_state_dict'])
  crit_opt.load_state_dict(checkpoint['optimizer_state_dict'])

  print("Loaded checkpoint")


In [None]:
# Training Loop

for epoch in range(n_epochs):
  for real, _ in tqdm(dataloader):
    cur_bs = len(real)
    real = real.to(device)

    # Critic
    mean_crit_loss = 0
    for _ in range(crit_cycles):
      crit_opt.zero_grad()

      noise = gen_noise(cur_bs, z_dim)
      fake = gen(noise)
      crit_fake_pred = crit(fake.detach())
      crit_real_pred = crit(real)

      alpha = torch.rand(len(real), 1, 1, 1, device = device, requires_grad = True) # 128 x 1 x 1 x 1
      gp = get_gp(real, fake.detach(), crit, alpha)

      crit_loss = crit_fake_pred.mean() - crit_real_pred.mean() + gp

      mean_crit_loss += crit_loss.item() / crit_cycles

      crit_loss.backward(retain_graph=True)
      crit_opt.step()

    crit_losses+=[mean_crit_loss]

    # Generator

    gen_opt.zero_grad()
    noise = gen_noise(cur_bs, z_dim)
    fake = gen(noise)
    crit_fake_pred = crit(fake)

    gen_loss = -crit_fake_pred.mean()
    gen_loss.backward()
    gen_opt.step()

    gen_losses+=[gen_loss.item()]

    # Statistics

    if (wandbact == 1):
      wandb.log({'Epoch': epoch, 'Step': cur_step, 'Critic Loss':mean_crit_loss, 'Generator Loss':gen_loss})

    if cur_step % save_step == 0 and cur_step > 0:
      print("Saving checkpoint: ", cur_step, save_step)
      save_checkpoint("latest")

    if (cur_step % show_step == 0 and cur_step > 0):
      show(fake, wandbactive=1, name='fake')
      show(real, wandbactive=1, name='real')

      gen_mean=sum(gen_losses[-show_step:]) / show_step
      crit_mean=sum(crit_losses[-show_step:]) / show_step
      print(f"Epoch: {epoch}: Step {cur_step}: Generator Loss:{gen_mean}, Critic Loss:{crit_mean}")

      plt.plot(
          range(len(gen_losses)),
          torch.Tensor(gen_losses),
          label="Generator Loss"
      )

      plt.plot(
          range(len(gen_losses)),
          torch.Tensor(crit_losses),
          label="Critic Loss"
      )

      plt.ylim(-1000,1000)
      plt.legend()
      plt.show()

    cur_step += 1

In [None]:
# Generate New Faces

noise = gen_noise(batch_size, z_dim)
fake = gen(noise)
show(fake)

In [None]:
plt.imshow(fake[4].detach().cpu().permute(1,2,0).squeeze().clip(0,1))

In [None]:
# Morphing (interpolation between points in latent space)

gen_set=[]
z_shape=[1, 200, 1, 1]
rows=4
steps=17

for i in range(rows):
  z1, z2 = torch.randn(z_shape), torch.randn(z_shape)
  for alpha in np.linspace(0, 1, steps):
    z = alpha*z1 + (1-alpha)*z2
    res=gen(z.cuda())[0]
    gen_set.append(res)

fig = plt.figure(figsize=(25,11))
grid = ImageGrid(fig, 111, nrows_ncols=(rows,steps), axes_pad = 0.1)

for ax, img in zip (grid, gen_set):
  ax.axis('off')
  res = img.cpu().detach().permute(1, 2, 0)
  res = res - res.min()
  res = res/ res.max() - res.min()
  ax.imshow(res)

