Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Merge pull request #46 from PetrochukM/issues
Browse files Browse the repository at this point in the history
Add Batch Encoding & Fix Issues 43 - 45
  • Loading branch information
PetrochukM committed Jun 2, 2018
2 parents fd94d21 + af25109 commit 0fda5b6
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 43 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ Join our community, add datasets and neural network layers! Chat with us on [Git

## Installation

Make sure you have Python 3.5+ and PyTorch 0.2.0 or newer. You can then install `pytorch-nlp` using
Make sure you have Python 3.5+ and PyTorch 0.4 or newer. You can then install `pytorch-nlp` using
pip:

pip install pytorch-nlp

Or to install the latest code via:

pip install git+https://github.com/PetrochukM/PyTorch-NLP.git

## Docs 📖

The complete documentation for PyTorch-NLP is available via [our ReadTheDocs website](https://pytorchnlp.readthedocs.io).
Expand Down
18 changes: 9 additions & 9 deletions tests/datasets/test_imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ def test_imdb_dataset_row(mock_urlretrieve):
train, test = imdb_dataset(directory=directory, test=True, train=True)
assert len(train) > 0
assert len(test) > 0
test = sorted(test, key=lambda r: len(r['text']))
assert test[0] == {
'text':
"My boyfriend and I went to watch The Guardian.At first I didn't want to watch it, " +
"but I loved the movie- It was definitely the best movie I have seen in sometime." +
"They portrayed the USCG very well, it really showed me what they do and I think " +
"they should really be appreciated more.Not only did it teach but it was a really " +
"good movie. The movie shows what the really do and how hard the job is.I think " +
"being a USCG would be challenging and very scary. It was a great movie all around. " +
"I would suggest this movie for anyone to see.The ending broke my heart but I know " +
"why he did it. The storyline was great I give it 2 thumbs up. I cried it was very " +
"emotional, I would give it a 20 if I could!",
"This movie was sadly under-promoted but proved to be truly exceptional. Entering " +
"the theatre I knew nothing about the film except that a friend wanted to see it." +
"<br /><br />I was caught off guard with the high quality of the film. I couldn't " +
"image Ashton Kutcher in a serious role, but his performance truly exemplified his " +
"character. This movie is exceptional and deserves our monetary support, unlike so " +
"many other movies. It does not come lightly for me to recommend any movie, but in " +
"this case I highly recommend that everyone see it.<br /><br />This films is Truly " +
"Exceptional!",
'sentiment':
'pos'
}
Expand Down
12 changes: 4 additions & 8 deletions tests/datasets/test_iwslt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,13 @@ def test_iwslt_dataset_row(mock_urlretrieve):
assert len(train) > 0
assert len(dev) > 0
assert len(test) > 0
assert train[0] == {
'en': "David Gallo: This is Bill Lange. I'm Dave Gallo.",
'de': 'David Gallo: Das ist Bill Lange. Ich bin Dave Gallo.'
}
train = sorted(train, key=lambda r: len(r['en']))
assert train[0] == {'en': 'Thank you.', 'de': 'Danke.'}

# Smoke test for iwslt_clean running twice
train, dev, test = iwslt_dataset(directory=iwslt_directory, test=True, dev=True, train=True)
assert train[0] == {
'en': "David Gallo: This is Bill Lange. I'm Dave Gallo.",
'de': 'David Gallo: Das ist Bill Lange. Ich bin Dave Gallo.'
}
train = sorted(train, key=lambda r: len(r['en']))
assert train[0] == {'en': 'Thank you.', 'de': 'Danke.'}

# Clean up
shutil.rmtree(os.path.join(iwslt_directory, 'en-de'))
9 changes: 9 additions & 0 deletions tests/text_encoders/test_character_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def test_character_encoder(encoder, sample):
assert encoder.decode(output) == input_.replace('-', UNKNOWN_TOKEN)


def test_character_batch_encoder(encoder, sample):
input_ = 'english-language pangram'
outputs = encoder.batch_encode([input_, input_])
assert len(outputs) == 2
for output in outputs:
assert len(output) == len(input_)
assert encoder.decode(output) == input_.replace('-', UNKNOWN_TOKEN)


def test_character_encoder_min_occurrences(sample):
encoder = CharacterEncoder(sample, min_occurrences=10)
input_ = 'English-language pangram'
Expand Down
27 changes: 21 additions & 6 deletions tests/text_encoders/test_spacy_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,33 @@


@pytest.fixture
def encoder():
input_ = 'This is a sentence'
def input_():
return ('This is a sentence')


@pytest.fixture
def encoder(input_):
return SpacyEncoder([input_])


def test_spacy_encoder(encoder):
input_ = 'This is a sentence'
def test_spacy_encoder(encoder, input_):
tokens = encoder.encode(input_)
assert encoder.decode(tokens) == input_


def test_spacy_encoder_issue_44():
# https://github.com/PetrochukM/PyTorch-NLP/issues/44
encoder = SpacyEncoder(["This ain't funny."])
assert 'ai' in encoder.vocab
assert 'n\'t' in encoder.vocab


def test_spacy_encoder_batch(encoder, input_):
tokens = encoder.batch_encode([input_, input_])
assert encoder.decode(tokens[0]) == input_
assert encoder.decode(tokens[1]) == input_


def test_spacy_encoder_not_installed_language():
error_message = ''
try:
Expand All @@ -34,8 +50,7 @@ def test_spacy_encoder_unsupported_language():
except Exception as e:
error_message = str(e)

assert error_message.startswith("No tokenizer available for language " +
"'python'.")
assert error_message.startswith("No tokenizer available for language " + "'python'.")


def test_is_pickleable(encoder):
Expand Down
3 changes: 3 additions & 0 deletions torchnlp/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def imdb_dataset(directory='data/',
training, and 25,000 for testing. There is additional unlabeled data for use as well. Raw text
and already processed bag of words formats are provided.
Note:
The order examples are returned is not guaranteed due to ``iglob``.
**Reference:** http://ai.stanford.edu/~amaas/data/sentiment/
Args:
Expand Down
3 changes: 3 additions & 0 deletions torchnlp/datasets/iwslt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def iwslt_dataset(
challenging due to their variety in topics, but are very benign as they are very thoroughly
rehearsed and planned, leading to easy to recognize and translate language.
Note:
The order examples are returned is not guaranteed due to ``iglob``.
References:
* http://workshop2017.iwslt.org/downloads/iwslt2017_proceeding_v2.pdf
* http://workshop2017.iwslt.org/
Expand Down
11 changes: 5 additions & 6 deletions torchnlp/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,8 @@ def get_moses_multi_bleu(hypotheses, references, lowercase=False):
"master/scripts/generic/multi-bleu.perl")
os.chmod(multi_bleu_path, 0o755)
except:
logger.info("Unable to fetch multi-bleu.perl script, using local.")
metrics_dir = os.path.dirname(os.path.realpath(__file__))
bin_dir = os.path.abspath(os.path.join(metrics_dir, "..", "..", "bin"))
multi_bleu_path = os.path.join(bin_dir, "tools/multi-bleu.perl")
logger.warning("Unable to fetch multi-bleu.perl script")
return None

# Dump hypotheses and references to tempfiles
hypothesis_file = tempfile.NamedTemporaryFile()
Expand All @@ -95,14 +93,15 @@ def get_moses_multi_bleu(hypotheses, references, lowercase=False):
bleu_out = bleu_out.decode("utf-8")
bleu_score = re.search(r"BLEU = (.+?),", bleu_out).group(1)
bleu_score = float(bleu_score)
bleu_score = np.float32(bleu_score)
except subprocess.CalledProcessError as error:
if error.output is not None:
logger.warning("multi-bleu.perl script returned non-zero exit code")
logger.warning(error.output)
bleu_score = np.float32(0.0)
bleu_score = None

# Close temp files
hypothesis_file.close()
reference_file.close()

return np.float32(bleu_score)
return bleu_score
35 changes: 22 additions & 13 deletions torchnlp/text_encoders/spacy_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from functools import partial

import torch

from torchnlp.text_encoders.reserved_tokens import EOS_INDEX
from torchnlp.text_encoders.reserved_tokens import UNKNOWN_INDEX
from torchnlp.text_encoders.static_tokenizer_encoder import StaticTokenizerEncoder


Expand Down Expand Up @@ -46,8 +50,7 @@ def __init__(self, *args, **kwargs):
try:
import spacy
except ImportError:
print("Please install spaCy: "
"`pip install spacy`")
print("Please install spaCy: " "`pip install spacy`")
raise

# Use English as default when no language was specified
Expand All @@ -60,17 +63,23 @@ def __init__(self, *args, **kwargs):
if language in supported_languages:
# Load the spaCy language model if it has been installed
try:
nlp = spacy.load(language)
self.spacy = spacy.load(language, disable=['parser', 'tagger', 'ner'])
except OSError:
raise ValueError(("Language '{0}' not found. Install using " +
"spaCy: `python -m spacy download {0}`"
).format(language))

from spacy.tokenizer import Tokenizer
tokenizer = Tokenizer(nlp.vocab)
"spaCy: `python -m spacy download {0}`").format(language))
else:
raise ValueError(("No tokenizer available for language '%s'. " +
"Currently supported are %s")
% (language, supported_languages))

super().__init__(*args, tokenize=partial(_tokenize, tokenizer=tokenizer), **kwargs)
raise ValueError(
("No tokenizer available for language '%s'. " + "Currently supported are %s") %
(language, supported_languages))

super().__init__(*args, tokenize=partial(_tokenize, tokenizer=self.spacy), **kwargs)

def batch_encode(self, texts, eos_index=EOS_INDEX, unknown_index=UNKNOWN_INDEX):
return_ = []
for tokens in self.spacy.pipe(texts, n_threads=-1):
text = [token.text for token in tokens]
vector = [self.stoi.get(token, unknown_index) for token in text]
if self.append_eos:
vector.append(eos_index)
return_.append(torch.LongTensor(vector))
return return_
4 changes: 4 additions & 0 deletions torchnlp/text_encoders/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ def encode(self, string): # pragma: no cover
""" Returns a :class:`torch.LongTensor` encoding of the `text`. """
raise NotImplementedError

def batch_encode(self, strings, *args, **kwargs):
""" Returns a :class:`list` of :class:`torch.LongTensor` encoding of the `text`. """
return [self.encode(s, *args, **kwargs) for s in strings]

def decode(self, tensor): # pragma: no cover
""" Given a :class:`torch.Tensor`, returns a :class:`str` representing the decoded text.
Expand Down

0 comments on commit 0fda5b6

Please sign in to comment.