# **Generative Adversarial Networks(GANs)**
<img align='right' width='800' src="https://cdn-images-1.medium.com/v2/resize:fit:851/0*pPEL7ryJR51VpnDO.jpg">

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os
from PIL import Image

import scipy

import torch
from torch import nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision.datasets import MNIST, FashionMNIST
from torchvision import transforms

## **work with kaggle**
- First you have to go to your profile and creat API token which will download kaggle.json to your pc
- Now drag and drop json file to your colab files tab
- Run the codes below
    - ```
    !mkdir /root/.kaggle
    !mv kaggle.json /root/.kaggle/kaggle.json
    !chmod 600 /root/.kaggle/kaggle.json
    !kaggle datasets list
    ```
- To download the datasts to your colab open the desire competition(datasets) and click on `three dot` on the upper-right then `copy API command`

-  for more information [click here](https://www.kaggle.com/discussions/general/74235)

In [14]:
# !rm -r /root/.kaggle
!mkdir /root/.kaggle
!mv kaggle.json /root/.kaggle/kaggle.json
!chmod 600 /root/.kaggle/kaggle.json

mkdir: cannot create directory ‘/root/.kaggle’: File exists


In [15]:
!kaggle datasets download -d jessicali9530/celeba-dataset

Downloading celeba-dataset.zip to /content
100% 1.33G/1.33G [00:35<00:00, 42.7MB/s]
100% 1.33G/1.33G [00:35<00:00, 39.7MB/s]


In [None]:
!unzip /content/celeba-dataset.zip

In [5]:
# Hyperparameters
EPOCH = 20
LR = 2e-4
BS = 32
C, H, W = 3, 24, 24
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Availabe device is: ", device)

Availabe device is:  cuda


In [6]:
# Visioalize the data
def show(tensor, ch=C, size=(H, W), num_to_display=16):
    """
    Inputs would be tensor with (batch_size, channel, height, weight) dimention
    First we detach() tensor so because it's not require grade any more,
    Then send it to cpu() to make sure the tensor doesn't on different device
    Matplotlib show images in (height, width, channel) dimention so the images permute to match the criteria
    """
    images = tensor.detach().cpu().view(-1, ch, *size)
    grid = make_grid(images[:num_to_display], nrow=4, normalize=True).permute(1, 2, 0)
    plt.axis(False)
    plt.imshow(grid)
    plt.show()

##**ESRGAN architecture**

<img align='center' width='1200' src="https://esrgan.readthedocs.io/en/latest/_images/architecture.png">

In [7]:
class convBlock(nn.Module):

    def __init__(self, in_channel, out_channel, use_act=True, use_bn=False, discriminator=False, **kw):
        super().__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, **kw, bias= not discriminator)
        self.bn = nn.BatchNorm2d(out_channel) if discriminator else nn.Identity()
        self.act = nn.LeakyReLU(0.2, inplace=True) if use_act else nn.Identity()

    def forward(self, x):
        out = self.act(self.bn(self.conv(x)))
        return out

class denseBlock(nn.Module):

    def __init__(self, in_channel, channels=32, use_act=True, beta=0.2, **kw):
        super().__init__()
        self.beta = beta
        self.conv = nn.ModuleList()
        for i in range(5):
            use_act, ch_out = (False, in_channel) if i == 4 else (True, channels)
            self.conv.append(convBlock(
                in_channel + i*channels, ch_out, use_act, kernel_size=3, stride=1, padding=1
            ))

    def forward(self, x):
        new = x
        for conv in self.conv:
            out = conv(new)
            new = torch.cat([new, out], dim=1)
        # In basicBlock residual layers output multiply by 0.2
        out = self.beta * out + x
        return out

class basicBlock(nn.Module):

    def __init__(self, in_channel, beta=0.2):
        super().__init__()
        self.beta = beta
        self.conv =nn.Sequential(*[denseBlock(in_channel) for _ in range(3)])


    def forward(self, x):
        return self.beta * self.conv(x) + x


class upSample(nn.Module):

    def __init__(self, in_channel, up_factor=2):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=up_factor, mode='nearest') # 'nearest', 'linear', 'bilinear', 'bicubic' ,'trilinear'. Default: 'nearest'
        self.conv = nn.Conv2d(in_channel, in_channel, 3, 1, 1)
        self.act = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        return self.act(self.conv(self.upsample(x)))

In [8]:
# Generator
class Generator(nn.Module):

    def __init__(self, in_channel=3, h_channel=64, n_block=23):
        super().__init__()
        self.initial = convBlock(in_channel, h_channel, kernel_size=3, stride=1, padding=1)
        self.resblocks = nn.Sequential(
            *[basicBlock(h_channel) for _ in range(n_block)]
        )
        self.conv = nn.Conv2d(h_channel, h_channel, kernel_size=3, stride=1, padding=1)
        self.up_sample = nn.Sequential(
            upSample(h_channel),
            upSample(h_channel)
        )
        self.final = nn.Sequential(
            convBlock(h_channel, h_channel, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(h_channel, in_channel, 3, 1, 1)
        )


    def forward(self, x):
        initial = self.initial(x)
        x = self.resblocks(initial)
        x = self.conv(x) + initial
        x = self.up_sample(x)

        return self.final(x)


# Discriminator
class Discriminator(nn.Module):

    def __init__(self, in_c=3, h_feature=[64, 64, 128, 128, 256, 256, 512, 512]):
        super().__init__()
        conv_blocks = []
        for i, feature in enumerate(h_feature):
            layer = convBlock(
                in_c,
                feature,
                use_bn=False if i == 0 else True,
                discriminator=True,
                kernel_size=3,
                stride=1 if i%2 == 0 else 2,
                padding=1
            )

            conv_blocks.append(layer)
            in_c = feature

        self.final = nn.Sequential(
            *conv_blocks,
            nn.Flatten(),
            nn.Linear(512*6*6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.final(x)



def test_models(device=device):
    x = torch.rand(10, 3, 24, 24).to(device)
    out = Generator().to(device)(x)
    pred = Discriminator().to(device)(out)
    print(f" --> input shape: {x.shape}\n --> Generator output size: {out.shape}\n --> Discriminator output size: {pred.shape}\n")
    return "Every thing is O.K"

test_models()

 --> input shape: torch.Size([10, 3, 24, 24])
 --> Generator output size: torch.Size([10, 3, 96, 96])
 --> Discriminator output size: torch.Size([10, 1])



'Every thing is O.K'

In [9]:
# def weights_init(m):
#     classname = m.__class__.__name__
#     if classname.find('Conv') != -1:
#         nn.init.normal_(m.weight.data, 0.0, 0.02)
#     elif classname.find('Line') != -1:
#         nn.init.normal_(m.weight.data, 1.0, 0.02)
#         nn.init.constant_(m.bias.data, 0)

def weights_init(m, scale=0.1):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight.data)
        m.weight.data *= scale

    elif classname.find('Line') != -1:
        nn.init.kaiming_normal_(m.weight.data)
        m.weight.data *= scale

## **Perceptual Loss**

They also develop a more effective perceptual loss $ \mathcal L_{percep}$ by constraining on fea- tures before activation rather than after activation as practiced in SRGAN..

$$
    \large \mathcal {L}_{G} = \mathcal {L}_{percep} + \lambda \mathcal {L}^{Ra}_{Gen} + \eta \mathcal {L}_{1}
$$

<br>

$\hspace{20pt} \lambda = 5*10^{-3}$

$\hspace{20pt} \eta = 10^{-2}$

<br>

- **Perceptual Loss:**
the paper’s authors decide to go with the loss of several VGG layers. The ReLU activation layers of the pre-trained 19-layer VGG network act as the foundation for this VGG loss which is the Euclidean distance between feature representations.

$$
    \large \mathcal {L}_{percep} = \frac{1}{W_{i,j}H_{i,j}} \sum^{W_{i,j}}_{x} \sum^{H_{i,j}}_{y} \left(\phi_{i,j}(I^{HR}_{x,y}) - \phi_{i,j}(I^{LR}_{x,y}) \right)^2
$$

With $\phi_{i,j}$ we indicate the feature map obtained by the j-th convolution
(after activation) before the i-th maxpooling layer within the
VGG19 network

Here $W_{i,j}$ and $H_{i,j}$ describe the dimensions of the respective feature maps within the VGG network.

<br>

- **Adversarial Loss(Relativistic):**

- When label is 1 (for real ones):

$$
    \large D_{Ra}(x_r, x_f)  =\ frac{1}{n}\sum_{i=1}^{n}{(\log(\hat{𝚢_i}))} - [\mathbb{E}(D(I^{HR}_{x,y})]
$$

<br>

- When label is 0 (for generated images):

$$
    \large D_{Ra}(x_f, x_r)  = \frac{1}{n}\sum_{i=1}^{n}{(\log(1- \hat{𝚢_i}))} - [\mathbb{E}(D(G(I^{LR}_{x,y}))]
$$

<br>

$
\begin{align*}
    \hspace{20pt} \mathbb{E}(D(G(I^{LR}_{x,y})) &=  \overline{D(G(I^{LR}_{x,y}))} \\
    \hspace{20pt} \mathbb{E}(D(I^{HR}_{x,y}) &=  \overline{D(I^{HR}_{x,y})}
\end{align*}
$

<br>

$$
    \large \mathcal {L}^{Ra}_{Gen} = D_{Ra}(x_r, x_f) + D_{Ra}(x_f, x_r)
$$

- **content loss:**
 evaluate the 1-norm distance between recovered image $G(I^{LR}_{x,y})$ and the ground-truth $I^{HR}_{x,y}$

<br>

$$
    \large mathcal {L}_{1} = \mathbb{E}_{x_i} \|G(I^{LR}_{x,y}) - I^{HR}_{x,y}\|
$$

<br>

[paper](https://arxiv.org/abs/1609.04802)

In [10]:
class vggPartial(nn.Module):

    def __init__(self, device=device):
        super().__init__()
        self.vgg = torchvision.models.vgg19(weights='VGG19_Weights.IMAGENET1K_V1').to(device)
        self.vgg = self.vgg.eval()
        self.vgg.features[34].register_forward_hook(self._hook)

    def _hook(self, module, input, output):
        self.out = output

    def forward(self, x):
        self.vgg(x)
        return self.out

class vggLoss(nn.Module):

    def __init__(self, device=device):
        super().__init__()
        self.vgg = vggPartial()

    def forward(self, x, y):
        x = self.vgg(x)
        y = self.vgg(y)
        return F.mse_loss(x, y)


def test_vggLoss(device=device):
    loss = vggLoss()
    x = torch.rand(1, 3, 24, 24).to(device)
    y = torch.rand(1, 3, 96, 96).to(device)
    out = Generator().to(device)(x)
    return loss(out, y)

test_vggLoss()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:09<00:00, 61.6MB/s]


tensor(0.1114, device='cuda:0', grad_fn=<MseLossBackward0>)

In [11]:
# torch.cuda.empty_cache()
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "caching_allocator"
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb=512'
# !export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512

In [12]:
def gen_loss_func(gen_net, disc_net, loss_func, l_res, h_res, mode=None):

    fake = gen_net(l_res)

    if mode == 'First_run':
        content_loss = nn.MSELoss()
        return content_loss(fake, h_res)

    pred = disc_net(fake)
    regular_loss = loss_func(pred, torch.ones_like(pred))
    perceptual_loss = vggLoss()
    vgg_loss = perceptual_loss(fake, h_res)
    content_loss = nn.L1Loss()
    # 1e-3 * regular_loss + 6e-3 * vgg_loss
    return vgg_loss + 5e-3 * regular_loss + 1e-2 * content_loss(fake, h_res)



def disc_loss_func(gen_net, disc_net, loss_func, l_res, h_res, mode='relativistic'):
    fake = gen_net(l_res)
    fake_pred = disc_net(fake.detach())  # detach() the generator output so it won't participate in gen_net learning
    real_pred = disc_net(h_res)


    if mode == 'relativistic':
        loss_real = loss_func(real_pred - fake_pred.mean(0, keepdim=True), torch.ones_like(real_pred))
        loss_fake = loss_func(fake_pred - real_pred.mean(0, keepdim=True), torch.zeros_like(fake_pred))
    else:
        loss_real = loss_func(real_pred, torch.ones_like(real_pred))
        loss_fake = loss_func(fake_pred, torch.zeros_like(fake_pred))

    return (loss_real + loss_fake) / 2

In [13]:
class dataSet(torch.utils.data.Dataset):

    def __init__(self, data, high_res=96):
        super().__init__()
        low_res = high_res // 4
        self.high_transform = transforms.Compose([
            transforms.Resize((high_res, high_res)),
            transforms.Normalize([0, 0, 0],[1, 1, 1]),
            # transforms.ToTensor()
        ])

        self.low_transform = transforms.Compose([
            transforms.Resize((low_res, low_res)),
            transforms.Normalize([0, 0, 0],[1, 1, 1]),
            # transforms.ToTensor()
        ])

        self.both_transform = transforms.Compose([
            transforms.RandomCrop((high_res, high_res)),
            # transforms.RandomHorizontalFlip(p=0.5),
            # transforms.RandomRotation(90),
            transforms.ToTensor()
        ])


        self.data = data(root='data',
                         download=True,
                         transform=self.both_transform)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item, _ = self.data[idx]
        high = self.high_transform(item)
        low = self.low_transform(item)

        return low, high

In [None]:
from torch.utils.data import Dataset
class KaggleDataset(Dataset):

    def __init__(self, dir_path, high_res=96):
        super().__init__()
        low_res = high_res // 4
        self.dir_path = dir_path

        self.high_transform = transforms.Compose([
            transforms.Resize((high_res, high_res)),
            transforms.Normalize([0, 0, 0],[1, 1, 1])
        ])

        self.low_transform = transforms.Compose([
            transforms.Resize((low_res, low_res)),
            transforms.Normalize([0, 0, 0],[1, 1, 1])
        ])

        self.both_transform = transforms.Compose([
            transforms.RandomCrop((high_res, high_res)),
            # transforms.RandomHorizontalFlip(p=0.5),
            # transforms.RandomRotation(90),
            transforms.ToTensor()
        ])

        self.images = []

        for item in tqdm(os.listdir(self.dir_path)[:12800]):
            img = Image.open(os.path.join(self.dir_path, item))
            img = self.both_transform(img)
            high = self.high_transform(img)
            low = self.low_transform(img)
            self.images.append((low, high))


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        low, high = self.images[idx]

        # img_name = os.listdir(self.dir_path)[idx]
        # img = Image.open(os.path.join(self.dir_path, img_name))

        # img = self.both_transform(img)
        # high = self.high_transform(img)
        # low = self.low_transform(img)

        return low, high

In [18]:
# data = dataSet(torchvision.datasets.LFWPeople)
path = "img_align_celeba/img_align_celeba"
data = KaggleDataset(path)
loader = torch.utils.data.DataLoader(data, 32)

  0%|          | 0/12800 [00:00<?, ?it/s]



In [None]:
x, y = next(iter(loader))
show(x, size=(24, 24))
show(y, size=(96, 96))

In [20]:
from torch.utils.tensorboard import SummaryWriter
!rm -r /content/runs
writer = SummaryWriter("/content/runs")
writer_fake = SummaryWriter("/content/runs/fake")
writer_l = SummaryWriter("/content/runs/l_res")
writer_h = SummaryWriter("/content/runs/h_res")

rm: cannot remove '/content/runs': No such file or directory


In [None]:
# !kill 5081
%load_ext tensorboard
%tensorboard --logdir=runs
# %reload_ext tensorboard

In [22]:
# !rm -r $PATH
PATH = "/content/model/"
!mkdir $PATH

In [23]:
gen = Generator().to(device)
gen.apply(weights_init)
# gen.load_state_dict(torch.load('/content/ESRGAN_Gen_1', map_location=device))
gen_opt = torch.optim.Adam(gen.parameters(), lr=2e-4, betas=(0.9, 0.999))
gen_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(gen_opt, step_size=1, gamma=0.8)

disc = Discriminator().to(device)
disc.apply(weights_init)
# critic.load_state_dict(torch.load('/content/disc_20'))
disc_opt = torch.optim.Adam(disc.parameters(), lr=4e-4, betas=(0.5, 0.999))
disc_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(disc_opt, step_size=1, gamma=0.8)

In [None]:
step = 0

loss_func = nn.BCELoss()

for epoch in range(EPOCH):
    discLoss, genLoss = 0, 0
    print(f"\nEpoch: {epoch + 1}")

    for batch,  (l_res, h_res) in enumerate(tqdm(loader)):
        l_res = l_res.to(device)
        h_res = h_res.to(device)

        # disc_opt.zero_grad()
        # disc_loss = disc_loss_func(gen, disc, loss_func, l_res, h_res, mode='relativistic')
        # disc_loss.backward() # If False, the graph used to compute the grad will be freed, Actually It isnt necessary
        # disc_opt.step()

        gen_opt.zero_grad()
        # At first use MSE loss then use paper loss
        gen_loss = gen_loss_func(gen, disc, loss_func, l_res, h_res, mode='First_run')
        # gen_loss = gen_loss_func(gen, disc, loss_func, content_loss, l_res, h_res)
        gen_loss.backward()
        gen_opt.step()

        # discLoss += disc_loss /len(data)
        genLoss += gen_loss /len(data)
        if batch % 30 == 0 and batch != 0:
            with torch.no_grad():
                step += 1
                fake = gen(l_res)
                image = h_res.view(-1, 3, H, W)
                h_res_grid = make_grid(h_res[:32], normalize=True)
                l_res_grid = make_grid(l_res[:32], normalize=True)
                fake_grid = make_grid(fake[:32], normalize=True)

                writer_fake.add_image(
                    "fake image", fake_grid, global_step=step
                )
                writer_l.add_image(
                    "Low_res image", l_res_grid, global_step=step
                )
                writer_h.add_image(
                    "High_res image", h_res_grid, global_step=step
                )

        writer.add_scalars("Loss", {
                    "Critic": 0,
                    "Generator": gen_loss
                }, (epoch+1)*batch)

    print(f'  Discriminator Loss: {discLoss:.4f} -- Generator Loss: {genLoss:.4f}')
    gen_exp_lr_scheduler.step()
    disc_exp_lr_scheduler.step()

    #Save model
    torch.save(gen.state_dict(), f"{PATH}Gen_{epoch+1}")
    torch.save(disc.state_dict(), f"{PATH}disc_{epoch+1}")

    if (epoch + 1) % 2 == 0 and epoch > 0:
        print(f"  >>> Discriminator Learning Rate: {disc_opt.param_groups[0]['lr']}")
        print(f"  >>> Generator Learning Rate: {gen_opt.param_groups[0]['lr']}")

In [None]:
# !cp /content/model/Gen_3
# ! cp /content/ESRGAN_Gen_1 /content/ESRGAN_Gen_1_

In [None]:
# gen.load_state_dict(torch.load('/content/ESRGAN_Gen_1'))
fig = plt.figure(figsize=(10, 8))
for i, (l, h) in enumerate(loader):
    generated = gen(l.to(device))
    generated = (generated - generated.min()) / generated.max()
    fig.add_subplot(3, 5, i+1)
    plt.imshow(generated[0].detach().cpu().permute(1, 2 ,0))
    plt.axis(False)
    plt.title('Generated')
    fig.add_subplot(3, 5, i+6)
    plt.imshow(l[0].detach().cpu().permute(1, 2 ,0))
    plt.axis(False)
    plt.title('Low res')
    fig.add_subplot(3, 5, i+11)
    plt.imshow(h[0].detach().cpu().permute(1, 2 ,0))
    plt.axis(False)
    plt.title('HIgh res')
    if i == 4:
        break

plt.show()

In [None]:
fig = plt.figure(figsize=(10, 8))
for i, (l, h) in enumerate(loader):
    generated = gen(l.to(device))
    fig.add_subplot(3, 5, i+1)
    plt.imshow(((generated[0]+1)/2).detach().cpu().permute(1, 2 ,0))
    plt.axis(False)
    plt.title('Generated')
    fig.add_subplot(3, 5, i+6)
    plt.imshow(((l[0]+1)/2).detach().cpu().permute(1, 2 ,0))
    plt.axis(False)
    plt.title('Low res')
    fig.add_subplot(3, 5, i+11)
    plt.imshow(((h[0]+1)/2).detach().cpu().permute(1, 2 ,0))
    plt.axis(False)
    plt.title('HIgh res')
    if i == 4:
        break

plt.show()

In [None]:
((l[0]+1)/2).min()

tensor(0.)

In [None]:
from PIL import Image
img = np.array(Image.open('/content/img.jpg'))
img = torch.from_numpy(img).permute(2, 0, 1)
l = (torchvision.transforms.Resize((24, 24), antialias=True)(img) - 122.5) / 122.5
h = (torchvision.transforms.Resize((96, 96), antialias=True)(img) - 122.5) / 122.5

In [None]:
gen.load_state_dict(torch.load('/content/drive/MyDrive/Gans_models/SRGAN_Gen_MSE'))

In [None]:
l_res = l.unsqueeze(0).to(device)
h_res = h.unsqueeze(0).to(device)

i = 0
while i <100:
    i += 1

    gen_opt.zero_grad()
    gen_loss = gen_loss_func(gen, disc, loss_func, content_loss, l_res, h_res)
    gen_loss.backward()
    gen_opt.step()

In [None]:
generated = gen(l.unsqueeze(0).to(device))

In [None]:
fig = plt.figure(figsize=(10, 8))
fig.add_subplot(1, 3, 1)
plt.imshow(((generated[0]+1)/2).detach().cpu().permute(1, 2 ,0))
plt.axis(False)
plt.title('Generated')
fig.add_subplot(1, 3, 2)
plt.imshow(((l+1)/2).detach().cpu().permute(1, 2 ,0))
plt.axis(False)
plt.title('Low res')
fig.add_subplot(1, 3, 3)
plt.imshow(((h+1)/2).detach().cpu().permute(1, 2 ,0))
plt.axis(False)
plt.title('HIgh res')



In [None]:
fig = plt.figure(figsize=(10, 8))
fig.add_subplot(1, 3, 1)
plt.imshow(((generated[0]+1)/2).detach().cpu().permute(1, 2 ,0))
plt.axis(False)
plt.title('Generated')
fig.add_subplot(1, 3, 2)
plt.imshow(((l+1)/2).detach().cpu().permute(1, 2 ,0))
plt.axis(False)
plt.title('Low res')
fig.add_subplot(1, 3, 3)
plt.imshow(((h+1)/2).detach().cpu().permute(1, 2 ,0))
plt.axis(False)
plt.title('HIgh res')


plt.show()