Skip to content

Commit

Permalink
update loss registry
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Jul 6, 2022
1 parent 15510d4 commit fc36502
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 251 deletions.
19 changes: 12 additions & 7 deletions basicsr/losses/__init__.py
@@ -1,14 +1,19 @@
import importlib
from copy import deepcopy
from os import path as osp

from basicsr.utils import get_root_logger
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import LOSS_REGISTRY
from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
gradient_penalty_loss, r1_penalty)
from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty

__all__ = [
'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
'r1_penalty', 'g_path_regularize'
]
__all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize']

# automatically scan and import loss modules for registry
# scan all the files under the 'losses' folder and collect files ending with '_loss.py'
loss_folder = osp.dirname(osp.abspath(__file__))
loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')]
# import all the loss modules
_model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames]


def build_loss(opt):
Expand Down
239 changes: 0 additions & 239 deletions basicsr/losses/losses.py → basicsr/losses/basic_loss.py
@@ -1,6 +1,4 @@
import math
import torch
from torch import autograd as autograd
from torch import nn as nn
from torch.nn import functional as F

Expand Down Expand Up @@ -253,240 +251,3 @@ def _gram_mat(self, x):
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram


@LOSS_REGISTRY.register()
class GANLoss(nn.Module):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""

def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
super(GANLoss, self).__init__()
self.gan_type = gan_type
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val

if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
self.loss = self._wgan_loss
elif self.gan_type == 'wgan_softplus':
self.loss = self._wgan_softplus_loss
elif self.gan_type == 'hinge':
self.loss = nn.ReLU()
else:
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')

def _wgan_loss(self, input, target):
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return -input.mean() if target else input.mean()

def _wgan_softplus_loss(self, input, target):
"""wgan loss with soft plus. softplus is a smooth approximation to the
ReLU function.
In StyleGAN2, it is called:
Logistic loss for discriminator;
Non-saturating loss for generator.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return F.softplus(-input).mean() if target else F.softplus(input).mean()

def get_target_label(self, input, target_is_real):
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
return Tensor.
"""

if self.gan_type in ['wgan', 'wgan_softplus']:
return target_is_real
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
return input.new_ones(input.size()) * target_val

def forward(self, input, target_is_real, is_disc=False):
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
target_label = self.get_target_label(input, target_is_real)
if self.gan_type == 'hinge':
if is_disc: # for discriminators in hinge-gan
input = -input if target_is_real else input
loss = self.loss(1 + input).mean()
else: # for generators in hinge-gan
loss = -input.mean()
else: # other gan types
loss = self.loss(input, target_label)

# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight


@LOSS_REGISTRY.register()
class MultiScaleGANLoss(GANLoss):
"""
MultiScaleGANLoss accepts a list of predictions
"""

def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)

def forward(self, input, target_is_real, is_disc=False):
"""
The input is a list of tensors, or a list of (a list of tensors)
"""
if isinstance(input, list):
loss = 0
for pred_i in input:
if isinstance(pred_i, list):
# Only compute GAN loss for the last layer
# in case of multiscale feature matching
pred_i = pred_i[-1]
# Safe operation: 0-dim tensor calling self.mean() does nothing
loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
loss += loss_tensor
return loss / len(input)
else:
return super().forward(input, target_is_real, is_disc)


def r1_penalty(real_pred, real_img):
"""R1 regularization for discriminator. The core idea is to
penalize the gradient on real data alone: when the
generator distribution produces the true data distribution
and the discriminator is equal to 0 on the data manifold, the
gradient penalty ensures that the discriminator cannot create
a non-zero gradient orthogonal to the data manifold without
suffering a loss in the GAN game.
Ref:
Eq. 9 in Which training methods for GANs do actually converge.
"""
grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty


def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)

path_penalty = (path_lengths - path_mean).pow(2).mean()

return path_penalty, path_lengths.detach().mean(), path_mean.detach()


def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
"""Calculate gradient penalty for wgan-gp.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
weight (Tensor): Weight tensor. Default: None.
Returns:
Tensor: A tensor for gradient penalty.
"""

batch_size = real_data.size(0)
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))

# interpolate between real_data and fake_data
interpolates = alpha * real_data + (1. - alpha) * fake_data
interpolates = autograd.Variable(interpolates, requires_grad=True)

disc_interpolates = discriminator(interpolates)
gradients = autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]

if weight is not None:
gradients = gradients * weight

gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
if weight is not None:
gradients_penalty /= torch.mean(weight)

return gradients_penalty


@LOSS_REGISTRY.register()
class GANFeatLoss(nn.Module):
"""Define feature matching loss for gans
Args:
criterion (str): Support 'l1', 'l2', 'charbonnier'.
loss_weight (float): Loss weight. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""

def __init__(self, criterion='l1', loss_weight=1.0, reduction='mean'):
super(GANFeatLoss, self).__init__()
if criterion == 'l1':
self.loss_op = L1Loss(loss_weight, reduction)
elif criterion == 'l2':
self.loss_op = MSELoss(loss_weight, reduction)
elif criterion == 'charbonnier':
self.loss_op = CharbonnierLoss(loss_weight, reduction)
else:
raise ValueError(f'Unsupported loss mode: {criterion}. Supported ones are: l1|l2|charbonnier')

self.loss_weight = loss_weight

def forward(self, pred_fake, pred_real):
num_d = len(pred_fake)
loss = 0
for i in range(num_d): # for each discriminator
# last output is the final prediction, exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs): # for each layer output
unweighted_loss = self.loss_op(pred_fake[i][j], pred_real[i][j].detach())
loss += unweighted_loss / num_d
return loss * self.loss_weight

0 comments on commit fc36502

Please sign in to comment.