# STEP 2: Define Evaluation Metrics
* PSNR
* SSIM

In [None]:
import numpy as np
import math
import cv2
import matplotlib.pyplot as plt

def plt_img(input_array, model_x, w, h):
    input_img = input_array[0]
    input_img = input_img.transpose((1, 2, 0))
    
    if model_x == "SRCNN":
        plt.imshow(input_img)
    elif model_x == "ESPCN":
        input_img_resize = cv2.resize(input_img, (w, h), interpolation = cv.INTER_AREA)
        plt.imshow(input_img_resize)
    plt.show()

def PSNR(original, compressed):
    mse = np.mean( (original/255. - compressed/255.) ** 2 )
    if mse < 1.0e-10: return 100
    PIXEL_MAX = 1
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

def calculate_PSNR(original_batch, compressed_batch, batch_size):
    PSNR_TOTAL = 0
    for num in range(batch_size):
        psnr = PSNR(original_batch[num], compressed_batch[num])
        PSNR_TOTAL = PSNR_TOTAL + psnr
    return PSNR_TOTAL

def SSIM(img1, img2):
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


def calculate_SSIM(img1, img2):
    if not img1.shape == img2.shape:
        raise ValueError('Input images must have the same dimensions.')
    if img1.ndim == 2:
        return ssim(img1, img2)
    elif img1.ndim == 3:
        if img1.shape[2] == 3:
            ssims = []
            for i in range(3):
                ssims.append(ssim(img1, img2))
            return np.array(ssims).mean()
        elif img1.shape[2] == 1:
            return ssim(np.squeeze(img1), np.squeeze(img2))
    else:
        raise ValueError('Wrong input image dimensions.')

# STEP 3: Design Custom Dataset

In [None]:
import cv2 as cv
import os
from torch.utils.data import Dataset, DataLoader

class DIV2K_Dataset(Dataset):
    
    def __init__(self, width, height, scale, path_to_imgs, model_x, transform = None):
        self.model_x = model_x
        self.width = width
        self.height = height
        self.scale = scale
        self.path_to_imgs = path_to_imgs
        self.length = len(os.listdir(path_to_imgs))
        self.transform = transform
        
    def __getitem__(self, index):
        # Interpolation: INTER_CUBIC, INTER_NEAREST, INTER_LINEAR, INTER_LANCZOS4, 【INTER_AREA】
        img = cv.imread(self.path_to_imgs + os.listdir(self.path_to_imgs)[index])
        img_rgb = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        h, w, c = img_rgb.shape
        if (h > w):
            img_rgb = img_rgb.transpose((1,0,2))
        
        img_hr   = cv.resize(img_rgb,   (self.width             , self.height             ), interpolation = cv.INTER_AREA)
        img_lr_1 = cv.resize(img_hr ,   (self.width //self.scale, self.height //self.scale), interpolation = cv.INTER_AREA)
        img_lr_2 = cv.resize(img_lr_1 , (self.width             , self.height             ), interpolation = cv.INTER_AREA)
        if self.transform and self.model_x == "SRCNN":
            img_lr_tensor = self.transform(img_lr_2)
            img_hr_tensor = self.transform(img_hr)
        
        elif self.transform and self.model_x == "ESPCN":
            img_lr_tensor = self.transform(img_lr_1)
            img_hr_tensor = self.transform(img_hr)
            
        return (img_lr_tensor, img_hr_tensor)
    
    def __len__(self):
        return self.length

# STEP 4: Define Model ==> SRCNN
* Github Repo Link: https://github.com/yjn870/SRCNN-pytorch
* Difference:
    1. Added the zero padding
    2. Used the Adam instead of the SGD
    3. Removed the weights initialization

In [None]:
from torch import nn

class SRCNN(nn.Module):
    def __init__(self, num_channels = 3):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size = 9, padding = 9 // 2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size = 5, padding = 5 // 2)
        self.conv3 = nn.Conv2d(32, num_channels, kernel_size = 5, padding = 5 // 2)
        self.relu = nn.ReLU(inplace = True)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x
    
    
class ESPCN(nn.Module):
    def __init__(self, scale_factor, num_channels=2):
        super(ESPCN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=5, padding=5//2),
            nn.Tanh(),
            nn.Conv2d(64, 32, kernel_size=3, padding=3//2),
            nn.Tanh(),
        )
        self.last_part = nn.Sequential(
            nn.Conv2d(32, num_channels * (scale_factor ** 2), kernel_size=3, padding=3 // 2),
            nn.PixelShuffle(scale_factor)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.in_channels == 32:
                    nn.init.normal_(m.weight.data, mean=0.0, std=0.001)
                    nn.init.zeros_(m.bias.data)
                else:
                    nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                    nn.init.zeros_(m.bias.data)

    def forward(self, x):
        x = self.first_part(x)
        x = self.last_part(x)
        return x

# STEP 5: Summerize Model & Set DataLoader

In [None]:
from torchsummary import summary
from torchvision import transforms
import torch

width, height = 2000, 1600
scale = 8
path_to_train_imgs = "../__HW8_DATA/DIV2K_train_HR/"
path_to_valid_imgs = "../__HW8_DATA/DIV2K_valid_HR/"
trans_train = transforms.Compose([transforms.ToTensor()])
trans_valid = transforms.Compose([transforms.ToTensor()]) 
batch_size = 3
model_x = "ESPCN"

Train_Dataset = DIV2K_Dataset(width = width, height = height, scale = scale, path_to_imgs = path_to_train_imgs, model_x = model_x, transform = trans_train)
Train_Dataloader = DataLoader(Train_Dataset, batch_size = batch_size, shuffle = True, num_workers = 0)

Valid_Dataset = DIV2K_Dataset(width = width, height = height, scale = scale, path_to_imgs = path_to_valid_imgs, model_x = model_x, transform = trans_valid)
Valid_Dataloader = DataLoader(Valid_Dataset, batch_size = batch_size, shuffle = True, num_workers = 0)

if False and model_x == "SRCNN":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SRCNN(num_channels = 3)
    model = model.to(device)
    summary(model, input_size = (3, width, height))  
elif False and model_x == "ESPCN":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ESPCN(scale_factor = scale, num_channels = 3)
    model = model.to(device)
    summary(model, input_size = (3, width//scale, height//scale))
    # Estimated total size (MB): 220.60

# STEP 6: Set Hyper Parameter

In [None]:
import torch
from torch import nn
import torch.optim as optim

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

if model_x == "SRCNN":
    model = SRCNN(num_channels = 3).to(device)
elif model_x == "ESPCN":
    model = ESPCN(scale_factor = scale, num_channels = 3).to(device)
    
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr = 1e-4)
epochs = 10

# STEP 7: Train

In [None]:
from tqdm import tqdm

probe_number =  80
loss_list = []
train_psnr_before_list = []
train_psnr_after_list = []
valid_psnr_before_list = []
valid_psnr_after_list = []

for epoch in range(epochs):
    
    model.train()
    running_loss = 0
    running_psnr_before = 0
    running_psnr_after = 0
    inner_epoch_count = 0
    train_total = 0
    for data in tqdm(Train_Dataloader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        preds = model(inputs).clamp(0.0, 1.0)
        
        loss = criterion(preds, labels)
        if model_x == "SRCNN":
            psnr_before= calculate_PSNR(inputs.data.cpu().numpy(), labels.data.cpu().numpy(), labels.size(0))
            running_psnr_before = running_psnr_before + psnr_before

        psnr_after = calculate_PSNR( preds.data.cpu().numpy(), labels.data.cpu().numpy(), labels.size(0))
        running_psnr_after  = running_psnr_after  + psnr_after

        running_loss = running_loss + loss.item()
        
        train_total = train_total + labels.size(0)
        inner_epoch_count = inner_epoch_count + 1
        if inner_epoch_count % probe_number == probe_number - 1:
            for ele in [inputs.data.cpu().numpy(), preds.data.cpu().numpy(), labels.data.cpu().numpy()]:
                plt_img(ele, model_x,width, height)
            loss_list.append(running_loss/train_total)
            if model_x == "SRCNN":
                train_psnr_before_list.append(running_psnr_before/train_total)
            
            train_psnr_after_list.append(running_psnr_after/train_total)
            print(f"Before: {round(running_psnr_before/train_total, 4)}, After: {round(running_psnr_after/train_total, 4)}, Diff: {(round((running_psnr_after - running_psnr_before)/train_total, 4))}, Loss: {round(running_loss/train_total, 5)}")
            train_total = 0
            running_psnr_before = 0
            running_psnr_after = 0
            running_loss = 0
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    model.eval()
    running_psnr_before = 0
    running_psnr_after = 0
    valid_total = 0
    for data in tqdm(Valid_Dataloader):
        inputs, labels = data[0].to(device), data[1].to(device)
        valid_total = valid_total + labels.size(0)
        with torch.no_grad():
            preds = model(inputs).clamp(0.0, 1.0)
                
        if model_x == "SRCNN":
            psnr_before= calculate_PSNR(inputs.data.cpu().numpy(), labels.data.cpu().numpy(), labels.size(0))
            valid_psnr_before_list.append(psnr_before)
            running_psnr_before = running_psnr_before + psnr_before
        
        psnr_after = calculate_PSNR( preds.data.cpu().numpy(), labels.data.cpu().numpy(), labels.size(0))
        valid_psnr_after_list.append(psnr_after/labels.size(0))
        running_psnr_after = running_psnr_after + psnr_after
        
    print(f"Before: {round(running_psnr_before/valid_total, 4)}, After: {round(running_psnr_after/valid_total, 4)}, Diff: {round((running_psnr_after - running_psnr_before)/valid_total, 4)}")
            

# STEP 8: Plot

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_psnr_before_list)
plt.plot(train_psnr_after_list)
plt.show()

#plt.plot(valid_psnr_before_list)
plt.plot(valid_psnr_after_list)
plt.show()

plt.plot(loss_list)
plt.show()