This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
qangaroo.py
106 lines (79 loc) · 3.52 KB
/
qangaroo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import json
import logging
from typing import Dict, List
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.instance import Instance
from allennlp.data.fields import Field, TextField, ListField, MetadataField, IndexField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Tokenizer, SpacyTokenizer
logger = logging.getLogger(__name__)
@DatasetReader.register("qangaroo")
class QangarooReader(DatasetReader):
"""
Reads a JSON-formatted Qangaroo file and returns a ``Dataset`` where the ``Instances`` have six
fields: ``candidates``, a ``ListField[TextField]``, ``query``, a ``TextField``, ``supports``, a
``ListField[TextField]``, ``answer``, a ``TextField``, and ``answer_index``, a ``IndexField``.
We also add a ``MetadataField`` that stores the instance's ID and annotations if they are present.
# Parameters
tokenizer : `Tokenizer`, optional (default=`SpacyTokenizer()`)
We use this `Tokenizer` for both the question and the passage. See :class:`Tokenizer`.
Default is ```SpacyTokenizer()``.
token_indexers : `Dict[str, TokenIndexer]`, optional
We similarly use this for both the question and the passage. See :class:`TokenIndexer`.
Default is `{"tokens": SingleIdTokenIndexer()}`.
"""
def __init__(
self,
tokenizer: Tokenizer = None,
token_indexers: Dict[str, TokenIndexer] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self._tokenizer = tokenizer or SpacyTokenizer()
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
def _read(self, file_path: str):
# if `file_path` is a URL, redirect to the cache
file_path = cached_path(file_path)
logger.info("Reading file at %s", file_path)
with open(file_path) as dataset_file:
dataset = json.load(dataset_file)
logger.info("Reading the dataset")
for sample in dataset:
instance = self.text_to_instance(
sample["candidates"],
sample["query"],
sample["supports"],
sample["id"],
sample["answer"],
sample["annotations"] if "annotations" in sample else [[]],
)
yield instance
def text_to_instance(
self, # type: ignore
candidates: List[str],
query: str,
supports: List[str],
_id: str = None,
answer: str = None,
annotations: List[List[str]] = None,
) -> Instance:
fields: Dict[str, Field] = {}
candidates_field = ListField(
[
TextField(candidate, self._token_indexers)
for candidate in self._tokenizer.batch_tokenize(candidates)
]
)
fields["query"] = TextField(self._tokenizer.tokenize(query), self._token_indexers)
fields["supports"] = ListField(
[
TextField(support, self._token_indexers)
for support in self._tokenizer.batch_tokenize(supports)
]
)
fields["answer"] = TextField(self._tokenizer.tokenize(answer), self._token_indexers)
fields["answer_index"] = IndexField(candidates.index(answer), candidates_field)
fields["candidates"] = candidates_field
fields["metadata"] = MetadataField({"annotations": annotations, "id": _id})
return Instance(fields)