Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Merge pull request #42 from benjamin-work/feature/remove-lambdas-for-…
Browse files Browse the repository at this point in the history
…pickle

remove lambdas for pickle
  • Loading branch information
PetrochukM committed Jun 2, 2018
2 parents aa23d5f + 644d105 commit fd94d21
Show file tree
Hide file tree
Showing 27 changed files with 307 additions and 59 deletions.
26 changes: 21 additions & 5 deletions tests/samplers/test_bptt_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
import pickle
import string

import pytest

from torchnlp.samplers import BPTTBatchSampler
from torchnlp.utils import sampler_to_iterator


def test_bptt_batch_sampler_drop_last():
@pytest.fixture
def alphabet():
return list(string.ascii_lowercase)


@pytest.fixture
def sampler(alphabet):
return BPTTBatchSampler(alphabet, bptt_length=2, batch_size=4, drop_last=True)


def test_bptt_batch_sampler_drop_last(sampler, alphabet):
# Test samplers iterate over chunks similar to:
# https://github.com/pytorch/examples/blob/c66593f1699ece14a4a2f4d314f1afb03c6793d9/word_language_model/main.py#L112
alphabet = list('abcdefghijklmnopqrstuvwxyz')
sampler = BPTTBatchSampler(alphabet, bptt_length=2, batch_size=4, drop_last=True)
list_ = list(sampler_to_iterator(alphabet, sampler))
assert list_[0] == [['a', 'b'], ['g', 'h'], ['m', 'n'], ['s', 't']]
assert len(sampler) == len(list_)


def test_bptt_batch_sampler():
alphabet = list('abcdefghijklmnopqrstuvwxyz')
def test_bptt_batch_sampler(alphabet):
sampler = BPTTBatchSampler(alphabet, bptt_length=2, batch_size=4, drop_last=False)
list_ = list(sampler_to_iterator(alphabet, sampler))
assert list_[0] == [['a', 'b'], ['h', 'i'], ['o', 'p'], ['u', 'v']]
Expand All @@ -27,3 +39,7 @@ def test_bptt_batch_sampler_example():
sampler = BPTTBatchSampler(
range(100), bptt_length=2, batch_size=3, drop_last=False, type_='target')
assert list(sampler)[0] == [slice(1, 3), slice(35, 37), slice(68, 70)]


def test_is_pickleable(sampler):
pickle.dumps(sampler)
15 changes: 13 additions & 2 deletions tests/samplers/test_bptt_sampler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import pickle
import random

import pytest

from torchnlp.samplers import BPTTSampler


def test_bptt_sampler_odd():
sampler = BPTTSampler(range(5), 2)
@pytest.fixture
def sampler():
return BPTTSampler(range(5), 2)


def test_bptt_sampler_odd(sampler):
assert list(sampler) == [slice(0, 2), slice(2, 4)]
assert len(sampler) == 2

Expand All @@ -19,3 +26,7 @@ def test_bptt_sampler_length():
for i in range(1, 1000):
sampler = BPTTSampler(range(i), random.randint(1, 17))
assert len(sampler) == len(list(sampler))


def test_is_pickleable(sampler):
pickle.dumps(sampler)
9 changes: 8 additions & 1 deletion tests/samplers/test_bucket_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import pickle

import torch
from torchnlp.samplers import BucketBatchSampler


Expand Down Expand Up @@ -56,3 +57,9 @@ def test_bucket_batch_sampler_sorted():
# Largest batch (4) is in first batch
for i, batch in enumerate(batches):
assert batch[0] == i


def test_pickleable():
data_source = [[1], [2], [3], [4], [5]]
sampler = BucketBatchSampler(data_source, batch_size=2, drop_last=False)
pickle.dumps(sampler)
8 changes: 8 additions & 0 deletions tests/samplers/test_noisy_sorted_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

from torchnlp.samplers import NoisySortedBatchSampler


Expand Down Expand Up @@ -49,3 +51,9 @@ def test_noisy_sorted_batch_sampler_sorted():
# Largest batch (4) is in first batch
for i, batch in enumerate(batches):
assert batch[0] == i


def test_pickleable():
data_source = [[1], [2], [3], [4], [5], [6]]
sampler = NoisySortedBatchSampler(data_source, batch_size=2, drop_last=False)
pickle.dumps(sampler)
8 changes: 8 additions & 0 deletions tests/samplers/test_noisy_sorted_sampler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

from torchnlp.samplers import NoisySortedSampler


Expand All @@ -24,3 +26,9 @@ def test_noisy_sorted_sampler_sort_key_noise():
indexes = list(NoisySortedSampler(data_source, sort_key=sort_key, sort_key_noise=0.25))
for i, j in enumerate(indexes):
assert i == j


def test_pickleable():
data_source = [[1], [2], [3], [4], [5], [6]]
sampler = NoisySortedSampler(data_source)
pickle.dumps(sampler)
8 changes: 8 additions & 0 deletions tests/samplers/test_shuffle_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

from torchnlp.samplers import ShuffleBatchSampler

from torchnlp.samplers import SortedSampler
Expand All @@ -20,3 +22,9 @@ def test_shuffle_batch_sampler_drop_last():
ShuffleBatchSampler(
SortedSampler(data_source, sort_key=sort_key), batch_size, drop_last=True))
assert len(batches) == 2


def test_pickleable():
data_source = [[1], [2], [3], [4], [5], [6]]
sampler = ShuffleBatchSampler(SortedSampler(data_source), batch_size=2, drop_last=False)
pickle.dumps(sampler)
8 changes: 8 additions & 0 deletions tests/samplers/test_sorted_sampler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

from torchnlp.samplers import SortedSampler


Expand All @@ -8,3 +10,9 @@ def test_sorted_sampler():
assert len(indexes) == len(data_source)
for i, j in enumerate(indexes):
assert i == j


def test_pickleable():
data_source = [[1], [2], [3], [4], [5], [6]]
sampler = SortedSampler(data_source)
pickle.dumps(sampler)
25 changes: 20 additions & 5 deletions tests/text_encoders/test_character_encoder.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
import pickle

import pytest

from torchnlp.text_encoders import CharacterEncoder
from torchnlp.text_encoders import UNKNOWN_TOKEN
from torchnlp.text_encoders.reserved_tokens import RESERVED_ITOS


def test_character_encoder():
sample = ['The quick brown fox jumps over the lazy dog']
encoder = CharacterEncoder(sample)
@pytest.fixture
def sample():
return ['The quick brown fox jumps over the lazy dog']


@pytest.fixture
def encoder(sample):
return CharacterEncoder(sample)


def test_character_encoder(encoder, sample):
input_ = 'english-language pangram'
output = encoder.encode(input_)
assert encoder.vocab_size == len(set(list(sample[0]))) + len(RESERVED_ITOS)
assert len(output) == len(input_)
assert encoder.decode(output) == input_.replace('-', UNKNOWN_TOKEN)


def test_character_encoder_min_occurrences():
sample = ['The quick brown fox jumps over the lazy dog']
def test_character_encoder_min_occurrences(sample):
encoder = CharacterEncoder(sample, min_occurrences=10)
input_ = 'English-language pangram'
output = encoder.encode(input_)
assert encoder.decode(output) == ''.join([UNKNOWN_TOKEN] * len(input_))


def test_is_pickleable(encoder):
pickle.dumps(encoder)
16 changes: 14 additions & 2 deletions tests/text_encoders/test_delimiter_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import pickle

import pytest

from torchnlp.text_encoders import DelimiterEncoder
from torchnlp.text_encoders import UNKNOWN_TOKEN
from torchnlp.text_encoders import EOS_TOKEN


def test_delimiter_encoder():
@pytest.fixture
def encoder():
sample = ['people/deceased_person/place_of_death', 'symbols/name_source/namesakes']
encoder = DelimiterEncoder('/', sample, append_eos=True)
return DelimiterEncoder('/', sample, append_eos=True)


def test_delimiter_encoder(encoder):
input_ = 'symbols/namesake/named_after'
output = encoder.encode(input_)
assert encoder.decode(output) == '/'.join(['symbols', UNKNOWN_TOKEN, UNKNOWN_TOKEN]) + EOS_TOKEN


def test_is_pickleable(encoder):
pickle.dumps(encoder)
20 changes: 15 additions & 5 deletions tests/text_encoders/test_identity_encoder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import pickle

import pytest

from torchnlp.text_encoders import IdentityEncoder
from torchnlp.text_encoders import UNKNOWN_TOKEN


def test_identity_encoder_unknown():
@pytest.fixture
def encoder():
sample = ['people/deceased_person/place_of_death', 'symbols/name_source/namesakes']
encoder = IdentityEncoder(sample)
return IdentityEncoder(sample)


def test_identity_encoder_unknown(encoder):
input_ = 'symbols/namesake/named_after'
output = encoder.encode(input_)
assert len(output) == 1
Expand All @@ -21,10 +29,12 @@ def test_identity_encoder_known():
assert encoder.decode(output) == input_


def test_identity_encoder_sequence():
def test_identity_encoder_sequence(encoder):
input_ = ['symbols/namesake/named_after', 'people/deceased_person/place_of_death']
sample = ['people/deceased_person/place_of_death', 'symbols/name_source/namesakes']
encoder = IdentityEncoder(sample)
output = encoder.encode(input_)
assert len(output) == 2
assert encoder.decode(output) == [UNKNOWN_TOKEN, 'people/deceased_person/place_of_death']


def test_is_pickleable(encoder):
pickle.dumps(encoder)
24 changes: 20 additions & 4 deletions tests/text_encoders/test_moses_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import pickle

import pytest

from torchnlp.text_encoders import MosesEncoder


def test_moses_encoder():
@pytest.fixture
def input_():
return ("This ain't funny. It's actually hillarious, yet double Ls. | [] < > [ ] & " +
"You're gonna shake it off? Don't?")


@pytest.fixture
def encoder(input_):
return MosesEncoder([input_])


def test_moses_encoder(encoder, input_):
# TEST adapted from example in http://www.nltk.org/_modules/nltk/tokenize/moses.html
input_ = ("This ain't funny. It's actually hillarious, yet double Ls. | [] < > [ ] & " +
"You're gonna shake it off? Don't?")
encoder = MosesEncoder([input_])
expected_tokens = [
'This', 'ain', '&apos;t', 'funny', '.', 'It', '&apos;s', 'actually', 'hillarious', ',',
'yet', 'double', 'Ls', '.', '&#124;', '&#91;', '&#93;', '&lt;', '&gt;', '&#91;', '&#93;',
Expand All @@ -16,3 +28,7 @@ def test_moses_encoder():
tokens = encoder.encode(input_)
assert [encoder.itos[i] for i in tokens] == expected_tokens
assert encoder.decode(tokens) == expected_decode


def test_is_pickleable(encoder):
pickle.dumps(encoder)
17 changes: 15 additions & 2 deletions tests/text_encoders/test_spacy_encoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import pickle

import pytest

from torchnlp.text_encoders import SpacyEncoder


def test_spacy_encoder():
@pytest.fixture
def encoder():
input_ = 'This is a sentence'
return SpacyEncoder([input_])


def test_spacy_encoder(encoder):
input_ = 'This is a sentence'
encoder = SpacyEncoder([input_])
tokens = encoder.encode(input_)
assert encoder.decode(tokens) == input_

Expand All @@ -27,3 +36,7 @@ def test_spacy_encoder_unsupported_language():

assert error_message.startswith("No tokenizer available for language " +
"'python'.")


def test_is_pickleable(encoder):
pickle.dumps(encoder)
24 changes: 24 additions & 0 deletions tests/text_encoders/test_static_tokenizer_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pickle

import pytest

from torchnlp.text_encoders import StaticTokenizerEncoder


@pytest.fixture
def input_():
return 'This is a sentence'


@pytest.fixture
def encoder(input_):
return StaticTokenizerEncoder([input_])


def test_static_tokenizer_encoder(encoder, input_):
tokens = encoder.encode(input_)
assert encoder.decode(tokens) == input_


def test_is_pickleable(encoder):
pickle.dumps(encoder)
30 changes: 18 additions & 12 deletions tests/text_encoders/test_subword_encoder.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,37 @@
import unittest
import pickle

import pytest

from torchnlp.text_encoders import SubwordEncoder
from torchnlp.text_encoders import EOS_INDEX


class SubwordEncoderTest(unittest.TestCase):
class TestSubwordEncoder:

def setUp(self):
self.corpus = [
@pytest.fixture(scope='module')
def corpus(self):
return [
"One morning I shot an elephant in my pajamas. How he got in my pajamas, I don't",
'know.', '', 'Groucho Marx',
"I haven't slept for 10 days... because that would be too long.", '', 'Mitch Hedberg'
]

def test_build_vocab_target_size(self):
@pytest.fixture
def encoder(self, corpus):
return SubwordEncoder(corpus, target_vocab_size=86, min_occurrences=2, max_occurrences=6)

def test_build_vocab_target_size(self, encoder):
# NOTE: `target_vocab_size` is approximate; therefore, it may not be exactly the target size
encoder = SubwordEncoder(
self.corpus, target_vocab_size=86, min_occurrences=2, max_occurrences=6)
assert len(encoder.vocab) == 86

def test_encode(self):
encoder = SubwordEncoder(
self.corpus, target_vocab_size=86, min_occurrences=2, max_occurrences=6)
def test_encode(self, encoder):
input_ = 'This has UPPER CASE letters that are out of alphabet'
assert encoder.decode(encoder.encode(input_)) == input_

def test_eos(self):
encoder = SubwordEncoder(self.corpus, append_eos=True)
def test_eos(self, corpus):
encoder = SubwordEncoder(corpus, append_eos=True)
input_ = 'This is a sentence'
assert encoder.encode(input_)[-1] == EOS_INDEX

def test_is_pickleable(self, encoder):
pickle.dumps(encoder)

0 comments on commit fd94d21

Please sign in to comment.