-
Notifications
You must be signed in to change notification settings - Fork 10
/
creative_writing_utils.py
389 lines (306 loc) · 14.7 KB
/
creative_writing_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
import re
from lib.run_query import run_query
import openai
import concurrent.futures
import time
N_THREADS = 1 # Parallellises the judge prompts (only relevant if COMBINE_CRITERIA == False)
openai_client_judge = None # Separate client for the openai judge because the test model openai client might be
# using a different openai-compatible url
SKIP_ANALYSIS = False # Skips the "detailed analysis" part of the judge prompt (default False)
COMBINE_CRITERIA = True # Combine all the criteria sets into one big judge prompt (default True)
INCLUDE_REFERENCE = True # Include the exemplar reference output (default True)
RELATIVE_SCORING = False # Use relative scoring system (relative to the reference). (default False) ! This doesn't work very well
if RELATIVE_SCORING:
INCLUDE_REFERENCE = True
TEST_MODEL_SEES_CRITERIA = False # Is the test model shown the scoring criteria? (default False) ! This seems to produce worse results
CRITERIA_TO_IGNORE = [ # Removed these criteria for now as they were weakly discriminative
'Appropriate Length',
"Unearned Resolution: Characters' disagreements or tensions are too quickly or easily resolved, without exploring the depth or implications of the conflict.",
"Melodramatic",
"Clever / Witty",
"Gripping",
"Effective Use of Tropes: If applicable, common narrative tropes are employed thoughtfully and subverted, deconstructed, or used in service of the story's themes and character",
"Correct Spelling & Grammar"
]
def process_criteria(criteria_set, writing_prompt, reference_output, test_model_response, verbose, judge_params, judge_temp):
judging_prompt = create_judging_prompt(criteria_set, writing_prompt, reference_output, test_model_response)
#print(judging_prompt)
# Run judging process using judge model
success = False
tries = 0
#judge_temp = 0
while not success and tries < 3:
try:
judge_model_response = run_query(judge_params['judge_model'], None, judging_prompt, [], 3000, judge_params['judge_model'], None, judge_temp, judge_params['judge_model_api'], None, False, None, openai_client_judge, api_key=judge_params['judge_model_api_key'])
if judge_model_response:
success = True
else:
print('! Empty output from judge model')
tries += 1
except Exception as e:
print(e)
time.sleep(30)
tries += 1
#judge_temp += 0.2
if verbose:
print(judge_model_response)
return judge_model_response
def process_writing_prompt(prompt_id, prompt_data, model_path, prompt_type, model, tokenizer, results, run_index,
run_iter, verbose, n_prompt_attempts, inference_engine, ooba_instance,
launch_ooba, ooba_request_timeout, openai_client, judge_params, test_model_output = None, judgemark_test_model = None):
global openai_client_judge, SKIP_ANALYSIS, COMBINE_CRITERIA, N_THREADS, CRITERIA_TO_IGNORE
if test_model_output:
TEST_TYPE = 'judgemark'
else:
TEST_TYPE = 'creative-writing'
combined_criteria = []
for criteria_set in prompt_data['judging_criteria']:
combined_criteria += criteria_set['criteria']
combined_criteria = list(reversed(combined_criteria))
filtered_criteria = [x for x in combined_criteria if x not in CRITERIA_TO_IGNORE]
if judge_params['judge_model_api'] == 'openai' and not openai_client_judge:
openai_client_judge = openai.OpenAI(
#base_url='https://api.together.xyz',
api_key=judge_params['judge_model_api_key'],
)
writing_prompt = "You are a talented creative writer of compelling, original prose.\n\n"
if TEST_MODEL_SEES_CRITERIA:
writing_prompt += 'You are taking a creative writing test. These will be the assessment criteria to help direct your writing:\n' + '\n'.join(filtered_criteria) + '\n\n'
writing_prompt += prompt_data['writing_prompt']
judging_criteria = prompt_data['judging_criteria']
reference_output = prompt_data['reference_output']
# Add the seed to modify the writing prompt for this iteration
seed_index = (int(run_iter) - 1) % len(prompt_data['seed_modifiers'])
writing_prompt = writing_prompt.replace("<SEED>", prompt_data['seed_modifiers'][seed_index])
if verbose:
print(writing_prompt)
judging_tries = 0
judging_success = False
judge_temp = 0.01
while not judging_success and judging_tries <= 3:
judging_tries += 1
try:
if TEST_TYPE == 'creative-writing' and prompt_id in results[run_index]['iterations'][run_iter]['test_model_response'] and len(results[run_index]['iterations'][run_iter]['test_model_response'][prompt_id]) > 400:
# this is for when the test has been conducted manually somewhere and the test_model_response fields have been populated but require judging
test_model_output = results[run_index]['iterations'][run_iter]['test_model_response'][prompt_id]
# Generate response from test model
test_generation_success = False
if test_model_output != None: # these are pregenerated if we are running judgemark
test_model_response = test_model_output
test_generation_success = True # automatic success!
else:
temp = 0.7
tries = 0
while not test_generation_success and tries <= 3:
tries += 1
this_writing_prompt = writing_prompt
if tries > 1:
this_writing_prompt += '\n\nDo not use language or themes that would get flagged by your content filter (Keep it PG-13).'
# Generate response from test model for creative writing benchmark
test_model_response = run_query(model_path, prompt_type, writing_prompt, [], 3000, model, tokenizer, temp, inference_engine, ooba_instance, launch_ooba, ooba_request_timeout, openai_client)
if not test_model_response or len(test_model_response) < 300:
temp += 0.1
if temp > 1:
temp = 1
print(test_model_response)
print('! Missing or too short output from test model')
if tries <= 5:
print('retrying...')
continue
test_generation_success = True
if not test_model_response or len(test_model_response) < 300:
print(test_model_response)
print('! Failed to get output from test model')
return None
if verbose and TEST_TYPE != 'judgemark':
print(test_model_response)
scores = {}
judge_model_responses = []
scores = {}
judge_model_responses = []
if COMBINE_CRITERIA:
judge_model_response = process_criteria({
'criteria': combined_criteria,
'prefix_text': 'Now, rate the supplied model output on the following criteria:'
}, writing_prompt, reference_output, test_model_response, verbose, judge_params, judge_temp)
# gemini likes to add *'s as markdown formatting. we can safely strip these out.
judge_model_response = judge_model_response.replace('*','')
# other models (like wizardlm 8x22) like to add square brackets
judge_model_response = judge_model_response.replace('[','').replace(']', '')
if not parse_scores(judge_model_response):
print(judge_model_response)
print('! Failed to parse scores in judge response')
judge_temp += 0.2
continue
scores.update(parse_scores(judge_model_response))
judge_model_responses.append(judge_model_response)
judging_success = True
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=N_THREADS) as executor:
future_to_criteria = {executor.submit(process_criteria, criteria_set): criteria_set for criteria_set in judging_criteria}
for future in concurrent.futures.as_completed(future_to_criteria):
judge_model_response = future.result()
# gemini likes to add *'s as markdown formatting. we can safely strip these out.
judge_model_response = judge_model_response.replace('*','')
# other models (like wizardlm 8x22) like to add square brackets
judge_model_response = judge_model_response.replace('[','').replace(']', '')
scores.update(parse_scores(judge_model_response))
judge_model_responses.append(judge_model_response)
judging_success = True
except Exception as e:
print(e)
if verbose:
print_score(scores)
if not judging_success:
return {}
# Store scores and responses in results dict
if TEST_TYPE == 'creative-writing':
results[run_index]['iterations'][run_iter]['individual_scores'][prompt_id] = scores
results[run_index]['iterations'][run_iter]['test_model_response'][prompt_id] = test_model_response
results[run_index]['iterations'][run_iter]['judge_model_response'][prompt_id] = judge_model_responses
elif TEST_TYPE == 'judgemark':
results[run_index]['iterations'][run_iter]['judgemark_results'][judgemark_test_model]['individual_scores'][prompt_id] = scores
results[run_index]['iterations'][run_iter]['judgemark_results'][judgemark_test_model]['test_model_response'][prompt_id] = test_model_response
results[run_index]['iterations'][run_iter]['judgemark_results'][judgemark_test_model]['judge_model_response'][prompt_id] = judge_model_responses
return scores
def parse_scores(judge_model_response):
scores = {}
# Parse scores using regex
score_pattern = r'(.*?):\s*(?:Score\s+)?(-?\d+(?:\.\d+)?)'
matches = re.findall(score_pattern, judge_model_response)
for match in matches:
metric_name = match[0].strip()
score = float(match[1])
scores[metric_name] = score
return scores
def print_score(scores, RELATIVE_SCORING=False):
if not scores:
print('! No scores were parseable')
return
scoresum = 0
neg_criteria = [
"melodramatic",
"shallow resolution",
"unearned resolution", # old naming
"simplistic moralizing",
"shallow optimism",
"forced optimism", # old naming
"trite",
"overwrought",
"amateurish",
"contrived",
"uninspiring",
"characters are too good",
"incongruent ending positivity",
"unearned transformations",
"profundity over-reach",
"amateurish descriptives",
"clunky asides and interruptive sentence structures",
"stilted dialogue",
"tit-for-tat dialogue"
]
for criteria, score in scores.items():
criteria_lower = criteria.lower().strip()
if RELATIVE_SCORING:
if any(neg_criterion in criteria_lower for neg_criterion in neg_criteria):
scoresum += ((-1 * score) + 10) / 2
else:
scoresum += (score + 10) / 2
else:
if any(neg_criterion in criteria_lower for neg_criterion in neg_criteria):
scoresum += 10 - score
else:
scoresum += score
print('This question score:', round(10 * scoresum / len(scores)))
def create_judging_prompt(criteria_set, writing_prompt, reference_output, test_model_response):
criteria = [x for x in criteria_set['criteria'] if x not in CRITERIA_TO_IGNORE]
prefix_text = criteria_set['prefix_text']
criteria_str = '\n'.join(criteria)
analysis_section_1 = """
- You are to write a comprehensive analysis for each of the metrics, then give your scores.
"""
analysis_section_2 = """
[Analysis]
Write your detailed analysis.
"""
if SKIP_ANALYSIS:
analysis_section_1 = ""
analysis_section_2 = ""
if RELATIVE_SCORING:
relative_section_1 = """You are an expert in assessing creative writing. Your task is to score the quality of the test model's response above in comparison to the reference, by several metrics, on a -10 to 10 scale.
Scoring notes:
- You are not scoring the quality of the prompt or the reference response, only the test model response.
- The reference model response is to be considered a high quality exemplar.
- Scores are relative to the quality of the reference output. A score of zero means equal to reference. Below 0 means worse than the reference. Above 0 means better than the reference.
- The minimum score is -10 and the maximum is 10.
- If the criteria is not relevant to the piece (e.g. criteria relating to dialogue but the piece has none), exclude it from your assessment."""
relative_section_2 = "Score [-10 to 10]"
else:
ref_str = ""
if INCLUDE_REFERENCE:
ref_str = """
- You are not scoring the quality of the prompt or the reference response, only the test model response.
- The reference model response is to be considered a high quality exemplar.
"""
relative_section_1 = f"""You are an expert in assessing creative writing. Your task is to score the quality of the test model's response above, by several metrics, on a 0-10 scale.
Scoring notes:
{ref_str}
- Scores of 0 or 10 should not be considered highly unlikely just because they are the max/min. Use the full scoring range as appropriate.
- If the criteria is not relevant to the piece (e.g. criteria relating to dialogue but the piece has none), exclude it from your assessment."""
relative_section_2 = "[Score 0-10]"
reference_section_1 = ""
if INCLUDE_REFERENCE:
reference_section_1 = f"""
[REFERENCE RESPONSE (DO NOT JUDGE)]
{reference_output}
[REFERENCE RESPONSE END]
"""
# Construct judging prompt
judging_prompt = f"""
You are an expert in assessing creative writing. Your task is to score the quality of the test model's response below, by several metrics, on a 0-10 scale.
[PROMPT START]
{writing_prompt}
[PROMPT END]
{reference_section_1}
[TEST MODEL RESPONSE]
{test_model_response}
[TEST MODEL RESPONSE END]
[Task]
{relative_section_1}
- Do not add any commentary or explanation to the scores section.
- Do not penalize NSFW output; that is the default for some models. Just assess it on its merits.
- Everything within the "TEST MODEL RESPONSE" section was written by the test model. Sometimes models like to write comments on the piece after the piece is concluded; if this happens you should ignore their comments.
- In the output, write the metric names exactly as below so they can be parsed.
- Some models produce overly long outputs. You should neither penalise nor favour this if it happens; simply assess the writing on its merit. You should however penalise overly short pieces.
- The test model's output can suddenly truncate because of token length constraints. If you notice that this has occurred, don't penalise it.
- Do not use markdown in your response. Use the designated output format exactly.
- Some models have a positivity bias that produces worse writing, hence the criteria about that. Don't let the over-abundance of these criteria influence your assessment; it will only apply to some model outputs and you will know it when you see it. Likewise, there are a lot of "negative" critical criteria; these will not always apply and don't let their over-abundance colour your perception of the writing.
- For these criteria, lower is better:
Trite
Overwrought
Amateurish
Contrived
Uninspiring
Simplistic Moralizing
Shallow Optimism
Unearned Transformations
Incongruent Ending Positivity
Characters are Too Good
Shallow Resolution
Repetitive Tit-for-Tat Dialogue
Stilted Dialogue
Clunky Asides and Interruptive Sentence Structures
Amateurish Descriptives
Profundity Over-reach
- You are a critic, so be honest, objective, critical and discriminative. No need to be charitable; say what you genuinely think.
{analysis_section_1}
- Output format is:
{analysis_section_2}
[Scores]
Metric 1 name: {relative_section_2}
Metric 2 name: ...
---
{prefix_text}
{criteria_str}
"""
return judging_prompt