Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Merge pull request #68 from PetrochukM/update
Browse files Browse the repository at this point in the history
Release 0.4.0 - Encoder rewrite, variable sequence collate support, reduced memory usage, doctests, removed SRU
  • Loading branch information
PetrochukM committed Apr 3, 2019
2 parents aa50d77 + d944083 commit e852dae
Show file tree
Hide file tree
Showing 98 changed files with 1,247 additions and 1,860 deletions.
4 changes: 4 additions & 0 deletions .style.yapf
@@ -0,0 +1,4 @@
[style]
based_on_style = chromium
indent_width = 4
column_limit = 100
13 changes: 9 additions & 4 deletions .travis.yml
Expand Up @@ -2,15 +2,20 @@ dist: trusty
sudo: required

language: python
python:
- '3.6'
- '3.5'
matrix:
include:
- python: 3.6
dist: trusty
sudo: false
- python: 3.7
dist: xenial
sudo: true

cache: pip

notifications:
email: false

before_install: source build_tools/travis/before_install.sh
install: source build_tools/travis/install.sh
script: RUN_DOCS=true RUN_SLOW=true RUN_FLAKE8=true bash build_tools/travis/test_script.sh
Expand Down
25 changes: 11 additions & 14 deletions README.md
Expand Up @@ -19,7 +19,7 @@ Join our community, add datasets and neural network layers! Chat with us on [Git

## Installation

Make sure you have Python 3.5+ and PyTorch 0.4 or newer. You can then install `pytorch-nlp` using
Make sure you have Python 3.6+ and PyTorch 1.0+. You can then install `pytorch-nlp` using
pip:

pip install pytorch-nlp
Expand Down Expand Up @@ -50,35 +50,32 @@ train[0] # RETURNS: {'text': 'For a movie that gets..', 'sentiment': 'pos'}

### Apply [Neural Networks](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.nn.html) Layers

For example, from the neural network package, apply a Simple Recurrent Unit (SRU):
For example, from the neural network package, apply state-of-the-art LockedDropout:

```python
from torchnlp.nn import SRU
import torch
from torchnlp.nn import LockedDropout

input_ = torch.autograd.Variable(torch.randn(6, 3, 10))
sru = SRU(10, 20)
input_ = torch.randn(6, 3, 10)
dropout = LockedDropout(0.5)

# Apply a Simple Recurrent Unit to `input_`
sru(input_)
# RETURNS: (
# output [torch.FloatTensor (6x3x20)],
# hidden_state [torch.FloatTensor (2x3x20)]
# )
# Apply a LockedDropout to `input_`
dropout(input_)
# RETURNS: torch.FloatTensor (6x3x10)
```

### [Encode Text](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.text_encoders.html)
### [Encode Text](http://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.encoders.text.html)

Tokenize and encode text as a tensor. For example, a `WhitespaceEncoder` breaks text into terms whenever it encounters a whitespace character.

```python
from torchnlp.text_encoders import WhitespaceEncoder
from torchnlp.encoders.text import WhitespaceEncoder

# Create a `WhitespaceEncoder` with a corpus of text
encoder = WhitespaceEncoder(["now this ain't funny", "so don't you dare laugh"])

# Encode and decode phrases
encoder.encode("this ain't funny.") # RETURNS: torch.LongTensor([6, 7, 1])
encoder.encode("this ain't funny.") # RETURNS: torch.Tensor([6, 7, 1])
encoder.decode(encoder.encode("This ain't funny.")) # RETURNS: "this ain't funny."
```

Expand Down
8 changes: 4 additions & 4 deletions build_tools/travis/install.sh
Expand Up @@ -35,10 +35,10 @@ python -m spacy download en
python -m nltk.downloader perluniprops nonbreaking_prefixes

# Install PyTorch Dependancies
if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then
pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-linux_x86_64.whl
if [[ $TRAVIS_PYTHON_VERSION == '3.7' ]]; then
pip install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl
fi
if [[ $TRAVIS_PYTHON_VERSION == '3.5' ]]; then
pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl
if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then
pip install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl
fi
pip install torchvision
17 changes: 9 additions & 8 deletions build_tools/travis/test_script.sh
Expand Up @@ -7,22 +7,23 @@
# Exit immediately if a command exits with a non-zero status.
set -e

python --version

if [[ "$RUN_FLAKE8" == "true" ]]; then
flake8
fi
export PYTHONPATH=.

python --version

if [[ "$RUN_DOCS" == "true" ]]; then
make -C docs html
fi

if [[ "$RUN_FLAKE8" == "true" ]]; then
flake8 torchnlp/
flake8 tests/
fi

run_tests() {
TEST_CMD="python -m pytest tests/ torchnlp/ --verbose --durations=20 --cov=torchnlp --doctest-modules"
if [[ "$RUN_SLOW" == "true" ]]; then
TEST_CMD="py.test -v --durations=20 --cov=torchnlp --runslow"
else
TEST_CMD="py.test -v --durations=20 --cov=torchnlp"
TEST_CMD="$TEST_CMD --runslow"
fi
$TEST_CMD
}
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Expand Up @@ -16,7 +16,7 @@ and text encoders. It's open-source software, released under the BSD3 license.
source/torchnlp.datasets
source/torchnlp.word_to_vector
source/torchnlp.nn
source/torchnlp.text_encoders
source/torchnlp.encoders
source/torchnlp.samplers
source/torchnlp.metrics
source/torchnlp.utils
Expand Down
15 changes: 15 additions & 0 deletions docs/source/torchnlp.encoders.rst
@@ -0,0 +1,15 @@
torchnlp.encoders package
===============================

The ``torchnlp.encoders`` package supports encoding objects as a vector
:class:`torch.Tensor` and decoding a vector :class:`torch.Tensor` back.

.. automodule:: torchnlp.encoders
:members:
:undoc-members:
:show-inheritance:

.. automodule:: torchnlp.encoders.text
:members:
:undoc-members:
:show-inheritance:
11 changes: 0 additions & 11 deletions docs/source/torchnlp.text_encoders.rst

This file was deleted.

2 changes: 1 addition & 1 deletion examples/awd-lstm-lm/README.md
@@ -1,4 +1,4 @@
`awd-lstm-lm` set the state-of-the-art in word level perplexities in 2017. With PyTorch NLP, we show that in 30 minutes, we were able to reduce the footprint of this repository by 4 files (185 lines of code). We employ the use of the [datasets package](https://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.datasets.html), [IdentityEncoder module](https://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.text_encoders.html#torchnlp.text_encoders.IdentityEncoder), [BPTTBatchSampler module](https://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.samplers.html#torchnlp.samplers.BPTTBatchSampler), [LockedDropout module](https://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.nn.html#torchnlp.nn.LockedDropout) and [WeightDrop module](https://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.nn.html#torchnlp.nn.WeightDrop)
`awd-lstm-lm` set the state-of-the-art in word level perplexities in 2017. With PyTorch NLP, we show that in 30 minutes, we were able to reduce the footprint of this repository by 4 files (185 lines of code). We employ the use of the [datasets package](https://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.datasets.html), [IdentityEncoder module](https://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.encoders.text.html#torchnlp.encoders.text.IdentityEncoder), [BPTTBatchSampler module](https://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.samplers.html#torchnlp.samplers.BPTTBatchSampler), [LockedDropout module](https://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.nn.html#torchnlp.nn.LockedDropout) and [WeightDrop module](https://pytorchnlp.readthedocs.io/en/latest/source/torchnlp.nn.html#torchnlp.nn.WeightDrop)


Below is the original README from the repository:
Expand Down
10 changes: 5 additions & 5 deletions examples/awd-lstm-lm/main.py
Expand Up @@ -96,17 +96,17 @@ def model_load(fn):


from torchnlp import datasets
from torchnlp.text_encoders import IdentityEncoder
from torchnlp.encoders import LabelEncoder
from torchnlp.samplers import BPTTBatchSampler

print('Producing dataset...')
train, val, test = getattr(datasets, args.data)(train=True, dev=True, test=True)

encoder = IdentityEncoder(train + val + test)
encoder = LabelEncoder(train + val + test)

train_data = encoder.encode(train)
val_data = encoder.encode(val)
test_data = encoder.encode(test)
train_data = encoder.batch_encode(train)
val_data = encoder.batch_encode(val)
test_data = encoder.batch_encode(test)

eval_batch_size = 10
test_batch_size = 1
Expand Down
5 changes: 3 additions & 2 deletions examples/snli/train.py
Expand Up @@ -13,7 +13,8 @@
from torchnlp.samplers import BucketBatchSampler
from torchnlp.datasets import snli_dataset
from torchnlp.utils import datasets_iterator
from torchnlp.text_encoders import WhitespaceEncoder, IdentityEncoder
from torchnlp.encoders.text import WhitespaceEncoder
from torchnlp.encoders import LabelEncoder
from torchnlp import word_to_vector

from model import SNLIClassifier
Expand All @@ -38,7 +39,7 @@
sentence_encoder = WhitespaceEncoder(sentence_corpus)

label_corpus = [row['label'] for row in datasets_iterator(train, dev, test)]
label_encoder = IdentityEncoder(label_corpus)
label_encoder = LabelEncoder(label_corpus)

# Encode
for row in datasets_iterator(train, dev, test):
Expand Down
2 changes: 1 addition & 1 deletion examples/snli/util.py
Expand Up @@ -4,7 +4,7 @@

import torch

from torchnlp.utils import pad_batch
from torchnlp.encoders.text import pad_batch


def makedirs(name):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -10,7 +10,7 @@
# Testing + Code Coverage
codecov
coverage
pytest
pytest>=3.6
pytest-cov

# Linting
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -43,8 +43,8 @@ def find_version(*file_paths):
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
Expand All @@ -53,7 +53,7 @@ def find_version(*file_paths):
'Topic :: Software Development :: Libraries :: Python Modules',
],
keywords='pytorch nlp text torchtext torchnlp',
python_requires='>=3.5',
python_requires='>=3.6',

# Package info
packages=find_packages(exclude=['.vscode', 'build_tools', 'docs', 'tests']),
Expand Down
67 changes: 67 additions & 0 deletions tests/encoders/test_label_encoder.py
@@ -0,0 +1,67 @@
import pickle

import pytest
import torch

from torchnlp.encoders import LabelEncoder
from torchnlp.encoders.label_encoder import DEFAULT_UNKNOWN_TOKEN


@pytest.fixture
def label_encoder():
sample = ['people/deceased_person/place_of_death', 'symbols/name_source/namesakes']
return LabelEncoder(sample)


def test_label_encoder_no_reserved():
sample = ['people/deceased_person/place_of_death', 'symbols/name_source/namesakes']
label_encoder = LabelEncoder(sample, reserved_labels=[], unknown_index=None)

label_encoder.encode('people/deceased_person/place_of_death')

# No ``unknown_index`` defined causes ``RuntimeError`` if an unknown label is used.
with pytest.raises(RuntimeError):
label_encoder.encode('symbols/namesake/named_after')


def test_label_encoder_enforce_reversible(label_encoder):
label_encoder.enforce_reversible()

with pytest.raises(ValueError):
label_encoder.encode('symbols/namesake/named_after')

with pytest.raises(IndexError):
label_encoder.decode(torch.tensor(label_encoder.vocab_size))


def test_label_encoder_batch_encoding(label_encoder):
encoded = label_encoder.batch_encode(label_encoder.vocab)
assert torch.equal(encoded, torch.arange(label_encoder.vocab_size).view(-1))


def test_label_encoder_batch_decoding(label_encoder):
assert label_encoder.vocab == label_encoder.batch_decode(torch.arange(label_encoder.vocab_size))


def test_label_encoder_vocab(label_encoder):
assert len(label_encoder.vocab) == 3
assert len(label_encoder.vocab) == label_encoder.vocab_size


def test_label_encoder_unknown(label_encoder):
input_ = 'symbols/namesake/named_after'
output = label_encoder.encode(input_)
assert label_encoder.decode(output) == DEFAULT_UNKNOWN_TOKEN


def test_label_encoder_known(label_encoder):
input_ = 'symbols/namesake/named_after'
sample = ['people/deceased_person/place_of_death', 'symbols/name_source/namesakes']
sample.append(input_)
label_encoder = LabelEncoder(sample)
output = label_encoder.encode(input_)
assert label_encoder.decode(output) == input_


def test_label_encoder_is_pickleable(label_encoder):
pickle.dumps(label_encoder)
47 changes: 47 additions & 0 deletions tests/encoders/text/test_character_encoder.py
@@ -0,0 +1,47 @@
import pickle

import pytest

from torchnlp.encoders.text import CharacterEncoder
from torchnlp.encoders.text import DEFAULT_RESERVED_TOKENS
from torchnlp.encoders.text import DEFAULT_UNKNOWN_TOKEN


@pytest.fixture
def sample():
return ['The quick brown fox jumps over the lazy dog']


@pytest.fixture
def encoder(sample):
return CharacterEncoder(sample)


def test_character_encoder(encoder, sample):
input_ = 'english-language pangram'
output = encoder.encode(input_)
assert encoder.vocab_size == len(set(list(sample[0]))) + len(DEFAULT_RESERVED_TOKENS)
assert len(output) == len(input_)
assert encoder.decode(output) == input_.replace('-', DEFAULT_UNKNOWN_TOKEN)


def test_character_encoder_batch(encoder, sample):
input_ = 'english-language pangram'
longer_input_ = 'english-language pangram pangram'
encoded, lengths = encoder.batch_encode([input_, longer_input_])
assert encoded.shape[0] == 2
assert len(lengths) == 2
decoded = encoder.batch_decode(encoded, lengths=lengths)
assert decoded[0] == input_.replace('-', DEFAULT_UNKNOWN_TOKEN)
assert decoded[1] == longer_input_.replace('-', DEFAULT_UNKNOWN_TOKEN)


def test_character_encoder_min_occurrences(sample):
encoder = CharacterEncoder(sample, min_occurrences=10)
input_ = 'English-language pangram'
output = encoder.encode(input_)
assert encoder.decode(output) == ''.join([DEFAULT_UNKNOWN_TOKEN] * len(input_))


def test_is_pickleable(encoder):
pickle.dumps(encoder)
Expand Up @@ -2,9 +2,9 @@

import pytest

from torchnlp.text_encoders import DelimiterEncoder
from torchnlp.text_encoders import UNKNOWN_TOKEN
from torchnlp.text_encoders import EOS_TOKEN
from torchnlp.encoders.text import DelimiterEncoder
from torchnlp.encoders.text import DEFAULT_UNKNOWN_TOKEN
from torchnlp.encoders.text import DEFAULT_EOS_TOKEN


@pytest.fixture
Expand All @@ -16,7 +16,8 @@ def encoder():
def test_delimiter_encoder(encoder):
input_ = 'symbols/namesake/named_after'
output = encoder.encode(input_)
assert encoder.decode(output) == '/'.join(['symbols', UNKNOWN_TOKEN, UNKNOWN_TOKEN]) + EOS_TOKEN
assert encoder.decode(output) == '/'.join(
['symbols', DEFAULT_UNKNOWN_TOKEN, DEFAULT_UNKNOWN_TOKEN, DEFAULT_EOS_TOKEN])


def test_is_pickleable(encoder):
Expand Down

0 comments on commit e852dae

Please sign in to comment.