diff --git a/basicsr/archs/discriminator_arch.py b/basicsr/archs/discriminator_arch.py index 870acec30..33f9a8f1b 100644 --- a/basicsr/archs/discriminator_arch.py +++ b/basicsr/archs/discriminator_arch.py @@ -1,4 +1,6 @@ from torch import nn as nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm from basicsr.utils.registry import ARCH_REGISTRY @@ -83,3 +85,66 @@ def forward(self, x): feat = self.lrelu(self.linear1(feat)) out = self.linear2(feat) return out + + +@ARCH_REGISTRY.register(suffix='basicsr') +class UNetDiscriminatorSN(nn.Module): + """Defines a U-Net discriminator with spectral normalization (SN) + + It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + Arg: + num_in_ch (int): Channel number of inputs. Default: 3. + num_feat (int): Channel number of base intermediate features. Default: 64. + skip_connection (bool): Whether to use skip connections between U-Net. Default: True. + """ + + def __init__(self, num_in_ch, num_feat=64, skip_connection=True): + super(UNetDiscriminatorSN, self).__init__() + self.skip_connection = skip_connection + norm = spectral_norm + # the first convolution + self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) + # downsample + self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) + self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) + self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) + # upsample + self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) + self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) + self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) + # extra convolutions + self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) + self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) + + def forward(self, x): + # downsample + x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) + x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) + x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) + x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) + + # upsample + x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x4 = x4 + x2 + x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) + x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x5 = x5 + x1 + x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) + x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) + + if self.skip_connection: + x6 = x6 + x0 + + # extra convolutions + out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) + out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) + out = self.conv9(out) + + return out diff --git a/basicsr/archs/srvgg_arch.py b/basicsr/archs/srvgg_arch.py new file mode 100644 index 000000000..d8fe5ceb4 --- /dev/null +++ b/basicsr/archs/srvgg_arch.py @@ -0,0 +1,70 @@ +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import ARCH_REGISTRY + + +@ARCH_REGISTRY.register(suffix='basicsr') +class SRVGGNetCompact(nn.Module): + """A compact VGG-style network structure for super-resolution. + + It is a compact network structure, which performs upsampling in the last layer and no convolution is + conducted on the HR feature space. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_conv (int): Number of convolution layers in the body network. Default: 16. + upscale (int): Upsampling factor. Default: 4. + act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. + """ + + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): + super(SRVGGNetCompact, self).__init__() + self.num_in_ch = num_in_ch + self.num_out_ch = num_out_ch + self.num_feat = num_feat + self.num_conv = num_conv + self.upscale = upscale + self.act_type = act_type + + self.body = nn.ModuleList() + # the first conv + self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) + # the first activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the body structure + for _ in range(num_conv): + self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) + # activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the last conv + self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) + # upsample + self.upsampler = nn.PixelShuffle(upscale) + + def forward(self, x): + out = x + for i in range(0, len(self.body)): + out = self.body[i](out) + + out = self.upsampler(out) + # add the nearest upsampled image, so that the network learns the residual + base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') + out += base + return out diff --git a/basicsr/data/realesrgan_dataset.py b/basicsr/data/realesrgan_dataset.py new file mode 100644 index 000000000..1616e9b91 --- /dev/null +++ b/basicsr/data/realesrgan_dataset.py @@ -0,0 +1,193 @@ +import cv2 +import math +import numpy as np +import os +import os.path as osp +import random +import time +import torch +from torch.utils import data as data + +from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register(suffix='basicsr') +class RealESRGANDataset(data.Dataset): + """Dataset used for Real-ESRGAN model: + Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It loads gt (Ground-Truth) images, and augments them. + It also generates blur kernels and sinc kernels for generating low-quality images. + Note that the low-quality images are processed in tensors on GPUS for faster processing. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). + Please see more options in the codes. + """ + + def __init__(self, opt): + super(RealESRGANDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.gt_folder = opt['dataroot_gt'] + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.gt_folder] + self.io_backend_opt['client_keys'] = ['gt'] + if not self.gt_folder.endswith('.lmdb'): + raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + # disk backend with meta_info + # Each line in the meta_info describes the relative path to an image + with open(self.opt['meta_info']) as fin: + paths = [line.strip().split(' ')[0] for line in fin] + self.paths = [os.path.join(self.gt_folder, v) for v in paths] + + # blur settings for the first degradation + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability + self.blur_sigma = opt['blur_sigma'] + self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels + self.betap_range = opt['betap_range'] # betap used in plateau blur kernels + self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters + + # blur settings for the second degradation + self.blur_kernel_size2 = opt['blur_kernel_size2'] + self.kernel_list2 = opt['kernel_list2'] + self.kernel_prob2 = opt['kernel_prob2'] + self.blur_sigma2 = opt['blur_sigma2'] + self.betag_range2 = opt['betag_range2'] + self.betap_range2 = opt['betap_range2'] + self.sinc_prob2 = opt['sinc_prob2'] + + # a final sinc filter + self.final_sinc_prob = opt['final_sinc_prob'] + + self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 + # TODO: kernel range is now hard-coded, should be in the configure file + self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect + self.pulse_tensor[10, 10] = 1 + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # -------------------------------- Load gt images -------------------------------- # + # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. + gt_path = self.paths[index] + # avoid errors caused by high latency in reading files + retry = 3 + while retry > 0: + try: + img_bytes = self.file_client.get(gt_path, 'gt') + except (IOError, OSError) as e: + logger = get_root_logger() + logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') + # change another file to read + index = random.randint(0, self.__len__()) + gt_path = self.paths[index] + time.sleep(1) # sleep 1s for occasional server congestion + else: + break + finally: + retry -= 1 + img_gt = imfrombytes(img_bytes, float32=True) + + # -------------------- Do augmentation for training: flip, rotation -------------------- # + img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) + + # crop or pad to 400 + # TODO: 400 is hard-coded. You may change it accordingly + h, w = img_gt.shape[0:2] + crop_pad_size = 400 + # pad + if h < crop_pad_size or w < crop_pad_size: + pad_h = max(0, crop_pad_size - h) + pad_w = max(0, crop_pad_size - w) + img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) + # crop + if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: + h, w = img_gt.shape[0:2] + # randomly choose top and left coordinates + top = random.randint(0, h - crop_pad_size) + left = random.randint(0, w - crop_pad_size) + img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] + + # ------------------------ Generate kernels (used in the first degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob']: + # this sinc filter setting is for kernels ranging from [7, 21] + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel = random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + self.betag_range, + self.betap_range, + noise_range=None) + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------ Generate kernels (used in the second degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.opt['sinc_prob2']: + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel2 = random_mixed_kernels( + self.kernel_list2, + self.kernel_prob2, + kernel_size, + self.blur_sigma2, + self.blur_sigma2, [-math.pi, math.pi], + self.betag_range2, + self.betap_range2, + noise_range=None) + + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------------------- the final sinc kernel ------------------------------------- # + if np.random.uniform() < self.opt['final_sinc_prob']: + kernel_size = random.choice(self.kernel_range) + omega_c = np.random.uniform(np.pi / 3, np.pi) + sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) + sinc_kernel = torch.FloatTensor(sinc_kernel) + else: + sinc_kernel = self.pulse_tensor + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] + kernel = torch.FloatTensor(kernel) + kernel2 = torch.FloatTensor(kernel2) + + return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} + return return_d + + def __len__(self): + return len(self.paths) diff --git a/basicsr/data/realesrgan_paired_dataset.py b/basicsr/data/realesrgan_paired_dataset.py new file mode 100644 index 000000000..e014f8311 --- /dev/null +++ b/basicsr/data/realesrgan_paired_dataset.py @@ -0,0 +1,109 @@ +import os +from torch.utils import data as data +from torchvision.transforms.functional import normalize + +from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb +from basicsr.data.transforms import augment, paired_random_crop +from basicsr.utils import FileClient, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY + + +@DATASET_REGISTRY.register(suffix='basicsr') +class RealESRGANPairedDataset(data.Dataset): + """Paired image dataset for image restoration. + + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. + + There are three modes: + 1. 'lmdb': Use lmdb files. + If opt['io_backend'] == lmdb. + 2. 'meta_info': Use meta information file to generate paths. + If opt['io_backend'] != lmdb and opt['meta_info'] is not None. + 3. 'folder': Scan folders to generate paths. + The rest. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + dataroot_lq (str): Data root path for lq. + meta_info (str): Path for meta information file. + io_backend (dict): IO backend type and other kwarg. + filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. + Default: '{}'. + gt_size (int): Cropped patched size for gt patches. + use_hflip (bool): Use horizontal flips. + use_rot (bool): Use rotation (use vertical flip and transposing h + and w for implementation). + + scale (bool): Scale, which will be added automatically. + phase (str): 'train' or 'val'. + """ + + def __init__(self, opt): + super(RealESRGANPairedDataset, self).__init__() + self.opt = opt + self.file_client = None + self.io_backend_opt = opt['io_backend'] + # mean and std for normalizing the input images + self.mean = opt['mean'] if 'mean' in opt else None + self.std = opt['std'] if 'std' in opt else None + + self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] + self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] + self.io_backend_opt['client_keys'] = ['lq', 'gt'] + self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) + elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: + # disk backend with meta_info + # Each line in the meta_info describes the relative path to an image + with open(self.opt['meta_info']) as fin: + paths = [line.strip() for line in fin] + self.paths = [] + for path in paths: + gt_path, lq_path = path.split(', ') + gt_path = os.path.join(self.gt_folder, gt_path) + lq_path = os.path.join(self.lq_folder, lq_path) + self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) + else: + # disk backend + # it will scan the whole folder to get meta info + # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file + self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + scale = self.opt['scale'] + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + gt_path = self.paths[index]['gt_path'] + img_bytes = self.file_client.get(gt_path, 'gt') + img_gt = imfrombytes(img_bytes, float32=True) + lq_path = self.paths[index]['lq_path'] + img_bytes = self.file_client.get(lq_path, 'lq') + img_lq = imfrombytes(img_bytes, float32=True) + + # augmentation for training + if self.opt['phase'] == 'train': + gt_size = self.opt['gt_size'] + # random crop + img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + # normalize + if self.mean is not None or self.std is not None: + normalize(img_lq, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/basicsr/models/realesrgan_model.py b/basicsr/models/realesrgan_model.py new file mode 100644 index 000000000..eb0ec1ca7 --- /dev/null +++ b/basicsr/models/realesrgan_model.py @@ -0,0 +1,259 @@ +import numpy as np +import random +import torch +from collections import OrderedDict +from torch.nn import functional as F + +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from basicsr.data.transforms import paired_random_crop +from basicsr.models.srgan_model import SRGANModel +from basicsr.utils import DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.utils.registry import MODEL_REGISTRY + + +@MODEL_REGISTRY.register(suffix='basicsr') +class RealESRGANModel(SRGANModel): + """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ + + def __init__(self, opt): + super(RealESRGANModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.queue_size = opt.get('queue_size', 180) + + @torch.no_grad() + def _dequeue_and_enqueue(self): + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # the pool is full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get first b samples + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update the queue + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ + if self.is_train and self.opt.get('high_order_degradation', True): + # training data synthesis + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt_usm, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size, + self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue + self.gt_usm = self.usm_sharpener(self.gt) + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + else: + # for paired training or validation + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True + + def optimize_parameters(self, current_iter): + # usm sharpening + l1_gt = self.gt_usm + percep_gt = self.gt_usm + gan_gt = self.gt_usm + if self.opt['l1_gt_usm'] is False: + l1_gt = self.gt + if self.opt['percep_gt_usm'] is False: + percep_gt = self.gt + if self.opt['gan_gt_usm'] is False: + gan_gt = self.gt + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, l1_gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(gan_gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9 + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + self.log_dict = self.reduce_loss_dict(loss_dict) diff --git a/basicsr/models/realesrnet_model.py b/basicsr/models/realesrnet_model.py new file mode 100644 index 000000000..f5790918b --- /dev/null +++ b/basicsr/models/realesrnet_model.py @@ -0,0 +1,189 @@ +import numpy as np +import random +import torch +from torch.nn import functional as F + +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from basicsr.data.transforms import paired_random_crop +from basicsr.models.sr_model import SRModel +from basicsr.utils import DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.utils.registry import MODEL_REGISTRY + + +@MODEL_REGISTRY.register(suffix='basicsr') +class RealESRNetModel(SRModel): + """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. + + It is trained without GAN losses. + It mainly performs: + 1. randomly synthesize LQ images in GPU tensors + 2. optimize the networks with GAN training. + """ + + def __init__(self, opt): + super(RealESRNetModel, self).__init__(opt) + self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts + self.usm_sharpener = USMSharp().cuda() # do usm sharpening + self.queue_size = opt.get('queue_size', 180) + + @torch.no_grad() + def _dequeue_and_enqueue(self): + """It is the training pair pool for increasing the diversity in a batch. + + Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a + batch could not have different resize scaling factors. Therefore, we employ this training pair pool + to increase the degradation diversity in a batch. + """ + # initialize + b, c, h, w = self.lq.size() + if not hasattr(self, 'queue_lr'): + assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' + self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() + _, c, h, w = self.gt.size() + self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() + self.queue_ptr = 0 + if self.queue_ptr == self.queue_size: # the pool is full + # do dequeue and enqueue + # shuffle + idx = torch.randperm(self.queue_size) + self.queue_lr = self.queue_lr[idx] + self.queue_gt = self.queue_gt[idx] + # get first b samples + lq_dequeue = self.queue_lr[0:b, :, :, :].clone() + gt_dequeue = self.queue_gt[0:b, :, :, :].clone() + # update the queue + self.queue_lr[0:b, :, :, :] = self.lq.clone() + self.queue_gt[0:b, :, :, :] = self.gt.clone() + + self.lq = lq_dequeue + self.gt = gt_dequeue + else: + # only do enqueue + self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() + self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() + self.queue_ptr = self.queue_ptr + b + + @torch.no_grad() + def feed_data(self, data): + """Accept data from dataloader, and then add two-order degradations to obtain LQ images. + """ + if self.is_train and self.opt.get('high_order_degradation', True): + # training data synthesis + self.gt = data['gt'].to(self.device) + # USM sharpen the GT images + if self.opt['gt_usm'] is True: + self.gt = self.usm_sharpener(self.gt) + + self.kernel1 = data['kernel1'].to(self.device) + self.kernel2 = data['kernel2'].to(self.device) + self.sinc_kernel = data['sinc_kernel'].to(self.device) + + ori_h, ori_w = self.gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(self.gt, self.kernel1) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob'] + if np.random.uniform() < self.opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if np.random.uniform() < self.opt['second_blur_prob']: + out = filter2D(out, self.kernel2) + # random resize + updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] + if updown_type == 'up': + scale = np.random.uniform(1, self.opt['resize_range2'][1]) + elif updown_type == 'down': + scale = np.random.uniform(self.opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate( + out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) + # add noise + gray_noise_prob = self.opt['gray_noise_prob2'] + if np.random.uniform() < self.opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if np.random.uniform() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) + out = filter2D(out, self.sinc_kernel) + + # clamp and round + self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. + + # random crop + gt_size = self.opt['gt_size'] + self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale']) + + # training pair pool + self._dequeue_and_enqueue() + self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + else: + # for paired training or validation + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + self.gt_usm = self.usm_sharpener(self.gt) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + # do not use the synthetic process during validation + self.is_train = False + super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) + self.is_train = True diff --git a/options/train/RealESRGAN/train_realesrgan_x2plus.yml b/options/train/RealESRGAN/train_realesrgan_x2plus.yml new file mode 100644 index 000000000..3c98a0f37 --- /dev/null +++ b/options/train/RealESRGAN/train_realesrgan_x2plus.yml @@ -0,0 +1,186 @@ +# general settings +name: train_RealESRGANx2plus_400k_B12G4 +model_type: RealESRGANModel +scale: 2 +num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs +manual_seed: 0 + +# ----------------- options for synthesizing training data in RealESRGANModel ----------------- # +# USM the ground-truth +l1_gt_usm: True +percep_gt_usm: True +gan_gt_usm: False + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 0.5 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 0.4 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 0.8 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 0.5 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 0.4 +jpeg_range2: [30, 95] + +gt_size: 256 +queue_size: 180 + +# dataset and data loader settings +datasets: + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K + meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 3] + betag_range: [0.5, 4] + betap_range: [1, 2] + + blur_kernel_size2: 21 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.5] + betag_range2: [0.5, 4] + betap_range2: [1, 2] + + final_sinc_prob: 0.8 + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + scale: 2 + +network_d: + type: UNetDiscriminatorSN + num_in_ch: 3 + num_feat: 64 + skip_connection: True + +# path +path: + # use the pre-trained Real-ESRNet model + pretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pth + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + optim_d: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [400000] + gamma: 0.5 + + total_iter: 400000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1.0 + style_weight: 0 + range_norm: false + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1e-1 + + net_d_iters: 1 + net_d_init_iters: 0 + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/RealESRGAN/train_realesrgan_x4plus.yml b/options/train/RealESRGAN/train_realesrgan_x4plus.yml new file mode 100644 index 000000000..763199a35 --- /dev/null +++ b/options/train/RealESRGAN/train_realesrgan_x4plus.yml @@ -0,0 +1,185 @@ +# general settings +name: train_RealESRGANx4plus_400k_B12G4 +model_type: RealESRGANModel +scale: 4 +num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs +manual_seed: 0 + +# ----------------- options for synthesizing training data in RealESRGANModel ----------------- # +# USM the ground-truth +l1_gt_usm: True +percep_gt_usm: True +gan_gt_usm: False + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 0.5 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 0.4 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 0.8 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 0.5 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 0.4 +jpeg_range2: [30, 95] + +gt_size: 256 +queue_size: 180 + +# dataset and data loader settings +datasets: + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K + meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 3] + betag_range: [0.5, 4] + betap_range: [1, 2] + + blur_kernel_size2: 21 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.5] + betag_range2: [0.5, 4] + betap_range2: [1, 2] + + final_sinc_prob: 0.8 + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + +network_d: + type: UNetDiscriminatorSN + num_in_ch: 3 + num_feat: 64 + skip_connection: True + +# path +path: + # use the pre-trained Real-ESRNet model + pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + optim_d: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [400000] + gamma: 0.5 + + total_iter: 400000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1.0 + style_weight: 0 + range_norm: false + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1e-1 + + net_d_iters: 1 + net_d_init_iters: 0 + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/RealESRGAN/train_realesrnet_x2plus.yml b/options/train/RealESRGAN/train_realesrnet_x2plus.yml new file mode 100644 index 000000000..59b30737e --- /dev/null +++ b/options/train/RealESRGAN/train_realesrnet_x2plus.yml @@ -0,0 +1,145 @@ +# general settings +name: train_RealESRNetx2plus_1000k_B12G4 +model_type: RealESRNetModel +scale: 2 +num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs +manual_seed: 0 + +# ----------------- options for synthesizing training data in RealESRNetModel ----------------- # +gt_usm: True # USM the ground-truth + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 0.5 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 0.4 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 0.8 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 0.5 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 0.4 +jpeg_range2: [30, 95] + +gt_size: 256 +queue_size: 180 + +# dataset and data loader settings +datasets: + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K + meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 3] + betag_range: [0.5, 4] + betap_range: [1, 2] + + blur_kernel_size2: 21 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.5] + betag_range2: [0.5, 4] + betap_range2: [1, 2] + + final_sinc_prob: 0.8 + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + scale: 2 + +# path +path: + pretrain_network_g: experiments/pretrained_models/ESRGAN/RealESRGAN_x4plus.pth + param_key_g: params_ema + strict_load_g: False + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [1000000] + gamma: 0.5 + + total_iter: 1000000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/RealESRGAN/train_realesrnet_x4plus.yml b/options/train/RealESRGAN/train_realesrnet_x4plus.yml new file mode 100644 index 000000000..ae093705c --- /dev/null +++ b/options/train/RealESRGAN/train_realesrnet_x4plus.yml @@ -0,0 +1,144 @@ +# general settings +name: train_RealESRNetx4plus_1000k_B12G4 +model_type: RealESRNetModel +scale: 4 +num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs +manual_seed: 0 + +# ----------------- options for synthesizing training data in RealESRNetModel ----------------- # +gt_usm: True # USM the ground-truth + +# the first degradation process +resize_prob: [0.2, 0.7, 0.1] # up, down, keep +resize_range: [0.15, 1.5] +gaussian_noise_prob: 0.5 +noise_range: [1, 30] +poisson_scale_range: [0.05, 3] +gray_noise_prob: 0.4 +jpeg_range: [30, 95] + +# the second degradation process +second_blur_prob: 0.8 +resize_prob2: [0.3, 0.4, 0.3] # up, down, keep +resize_range2: [0.3, 1.2] +gaussian_noise_prob2: 0.5 +noise_range2: [1, 25] +poisson_scale_range2: [0.05, 2.5] +gray_noise_prob2: 0.4 +jpeg_range2: [30, 95] + +gt_size: 256 +queue_size: 180 + +# dataset and data loader settings +datasets: + train: + name: DF2K+OST + type: RealESRGANDataset + dataroot_gt: datasets/DF2K + meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt + io_backend: + type: disk + + blur_kernel_size: 21 + kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob: 0.1 + blur_sigma: [0.2, 3] + betag_range: [0.5, 4] + betap_range: [1, 2] + + blur_kernel_size2: 21 + kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] + kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] + sinc_prob2: 0.1 + blur_sigma2: [0.2, 1.5] + betag_range2: [0.5, 4] + betap_range2: [1, 2] + + final_sinc_prob: 0.8 + + gt_size: 256 + use_hflip: True + use_rot: False + + # data loader + use_shuffle: true + num_worker_per_gpu: 5 + batch_size_per_gpu: 12 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # Uncomment these for validation + # val: + # name: validation + # type: PairedImageDataset + # dataroot_gt: path_to_gt + # dataroot_lq: path_to_lq + # io_backend: + # type: disk + +# network structures +network_g: + type: RRDBNet + num_in_ch: 3 + num_out_ch: 3 + num_feat: 64 + num_block: 23 + num_grow_ch: 32 + +# path +path: + pretrain_network_g: experiments/pretrained_models/ESRGAN/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth + param_key_g: params_ema + strict_load_g: true + resume_state: ~ + +# training settings +train: + ema_decay: 0.999 + optim_g: + type: Adam + lr: !!float 2e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [1000000] + gamma: 0.5 + + total_iter: 1000000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: L1Loss + loss_weight: 1.0 + reduction: mean + +# Uncomment these for validation +# validation settings +# val: +# val_freq: !!float 5e3 +# save_img: True + +# metrics: +# psnr: # metric name +# type: calculate_psnr +# crop_border: 4 +# test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500