diff --git a/celldetection/ops/loss.py b/celldetection/ops/loss.py index 5f3a3e4..3df084d 100644 --- a/celldetection/ops/loss.py +++ b/celldetection/ops/loss.py @@ -2,10 +2,10 @@ import torch.nn.functional as F from torch import Tensor -__all__ = ['reduce_loss', 'log_margin_loss', 'margin_loss'] +__all__ = ['reduce_loss', 'log_margin_loss', 'margin_loss', 'r1_regularization'] -def reduce_loss(x: Tensor, reduction: str): +def reduce_loss(x: Tensor, reduction: str, **kwargs): """Reduce loss. Reduces Tensor according to ``reduction``. @@ -13,6 +13,7 @@ def reduce_loss(x: Tensor, reduction: str): Args: x: Input. reduction: Reduction method. Must be a symbol of ``torch``. + **kwargs: Additional keyword arguments. Returns: Reduced Tensor. @@ -21,8 +22,8 @@ def reduce_loss(x: Tensor, reduction: str): return x fn = getattr(torch, reduction, None) if fn is None: - raise ValueError - return fn(x) + raise ValueError(f'Unknown reduction: {reduction}') + return fn(x, **kwargs) def log_margin_loss(inputs: Tensor, targets: Tensor, m_pos=.9, m_neg=None, exponent=1, reduction='mean', eps=1e-6): @@ -43,3 +44,41 @@ def margin_loss(inputs: Tensor, targets: Tensor, m_pos=.9, m_neg=None, exponent= neg = torch.pow(F.relu_(inputs - m_neg), exponent) loss = targets * pos + (1 - targets) * neg return reduce_loss(loss, reduction) + + +def r1_regularization(logits, inputs, gamma=1., reduction='sum'): + r"""R1 regularization. + + A gradient penalty regularization. + This regularization may for example be applied to a discriminator with real data: + + .. math:: + + R_1(\psi) &:= \frac{\gamma}{2} \mathbb E_{ p_{\mathcal D}(x)} \left[\|\nabla D_\psi(x)\|^2\right] + + References: + - https://arxiv.org/pdf/1801.04406.pdf (Eq. 9) + - https://arxiv.org/pdf/1705.09367.pdf + - https://arxiv.org/pdf/1711.09404.pdf + + Examples: + >>> real.requires_grad_(True) + ... real_logits = discriminator(real) + ... loss_d_real = F.softplus(-real_logits) + ... loss_d_r1 = r1_regularization(real_logits, real) + ... loss_d_real = (loss_d_r1 + loss_d_real).mean() + ... loss_d_real.backward() + ... real.requires_grad_(False) + + Args: + logits: Logits. + inputs: Inputs. + gamma: Gamma. + reduction: How to reduce all non-batch dimensions. E.g. ``'sum'`` or ``'mean'``. + + Returns: + Penalty Tensor[n]. + """ + grads = torch.autograd.grad(logits.sum(), inputs=inputs, create_graph=True, retain_graph=True, only_inputs=True)[0] + penalty = reduce_loss(grads.square(), reduction, dim=list(range(1, grads.ndim))) + return penalty * (gamma * .5)