# 0. Setup

In [1]:
# Mount to Google Drive
from google.colab import drive
drive.mount('drive')

Mounted at drive


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torchvision.models import vgg19
from PIL import Image
import numpy as np
import os

# GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
folder_path = '/content/drive/MyDrive/SRAE'

import os
os.chdir(folder_path+'/functions')

# 1. Data Loader

In [None]:
import data_loader
import importlib
importlib.reload(data_loader)

batch_size = 4
scale = 4

lr_size = 48
num_patches=10

guide_name = "bicubic"

dataloader = data_loader.load_ds(f"Flickr2K_A_(x{scale})", guide_name, batch_size, scale, lr_size=lr_size, num_patches=num_patches)
# dataloader = data_loader.load_ds(f"BSD100_(x{scale})", guide_name, 1, scale, lr_size=32, num_patches=num_patches)

1325 images loaded


In [None]:
bsdLoader = data_loader.load_ds(f"BSD100_(x{scale})", guide_name, 1, scale, lr_size=None)
s5Loader  = data_loader.load_ds(f"set5_(x{scale})" , guide_name, 1, scale, lr_size=None)
s14Loader = data_loader.load_ds(f"set14_(x{scale})", guide_name, 1, scale, lr_size=None)
urbLoader = data_loader.load_ds(f"urban100_(x{scale})", guide_name, 1, scale, lr_size=None)

testLoaders = []
testLoaders.append((bsdLoader, "BSD100"))
testLoaders.append((s14Loader, "set14"))

100 images loaded
5 images loaded
14 images loaded
100 images loaded


# 2. Train (x4 SR Model)

In [None]:
modelName = "test"
depth = 5

Model_name = f"{modelName}_x{scale}_d{depth}" # save name
num_epochs = 200

# SR(x4) Decoder
import NET
importlib.reload(NET)

# Initialize the model
model = NET.SRAE(scale=scale, depth=depth)
if depth > 1:
  path = f"../models/save/{modelName}_x{scale}_d{depth-1}.pth"
  model.load_model(path, depth-1)

optimizer = optim.Adam(model.parameters(), lr=1e-5)
epoch=0

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 216MB/s]
  state_dict = torch.load(path)


Decoder 1 loaded successfully.
Decoder 2 loaded successfully.
Decoder 3 loaded successfully.
Decoder 4 loaded successfully.


In [None]:
import net_trainer
import net_tester
importlib.reload(net_trainer)
importlib.reload(net_tester)
net_trainer.trainer(device, model, optimizer, num_epochs, dataloader, testLoaders, Model_name, st_epoch=epoch)

In [None]:
####if stopped ###
epoch = 32
model.load_state_dict(torch.load(f"../models//temp/{Model_name}/{Model_name}_e{epoch}.pth"))
optimizer = optim.Adam(model.parameters(), lr=1e-4)
optimizer.load_state_dict(torch.load(f"../models/optim/{Model_name}/optim_{Model_name}_e{epoch}.pth"))

  model.load_state_dict(torch.load(f"../models//temp/{Model_name}/{Model_name}_e{epoch}.pth"))
  optimizer.load_state_dict(torch.load(f"../models/optim/{Model_name}/optim_{Model_name}_e{epoch}.pth"))


# 3. Test

In [None]:
# Load Model
model = NET.SRAE(scale=scale, depth=depth).to(device)

epoch = 14

model.load_state_dict(torch.load(f"../models/temp/{Model_name}/{Model_name}_e{epoch}.pth"))
optimizer = optim.Adam(model.parameters(), lr=1e-4)
model.eval()
model.to(device)

  model.load_state_dict(torch.load(f"../models/temp/{Model_name}/{Model_name}_e{epoch}.pth"))


SRAE(
  (encoders): ModuleList(
    (0): None
    (1): block1(
      (feature_extractor): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
    )
    (2): block2(
      (feature_extractor): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
      )
    )
    (3): block3(
      (feature_extractor): Sequential(
   

> save image

In [None]:
import net_tester
importlib.reload(net_tester)

print("[set5]")
net_tester.test_set(s5Loader, device, model, save_path=f"../outputs/{Model_name}/set5")

print("\n[set14]")
net_tester.test_set(s14Loader, device, model, save_path=f"../outputs/{Model_name}/set14")

print("\n[bsd]")
net_tester.test_set(bsdLoader, device, model, save_path=f"../outputs/{Model_name}/BSD100")

print("\n[urban]")
net_tester.test_set(urbLoader, device, model, save_path=f"../outputs/{Model_name}/urban100")

[set5]
Average PSNR: 30.0972 dB
Average SSIM: 0.8834

[set14]
Average PSNR: 26.4901 dB
Average SSIM: 0.7728

[bsd]
Average PSNR: 26.0341 dB
Average SSIM: 0.7311

[urban]
Average PSNR: 24.0167 dB
Average SSIM: 0.7615
