diff --git a/flambe/field/text.py b/flambe/field/text.py index 0a2ab81b..517c5fe8 100644 --- a/flambe/field/text.py +++ b/flambe/field/text.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict +from typing import Optional, Dict, Set from collections import OrderedDict as odict import torch @@ -40,7 +40,8 @@ def __init__(self, # nosec embeddings: Optional[str] = None, embeddings_format: str = 'glove', embeddings_binary: bool = False, - unk_init_all: bool = False) -> None: + unk_init_all: bool = False, + drop_unknown: bool = False) -> None: """Initialize the TextField. Parameters @@ -77,6 +78,10 @@ def __init__(self, # nosec If True, every token not provided in the input embeddings is given a random embedding from a normal distribution. Otherwise, all of them map to the '' token. + drop_unknown: bool + Whether to drop tokens that don't have embeddings + associated. Defaults to True. + Important: this flag will only work when using embeddings. """ self.tokenizer = tokenizer or WordTokenizer() @@ -92,6 +97,9 @@ def __init__(self, # nosec self.embeddings_binary = embeddings_binary self.embedding_matrix: Optional[torch.Tensor] = None self.unk_init_all = unk_init_all + self.drop_unknown = drop_unknown + + self.unk_numericals: Set[int] = set() self.vocab: Dict = odict() specials = [pad_token, unk_token, sos_token, eos_token] @@ -169,13 +177,16 @@ def setup(self, *data: np.ndarray) -> None: if token in model: self.vocab[token] = index = index + 1 embeddings_matrix.append(torch.tensor(model[token])) - elif self.unk_init_all: - # Give every OOV it's own embedding - self.vocab[token] = index = index + 1 - embeddings_matrix.append(torch.randn(model.vector_size)) else: - # Collapse all OOV's to the same token id - self.vocab[token] = self.vocab[self.unk] + if self.unk_init_all: + # Give every OOV it's own embedding + self.vocab[token] = index = index + 1 + embeddings_matrix.append(torch.randn(model.vector_size)) + else: + # Collapse all OOV's to the same token + # id + self.vocab[token] = self.vocab[self.unk] + self.unk_numericals.add(self.vocab[token]) else: self.vocab[token] = index = index + 1 @@ -219,6 +230,12 @@ def process(self, example: str) -> torch.Tensor: # type: ignore token = self.unk numerical = self.vocab[token] # type: ignore + + if self.drop_unknown and \ + self.embeddings is not None and numerical in self.unk_numericals: + # Don't add unknown tokens in case the flag is activated + continue + numericals.append(numerical) return torch.tensor(numericals).long() diff --git a/requirements.txt b/requirements.txt index 3e491154..b90ca692 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ requests>=2.21.0 Flask>=1.0 tensorboardx-hparams>=1.7 GitPython>=2.1.11 -sru>=2.1.6 +sru>=2.1.7 psutil>=5.6.0 pygments>=2.3.1 nltk>=3.4.1