-
Notifications
You must be signed in to change notification settings - Fork 344
/
scrambler.py
98 lines (79 loc) · 3.01 KB
/
scrambler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simple scrambling test generator."""
import random
from typing import Optional
from lit_nlp.api import components as lit_components
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
from lit_nlp.api import types
from lit_nlp.lib import utils
JsonDict = types.JsonDict
FIELDS_TO_SCRAMBLE_KEY = 'Fields to scramble'
def _scramble(val: str) -> str:
words = val.split(' ')
random.shuffle(words)
return ' '.join(words)
class Scrambler(lit_components.Generator):
"""Scramble all words in an example to generate a new example."""
def config_spec(self) -> types.Spec:
return {
FIELDS_TO_SCRAMBLE_KEY:
types.MultiFieldMatcher(
spec='input',
types=['TextSegment'],
select_all=True),
}
def is_compatible(self, model: lit_model.Model,
dataset: lit_dataset.Dataset) -> bool:
del model # Unused by Scrambler
return utils.spec_contains(dataset.spec(), types.TextSegment)
def generate(self,
example: JsonDict,
model: lit_model.Model,
dataset: lit_dataset.Dataset,
config: Optional[JsonDict] = None) -> list[JsonDict]:
"""Naively scramble all words in an example.
Note: Even if more than one field is to be scrambled, only a single example
will be produced, unlike other generators which will produce multiple
examples, one per field.
Args:
example: the example used for basis of generated examples.
model: the model.
dataset: the dataset.
config: user-provided config properties.
Returns:
examples: a list of generated examples.
"""
del model # Unused.
config = config or {}
# If config key is missing, generate no examples.
fields_to_scramble = list(config.get(FIELDS_TO_SCRAMBLE_KEY, []))
if not fields_to_scramble:
return []
# TODO(lit-dev): move this to generate_all(), so we read the spec once
# instead of on every example.
text_keys = [
key for key in utils.find_spec_keys(dataset.spec(), types.TextSegment)
if key in fields_to_scramble
]
if not text_keys:
return []
updates = {
text_key: _scramble(example[text_key])
for text_key in text_keys
}
new_example = utils.make_modified_input(example, updates, 'Scrambler')
return [new_example]