In [None]:
!git clone https://github.com/JingyunLiang/SwinIR.git
!wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x2_GAN.pth -P experiments/pretrained_models

In [None]:
import sys
sys.path.insert(0, "/content/SwinIR")

In [None]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--task', type=str, default='color_dn', help='classical_sr, lightweight_sr, real_sr, '
#                                                                  'gray_dn, color_dn, jpeg_car, color_jpeg_car')
# parser.add_argument('--scale', type=int, default=1, help='scale factor: 1, 2, 3, 4, 8') # 1 for dn and jpeg car
# parser.add_argument('--noise', type=int, default=15, help='noise level: 15, 25, 50')
# parser.add_argument('--jpeg', type=int, default=40, help='scale factor: 10, 20, 30, 40')
# parser.add_argument('--training_patch_size', type=int, default=128, help='patch size used in training SwinIR. '
#                                    'Just used to differentiate two different settings in Table 2 of the paper. '
#                                    'Images are NOT tested patch by patch.')
# parser.add_argument('--large_model', action='store_true', help='use large model, only provided for real image sr')
# parser.add_argument('--model_path', type=str,
#                     default='model_zoo/swinir/001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth')
# parser.add_argument('--folder_lq', type=str, default=None, help='input low-quality test image folder')
# parser.add_argument('--folder_gt', type=str, default=None, help='input ground-truth test image folder')
# parser.add_argument('--tile', type=int, default=None, help='Tile size, None for no tile during testing (testing as a whole)')
# parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')
# args = parser.parse_args()

In [None]:
from PIL import Image
import argparse
import cv2
import glob
import numpy as np
from collections import OrderedDict
import os
import torch
import requests

from models.network_swinir import SwinIR as net
from utils import util_calculate_psnr_ssim as util


def main(
    task="real_sr",
    scale=1,
    folder_lq=None,
    tile=None,
    tile_overlap=32,
    model_path="experiments/pretrained_models/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth",
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if os.path.exists(model_path):
        print(f"loading model from {model_path}")
    else:
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/{}".format(
            os.path.basename(model_path)
        )
        r = requests.get(url, allow_redirects=True)
        print(f"downloading model {model_path}")
        open(model_path, "wb").write(r.content)

    model = define_model(scale, model_path)
    model.eval()
    model = model.to(device)

    # setup folder and path
    folder, save_dir, border, window_size = setup(task, scale, folder_lq)
    os.makedirs(save_dir, exist_ok=True)
    psnr, ssim, psnr_y, ssim_y, psnrb, psnrb_y = 0, 0, 0, 0, 0, 0

    for idx, path in enumerate(sorted(glob.glob(os.path.join(folder, "*")))):
        # read image
        imgname, img_lq = get_image_pair(folder_lq, path)  # image to HWC-BGR, float32

        if img_lq is None:
            img = Image.open(path)
            img.save(os.path.join(save_dir, imgname + ".png"))
            continue

        img_lq = np.transpose(
            img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)
        )  # HCW-BGR to CHW-RGB
        img_lq = (
            torch.from_numpy(img_lq).float().unsqueeze(0).to(device)
        )  # CHW-RGB to NCHW-RGB

        # inference
        with torch.no_grad():
            # pad input image to be a multiple of window_size
            _, _, h_old, w_old = img_lq.size()
            h_pad = (h_old // window_size + 1) * window_size - h_old
            w_pad = (w_old // window_size + 1) * window_size - w_old
            img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[
                :, :, : h_old + h_pad, :
            ]
            img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[
                :, :, :, : w_old + w_pad
            ]
            output = test(img_lq, model, tile, tile_overlap, scale, window_size)
            output = output[..., : h_old * scale, : w_old * scale]

        # save image
        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        if output.ndim == 3:
            output = np.transpose(
                output[[2, 1, 0], :, :], (1, 2, 0)
            )  # CHW-RGB to HCW-BGR
        output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
        cv2.imwrite(os.path.join(save_dir, imgname + ".jpeg"), output)

        print("Testing {:d} {:20s}".format(idx, imgname))


def define_model(scale, model_path):
    # use 'nearest+conv' to avoid block artifacts
    model = net(
        upscale=scale,
        in_chans=3,
        img_size=64,
        window_size=8,
        img_range=1.0,
        depths=[6, 6, 6, 6, 6, 6],
        embed_dim=180,
        num_heads=[6, 6, 6, 6, 6, 6],
        mlp_ratio=2,
        upsampler="nearest+conv",
        resi_connection="1conv",
    )
    param_key_g = "params_ema"

    pretrained_model = torch.load(model_path)
    model.load_state_dict(
        pretrained_model[param_key_g]
        if param_key_g in pretrained_model.keys()
        else pretrained_model,
        strict=True,
    )

    return model


def setup(task, scale, folder_lq):
    save_dir = f"results/swinir_{task}_x{scale}"
    folder = folder_lq
    border = 0
    window_size = 8

    return folder, save_dir, border, window_size


def get_image_pair(folder_lq, path):
    (imgname, imgext) = os.path.splitext(os.path.basename(path))
    if ("ovar" in imgname) or ("foll" in imgname):
        return imgname, None

    img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0

    return imgname, img_lq


def test(img_lq, model, tile, tile_overlap, scale, window_size):
    if tile is None:
        # test the image as a whole
        output = model(img_lq)
    else:
        # test the image tile by tile
        b, c, h, w = img_lq.size()
        tile = min(tile, h, w)
        assert tile % window_size == 0, "tile size should be a multiple of window_size"
        tile_overlap = tile_overlap
        sf = scale

        stride = tile - tile_overlap
        h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
        w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
        E = torch.zeros(b, c, h * sf, w * sf).type_as(img_lq)
        W = torch.zeros_like(E)

        for h_idx in h_idx_list:
            for w_idx in w_idx_list:
                in_patch = img_lq[..., h_idx : h_idx + tile, w_idx : w_idx + tile]
                out_patch = model(in_patch)
                out_patch_mask = torch.ones_like(out_patch)

                E[
                    ...,
                    h_idx * sf : (h_idx + tile) * sf,
                    w_idx * sf : (w_idx + tile) * sf,
                ].add_(out_patch)
                W[
                    ...,
                    h_idx * sf : (h_idx + tile) * sf,
                    w_idx * sf : (w_idx + tile) * sf,
                ].add_(out_patch_mask)
        output = E.div_(W)

    return output

In [None]:
main(
    task="real_sr",
    model_path="experiments/pretrained_models/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x2_GAN.pth",
    folder_lq="train_clean/",
    scale=2,
)

main(
    task="real_sr",
    model_path="experiments/pretrained_models/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x2_GAN.pth",
    folder_lq="test_clean/",
    scale=2,
)