-
Notifications
You must be signed in to change notification settings - Fork 0
/
inspection.py
101 lines (83 loc) · 2.68 KB
/
inspection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import json
from dataclasses import dataclass
from pathlib import Path
from string import Template
import dotenv
import tyro
from openai import OpenAI
from tqdm import tqdm
import utils
dotenv.load_dotenv()
@dataclass
class ScriptArguments:
input_filepath: str
"""Path to the file with descriptions"""
dataset_name: str
"""Name of the dataset"""
output_filepath: str
"""Where to save the API responses"""
args: ScriptArguments = tyro.cli(ScriptArguments)
def get_message_list(class_name, input_text):
system_message = dict(
role="system",
content=f"""You are a knowledgeable teacher. Answer the questions in JSON format.""",
)
user_msg_template = Template(
"You want to explain what a ${cn} is to your students. Does the following text snippet mention any specific details about ${cn} that increases your students' knowledge about ${cn}? Answer yes or no. Provide an explanation for your answer.\n\nText snippet: ${it}"
)
message_list = list()
message_list.append(system_message)
message_list.append(
dict(
role="user",
content=user_msg_template.substitute(
cn="tench", it="A photo of a tench, with dark green color."
),
)
)
message_list.append(
{
"role": "assistant",
"content": json.dumps(
dict(
explanation="It teaches the students about the color of a tench.",
increases_knowledge="Yes",
),
indent=0,
),
}
)
message_list.append(
dict(
role="user",
content=user_msg_template.substitute(cn=class_name, it=input_text),
)
)
return message_list
class_names = utils.load_classnames(args.dataset_name)
with open(args.input_filepath, "r") as f:
descriptions = json.load(f)
gen_kwargs = dict(
model="gpt-3.5-turbo-1106",
temperature=0.1,
top_p=0.1,
frequency_penalty=0,
presence_penalty=0,
response_format={"type": "json_object"},
max_tokens=128,
)
client = OpenAI()
response_list = list()
for cls_name, cls_descs in tqdm(zip(class_names, descriptions)):
response_list.append([])
for desc in cls_descs:
response = client.chat.completions.create(
messages=get_message_list(class_name=cls_name, input_text=desc),
**gen_kwargs,
)
response_json_str = response.model_dump_json()
response_list[-1].append(response_json_str)
output_filepath = Path(args.output_filepath)
output_filepath.parent.mkdir(exist_ok=True, parents=True)
with open(output_filepath, "w") as f:
json.dump(response_list, f)