Skip to content

Commit

Permalink
Update reqs and other improvements (#110)
Browse files Browse the repository at this point in the history
* Add unknown set to text field

* Update reqs

* ROllback reqs

* Fix mypy

* Add drop_unk option to text field

* Fix flake8
  • Loading branch information
iitzco-asapp committed Oct 18, 2019
1 parent b6276d8 commit c7fb5cc
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
33 changes: 25 additions & 8 deletions flambe/field/text.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Dict
from typing import Optional, Dict, Set
from collections import OrderedDict as odict

import torch
Expand Down Expand Up @@ -40,7 +40,8 @@ def __init__(self, # nosec
embeddings: Optional[str] = None,
embeddings_format: str = 'glove',
embeddings_binary: bool = False,
unk_init_all: bool = False) -> None:
unk_init_all: bool = False,
drop_unknown: bool = False) -> None:
"""Initialize the TextField.
Parameters
Expand Down Expand Up @@ -77,6 +78,10 @@ def __init__(self, # nosec
If True, every token not provided in the input embeddings is
given a random embedding from a normal distribution.
Otherwise, all of them map to the '<unk>' token.
drop_unknown: bool
Whether to drop tokens that don't have embeddings
associated. Defaults to True.
Important: this flag will only work when using embeddings.
"""
self.tokenizer = tokenizer or WordTokenizer()
Expand All @@ -92,6 +97,9 @@ def __init__(self, # nosec
self.embeddings_binary = embeddings_binary
self.embedding_matrix: Optional[torch.Tensor] = None
self.unk_init_all = unk_init_all
self.drop_unknown = drop_unknown

self.unk_numericals: Set[int] = set()

self.vocab: Dict = odict()
specials = [pad_token, unk_token, sos_token, eos_token]
Expand Down Expand Up @@ -169,13 +177,16 @@ def setup(self, *data: np.ndarray) -> None:
if token in model:
self.vocab[token] = index = index + 1
embeddings_matrix.append(torch.tensor(model[token]))
elif self.unk_init_all:
# Give every OOV it's own embedding
self.vocab[token] = index = index + 1
embeddings_matrix.append(torch.randn(model.vector_size))
else:
# Collapse all OOV's to the same token id
self.vocab[token] = self.vocab[self.unk]
if self.unk_init_all:
# Give every OOV it's own embedding
self.vocab[token] = index = index + 1
embeddings_matrix.append(torch.randn(model.vector_size))
else:
# Collapse all OOV's to the same token
# id
self.vocab[token] = self.vocab[self.unk]
self.unk_numericals.add(self.vocab[token])
else:
self.vocab[token] = index = index + 1

Expand Down Expand Up @@ -219,6 +230,12 @@ def process(self, example: str) -> torch.Tensor: # type: ignore
token = self.unk

numerical = self.vocab[token] # type: ignore

if self.drop_unknown and \
self.embeddings is not None and numerical in self.unk_numericals:
# Don't add unknown tokens in case the flag is activated
continue

numericals.append(numerical)

return torch.tensor(numericals).long()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ requests>=2.21.0
Flask>=1.0
tensorboardx-hparams>=1.7
GitPython>=2.1.11
sru>=2.1.6
sru>=2.1.7
psutil>=5.6.0
pygments>=2.3.1
nltk>=3.4.1
Expand Down

0 comments on commit c7fb5cc

Please sign in to comment.