-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
55 lines (49 loc) · 1.84 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import thermostat
import dataloader
if __name__ == "__main__":
config = {
"source": "imdb-bert-lig",
#"source": "data/Thermostat_imdb-albert-LayerIntegratedGradients.jsonl",
"sgn": "+", # TODO: "-"
"samples": 1,
"metric": "mean: 0.4", # TODO: {"name": "mean", "params": .2},
"dev": True
# searches = {"span", "total"}; "all"
}
if not os.path.isfile(config["source"]):
# Load source from Thermostat configuration
thermo_config = thermostat.load(config["source"], cache_dir='data')
# Convert to pandas DataFrame and then to JSON lines
source = thermo_config.to_pandas().to_json(orient='records', lines=True).splitlines()
else:
source = config["source"]
#TODO: Change to NamedTemporaryFile, this is bad
#Also TODO: put into module
loader = dataloader.Verbalizer(source, config=config)
explanations, texts, orders = loader()
print(explanations["convolution search"])
print(orders)
valid_keys = loader.filter_verbalizations(explanations, texts, orders, maxlen=120, mincoverage=.3)
for key in texts.keys():
if key in valid_keys:
txt = "SAMPLE:\n" + " ".join(texts[key]["input_ids"])
c = 0
txt_ = ""
for i in txt:
c += 1
txt_ += i
if c > 150:
if i == " ":
txt_ += "\n"
c = 0
else:
pass
print(txt_) # makeshift \n-ing
for expl_subclass in explanations.keys():
print("subclass '{}'".format(expl_subclass))
_ = explanations[expl_subclass][key][:5]
for __ in _:
print(__)
txt = ""
# TODO: pruned span search?