Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Seq2seq dataset reader improvements (#2599)
Browse files Browse the repository at this point in the history
* use Python core csv module.

* add delimiter parameter to seq2seq datareader

* add test for ConfigurationError in Seq2SeqDatasetReader

* add test that ensures quoting with/without '"' works the same

* some changes to satisfy pylint
  • Loading branch information
mfa authored and matt-gardner committed Mar 13, 2019
1 parent 1adb3e8 commit 18312a0
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 11 deletions.
19 changes: 9 additions & 10 deletions allennlp/data/dataset_readers/seq2seq.py
@@ -1,3 +1,4 @@
import csv
from typing import Dict
import logging

Expand Down Expand Up @@ -45,35 +46,33 @@ class Seq2SeqDatasetReader(DatasetReader):
``source_token_indexers``.
source_add_start_token : bool, (optional, default=True)
Whether or not to add `START_SYMBOL` to the beginning of the source sequence.
delimiter : str, (optional, default="\t")
Set delimiter for tsv/csv file.
"""
def __init__(self,
source_tokenizer: Tokenizer = None,
target_tokenizer: Tokenizer = None,
source_token_indexers: Dict[str, TokenIndexer] = None,
target_token_indexers: Dict[str, TokenIndexer] = None,
source_add_start_token: bool = True,
delimiter: str = "\t",
lazy: bool = False) -> None:
super().__init__(lazy)
self._source_tokenizer = source_tokenizer or WordTokenizer()
self._target_tokenizer = target_tokenizer or self._source_tokenizer
self._source_token_indexers = source_token_indexers or {"tokens": SingleIdTokenIndexer()}
self._target_token_indexers = target_token_indexers or self._source_token_indexers
self._source_add_start_token = source_add_start_token
self._delimiter = delimiter

@overrides
def _read(self, file_path):
with open(cached_path(file_path), "r") as data_file:
logger.info("Reading instances from lines in file at: %s", file_path)
for line_num, line in enumerate(data_file):
line = line.strip("\n")

if not line:
continue

line_parts = line.split('\t')
if len(line_parts) != 2:
raise ConfigurationError("Invalid line format: %s (line number %d)" % (line, line_num + 1))
source_sequence, target_sequence = line_parts
for line_num, row in enumerate(csv.reader(data_file, delimiter=self._delimiter)):
if len(row) != 2:
raise ConfigurationError("Invalid line format: %s (line number %d)" % (row, line_num + 1))
source_sequence, target_sequence = row
yield self.text_to_instance(source_sequence, target_sequence)

@overrides
Expand Down
49 changes: 48 additions & 1 deletion allennlp/tests/data/dataset_readers/seq2seq_test.py
@@ -1,9 +1,11 @@
# pylint: disable=no-self-use,invalid-name
import tempfile
import pytest

from allennlp.data.dataset_readers import Seq2SeqDatasetReader
from allennlp.common.checks import ConfigurationError
from allennlp.common.util import ensure_list
from allennlp.common.testing import AllenNlpTestCase
from allennlp.data.dataset_readers import Seq2SeqDatasetReader

class TestSeq2SeqDatasetReader:
@pytest.mark.parametrize("lazy", (True, False))
Expand Down Expand Up @@ -39,3 +41,48 @@ def test_source_add_start_token(self):
assert [t.text for t in fields["source_tokens"].tokens] == ["this", "is", "a", "sentence", "@end@"]
assert [t.text for t in fields["target_tokens"].tokens] == ["@start@", "this", "is",
"a", "sentence", "@end@"]

def test_delimiter_parameter(self):
reader = Seq2SeqDatasetReader(delimiter=",")
instances = reader.read(str(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'seq2seq_copy.csv'))
instances = ensure_list(instances)

assert len(instances) == 3
fields = instances[0].fields
assert [t.text for t in fields["source_tokens"].tokens] == ["@start@", "this", "is",
"a", "sentence", "@end@"]
assert [t.text for t in fields["target_tokens"].tokens] == ["@start@", "this", "is",
"a", "sentence", "@end@"]
fields = instances[2].fields
assert [t.text for t in fields["source_tokens"].tokens] == ["@start@", "all", "these", "sentences",
"should", "get", "copied", "@end@"]
assert [t.text for t in fields["target_tokens"].tokens] == ["@start@", "all", "these", "sentences",
"should", "get", "copied", "@end@"]

@pytest.mark.parametrize("line", (
("a\n"),
("a\tb\tc\n"),
))
def test_invalid_line_format(self, line):
with tempfile.NamedTemporaryFile("w") as fp_tmp:
fp_tmp.write(line)
fp_tmp.flush()
reader = Seq2SeqDatasetReader()
with pytest.raises(ConfigurationError):
reader.read(fp_tmp.name)

@pytest.mark.parametrize("line", (
("a b\tc d\n"),
('"a b"\t"c d"\n'),
))
def test_correct_quote_handling(self, line):
with tempfile.NamedTemporaryFile("w") as fp_tmp:
fp_tmp.write(line)
fp_tmp.flush()
reader = Seq2SeqDatasetReader()
instances = reader.read(fp_tmp.name)
instances = ensure_list(instances)
assert len(instances) == 1
fields = instances[0].fields
assert [t.text for t in fields["source_tokens"].tokens] == ["@start@", "a", "b", "@end@"]
assert [t.text for t in fields["target_tokens"].tokens] == ["@start@", "c", "d", "@end@"]
3 changes: 3 additions & 0 deletions allennlp/tests/fixtures/data/seq2seq_copy.csv
@@ -0,0 +1,3 @@
"this is a sentence","this is a sentence"
"this is another","this is another"
"all these sentences should get copied","all these sentences should get copied"

0 comments on commit 18312a0

Please sign in to comment.