from typing import Dict, List
import itertools
from overrides import overrides
import torch
from allennlp.common.util import pad_sequence_to_length
from import Vocabulary
from import Token
from import TokenIndexer
class SingleIdTokenIndexer(TokenIndexer[int]):
This :class:`TokenIndexer` represents tokens as single integers.
namespace : ``str``, optional (default=``tokens``)
We will use this namespace in the :class:`Vocabulary` to map strings to indices.
lowercase_tokens : ``bool``, optional (default=``False``)
If ``True``, we will call ``token.lower()`` before getting an index for the token from the
start_tokens : ``List[str]``, optional (default=``None``)
These are prepended to the tokens provided to ``tokens_to_indices``.
end_tokens : ``List[str]``, optional (default=``None``)
These are appended to the tokens provided to ``tokens_to_indices``.
token_min_padding_length : ``int``, optional (default=``0``)
See :class:`TokenIndexer`.
# pylint: disable=no-self-use
def __init__(self,
namespace: str = 'tokens',
lowercase_tokens: bool = False,
start_tokens: List[str] = None,
end_tokens: List[str] = None,
token_min_padding_length: int = 0) -> None:
self.namespace = namespace
self.lowercase_tokens = lowercase_tokens
self._start_tokens = [Token(st) for st in (start_tokens or [])]
self._end_tokens = [Token(et) for et in (end_tokens or [])]
def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
# If `text_id` is set on the token (e.g., if we're using some kind of hash-based word
# encoding), we will not be using the vocab for this token.
if getattr(token, 'text_id', None) is None:
text = token.text
if self.lowercase_tokens:
text = text.lower()
counter[self.namespace][text] += 1
def tokens_to_indices(self,
tokens: List[Token],
vocabulary: Vocabulary,
index_name: str) -> Dict[str, List[int]]:
indices: List[int] = []
for token in itertools.chain(self._start_tokens, tokens, self._end_tokens):
if getattr(token, 'text_id', None) is not None:
# `text_id` being set on the token means that we aren't using the vocab, we just use
# this id instead.
text = token.text
if self.lowercase_tokens:
text = text.lower()
indices.append(vocabulary.get_token_index(text, self.namespace))
return {index_name: indices}
def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument
return {}
def as_padded_tensor(self,
tokens: Dict[str, List[int]],
desired_num_tokens: Dict[str, int],
padding_lengths: Dict[str, int]) -> Dict[str, torch.Tensor]: # pylint: disable=unused-argument
return {key: torch.LongTensor(pad_sequence_to_length(val, desired_num_tokens[key]))
for key, val in tokens.items()}
