-
Notifications
You must be signed in to change notification settings - Fork 0
/
example.py
40 lines (31 loc) · 1.14 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from vilnius.evaluation import check_answer_binary, evaluate_fact_accuracy
from vilnius.fact import generate_facts
from vilnius.gpt3 import gpt3_query
from vilnius.graph import assign_names_to_nodes, generate_dag, plot_graph
from vilnius.prompt import (
generate_binary_question_prompt,
generate_templated_prompt_header,
)
from vilnius.question import few_shot_balanced_types, generate_all_pair_questions
G = generate_dag(n=5, p=0.4)
G = assign_names_to_nodes(G, use_real_words=False)
f = plot_graph(G)
f.savefig("example.png")
facts = generate_facts(G, fact_type="v1")
prompt_header = generate_templated_prompt_header(G, facts, prompt_type="v6")
questions = generate_all_pair_questions(G, facts)
question_prompt = generate_binary_question_prompt(
questions.iloc[0], few_shot_balanced_types(5, questions, exclude=[0])
)
prompt = prompt_header + question_prompt
print(prompt)
model_answer = gpt3_query(prompt)
print(model_answer)
print(
"\n\nIs the answer correct?",
check_answer_binary(questions.iloc[0]["answer"], model_answer),
)
print(
"\n\nAre the facts correct?",
evaluate_fact_accuracy(questions.iloc[0], model_answer),
)