In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader

import copy
import math

from tqdm import tqdm

from datasets import Div2kDataset, ImageDataset
from utils import *

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Model

In [None]:
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, relu=False, prelu=False):
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.relu = nn.ReLU(inplace=True) if relu else None
        self.prelu = nn.PReLU(out_channels) if prelu else None

    def forward(self, x):
        x = self.conv(x)
        if self.relu is not None:
            x = self.relu(x)
        elif self.prelu is not None:
            x
        return x

class Upsampler(nn.Sequential): # pixel shuffle block
    def __init__(self, scale, n_feats):
        m = []
        if (scale & (scale - 1)) == 0: # if scale == 2^n
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(n_feats, 4*n_feats, 3, padding=1))
                m.append(nn.PixelShuffle(2))
        elif scale == 3:
            m.append(nn.Conv2d(n_feats, 9*n_feats, 3, padding=1))
            m.append(nn.PixelShuffle(3))
        else:
            raise NotImplementedError
        
        super(Upsampler, self).__init__(*m)

class Scale(nn.Module):
    def __init__(self, init_value=1e-3):
        super().__init__()
        self.scale = nn.Parameter(torch.FloatTensor([init_value]))

    def forward(self, input):
        return input * self.scale

### Lightweight CNN backbone

In [None]:
class ResidualUnit(nn.Module):
    def __init__(self, in_chanels, out_channels, kernel_size = 3):
        super(ResidualUnit,self).__init__()
        self.reduction = nn.Conv2d(in_chanels, out_channels, kernel_size, padding=kernel_size//2)
        self.expansion = nn.Conv2d(out_channels, in_chanels, kernel_size, padding=kernel_size//2)
        self.relu = nn.PReLU(out_channels)
        self.weight1 = Scale(1)
        self.weight2 = Scale(1)

    def forward(self, x):
        return self.weight1(x) + self.weight2(self.expansion(self.relu(self.reduction(x))))

class ARFB(nn.Module):
    def __init__(self, n_feats):
        super(ARFB, self).__init__()
        self.ru1 = ResidualUnit(n_feats, n_feats // 2, 3)
        self.ru2 = ResidualUnit(n_feats, n_feats // 2, 3)
        self.conv1 = Conv(2*n_feats, n_feats, 1, 1, 0, prelu=True)
        self.conv3 = Conv(n_feats, n_feats, 3, 1, 1, prelu=True)
        self.attention = CA(n_feats)
        self.weight1 = Scale(1)
        self.weight2 = Scale(1)
        self.weight3 = Scale(1)
        self.weight4 = Scale(1)

    def forward(self, x):
        x1 = self.ru1(x)
        x2 = self.ru2(x1)
        x3 = self.conv3(self.attention(self.conv1(torch.cat([self.weight1(x2), self.weight2(x1)], 1))))
        return self.weight3(x)+self.weight4(x3)
    
class CA(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CA, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_du = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction, channel, 1, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y
    
class HPB(nn.Module):
    def __init__(self, n_feats):
        super(HPB, self).__init__()
        self.encoder = ARFB(n_feats)
        self.avgpool = nn.AvgPool2d(kernel_size=2)
        self.decoder_low = ARFB(n_feats)
        self.decoder_high = ARFB(n_feats)
        self.conv = Conv(2*n_feats, n_feats, 1, 1, 0, relu=True)
        self.attention = CA(n_feats)
        self.alise = ARFB(n_feats)
    
    def forward(self, x):
        x1 = self.encoder(x)
        x2 = self.avgpool(x1)
        high = x1 - F.interpolate(x2, size = x.size()[-2:], mode='bilinear', align_corners=True)    # HFM

        high = self.decoder_high(high)

        for i in range(5):
            x2 = self.decoder_low(x2)
        x2 = F.interpolate(x2, size = x.size()[-2:], mode='bilinear', align_corners=True)
        
        return self.alise(self.attention(self.conv(torch.cat([x2, high], dim=1)))) + x
    
class LCB(nn.Module):
    def __init__(self, n_feats, n_blocks):
        super(LCB, self).__init__()
        self.n_blocks = n_blocks
        self.encoders = nn.ModuleList()
        for _ in range(n_blocks):
            self.encoders.append(HPB(n_feats))

        self.encoders = nn.Sequential(*self.encoders)
        
    def forward(self, x):
        blocks_res = []
        for i in range(self.n_blocks):
            x = self.encoders[i](x)
            blocks_res.append(x)
        return blocks_res

### Lightweight transformer backbone

In [None]:
def reverse_patches(images, out_size, ksizes, strides, padding):
    unfold = torch.nn.Fold(output_size = out_size, 
                            kernel_size=ksizes, 
                            dilation=1, 
                            padding=padding, 
                            stride=strides)
    patches = unfold(images)
    return patches

def same_padding(images, ksizes, strides, rates):
    assert len(images.size()) == 4
    batch_size, channel, rows, cols = images.size()
    out_rows = (rows + strides[0] - 1) // strides[0]
    out_cols = (cols + strides[1] - 1) // strides[1]
    effective_k_row = (ksizes[0] - 1) * rates[0] + 1
    effective_k_col = (ksizes[1] - 1) * rates[1] + 1
    padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
    padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
    # Pad the input
    padding_top = int(padding_rows / 2.)
    padding_left = int(padding_cols / 2.)
    padding_bottom = padding_rows - padding_top
    padding_right = padding_cols - padding_left
    paddings = (padding_left, padding_right, padding_top, padding_bottom)
    images = torch.nn.ZeroPad2d(paddings)(images)
    return images

def extract_image_patches(images, ksizes, strides, rates):
    # images: [B, C, W, H]
    images = same_padding(images, ksizes, strides, rates)

    unfold = torch.nn.Unfold(kernel_size=ksizes,
                             dilation=rates,
                             padding=0,
                             stride=strides)
    patches = unfold(images)
    return patches  # [N, C*k*k, L] L=WH

class EMHA(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.reduce = nn.Linear(dim, dim//2, bias=qkv_bias)
        self.qkv = nn.Linear(dim//2, dim//2 * 3, bias=qkv_bias)
        self.scale = qk_scale or head_dim ** -0.5
        self.expand = nn.Linear(dim//2, dim)

        self.attn_drop = nn.Dropout(attn_drop)
    
    def forward(self, x):
        x = self.reduce(x)

        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q_all = torch.split(q, math.ceil(N//4), dim=-2)
        k_all = torch.split(k, math.ceil(N//4), dim=-2)
        v_all = torch.split(v, math.ceil(N//4), dim=-2)

        output = []
        for q,k,v in zip(q_all, k_all, v_all):
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            trans_x = (attn @ v).transpose(1, 2)

            output.append(trans_x)

        x = torch.cat(output,dim=1)
        x = x.reshape(B,N,C)
        x = self.expand(x)
        return x

class MLP(nn.Module):
    # pointwise feed forward
    # potentially enriches the attention output
    def __init__(self, in_features, hidden_features=None, out_features=None):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features//4
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(0.)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class ET(nn.Module):
    def __init__(self, dim=768):
        super(ET, self).__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attention = EMHA(dim=dim, num_heads=8)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(in_features=dim, hidden_features=dim//4)

    def forward(self, x):
        x = extract_image_patches(x, ksizes=[3, 3],
                                     strides=[1, 1],
                                     rates=[1, 1])
        x = x.permute(0,2,1)

        x = x + self.attention(self.norm1(x))
        x = x + self.mlp(self.norm2(x))

        return x

class LTB(nn.Module):
    def __init__(self, n_feats, n_lcb_blocks):
        super(LTB, self).__init__()
        self.attention = ET(dim=288)
        self.reduce = nn.Conv2d(n_lcb_blocks * n_feats, n_feats, 3, padding=1)
        self.alise = nn.Conv2d(n_feats, n_feats, 3, padding=1)
        self.weight1 = Scale(1)
        self.weight2 = Scale(1)

    def forward(self, x, lcb_res):
        _,_,h,w = lcb_res[-1].shape
        out = self.attention(self.reduce(torch.cat(lcb_res,dim=1)))
        out = out.permute(0,2,1)
        out = reverse_patches(out, (h,w), (3,3), 1, 1)
        out = self.alise(out)
        return self.weight1(x) + self.weight2(out)

In [None]:
class ESRT(nn.Module):
    def __init__(self, upscale=3):
        super(ESRT, self).__init__()
        n_feats = 32
        kernel_size = 3
        lcb_blocks = 3

        self.head = nn.Conv2d(3, n_feats, kernel_size, padding=1)

        self.body = nn.Sequential(
            LCB(n_feats, lcb_blocks),
            LTB(n_feats, 3)
        )

        self.tail1 = nn.Sequential(
                Upsampler(upscale, n_feats),
                nn.Conv2d(n_feats, 3, kernel_size, padding=1)
        )
        self.tail2 = nn.Sequential(
                Upsampler(upscale, n_feats),
                Conv(n_feats, 3, kernel_size, 1, 1, relu=True)
        )

    def forward(self, x):
        res1 = self.head(x)
        res2 = res1

        lcb_blocks_res = self.body[0](res1)
        res1 = self.body[1](res1, lcb_blocks_res)

        res1 = self.tail1(res1)
        res2 = self.tail2(res2)
        
        return res1 + res2

    def chunks_forward(self, x, scale, overlap=10, min_size=60000):
        b,c,w,h = x.size()
        w,h = w//2, h//2
        lr_chunks = [
            x[:, :, :w+overlap, :h+overlap],
            x[:, :, w-overlap:, :h+overlap],
            x[:, :, :w+overlap, h-overlap:],
            x[:, :, w-overlap:, h-overlap:]
        ]

        b,c,w,h = lr_chunks[0].size()
        if w*h < min_size:
            hr_chunks = []
            for lr_chunk in lr_chunks:
                hr_chunks.append(self(lr_chunk))
        else:
            hr_chunks = [
                self.chunks_forward(lr_chunk, scale, min_size=min_size) for lr_chunk in lr_chunks
            ]

        b,c,w,h = hr_chunks[0].size()
        overlap *= scale
        c1 = hr_chunks[0][:, :, :w-overlap, :h-overlap]
        c2 = hr_chunks[1][:, :, overlap:, :h-overlap]
        c3 = hr_chunks[2][:, :, :w-overlap, overlap:]
        c4 = hr_chunks[3][:, :, overlap:,  overlap:]

        return torch.cat((torch.cat((c1, c2), dim=2), torch.cat((c3, c4), dim=2)), dim=3)

    def process_image(self, lr_tensor, upscale_factor, chop_size=60000):
        self.eval()
        with torch.no_grad():
            hr_tensor = self.chunks_forward(lr_tensor.unsqueeze(0).to(device), upscale_factor, min_size=chop_size).squeeze().cpu()
        return hr_tensor
    
    def get_name(self):
        return "ESRT"

# Train

In [None]:
def train_model(
        model, 
        train_dataloader, 
        eval_dataloader, 
        criterion, 
        optimizer, 
        learning_rate, 
        upscale_factor,
        trained_path, 
        start_epoch=1, 
        end_epoch=50,
        device='cpu'):
    
    tqdmEpoch = start_epoch

    num_epochs = end_epoch-start_epoch+1
    
    best_psnr = -float('inf')
    best_ssim = -float('inf')
    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = -1

    with tqdm(total=len(train_dataloader) * num_epochs, desc=f'Epoch {tqdmEpoch}/{end_epoch}', unit='patches') as pbar:
        for epoch in range(start_epoch, end_epoch + 1):
            # eval
            avg_psnr, avg_ssim = eval(model, eval_dataloader, device, upscale_factor)
            if avg_psnr > best_psnr and avg_ssim > best_ssim:
                best_psnr = avg_psnr
                best_ssim = avg_ssim
                best_epoch = tqdmEpoch-1
                best_weights = copy.deepcopy(model.state_dict())

            pbar.set_description_str(f'Epoch {tqdmEpoch}/{end_epoch} | PSNR: {avg_psnr} | SSIM: {avg_ssim}', refresh=True)

            # decrease lr in half every 200 epochs
            cur_lr = optimizer.param_groups[0]['lr']
            factor = epoch // 200
            lr = learning_rate * (0.5 ** factor)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            if cur_lr != lr:
                print(f"New learning rate {lr} at epoch {epoch}")

            # train
            avg_loss = train(model, train_dataloader, optimizer, criterion, pbar, device)

            tqdmEpoch += 1
            pbar.set_postfix(loss=avg_loss)

            if epoch % 10 == 0:
                save(model.state_dict(), model.get_name(), epoch, trained_path)

        # final eval
        avg_psnr, avg_ssim = eval(model, eval_dataloader, device, upscale_factor)
        if avg_psnr > best_psnr and avg_ssim > best_ssim:
            best_epoch = tqdmEpoch-1
            best_weights = copy.deepcopy(model.state_dict())

        pbar.set_description_str(f'Epoch {tqdmEpoch-1}/{end_epoch} | PSNR: {avg_psnr} | SSIM: {avg_ssim}', refresh=True)

    save(best_weights, model.get_name(), best_epoch, trained_path, best=True)

# Showcase

### Settings

In [None]:
batch_size = 16
learning_rate = 2e-4
start_epoch=101
end_epoch=150
upscale_factor = 3
patch_size = 96 # | 96 for x2 | 144 for x3 | 192 for x4 |
num_workers = 7

train_path = "./Datasets/DIV2K"
eval_path = "./Datasets/Set5"

set5_path = "./Datasets/Set5"
set14_path = "./Datasets/Set14"
urban100_path = "./Datasets/urban100"
bsd100_path = "./Datasets/BSD100"
manga109_path = "./Datasets/manga109"


mode = "load"
# mode =  "train"
# mode = "load-train"

trained_path = "TrainedModels/esrt/X" + str(upscale_factor) + "/"

# load_model_name = "esrt_best_130.pt"  # x2
load_model_name = "esrt_best_450.pt"  # x3
# load_model_name = "esrt_best_340.pt"  # x4

### Dataloader

In [None]:
train_dataset = Div2kDataset(train_path, train=True, repeat=10, upscale_factor=upscale_factor, patch_size=patch_size)
eval_dataset = ImageDataset(eval_path, upscale_factor=upscale_factor)

train_dataloader = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=False)

### Train

In [None]:
model = ESRT(upscale_factor).to(device)
criterion = nn.L1Loss().to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

if mode == "train":
    train_model(model, train_dataloader, eval_dataloader, criterion, optimizer, learning_rate, upscale_factor, trained_path, start_epoch, end_epoch, device)
elif mode == "load-train":
    model.load_state_dict(torch.load(trained_path + load_model_name, weights_only=False))
    print("Loaded model: " + load_model_name)
    train_model(model, train_dataloader, eval_dataloader, criterion, optimizer, learning_rate, upscale_factor, trained_path, start_epoch, end_epoch, device)
else:
    model.load_state_dict(torch.load(trained_path + load_model_name, weights_only=False))
    print("Loaded model: " + load_model_name)

### Results

In [None]:
set5_dataset = ImageDataset(set5_path, upscale_factor=upscale_factor)
set14_dataset = ImageDataset(set14_path, upscale_factor=upscale_factor)
# bsd100_dataset = ImageDataset(bsd100_path, upscale_factor=upscale_factor)
# urban100_dataset = ImageDataset(urban100_path, upscale_factor=upscale_factor)
# manga109_dataset = ImageDataset(manga109_path, upscale_factor=upscale_factor)

In [None]:
show_comparison_picture(model, set5_dataset, 2, (50, 50), scale=upscale_factor)

In [None]:
print("Model:", load_model_name)
bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, set5_dataset, upscale_factor)
print("Set5\n", f"Bicubic / {round(bic_psnr, 4)} / {round(bic_ssim, 4)}\n", f"ESRT / {round(up_psnr, 4)} / {round(up_ssim, 4)}")

bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, set14_dataset, upscale_factor)
print("Set14\n", f"Bicubic / {round(bic_psnr, 4)} / {round(bic_ssim, 4)}\n", f"ESRT / {round(up_psnr, 4)} / {round(up_ssim, 4)}")

# bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, bsd100_dataset, upscale_factor)
# print("BSD100\n", f"Bicubic / {round(bic_psnr, 4)} / {round(bic_ssim, 4)}\n", f"ESRT / {round(up_psnr, 4)} / {round(up_ssim, 4)}")

# bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, urban100_dataset, upscale_factor, chop_size=20000)
# print("urban100\n", f"Bicubic / {round(bic_psnr, 4)} / {round(bic_ssim, 4)}\n", f"ESRT / {round(up_psnr, 4)} / {round(up_ssim, 4)}")

# bic_psnr, bic_ssim, up_psnr, up_ssim = get_avg_metrics(model, manga109_dataset, upscale_factor, chop_size=25000)
# print("manga109\n", f"Bicubic / {round(bic_psnr, 4)} / {round(bic_ssim, 4)}\n", f"ESRT / {round(up_psnr, 4)} / {round(up_ssim, 4)}")

In [None]:
# upscale_factor = 3

dataset_name = "Set5"
# dataset_name = "Set14"
# dataset_name = "BSD100"
# dataset_name = "urban100"
# dataset_name = "manga109"

gt_path = f"./Datasets/{dataset_name}/{dataset_name}_HR"
up_path_srcnn = f"./Results/{dataset_name}/SRCNN/X{str(upscale_factor)}"
up_path_esrt = f"./Results/{dataset_name}/ESRT/X{str(upscale_factor)}"
up_path_esrt_paper = f"./Results/{dataset_name}/ESRT_paper/X{str(upscale_factor)}"

print(f"{dataset_name} x{upscale_factor} results for:")
print("SRCNN")
metrics_from_results(gt_path, up_path_srcnn)
print("ESRT")
metrics_from_results(gt_path, up_path_esrt)
print("ESRT paper")
metrics_from_results(gt_path, up_path_esrt_paper)