In [2]:
import os
import queue
import threading

import cv2
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

import imgproc

In [None]:
import os
import time

import torch
from torch import nn
from torch import optim
from torch.cuda import amp
from torch.optim import lr_scheduler
from torch.optim.swa_utils import AveragedModel
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import esrgan_config
import model
from dataset import CUDAPrefetcher, TrainValidImageDataset, TestImageDataset
from image_quality_assessment import PSNR, SSIM
from utils import load_state_dict, make_directory, save_checkpoint, AverageMeter, ProgressMeter


In [3]:
class TrainValidImageDataset(Dataset):
    """Define training/valid dataset loading methods.

    Args:
        image_dir (str): Train/Valid dataset address.
        gt_image_size (int): Ground-truth resolution image size.
        upscale_factor (int): Image up scale factor.
        mode (str): Data set loading method, the training data set is for data enhancement, and the
            verification dataset is not for data enhancement.
    """

    def __init__(
            self,
            lr_image_dir: str,
            hr_image_dir:str,
    ) -> None:
        super(TrainValidImageDataset, self).__init__()
        self.lr_img_names = [os.path.join(lr_image_dir, image_file_name) for image_file_name in os.listdir(lr_image_dir)]
        self.hr_img_names = [os.path.join(hr_image_dir, image_file_name) for image_file_name in os.listdir(hr_image_dir)]


    def __getitem__(self, batch_index: int):
        # Read a batch of image data
        lr_image = cv2.imread(self.lr_img_names_names[batch_index]).astype(np.float32) / 255.
        hr_image = cv2.imread(self.hr_img_names_names[batch_index]).astype(np.float32) / 255.

        hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)
        lr_image = cv2.cvtColor(lr_image, cv2.COLOR_BGR2RGB)

        # Convert image data into Tensor stream format (PyTorch).
        # Note: The range of input and output is between [0, 1]
        gt_tensor = imgproc.image_to_tensor(hr_image, False, False)
        lr_tensor = imgproc.image_to_tensor(lr_image, False, False)

        return {"gt": gt_tensor, "lr": lr_tensor}

    def __len__(self) -> int:
        return len(self.image_file_names)

In [None]:
def load_dataset():
    # Load train, test and valid datasets
    train_datasets = TrainValidImageDataset(esrgan)
    test_datasets = TestImageDataset(esrgan_config.test_gt_images_dir, esrgan_config.test_lr_images_dir)

    # Generator all dataloader
    train_dataloader = DataLoader(train_datasets,
                                  batch_size=esrgan_config.batch_size,
                                  shuffle=True,
                                  num_workers=esrgan_config.num_workers,
                                  pin_memory=True,
                                  drop_last=True,
                                  persistent_workers=True)
    test_dataloader = DataLoader(test_datasets,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=1,
                                 pin_memory=True,
                                 drop_last=False,
                                 persistent_workers=True)

    # Place all data on the preprocessing data loader
    train_prefetcher = CUDAPrefetcher(train_dataloader, esrgan_config.device)
    test_prefetcher = CUDAPrefetcher(test_dataloader, esrgan_config.device)

    return train_prefetcher, test_prefetcher