This repository has been archived by the owner on Jul 4, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f4987ff
commit 349389e
Showing
13 changed files
with
384 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
import shutil | ||
|
||
import mock | ||
|
||
from torchnlp.datasets import smt_dataset | ||
from tests.datasets.utils import urlretrieve_side_effect | ||
|
||
directory = 'tests/_test_data/' | ||
|
||
|
||
@mock.patch("urllib.request.urlretrieve") | ||
def test_smt_dataset_row(mock_urlretrieve): | ||
mock_urlretrieve.side_effect = urlretrieve_side_effect | ||
|
||
# Check a row are parsed correctly | ||
train, dev, test = smt_dataset(directory=directory, test=True, dev=True, train=True) | ||
assert len(train) > 0 | ||
assert len(dev) > 0 | ||
assert len(test) > 0 | ||
assert train[5] == { | ||
'text': | ||
"Whether or not you 're enlightened by any of Derrida 's lectures on `` the other '' " + | ||
"and `` the self , '' Derrida is an undeniably fascinating and playful fellow .", | ||
'label': | ||
'positive' | ||
} | ||
train = smt_dataset(directory=directory, train=True, subtrees=True) | ||
assert train[3] == {'text': 'Rock', 'label': 'neutral'} | ||
|
||
train = smt_dataset(directory=directory, train=True, subtrees=True, fine_grained=True) | ||
assert train[4] == { | ||
'text': | ||
"is destined to be the 21st Century 's new `` Conan '' and that he 's going to make a" + | ||
" splash even greater than Arnold Schwarzenegger , Jean-Claud Van Damme or Steven" + | ||
" Segal .", | ||
'label': | ||
'positive' | ||
} | ||
|
||
# Clean up | ||
shutil.rmtree(os.path.join(directory, 'trees')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import os | ||
|
||
import mock | ||
|
||
from torchnlp.datasets import wmt_dataset | ||
from tests.datasets.utils import download_from_drive_side_effect | ||
|
||
directory = 'tests/_test_data/wmt16_en_de' | ||
|
||
|
||
@mock.patch("torchnlp.utils.download_from_drive") | ||
def test_wmt_dataset(mock_download_from_drive): | ||
mock_download_from_drive.side_effect = download_from_drive_side_effect | ||
|
||
# Check a row are parsed correctly | ||
train, dev, test = wmt_dataset(directory=directory, test=True, dev=True, train=True) | ||
assert len(train) > 0 | ||
assert len(test) > 0 | ||
assert len(dev) > 0 | ||
assert train[0] == { | ||
'en': 'Res@@ um@@ ption of the session', | ||
'de': 'Wiederaufnahme der Sitzungsperiode' | ||
} | ||
|
||
# Clean up | ||
os.remove(os.path.join(directory, 'bpe.32000')) | ||
os.remove(os.path.join(directory, 'newstest2013.tok.bpe.32000.en')) | ||
os.remove(os.path.join(directory, 'newstest2013.tok.bpe.32000.de')) | ||
os.remove(os.path.join(directory, 'newstest2014.tok.bpe.32000.en')) | ||
os.remove(os.path.join(directory, 'newstest2014.tok.bpe.32000.de')) | ||
os.remove(os.path.join(directory, 'train.tok.clean.bpe.32000.de')) | ||
os.remove(os.path.join(directory, 'train.tok.clean.bpe.32000.en')) | ||
os.remove(os.path.join(directory, 'vocab.bpe.32000')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import os | ||
import io | ||
|
||
from torchnlp.utils import download_compressed_directory | ||
from torchnlp.datasets.dataset import Dataset | ||
|
||
|
||
def get_label_str(label, fine_grained=False): | ||
pre = 'very ' if fine_grained else '' | ||
return { | ||
'0': pre + 'negative', | ||
'1': 'negative', | ||
'2': 'neutral', | ||
'3': 'positive', | ||
'4': pre + 'positive', | ||
None: None | ||
}[label] | ||
|
||
|
||
def parse_tree(data, subtrees=False, fine_grained=False): | ||
# https://github.com/pytorch/text/blob/6476392a801f51794c90378dd23489578896c6f2/torchtext/data/example.py#L56 | ||
try: | ||
from nltk.tree import Tree | ||
except ImportError: | ||
print("Please install NLTK. " "See the docs at http://nltk.org for more information.") | ||
raise | ||
tree = Tree.fromstring(data) | ||
|
||
if subtrees: | ||
return [{ | ||
'text': ' '.join(t.leaves()), | ||
'label': get_label_str(t.label(), fine_grained=fine_grained) | ||
} for t in tree.subtrees()] | ||
|
||
return { | ||
'text': ' '.join(tree.leaves()), | ||
'label': get_label_str(tree.label(), fine_grained=fine_grained) | ||
} | ||
|
||
|
||
def smt_dataset(directory='data/', | ||
train=False, | ||
dev=False, | ||
test=False, | ||
train_filename='train.txt', | ||
dev_filename='dev.txt', | ||
test_filename='test.txt', | ||
extracted_name='trees', | ||
check_file='trees/train.txt', | ||
url='http://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip', | ||
fine_grained=False, | ||
subtrees=False): | ||
""" | ||
Load the Stanford Sentiment Treebank dataset. | ||
Semantic word spaces have been very useful but cannot express the meaning of longer phrases in | ||
a principled way. Further progress towards understanding compositionality in tasks such as | ||
sentiment detection requires richer supervised training and evaluation resources and more | ||
powerful models of composition. To remedy this, we introduce a Sentiment Treebank. It includes | ||
fine grained sentiment labels for 215,154 phrases in the parse trees of 11,855 sentences and | ||
presents new challenges for sentiment compositionality. | ||
More details: | ||
https://nlp.stanford.edu/sentiment/index.html | ||
Citation: | ||
Richard Socher, Alex Perelygin, Jean Y. Wu, Jason Chuang, Christopher D. Manning, | ||
Andrew Y. Ng and Christopher Potts. Recursive Deep Models for Semantic Compositionality Over a | ||
Sentiment Treebank | ||
Args: | ||
directory (str, optional): Directory to cache the dataset. | ||
train (bool, optional): If to load the training split of the dataset. | ||
dev (bool, optional): If to load the development split of the dataset. | ||
test (bool, optional): If to load the test split of the dataset. | ||
train_filename (str, optional): The filename of the training split. | ||
dev_filename (str, optional): The filename of the development split. | ||
test_filename (str, optional): The filename of the test split. | ||
extracted_name (str, optional): Name of the extracted dataset directory. | ||
check_file (str, optional): Check this file exists if download was successful. | ||
url (str, optional): URL of the dataset `tar.gz` file. | ||
subtrees (bool, optional): Whether to include sentiment-tagged subphrases in addition to | ||
complete examples. | ||
fine_grained (bool, optional): Whether to use 5-class instead of 3-class labeling. | ||
Returns: | ||
:class:`tuple` of :class:`torchnlp.datasets.Dataset`: Tuple with the training tokens, dev | ||
tokens and test tokens in order if their respective boolean argument is true. | ||
Example: | ||
>>> from torchnlp.datasets import smt_dataset | ||
>>> train = smt_dataset(train=True) | ||
>>> train[5] | ||
{ | ||
'text': "Whether or not you 're enlightened by any of Derrida 's lectures on ...", | ||
'label': 'positive' | ||
} | ||
""" | ||
download_compressed_directory(file_url=url, directory=directory, check_file=check_file) | ||
|
||
ret = [] | ||
splits = [(train, train_filename), (dev, dev_filename), (test, test_filename)] | ||
splits = [f for (requested, f) in splits if requested] | ||
for filename in splits: | ||
full_path = os.path.join(directory, extracted_name, filename) | ||
examples = [] | ||
with io.open(full_path, encoding='utf-8') as f: | ||
for line in f: | ||
line = line.strip() | ||
if subtrees: | ||
examples.extend(parse_tree(line, subtrees=subtrees)) | ||
else: | ||
examples.append(parse_tree(line, subtrees=subtrees)) | ||
ret.append(Dataset(examples)) | ||
|
||
if len(ret) == 1: | ||
return ret[0] | ||
else: | ||
return tuple(ret) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.