This repository has been archived by the owner on Apr 9, 2022. It is now read-only.
/
grammar_based_text2sql.py
189 lines (165 loc) · 7.5 KB
/
grammar_based_text2sql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from typing import Dict, List
import logging
import json
import glob
import os
import sqlite3
from overrides import overrides
from allennlp.common.file_utils import cached_path
from allennlp.common.checks import ConfigurationError
from allennlp.data import DatasetReader
from allennlp.data.fields import TextField, Field, ListField, IndexField
from allennlp.data.instance import Instance
from allennlp.data.tokenizers import Token
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp_semparse.common.sql import text2sql_utils as util
from allennlp_semparse.fields import ProductionRuleField
from allennlp_semparse.parsimonious_languages.worlds.text2sql_world import Text2SqlWorld
logger = logging.getLogger(__name__)
@DatasetReader.register("grammar_based_text2sql")
class GrammarBasedText2SqlDatasetReader(DatasetReader):
"""
Reads text2sql data from
`"Improving Text to SQL Evaluation Methodology" <https://arxiv.org/abs/1806.09029>`_
for a type constrained semantic parser.
Parameters
----------
schema_path : ``str``, required.
The path to the database schema.
database_path : ``str``, optional (default = None)
The path to a database.
use_all_sql : ``bool``, optional (default = False)
Whether to use all of the sql queries which have identical semantics,
or whether to just use the first one.
remove_unneeded_aliases : ``bool``, (default = True)
Whether or not to remove table aliases in the SQL which
are not required.
use_prelinked_entities : ``bool``, (default = True)
Whether or not to use the pre-linked entities in the text2sql data.
use_untyped_entities : ``bool``, (default = True)
Whether or not to attempt to infer the pre-linked entity types.
token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``)
We use this to define the input representation for the text. See :class:`TokenIndexer`.
Note that the `output` tags will always correspond to single token IDs based on how they
are pre-tokenised in the data file.
cross_validation_split_to_exclude : ``int``, optional (default = None)
Some of the text2sql datasets are very small, so you may need to do cross validation.
Here, you can specify a integer corresponding to a split_{int}.json file not to include
in the training set.
keep_if_unparsable : ``bool``, optional (default = True)
Whether or not to keep examples that we can't parse using the grammar.
"""
def __init__(
self,
schema_path: str,
database_file: str = None,
use_all_sql: bool = False,
remove_unneeded_aliases: bool = True,
use_prelinked_entities: bool = True,
use_untyped_entities: bool = True,
token_indexers: Dict[str, TokenIndexer] = None,
cross_validation_split_to_exclude: int = None,
keep_if_unparseable: bool = True,
lazy: bool = False,
) -> None:
super().__init__(lazy)
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
self._use_all_sql = use_all_sql
self._remove_unneeded_aliases = remove_unneeded_aliases
self._use_prelinked_entities = use_prelinked_entities
self._keep_if_unparsable = keep_if_unparseable
if not self._use_prelinked_entities:
raise ConfigurationError(
"The grammar based text2sql dataset reader "
"currently requires the use of entity pre-linking."
)
self._cross_validation_split_to_exclude = str(cross_validation_split_to_exclude)
if database_file is not None:
database_file = cached_path(database_file)
connection = sqlite3.connect(database_file)
self._cursor = connection.cursor()
else:
self._cursor = None
self._schema_path = schema_path
self._world = Text2SqlWorld(
schema_path,
self._cursor,
use_prelinked_entities=use_prelinked_entities,
use_untyped_entities=use_untyped_entities,
)
@overrides
def _read(self, file_path: str):
"""
This dataset reader consumes the data from
https://github.com/jkkummerfeld/text2sql-data/tree/master/data
formatted using ``scripts/reformat_text2sql_data.py``.
Parameters
----------
file_path : ``str``, required.
For this dataset reader, file_path can either be a path to a file `or` a
path to a directory containing json files. The reason for this is because
some of the text2sql datasets require cross validation, which means they are split
up into many small files, for which you only want to exclude one.
"""
files = [
p
for p in glob.glob(file_path)
if self._cross_validation_split_to_exclude not in os.path.basename(p)
]
schema = util.read_dataset_schema(self._schema_path)
for path in files:
with open(cached_path(path), "r") as data_file:
data = json.load(data_file)
for sql_data in util.process_sql_data(
data,
use_all_sql=self._use_all_sql,
remove_unneeded_aliases=self._remove_unneeded_aliases,
schema=schema,
):
linked_entities = sql_data.sql_variables if self._use_prelinked_entities else None
instance = self.text_to_instance(
sql_data.text_with_variables, linked_entities, sql_data.sql
)
if instance is not None:
yield instance
@overrides
def text_to_instance(
self, # type: ignore
query: List[str],
prelinked_entities: Dict[str, Dict[str, str]] = None,
sql: List[str] = None,
) -> Instance:
fields: Dict[str, Field] = {}
tokens = TextField([Token(t) for t in query], self._token_indexers)
fields["tokens"] = tokens
if sql is not None:
action_sequence, all_actions = self._world.get_action_sequence_and_all_actions(
sql, prelinked_entities
)
if action_sequence is None and self._keep_if_unparsable:
print("Parse error")
action_sequence = []
elif action_sequence is None:
return None
index_fields: List[Field] = []
production_rule_fields: List[Field] = []
for production_rule in all_actions:
nonterminal, _ = production_rule.split(" ->")
production_rule = " ".join(production_rule.split(" "))
field = ProductionRuleField(
production_rule, self._world.is_global_rule(nonterminal), nonterminal=nonterminal
)
production_rule_fields.append(field)
valid_actions_field = ListField(production_rule_fields)
fields["valid_actions"] = valid_actions_field
action_map = {
action.rule: i # type: ignore
for i, action in enumerate(valid_actions_field.field_list)
}
for production_rule in action_sequence:
index_fields.append(IndexField(action_map[production_rule], valid_actions_field))
if not action_sequence:
index_fields = [IndexField(-1, valid_actions_field)]
action_sequence_field = ListField(index_fields)
fields["action_sequence"] = action_sequence_field
return Instance(fields)