In [None]:
from google.colab import drive
import json
drive.mount("/content/gdrive", force_remount=True)

Mounted at /content/gdrive


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision.models.vgg import vgg19
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

import os
import time
from random import seed
from util import *

seed(11785)

In [None]:
#define the GAN architecture and perceptual loss
class sub_conv(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias):
        super(sub_conv, self).__init__(in_channels = in_channels, out_channels = out_channels, 
                               kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True)
        
        self.weight.data = torch.normal(torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), 0.02)
        self.bias.data = torch.zeros((out_channels))
        
        for p in self.parameters():
            p.requires_grad = True

class conv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, BN = False, act = None, stride = 1, bias = True):
        super(conv, self).__init__()
        m = []
        m.append(sub_conv(in_channels = in_channel, out_channels = out_channel, 
                               kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True))
        
        if BN:
            m.append(nn.BatchNorm2d(num_features = out_channel))
        
        if act is not None:
            m.append(act)
        
        self.layers = nn.Sequential(*m)
        
    def forward(self, x):
        out = self.layers(x)
        return out
        
class ResidualBlock(nn.Module):
    def __init__(self, channels, kernel_size, act = nn.ReLU(inplace = True), bias = True):
        super(ResidualBlock, self).__init__()
        self.blocks = nn.Sequential(conv(channels, channels, kernel_size, BN = True, act = act),
                                    conv(channels, channels, kernel_size, BN = True, act = None))
        
    def forward(self, x):
        res = self.blocks(x)
        res += x
        return res
        
class Upsampler(nn.Module):
    def __init__(self, channel, kernel_size, scale, act = nn.ReLU(inplace = True)):
        super(Upsampler, self).__init__()
        self.upsample = nn.Sequential(conv(channel, channel * scale * scale, kernel_size),
                                      nn.PixelShuffle(scale),
                                      act)
    def forward(self, x):
        out = self.upsample(x)
        return out

class discriminator_block(nn.Module):
    def __init__(self, in_feats, out_feats, kernel_size, act = nn.LeakyReLU(inplace = True)):
        super(discriminator_block, self).__init__()
        m = []
        m.append(conv(in_feats, out_feats, kernel_size, BN = True, act = act))
        m.append(conv(out_feats, out_feats, kernel_size, BN = True, act = act, stride = 2))
        self.body = nn.Sequential(*m)
        
    def forward(self, x):
        out = self.body(x)
        return out

class Generator(nn.Module):
    def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, num_block = 16, act = nn.PReLU(), scale=4):
        super(Generator, self).__init__()
        
        self.conv1 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act)

        layers = [ResidualBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)]
        layers.append(conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = True, act = None))
        self.layers = nn.Sequential(*layers)

        upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)]

        self.tail = nn.Sequential(*upsample_blocks)
        
        self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh())
        
    def forward(self, x):
        x = self.conv1(x)
        _skip_connection = x

        x = self.layers(x)
        feat = x + _skip_connection
        
        x = self.tail(feat)
        x = self.last_conv(x)
        
        return x, feat
    
class Discriminator(nn.Module):
    
    def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, act = nn.LeakyReLU(inplace = True), num_of_block = 3, patch_size = 96):
        super(Discriminator, self).__init__()
        self.act = act
        
        layers = []
        layers.append(conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act))
        layers.append(conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act, stride = 2))
        layers.extend([discriminator_block(in_feats = n_feats * (2 ** i), out_feats = n_feats * (2 ** (i + 1)), kernel_size = 3, act = self.act) for i in range(num_of_block)])
        self.layers = nn.Sequential(*layers)
        
        self.linear_size = ((patch_size // (2 ** (num_of_block + 1))) ** 2) * (n_feats * (2 ** num_of_block))
        
        self.final_layers = nn.Sequential(nn.Linear(self.linear_size, 1024),
                                  self.act,
                                  nn.Linear(1024, 1),
                                  nn.Sigmoid())
        
    def forward(self, x):
        x = self.layers(x)        
        x = x.view(-1, self.linear_size)
        x = self.final_layers(x)
        return x

class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()

        vgg = vgg19(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for p in loss_network.parameters():
            p.requires_grad = False
        self.network = loss_network
        self.loss = nn.MSELoss()

    def forward(self, high_resolution, fake_high_resolution):
        perception_loss = self.loss(self.network(fake_high_resolution), self.network(high_resolution))
        return perception_loss

In [None]:
# parameter
class args():
  res_num = 16  # number of residual block
  num_workers = 0  # number of epochs to train for
  batch_size = 16
  sample_batch_size = 1
  L2_coeff = 1.0
  adv_coeff = 1e-3
  vgg_rescale_coeff = 0.006
  pre_train_epoch = 10  # epochs in current train (load model)
  fine_train_epoch = 60
  checkpoint_dir = 'SRGAN/checkpoints' #path to saved models
  sample_dir = 'samples'  #folder to output images and model checkpoints
  scale = 4
  patch_size = 32 # lr patch_size
  load_model_epoch = False  # model epoch to load, start from scratch if False
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# define the function for training and validation process
def train_valid(args):
    train_dataset = Datasets('train')
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
    dev_dataset = Datasets('valid')
    dev_loader = torch.utils.data.DataLoader(dataset=dev_dataset, batch_size=args.sample_batch_size, shuffle=True)

    generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num, scale=args.scale).to(args.device)
    generator.train()
    
    l2_loss = nn.MSELoss()
    g_optim = optim.Adam(generator.parameters(), lr = 1e-4)
        
    pre_epoch = 0
    fine_epoch = 0
    
    #### Train using L2_loss
    while pre_epoch < args.pre_train_epoch:
        start_time = time.time()
        for i, tr_data in enumerate(train_loader):
            # if i == 5: break
            gt = tr_data['hr'].to(args.device)
            lr = tr_data['lr'].to(args.device)

            output, _ = generator(lr)
            loss = l2_loss(gt, output)

            g_optim.zero_grad()
            loss.backward()
            g_optim.step()

        pre_epoch += 1
        end_time = time.time()
        print('pre_epoch:', pre_epoch)
        print('time: ', end_time-start_time)
        print(loss.item())

        print('=========')

    #### Train using perceptual & adversarial loss
    
    discriminator = Discriminator(patch_size = args.patch_size * args.scale).to(args.device)
    discriminator.train()
    
    d_optim = optim.Adam(discriminator.parameters(), lr = 1e-4)
    scheduler = optim.lr_scheduler.StepLR(g_optim, step_size = args.fine_train_epoch//4, gamma = 0.1)

    VGG_loss = PerceptualLoss().to(args.device)
    cross_ent = nn.BCELoss()
    
    while fine_epoch < args.fine_train_epoch:
        start_time = time.time()
        generator.train()
        discriminator.train()
        total_g_loss, total_d_loss = 0, 0
        real_label = torch.ones((args.batch_size, 1)).to(args.device)
        fake_label = torch.zeros((args.batch_size, 1)).to(args.device)
        for i, tr_data in enumerate(train_loader):
            # if i == 5: break
            gt = tr_data['hr'].to(args.device)
            lr = tr_data['lr'].to(args.device)
                        
            ## Training Discriminator
            output, _ = generator(lr)
            fake_prob, real_prob = discriminator(output), discriminator(gt)
            
            d_loss = cross_ent(real_prob, real_label) + cross_ent(fake_prob, fake_label)

            total_d_loss += d_loss.item()
            g_optim.zero_grad()
            d_optim.zero_grad()
            d_loss.backward()
            d_optim.step()
            
            ## Training Generator
            output, _ = generator(lr)
            fake_prob = discriminator(output)

            _percep_loss = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0)

            L2_loss = l2_loss(output, gt)
            percep_loss = args.vgg_rescale_coeff * _percep_loss
            adversarial_loss = args.adv_coeff * cross_ent(fake_prob, real_label)
            
            g_loss = percep_loss + adversarial_loss + L2_loss
            
            total_g_loss += g_loss.item()
            g_optim.zero_grad()
            d_optim.zero_grad()
            g_loss.backward()
            g_optim.step()
        end_time = time.time()
        print('fine_epoch:', fine_epoch)
            
        print('training time: ', end_time-start_time)
        print('total g loss:', total_g_loss/len(train_loader))
        print('total d loss:', total_d_loss/len(train_loader))
        print('====')

        generator.eval()
        discriminator.eval()
        total_g_loss, total_d_loss, PSNR, SSIM = 0, 0, 0, 0
        real_label = torch.ones((args.sample_batch_size, 1)).to(args.device)
        fake_label = torch.zeros((args.sample_batch_size, 1)).to(args.device)
        for i, tr_data in enumerate(dev_loader):
            # if i == 5: break
            gt = tr_data['hr'].to(args.device)
            lr = tr_data['lr'].to(args.device)
                        
            ## Training Discriminator
            output, _ = generator(lr)
            fake_prob, real_prob = discriminator(output), discriminator(gt)
            
            d_loss = cross_ent(real_prob, real_label) + cross_ent(fake_prob, fake_label)

            total_d_loss += d_loss.item()
            g_optim.zero_grad()
            d_optim.zero_grad()
            
            ## Training Generator
            output, _ = generator(lr)
            fake_prob = discriminator(output)

            _percep_loss = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0)

            L2_loss = l2_loss(output, gt)
            percep_loss = args.vgg_rescale_coeff * _percep_loss
            adversarial_loss = args.adv_coeff * cross_ent(fake_prob, real_label)
            
            g_loss = percep_loss + adversarial_loss + L2_loss
            
            total_g_loss += g_loss.item()
            g_optim.zero_grad()
            d_optim.zero_grad()

            gt = gt.permute(0, 2, 3, 1).cpu().detach().numpy()
            output = output.permute(0, 2, 3, 1).cpu().detach().numpy()

            PSNR += peak_signal_noise_ratio(gt[0], output[0])
            SSIM += structural_similarity(gt[0], output[0], multichannel=True)

        scheduler.step()
        fine_epoch += 1
        end_time = time.time()

            
        print('time: ', time.time()-end_time)
        print('total g loss:', total_g_loss/len(dev_loader))
        print('total d loss:', total_d_loss/len(dev_loader))
        print('PSNR:', PSNR/len(dev_loader))
        print('SSIM:', SSIM/len(dev_loader))   

        %cd ./gdrive/My Drive/11785/HW5/
        torch.save(generator.state_dict(), os.path.join(args.checkpoint_dir, f"generator_{fine_epoch}.pth"))
        torch.save(discriminator.state_dict(), os.path.join(args.checkpoint_dir, f"discriminator_{fine_epoch}.pth"))
        %cd /content

        print('=========')

In [None]:
# download the dataset and do the data preprocessing (crop or resize)
download_dataset()

print('[!] Making Patches')

resize_image('train_hr', args.patch_size*args.scale)
resize_image('train_lr', args.patch_size)
resize_image('valid_hr', args.patch_size*args.scale)
resize_image('valid_lr', args.patch_size)

#train and save the model
torch.cuda.empty_cache()
train_valid(args)

In [None]:
%cd /content/gdrive/My Drive/11785/HW5

# change to the google drive folder to generate test image
def test(target_folder, load_model_epoch, args, resize=False): # generate image from testset
    if resize:
      resize_image('test_lr/'+target_folder, resize, 'test_resize/'+target_folder)
      test_dataset = Datasets(mode='test_resize/'+target_folder)
    else:
      test_dataset = Datasets(mode='test_lr/'+target_folder)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.sample_batch_size)
    if not os.path.exists('test_results/'+target_folder):
      os.makedirs('test_results/'+target_folder)
    
    model = Generator(img_feat=3, n_feats=64, kernel_size=3, num_block=args.res_num, scale=args.scale).to(args.device)
    model.load_state_dict(torch.load(os.path.join(args.checkpoint_dir, f'generator_{load_model_epoch}.pth')))
    with torch.no_grad():
      model.eval()
      for step, image in enumerate(test_loader):
        lr = image['lr'].to(args.device)
        image_name = image['hr']
        outputs, _ = model(lr)
        save_image(outputs, os.path.join('test_results/'+target_folder, image_name[0]))

target_folder = ['large_test', 'small_test', 'comics', 'structures']
load_model_epoch = 59
for t in target_folder:
    test(t, load_model_epoch, args)

In [None]:
# Using the provided script to calcualte test PSNR and SSIM
TEST_LABEL_PATH = 'test_labels/small_test/'
TEST_RESULT_PATH = 'test_results/small_test/'
TEST_RESTORE_PATH =  'test_restore_results/small_test/'


scores = compute_scores(TEST_LABEL_PATH, TEST_RESULT_PATH, TEST_RESTORE_PATH)
print('small_test: ', scores)

TEST_LABEL_PATH = 'test_labels/large_test/'
TEST_RESULT_PATH = 'test_results/large_test/'
TEST_RESTORE_PATH =  'test_restore_results/large_test/'

scores = compute_scores(TEST_LABEL_PATH, TEST_RESULT_PATH, TEST_RESTORE_PATH)
print('large_test: ', scores)

TEST_LABEL_PATH = 'test_labels/comics/'
TEST_RESULT_PATH = 'test_results/comics/'
TEST_RESTORE_PATH =  'test_restore_results/comics/'

scores = compute_scores(TEST_LABEL_PATH, TEST_RESULT_PATH, TEST_RESTORE_PATH)
print('comic: ', scores)

TEST_LABEL_PATH = 'test_labels/structures/'
TEST_RESULT_PATH = 'test_results/structures/'
TEST_RESTORE_PATH =  'test_restore_results/structures/'

scores = compute_scores(TEST_LABEL_PATH, TEST_RESULT_PATH, TEST_RESTORE_PATH)
print('structures: ', scores)