Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Merge pull request #23 from floscha/embeddings-token-list
Browse files Browse the repository at this point in the history
Get vectors from token list
  • Loading branch information
PetrochukM committed Apr 26, 2018
2 parents 36f5aca + dfa987a commit 9fc48f5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
42 changes: 42 additions & 0 deletions tests/word_to_vector/test_fast_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,48 @@ def test_fasttext_simple(mock_urlretrieve):
os.remove(os.path.join(directory, 'wiki.simple.vec.pt'))


@mock.patch('urllib.request.urlretrieve')
def test_fasttext_list_arguments(mock_urlretrieve):
directory = 'tests/_test_data/fast_text/'

# Make sure URL has a 200 status
mock_urlretrieve.side_effect = urlretrieve_side_effect

# Load subset of FastText
vectors = FastText(language='simple', cache=directory)

# Test implementation of __getitem()__ for token list and tuple
list(vectors[['the', 'of']].shape) == [2, 300]
list(vectors[('the', 'of')].shape) == [2, 300]

# Clean up
os.remove(os.path.join(directory, 'wiki.simple.vec.pt'))


@mock.patch('urllib.request.urlretrieve')
def test_fasttext_non_list_or_tuple_raises_type_error(mock_urlretrieve):
directory = 'tests/_test_data/fast_text/'

# Make sure URL has a 200 status
mock_urlretrieve.side_effect = urlretrieve_side_effect

# Load subset of FastText
vectors = FastText(language='simple', cache=directory)

# Test implementation of __getitem()__ for invalid type
error_class = None

try:
vectors[None]
except Exception as e:
error_class = e.__class__

assert error_class is TypeError

# Clean up
os.remove(os.path.join(directory, 'wiki.simple.vec.pt'))


@mock.patch('urllib.request.urlretrieve')
def test_aligned_fasttext(mock_urlretrieve):
directory = 'tests/_test_data/fast_text/'
Expand Down
14 changes: 13 additions & 1 deletion torchnlp/word_to_vector/pretrained_word_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,24 @@ def __init__(self,
def __contains__(self, token):
return token in self.stoi

def __getitem__(self, token):
def _get_token_vector(self, token):
"""Return embedding for token or for UNK if token not in vocabulary"""
if token in self.stoi:
return self.vectors[self.stoi[token]]
else:
return self.unk_init(torch.Tensor(self.dim))

def __getitem__(self, tokens):
if isinstance(tokens, list) or isinstance(tokens, tuple):
vector_list = [self._get_token_vector(token) for token in tokens]
return torch.stack(vector_list)
elif isinstance(tokens, str):
token = tokens
return self._get_token_vector(token)
else:
raise TypeError("'__getitem__' method can only be used with types"
"'str', 'list', or 'tuple' as parameter")

def __len__(self):
return len(self.vectors)

Expand Down

0 comments on commit 9fc48f5

Please sign in to comment.