From bdeac7522e074b9106d01d1447057a5249bfedbd Mon Sep 17 00:00:00 2001 From: robinbg Date: Wed, 6 Oct 2021 23:32:51 +0800 Subject: [PATCH] Add AdaptiveLogSoftmaxWithLoss --- python/paddle/nn/__init__.py | 1 + python/paddle/nn/layer/__init__.py | 1 + python/paddle/nn/layer/loss.py | 264 ++++++++++++++++++++++++++++- 3 files changed, 264 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 98444e69d0b1b..b9c0b76b321a8 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -99,6 +99,7 @@ from .layer.loss import MarginRankingLoss # noqa: F401 from .layer.loss import CTCLoss # noqa: F401 from .layer.loss import SmoothL1Loss # noqa: F401 +from .layer.loss import AdaptiveLogSoftmaxWithLoss # noqa: F401 from .layer.norm import BatchNorm # noqa: F401 from .layer.norm import SyncBatchNorm # noqa: F401 from .layer.norm import GroupNorm # noqa: F401 diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 074dfac5108f9..6d2308ddd86c7 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -71,6 +71,7 @@ from .loss import MarginRankingLoss # noqa: F401 from .loss import CTCLoss # noqa: F401 from .loss import SmoothL1Loss # noqa: F401 +from .loss import AdaptiveLogSoftmaxWithLoss # noqa: F401 from .norm import BatchNorm1D # noqa: F401 from .norm import BatchNorm2D # noqa: F401 from .norm import BatchNorm3D # noqa: F401 diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 781e13867f243..8626d44173e4f 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -14,16 +14,21 @@ # limitations under the License. # TODO: define loss functions of neural network +from collections import namedtuple +from typing import List, Sequence + import numpy as np import paddle.fluid as fluid import paddle.fluid.core as core import paddle from .. import functional as F from paddle.fluid.framework import core, in_dygraph_mode, _varbase_creator -from .. import Layer +from .. import Layer, Sequential, LayerList +from paddle.nn.layer import Linear +from paddle import Tensor __all__ = [] - +_ASMoutput = namedtuple('_ASMoutput', ['output', 'loss']) class BCEWithLogitsLoss(Layer): r""" @@ -1207,3 +1212,258 @@ def forward(self, input, label): reduction=self.reduction, delta=self.delta, name=self.name) + + +class AdaptiveLogSoftmaxWithLoss(Layer): + r"""Efficient softmax approximation as described in + `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, + Moustapha Cissé, David Grangier, and Hervé Jégou + `__. + Adaptive softmax is an approximate strategy for training models with large + output spaces. It is most effective when the label distribution is highly + imbalanced, for example in natural language modelling, where the word + frequency distribution approximately follows the `Zipf's law`_. + Adaptive softmax partitions the labels into several clusters, according to + their frequency. These clusters may contain different number of targets + each. + Additionally, clusters containing less frequent labels assign lower + dimensional embeddings to those labels, which speeds up the computation. + For each minibatch, only clusters for which at least one target is + present are evaluated. + The idea is that the clusters which are accessed frequently + (like the first one, containing most frequent labels), should also be cheap + to compute -- that is, contain a small number of assigned labels. + We highly recommend taking a look at the original paper for more details. + * :attr:`cutoffs` should be an ordered Sequence of integers sorted + in the increasing order. + It controls number of clusters and the partitioning of targets into + clusters. For example setting ``cutoffs = [10, 100, 1000]`` + means that first `10` targets will be assigned + to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be + assigned to the first cluster, and targets `101, 102, ..., 1000` will be + assigned to the second cluster, while targets + `1001, 1002, ..., n_classes - 1` will be assigned + to the last, third cluster. + * :attr:`div_value` is used to compute the size of each additional cluster, + which is given as + :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`, + where :math:`idx` is the cluster index (with clusters + for less frequent words having larger indices, + and indices starting from :math:`1`). + * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the + adaptive softmax. See paper for details. Set to False in the official + implementation. + .. warning:: + Labels passed as inputs to this module should be sorted according to + their frequency. This means that the most frequent label should be + represented by the index `0`, and the least frequent + label should be represented by the index `n_classes - 1`. + .. note:: + This module returns a ``NamedTuple`` with ``output`` + and ``loss`` fields. See further documentation for details. + .. note:: + To compute log-probabilities for all classes, the ``log_prob`` + method can be used. + Args: + in_features (int): Number of features in the input tensor + n_classes (int): Number of classes in the dataset + cutoffs (Sequence): Cutoffs used to assign targets to their buckets + div_value (float, optional): value used as an exponent to compute sizes + of the clusters. Default: 4.0 + head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the + adaptive softmax. Default: ``False`` + Returns: + ``NamedTuple`` with ``output`` and ``loss`` fields: + * **output** is a Tensor of size ``N`` containing computed target + log probabilities for each example + * **loss** is a Scalar representing the computed negative + log likelihood loss + Shape: + - input: :math:`(N, \texttt{in\_features})` + - target: :math:`(N)` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}` + - output1: :math:`(N)` + - output2: ``Scalar`` + .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law + """ + + in_features: int + n_classes: int + cutoffs: List[int] + div_value: float + head_bias: bool + head: Linear + tail: LayerList + + def __init__( + self, + in_features: int, + n_classes: int, + cutoffs: Sequence[int], + div_value: float = 4., + head_bias: bool = False, + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super(AdaptiveLogSoftmaxWithLoss, self).__init__() + + cutoffs = list(cutoffs) + + if (cutoffs != sorted(cutoffs)) \ + or (min(cutoffs) <= 0) \ + or (max(cutoffs) > (n_classes - 1)) \ + or (len(set(cutoffs)) != len(cutoffs)) \ + or any([int(c) != c for c in cutoffs]): + + raise ValueError("cutoffs should be a sequence of unique, positive " + "integers sorted in an increasing order, where " + "each value is between 1 and n_classes-1") + + self.in_features = in_features + self.n_classes = n_classes + self.cutoffs = cutoffs + [n_classes] + self.div_value = div_value + self.head_bias = head_bias + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + self.head = Linear(self.in_features, self.head_size, bias=self.head_bias, + **factory_kwargs) + self.tail = LayerList() + + for i in range(self.n_clusters): + + hsz = int(self.in_features // (self.div_value ** (i + 1))) + osz = self.cutoffs[i + 1] - self.cutoffs[i] + + projection = Sequential( + Linear(self.in_features, hsz, bias=False, **factory_kwargs), + Linear(hsz, osz, bias=False, **factory_kwargs), + ) + + self.tail.append(projection) + + + def forward(self, input: Tensor, target: Tensor) -> _ASMoutput: + if input.shape[0]!= target.shape[0]: + raise RuntimeError('Input and target should have the same size ' + 'in the batch dimension.') + + used_rows = 0 + batch_size = target.shape[0] + + output = paddle.zeros([batch_size]) + gather_inds = paddle.empty([batch_size]) + + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + + low_idx = cutoff_values[i] + high_idx = cutoff_values[i + 1] + + target_mask = (target >= low_idx) & (target < high_idx) + row_indices = target_mask.nonzero().squeeze() + + if row_indices.numel() == 0: + continue + + if i == 0: + for i, j in enumerate(row_indices): + gather_inds[j,:] = target[target_mask][i,:] + + else: + relative_target = target[target_mask] - low_idx + input_subset = input.index_select(row_indices, 0) + + cluster_output = self.tail[i - 1](input_subset) + cluster_index = self.shortlist_size + i - 1 + + + for idx in row_indices: + gather_inds[idx, :] = cluster_index + + cluster_logprob = F.log_softmax(cluster_output, axis=1) + local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)).squeeze(1) + for i, j in enumerate(row_indices): + gather_inds[j,:] = local_logprob[i,:] + used_rows += row_indices.numel() + + if used_rows != batch_size: + raise RuntimeError("Target values should be in [0, {}], " + "but values in range [{}, {}] " + "were found. ".format(self.n_classes - 1, + target.min().item(), + target.max().item())) + + head_output = self.head(input) + head_logprob = F.log_softmax(head_output, axis=1) + output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze() + loss = (-output).mean() + + return _ASMoutput(output, loss) + + def _get_full_log_prob(self, input, head_output): + """ Given input tensor, and output of `self.head`, + compute the log of the full distribution """ + + out = paddle.empty((head_output.shape[0], self.n_classes)) + head_logprob = F.log_softmax(head_output, axis=1) + + out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size] + + for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): + cluster_output = self.tail[i](input) + cluster_logprob = F.log_softmax(cluster_output, axis=1) + output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1) + + out[:, start_idx:stop_idx] = output_logprob + + return out + + def log_prob(self, input: Tensor) -> Tensor: + r""" Computes log probabilities for all :math:`\texttt{n\_classes}` + Args: + input (Tensor): a minibatch of examples + Returns: + log-probabilities of for each class :math:`c` + in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a + parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N, \texttt{n\_classes})` + """ + + head_output = self.head(input) + return self._get_full_log_prob(input, head_output) + + def predict(self, input: Tensor) -> Tensor: + r""" This is equivalent to `self.log_pob(input).argmax(dim=1)`, + but is more efficient in some cases. + Args: + input (Tensor): a minibatch of examples + Returns: + output (Tensor): a class with the highest probability for each example + Shape: + - Input: :math:`(N, \texttt{in\_features})` + - Output: :math:`(N)` + """ + + head_output = self.head(input) + output = paddle.argmax(head_output, axis=1) + not_in_shortlist = (output >= self.shortlist_size) + all_in_shortlist = not (not_in_shortlist.any()) + + if all_in_shortlist: + return output + + elif not_in_shortlist.all(): + log_prob = self._get_full_log_prob(input, head_output) + return paddle.argmax(log_prob, axis=1) + + else: + log_prob = self._get_full_log_prob(input[not_in_shortlist], + head_output[not_in_shortlist]) + output[not_in_shortlist] = paddle.argmax(log_prob, axis=1) + return output \ No newline at end of file