<a href="https://colab.research.google.com/github/EVA4-RS-Group/Phase2/blob/master/S6_GAN/EVA4_P2_S6_GenerativeAdversarialNetwork_R1GAN_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementation of GAN with R1 Regularizer
Reference: https://arxiv.org/pdf/1801.04406

In [None]:
# Run the comment below only when using Google Colab
# !pip install torch torchvision

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

In [None]:
import numpy as np
import datetime
import os, sys
import glob
from tqdm import tqdm

In [None]:
from PIL import Image

In [None]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow, imsave
%matplotlib inline

In [None]:
MODEL_NAME = 'R1'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
IMAGE_DIM = (64, 64, 3)

In [None]:
def tensor2img(tensor):
    img = np.clip((np.transpose(tensor.detach().cpu().numpy(), [1,2,0])+1)/2.0,0,1)

    return img

In [None]:
def get_sample_image(G, n_noise=100, n_samples=64):
    """
        save sample 100 images
    """
    n_rows = int(np.sqrt(n_samples))
    z = (torch.rand(size=[n_samples, n_noise])*2-1).to(DEVICE) # U[-1, 1]
    x_fake = G(z)
    x_fake = torch.cat([torch.cat([x_fake[n_rows*j+i] for i in range(n_rows)], dim=1) for j in range(n_rows)], dim=2)
    result = tensor2img(x_fake)
    return result

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, inplanes, planes, kernel_size=3, stride=1, downsample=None, groups=1):
        super(ResidualBlock, self).__init__()
        p = kernel_size//2
        self.conv1 = nn.Sequential(
            nn.Conv2d(inplanes, planes, kernel_size, stride=stride, padding=p),
            nn.LeakyReLU(0.2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(planes, planes, kernel_size, padding=p),
            nn.LeakyReLU(0.2)
        )
        self.proj = nn.Conv2d(inplanes, planes, 1) if inplanes != planes else None
    
    def forward(self, x):
        identity = x
        
        y = self.conv1(x)
        y = self.conv2(y)
        
        identity = identity if self.proj is None else self.proj(identity)
        y = y + identity
        return y

In [None]:
class Discriminator(nn.Module):
    """
        Convolutional Discriminator
    """
    def __init__(self, in_channel=1):
        super(Discriminator, self).__init__()
        self.D = nn.Sequential(
            nn.Conv2d(in_channel, 64, 3, padding=1), # (N, 64, 64, 64)
            ResidualBlock(64, 128),
            nn.AvgPool2d(3, 2, padding=1), # (N, 128, 32, 32)
            ResidualBlock(128, 256),
            nn.AvgPool2d(3, 2, padding=1), # (N, 256, 16, 16)
            ResidualBlock(256, 512),
            nn.AvgPool2d(3, 2, padding=1), # (N, 512, 8, 8)
            ResidualBlock(512, 1024),
            nn.AvgPool2d(3, 2, padding=1) # (N, 1024, 4, 4)
        )
        self.fc = nn.Linear(1024*4*4, 1) # (N, 1)
        
    def forward(self, x):
        B = x.size(0)
        h = self.D(x)
        h = h.view(B, -1)
        y = self.fc(h)
        return y

In [None]:
class Generator(nn.Module):
    """
        Convolutional Generator
    """
    def __init__(self, out_channel=1, n_filters=128, n_noise=512):
        super(Generator, self).__init__()
        self.fc = nn.Linear(n_noise, 1024*4*4)
        self.G = nn.Sequential(
            ResidualBlock(1024, 512),
            nn.Upsample(scale_factor=2, mode='bilinear'), # (N, 512, 8, 8)
            ResidualBlock(512, 256),
            nn.Upsample(scale_factor=2, mode='bilinear'), # (N, 256, 16, 16)
            ResidualBlock(256, 128),
            nn.Upsample(scale_factor=2, mode='bilinear'), # (N, 128, 32, 32)
            ResidualBlock(128, 64),
            nn.Upsample(scale_factor=2, mode='bilinear'), # (N, 64, 64, 64)
            ResidualBlock(64, 64),
            nn.Conv2d(64, out_channel, 3, padding=1) # (N, 3, 64, 64)
        )
        
    def forward(self, z):
        B = z.size(0)
        h = self.fc(z)
        h = h.view(B, 1024, 4, 4)
        x = self.G(h)
        return x

In [None]:
transform = transforms.Compose([transforms.Resize((IMAGE_DIM[0],IMAGE_DIM[1])),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                std=(0.5, 0.5, 0.5))
                               ]
)

CelebA-aligned download: [link](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!rm -rf img_align_celeba
!cp "/content/drive/My Drive/EVA4/Phase2/S6/Img/img_align_celeba.zip" /content
!unzip -q img_align_celeba.zip
!mv /content/img_align_celeba /content/train
!mkdir /content/img_align_celeba
!cp -r /content/train/ /content/img_align_celeba/
!rm -rf train

In [None]:
dataset = datasets.ImageFolder(root='/content/img_align_celeba', transform=transform)

In [None]:
batch_size = 64
n_noise = 256

In [None]:
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)

In [None]:
D = Discriminator(in_channel=IMAGE_DIM[-1]).to(DEVICE)
G = Generator(out_channel=IMAGE_DIM[-1], n_noise=n_noise).to(DEVICE)

In [None]:
D_opt = torch.optim.RMSprop(D.parameters(), lr=1e-4, alpha=0.99)
G_opt = torch.optim.RMSprop(G.parameters(), lr=1e-4, alpha=0.99)

In [None]:
# D_scheduler = torch.optim.lr_scheduler.MultiStepLR(D_opt, milestones=[3, 10, 17], gamma=0.6)
# G_scheduler = torch.optim.lr_scheduler.MultiStepLR(G_opt, milestones=[3, 10, 17], gamma=0.6)

In [None]:
# criterion = nn.L1Loss()
def r1loss(inputs, label=None):
    # non-saturating loss with R1 regularization
    l = -1 if label else 1
    return F.softplus(l*inputs).mean()

In [None]:
max_epoch = 20
step = 0

In [None]:
log_term = 1000
save_term = 1000

In [None]:
r1_gamma = 10

In [None]:
steps_per_epoch = len(data_loader.dataset) // batch_size
steps_per_epoch

3165

In [None]:
if not os.path.exists('samples'):
    os.makedirs('samples')
    
if not os.path.exists('ckpt'):
    os.makedirs('ckpt')

In [None]:
def save_checkpoint(state, file_name='checkpoint.pth.tar'):
    torch.save(state, file_name)

In [None]:
for epoch in range(max_epoch):
    for idx, (images, labels) in enumerate(tqdm(data_loader, total=len(data_loader))):
        G.zero_grad()
        # Training Discriminator
        x = images.to(DEVICE)
        x.requires_grad = True
        x_outputs = D(x)
        d_real_loss = r1loss(x_outputs, True)
        # Reference >> https://github.com/rosinality/style-based-gan-pytorch/blob/a3d000e707b70d1a5fc277912dc9d7432d6e6069/train.py
        # little different with original DiracGAN
        grad_real = grad(outputs=x_outputs.sum(), inputs=x, create_graph=True)[0]
        grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()
        grad_penalty = 0.5*r1_gamma*grad_penalty
        D_x_loss = d_real_loss + grad_penalty

        z = (torch.rand(size=[batch_size, n_noise])*2-1).to(DEVICE)
        x_fake = G(z)
        z_outputs = D(x_fake.detach())
        D_z_loss = r1loss(z_outputs, False)
        D_loss = D_x_loss + D_z_loss
        
        D.zero_grad()
        D_loss.backward()
        D_opt.step()

        # Training Generator
        z = (torch.rand(size=[batch_size, n_noise])*2-1).to(DEVICE)
        x_fake = G(z)
        z_outputs = D(x_fake)
        G_loss = r1loss(z_outputs, True)
        
        G.zero_grad()
        G_loss.backward()
        G_opt.step()
        
        if step % save_term == 0:
            save_checkpoint({'global_step': step,
                 'D':D.state_dict(),
                 'G':G.state_dict(),
                 'd_optim': D_opt.state_dict(),
                 'g_optim' : G_opt.state_dict()},
                'ckpt/r1gan{:06d}.pth.tar'.format(step))
        
        if step % log_term == 0:
            dt = datetime.datetime.now().strftime('%H:%M:%S')
            print('Epoch: {}/{}, Step: {}, D Loss: {:.4f}, G Loss: {:.4f}, gp: {:.4f}, Time:{}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item(), grad_penalty.item(), dt))
            G.eval()
            img = get_sample_image(G, n_noise, n_samples=25)
            imsave('samples/{}_step{:06d}.jpg'.format(MODEL_NAME, step), img)
            G.train()
        
        step += 1
#     D_scheduler.step()
#     G_scheduler.step()

  "See the documentation of nn.Upsample for details.".format(mode))
  0%|          | 1/3165 [00:03<2:45:08,  3.13s/it]

Epoch: 0/20, Step: 0, D Loss: 1.3818, G Loss: 0.2489, gp: 0.0000, Time:04:41:05


 32%|███▏      | 1001/3165 [22:33<1:01:31,  1.71s/it]

Epoch: 0/20, Step: 1000, D Loss: 1.3570, G Loss: 0.6873, gp: 0.1113, Time:05:03:36


 63%|██████▎   | 2001/3165 [45:10<33:45,  1.74s/it]

Epoch: 0/20, Step: 2000, D Loss: 1.1044, G Loss: 0.7396, gp: 0.1250, Time:05:26:12


 95%|█████████▍| 3001/3165 [1:07:47<04:40,  1.71s/it]

Epoch: 0/20, Step: 3000, D Loss: 1.3055, G Loss: 0.6895, gp: 0.0714, Time:05:48:50


100%|██████████| 3165/3165 [1:11:29<00:00,  1.36s/it]
 26%|██▋       | 836/3165 [18:53<1:08:00,  1.75s/it]

Epoch: 1/20, Step: 4000, D Loss: 1.4147, G Loss: 0.9737, gp: 0.0211, Time:06:11:25


 58%|█████▊    | 1836/3165 [41:28<38:10,  1.72s/it]

Epoch: 1/20, Step: 5000, D Loss: 1.3167, G Loss: 0.7888, gp: 0.0574, Time:06:34:01


 90%|████████▉ | 2836/3165 [1:04:05<09:33,  1.74s/it]

Epoch: 1/20, Step: 6000, D Loss: 1.3393, G Loss: 0.7623, gp: 0.0543, Time:06:56:37


100%|██████████| 3165/3165 [1:11:30<00:00,  1.36s/it]
 21%|██        | 671/3165 [15:10<1:12:47,  1.75s/it]

Epoch: 2/20, Step: 7000, D Loss: 1.3101, G Loss: 0.8326, gp: 0.0585, Time:07:19:14


 53%|█████▎    | 1671/3165 [37:46<43:03,  1.73s/it]

Epoch: 2/20, Step: 8000, D Loss: 1.3726, G Loss: 0.8579, gp: 0.0446, Time:07:41:49


 84%|████████▍ | 2671/3165 [1:00:23<14:08,  1.72s/it]

Epoch: 2/20, Step: 9000, D Loss: 1.3652, G Loss: 0.9076, gp: 0.0406, Time:08:04:26


100%|██████████| 3165/3165 [1:11:33<00:00,  1.36s/it]
 16%|█▌        | 506/3165 [11:27<1:17:16,  1.74s/it]

Epoch: 3/20, Step: 10000, D Loss: 1.3275, G Loss: 0.8380, gp: 0.0401, Time:08:27:04


 48%|████▊     | 1506/3165 [34:02<47:20,  1.71s/it]

Epoch: 3/20, Step: 11000, D Loss: 1.3407, G Loss: 0.7436, gp: 0.0384, Time:08:49:39


 79%|███████▉  | 2506/3165 [56:37<18:50,  1.72s/it]

Epoch: 3/20, Step: 12000, D Loss: 1.3232, G Loss: 0.7574, gp: 0.0369, Time:09:12:14


100%|██████████| 3165/3165 [1:11:31<00:00,  1.36s/it]
 11%|█         | 341/3165 [07:43<1:22:15,  1.75s/it]

Epoch: 4/20, Step: 13000, D Loss: 1.3224, G Loss: 0.7754, gp: 0.0335, Time:09:34:51


 42%|████▏     | 1341/3165 [30:20<52:00,  1.71s/it]

Epoch: 4/20, Step: 14000, D Loss: 1.3552, G Loss: 0.6814, gp: 0.0332, Time:09:57:28


 74%|███████▍  | 2341/3165 [52:56<23:43,  1.73s/it]

Epoch: 4/20, Step: 15000, D Loss: 1.3790, G Loss: 0.7045, gp: 0.0350, Time:10:20:05


100%|██████████| 3165/3165 [1:11:34<00:00,  1.36s/it]
  6%|▌         | 176/3165 [04:00<1:27:11,  1.75s/it]

Epoch: 5/20, Step: 16000, D Loss: 1.3481, G Loss: 0.8187, gp: 0.0330, Time:10:42:43


 37%|███▋      | 1176/3165 [26:37<57:54,  1.75s/it]

Epoch: 5/20, Step: 17000, D Loss: 1.3548, G Loss: 0.7301, gp: 0.0300, Time:11:05:20


 69%|██████▉   | 2176/3165 [49:14<28:39,  1.74s/it]

Epoch: 5/20, Step: 18000, D Loss: 1.3535, G Loss: 0.7309, gp: 0.0275, Time:11:27:57


100%|██████████| 3165/3165 [1:11:34<00:00,  1.36s/it]
  0%|          | 11/3165 [00:16<1:32:05,  1.75s/it]

Epoch: 6/20, Step: 19000, D Loss: 1.3519, G Loss: 0.7394, gp: 0.0293, Time:11:50:33


 32%|███▏      | 1011/3165 [22:51<1:02:57,  1.75s/it]

Epoch: 6/20, Step: 20000, D Loss: 1.3539, G Loss: 0.8208, gp: 0.0277, Time:12:13:09


 64%|██████▎   | 2011/3165 [45:26<33:38,  1.75s/it]

Epoch: 6/20, Step: 21000, D Loss: 1.3624, G Loss: 0.7708, gp: 0.0266, Time:12:35:44


 95%|█████████▌| 3011/3165 [1:08:03<04:28,  1.74s/it]

Epoch: 6/20, Step: 22000, D Loss: 1.3501, G Loss: 0.7151, gp: 0.0279, Time:12:58:20


100%|██████████| 3165/3165 [1:11:31<00:00,  1.36s/it]
 27%|██▋       | 846/3165 [19:08<1:08:18,  1.77s/it]

Epoch: 7/20, Step: 23000, D Loss: 1.3499, G Loss: 0.7333, gp: 0.0253, Time:13:20:57


 58%|█████▊    | 1846/3165 [41:44<38:02,  1.73s/it]

Epoch: 7/20, Step: 24000, D Loss: 1.3501, G Loss: 0.7451, gp: 0.0274, Time:13:43:33


 90%|████████▉ | 2846/3165 [1:04:20<09:14,  1.74s/it]

Epoch: 7/20, Step: 25000, D Loss: 1.3585, G Loss: 0.7017, gp: 0.0249, Time:14:06:09


100%|██████████| 3165/3165 [1:11:33<00:00,  1.36s/it]
 22%|██▏       | 681/3165 [15:25<1:12:10,  1.74s/it]

Epoch: 8/20, Step: 26000, D Loss: 1.3633, G Loss: 0.6821, gp: 0.0227, Time:14:28:47


 32%|███▏      | 999/3165 [22:36<48:56,  1.36s/it]

In [None]:
save_checkpoint({'global_step': step,
     'D':D.state_dict(),
     'G':G.state_dict(),
     'd_optim': D_opt.state_dict(),
     'g_optim' : G_opt.state_dict()},
    'ckpt/r1gan{:06d}.pth.tar'.format(step))

### Random Sample

In [None]:
G_path = sorted(glob.glob(os.path.join('ckpt', '*.pth.tar')))[-1]
state = torch.load(G_path)
G.load_state_dict(state['G'])

In [None]:
G.eval()
None

In [None]:
img = get_sample_image(G, n_noise, n_samples=25)
imshow(img)

In [None]:
# Fake Image
idx = [3, 1]
row, col = IMAGE_DIM[0]*idx[0], IMAGE_DIM[1]*idx[1]
imshow(img[row:row+IMAGE_DIM[0], col:col+IMAGE_DIM[1], :])

In [None]:
# Real Image
i = 14
rimg = tensor2img(images[i])
imshow(rimg)

### Interpolation

In [None]:
def sample_noise(size=None):
    z = np.random.random(size=size)*2-1
    return z

In [None]:
z_a, z_b = sample_noise(n_noise), sample_noise(n_noise)
zs = torch.tensor([np.linspace(z_a[i], z_b[i], num=10) for i in range(n_noise)], dtype=torch.float32).to(DEVICE)
zs = torch.transpose(zs, 0, 1)
zs.shape

In [None]:
imgs = G(zs).detach()
imgs = torch.cat([imgs[i] for i in range(10)], dim=-1)
imgs.shape

In [None]:
fig = plt.figure(figsize=(20, 3))
imshow(tensor2img(imgs))