Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

Commit

Permalink
Added TREC Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
PetrochukM committed Apr 1, 2018
1 parent 7aafdfa commit b348344
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 15 deletions.
15 changes: 15 additions & 0 deletions tests/_test_data/trec/TREC_10.label
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
NUM:dist How far is it from Denver to Aspen ?
LOC:city What county is Modesto , California in ?
HUM:desc Who was Galileo ?
DESC:def What is an atom ?
NUM:date When did Hawaii become a state ?
NUM:dist How tall is the Sears Building ?
HUM:gr George Bush purchased a small interest in which baseball team ?
ENTY:plant What is Australia 's national flower ?
DESC:reason Why does the moon turn orange ?
DESC:def What is autism ?
LOC:city What city had a world fair in 1900 ?
HUM:ind What person 's head is on a dime ?
NUM:weight What is the average weight of a Yellow Labrador ?
HUM:ind Who was the first man to fly across the Pacific Ocean ?
NUM:date When did Idaho become a state ?
15 changes: 15 additions & 0 deletions tests/_test_data/trec/train_5500.label
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
DESC:manner How did serfdom develop in and then leave Russia ?
ENTY:cremat What films featured the character Popeye Doyle ?
DESC:manner How can I find a list of celebrities ' real names ?
ENTY:animal What fowl grabs the spotlight after the Chinese Year of the Monkey ?
ABBR:exp What is the full form of .com ?
HUM:ind What contemptible scoundrel stole the cork from my lunch ?
HUM:gr What team did baseball 's St. Louis Browns become ?
HUM:title What is the oldest profession ?
DESC:def What are liver enzymes ?
HUM:ind Name the scar-faced bounty hunter of The Old West .
NUM:date When was Ozzy Osbourne born ?
DESC:reason Why do heavier objects travel downhill faster ?
HUM:ind Who was The Pride of the Yankees ?
HUM:ind Who killed Gandhi ?
ENTY:event What is considered the costliest disaster the insurance industry has ever faced ?
2 changes: 1 addition & 1 deletion tests/datasets/test_penn_treebank.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torchnlp.datasets import penn_treebank_dataset
from tests.datasets.utils import urlretrieve_side_effect

directory = 'tests/_test_data/'
directory = 'tests/_test_data/penn-treebank'


@mock.patch("urllib.request.urlretrieve")
Expand Down
25 changes: 25 additions & 0 deletions tests/datasets/test_trec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import mock

from torchnlp.datasets import trec_dataset
from tests.datasets.utils import urlretrieve_side_effect

directory = 'tests/_test_data/trec'


@mock.patch("urllib.request.urlretrieve")
def test_penn_treebank_dataset_row(mock_urlretrieve):
mock_urlretrieve.side_effect = urlretrieve_side_effect

# Check a row are parsed correctly
train, test = trec_dataset(directory=directory, test=True, train=True, check_file=None)
assert len(train) > 0
assert len(test) > 0
assert train[:2] == [{
'label_fine': 'manner',
'label': 'DESC',
'text': 'How did serfdom develop in and then leave Russia ?'
}, {
'label_fine': 'cremat',
'label': 'ENTY',
'text': 'What films featured the character Popeye Doyle ?'
}]
2 changes: 2 additions & 0 deletions torchnlp/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torchnlp.datasets.penn_treebank import penn_treebank_dataset
from torchnlp.datasets.ud_pos import ud_pos_dataset
from torchnlp.datasets.snli import snli_dataset
from torchnlp.datasets.trec import trec_dataset
from torchnlp.datasets.dataset import Dataset

__all__ = [
Expand All @@ -17,6 +18,7 @@
'wikitext_2_dataset',
'penn_treebank_dataset',
'ud_pos_dataset',
'trec_dataset',
'reverse_dataset',
'count_dataset',
'zero_dataset',
Expand Down
15 changes: 5 additions & 10 deletions torchnlp/datasets/penn_treebank.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
import os
import io
import urllib.request

from tqdm import tqdm

from torchnlp.utils import reporthook
from torchnlp.text_encoders import UNKNOWN_TOKEN
from torchnlp.text_encoders import EOS_TOKEN
from torchnlp.utils import download_urls


def penn_treebank_dataset(
directory='data/',
directory='data/penn-treebank',
train=False,
dev=False,
test=False,
train_filename='ptb.train.txt',
dev_filename='ptb.valid.txt',
test_filename='ptb.test.txt',
name='penn-treebank',
check_file='penn-treebank/ptb.train.txt',
check_file='ptb.train.txt',
urls=[
'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt',
'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt',
Expand Down Expand Up @@ -48,7 +43,7 @@ def penn_treebank_dataset(
test_filename (str, optional): The filename of the test split.
name (str, optional): Name of the dataset directory.
check_file (str, optional): Check this file exists if download was successful.
urls (str, optional): URLs of the dataset `tar.gz` file.
urls (str, optional): URLs to download.
Returns:
:class:`tuple` of :class:`list` of :class:`str`: Tuple with the training tokens, dev tokens
Expand All @@ -61,13 +56,13 @@ def penn_treebank_dataset(
['aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano',
'guterman', 'hydro-quebec']
"""
download_urls(directory=os.path.join(directory, name), urls=urls, check_file=check_file)
download_urls(directory=directory, urls=urls, check_file=check_file)

ret = []
splits = [(train, train_filename), (dev, dev_filename), (test, test_filename)]
split_filenames = [dir_ for (requested, dir_) in splits if requested]
for filename in split_filenames:
full_path = os.path.join(directory, name, filename)
full_path = os.path.join(directory, filename)
text = []
with io.open(full_path, encoding='utf-8') as f:
for line in f:
Expand Down
77 changes: 77 additions & 0 deletions torchnlp/datasets/trec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os

from torchnlp.utils import download_urls
from torchnlp.datasets.dataset import Dataset


def trec_dataset(directory='data/trec/',
train=False,
test=False,
train_filename='train_5500.label',
test_filename='TREC_10.label',
check_file='train_5500.label',
urls=[
'http://cogcomp.org/Data/QA/QC/train_5500.label',
'http://cogcomp.org/Data/QA/QC/TREC_10.label'
]):
"""
Load the Text REtrieval Conference (TREC) Question Classification dataset.
TREC dataset contains 5500 labeled questions in training set and another 500 for test set. The
dataset has 6 labels, 50 level-2 labels. Average length of each sentence is 10, vocabulary size
of 8700.
More details:
https://nlp.stanford.edu/courses/cs224n/2004/may-steinberg-project.pdf
http://cogcomp.org/Data/QA/QC/
http://www.aclweb.org/anthology/C02-1150
Citation:
Xin Li, Dan Roth, Learning Question Classifiers. COLING'02, Aug., 2002.
Args:
directory (str, optional): Directory to cache the dataset.
train (bool, optional): If to load the training split of the dataset.
test (bool, optional): If to load the test split of the dataset.
train_filename (str, optional): The filename of the training split.
test_filename (str, optional): The filename of the test split.
check_file (str, optional): Check this file exists if download was successful.
urls (str, optional): URLs to download.
Returns:
:class:`tuple` of :class:`list` of :class:`str`: Tuple with the training tokens, dev tokens
and test tokens in order if their respective boolean argument is true.
Example:
>>> from torchnlp.datasets import trec_dataset
>>> train = trec_dataset(train=True)
>>> train[:2] # Sentence at index 17 is shortish
[{
'label_fine': 'manner',
'label': 'DESC',
'text': 'How did serfdom develop in and then leave Russia ?'
}, {
'label_fine': 'cremat',
'label': 'ENTY',
'text': 'What films featured the character Popeye Doyle ?'
}]
"""
download_urls(directory=directory, urls=urls, check_file=check_file)

ret = []
splits = [(train, train_filename), (test, test_filename)]
split_filenames = [dir_ for (requested, dir_) in splits if requested]
for filename in split_filenames:
full_path = os.path.join(directory, filename)
examples = []
for line in open(full_path, 'rb'):
# there is one non-ASCII byte: sisterBADBYTEcity; replaced with space
label, _, text = line.replace(b'\xf0', b' ').strip().decode().partition(' ')
label, _, label_fine = label.partition(':')
examples.append({'label_fine': label_fine, 'label': label, 'text': text})
ret.append(Dataset(examples))

if len(ret) == 1:
return ret[0]
else:
return tuple(ret)
8 changes: 4 additions & 4 deletions torchnlp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@ def get_filename_from_url(url):
return os.path.basename(parse.path)


def download_urls(urls, directory, check_file):
def download_urls(urls, directory, check_file=None):
""" Download a set of ``urls`` into a ``directory``.
Args:
urls (:class:`list` of :class:`str`): Set of urls to download.
directory (str): Directory in which to download urls.
check_file (str): Operation was successful if this file exists.
check_file (str, optional): Operation was successful if this file exists.
Returns:
None:
"""
Expand All @@ -178,13 +178,13 @@ def download_urls(urls, directory, check_file):
raise ValueError('[DOWNLOAD FAILED] `check_file` not found')


def download_compressed_directory(url, directory, check_file):
def download_compressed_directory(url, directory, check_file=None):
""" Download a ``tar.gz`` from ``url`` and extract into ``directory``.
Args:
url (str): Url of a compressed directory
directory (str): Directory to extract ``tar.gz`` to.
check_file (str): Operation was successful if this file exists.
check_file (str, optional): Operation was successful if this file exists.
Returns:
None:
"""
Expand Down

0 comments on commit b348344

Please sign in to comment.