In [None]:
"""
# 思路：束搜索，束大小为 B ，每次迭代选取前 b 个最优指令，对错例进行 s - 1 次去重采样，经过反思和细化，生成新指令，满足 b * s = B ，得到新束
# 过程中保留上一次迭代的最优指令，原因是每一次反思与细化的结果不稳定，可能会导致性能下降，因此需要保留上一次迭代的最优指令
def APE_beam(eval_file_path, iter_limit, B, b, s):
    # prompt_insts = [[{"inst": initial prompt_inst, "accuracy": 0.0, "responses_path": ""} * B]]
    prompt_insts = [get_initial_prompt_insts(B)]
    best_prompt_inst = {"inst": "", "accuracy": 0.0, "responses_path": ""}
    for i in range(iter_limit):
        for inst_dict in prompt_insts[i]:
            if inst_dict["responses_path"] == "":
                responses = get_responses(inst_dict["inst"], dataset)
                store responses to file and set inst_dict["responses_path"]
                inst_dict["accuracy"] = calculate_accuracy(responses)
                update best_prompt_inst if needed
        if best_prompt_inst["accuracy"] == 1.0:
            break
        Record indexes of top b prompt_insts
        prompt_insts.append([])
        for index in top_b_indexes:
            prompt_insts[i+1].append(prompt_insts[i][index])
            all_error_examples = [select error examples from dataset]
            error_example_sets = [sample s - 1 times from all_error_examples, and remove duplicates]
            for error_example_set in error_example_sets:
                reflection_prompt = get_reflection_prompt(prompt_insts[i][index], error_example_set)
                reasons = reflect(reflection_prompt)
                refinement_prompt = get_refinement_prompt(prompt_insts[i][index], error_example_set, reasons)
                prompt_insts[i+1].append({"inst": refine(refinement_prompt), "accuracy": 0.0})
    return best_prompt_inst
"""
""" Globals """
""" for APE_utils """
TARGET_MODEL = "gpt-3.5-turbo"
OPTIMIZER_MODEL = "gpt-3.5-turbo"
""" for data_process """
TASK_CODE_NAME = 'beam-trial-8'
MOVIE_RESPONSES_DIR = "../data/response/movie" + '/' + TASK_CODE_NAME
GSM8K_RESPONSES_DIR = "../data/response/gsm8k" + '/' + TASK_CODE_NAME

DEBUG = True

from APE_utils import *

# CURRENT_DATASET = "movie"
CURRENT_DATASET = "gsm8k"
# eval_file_path = '../data/movie/eval.json'
eval_file_path = '../data/gsm8k/eval-1.json'
# if not os.path.exists(MOVIE_RESPONSES_DIR):
#     os.makedirs(MOVIE_RESPONSES_DIR)
if not os.path.exists(GSM8K_RESPONSES_DIR):
    os.makedirs(GSM8K_RESPONSES_DIR)

In [None]:
iter_limit = 4
B = 2
b = 2
s = 3

# def APE_beam(eval_file_path: str, iter_limit: int, B = 8, b = 2, s = 3):
# set globals
FEEDBACK_REASONS_NUM = 2
EVAL_SAMPLE_NUM = 50
ERROR_SAMPLE_NUM = 3
# init
eval_set = gen_samples_from_dataset(eval_file_path, EVAL_SAMPLE_NUM, keep_orginal_order=False)
prompt_insts = [get_initial_prompt_insts(CURRENT_DATASET, B)]   # 之后扩充一下
best_prompt_inst = {"inst": "", "accuracy": 0.0, "responses_path": "", "error_num": 0}
reflection_refinement_record = []   # [[(reflection_prompt, reasons, refinement_prompt, improved prompts),...],...]

In [None]:
for i in range(iter_limit):
    if DEBUG:
        print(f"\n\n>>> Current iteration: {i}")
    for inst_dict, inst_dict_id in zip(prompt_insts[i], range(len(prompt_insts[i]))):
        if inst_dict["responses_path"] == "":
            responses = get_target_model_responses(
                CURRENT_DATASET, TARGET_MODEL, MOVIE_RESPONSES_DIR, GSM8K_RESPONSES_DIR, 
                eval_set, inst_dict["inst"], if_print=False
            )
            inst_dict["responses_path"] = write_target_model_responses(
                TARGET_MODEL + f"_{i}-{inst_dict_id}",
                MOVIE_RESPONSES_DIR if CURRENT_DATASET == "movie" else GSM8K_RESPONSES_DIR, 
                responses
            )
            inst_dict["accuracy"] = evaluation_movie(CURRENT_DATASET, inst_dict["responses_path"]) \
                                    if CURRENT_DATASET == 'movie' \
                                    else evaluation_gsm8k(CURRENT_DATASET, inst_dict["responses_path"])
            if inst_dict["accuracy"] > best_prompt_inst["accuracy"]:
                best_prompt_inst = inst_dict
        if DEBUG:
            print(f">>> Current prompt instance, iteration {i}, id {inst_dict_id}")
            print(f">>> Current accuracy rate: {inst_dict['accuracy']}")
    if best_prompt_inst["accuracy"] == 1.0:
        print(f"Early stop at iteration {i}")
        break
    # sort
    top_b_indexes = []
    for _ in range(b):
        max_accuracy = 0.0
        max_accuracy_index = -1
        for inst_dict_id in range(len(prompt_insts[i])):
            if inst_dict_id not in top_b_indexes and prompt_insts[i][inst_dict_id]["accuracy"] > max_accuracy:
                max_accuracy = prompt_insts[i][inst_dict_id]["accuracy"]
                max_accuracy_index = inst_dict_id
        top_b_indexes.append(max_accuracy_index)
    if DEBUG:
        print(f"\n>>> Top b indexes in this iteration: {top_b_indexes}")
        print("Now begin to generate new prompt instances...")
    reflection_refinement_record.append([])
    prompt_insts.append([])
    # sample and update
    for index in top_b_indexes:
        # if DEBUG:
        #     print(f">>> for index {index} in top b indexes:")
        prompt_insts[i+1].append(prompt_insts[i][index].copy())
        error_example_sets = get_error_example_sets_movie(prompt_insts[i][index]["responses_path"], ERROR_SAMPLE_NUM, s - 1) \
                            if CURRENT_DATASET == 'movie' \
                            else get_error_example_sets_gsm8k(prompt_insts[i][index]["responses_path"], ERROR_SAMPLE_NUM, s - 1)
        for error_example_set, set_index in zip(error_example_sets, range(len(error_example_sets))):
            # if DEBUG:
            #     print(f">>> for error example set {set_index}:")
            reflection_prompt = gen_reflection(FEEDBACK_REASONS_NUM, prompt_insts[i][index]["inst"], error_example_set)
            reasons = get_reflection_from_optimizer(OPTIMIZER_MODEL, reflection_prompt)
            refinement_prompt = gen_refinement(prompt_insts[i][index]["inst"], error_example_set, reasons)
            improved_inst = get_refinement_from_optimizer(OPTIMIZER_MODEL, refinement_prompt)[0]
            prompt_insts[i+1].append({"inst": improved_inst, "accuracy": 0.0, "responses_path": ""})
            reflection_refinement_record[i].append((reflection_prompt, reasons, refinement_prompt, improved_inst))
print("="*50)

In [None]:
# print prompt_insts
print(f"Prompt instances:")
for iter_list, iter_index in zip(prompt_insts, range(len(prompt_insts))):
    print(f"In iteration {iter_index}: ")
    for inst_dict, dict_index in zip(iter_list, range(len(iter_list))):
        print(f">>> {dict_index}")
        print(f">>> Prompt: {inst_dict['inst']}")
        print(f">>> Accuracy: {inst_dict['accuracy']}")
        print(f">>> Responses path: {inst_dict['responses_path']}")
    print("="*50)
print(f"Best prompt instance: {best_prompt_inst}")
record_save_path = MOVIE_RESPONSES_DIR if CURRENT_DATASET == "movie" else GSM8K_RESPONSES_DIR
record_save_path += "/record.txt"
with open(record_save_path, "w") as f:
    for i in range(len(reflection_refinement_record)):
        f.write(f"\n\n>>> Iteration {i}")
        f.write("\n")
        for record in reflection_refinement_record[i]:
            f.write(f">>> Reflection prompt: \n{record[0]}")
            f.write("\n")
            f.write(f">>> Reasons: \n{record[1]}")
            f.write("\n")
            f.write(f">>> Refinement prompt: \n{record[2]}")
            f.write("\n")
            f.write(f">>> Improved prompt: \n{record[3]}")
            f.write("\n")
# return best_prompt_inst

In [None]:
# 统计每一次迭代中所有指令的正确率
accuracy_list = []
for iter_list in prompt_insts:
    accuracy_list.append([inst_dict["accuracy"] for inst_dict in iter_list])
# 打印时保留 3 位小数
print("Accuracy list:")
for accuracy_list_item in accuracy_list:
    print([round(accuracy, 3) for accuracy in accuracy_list_item])
    print("average: ", round(sum(accuracy_list_item) / len(accuracy_list_item), 3), "; max: ", round(max(accuracy_list_item), 3))


In [None]:
"""
从 ../data/gsm8k/test.jsonl 中载入数据；
从 ../data/gsm8k/eval.json 中载入数据；
对 test.jsonl 中的数据进行采样，得到 100 个样本，确保与 eval.json 中的数据不重复；
重复的判断标准是 question 字段相同
将采样结果写入 ../data/gsm8k/eval-1.json 中
"""
# eval_file_path = '../data/gsm8k/eval.json'
# test_file_path = '../data/gsm8k/test.jsonl'
# with open(test_file_path, "r") as f:
#     test_set = [json.loads(line) for line in f.readlines()]
#     with open(eval_file_path, "r") as f:
#         eval_set = json.load(f)
#         random.shuffle(test_set)
#         new_eval_set = []
#         while len(new_eval_set) < 100:
#             new_item = test_set.pop()
#             question = new_item["question"]
#             flag = False
#             for eval_item in eval_set:
#                 if eval_item["question"] == question:
#                     flag = True
#                     break
#             if not flag:
#                 new_eval_set.append(new_item)
#         with open('../data/gsm8k/eval-1.json', "w") as f:
#             json.dump(new_eval_set, f, indent=4)
#             print("Done!")

In [None]:
""" test """
test_file_path = '../data/gsm8k/eval.json'
TEST_SAMPLE_NUM = 50
test_set = gen_samples_from_dataset(test_file_path, TEST_SAMPLE_NUM, keep_orginal_order=True)
test_responses = get_target_model_responses(
    CURRENT_DATASET, TARGET_MODEL, MOVIE_RESPONSES_DIR, GSM8K_RESPONSES_DIR, 
    test_set, best_prompt_inst["inst"], if_print=False
)
test_responses_path = write_target_model_responses(
    TARGET_MODEL + "_test",
    MOVIE_RESPONSES_DIR if CURRENT_DATASET == "movie" else GSM8K_RESPONSES_DIR, 
    test_responses
)
accuracy_rate = evaluation_gsm8k(CURRENT_DATASET, test_responses_path)

In [None]:
print(accuracy_rate)