From fc36502cc406f3b23b97c3910483d95ce232625a Mon Sep 17 00:00:00 2001 From: Xintao Date: Wed, 6 Jul 2022 22:53:10 +0800 Subject: [PATCH] update loss registry --- basicsr/losses/__init__.py | 19 +- basicsr/losses/{losses.py => basic_loss.py} | 239 -------------------- basicsr/losses/gan_loss.py | 208 +++++++++++++++++ basicsr/models/__init__.py | 3 +- basicsr/models/stylegan2_model.py | 2 +- tests/test_losses/test_losses.py | 2 +- tests/test_models/test_sr_model.py | 2 +- 7 files changed, 224 insertions(+), 251 deletions(-) rename basicsr/losses/{losses.py => basic_loss.py} (51%) create mode 100644 basicsr/losses/gan_loss.py diff --git a/basicsr/losses/__init__.py b/basicsr/losses/__init__.py index b1570dd2d..70a172aee 100644 --- a/basicsr/losses/__init__.py +++ b/basicsr/losses/__init__.py @@ -1,14 +1,19 @@ +import importlib from copy import deepcopy +from os import path as osp -from basicsr.utils import get_root_logger +from basicsr.utils import get_root_logger, scandir from basicsr.utils.registry import LOSS_REGISTRY -from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize, - gradient_penalty_loss, r1_penalty) +from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty -__all__ = [ - 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss', - 'r1_penalty', 'g_path_regularize' -] +__all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] + +# automatically scan and import loss modules for registry +# scan all the files under the 'losses' folder and collect files ending with '_loss.py' +loss_folder = osp.dirname(osp.abspath(__file__)) +loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] +# import all the loss modules +_model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames] def build_loss(opt): diff --git a/basicsr/losses/losses.py b/basicsr/losses/basic_loss.py similarity index 51% rename from basicsr/losses/losses.py rename to basicsr/losses/basic_loss.py index 55436902e..d2e965526 100644 --- a/basicsr/losses/losses.py +++ b/basicsr/losses/basic_loss.py @@ -1,6 +1,4 @@ -import math import torch -from torch import autograd as autograd from torch import nn as nn from torch.nn import functional as F @@ -253,240 +251,3 @@ def _gram_mat(self, x): features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram - - -@LOSS_REGISTRY.register() -class GANLoss(nn.Module): - """Define GAN loss. - - Args: - gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. - real_label_val (float): The value for real label. Default: 1.0. - fake_label_val (float): The value for fake label. Default: 0.0. - loss_weight (float): Loss weight. Default: 1.0. - Note that loss_weight is only for generators; and it is always 1.0 - for discriminators. - """ - - def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): - super(GANLoss, self).__init__() - self.gan_type = gan_type - self.loss_weight = loss_weight - self.real_label_val = real_label_val - self.fake_label_val = fake_label_val - - if self.gan_type == 'vanilla': - self.loss = nn.BCEWithLogitsLoss() - elif self.gan_type == 'lsgan': - self.loss = nn.MSELoss() - elif self.gan_type == 'wgan': - self.loss = self._wgan_loss - elif self.gan_type == 'wgan_softplus': - self.loss = self._wgan_softplus_loss - elif self.gan_type == 'hinge': - self.loss = nn.ReLU() - else: - raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') - - def _wgan_loss(self, input, target): - """wgan loss. - - Args: - input (Tensor): Input tensor. - target (bool): Target label. - - Returns: - Tensor: wgan loss. - """ - return -input.mean() if target else input.mean() - - def _wgan_softplus_loss(self, input, target): - """wgan loss with soft plus. softplus is a smooth approximation to the - ReLU function. - - In StyleGAN2, it is called: - Logistic loss for discriminator; - Non-saturating loss for generator. - - Args: - input (Tensor): Input tensor. - target (bool): Target label. - - Returns: - Tensor: wgan loss. - """ - return F.softplus(-input).mean() if target else F.softplus(input).mean() - - def get_target_label(self, input, target_is_real): - """Get target label. - - Args: - input (Tensor): Input tensor. - target_is_real (bool): Whether the target is real or fake. - - Returns: - (bool | Tensor): Target tensor. Return bool for wgan, otherwise, - return Tensor. - """ - - if self.gan_type in ['wgan', 'wgan_softplus']: - return target_is_real - target_val = (self.real_label_val if target_is_real else self.fake_label_val) - return input.new_ones(input.size()) * target_val - - def forward(self, input, target_is_real, is_disc=False): - """ - Args: - input (Tensor): The input for the loss module, i.e., the network - prediction. - target_is_real (bool): Whether the targe is real or fake. - is_disc (bool): Whether the loss for discriminators or not. - Default: False. - - Returns: - Tensor: GAN loss value. - """ - target_label = self.get_target_label(input, target_is_real) - if self.gan_type == 'hinge': - if is_disc: # for discriminators in hinge-gan - input = -input if target_is_real else input - loss = self.loss(1 + input).mean() - else: # for generators in hinge-gan - loss = -input.mean() - else: # other gan types - loss = self.loss(input, target_label) - - # loss_weight is always 1.0 for discriminators - return loss if is_disc else loss * self.loss_weight - - -@LOSS_REGISTRY.register() -class MultiScaleGANLoss(GANLoss): - """ - MultiScaleGANLoss accepts a list of predictions - """ - - def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): - super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) - - def forward(self, input, target_is_real, is_disc=False): - """ - The input is a list of tensors, or a list of (a list of tensors) - """ - if isinstance(input, list): - loss = 0 - for pred_i in input: - if isinstance(pred_i, list): - # Only compute GAN loss for the last layer - # in case of multiscale feature matching - pred_i = pred_i[-1] - # Safe operation: 0-dim tensor calling self.mean() does nothing - loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() - loss += loss_tensor - return loss / len(input) - else: - return super().forward(input, target_is_real, is_disc) - - -def r1_penalty(real_pred, real_img): - """R1 regularization for discriminator. The core idea is to - penalize the gradient on real data alone: when the - generator distribution produces the true data distribution - and the discriminator is equal to 0 on the data manifold, the - gradient penalty ensures that the discriminator cannot create - a non-zero gradient orthogonal to the data manifold without - suffering a loss in the GAN game. - - Ref: - Eq. 9 in Which training methods for GANs do actually converge. - """ - grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] - grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() - return grad_penalty - - -def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): - noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) - grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] - path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) - - path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) - - path_penalty = (path_lengths - path_mean).pow(2).mean() - - return path_penalty, path_lengths.detach().mean(), path_mean.detach() - - -def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): - """Calculate gradient penalty for wgan-gp. - - Args: - discriminator (nn.Module): Network for the discriminator. - real_data (Tensor): Real input data. - fake_data (Tensor): Fake input data. - weight (Tensor): Weight tensor. Default: None. - - Returns: - Tensor: A tensor for gradient penalty. - """ - - batch_size = real_data.size(0) - alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) - - # interpolate between real_data and fake_data - interpolates = alpha * real_data + (1. - alpha) * fake_data - interpolates = autograd.Variable(interpolates, requires_grad=True) - - disc_interpolates = discriminator(interpolates) - gradients = autograd.grad( - outputs=disc_interpolates, - inputs=interpolates, - grad_outputs=torch.ones_like(disc_interpolates), - create_graph=True, - retain_graph=True, - only_inputs=True)[0] - - if weight is not None: - gradients = gradients * weight - - gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() - if weight is not None: - gradients_penalty /= torch.mean(weight) - - return gradients_penalty - - -@LOSS_REGISTRY.register() -class GANFeatLoss(nn.Module): - """Define feature matching loss for gans - - Args: - criterion (str): Support 'l1', 'l2', 'charbonnier'. - loss_weight (float): Loss weight. Default: 1.0. - reduction (str): Specifies the reduction to apply to the output. - Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. - """ - - def __init__(self, criterion='l1', loss_weight=1.0, reduction='mean'): - super(GANFeatLoss, self).__init__() - if criterion == 'l1': - self.loss_op = L1Loss(loss_weight, reduction) - elif criterion == 'l2': - self.loss_op = MSELoss(loss_weight, reduction) - elif criterion == 'charbonnier': - self.loss_op = CharbonnierLoss(loss_weight, reduction) - else: - raise ValueError(f'Unsupported loss mode: {criterion}. Supported ones are: l1|l2|charbonnier') - - self.loss_weight = loss_weight - - def forward(self, pred_fake, pred_real): - num_d = len(pred_fake) - loss = 0 - for i in range(num_d): # for each discriminator - # last output is the final prediction, exclude it - num_intermediate_outputs = len(pred_fake[i]) - 1 - for j in range(num_intermediate_outputs): # for each layer output - unweighted_loss = self.loss_op(pred_fake[i][j], pred_real[i][j].detach()) - loss += unweighted_loss / num_d - return loss * self.loss_weight diff --git a/basicsr/losses/gan_loss.py b/basicsr/losses/gan_loss.py new file mode 100644 index 000000000..447eb6bb2 --- /dev/null +++ b/basicsr/losses/gan_loss.py @@ -0,0 +1,208 @@ +import math +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import LOSS_REGISTRY + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + target_label = self.get_target_label(input, target_is_real) + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +@LOSS_REGISTRY.register() +class MultiScaleGANLoss(GANLoss): + """ + MultiScaleGANLoss accepts a list of predictions + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) + + def forward(self, input, target_is_real, is_disc=False): + """ + The input is a list of tensors, or a list of (a list of tensors) + """ + if isinstance(input, list): + loss = 0 + for pred_i in input: + if isinstance(pred_i, list): + # Only compute GAN loss for the last layer + # in case of multiscale feature matching + pred_i = pred_i[-1] + # Safe operation: 0-dim tensor calling self.mean() does nothing + loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() + loss += loss_tensor + return loss / len(input) + else: + return super().forward(input, target_is_real, is_disc) + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty diff --git a/basicsr/models/__init__.py b/basicsr/models/__init__.py index 285ce3ef9..85796deae 100644 --- a/basicsr/models/__init__.py +++ b/basicsr/models/__init__.py @@ -8,8 +8,7 @@ __all__ = ['build_model'] # automatically scan and import model modules for registry -# scan all the files under the 'models' folder and collect files ending with -# '_model.py' +# scan all the files under the 'models' folder and collect files ending with '_model.py' model_folder = osp.dirname(osp.abspath(__file__)) model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] # import all the model modules diff --git a/basicsr/models/stylegan2_model.py b/basicsr/models/stylegan2_model.py index b03844378..d7da70812 100644 --- a/basicsr/models/stylegan2_model.py +++ b/basicsr/models/stylegan2_model.py @@ -8,7 +8,7 @@ from basicsr.archs import build_network from basicsr.losses import build_loss -from basicsr.losses.losses import g_path_regularize, r1_penalty +from basicsr.losses.gan_loss import g_path_regularize, r1_penalty from basicsr.utils import imwrite, tensor2img from basicsr.utils.registry import MODEL_REGISTRY from .base_model import BaseModel diff --git a/tests/test_losses/test_losses.py b/tests/test_losses/test_losses.py index 38b03deaa..6047253a6 100644 --- a/tests/test_losses/test_losses.py +++ b/tests/test_losses/test_losses.py @@ -1,7 +1,7 @@ import pytest import torch -from basicsr.losses.losses import CharbonnierLoss, L1Loss, MSELoss, WeightedTVLoss +from basicsr.losses.basic_loss import CharbonnierLoss, L1Loss, MSELoss, WeightedTVLoss @pytest.mark.parametrize('loss_class', [L1Loss, MSELoss, CharbonnierLoss]) diff --git a/tests/test_models/test_sr_model.py b/tests/test_models/test_sr_model.py index 2583fe19f..2e95cda52 100644 --- a/tests/test_models/test_sr_model.py +++ b/tests/test_models/test_sr_model.py @@ -4,7 +4,7 @@ from basicsr.archs.srresnet_arch import MSRResNet from basicsr.data.paired_image_dataset import PairedImageDataset -from basicsr.losses.losses import L1Loss, PerceptualLoss +from basicsr.losses.basic_loss import L1Loss, PerceptualLoss from basicsr.models.sr_model import SRModel