This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
preco.py
106 lines (90 loc) · 4.71 KB
/
preco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import json
import logging
from typing import Dict, List, Optional, Tuple
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.instance import Instance
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
from allennlp_models.coref.util import make_coref_instance
logger = logging.getLogger(__name__)
@DatasetReader.register("preco")
class PrecoReader(DatasetReader):
"""
Reads a single JSON-lines file for [the PreCo dataset](https://www.aclweb.org/anthology/D18-1016.pdf).
Each line contains a "sentences" key for a list of sentences and a "mention_clusters" key
for the clusters.
Returns a `Dataset` where the `Instances` have four fields : `text`, a `TextField`
containing the full document text, `spans`, a `ListField[SpanField]` of inclusive start and
end indices for span candidates, and `metadata`, a `MetadataField` that stores the instance's
original text. For data with gold cluster labels, we also include the original `clusters`
(a list of list of index pairs) and a `SequenceLabelField` of cluster ids for every span
candidate.
# Parameters
max_span_width : `int`, required.
The maximum width of candidate spans to consider.
token_indexers : `Dict[str, TokenIndexer]`, optional
This is used to index the words in the document. See :class:`TokenIndexer`.
Default is `{"tokens": SingleIdTokenIndexer()}`.
wordpiece_modeling_tokenizer: `PretrainedTransformerTokenizer`, optional (default = `None`)
If not None, this dataset reader does subword tokenization using the supplied tokenizer
and distribute the labels to the resulting wordpieces. All the modeling will be based on
wordpieces. If this is set to `False` (default), the user is expected to use
`PretrainedTransformerMismatchedIndexer` and `PretrainedTransformerMismatchedEmbedder`,
and the modeling will be on the word-level.
max_sentences: `int`, optional (default = `None`)
The maximum number of sentences in each document to keep. By default keeps all sentences.
remove_singleton_clusters : `bool`, optional (default = `False`)
Some datasets contain clusters that are singletons (i.e. no coreferents). This option allows
the removal of them.
"""
def __init__(
self,
max_span_width: int,
token_indexers: Dict[str, TokenIndexer] = None,
wordpiece_modeling_tokenizer: Optional[PretrainedTransformerTokenizer] = None,
max_sentences: int = None,
remove_singleton_clusters: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self._max_span_width = max_span_width
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
self._wordpiece_modeling_tokenizer = wordpiece_modeling_tokenizer
self._max_sentences = max_sentences
self._remove_singleton_clusters = remove_singleton_clusters
def _read(self, file_path: str):
# if `file_path` is a URL, redirect to the cache
file_path = cached_path(file_path)
with open(file_path, "r") as preco_file:
for line in preco_file:
example = json.loads(line)
yield self.text_to_instance(example["sentences"], example["mention_clusters"])
def text_to_instance(
self, # type: ignore
sentences: List[List[str]],
gold_clusters: Optional[List[List[Tuple[int, int, int]]]] = None,
) -> Instance:
sentence_offsets = [0]
for sentence in sentences[:-1]:
sent_length = len(sentence)
if sentence == [" "]: # paragraph separator
sent_length = 0 # we ignore them
sentence_offsets.append(sentence_offsets[-1] + sent_length)
sentences = [sentence for sentence in sentences if sentence != [" "]]
# Convert (sent_idx, rel_start, rel_exclusive_end) to (abs_start, abs_inclusive_end)
if gold_clusters:
for cluster in gold_clusters:
for mention_id, (sent_idx, start, end) in enumerate(cluster):
start = start + sentence_offsets[sent_idx]
end = end + sentence_offsets[sent_idx] - 1 # exclusive -> inclusive
cluster[mention_id] = (start, end) # type: ignore
return make_coref_instance(
sentences,
self._token_indexers,
self._max_span_width,
gold_clusters, # type: ignore
self._wordpiece_modeling_tokenizer,
self._max_sentences,
self._remove_singleton_clusters,
)