This repository has been archived by the owner on Jul 4, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
def0dfb
commit f4987ff
Showing
20 changed files
with
325 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
import shutil | ||
|
||
import mock | ||
|
||
from torchnlp.datasets import iwslt_dataset | ||
from tests.datasets.utils import urlretrieve_side_effect | ||
|
||
iwslt_directory = 'tests/_test_data/iwslt' | ||
|
||
|
||
@mock.patch("urllib.request.urlretrieve") | ||
def test_iwslt_dataset_row(mock_urlretrieve): | ||
mock_urlretrieve.side_effect = urlretrieve_side_effect | ||
|
||
# Check a row are parsed correctly | ||
train, dev, test = iwslt_dataset(directory=iwslt_directory, test=True, dev=True, train=True) | ||
assert len(train) > 0 | ||
assert len(dev) > 0 | ||
assert len(test) > 0 | ||
assert train[0] == { | ||
'en': "David Gallo: This is Bill Lange. I'm Dave Gallo.", | ||
'de': 'David Gallo: Das ist Bill Lange. Ich bin Dave Gallo.' | ||
} | ||
|
||
# Smoke test for iwslt_clean running twice | ||
train, dev, test = iwslt_dataset(directory=iwslt_directory, test=True, dev=True, train=True) | ||
assert train[0] == { | ||
'en': "David Gallo: This is Bill Lange. I'm Dave Gallo.", | ||
'de': 'David Gallo: Das ist Bill Lange. Ich bin Dave Gallo.' | ||
} | ||
|
||
# Clean up | ||
shutil.rmtree(os.path.join(iwslt_directory, 'en-de')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import os | ||
|
||
import mock | ||
|
||
from torchnlp.datasets import multi30k_dataset | ||
from tests.datasets.utils import urlretrieve_side_effect | ||
|
||
multi30k_directory = 'tests/_test_data/multi30k' | ||
|
||
|
||
@mock.patch("urllib.request.urlretrieve") | ||
def test_multi30k_dataset_row(mock_urlretrieve): | ||
mock_urlretrieve.side_effect = urlretrieve_side_effect | ||
|
||
# Check a row are parsed correctly | ||
train, dev, test = multi30k_dataset( | ||
directory=multi30k_directory, test=True, dev=True, train=True) | ||
assert len(train) > 0 | ||
assert len(dev) > 0 | ||
assert len(test) > 0 | ||
assert train[0] == { | ||
'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.', | ||
'en': 'Two young, White males are outside near many bushes.' | ||
} | ||
|
||
# Clean up | ||
os.remove(os.path.join(multi30k_directory, 'train.en')) | ||
os.remove(os.path.join(multi30k_directory, 'train.de')) | ||
os.remove(os.path.join(multi30k_directory, 'test.en')) | ||
os.remove(os.path.join(multi30k_directory, 'test.de')) | ||
os.remove(os.path.join(multi30k_directory, 'val.en')) | ||
os.remove(os.path.join(multi30k_directory, 'val.de')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import os | ||
import xml.etree.ElementTree as ElementTree | ||
import io | ||
import glob | ||
|
||
from torchnlp.utils import download_compressed_directory | ||
from torchnlp.datasets.dataset import Dataset | ||
|
||
|
||
def iwslt_dataset( | ||
directory='data/iwslt/', | ||
train=False, | ||
dev=False, | ||
test=False, | ||
language_extensions=['en', 'de'], | ||
train_filename='{source}-{target}/train.{source}-{target}.{lang}', | ||
dev_filename='{source}-{target}/IWSLT16.TED.tst2013.{source}-{target}.{lang}', | ||
test_filename='{source}-{target}/IWSLT16.TED.tst2014.{source}-{target}.{lang}', | ||
check_file='{source}-{target}/train.tags.{source}-{target}.{source}', | ||
url='https://wit3.fbk.eu/archive/2016-01/texts/{source}/{target}/{source}-{target}.tgz'): | ||
""" | ||
Load the International Workshop on Spoken Language Translation (IWSLT) 2017 translation dataset. | ||
In-domain training, development and evaluation sets were supplied through the website of the | ||
WIT3 project, while out-of-domain training data were linked in the workshop’s website. With | ||
respect to edition 2016 of the evaluation campaign, some of the talks added to the TED | ||
repository during the last year have been used to define the evaluation sets (tst2017), while | ||
the remaining new talks have been included in the training sets. | ||
The English data that participants were asked to recognize and translate consists in part of | ||
TED talks as in the years before, and in part of real-life lectures and talks that have been | ||
mainly recorded in lecture halls at KIT and Carnegie Mellon University. TED talks are | ||
challenging due to their variety in topics, but are very benign as they are very thoroughly | ||
rehearsed and planned, leading to easy to recognize and translate language. | ||
More details: | ||
http://workshop2017.iwslt.org/downloads/iwslt2017_proceeding_v2.pdf | ||
http://workshop2017.iwslt.org/ | ||
Citation: | ||
M. Cettolo, C. Girardi, and M. Federico. 2012. WIT3: Web Inventory of Transcribed and Translated | ||
Talks. In Proc. of EAMT, pp. 261-268, Trento, Italy. | ||
Args: | ||
directory (str, optional): Directory to cache the dataset. | ||
train (bool, optional): If to load the training split of the dataset. | ||
dev (bool, optional): If to load the dev split of the dataset. | ||
test (bool, optional): If to load the test split of the dataset. | ||
language_extensions (:class:`list` of :class:`str`): Two language extensions | ||
['en'|'de'|'it'|'ni'|'ro'] to load. | ||
train_filename (str, optional): The filename of the training split. | ||
dev_filename (str, optional): The filename of the dev split. | ||
test_filename (str, optional): The filename of the test split. | ||
check_file (str, optional): Check this file exists if download was successful. | ||
url (str, optional): URL of the dataset file. | ||
Returns: | ||
:class:`tuple` of :class:`torchnlp.datasets.Dataset`: Tuple with the training tokens, dev | ||
tokens and test tokens in order if their respective boolean argument is true. | ||
Example: | ||
>>> from torchnlp.datasets import iwslt_dataset | ||
>>> train = iwslt_dataset(train=True) | ||
>>> train[:2] | ||
[{ | ||
'en': "David Gallo: This is Bill Lange. I'm Dave Gallo.", | ||
'de': 'David Gallo: Das ist Bill Lange. Ich bin Dave Gallo.' | ||
}, { | ||
'en': "And we're going to tell you some stories from the sea here in video.", | ||
'de': 'Wir werden Ihnen einige Geschichten über das Meer in Videoform erzählen.' | ||
}] | ||
""" | ||
if len(language_extensions) != 2: | ||
raise ValueError("`language_extensions` must be two language extensions " | ||
"['en'|'de'|'it'|'ni'|'ro'] to load.") | ||
|
||
# Format Filenames | ||
source, target = tuple(language_extensions) | ||
check_file = check_file.format(source=source, target=target) | ||
url = url.format(source=source, target=target) | ||
|
||
download_compressed_directory(file_url=url, directory=directory, check_file=check_file) | ||
|
||
iwslt_clean(os.path.join(directory, '{source}-{target}'.format(source=source, target=target))) | ||
|
||
ret = [] | ||
splits = [(train, train_filename), (dev, dev_filename), (test, test_filename)] | ||
splits = [f for (requested, f) in splits if requested] | ||
for filename in splits: | ||
examples = [] | ||
for extension in language_extensions: | ||
path = os.path.join(directory, | ||
filename.format(lang=extension, source=source, target=target)) | ||
with open(path, 'r', encoding='utf-8') as f: | ||
language_specific_examples = [l.strip() for l in f] | ||
|
||
if len(examples) == 0: | ||
examples = [{} for _ in range(len(language_specific_examples))] | ||
for i, example in enumerate(language_specific_examples): | ||
examples[i][extension] = example | ||
|
||
ret.append(Dataset(examples)) | ||
|
||
if len(ret) == 1: | ||
return ret[0] | ||
else: | ||
return tuple(ret) | ||
|
||
|
||
def iwslt_clean(directory): | ||
# Thanks to torchtext for this snippet: | ||
# https://github.com/pytorch/text/blob/ea64e1d28c794ed6ffc0a5c66651c33e2f57f01f/torchtext/datasets/translation.py#L152 | ||
for xml_filename in glob.iglob(os.path.join(directory, '*.xml')): | ||
txt_filename = os.path.splitext(xml_filename)[0] | ||
if os.path.isfile(txt_filename): | ||
continue | ||
|
||
with io.open(txt_filename, mode='w', encoding='utf-8') as f: | ||
root = ElementTree.parse(xml_filename).getroot()[0] | ||
for doc in root.findall('doc'): | ||
for element in doc.findall('seg'): | ||
f.write(element.text.strip() + '\n') | ||
|
||
xml_tags = [ | ||
'<url', '<keywords', '<talkid', '<description', '<reviewer', '<translator', '<title', | ||
'<speaker' | ||
] | ||
for original_filename in glob.iglob(os.path.join(directory, 'train.tags*')): | ||
txt_filename = original_filename.replace('.tags', '') | ||
if os.path.isfile(txt_filename): | ||
continue | ||
|
||
with io.open(txt_filename, mode='w', encoding='utf-8') as txt_file, \ | ||
io.open(original_filename, mode='r', encoding='utf-8') as original_file: | ||
for line in original_file: | ||
if not any(tag in line for tag in xml_tags): | ||
txt_file.write(line.strip() + '\n') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import os | ||
|
||
from torchnlp.utils import download_urls | ||
from torchnlp.datasets.dataset import Dataset | ||
|
||
|
||
def multi30k_dataset(directory='data/multi30k/', | ||
train=False, | ||
dev=False, | ||
test=False, | ||
language_extensions=['en', 'de'], | ||
train_filename='train', | ||
dev_filename='val', | ||
test_filename='test', | ||
check_file='train.de', | ||
urls=[ | ||
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz', | ||
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz', | ||
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz' | ||
]): | ||
""" | ||
Load the WMT 2016 machine translation dataset. | ||
As a translation task, this task consists in translating English sentences that describe an | ||
image into German, given the English sentence itself. As training and development data, we | ||
provide 29,000 and 1,014 triples respectively, each containing an English source sentence, its | ||
German human translation. As test data, we provide a new set of 1,000 tuples containing an | ||
English description. | ||
More details: | ||
http://www.statmt.org/wmt16/multimodal-task.html | ||
http://shannon.cs.illinois.edu/DenotationGraph/ | ||
Citation: | ||
``` | ||
@article{elliott-EtAl:2016:VL16, | ||
author = {{Elliott}, D. and {Frank}, S. and {Sima'an}, K. and {Specia}, L.}, | ||
title = {Multi30K: Multilingual English-German Image Descriptions}, | ||
booktitle = {Proceedings of the 5th Workshop on Vision and Language}, | ||
year = {2016}, | ||
pages = {70--74}, | ||
year = 2016 | ||
} | ||
``` | ||
Args: | ||
directory (str, optional): Directory to cache the dataset. | ||
train (bool, optional): If to load the training split of the dataset. | ||
dev (bool, optional): If to load the dev split of the dataset. | ||
test (bool, optional): If to load the test split of the dataset. | ||
language_extensions (:class:`list` of :class:`str`): List of language extensions ['en'|'de'] | ||
to load. | ||
train_directory (str, optional): The directory of the training split. | ||
dev_directory (str, optional): The directory of the dev split. | ||
test_directory (str, optional): The directory of the test split. | ||
check_file (str, optional): Check this file exists if download was successful. | ||
urls (str, optional): URLs to download. | ||
Returns: | ||
:class:`tuple` of :class:`torchnlp.datasets.Dataset`: Tuple with the training tokens, dev | ||
tokens and test tokens in order if their respective boolean argument is true. | ||
Example: | ||
>>> from torchnlp.datasets import multi30k_dataset | ||
>>> train = multi30k_dataset(train=True) | ||
>>> train[:2] | ||
[{ | ||
'en': 'Two young, White males are outside near many bushes.', | ||
'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.' | ||
}, { | ||
'en': 'Several men in hard hatsare operating a giant pulley system.', | ||
'de': 'Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.' | ||
}] | ||
""" | ||
download_urls(directory=directory, file_urls=urls, check_file=check_file) | ||
|
||
ret = [] | ||
splits = [(train, train_filename), (dev, dev_filename), (test, test_filename)] | ||
splits = [f for (requested, f) in splits if requested] | ||
for filename in splits: | ||
examples = [] | ||
for extension in language_extensions: | ||
path = os.path.join(directory, filename + '.' + extension) | ||
with open(path, 'r', encoding='utf-8') as f: | ||
language_specific_examples = [l.strip() for l in f] | ||
|
||
if len(examples) == 0: | ||
examples = [{} for _ in range(len(language_specific_examples))] | ||
for i, example in enumerate(language_specific_examples): | ||
examples[i][extension] = example | ||
|
||
ret.append(Dataset(examples)) | ||
|
||
if len(ret) == 1: | ||
return ret[0] | ||
else: | ||
return tuple(ret) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.