Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Merge pull request #79 from PetrochukM/fix_enforce_reversible
Browse files Browse the repository at this point in the history
Fix `enforce_reversible`, `char_n_gram` test, and travis.
  • Loading branch information
PetrochukM committed Jul 12, 2019
2 parents 19e5001 + cbfca9b commit f57af13
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 20 deletions.
7 changes: 2 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
dist: trusty
sudo: required

language: python
matrix:
include:
- python: 3.6
dist: trusty
sudo: false
dist: xenial
sudo: true
- python: 3.7
dist: xenial
sudo: true
Expand Down
14 changes: 7 additions & 7 deletions build_tools/travis/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ fi
# TODO: Add a script similar to RTD to test locally with virtual environment

# Install requirements via pip
pip install -r requirements.txt
pip install -r requirements.txt --progress-bar off

# Optional Requirements
pip install spacy
pip install nltk
pip install sacremoses
pip install spacy --progress-bar off
pip install nltk --progress-bar off
pip install sacremoses --progress-bar off

# SpaCy English web model
python -m spacy download en
Expand All @@ -36,9 +36,9 @@ python -m nltk.downloader perluniprops nonbreaking_prefixes

# Install PyTorch Dependancies
if [[ $TRAVIS_PYTHON_VERSION == '3.7' ]]; then
pip install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl
pip install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl --progress-bar off
fi
if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then
pip install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl
pip install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl --progress-bar off
fi
pip install torchvision
pip install torchvision --progress-bar off
17 changes: 16 additions & 1 deletion tests/encoders/text/test_character_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torchnlp.encoders.text import CharacterEncoder
from torchnlp.encoders.text import DEFAULT_RESERVED_TOKENS
from torchnlp.encoders.text import DEFAULT_UNKNOWN_TOKEN
from torchnlp.encoders.text import DEFAULT_UNKNOWN_INDEX


@pytest.fixture
Expand All @@ -25,7 +26,21 @@ def test_character_encoder(encoder, sample):
assert encoder.decode(output) == input_.replace('-', DEFAULT_UNKNOWN_TOKEN)


def test_character_encoder_batch(encoder, sample):
def test_character_encoder__enforce_reversible(encoder):
encoder.enforce_reversible = True

with pytest.raises(ValueError):
encoder.decode(encoder.encode('english-language pangram'))

encoder.decode(encoder.encode('english language pangram'))

encoded = encoder.encode('english language pangram')
encoded[7] = DEFAULT_UNKNOWN_INDEX
with pytest.raises(ValueError):
encoder.decode(encoded)


def test_character_encoder_batch(encoder):
input_ = 'english-language pangram'
longer_input_ = 'english-language pangram pangram'
encoded, lengths = encoder.batch_encode([input_, longer_input_])
Expand Down
4 changes: 2 additions & 2 deletions tests/word_to_vector/test_char_n_gram.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import mock

from torchnlp.word_to_vector import CharNGram
from tests.word_to_vector.utils import urlretrieve_side_effect
from torchnlp.encoders.text import DEFAULT_UNKNOWN_TOKEN


Expand All @@ -11,7 +10,8 @@ def test_charngram_100d(mock_urlretrieve):
directory = 'tests/_test_data/char_n_gram/'

# Make sure URL has a 200 status
mock_urlretrieve.side_effect = urlretrieve_side_effect
# TODO: Skip for now due to SSL failure.
# mock_urlretrieve.side_effect = urlretrieve_side_effect

# Attempt to parse a subset of CharNGram
vectors = CharNGram(cache=directory)
Expand Down
10 changes: 6 additions & 4 deletions torchnlp/encoders/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def encode(self, object_):
"""
if self.enforce_reversible:
self.enforce_reversible = False
if self.decode(self.encode(object_)) != object_:
raise ValueError('Encoding is not reversible for "%s"' % object_)
encoded_decoded = self.decode(self.encode(object_))
self.enforce_reversible = True
if encoded_decoded != object_:
raise ValueError('Encoding is not reversible for "%s"' % object_)

return object_

Expand All @@ -51,9 +52,10 @@ def decode(self, encoded):
"""
if self.enforce_reversible:
self.enforce_reversible = False
if self.encode(self.decode(encoded)) != encoded:
raise ValueError('Decoding is not reversible for "%s"' % encoded)
decoded_encoded = self.encode(self.decode(encoded))
self.enforce_reversible = True
if decoded_encoded != encoded:
raise ValueError('Decoding is not reversible for "%s"' % encoded)

return encoded

Expand Down
18 changes: 18 additions & 0 deletions torchnlp/encoders/text/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,24 @@ def stack_and_pad_tensors(batch, padding_index=DEFAULT_PADDING_INDEX, dim=0):

class TextEncoder(Encoder):

def decode(self, encoded):
""" Decodes an object.
Args:
object_ (object): Encoded object.
Returns:
object: Object decoded.
"""
if self.enforce_reversible:
self.enforce_reversible = False
decoded_encoded = self.encode(self.decode(encoded))
self.enforce_reversible = True
if not torch.equal(decoded_encoded, encoded):
raise ValueError('Decoding is not reversible for "%s"' % encoded)

return encoded

def batch_encode(self, iterator, *args, dim=0, **kwargs):
"""
Args:
Expand Down
2 changes: 1 addition & 1 deletion torchnlp/word_to_vector/char_n_gram.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class CharNGram(_PretrainedWordVectors):
"""

name = 'charNgram.txt'
url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/'
url = ('https://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/'
'jmt_pre-trained_embeddings.tar.gz')

def __init__(self, **kwargs):
Expand Down

0 comments on commit f57af13

Please sign in to comment.