In [None]:
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader

import copy
from tqdm import tqdm

from datasets import ImageDataset
from utils import *

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Model

In [None]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=5, stride=1, padding=2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        # for i, im in enumerate(x[0], 1):
        #     plt.imsave(f"./srcnn_layers/l1/{i}.png", np.array(tensor_to_image(im.cpu())), cmap='gray', pil_kwargs={'compress_level':0})
        x = self.relu(self.conv2(x))
        # for i, im in enumerate(x[0], 1):
        #     plt.imsave(f"./srcnn_layers/l2/{i}.png", np.array(tensor_to_image(im.cpu())), cmap='gray', pil_kwargs={'compress_level':0})
        x = self.conv3(x)
        return x
    
    def process_image(self, lr_tensor):
        self.eval()
        with torch.no_grad():
            hr_tensor = self(lr_tensor.unsqueeze(0).to(device)).squeeze().cpu()
        return hr_tensor
    
    def get_name(self):
        return "SRCNN"

# Train

In [None]:
def train_model(
        model, 
        train_dataloader, 
        eval_dataloader, 
        criterion, 
        optimizer, 
        trained_path, 
        start_epoch=1, 
        end_epoch=50, 
        device='cpu',
        rgb=False):
    
    tqdmEpoch = start_epoch

    num_epochs = end_epoch-start_epoch+1
    
    best_psnr = -float('inf')
    best_ssim = -float('inf')
    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = -1

    with tqdm(total=len(train_dataloader) * num_epochs, desc=f'Epoch {tqdmEpoch}/{end_epoch}', unit=' patches') as pbar:
        for epoch in range(start_epoch, end_epoch + 1):
            #eval
            avg_psnr, avg_ssim = eval(model, eval_dataloader, device, rgb=rgb)
            if avg_psnr > best_psnr and avg_ssim > best_ssim:
                best_psnr = avg_psnr
                best_ssim = avg_ssim
                best_epoch = tqdmEpoch-1
                best_weights = copy.deepcopy(model.state_dict())

            pbar.set_description_str(f'Epoch {tqdmEpoch}/{end_epoch} | PSNR: {avg_psnr} | SSIM: {avg_ssim}', refresh=True)
            
            #train
            avg_loss = train(model, train_dataloader, optimizer, criterion, pbar, device)

            tqdmEpoch += 1
            pbar.set_postfix(loss=avg_loss)

            if epoch % 50 == 0:
                save(model.state_dict(), model.get_name(), epoch, trained_path)

        # final eval
        avg_psnr, avg_ssim = eval(model, eval_dataloader, device, rgb=rgb)
        if avg_psnr > best_psnr and avg_ssim > best_ssim:
            best_epoch = tqdmEpoch-1
            best_weights = copy.deepcopy(model.state_dict())

        pbar.set_description_str(f'Epoch {tqdmEpoch-1}/{end_epoch} | PSNR: {avg_psnr} | SSIM: {avg_ssim}', refresh=True)

    save(best_weights, model.get_name(), best_epoch, trained_path, best=True)

# Showcase

### Settings

In [None]:
batch_size = 4
learning_rate = 1e-4
start_epoch = 1 # 21884/epoch
end_epoch = 200
upscale_factor = 3
metrics_rgb = False # Evaluate psnr and ssim on rgb channels or Y channel

train_path = "./Datasets/T91"
eval_path = "./Datasets/Set5"

set5_path = "./Datasets/Set5"
set14_path = "./Datasets/Set14"
urban100_path = "./Datasets/urban100"
bsd100_path = "./Datasets/BSD100"
manga109_path = "./Datasets/manga109"

mode = "load"
# mode =  "train"
# mode = "load-train"

trained_path = "TrainedModels/srcnn/X" + str(upscale_factor) + "/"

pretrained = ["srcnn_best_175.pt", "srcnn_best_156_bs4.pt", "srcnn_best_185.pt"] # x2 / x3 / x4
load_model_name = pretrained[upscale_factor-2]

### Dataloader

In [None]:
train_dataset = ImageDataset(train_path, upscale_factor, train=True, bic_up=True)
eval_dataset = ImageDataset(eval_path, upscale_factor, train=False, bic_up=True)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=False)

### Train

In [None]:
model = SRCNN().to(device)
criterion = nn.MSELoss().to(device)

optimizer = optim.Adam([
    {'params': model.conv1.parameters(), 'lr': learning_rate},
    {'params': model.conv2.parameters(), 'lr': learning_rate},
    {'params': model.conv3.parameters(), 'lr': learning_rate * 0.1}
])

if mode == "train":
    train_model(model, train_dataloader, eval_dataloader, criterion, optimizer, trained_path, start_epoch, end_epoch, device, metrics_rgb)
elif mode == "load-train":
    model.load_state_dict(torch.load(trained_path + load_model_name, weights_only=False))
    print("Loaded model: " + load_model_name)
    train_model(model, train_dataloader, eval_dataloader, criterion, optimizer, trained_path, start_epoch, end_epoch, device, metrics_rgb)
else:
    model.load_state_dict(torch.load(trained_path + load_model_name, weights_only=False))
    print("Loaded model: " + load_model_name)
    print(model)

### Results

In [None]:
set5_dataset = ImageDataset(set5_path, upscale_factor=upscale_factor, bic_up=True)
set14_dataset = ImageDataset(set14_path, upscale_factor=upscale_factor, bic_up=True)
# bsd100_dataset = ImageDataset(bsd100_path, upscale_factor=upscale_factor, bic_up=True)
# urban100_dataset = ImageDataset(urban100_path, upscale_factor=upscale_factor, bic_up=True)
# manga109_dataset = ImageDataset(manga109_path, upscale_factor=upscale_factor, bic_up=True)

In [None]:
show_comparison_picture(model, set5_dataset, 2, (50, 50), scale=upscale_factor, other_model=False, rgb=metrics_rgb)

In [None]:
print("Model:", load_model_name)
bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, set5_dataset, upscale_factor, rgb=metrics_rgb)
print("Set5\n", f"Bicubic / {round(bic_psnr, 4)} / {round(bic_ssim, 4)}\n", f"SRCNN / {round(up_psnr, 4)} / {round(up_ssim, 4)}")

bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, set14_dataset, upscale_factor, rgb=metrics_rgb)
print("Set14\n", f"Bicubic / {round(bic_psnr, 4)} / {round(bic_ssim, 4)}\n", f"SRCNN / {round(up_psnr, 4)} / {round(up_ssim, 4)}")

# bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, bsd100_dataset, upscale_factor, rgb=metrics_rgb)
# print("BDS100\n", f"Bicubic / {round(bic_psnr, 4)} / {round(bic_ssim, 4)}\n", f"SRCNN / {round(up_psnr, 4)} / {round(up_ssim, 4)}")

# bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, urban100_dataset, upscale_factor, rgb=metrics_rgb)
# print("urban100\n", f"Bicubic / {round(bic_psnr, 4)} / {round(bic_ssim, 4)}\n", f"SRCNN / {round(up_psnr, 4)} / {round(up_ssim, 4)}")

# bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, manga109_dataset, upscale_factor, rgb=metrics_rgb)
# print("manga109\n", f"Bicubic / {round(bic_psnr, 4)} / {round(bic_ssim, 4)}\n", f"SRCNN / {round(up_psnr, 4)} / {round(up_ssim, 4)}")

In [None]:
dataset_name = "Set5"
# dataset_name = "Set14"
# dataset_name = "BSD100"
# dataset_name = "urban100"
# dataset_name = "manga109"

gt_path = f"./Datasets/{dataset_name}/{dataset_name}_HR"
up_path_srcnn = f"./Results/{dataset_name}/SRCNN/X{str(upscale_factor)}"
up_path_esrt = f"./Results/{dataset_name}/ESRT/X{str(upscale_factor)}"
up_path_esrt_paper = f"./Results/{dataset_name}/ESRT_paper/X{str(upscale_factor)}"

print(f"{dataset_name} x{upscale_factor} results for:")
print("SRCNN")
metrics_from_results(gt_path, up_path_srcnn, metrics_rgb)
print("ESRT")
metrics_from_results(gt_path, up_path_esrt, metrics_rgb)
print("ESRT paper")
metrics_from_results(gt_path, up_path_esrt_paper, metrics_rgb)