This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
babi.py
110 lines (85 loc) · 4 KB
/
babi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import logging
from typing import Dict, List
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader, PathOrStr
from allennlp.data.instance import Instance
from allennlp.data.fields import Field, TextField, ListField, IndexField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
logger = logging.getLogger(__name__)
@DatasetReader.register("babi")
class BabiReader(DatasetReader):
"""
Reads one single task in the bAbI tasks format as formulated in
Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks
(https://arxiv.org/abs/1502.05698). Since this class handle a single file,
if one wants to load multiple tasks together it has to merge them into a
single file and use this reader.
Registered as a `DatasetReader` with name "babi".
# Parameters
keep_sentences : `bool`, optional, (default = `False`)
Whether to keep each sentence in the context or to concatenate them.
Default is `False` that corresponds to concatenation.
token_indexers : `Dict[str, TokenIndexer]`, optional (default=`{"tokens": SingleIdTokenIndexer()}`)
We use this to define the input representation for the text. See :class:`TokenIndexer`.
"""
def __init__(
self,
keep_sentences: bool = False,
token_indexers: Dict[str, TokenIndexer] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self._keep_sentences = keep_sentences
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
def _read(self, file_path: PathOrStr):
# if `file_path` is a URL, redirect to the cache
file_path = cached_path(file_path)
logger.info("Reading file at %s", file_path)
with open(file_path) as dataset_file:
dataset = dataset_file.readlines()
logger.info("Reading the dataset")
context: List[List[str]] = [[]]
for line in dataset:
if "?" in line:
question_str, answer, supports_str = line.replace("?", " ?").split("\t")
question = question_str.split()[1:]
supports = [int(support) - 1 for support in supports_str.split()]
yield self.text_to_instance(context, question, answer, supports)
else:
new_entry = line.replace(".", " .").split()[1:]
if line[0] == "1":
context = [new_entry]
else:
context.append(new_entry)
def text_to_instance( # type: ignore
self,
context: List[List[str]],
question: List[str],
answer: str,
supports: List[int],
) -> Instance:
fields: Dict[str, Field] = {}
if self._keep_sentences:
context_field_ks = ListField(
[TextField([Token(word) for word in line]) for line in context]
)
fields["supports"] = ListField(
[IndexField(support, context_field_ks) for support in supports]
)
else:
context_field = TextField([Token(word) for line in context for word in line])
fields["context"] = context_field_ks if self._keep_sentences else context_field
fields["question"] = TextField(
[Token(word) for word in question],
)
fields["answer"] = TextField([Token(answer)])
return Instance(fields)
def apply_token_indexers(self, instance: Instance) -> None:
if self._keep_sentences:
for text_field in instance.fields["context"]: # type: ignore
text_field._token_indexers = self._token_indexers # type: ignore
else:
instance.fields["context"]._token_indexers = self._token_indexers # type: ignore
instance.fields["question"]._token_indexers = self._token_indexers # type: ignore
instance.fields["answer"]._token_indexers = self._token_indexers # type: ignore