/
generate_chars.py
157 lines (130 loc) · 4.81 KB
/
generate_chars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import time
import json
import os
import random
import re
import shutil
from functools import partial
from multiprocessing import Pool
from jinja2 import Template
import fire
import tqdm
from rouge_score import rouge_scorer
from src.util.io import read_jsonl, write_jsonl
from src.util.openai import openai_batch_completion, OpenAIDecodingArguments
NON_ALPHANUM_RE = re.compile(r"[^a-zа-яё0-9]+")
def tokenize(text):
text = text.lower()
text = NON_ALPHANUM_RE.sub(" ", text)
return text.split()
def encode_prompt(example_chars, template_path):
with open(template_path) as f:
template = Template(f.read())
for char in example_chars:
char.pop("most_similar_chars", None)
char.pop("avg_similarity_score", None)
return template.render(
example_chars=json.dumps(example_chars, ensure_ascii=False)
).strip() + "\n"
def post_process(response):
if not response:
return []
if response["finish_reason"] == "length":
return []
raw_content = response["message"]["content"]
try:
chars = json.loads(raw_content)
if isinstance(chars, list):
return chars
elif isinstance(chars, dict):
return chars["characters"]
except Exception:
return []
def generate_chars(
output_path: str,
seed_chars_path: str,
template_path: str,
num_chars_to_generate: int = 200,
model_name: str = "gpt-4",
request_batch_size: int = 5,
temperature: float = 1.0,
top_p: float = 0.95,
num_cpus: int = 8,
rouge_cutoff: float = 0.24
):
random.seed(43)
seed_chars = [json.loads(line) for line in open(seed_chars_path, "r")]
print(f"Loaded {len(seed_chars)} character examples")
machine_chars = []
if os.path.exists(output_path):
machine_chars = read_jsonl(output_path)
print(f"Loaded {len(machine_chars)} machine-generated characters")
all_descriptions = [d["context"] for d in seed_chars + machine_chars]
all_description_tokens = [tokenize(d) for d in all_descriptions]
request_idx = 0
progress_bar = tqdm.tqdm(total=num_chars_to_generate)
if machine_chars:
progress_bar.update(len(machine_chars))
is_prompt_printed = False
is_output_printed = False
while len(machine_chars) < num_chars_to_generate:
request_idx += 1
batch = []
for _ in range(request_batch_size):
if machine_chars:
prompt_chars = random.sample(machine_chars, 1)
prompt_chars += random.sample(seed_chars, 1)
else:
prompt_chars = random.sample(seed_chars, 2)
random.shuffle(prompt_chars)
prompt = encode_prompt(prompt_chars, template_path)
messages = [{"role": "user", "content": prompt}]
batch.append(messages)
if not is_prompt_printed:
is_prompt_printed = True
print("Prompt example:")
for message in batch[0]:
print("Role: {}, content: {}".format(message["role"], message["content"]))
request_start = time.time()
results = openai_batch_completion(
batch=batch,
model_name=model_name,
decoding_args=OpenAIDecodingArguments(
temperature=temperature,
top_p=top_p
)
)
if not is_output_printed:
is_output_printed = True
print("Output example:")
print(results[0].message["content"])
request_duration = time.time() - request_start
process_start = time.time()
new_chars = []
for result in results:
new_chars.extend(post_process(result))
total = len(new_chars)
keep = 0
for new_char in new_chars:
new_description_tokens = tokenize(new_char["context"])
with Pool(num_cpus) as p:
rouge_scores = p.map(
partial(rouge_scorer._score_lcs, new_description_tokens),
all_description_tokens,
)
rouge_scores = [score.fmeasure for score in rouge_scores]
if max(rouge_scores) > rouge_cutoff:
continue
keep += 1
machine_chars.append(new_char)
all_descriptions.append(new_char["context"])
all_description_tokens.append(new_description_tokens)
progress_bar.update(1)
process_duration = time.time() - process_start
print(f"Request {request_idx} took {request_duration:.2f}s, processing took {process_duration:.2f}s")
print(f"Generated {total} chars, kept {keep} chars")
print("===================================")
write_jsonl(machine_chars, output_path + "_tmp")
shutil.move(output_path + "_tmp", output_path)
if __name__ == "__main__":
fire.Fire(generate_chars)