In [1]:
from google.colab import drive
import json
drive.mount("/content/gdrive", force_remount=True)

Mounted at /content/gdrive


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

import os
from random import seed
import time
from util import *

seed(11785)

In [3]:
#Define the Convolutional Autoencoder
class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
       
        #Encoder
        self.encoder = nn.Sequential(nn.Conv2d(3, 64, 3, padding=1),
                                     nn.BatchNorm2d(64),
                                     nn.PReLU(),
                                     nn.MaxPool2d(2, 2),
                                     nn.Conv2d(64, 128, 3, padding=1),
                                     nn.BatchNorm2d(128),
                                     nn.PReLU(),
                                     nn.MaxPool2d(2, 2))
       
        #Decoder
        self.decoder = nn.Sequential(nn.ConvTranspose2d(128, 64, 2, stride=2),
                                     nn.BatchNorm2d(64),
                                     nn.PReLU(),
                                     nn.ConvTranspose2d(64, 3, 2, stride=2))
    def forward(self, x):
        emb = self.encoder(x)
        y = self.decoder(emb)
        return y

In [4]:
# parameter
class args():
  batch_size = 8
  checkpoint_dir = 'VAE/checkpoints' #path to saved models
  sample_dir = 'samples'  #folder to output images and model checkpoints
  sample_batch_size = 1
  load_model_epoch = False  # model epoch to load, start from scratch if False
  resize = 128
  n_epochs = 100
  lr = 1e-3
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# download dataset and pre processing (crop or resize)
download_dataset()

print('[!] Making Patches')
# crop_image('train_hr', 'train_lr', 96, 24)
resize_image('train_hr', args.resize)
resize_image('train_lr', args.resize)

resize_image('valid_hr', args.resize)
resize_image('valid_lr', args.resize)

DIV2K_train_HR.zip: 0.00B [00:00, ?B/s]

[!] Downloading Dataset
[!] download data file


DIV2K_train_HR.zip: 3.53GB [03:04, 19.1MB/s]                            
DIV2K_train_LR_bicubic_X4.zip: 0.00B [00:00, ?B/s]

[!] download data file


DIV2K_train_LR_bicubic_X4.zip: 247MB [00:14, 16.7MB/s]                           
DIV2K_valid_HR.zip: 0.00B [00:00, ?B/s]

[!] download data file


DIV2K_valid_HR.zip: 449MB [00:25, 17.3MB/s]                           
DIV2K_valid_LR_bicubic_X4.zip: 0.00B [00:00, ?B/s]

[!] download data file


DIV2K_valid_LR_bicubic_X4.zip: 31.5MB [00:04, 7.65MB/s]                            


[!] Upzip zipfile
[!] Reformat DIV2K HR (Training Set)
[!] Reformat DIV2K LR (Training Set)
[!] Reformat DIV2K HR (Validation Set)
[!] Reformat DIV2K LR (Validation Set)
[!] Making Patches
[*] [0/800] Make patch train_hr/0015.png
[*] [500/800] Make patch train_hr/0486.png
[*] [0/800] Make patch train_lr/0015.png
[*] [500/800] Make patch train_lr/0486.png
[*] [0/100] Make patch valid_hr/0015.png
[*] [0/100] Make patch valid_lr/0015.png


In [5]:
model = ConvAutoencoder().to(args.device)
if args.load_model_epoch:
  %cd ./gdrive/My Drive/11785/HW5/
  print(f"[*] Load model from google drive {args.checkpoint_dir} epoch {args.load_model_epoch}")
  if not os.path.exists(args.checkpoint_dir):
    os.makedirs(args.checkpoint_dir)
  model.load_state_dict(torch.load(os.path.join(args.checkpoint_dir, f'VAE_{args.load_model_epoch}.pth')))
  %cd /content
print(model)

#Loss function
criterion = nn.MSELoss()

#Optimizer
optimizer = optim.Adam(model.parameters(), lr=args.lr)

#Scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = args.n_epochs//5, gamma = 0.1)

train_dataset = Datasets('train')
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
dev_dataset = Datasets('valid')
dev_loader = torch.utils.data.DataLoader(dataset=dev_dataset, batch_size=args.sample_batch_size, shuffle=True)

ConvAutoencoder(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): PReLU(num_parameters=1)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
    (3): ConvTranspose2d(64, 3, kernel_size=(2, 2), stride=(2, 2))
  )
)


In [7]:
#Epochs
torch.cuda.empty_cache()

if not os.path.exists(args.sample_dir):
    os.makedirs(args.sample_dir)
for epoch in range(args.n_epochs):
    start_time = time.time()
    train_loss = 0.0
    model.train()
    #Training
    for i, tr_data in enumerate(train_loader):
        # if i == 10: break
        hr = tr_data['hr'].to(args.device)
        lr = tr_data['lr'].to(args.device)

        optimizer.zero_grad()
        outputs = model(lr)

        loss = criterion(outputs, hr)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        if i == 0:
          result = torch.cat((hr, outputs, lr), 2)
          save_image(result, os.path.join(args.sample_dir, f"SR_{epoch}.png"))
    train_end_time = time.time()
    PSNR, SSIM, valid_loss  = 0, 0, 0
    for i, tr_data in enumerate(dev_loader):
        model.eval()
        # if i == 10: break
        with torch.no_grad():
          hr = tr_data['hr'].to(args.device)
          lr = tr_data['lr'].to(args.device)

          optimizer.zero_grad()
          outputs = model(lr)

          loss = criterion(outputs, hr)
          hr = hr.permute(0, 2, 3, 1).cpu().detach().numpy()
          outputs = outputs.permute(0, 2, 3, 1).cpu().detach().numpy()

          valid_loss += loss.item()
          # print(hr.shape)
          # print(outputs.shape)
          PSNR += peak_signal_noise_ratio(hr[0], outputs[0])
          SSIM += structural_similarity(hr[0], outputs[0], multichannel=True)

    end_time = time.time()
    scheduler.step()
    train_loss = train_loss/len(train_loader)
    valid_loss = valid_loss/len(dev_loader)
    PSNR = PSNR / len(dev_loader)
    SSIM = SSIM / len(dev_loader)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
    print('Epoch: {} \tTraining Time: {:.6f}'.format(epoch, train_end_time-start_time))
    print('Epoch: {} \tValidation Loss: {:.6f}'.format(epoch, valid_loss))
    print('Epoch: {} \tPSNR: {:.6f}'.format(epoch, PSNR))
    print('Epoch: {} \tSSIM: {:.6f}\n'.format(epoch, SSIM))
    print('Epoch: {} \tValidation Time: {:.6f}\n'.format(epoch, end_time-train_end_time))

    
    if epoch >= 10:
      %cd ./gdrive/My Drive/11785/HW5/
      if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)
      torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, f"VAE_{epoch}.pth"))
      %cd /content

In [18]:
# change to google drive to generate the result using trained model
%cd /content/gdrive/My Drive/11785/HW5

def test(target_folder, load_model_epoch, args, resize=False): # generate image from testset
    if resize:
      resize_image('test_lr/'+target_folder, resize, 'test_resize/'+target_folder)
      test_dataset = Datasets(mode='test_resize/'+target_folder)
    else:
      test_dataset = Datasets(mode='test_lr/'+target_folder)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.sample_batch_size)
    if not os.path.exists('test_results/'+target_folder):
      os.makedirs('test_results/'+target_folder)
    
    model = ConvAutoencoder().to(args.device)
    model.load_state_dict(torch.load(os.path.join(args.checkpoint_dir, f'VAE_{load_model_epoch}.pth')))
    with torch.no_grad():
      model.eval()
      for step, image in enumerate(test_loader):
        lr = image['lr'].to(args.device)
        image_name = image['hr']
        outputs = model(lr)
        save_image(outputs, os.path.join('test_results/'+target_folder, image_name[0]))

target_folder = ['large_test', 'small_test', 'comics', 'structures']
load_model_epoch = 133
for t in target_folder:
    test(t, load_model_epoch, args)

/content/gdrive/My Drive/11785/HW5


In [9]:
# calculate the test set PSNR and SSIM using script provided

TEST_LABEL_PATH = 'test_labels/small_test/'
TEST_RESULT_PATH = 'test_results/small_test/'
TEST_RESTORE_PATH =  'test_restore_results/small_test/'


scores = compute_scores(TEST_LABEL_PATH, TEST_RESULT_PATH, TEST_RESTORE_PATH)
print('small_test: ', scores)

TEST_LABEL_PATH = 'test_labels/large_test/'
TEST_RESULT_PATH = 'test_results/large_test/'
TEST_RESTORE_PATH =  'test_restore_results/large_test/'

scores = compute_scores(TEST_LABEL_PATH, TEST_RESULT_PATH, TEST_RESTORE_PATH)
print('large_test: ', scores)

TEST_LABEL_PATH = 'test_labels/comics/'
TEST_RESULT_PATH = 'test_results/comics/'
TEST_RESTORE_PATH =  'test_restore_results/comics/'

scores = compute_scores(TEST_LABEL_PATH, TEST_RESULT_PATH, TEST_RESTORE_PATH)
print('comic: ', scores)

TEST_LABEL_PATH = 'test_labels/structures/'
TEST_RESULT_PATH = 'test_results/structures/'
TEST_RESTORE_PATH =  'test_restore_results/structures/'

scores = compute_scores(TEST_LABEL_PATH, TEST_RESULT_PATH, TEST_RESTORE_PATH)
print('structures: ', scores)