<a href="https://colab.research.google.com/github/LeeYuuuan/Applied_AI_midterm_exam/blob/main/SRGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
import sys
from tqdm import tqdm
import io

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
sys.path.append("/content/drive/MyDrive/Colab Notebooks/Applied_AI_Course_Assignment/Midterm")
from utils import *

In [None]:
def preprocess(input_zip, size=(128, 128)):
    images_processed = []

    with zipfile.ZipFile(input_zip, 'r') as archive_zip:
        archives = archive_zip.namelist()
        images = []
        labels = []
        for i, archive in enumerate(archives):
            if archive.endswith(('.png', '.jpg', '.jpeg')):
                images.append(archive)
                if 'cat' in archive:
                    labels.append(0)
                if 'dog' in archive:
                    labels.append(1)
        # images = [archive for archive in archives if archive.endswith(('.png', '.jpg', '.jpeg'))]
        imgs = []
        for img_path in tqdm(images):
            with archive_zip.open(img_path) as image_zip:
                img = Image.open(io.BytesIO(image_zip.read()))


                imgs.append(img)
                # images_processed.append(img_array)

    # dataset = np.array(images_processed)

    return imgs, np.array(labels)

In [None]:
cat_and_dog_images_hr, labels = preprocess('/content/drive/MyDrive/Colab Notebooks/Applied_AI_Course_Assignment/Midterm/dataset/train.zip')

100%|██████████| 25000/25000 [00:14<00:00, 1780.28it/s]


In [None]:
class Lr_Hr_dataset(Dataset):
    def __init__(self, cat_and_dog_dataset, labels, transform=None):
        self.cat_and_dog_dataset = cat_and_dog_dataset
        self.transform = transform
        self.lr_images = []
        self.hr_images = []
        lr_transform = transforms.Resize((32, 32))
        hr_transform = transforms.Resize((128, 128))
        for img in tqdm(cat_and_dog_dataset):
            self.lr_images.append(lr_transform(img))
            self.hr_images.append(hr_transform(img))

    def __len__(self):
        return len(self.cat_and_dog_dataset)

    def __getitem__(self, idx):
        lr_image = self.lr_images[idx]
        hr_image = self.hr_images[idx]
        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)
        label = labels[idx]
        return lr_image, hr_image,


In [None]:
lr_hr_dataset = Lr_Hr_dataset(cat_and_dog_dataset=cat_and_dog_images_hr, labels=labels, transform=transforms.Compose([
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]))

100%|██████████| 25000/25000 [02:16<00:00, 183.34it/s]


In [None]:
train_dataset, test_dataset = train_test_split(lr_hr_dataset, test_size=0.3, random_state=42)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
import math
import torch
from torch import nn


class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual


class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

In [None]:
import torch
from torch import nn
from torchvision.models.vgg import vgg16


class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        adversarial_loss = torch.mean(1 - out_labels)
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        # TV Loss
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss


class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]


if __name__ == "__main__":
    g_loss = GeneratorLoss()
    print(g_loss)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:05<00:00, 99.9MB/s]


GeneratorLoss(
  (loss_network): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding

In [None]:
transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToPILImage(),
        transforms.ToTensor()
    ])

In [None]:
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
import pytorch_ssim
from math import log10
import torchvision.utils as utils

In [None]:
def train(train_loader, test_loader, num_epochs=150, learning_rate=0.001):

    netG = Generator(scale_factor=4)
    netD = Discriminator()

    generator_criterion = GeneratorLoss()
    if device == 'cuda':
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()

    optimizerG = optim.Adam(netG.parameters(), lr=learning_rate)
    optimizerD = optim.Adam(netD.parameters(), lr=learning_rate)
    results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

    for epoch in range(0, num_epochs):
        train_bar = tqdm(train_loader)
        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

        netG.train()
        netD.train()

        for data, target in train_bar:
            g_update_first = True
            batch_size = data.size(0)
            running_results['batch_sizes'] += batch_size

            # training generator
            real_img = target
            z = data
            if device == 'cuda':
                z = z.float().cuda()
                real_img = real_img.float().cuda()


            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

            optimizerG.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()
            optimizerG.step()

            # training discriminator
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img.detach()).mean()
            d_loss = 1 - real_out + fake_out

            optimizerD.zero_grad()
            d_loss.backward()
            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

            optimizerD.step()

            # loss for current batch before optimization
            running_results['g_loss'] += g_loss.item() * batch_size
            running_results['d_loss'] += d_loss.item() * batch_size
            running_results['d_score'] += real_out.item() * batch_size
            running_results['g_score'] += fake_out.item() * batch_size

            # train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
            #     epoch, num_epochs, running_results['d_loss'] / running_results['batch_sizes'],
            #     running_results['g_loss'] / running_results['batch_sizes'],
            #     running_results['d_score'] / running_results['batch_sizes'],
            #     running_results['g_score'] / running_results['batch_sizes']))

            netG.eval()
            out_path = 'training_results/SRF_' + str(4) + '/'
            if not os.path.exists(out_path):
                os.makedirs(out_path)


            with torch.no_grad():
                val_bar = tqdm(test_loader)
                valing_results = {'mse': 0, 'psnr': 0, 'batch_sizes': 0}
                val_images = []
                for val_lr, val_hr in val_bar:
                    batch_size = val_lr.size(0)
                    valing_results['batch_sizes'] += batch_size
                    lr = val_lr
                    hr = val_hr
                    if device == 'cuda':
                        lr = lr.float().cuda()
                        hr = hr.float().cuda()
                    sr = netG(lr)

                    batch_mse = ((sr - hr) ** 2).data.mean()
                    valing_results['mse'] += batch_mse * batch_size




                    valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))

                    # val_bar.set_description(
                        # desc='[converting LR images to SR images] PSNR: %.4f dB ' % (
                            # valing_results['psnr']))

                    val_images.extend(
                        [display_transform()(hr[0].data.cpu().squeeze(0)),
                        display_transform()(sr[0].data.cpu().squeeze(0))])

                val_images = torch.stack(val_images)
                val_images = torch.chunk(val_images, val_images.size(0) // 15)
                val_save_bar = tqdm(val_images, desc='[saving training results]')
                index = 1
                for image in val_save_bar:
                    image = utils.make_grid(image, nrow=3, padding=5)
                    utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
                    index += 1

            # save model parameters
        torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (4, epoch))
        torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (4, epoch))
        # save loss\scores\psnr\ssim
        results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
        results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
        results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
        results['psnr'].append(valing_results['psnr'])



        if epoch % 10 == 0 and epoch != 0:
            out_path = 'statistics/'
            data_frame = pd.DataFrame(
                data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                      'Score_G': results['g_score'], 'PSNR': results['psnr']},
                index=range(1, epoch + 1))
            data_frame.to_csv(out_path + 'srf_' + str(4) + '_train_results.csv', index_label='Epoch')

        return netG

torch.Size([32, 3, 128, 128])

In [None]:
train(train_loader, test_loader, num_epochs=2)

  0%|          | 0/547 [00:00<?, ?it/s]
  0%|          | 0/235 [00:00<?, ?it/s][A
  2%|▏         | 5/235 [00:00<00:04, 47.81it/s][A
  4%|▍         | 10/235 [00:00<00:04, 48.04it/s][A
  6%|▋         | 15/235 [00:00<00:04, 48.25it/s][A
  9%|▊         | 20/235 [00:00<00:04, 47.63it/s][A
 11%|█         | 25/235 [00:00<00:04, 47.87it/s][A
 13%|█▎        | 30/235 [00:02<00:04, 48.56it/s][A
 15%|█▍        | 35/235 [00:02<00:22,  8.94it/s][A
 17%|█▋        | 40/235 [00:02<00:16, 12.07it/s][A
 20%|█▉        | 46/235 [00:02<00:11, 16.52it/s][A
 22%|██▏       | 51/235 [00:02<00:09, 20.43it/s][A
 24%|██▍       | 56/235 [00:02<00:07, 24.77it/s][A
 26%|██▌       | 61/235 [00:02<00:05, 29.02it/s][A
 28%|██▊       | 66/235 [00:03<00:08, 19.02it/s][A
 30%|███       | 71/235 [00:03<00:07, 23.31it/s][A
 32%|███▏      | 76/235 [00:03<00:05, 27.66it/s][A
 34%|███▍      | 81/235 [00:03<00:04, 31.59it/s][A
 37%|███▋      | 86/235 [00:03<00:04, 35.27it/s][A
 39%|███▊      | 91/235 [00:03<00: