Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Added SMT and WMT dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
PetrochukM committed Apr 2, 2018
1 parent f4987ff commit 349389e
Show file tree
Hide file tree
Showing 13 changed files with 384 additions and 28 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def find_version(*file_paths):
" metrics, neural network modules and text encoders. It's open-source software, released " +
"under the BSD3 license.",
license='BSD',
install_requires=['numpy', 'pandas', 'tqdm', 'ujson'],
install_requires=['numpy', 'pandas', 'tqdm', 'ujson', 'requests'],
classifiers=[
'Development Status :: 2 - Pre-Alpha',
'Intended Audience :: Developers',
Expand Down
Binary file added tests/_test_data/trainDevTestTrees_PTB.zip
Binary file not shown.
Binary file added tests/_test_data/wmt16_en_de/wmt16_en_de.tar.gz
Binary file not shown.
42 changes: 42 additions & 0 deletions tests/datasets/test_smt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import shutil

import mock

from torchnlp.datasets import smt_dataset
from tests.datasets.utils import urlretrieve_side_effect

directory = 'tests/_test_data/'


@mock.patch("urllib.request.urlretrieve")
def test_smt_dataset_row(mock_urlretrieve):
mock_urlretrieve.side_effect = urlretrieve_side_effect

# Check a row are parsed correctly
train, dev, test = smt_dataset(directory=directory, test=True, dev=True, train=True)
assert len(train) > 0
assert len(dev) > 0
assert len(test) > 0
assert train[5] == {
'text':
"Whether or not you 're enlightened by any of Derrida 's lectures on `` the other '' " +
"and `` the self , '' Derrida is an undeniably fascinating and playful fellow .",
'label':
'positive'
}
train = smt_dataset(directory=directory, train=True, subtrees=True)
assert train[3] == {'text': 'Rock', 'label': 'neutral'}

train = smt_dataset(directory=directory, train=True, subtrees=True, fine_grained=True)
assert train[4] == {
'text':
"is destined to be the 21st Century 's new `` Conan '' and that he 's going to make a" +
" splash even greater than Arnold Schwarzenegger , Jean-Claud Van Damme or Steven" +
" Segal .",
'label':
'positive'
}

# Clean up
shutil.rmtree(os.path.join(directory, 'trees'))
11 changes: 9 additions & 2 deletions tests/datasets/test_trec.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@ def test_penn_treebank_dataset_row(mock_urlretrieve):
assert len(train) > 0
assert len(test) > 0
assert train[:2] == [{
'label_fine': 'manner',
'label': 'DESC',
'text': 'How did serfdom develop in and then leave Russia ?'
}, {
'label_fine': 'cremat',
'label': 'ENTY',
'text': 'What films featured the character Popeye Doyle ?'
}]

train = trec_dataset(directory=directory, train=True, check_file=None, fine_grained=True)
assert train[:2] == [{
'label': 'manner',
'text': 'How did serfdom develop in and then leave Russia ?'
}, {
'label': 'cremat',
'text': 'What films featured the character Popeye Doyle ?'
}]
33 changes: 33 additions & 0 deletions tests/datasets/test_wmt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os

import mock

from torchnlp.datasets import wmt_dataset
from tests.datasets.utils import download_from_drive_side_effect

directory = 'tests/_test_data/wmt16_en_de'


@mock.patch("torchnlp.utils.download_from_drive")
def test_wmt_dataset(mock_download_from_drive):
mock_download_from_drive.side_effect = download_from_drive_side_effect

# Check a row are parsed correctly
train, dev, test = wmt_dataset(directory=directory, test=True, dev=True, train=True)
assert len(train) > 0
assert len(test) > 0
assert len(dev) > 0
assert train[0] == {
'en': 'Res@@ um@@ ption of the session',
'de': 'Wiederaufnahme der Sitzungsperiode'
}

# Clean up
os.remove(os.path.join(directory, 'bpe.32000'))
os.remove(os.path.join(directory, 'newstest2013.tok.bpe.32000.en'))
os.remove(os.path.join(directory, 'newstest2013.tok.bpe.32000.de'))
os.remove(os.path.join(directory, 'newstest2014.tok.bpe.32000.en'))
os.remove(os.path.join(directory, 'newstest2014.tok.bpe.32000.de'))
os.remove(os.path.join(directory, 'train.tok.clean.bpe.32000.de'))
os.remove(os.path.join(directory, 'train.tok.clean.bpe.32000.en'))
os.remove(os.path.join(directory, 'vocab.bpe.32000'))
6 changes: 6 additions & 0 deletions tests/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@
def urlretrieve_side_effect(url, **kwargs):
# TODO: Fix failure case if internet does not work
assert urllib.request.urlopen(url).getcode() == 200


# Check the URL requested is valid
def download_from_drive_side_effect(directory, filename, url, **kwargs):
# TODO: Fix failure case if internet does not work
assert urllib.request.urlopen(url).getcode() == 200
4 changes: 4 additions & 0 deletions torchnlp/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
from torchnlp.datasets.dataset import Dataset
from torchnlp.datasets.multi30k import multi30k_dataset
from torchnlp.datasets.iwslt import iwslt_dataset
from torchnlp.datasets.wmt import wmt_dataset
from torchnlp.datasets.smt import smt_dataset

__all__ = [
'Dataset',
'wmt_dataset',
'iwslt_dataset',
'multi30k_dataset',
'snli_dataset',
Expand All @@ -26,4 +29,5 @@
'reverse_dataset',
'count_dataset',
'zero_dataset',
'smt_dataset',
]
20 changes: 9 additions & 11 deletions torchnlp/datasets/multi30k.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ def multi30k_dataset(directory='data/multi30k/',
train=False,
dev=False,
test=False,
language_extensions=['en', 'de'],
train_filename='train',
dev_filename='val',
test_filename='test',
Expand Down Expand Up @@ -48,8 +47,6 @@ def multi30k_dataset(directory='data/multi30k/',
train (bool, optional): If to load the training split of the dataset.
dev (bool, optional): If to load the dev split of the dataset.
test (bool, optional): If to load the test split of the dataset.
language_extensions (:class:`list` of :class:`str`): List of language extensions ['en'|'de']
to load.
train_directory (str, optional): The directory of the training split.
dev_directory (str, optional): The directory of the dev split.
test_directory (str, optional): The directory of the test split.
Expand Down Expand Up @@ -77,17 +74,18 @@ def multi30k_dataset(directory='data/multi30k/',
ret = []
splits = [(train, train_filename), (dev, dev_filename), (test, test_filename)]
splits = [f for (requested, f) in splits if requested]

for filename in splits:
examples = []
for extension in language_extensions:
path = os.path.join(directory, filename + '.' + extension)
with open(path, 'r', encoding='utf-8') as f:
language_specific_examples = [l.strip() for l in f]

if len(examples) == 0:
examples = [{} for _ in range(len(language_specific_examples))]
for i, example in enumerate(language_specific_examples):
examples[i][extension] = example
en_path = os.path.join(directory, filename + '.en')
de_path = os.path.join(directory, filename + '.de')
en_file = [l.strip() for l in open(en_path, 'r', encoding='utf-8')]
de_file = [l.strip() for l in open(de_path, 'r', encoding='utf-8')]
assert len(en_file) == len(de_file)
for i in range(len(en_file)):
if en_file[i] != '' and de_file[i] != '':
examples.append({'en': en_file[i], 'de': de_file[i]})

ret.append(Dataset(examples))

Expand Down
119 changes: 119 additions & 0 deletions torchnlp/datasets/smt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os
import io

from torchnlp.utils import download_compressed_directory
from torchnlp.datasets.dataset import Dataset


def get_label_str(label, fine_grained=False):
pre = 'very ' if fine_grained else ''
return {
'0': pre + 'negative',
'1': 'negative',
'2': 'neutral',
'3': 'positive',
'4': pre + 'positive',
None: None
}[label]


def parse_tree(data, subtrees=False, fine_grained=False):
# https://github.com/pytorch/text/blob/6476392a801f51794c90378dd23489578896c6f2/torchtext/data/example.py#L56
try:
from nltk.tree import Tree
except ImportError:
print("Please install NLTK. " "See the docs at http://nltk.org for more information.")
raise
tree = Tree.fromstring(data)

if subtrees:
return [{
'text': ' '.join(t.leaves()),
'label': get_label_str(t.label(), fine_grained=fine_grained)
} for t in tree.subtrees()]

return {
'text': ' '.join(tree.leaves()),
'label': get_label_str(tree.label(), fine_grained=fine_grained)
}


def smt_dataset(directory='data/',
train=False,
dev=False,
test=False,
train_filename='train.txt',
dev_filename='dev.txt',
test_filename='test.txt',
extracted_name='trees',
check_file='trees/train.txt',
url='http://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip',
fine_grained=False,
subtrees=False):
"""
Load the Stanford Sentiment Treebank dataset.
Semantic word spaces have been very useful but cannot express the meaning of longer phrases in
a principled way. Further progress towards understanding compositionality in tasks such as
sentiment detection requires richer supervised training and evaluation resources and more
powerful models of composition. To remedy this, we introduce a Sentiment Treebank. It includes
fine grained sentiment labels for 215,154 phrases in the parse trees of 11,855 sentences and
presents new challenges for sentiment compositionality.
More details:
https://nlp.stanford.edu/sentiment/index.html
Citation:
Richard Socher, Alex Perelygin, Jean Y. Wu, Jason Chuang, Christopher D. Manning,
Andrew Y. Ng and Christopher Potts. Recursive Deep Models for Semantic Compositionality Over a
Sentiment Treebank
Args:
directory (str, optional): Directory to cache the dataset.
train (bool, optional): If to load the training split of the dataset.
dev (bool, optional): If to load the development split of the dataset.
test (bool, optional): If to load the test split of the dataset.
train_filename (str, optional): The filename of the training split.
dev_filename (str, optional): The filename of the development split.
test_filename (str, optional): The filename of the test split.
extracted_name (str, optional): Name of the extracted dataset directory.
check_file (str, optional): Check this file exists if download was successful.
url (str, optional): URL of the dataset `tar.gz` file.
subtrees (bool, optional): Whether to include sentiment-tagged subphrases in addition to
complete examples.
fine_grained (bool, optional): Whether to use 5-class instead of 3-class labeling.
Returns:
:class:`tuple` of :class:`torchnlp.datasets.Dataset`: Tuple with the training tokens, dev
tokens and test tokens in order if their respective boolean argument is true.
Example:
>>> from torchnlp.datasets import smt_dataset
>>> train = smt_dataset(train=True)
>>> train[5]
{
'text': "Whether or not you 're enlightened by any of Derrida 's lectures on ...",
'label': 'positive'
}
"""
download_compressed_directory(file_url=url, directory=directory, check_file=check_file)

ret = []
splits = [(train, train_filename), (dev, dev_filename), (test, test_filename)]
splits = [f for (requested, f) in splits if requested]
for filename in splits:
full_path = os.path.join(directory, extracted_name, filename)
examples = []
with io.open(full_path, encoding='utf-8') as f:
for line in f:
line = line.strip()
if subtrees:
examples.extend(parse_tree(line, subtrees=subtrees))
else:
examples.append(parse_tree(line, subtrees=subtrees))
ret.append(Dataset(examples))

if len(ret) == 1:
return ret[0]
else:
return tuple(ret)
10 changes: 6 additions & 4 deletions torchnlp/datasets/trec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def trec_dataset(directory='data/trec/',
urls=[
'http://cogcomp.org/Data/QA/QC/train_5500.label',
'http://cogcomp.org/Data/QA/QC/TREC_10.label'
]):
],
fine_grained=False):
"""
Load the Text REtrieval Conference (TREC) Question Classification dataset.
Expand Down Expand Up @@ -47,11 +48,9 @@ def trec_dataset(directory='data/trec/',
>>> train = trec_dataset(train=True)
>>> train[:2]
[{
'label_fine': 'manner',
'label': 'DESC',
'text': 'How did serfdom develop in and then leave Russia ?'
}, {
'label_fine': 'cremat',
'label': 'ENTY',
'text': 'What films featured the character Popeye Doyle ?'
}]
Expand All @@ -68,7 +67,10 @@ def trec_dataset(directory='data/trec/',
# there is one non-ASCII byte: sisterBADBYTEcity; replaced with space
label, _, text = line.replace(b'\xf0', b' ').strip().decode().partition(' ')
label, _, label_fine = label.partition(':')
examples.append({'label_fine': label_fine, 'label': label, 'text': text})
if fine_grained:
examples.append({'label': label_fine, 'text': text})
else:
examples.append({'label': label, 'text': text})
ret.append(Dataset(examples))

if len(ret) == 1:
Expand Down

0 comments on commit 349389e

Please sign in to comment.