Skip to content
Branch: master
Find file Copy path
Find file Copy path
7 contributors

Users who have contributed to this file

@joelgrus @schmmd @DeNeutoy @saujasv @bryant1410 @nafitzgerald @hayata-yamamoto
63 lines (47 sloc) 2.23 KB
from overrides import overrides
import torch
from allennlp.modules.seq2vec_encoders.seq2vec_encoder import Seq2VecEncoder
from allennlp.nn.util import get_lengths_from_binary_sequence_mask
class BagOfEmbeddingsEncoder(Seq2VecEncoder):
A `BagOfEmbeddingsEncoder` is a simple [`Seq2VecEncoder`](./ which simply sums
the embeddings of a sequence across the time dimension. The input to this module is of shape
`(batch_size, num_tokens, embedding_dim)`, and the output is of shape `(batch_size, embedding_dim)`.
# Parameters
embedding_dim : `int`, required
This is the input dimension to the encoder.
averaged : `bool`, optional (default=`False`)
If `True`, this module will average the embeddings across time, rather than simply summing
(ie. we will divide the summed embeddings by the length of the sentence).
def __init__(self, embedding_dim: int, averaged: bool = False) -> None:
self._embedding_dim = embedding_dim
self._averaged = averaged
def get_input_dim(self) -> int:
return self._embedding_dim
def get_output_dim(self) -> int:
return self._embedding_dim
def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None):
if mask is not None:
tokens = tokens * mask.unsqueeze(-1).float()
# Our input has shape `(batch_size, num_tokens, embedding_dim)`, so we sum out the `num_tokens`
# dimension.
summed = tokens.sum(1)
if self._averaged:
if mask is not None:
lengths = get_lengths_from_binary_sequence_mask(mask)
length_mask = lengths > 0
# Set any length 0 to 1, to avoid dividing by zero.
lengths = torch.max(lengths, lengths.new_ones(1))
lengths = tokens.new_full((1,), fill_value=tokens.size(1))
length_mask = None
summed = summed / lengths.unsqueeze(-1).float()
if length_mask is not None:
summed = summed * (length_mask > 0).float().unsqueeze(-1)
return summed
You can’t perform that action at this time.