Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Add list and tuple support for pretrained word vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
floscha committed Apr 26, 2018
1 parent a467965 commit 6a59169
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
22 changes: 22 additions & 0 deletions tests/word_to_vector/test_fast_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ def test_fasttext_simple(mock_urlretrieve):
os.remove(os.path.join(directory, 'wiki.simple.vec.pt'))


@mock.patch('urllib.request.urlretrieve')
def test_fasttext_list_arguments(mock_urlretrieve):
directory = 'tests/_test_data/fast_text/'

# Make sure URL has a 200 status
mock_urlretrieve.side_effect = urlretrieve_side_effect

# Load subset of FastText
vectors = FastText(language='simple', cache=directory)

# Test implementation of __getitem()__ for token list and tuple
list(vectors[['the', 'of']].shape) == [2, 300]
list(vectors[('the', 'of')].shape) == [2, 300]

# Test implementation of __contains()__ for token list and tuple
assert ['the', 'of', 'a'] in vectors == [True, True, False]
assert ('the', 'of', 'a') in vectors == [True, True, False]

# Clean up
os.remove(os.path.join(directory, 'wiki.simple.vec.pt'))


@mock.patch('urllib.request.urlretrieve')
def test_aligned_fasttext(mock_urlretrieve):
directory = 'tests/_test_data/fast_text/'
Expand Down
25 changes: 22 additions & 3 deletions torchnlp/word_to_vector/pretrained_word_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,34 @@ def __init__(self,
self.name = name
self.cache(name, cache, url=url)

def __contains__(self, token):
return token in self.stoi
def __contains__(self, tokens):
if isinstance(tokens, list) or isinstance(tokens, tuple):
return [token in self.stoi for token in tokens]
elif isinstance(tokens, str):
token = tokens
return token in self.stoi
else:
raise TypeError("'__contains__' method can only be used with types"
"'str', 'list', or 'tuple' as parameter")

def __getitem__(self, token):
def _get_token(self, token):
"""Return embedding for token or for UNK if token not in vocabulary"""
if token in self.stoi:
return self.vectors[self.stoi[token]]
else:
return self.unk_init(torch.Tensor(self.dim))

def __getitem__(self, tokens):
if isinstance(tokens, list) or isinstance(tokens, tuple):
vector_list = [self._get_token(token) for token in tokens]
return torch.stack(vector_list)
elif isinstance(tokens, str):
token = tokens
return self._get_token(token)
else:
raise TypeError("'__getitem__' method can only be used with types"
"'str', 'list', or 'tuple' as parameter")

def __len__(self):
return len(self.vectors)

Expand Down

0 comments on commit 6a59169

Please sign in to comment.