In [None]:
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader

import copy
import math

from tqdm import tqdm

from datasets import Div2kDataset, ImageDataset
from utils import *

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Model

In [None]:
# CARN-M modules

class ResidualEBlock(nn.Module):
    def __init__(self, channels, groups=4):
        super(ResidualEBlock, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1, groups=groups),
            nn.ReLU(True),
            nn.Conv2d(channels, channels, 3, padding=1, groups=groups),
            nn.Conv2d(channels, channels, 1),
        )
        self.relu = nn.ReLU(True)

    def forward(self, x):
        out = self.body(x)
        out = torch.add(out, x)
        out = self.relu(out)

        return out

class RecursiveBlock(nn.Module):
    def __init__(self, channels, num_blocks=3):
        super(RecursiveBlock, self).__init__()
        self.num_blocks = num_blocks
        self.rec_b = ResidualEBlock(channels)
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(channels * (i + 2), channels, 1),
                nn.ReLU(True)
            ) for i in range(num_blocks)
        ])
        
    def forward(self, x):
        features = [x]
        current = x
        
        for i in range(self.num_blocks):
            rec_b_out = self.rec_b(current)
            concat = torch.cat(features + [rec_b_out], 1)
            current = self.convs[i](concat)
            features.append(rec_b_out)
            
        return current

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(channels, channels, 3, padding=1),
        )
        self.relu = nn.ReLU(True)

    def forward(self, x):
        out = self.body(x)
        out = torch.add(out, x)
        out = self.relu(out)

        return out

class CascadingBlock(nn.Module):
    def __init__(self, channels, num_blocks=3):
        super(CascadingBlock, self).__init__()

        self.body = nn.ModuleList([
            nn.ModuleList([
                ResidualBlock(channels),
                nn.Sequential(
                    nn.Conv2d(channels * (i + 2), channels, 1),
                    nn.ReLU(True)
                )
            ]) for i in range(num_blocks)
        ])

    def forward(self, x):
        res1 = self.body[0][0](x)
        concat1 = torch.cat([res1, x], 1)
        conv1 = self.body[0][1](concat1)

        res2 = self.body[1][0](conv1)
        concat2 = torch.cat([concat1, res2], 1)
        conv2 = self.body[1][1](concat2)

        res3 = self.body[2][0](conv2)
        concat3 = torch.cat([concat2, res3], 1)
        conv3 = self.body[2][1](concat3)

        return conv3

class Upsampler(nn.Sequential):
    def __init__(self, scale, channels):
        m = []
        if (scale & (scale - 1)) == 0: # if scale == 2^n
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(channels, 4*channels, 3, padding=1))
                m.append(nn.PixelShuffle(2))
        elif scale == 3:
            m.append(nn.Conv2d(channels, 9*channels, 3, padding=1))
            m.append(nn.PixelShuffle(3))
        else:
            raise NotImplementedError
        
        super(Upsampler, self).__init__(*m)

In [None]:
class CARN(nn.Module):
    def __init__(self, scale=3, light=False):
        super(CARN, self).__init__()
        channels = 64
        self.light = light

        self.head = nn.Conv2d(3, channels, 3, padding=1)

        num_blocks = 3
        self.body = nn.ModuleList([
            nn.ModuleList([
                # CascadingBlock(channels),
                RecursiveBlock(channels, 3) if light else CascadingBlock(channels),
                nn.Sequential(
                    nn.Conv2d(channels * (i+2), channels, 1),
                    nn.ReLU(True)
                )
            ]) for i in range(num_blocks)
        ])

        self.upsampler = Upsampler(scale, channels)

        self.tail = nn.Conv2d(channels, 3, 3, padding=1)

        self.init_weights()

    def forward(self, x):
        out = self.head(x)

        res1 = self.body[0][0](out)
        concat1 = torch.cat([res1, out], 1)
        conv1 = self.body[0][1](concat1)

        res2 = self.body[1][0](conv1)
        concat2 = torch.cat([concat1, res2], 1)
        conv2 = self.body[1][1](concat2)

        res3 = self.body[2][0](conv2)
        concat3 = torch.cat([concat2, res3], 1)
        conv3 = self.body[2][1](concat3)

        out = self.upsampler(conv3)

        out = self.tail(out)

        return out

    def init_weights(self) -> None:
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

    def load_components(self, components, model_name):
        pretrained = torch.load(model_name, weights_only=False)
        state_dict = {}

        for name, param in pretrained.items():
            for component in components:
                if component in name:
                    state_dict[name] = param
                    
        self.load_state_dict(state_dict, strict=False)


    def process_image(self, lr_tensor):
        self.eval()
        with torch.no_grad():
            hr_tensor = self(lr_tensor.unsqueeze(0).to(device)).squeeze().cpu()
        return hr_tensor
    
    def get_name(self):
        return "CARN-M" if self.light else "CARN"

# Train

In [None]:
def train_model(
        model, 
        train_dataloader, 
        eval_dataloader, 
        criterion, 
        optimizer, 
        learning_rate,
        trained_path, 
        start_epoch=1, 
        end_epoch=50,
        device='cpu',
        rgb=False):
    
    tqdmEpoch = start_epoch

    num_epochs = end_epoch-start_epoch+1
    
    best_psnr = -float('inf')
    best_ssim = -float('inf')
    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = -1

    with tqdm(total=len(train_dataloader) * num_epochs, desc=f'Epoch {tqdmEpoch}/{end_epoch}', unit='patches') as pbar:
        for epoch in range(start_epoch, end_epoch + 1):
            # eval
            avg_psnr, avg_ssim = eval(model, eval_dataloader, device, -1, rgb)
            if avg_psnr > best_psnr and avg_ssim > best_ssim:
                best_psnr = avg_psnr
                best_ssim = avg_ssim
                best_epoch = tqdmEpoch-1
                best_weights = copy.deepcopy(model.state_dict())

            pbar.set_description_str(f'Epoch {tqdmEpoch}/{end_epoch} | PSNR: {avg_psnr} | SSIM: {avg_ssim}', refresh=True)

            # decrease lr in half every 4*10^5 steps
            cur_lr = optimizer.param_groups[0]['lr']
            factor = len(train_dataloader) * num_epochs // 400000
            lr = learning_rate * (0.5 ** factor)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            if cur_lr != lr:
                print(f"New learning rate {lr} at epoch {epoch}")

            # train
            avg_loss = train(model, train_dataloader, optimizer, criterion, pbar, device)

            tqdmEpoch += 1
            pbar.set_postfix(loss=avg_loss)

            if epoch % 5 == 0:
                save(model.state_dict(), model.get_name(), epoch, trained_path)

        # final eval
        avg_psnr, avg_ssim = eval(model, eval_dataloader, device, -1, rgb)
        if avg_psnr >= best_psnr and avg_ssim >= best_ssim:
            best_epoch = tqdmEpoch-1
            best_weights = copy.deepcopy(model.state_dict())

        pbar.set_description_str(f'Epoch {tqdmEpoch-1}/{end_epoch} | PSNR: {avg_psnr} | SSIM: {avg_ssim}', refresh=True)

    save(best_weights, model.get_name(), best_epoch, trained_path, best=True)

# Showcase

### Settings

In [None]:
batch_size = 32
learning_rate = 2e-4
start_epoch=47
end_epoch=50
upscale_factor = 3
patch_size = 64 * upscale_factor
num_workers = 16
metrics_rgb = False # Evaluate psnr and ssim on rgb channels or Y channel
carn_m = False

train_path = "./Datasets/DIV2K"
eval_path = "./Datasets/Set5"

set5_path = "./Datasets/Set5"
set14_path = "./Datasets/Set14"
urban100_path = "./Datasets/urban100"
bsd100_path = "./Datasets/BSD100"
manga109_path = "./Datasets/manga109"

mode = "load"
# mode =  "train"
# mode = "load-train"

"""
part_load: pretrained model path
Loads all parameters except the upsampler
Use this if you trained one scale and want to train others
Works only in "train" mode
"""
# part_load = "TrainedModels/carn/X3/carn_best_46.pt"
part_load = ""

trained_path = "TrainedModels/carn/X" + str(upscale_factor) + "/"

pretrained = ["carn_best_9.pt", "carn_best_46.pt", "carn_best_13.pt"] # x2 / x3 / x4
load_model_name = pretrained[upscale_factor-2]

### Dataloader

In [None]:
train_dataset = Div2kDataset(train_path, train=True, repeat=40, upscale_factor=upscale_factor, patch_size=patch_size)
eval_dataset = ImageDataset(eval_path, upscale_factor=upscale_factor)

train_dataloader = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=False)

### Train

In [None]:
model = CARN(upscale_factor, carn_m).to(device)
criterion = nn.L1Loss().to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

num_params = 0
for param in model.parameters():
    num_params += param.numel()
print(f'Total number of parameters: {num_params}')


if mode == "train":
    if part_load != "":
        model.load_components(['head', 'body', 'tail'], part_load)
        print(f"Loaded partial weights from {part_load}")
    train_model(model, train_dataloader, eval_dataloader, criterion, optimizer, learning_rate, trained_path, start_epoch, end_epoch, device, metrics_rgb)
elif mode == "load-train":
    model.load_state_dict(torch.load(trained_path + load_model_name, weights_only=False))
    print("Loaded model: " + load_model_name)
    train_model(model, train_dataloader, eval_dataloader, criterion, optimizer, learning_rate, trained_path, start_epoch, end_epoch, device, metrics_rgb)
else:
    model.load_state_dict(torch.load(trained_path + load_model_name, weights_only=False))
    print("Loaded model: " + load_model_name)

### Results

In [None]:
set5_dataset = ImageDataset(set5_path, upscale_factor=upscale_factor)
set14_dataset = ImageDataset(set14_path, upscale_factor=upscale_factor)
# bsd100_dataset = ImageDataset(bsd100_path, upscale_factor=upscale_factor)
# urban100_dataset = ImageDataset(urban100_path, upscale_factor=upscale_factor)
# manga109_dataset = ImageDataset(manga109_path, upscale_factor=upscale_factor)

In [None]:
show_comparison_picture(model, set5_dataset, 2, (50, 50), scale=upscale_factor, other_model="ESRT", rgb=metrics_rgb)

In [None]:
print("Model:", load_model_name)
bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, set5_dataset, upscale_factor, rgb=metrics_rgb)
print("Set5\n", f"Bicubic / {bic_psnr} / {bic_ssim}\n", f"CARN / {up_psnr} / {up_ssim}")

bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, set14_dataset, upscale_factor, rgb=metrics_rgb)
print("Set14\n", f"Bicubic / {bic_psnr} / {bic_ssim}\n", f"CARN / {up_psnr} / {up_ssim}")

# bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, bsd100_dataset, upscale_factor, rgb=metrics_rgb)
# print("BSD100\n", f"Bicubic / {bic_psnr} / {bic_ssim}\n", f"CARN / {up_psnr} / {up_ssim}")

# bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, urban100_dataset, upscale_factor, chop_size=25000, rgb=metrics_rgb)
# print("urban100\n", f"Bicubic / {bic_psnr} / {bic_ssim}\n", f"CARN / {up_psnr} / {up_ssim}")

# bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, manga109_dataset, upscale_factor, chop_size=45000, rgb=metrics_rgb)
# print("manga109\n", f"Bicubic / {bic_psnr} / {bic_ssim}\n", f"CARN / {up_psnr} / {up_ssim}")

In [None]:
datasets = [
    "Set5", 
    # "Set14", 
    # "BSD100", 
    # "urban100", 
    # "manga109"
    ]

for dataset_name in datasets:
    gt_path = f"./Datasets/{dataset_name}/{dataset_name}_HR"
    up_path_srcnn = f"./Results/{dataset_name}/SRCNN/X{str(upscale_factor)}"
    up_path_carn = f"./Results/{dataset_name}/CARN/X{str(upscale_factor)}"
    up_path_esrt = f"./Results/{dataset_name}/ESRT/X{str(upscale_factor)}"
    # up_path_esrt_paper = f"./Results/{dataset_name}/ESRT_paper/X{str(upscale_factor)}"

    print(f"{dataset_name} x{upscale_factor} results for:")
    print("SRCNN")
    metrics_from_results(gt_path, up_path_srcnn, metrics_rgb)
    print("CARN")
    metrics_from_results(gt_path, up_path_carn, metrics_rgb)
    print("ESRT")
    metrics_from_results(gt_path, up_path_esrt, metrics_rgb)
    # print("ESRT paper")
    # metrics_from_results(gt_path, up_path_esrt_paper, metrics_rgb)