# Dataset Prepration files

## imgproc.py

In [1]:
import random
from typing import Any

import numpy as np
import torch
from PIL import Image
from torchvision import transforms

def normalize(image: np.ndarray) -> np.ndarray:
    """Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.

    Args:
        image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.

    Returns:
        Normalized image data. Data range [0, 1].
    """

    return image.astype(np.float64) / 255.0


def unnormalize(image: np.ndarray) -> np.ndarray:
    """Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.

    Args:
        image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.

    Returns:
        Denormalized image data. Data range [0, 255].
    """

    return image.astype(np.float64) * 255.0


def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
    """Convert ``PIL.Image`` to Tensor.

    Args:
        image (np.ndarray): The image data read by ``PIL.Image``
        range_norm (bool): Scale [0, 1] data to between [-1, 1]
        half (bool): Whether to convert torch.float32 similarly to torch.half type.

    Returns:
        Normalized image data

    Examples:
        >>> image = Image.open("image.bmp")
        >>> tensor_image = image2tensor(image, range_norm=False, half=False)
    """
    convert_tensor = transforms.ToTensor()
    tensor = convert_tensor(image)

    if range_norm:
        tensor = tensor.mul_(2.0).sub_(1.0)
    if half:
        tensor = tensor.half()

    return tensor


def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
    """Converts ``torch.Tensor`` to ``PIL.Image``.

    Args:
        tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
        range_norm (bool): Scale [-1, 1] data to between [0, 1]
        half (bool): Whether to convert torch.float32 similarly to torch.half type.

    Returns:
        Convert image data to support PIL library

    Examples:
        >>> tensor = torch.randn([1, 3, 128, 128])
        >>> image = tensor2image(tensor, range_norm=False, half=False)
    """

    if range_norm:
        tensor = tensor.add_(1.0).div_(2.0)
    if half:
        tensor = tensor.half()

    image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")

    return image


def convert_rgb_to_y(image: Any) -> Any:
    """Convert RGB image or tensor image data to YCbCr(Y) format.

    Args:
        image: RGB image data read by ``PIL.Image''.

    Returns:
        Y image array data.
    """

    if type(image) == np.ndarray:
        return 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
    elif type(image) == torch.Tensor:
        if len(image.shape) == 4:
            image = image.squeeze_(0)
        return 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
    else:
        raise Exception("Unknown Type", type(image))


def convert_rgb_to_ycbcr(image: Any) -> Any:
    """Convert RGB image or tensor image data to YCbCr format.

    Args:
        image: RGB image data read by ``PIL.Image''.

    Returns:
        YCbCr image array data.
    """

    if type(image) == np.ndarray:
        y = 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
        cb = 128. + (-37.945 * image[:, :, 0] - 74.494 * image[:, :, 1] + 112.439 * image[:, :, 2]) / 256.
        cr = 128. + (112.439 * image[:, :, 0] - 94.154 * image[:, :, 1] - 18.285 * image[:, :, 2]) / 256.
        return np.array([y, cb, cr]).transpose([1, 2, 0])
    elif type(image) == torch.Tensor:
        if len(image.shape) == 4:
            image = image.squeeze(0)
        y = 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
        cb = 128. + (-37.945 * image[0, :, :] - 74.494 * image[1, :, :] + 112.439 * image[2, :, :]) / 256.
        cr = 128. + (112.439 * image[0, :, :] - 94.154 * image[1, :, :] - 18.285 * image[2, :, :]) / 256.
        return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
    else:
        raise Exception("Unknown Type", type(image))


def convert_ycbcr_to_rgb(image: Any) -> Any:
    """Convert YCbCr format image to RGB format.

    Args:
       image: YCbCr image data read by ``PIL.Image''.

    Returns:
        RGB image array data.
    """

    if type(image) == np.ndarray:
        r = 298.082 * image[:, :, 0] / 256. + 408.583 * image[:, :, 2] / 256. - 222.921
        g = 298.082 * image[:, :, 0] / 256. - 100.291 * image[:, :, 1] / 256. - 208.120 * image[:, :, 2] / 256. + 135.576
        b = 298.082 * image[:, :, 0] / 256. + 516.412 * image[:, :, 1] / 256. - 276.836
        return np.array([r, g, b]).transpose([1, 2, 0])
    elif type(image) == torch.Tensor:
        if len(image.shape) == 4:
            image = image.squeeze(0)
        r = 298.082 * image[0, :, :] / 256. + 408.583 * image[2, :, :] / 256. - 222.921
        g = 298.082 * image[0, :, :] / 256. - 100.291 * image[1, :, :] / 256. - 208.120 * image[2, :, :] / 256. + 135.576
        b = 298.082 * image[0, :, :] / 256. + 516.412 * image[1, :, :] / 256. - 276.836
        return torch.cat([r, g, b], 0).permute(1, 2, 0)
    else:
        raise Exception("Unknown Type", type(image))


def center_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int):
    """Cut ``PIL.Image`` in the center area of the image.

    Args:
        lr: Low-resolution image data read by ``PIL.Image``.
        hr: High-resolution image data read by ``PIL.Image``.
        image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
        upscale_factor (int): magnification factor.

    Returns:
        Randomly cropped low-resolution images and high-resolution images.
    """

    w, h = hr.size

    left = (w - image_size) // 2
    top = (h - image_size) // 2
    right = left + image_size
    bottom = top + image_size

    lr = lr.crop((left // upscale_factor,
                  top // upscale_factor,
                  right // upscale_factor,
                  bottom // upscale_factor))
    hr = hr.crop((left, top, right, bottom))

    return lr, hr


def random_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int):
    """Will ``PIL.Image`` randomly capture the specified area of the image.

    Args:
        lr: Low-resolution image data read by ``PIL.Image``.
        hr: High-resolution image data read by ``PIL.Image``.
        image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
        upscale_factor (int): magnification factor.

    Returns:
        Randomly cropped low-resolution images and high-resolution images.
    """

    w, h = hr.size
    left = torch.randint(0, w - image_size + 1, size=(1,)).item()
    top = torch.randint(0, h - image_size + 1, size=(1,)).item()
    right = left + image_size
    bottom = top + image_size

    lr = lr.crop((left // upscale_factor,
                  top // upscale_factor,
                  right // upscale_factor,
                  bottom // upscale_factor))
    hr = hr.crop((left, top, right, bottom))

    return lr, hr


def random_rotate(lr: Any, hr: Any, degrees: list):
    """Will ``PIL.Image`` randomly rotate the image.

    Args:
        lr: Low-resolution image data read by ``PIL.Image``.
        hr: High-resolution image data read by ``PIL.Image``.
        degrees (list): rotation angle, clockwise and counterclockwise rotation.

    Returns:
        Randomly rotated low-resolution images and high-resolution images.
    """

    angle = random.choice(degrees)
    lr = F.rotate(lr, angle)
    hr = F.rotate(hr, angle)

    return lr, hr


def random_horizontally_flip(lr: Any, hr: Any, p=0.5):
    """Flip the ``PIL.Image`` image horizontally randomly.

    Args:
        lr: Low-resolution image data read by ``PIL.Image``.
        hr: High-resolution image data read by ``PIL.Image``.
        p (optional, float): rollover probability. (Default: 0.5)

    Returns:
        Low-resolution image and high-resolution image after random horizontal flip.
    """

    if torch.rand(1).item() > p:
        lr = F.hflip(lr)
        hr = F.hflip(hr)

    return lr, hr


def random_vertically_flip(lr: Any, hr: Any, p=0.5):
    """Turn the ``PIL.Image`` image upside down randomly.

    Args:
        lr: Low-resolution image data read by ``PIL.Image``.
        hr: High-resolution image data read by ``PIL.Image``.
        p (optional, float): rollover probability. (Default: 0.5)

    Returns:
        Randomly rotated up and down low-resolution images and high-resolution images.
    """

    if torch.rand(1).item() > p:
        lr = F.vflip(lr)
        hr = F.vflip(hr)

    return lr, hr


def random_adjust_brightness(lr: Any, hr: Any, factor: list):
    """Set ``PIL.Image`` to randomly adjust the image brightness.

    Args:
        lr: Low-resolution image data read by ``PIL.Image``.
        hr: High-resolution image data read by ``PIL.Image``.
        factor (list): Brightness coefficient adjustment range.

    Returns:
        Low-resolution image and high-resolution image with randomly adjusted brightness.
    """

    # Randomly adjust the brightness gain range.
    brightness_factor = random.choice(factor)
    lr = F.adjust_brightness(lr, brightness_factor)
    hr = F.adjust_brightness(hr, brightness_factor)

    return lr, hr


def random_adjust_contrast(lr: Any, hr: Any, factor: list):
    """Set ``PIL.Image`` to randomly adjust the image contrast.

    Args:
        lr: Low-resolution image data read by ``PIL.Image``.
        hr: High-resolution image data read by ``PIL.Image``.
        factor (list): Contrast coefficient adjustment range.

    Returns:
        Low-resolution image and high-resolution image with randomly adjusted contrast.
    """

    # Randomly adjust the contrast gain range.
    contrast_factor = random.choice(factor)
    lr = F.adjust_contrast(lr, contrast_factor)
    hr = F.adjust_contrast(hr, contrast_factor)

    return lr, hr

## Dataset.py

In [2]:
import io
import os

import lmdb
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode as IMode

class ImageDataset(Dataset):
    """Customize the data set loading function and prepare low/high resolution image data in advance.

    Args:
        dataroot         (str): Training data set address
        image_size       (int): High resolution image size
        upscale_factor   (int): Image magnification
        mode             (str): Data set loading method, the training data set is for data enhancement,
                             and the verification data set is not for data enhancement

    """

    def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
        super(ImageDataset, self).__init__()
        self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]

        if mode == "train":
            self.hr_transforms = transforms.Compose([
                transforms.RandomCrop(image_size),
                transforms.RandomRotation([0, 90]),
                transforms.RandomHorizontalFlip(0.5),
            ])
        elif mode == "valid":
            self.hr_transforms = transforms.CenterCrop(image_size)
        else:
            raise "Unsupported data processing model, please use `train` or `valid`."

        self.lr_transforms = transforms.Resize(image_size // upscale_factor, interpolation=IMode.BICUBIC)

    def __getitem__(self, batch_index: int):
        # Read a batch of image data
        image = Image.open(self.filenames[batch_index])

        # Transform image
        hr_image = self.hr_transforms(image)
        lr_image = self.lr_transforms(hr_image)

        # Convert image data into Tensor stream format (PyTorch).
        # Note: The range of input and output is between [0, 1]
        lr_tensor = image2tensor(lr_image, range_norm=False, half=False)
        hr_tensor = image2tensor(hr_image, range_norm=False, half=False)

        return lr_tensor, hr_tensor

    def __len__(self) -> int:
        return len(self.filenames)


class LMDBDataset(Dataset):
    """Load the data set as a data set in the form of LMDB.

    Attributes:
        lr_datasets (list): Low-resolution image data in the dataset
        hr_datasets (list): High-resolution image data in the dataset

    Args:
        lr_lmdb_path (str): LMDB file address of low-resolution image
        hr_lmdb_path (int): LMDB file address of high-resolution image
        image_size (int): High resolution image size
        upscale_factor (int): Image magnification
        mode (str): Data set loading method, the training data set is for data enhancement,
            and the verification data set is not for data enhancement
    """

    def __init__(self, lr_lmdb_path: str, hr_lmdb_path: str, image_size: int, upscale_factor: int, mode: str) -> None:
        super(LMDBDataset, self).__init__()
        self.image_size = image_size
        self.upscale_factor = upscale_factor
        self.mode = mode

        # Create low/high resolution image array
        self.lr_datasets = []
        self.hr_datasets = []

        # Initialize the LMDB database file address
        self.lr_lmdb_path = lr_lmdb_path
        self.hr_lmdb_path = hr_lmdb_path

        # Write image data in LMDB database to memory
        self.read_lmdb_dataset()

    def __getitem__(self, batch_index: int):
        # Read a batch of image data
        lr_image = self.lr_datasets[batch_index]
        hr_image = self.hr_datasets[batch_index]

        # Data augment
        if self.mode == "train:":
            lr_image, hr_image = random_crop(lr_image, hr_image, image_size=self.image_size, upscale_factor=self.upscale_factor)
            lr_image, hr_image = random_rotate(lr_image, hr_image, degrees=[0, 90])
            lr_image, hr_image = random_horizontally_flip(lr_image, hr_image, p=0.5)
        elif self.mode == "valid:":
            lr_image, hr_image = center_crop(lr_image, hr_image, image_size=self.image_size, upscale_factor=self.upscale_factor)
        else:
            raise "Unsupported data processing model, please use `train` or `valid`."

        # Convert image data into Tensor stream format (PyTorch).
        # Note: The range of input and output is between [0, 1]
        lr_tensor = image2tensor(lr_image, range_norm=False, half=False)
        hr_tensor = image2tensor(hr_image, range_norm=False, half=False)

        return lr_tensor, hr_tensor

    def __len__(self) -> int:
        return len(self.hr_datasets)

    def read_lmdb_dataset(self):
        # Open two LMDB database writing environments to read low/high image data
        lr_lmdb_env = lmdb.open(self.lr_lmdb_path)
        hr_lmdb_env = lmdb.open(self.hr_lmdb_path)

        # Write the image data in the low-resolution LMDB data set to the memory
        for _, image_bytes in lr_lmdb_env.begin().cursor():
            image = Image.open(io.BytesIO(image_bytes))
            self.lr_datasets.append(image)

        # Write the image data in the high-resolution LMDB data set to the memory
        for _, image_bytes in hr_lmdb_env.begin().cursor():
            image = Image.open(io.BytesIO(image_bytes))
            self.hr_datasets.append(image)

# Model

## model.py

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class ResidualDenseBlock(nn.Module):
    """Achieves densely connected convolutional layers.
    `Densely Connected Convolutional Networks <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.

    Args:
        channels (int): The number of channels in the input image.
        growths (int): The number of channels that increase in each layer of convolution.
    """

    def __init__(self, channels: int, growths: int) -> None:
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels + growths * 0, growths, (3, 3), (1, 1), (1, 1))
        self.conv2 = nn.Conv2d(channels + growths * 1, growths, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(channels + growths * 2, growths, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(channels + growths * 3, growths, (3, 3), (1, 1), (1, 1))
        self.conv5 = nn.Conv2d(channels + growths * 4, channels, (3, 3), (1, 1), (1, 1))

        self.leaky_relu = nn.LeakyReLU(0.2, True)
        self.identity = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out1 = self.leaky_relu(self.conv1(x))
        out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
        out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
        out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
        out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
        out = torch.mul(out5, 0.2)
        out = torch.add(out, identity)

        return out


class ResidualResidualDenseBlock(nn.Module):
    """Multi-layer residual dense convolution block.

    Args:
        channels (int): The number of channels in the input image.
        growths (int): The number of channels that increase in each layer of convolution.
    """

    def __init__(self, channels: int, growths: int) -> None:
        super(ResidualResidualDenseBlock, self).__init__()
        self.rdb1 = ResidualDenseBlock(channels, growths)
        self.rdb2 = ResidualDenseBlock(channels, growths)
        self.rdb3 = ResidualDenseBlock(channels, growths)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        out = torch.mul(out, 0.2)
        out = torch.add(out, identity)

        return out


class Discriminator(nn.Module):
    def __init__(self) -> None:
        super(Discriminator, self).__init__()
        self.features = nn.Sequential(
            # input size. (3) x 128 x 128
            nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=True),
            nn.LeakyReLU(0.2, True),
            # state size. (64) x 64 x 64
            nn.Conv2d(64, 64, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            # state size. (128) x 32 x 32
            nn.Conv2d(128, 128, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            # state size. (256) x 16 x 16
            nn.Conv2d(256, 256, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            # state size. (512) x 8 x 8
            nn.Conv2d(512, 512, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            # state size. (512) x 4 x 4
            nn.Conv2d(512, 512, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True)
        )

        self.classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 100),
            nn.LeakyReLU(0.2, True),
            nn.Linear(100, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.features(x)
        out = torch.flatten(out, 1)
        out = self.classifier(out)

        return out


class Generator(nn.Module):
    def __init__(self) -> None:
        super(Generator, self).__init__()
        # The first layer of convolutional layer.
        self.conv1 = nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1))

        # Feature extraction backbone network.
        trunk = []
        for _ in range(23):
            trunk.append(ResidualResidualDenseBlock(64, 32))
        self.trunk = nn.Sequential(*trunk)

        # After the feature extraction network, reconnect a layer of convolutional blocks.
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))

        # Upsampling convolutional layer.
        self.upsampling = nn.Sequential(
            nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
            nn.LeakyReLU(0.2, True)
        )

        # Reconnect a layer of convolution block after upsampling.
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
            nn.LeakyReLU(0.2, True)
        )

        # Output layer.
        self.conv4 = nn.Conv2d(64, 3, (3, 3), (1, 1), (1, 1))

    # The model should be defined in the Torch.script method.
    def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
        out1 = self.conv1(x)
        out = self.trunk(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(F.interpolate(out, scale_factor=2, mode="nearest"))
        out = self.upsampling(F.interpolate(out, scale_factor=2, mode="nearest"))
        out = self.conv3(out)
        out = self.conv4(out)

        return out

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self._forward_impl(x)

    def _initialize_weights(self) -> None:
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                module.weight.data *= 0.1
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)


class ContentLoss(nn.Module):
    """Constructs a content loss function based on the VGG19 network.
    Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.

    Paper reference list:
        -`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
        -`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks                    <https://arxiv.org/pdf/1809.00219.pdf>` paper.
        -`Perceptual Extreme Super Resolution Network with Receptive Field Block               <https://arxiv.org/pdf/2005.12597.pdf>` paper.

     """

    def __init__(self) -> None:
        super(ContentLoss, self).__init__()
        # Load the VGG19 model trained on the ImageNet dataset.
        vgg19 = models.vgg19(pretrained=True).eval()
        # Extract the thirty-sixth layer output in the VGG19 model as the content loss.
        self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:35])
        # Freeze model parameters.
        for parameters in self.feature_extractor.parameters():
            parameters.requires_grad = False

        # The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
        self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, sr: torch.Tensor, hr: torch.Tensor) -> torch.Tensor:
        # Standardized operations
        sr = sr.sub(self.mean).div(self.std)
        hr = hr.sub(self.mean).div(self.std)

        # Find the feature map difference between the two images
        loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))

        return loss

# Configuration

## config.py

In [4]:
import torch
from torch.backends import cudnn
# Random seed to maintain reproducible results
torch.manual_seed(0)
# Use GPU for training by default
if torch.cuda.is_available():
    device =  torch.device('cuda')
else:
    device =  torch.device('cpu')
# device = torch.device("cuda", 0)
# Turning on when the image size does not change during training can speed up training
cudnn.benchmark = True
# Image magnification factor
upscale_factor = 4
# Current configuration parameter method
mode = "train_rrdbnet"
# Experiment name, easy to save weights and log files
exp_name = "RRDBNet_baseline"

if mode == "train_rrdbnet":
    # Dataset address
    train_image_dir = '/kaggle/input/dataset/data/DIV2K/ESRGAN/train/'
    valid_image_dir = '/kaggle/input/dataset/data/DIV2K/ESRGAN/valid/'

    image_size = 192
    batch_size = 48
    num_workers = 4

    # Incremental training and migration training
    resume = False
    strict = False
    start_epoch = 0
    resume_weight = ""

    # Total num epochs
    epochs = 120

    # Adam optimizer parameter for RRDBNet(p)
    model_lr = 2e-4
    model_betas = (0.9, 0.999)

    # StepLR scheduler
    step_size = epochs // 5
    gamma = 0.5

    # Print the training log every one hundred iterations
    print_frequency = 1000

if mode == "train_esrgan":
    # Dataset address
    train_image_dir = 'data/DIV2K/ESRGAN/train/'
    valid_image_dir = 'data/DIV2K/ESRGAN/valid/'

    image_size = 128
    batch_size = 16
    num_workers = 4

    # Incremental training and migration training
    resume = False
    strict = False
    start_epoch = 0
    resume_d_weight = ""
    resume_g_weight = "results/RRDBNet_baseline/g-last.pth"

    # Total num epochs
    epochs = 48

    # Loss function weight
    pixel_weight = 1.0
    content_weight = 1.0
    adversarial_weight = 0.001

    # Adam optimizer parameter for Discriminator
    d_model_lr = 1e-4
    d_model_betas = (0.9, 0.999)

    # Adam optimizer parameter for Generator
    g_model_lr = 1e-4
    g_model_betas = (0.9, 0.999)

    # MultiStepLR scheduler parameter for ESRGAN
    d_optimizer_milestones = [int(epochs * 0.125), int(epochs * 0.250), int(epochs * 0.500), int(epochs * 0.750)]
    g_optimizer_milestones = [int(epochs * 0.125), int(epochs * 0.250), int(epochs * 0.500), int(epochs * 0.750)]
    d_optimizer_gamma = 0.5
    g_optimizer_gamma = 0.5

    # Print the training log every one hundred iterations
    print_frequency = 1000

if mode == "valid":
    # Test data address
    lr_dir = f"data/Set14/LRbicx{upscale_factor}"
    sr_dir = f"results/test/{exp_name}"
    hr_dir = f"data/Set14/GTmod12"

    model_path = f"results/{exp_name}/g-last.pth"

# Training

## train_rrdbnet.py

In [5]:
import os
import time

import torch
from torch import nn
from torch import optim
from torch.cuda import amp
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

def main():
    print("Load train dataset and valid dataset...")
    train_dataloader, valid_dataloader = load_dataset()
    print("Load train dataset and valid dataset successfully.")

    print("Build RRDBNet model...")
    model = build_model()
    print("Build RRDBNet model successfully.")

    print("Define all loss functions...")
    psnr_criterion, pixel_criterion = define_loss()
    print("Define all loss functions successfully.")

    print("Define all optimizer functions...")
    optimizer = define_optimizer(model)
    print("Define all optimizer functions successfully.")

    print("Define all scheduler functions...")
    scheduler = define_scheduler(optimizer)
    print("Define all scheduler functions successfully.")

    print("Check whether the training weight is restored...")
    resume_checkpoint(model)
    print("Check whether the training weight is restored successfully.")

    # Create a folder of super-resolution experiment results
    samples_dir = os.path.join("samples", exp_name)
    results_dir = os.path.join("results", exp_name)
    if not os.path.exists(samples_dir):
        os.makedirs(samples_dir)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    # Create training process log file
    writer = SummaryWriter(os.path.join("samples", "logs", exp_name))

    # Initialize the gradient scaler
    scaler = amp.GradScaler()

    # Initialize training to generate network evaluation indicators
    best_psnr = 0.0

    print("Start train RRDBNet model.")
    for epoch in range(start_epoch, epochs):
        print(f"Epoch : {epoch+1}")
        train(model, train_dataloader, psnr_criterion, pixel_criterion, optimizer, epoch, scaler, writer)

        psnr = validate(model, valid_dataloader, psnr_criterion, epoch, writer)
        # Automatically save the model with the highest index
        is_best = psnr > best_psnr
        best_psnr = max(psnr, best_psnr)
        torch.save(model.state_dict(), os.path.join(samples_dir, f"g_epoch_{epoch + 1}.pth"))
        if is_best:
            torch.save(model.state_dict(), os.path.join(results_dir, "g-best.pth"))

        # Update LR
        scheduler.step()

    # Save the generator weight under the last Epoch in this stage
    torch.save(model.state_dict(), os.path.join(results_dir, "g-last.pth"))
    print("End train RRDBNet model.")


def load_dataset():
    train_datasets = ImageDataset(train_image_dir, image_size, upscale_factor, "train")
    valid_datasets = ImageDataset(valid_image_dir, image_size, upscale_factor, "valid")
    # Make it into a data set type supported by PyTorch
    train_dataloader = DataLoader(train_datasets,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  pin_memory=True,
                                  persistent_workers=True)
    valid_dataloader = DataLoader(valid_datasets,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=num_workers,
                                  pin_memory=True,
                                  persistent_workers=True)

    return train_dataloader, valid_dataloader


def build_model() -> nn.Module:
    model = Generator().to(device)

    return model


def define_loss():
    psnr_criterion = nn.MSELoss().to(device)
    pixel_criterion = nn.L1Loss().to(device)

    return psnr_criterion, pixel_criterion


def define_optimizer(model) -> optim.Adam:
    optimizer = optim.Adam(model.parameters(), model_lr, model_betas)

    return optimizer


def define_scheduler(optimizer) -> optim.lr_scheduler:
    scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

    return scheduler


def resume_checkpoint(model) -> None:
    if resume:
        if resume_weight != "":
            # Get pretrained model state dict
            pretrained_state_dict = torch.load(resume_weight)
            model_state_dict = model.state_dict()
            # Extract the fitted model weights
            new_state_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict.items()}
            # Overwrite the pretrained model weights to the current model
            model_state_dict.update(new_state_dict)
            model.load_state_dict(model_state_dict, strict=strict)


def train(model, train_dataloader, psnr_criterion, pixel_criterion, optimizer, epoch, scaler, writer) -> None:
    # Calculate how many iterations there are under epoch
    batches = len(train_dataloader)

    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":6.6f")
    psnres = AverageMeter("PSNR", ":4.2f")
    progress = ProgressMeter(batches, [batch_time, data_time, losses, psnres], prefix=f"Epoch: [{epoch + 1}]")

    # Put the generator in train mode.
    model.train()

    end = time.time()
    for index, (lr, hr) in enumerate(train_dataloader):
        # measure data loading time
        data_time.update(time.time() - end)

        lr = lr.to(device, non_blocking=True)
        hr = hr.to(device, non_blocking=True)

        # Initialize the generator gradient
        model.zero_grad()

        # Mixed precision training
        with amp.autocast():
            sr = model(lr)
            loss = pixel_criterion(sr, hr)

        # Gradient zoom
        scaler.scale(loss).backward()
        # Update generator weight
        scaler.step(optimizer)
        scaler.update()

        # measure accuracy and record loss
        psnr = 10. * torch.log10(1. / psnr_criterion(sr, hr))
        losses.update(loss.item(), lr.size(0))
        psnres.update(psnr.item(), lr.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # Writer Loss to file
        writer.add_scalar("Train/Loss", loss.item(), index + epoch * batches + 1)
        if index % print_frequency == 0 and index != 0:
            progress.display(index)


def validate(model, valid_dataloader, psnr_criterion, epoch, writer) -> float:
    batch_time = AverageMeter("Time", ":6.3f")
    psnres = AverageMeter("PSNR", ":4.2f")
    progress = ProgressMeter(len(valid_dataloader), [batch_time, psnres], prefix="Valid: ")

    # Put the generator in verification mode.
    model.eval()

    with torch.no_grad():
        end = time.time()
        for index, (lr, hr) in enumerate(valid_dataloader):
            lr = lr.to(device, non_blocking=True)
            hr = hr.to(device, non_blocking=True)

            # Mixed precision
            with amp.autocast():
                sr = model(lr)

            # measure accuracy and record loss
            psnr = 10. * torch.log10(1. / psnr_criterion(sr, hr))
            psnres.update(psnr.item(), hr.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if index % print_frequency == 0:
                progress.display(index)

        writer.add_scalar("Valid/PSNR", psnres.avg, epoch + 1)
        # Print evaluation indicators.
        print(f"* PSNR: {psnres.avg:4.2f}.\n")

    return psnres.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=":f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print("\t".join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = "{:" + str(num_digits) + "d}"
        return "[" + fmt + "/" + fmt.format(num_batches) + "]"


if __name__ == "__main__":
    main()

Load train dataset and valid dataset...
Load train dataset and valid dataset successfully.
Build RRDBNet model...


  cpuset_checked))


Build RRDBNet model successfully.
Define all loss functions...
Define all loss functions successfully.
Define all optimizer functions...
Define all optimizer functions successfully.
Define all scheduler functions...
Define all scheduler functions successfully.
Check whether the training weight is restored...
Check whether the training weight is restored successfully.
Start train RRDBNet model.
Epoch : 1


  cpuset_checked))


Valid: [0/3]	Time 11.453 (11.453)	PSNR 16.63 (16.63)
* PSNR: 16.59.

Epoch : 2
Valid: [0/3]	Time  8.619 ( 8.619)	PSNR 18.69 (18.69)
* PSNR: 18.53.

Epoch : 3
Valid: [0/3]	Time  8.616 ( 8.616)	PSNR 19.95 (19.95)
* PSNR: 19.69.

Epoch : 4
Valid: [0/3]	Time  8.343 ( 8.343)	PSNR 20.54 (20.54)
* PSNR: 20.27.

Epoch : 5
Valid: [0/3]	Time  8.433 ( 8.433)	PSNR 20.98 (20.98)
* PSNR: 20.69.

Epoch : 6
Valid: [0/3]	Time  8.306 ( 8.306)	PSNR 21.16 (21.16)
* PSNR: 20.86.

Epoch : 7
Valid: [0/3]	Time  8.583 ( 8.583)	PSNR 21.51 (21.51)
* PSNR: 21.19.

Epoch : 8
Valid: [0/3]	Time  9.170 ( 9.170)	PSNR 20.88 (20.88)
* PSNR: 20.57.

Epoch : 9
Valid: [0/3]	Time  8.673 ( 8.673)	PSNR 21.82 (21.82)
* PSNR: 21.49.

Epoch : 10
Valid: [0/3]	Time  8.896 ( 8.896)	PSNR 22.07 (22.07)
* PSNR: 21.74.

Epoch : 11
Valid: [0/3]	Time  8.826 ( 8.826)	PSNR 22.01 (22.01)
* PSNR: 21.68.

Epoch : 12
Valid: [0/3]	Time  8.570 ( 8.570)	PSNR 22.45 (22.45)
* PSNR: 22.09.

Epoch : 13
Valid: [0/3]	Time  8.181 ( 8.181)	PSNR 21.94 (21

# Configuration

# Config.py

In [6]:
import torch
from torch.backends import cudnn
# Random seed to maintain reproducible results
torch.manual_seed(0)
# Use GPU for training by default
if torch.cuda.is_available():
    device =  torch.device('cuda')
else:
    device =  torch.device('cpu')
# device = torch.device("cuda", 0)
# Turning on when the image size does not change during training can speed up training
cudnn.benchmark = True
# Image magnification factor
upscale_factor = 4
# Current configuration parameter method
mode = "train_esrgan"
# Experiment name, easy to save weights and log files
exp_name = "RRDBNet_baseline"

if mode == "train_rrdbnet":
    # Dataset address
    train_image_dir = 'data/DIV2K/ESRGAN/train/'
    valid_image_dir = 'data/DIV2K/ESRGAN/valid/'

    image_size = 192
    batch_size = 48
    num_workers = 4

    # Incremental training and migration training
    resume = False
    strict = False
    start_epoch = 0
    resume_weight = ""

    # Total num epochs
    epochs = 120

    # Adam optimizer parameter for RRDBNet(p)
    model_lr = 2e-4
    model_betas = (0.9, 0.999)

    # StepLR scheduler
    step_size = epochs // 5
    gamma = 0.5

    # Print the training log every one hundred iterations
    print_frequency = 1000

if mode == "train_esrgan":
    # Dataset address
    train_image_dir = 'data/DIV2K/ESRGAN/train/'
    valid_image_dir = 'data/DIV2K/ESRGAN/valid/'

    image_size = 128
    batch_size = 32
    num_workers = 4

    # Incremental training and migration training
    resume = False
    strict = False
    start_epoch = 0
    resume_d_weight = ""
    resume_g_weight = "results/RRDBNet_baseline/g-last.pth"

    # Total num epochs
    epochs = 120

    # Loss function weight
    pixel_weight = 1.0
    content_weight = 1.0
    adversarial_weight = 0.001

    # Adam optimizer parameter for Discriminator
    d_model_lr = 1e-4
    d_model_betas = (0.9, 0.999)

    # Adam optimizer parameter for Generator
    g_model_lr = 1e-4
    g_model_betas = (0.9, 0.999)

    # MultiStepLR scheduler parameter for ESRGAN
    d_optimizer_milestones = [int(epochs * 0.125), int(epochs * 0.250), int(epochs * 0.500), int(epochs * 0.750)]
    g_optimizer_milestones = [int(epochs * 0.125), int(epochs * 0.250), int(epochs * 0.500), int(epochs * 0.750)]
    d_optimizer_gamma = 0.5
    g_optimizer_gamma = 0.5

    # Print the training log every one hundred iterations
    print_frequency = 1000

if mode == "valid":
    # Test data address
    lr_dir = f"data/Set14/LRbicx{upscale_factor}"
    sr_dir = f"results/test/{exp_name}"
    hr_dir = f"data/Set14/GTmod12"

    model_path = f"results/{exp_name}/g-last.pth"

## train_esrgan.py


In [7]:
def main():
    print("Load train dataset and valid dataset...")
    train_dataloader, valid_dataloader = load_dataset()
    print("Load train dataset and valid dataset successfully.")

    print("Build ESRGAN model...")
    discriminator, generator = build_model()
    print("Build ESRGAN model successfully.")

    print("Define all loss functions...")
    psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion = define_loss()
    print("Define all loss functions successfully.")

    print("Define all optimizer functions...")
    d_optimizer, g_optimizer = define_optimizer(discriminator, generator)
    print("Define all optimizer functions successfully.")

    print("Define all optimizer scheduler...")
    d_scheduler, g_scheduler = define_scheduler(d_optimizer, g_optimizer)
    print("Define all optimizer scheduler functions successfully.")

    print("Check whether the training weight is restored...")
    resume_checkpoint(discriminator, generator)
    print("Check whether the training weight is restored successfully.")

    # Create a folder of super-resolution experiment results
    samples_dir = os.path.join("samples", exp_name)
    results_dir = os.path.join("results", exp_name)
    if not os.path.exists(samples_dir):
        os.makedirs(samples_dir)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    # Create training process log file
    writer = SummaryWriter(os.path.join("samples", "logs", exp_name))

    # Initialize the gradient scaler.
    scaler = amp.GradScaler()

    # Initialize training to generate network evaluation indicators
    best_psnr = 0.0

    print("Start train ESRGAN model.")
    for epoch in range(start_epoch, epochs):
        train(discriminator,
              generator,
              train_dataloader,
              psnr_criterion,
              pixel_criterion,
              content_criterion,
              adversarial_criterion,
              d_optimizer,
              g_optimizer,
              epoch,
              scaler,
              writer)

        psnr = validate(generator, valid_dataloader, psnr_criterion, epoch, writer)
        # Automatically save the model with the highest index
        is_best = psnr > best_psnr
        best_psnr = max(psnr, best_psnr)
        torch.save(discriminator.state_dict(), os.path.join(samples_dir, f"d_epoch_{epoch + 1}.pth"))
        torch.save(generator.state_dict(), os.path.join(samples_dir, f"g_epoch_{epoch + 1}.pth"))
        if is_best:
            torch.save(discriminator.state_dict(), os.path.join(results_dir, "d-best.pth"))
            torch.save(generator.state_dict(), os.path.join(results_dir, f"g-best.pth"))

        # Update LR
        d_scheduler.step()
        g_scheduler.step()

    # Save the generator weight under the last Epoch in this stage
    torch.save(discriminator.state_dict(), os.path.join(results_dir, "d-last.pth"))
    torch.save(generator.state_dict(), os.path.join(results_dir, "g-last.pth"))
    print("End train ESRGAN model.")


def load_dataset():
    # Initialize the LMDB data set class and write the contents of the LMDB database file into memory
    train_datasets = ImageDataset(train_image_dir, image_size, upscale_factor, "train")
    valid_datasets = ImageDataset(valid_image_dir, image_size, upscale_factor, "valid")
    # Make it into a data set type supported by PyTorch
    train_dataloader = DataLoader(train_datasets,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  pin_memory=True,
                                  persistent_workers=True)
    valid_dataloader = DataLoader(valid_datasets,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=num_workers,
                                  pin_memory=True,
                                  persistent_workers=True)

    return train_dataloader, valid_dataloader


def build_model() -> nn.Module:
    discriminator = Discriminator().to(device)
    generator = Generator().to(device)

    return discriminator, generator


def define_loss():
    psnr_criterion = nn.MSELoss().to(device)
    pixel_criterion = nn.L1Loss().to(device)
    content_criterion = ContentLoss().to(device)
    adversarial_criterion = nn.BCEWithLogitsLoss().to(device)

    return psnr_criterion, pixel_criterion, content_criterion, adversarial_criterion


def define_optimizer(discriminator: nn.Module, generator: nn.Module):
    d_optimizer = optim.Adam(discriminator.parameters(), d_model_lr, d_model_betas)
    g_optimizer = optim.Adam(generator.parameters(), g_model_lr, g_model_betas)

    return d_optimizer, g_optimizer


def define_scheduler(d_optimizer: optim.Adam, g_optimizer: optim.Adam):
    d_scheduler = lr_scheduler.MultiStepLR(d_optimizer, d_optimizer_milestones, d_optimizer_gamma)
    g_scheduler = lr_scheduler.MultiStepLR(g_optimizer, g_optimizer_milestones, g_optimizer_gamma)

    return d_scheduler, g_scheduler


def resume_checkpoint(discriminator: nn.Module, generator: nn.Module) -> None:
    if resume:
        if resume_d_weight != "":
            # Get pretrained model state dict
            pretrained_state_dict = torch.load(resume_d_weight)
            model_state_dict = discriminator.state_dict()
            # Extract the fitted model weights
            new_state_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict.items()}
            # Overwrite the pretrained model weights to the current model
            model_state_dict.update(new_state_dict)
            discriminator.load_state_dict(model_state_dict, strict=strict)
        if resume_g_weight != "":
            # Get pretrained model state dict
            pretrained_state_dict = torch.load(resume_g_weight)
            model_state_dict = generator.state_dict()
            # Extract the fitted model weights
            new_state_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict.items()}
            # Overwrite the pretrained model weights to the current model
            model_state_dict.update(new_state_dict)
            generator.load_state_dict(model_state_dict, strict=strict)


def train(discriminator,
          generator,
          train_dataloader,
          psnr_criterion,
          pixel_criterion,
          content_criterion,
          adversarial_criterion,
          d_optimizer,
          g_optimizer,
          epoch,
          scaler,
          writer) -> None:
    # Calculate how many iterations there are under epoch
    batches = len(train_dataloader)

    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    pixel_losses = AverageMeter("Pixel loss", ":6.6f")
    content_losses = AverageMeter("Content loss", ":6.6f")
    adversarial_losses = AverageMeter("Adversarial loss", ":6.6f")
    d_hr_probabilities = AverageMeter("D(HR)", ":6.3f")
    d_sr_probabilities = AverageMeter("D(SR)", ":6.3f")
    psnres = AverageMeter("PSNR", ":4.2f")
    progress = ProgressMeter(batches,
                             [batch_time, data_time,
                              pixel_losses, content_losses, adversarial_losses,
                              d_hr_probabilities, d_sr_probabilities,
                              psnres],
                             prefix=f"Epoch: [{epoch + 1}]")

    # Put all model in train mode.
    discriminator.train()
    generator.train()

    end = time.time()
    for index, (lr, hr) in enumerate(train_dataloader):
        # measure data loading time
        data_time.update(time.time() - end)

        # Send data to designated device
        lr = lr.to(device, non_blocking=True)
        hr = hr.to(device, non_blocking=True)

        # Set the real sample label to 1, and the false sample label to 0
        real_label = torch.full([lr.size(0), 1], 1.0, dtype=lr.dtype, device=device)
        fake_label = torch.full([lr.size(0), 1], 0.0, dtype=lr.dtype, device=device)

        # Use generators to create super-resolution images
        sr = generator(lr)

        # Start training discriminator
        # At this stage, the discriminator needs to require a derivative gradient
        for p in discriminator.parameters():
            p.requires_grad = True

        # Initialize the discriminator optimizer gradient
        d_optimizer.zero_grad()

        # Calculate the loss of the discriminator on the high-resolution image
        with amp.autocast():
            hr_output = discriminator(hr)
            sr_output = discriminator(sr.detach())
            d_loss_hr = adversarial_criterion(hr_output - torch.mean(sr_output), real_label)
            d_loss_sr = adversarial_criterion(sr_output - torch.mean(hr_output), fake_label)
        # Gradient zoom
        scaler.scale(d_loss_hr).backward(retain_graph=True)
        scaler.scale(d_loss_sr).backward()

        # Update gradient
        scaler.step(d_optimizer)
        scaler.update()

        # Count discriminator total loss
        d_loss = d_loss_hr + d_loss_sr
        # End training discriminator

        # Start training generator
        # At this stage, the discriminator no needs to require a derivative gradient
        for p in discriminator.parameters():
            p.requires_grad = False

        # Initialize the generator optimizer gradient
        g_optimizer.zero_grad()

        # Calculate the loss of the generator on the super-resolution image
        with amp.autocast():
            hr_output = discriminator(hr.detach())
            sr_output = discriminator(sr)
            pixel_loss = pixel_weight * pixel_criterion(sr, hr.detach())
            content_loss = content_weight * content_criterion(sr, hr.detach())
            adversarial_loss = adversarial_weight * adversarial_criterion(sr_output - torch.mean(hr_output), real_label)
        # Count discriminator total loss
        g_loss = pixel_loss + content_loss + adversarial_loss
        # Gradient zoom
        scaler.scale(g_loss).backward()
        # Update generator parameters
        scaler.step(g_optimizer)
        scaler.update()

        # End training generator

        # Calculate the scores of the two images on the discriminator
        d_hr_probability = torch.sigmoid(torch.mean(hr_output))
        d_sr_probability = torch.sigmoid(torch.mean(sr_output))

        # measure accuracy and record loss
        psnr = 10. * torch.log10(1. / psnr_criterion(sr, hr))
        pixel_losses.update(pixel_loss.item(), lr.size(0))
        content_losses.update(content_loss.item(), lr.size(0))
        adversarial_losses.update(adversarial_loss.item(), lr.size(0))
        d_hr_probabilities.update(d_hr_probability.item(), lr.size(0))
        d_sr_probabilities.update(d_sr_probability.item(), lr.size(0))
        psnres.update(psnr.item(), lr.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        iters = index + epoch * batches + 1
        writer.add_scalar("Train/D_Loss", d_loss.item(), iters)
        writer.add_scalar("Train/G_Loss", g_loss.item(), iters)
        writer.add_scalar("Train/Pixel_Loss", pixel_loss.item(), iters)
        writer.add_scalar("Train/Content_Loss", content_loss.item(), iters)
        writer.add_scalar("Train/Adversarial_Loss", adversarial_loss.item(), iters)
        writer.add_scalar("Train/D(HR)_Probability", d_hr_probability.item(), iters)
        writer.add_scalar("Train/D(SR)_Probability", d_sr_probability.item(), iters)
        if index % print_frequency == 0 and index != 0:
            progress.display(index)


def validate(model, valid_dataloader, psnr_criterion, epoch, writer) -> float:
    batch_time = AverageMeter("Time", ":6.3f")
    psnres = AverageMeter("PSNR", ":4.2f")
    progress = ProgressMeter(len(valid_dataloader), [batch_time, psnres], prefix="Valid: ")

    # Put the generator in verification mode.
    model.eval()

    with torch.no_grad():
        end = time.time()
        for index, (lr, hr) in enumerate(valid_dataloader):
            lr = lr.to(device, non_blocking=True)
            hr = hr.to(device, non_blocking=True)

            # Mixed precision
            with amp.autocast():
                sr = model(lr)

            # measure accuracy and record loss
            psnr = 10. * torch.log10(1. / psnr_criterion(sr, hr))
            psnres.update(psnr.item(), hr.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if index % print_frequency == 0:
                progress.display(index)

        writer.add_scalar("Valid/PSNR", psnres.avg, epoch + 1)
        # Print evaluation indicators.
        print(f"* PSNR: {psnres.avg:4.2f}.\n")

    return psnres.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=":f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print("\t".join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = "{:" + str(num_digits) + "d}"
        return "[" + fmt + "/" + fmt.format(num_batches) + "]"


if __name__ == "__main__":
    main()

Load train dataset and valid dataset...
Load train dataset and valid dataset successfully.
Build ESRGAN model...
Build ESRGAN model successfully.
Define all loss functions...


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

Define all loss functions successfully.
Define all optimizer functions...
Define all optimizer functions successfully.
Define all optimizer scheduler...
Define all optimizer scheduler functions successfully.
Check whether the training weight is restored...
Check whether the training weight is restored successfully.
Start train ESRGAN model.
Valid: [0/4]	Time  9.603 ( 9.603)	PSNR 16.52 (16.52)
* PSNR: 16.22.

Valid: [0/4]	Time  8.082 ( 8.082)	PSNR 17.14 (17.14)
* PSNR: 16.68.

Valid: [0/4]	Time  8.529 ( 8.529)	PSNR 17.70 (17.70)
* PSNR: 17.22.

Valid: [0/4]	Time  8.114 ( 8.114)	PSNR 17.78 (17.78)
* PSNR: 17.35.

Valid: [0/4]	Time  8.325 ( 8.325)	PSNR 19.50 (19.50)
* PSNR: 18.99.

Valid: [0/4]	Time  7.874 ( 7.874)	PSNR 20.10 (20.10)
* PSNR: 19.57.

Valid: [0/4]	Time  7.825 ( 7.825)	PSNR 20.75 (20.75)
* PSNR: 20.17.

Valid: [0/4]	Time  8.101 ( 8.101)	PSNR 19.47 (19.47)
* PSNR: 18.94.

Valid: [0/4]	Time  8.144 ( 8.144)	PSNR 21.12 (21.12)
* PSNR: 20.50.

Valid: [0/4]	Time  7.829 ( 7.829)	PS