Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PaddlePaddle Hackathon] Add AdaptiveLogSoftmaxWithLoss #36267

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions python/paddle/nn/__init__.py
Expand Up @@ -99,6 +99,7 @@
from .layer.loss import MarginRankingLoss # noqa: F401
from .layer.loss import CTCLoss # noqa: F401
from .layer.loss import SmoothL1Loss # noqa: F401
from .layer.loss import AdaptiveLogSoftmaxWithLoss # noqa: F401
from .layer.norm import BatchNorm # noqa: F401
from .layer.norm import SyncBatchNorm # noqa: F401
from .layer.norm import GroupNorm # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions python/paddle/nn/layer/__init__.py
Expand Up @@ -71,6 +71,7 @@
from .loss import MarginRankingLoss # noqa: F401
from .loss import CTCLoss # noqa: F401
from .loss import SmoothL1Loss # noqa: F401
from .loss import AdaptiveLogSoftmaxWithLoss # noqa: F401
from .norm import BatchNorm1D # noqa: F401
from .norm import BatchNorm2D # noqa: F401
from .norm import BatchNorm3D # noqa: F401
Expand Down
264 changes: 262 additions & 2 deletions python/paddle/nn/layer/loss.py
Expand Up @@ -14,16 +14,21 @@
# limitations under the License.

# TODO: define loss functions of neural network
from collections import namedtuple
from typing import List, Sequence

import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle
from .. import functional as F
from paddle.fluid.framework import core, in_dygraph_mode, _varbase_creator
from .. import Layer
from .. import Layer, Sequential, LayerList
from paddle.nn.layer import Linear
from paddle import Tensor

__all__ = []

_ASMoutput = namedtuple('_ASMoutput', ['output', 'loss'])

class BCEWithLogitsLoss(Layer):
r"""
Expand Down Expand Up @@ -1207,3 +1212,258 @@ def forward(self, input, label):
reduction=self.reduction,
delta=self.delta,
name=self.name)


class AdaptiveLogSoftmaxWithLoss(Layer):
r"""Efficient softmax approximation as described in
`Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin,
Moustapha Cissé, David Grangier, and Hervé Jégou
<https://arxiv.org/abs/1609.04309>`__.
Adaptive softmax is an approximate strategy for training models with large
output spaces. It is most effective when the label distribution is highly
imbalanced, for example in natural language modelling, where the word
frequency distribution approximately follows the `Zipf's law`_.
Adaptive softmax partitions the labels into several clusters, according to
their frequency. These clusters may contain different number of targets
each.
Additionally, clusters containing less frequent labels assign lower
dimensional embeddings to those labels, which speeds up the computation.
For each minibatch, only clusters for which at least one target is
present are evaluated.
The idea is that the clusters which are accessed frequently
(like the first one, containing most frequent labels), should also be cheap
to compute -- that is, contain a small number of assigned labels.
We highly recommend taking a look at the original paper for more details.
* :attr:`cutoffs` should be an ordered Sequence of integers sorted
in the increasing order.
It controls number of clusters and the partitioning of targets into
clusters. For example setting ``cutoffs = [10, 100, 1000]``
means that first `10` targets will be assigned
to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be
assigned to the first cluster, and targets `101, 102, ..., 1000` will be
assigned to the second cluster, while targets
`1001, 1002, ..., n_classes - 1` will be assigned
to the last, third cluster.
* :attr:`div_value` is used to compute the size of each additional cluster,
which is given as
:math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`,
where :math:`idx` is the cluster index (with clusters
for less frequent words having larger indices,
and indices starting from :math:`1`).
* :attr:`head_bias` if set to True, adds a bias term to the 'head' of the
adaptive softmax. See paper for details. Set to False in the official
implementation.
.. warning::
Labels passed as inputs to this module should be sorted according to
their frequency. This means that the most frequent label should be
represented by the index `0`, and the least frequent
label should be represented by the index `n_classes - 1`.
.. note::
This module returns a ``NamedTuple`` with ``output``
and ``loss`` fields. See further documentation for details.
.. note::
To compute log-probabilities for all classes, the ``log_prob``
method can be used.
Args:
in_features (int): Number of features in the input tensor
n_classes (int): Number of classes in the dataset
cutoffs (Sequence): Cutoffs used to assign targets to their buckets
div_value (float, optional): value used as an exponent to compute sizes
of the clusters. Default: 4.0
head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the
adaptive softmax. Default: ``False``
Returns:
``NamedTuple`` with ``output`` and ``loss`` fields:
* **output** is a Tensor of size ``N`` containing computed target
log probabilities for each example
* **loss** is a Scalar representing the computed negative
log likelihood loss
Shape:
- input: :math:`(N, \texttt{in\_features})`
- target: :math:`(N)` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}`
- output1: :math:`(N)`
- output2: ``Scalar``
.. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law
"""

in_features: int
n_classes: int
cutoffs: List[int]
div_value: float
head_bias: bool
head: Linear
tail: LayerList

def __init__(
self,
in_features: int,
n_classes: int,
cutoffs: Sequence[int],
div_value: float = 4.,
head_bias: bool = False,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(AdaptiveLogSoftmaxWithLoss, self).__init__()

cutoffs = list(cutoffs)

if (cutoffs != sorted(cutoffs)) \
or (min(cutoffs) <= 0) \
or (max(cutoffs) > (n_classes - 1)) \
or (len(set(cutoffs)) != len(cutoffs)) \
or any([int(c) != c for c in cutoffs]):

raise ValueError("cutoffs should be a sequence of unique, positive "
"integers sorted in an increasing order, where "
"each value is between 1 and n_classes-1")

self.in_features = in_features
self.n_classes = n_classes
self.cutoffs = cutoffs + [n_classes]
self.div_value = div_value
self.head_bias = head_bias

self.shortlist_size = self.cutoffs[0]
self.n_clusters = len(self.cutoffs) - 1
self.head_size = self.shortlist_size + self.n_clusters

self.head = Linear(self.in_features, self.head_size, bias=self.head_bias,
**factory_kwargs)
self.tail = LayerList()

for i in range(self.n_clusters):

hsz = int(self.in_features // (self.div_value ** (i + 1)))
osz = self.cutoffs[i + 1] - self.cutoffs[i]

projection = Sequential(
Linear(self.in_features, hsz, bias=False, **factory_kwargs),
Linear(hsz, osz, bias=False, **factory_kwargs),
)

self.tail.append(projection)


def forward(self, input: Tensor, target: Tensor) -> _ASMoutput:
if input.shape[0]!= target.shape[0]:
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')

used_rows = 0
batch_size = target.shape[0]

output = paddle.zeros([batch_size])
gather_inds = paddle.empty([batch_size])

cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):

low_idx = cutoff_values[i]
high_idx = cutoff_values[i + 1]

target_mask = (target >= low_idx) & (target < high_idx)
row_indices = target_mask.nonzero().squeeze()

if row_indices.numel() == 0:
continue

if i == 0:
for i, j in enumerate(row_indices):
gather_inds[j,:] = target[target_mask][i,:]

else:
relative_target = target[target_mask] - low_idx
input_subset = input.index_select(row_indices, 0)

cluster_output = self.tail[i - 1](input_subset)
cluster_index = self.shortlist_size + i - 1


for idx in row_indices:
gather_inds[idx, :] = cluster_index

cluster_logprob = F.log_softmax(cluster_output, axis=1)
local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)).squeeze(1)
for i, j in enumerate(row_indices):
gather_inds[j,:] = local_logprob[i,:]
used_rows += row_indices.numel()

if used_rows != batch_size:
raise RuntimeError("Target values should be in [0, {}], "
"but values in range [{}, {}] "
"were found. ".format(self.n_classes - 1,
target.min().item(),
target.max().item()))

head_output = self.head(input)
head_logprob = F.log_softmax(head_output, axis=1)
output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
loss = (-output).mean()

return _ASMoutput(output, loss)

def _get_full_log_prob(self, input, head_output):
""" Given input tensor, and output of `self.head`,
compute the log of the full distribution """

out = paddle.empty((head_output.shape[0], self.n_classes))
head_logprob = F.log_softmax(head_output, axis=1)

out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size]

for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])):
cluster_output = self.tail[i](input)
cluster_logprob = F.log_softmax(cluster_output, axis=1)
output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1)

out[:, start_idx:stop_idx] = output_logprob

return out

def log_prob(self, input: Tensor) -> Tensor:
r""" Computes log probabilities for all :math:`\texttt{n\_classes}`
Args:
input (Tensor): a minibatch of examples
Returns:
log-probabilities of for each class :math:`c`
in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a
parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
Shape:
- Input: :math:`(N, \texttt{in\_features})`
- Output: :math:`(N, \texttt{n\_classes})`
"""

head_output = self.head(input)
return self._get_full_log_prob(input, head_output)

def predict(self, input: Tensor) -> Tensor:
r""" This is equivalent to `self.log_pob(input).argmax(dim=1)`,
but is more efficient in some cases.
Args:
input (Tensor): a minibatch of examples
Returns:
output (Tensor): a class with the highest probability for each example
Shape:
- Input: :math:`(N, \texttt{in\_features})`
- Output: :math:`(N)`
"""

head_output = self.head(input)
output = paddle.argmax(head_output, axis=1)
not_in_shortlist = (output >= self.shortlist_size)
all_in_shortlist = not (not_in_shortlist.any())

if all_in_shortlist:
return output

elif not_in_shortlist.all():
log_prob = self._get_full_log_prob(input, head_output)
return paddle.argmax(log_prob, axis=1)

else:
log_prob = self._get_full_log_prob(input[not_in_shortlist],
head_output[not_in_shortlist])
output[not_in_shortlist] = paddle.argmax(log_prob, axis=1)
return output