In [None]:
from utils import *
""" Globals """
TARGET_MODEL = "gpt-3.5-turbo"
OPTIMIZER_MODEL = "gpt-3.5-turbo"
TASK_CODE_NAME = 'summary-beam-2'
MOVIE_RESPONSES_DIR = "../data/response/movie" + '/' + TASK_CODE_NAME
GSM8K_RESPONSES_DIR = "../data/response/gsm8k" + '/' + TASK_CODE_NAME
DEBUG = True
CURRENT_DATASET = "gsm8k"
eval_file_path = '../data/gsm8k/eval.json'
if not os.path.exists(GSM8K_RESPONSES_DIR):
    os.makedirs(GSM8K_RESPONSES_DIR)

In [None]:
iter_limit = 6
B = 2
b = 2
s = 3
FEEDBACK_REASONS_NUM = 2
EVAL_SAMPLE_NUM = 50
ERROR_SAMPLE_NUM = 3
# init
random.seed(2)
eval_set = gen_samples_from_dataset(eval_file_path, EVAL_SAMPLE_NUM, keep_orginal_order=False)
prompt_insts = [get_initial_prompt_insts(CURRENT_DATASET, B)]
best_prompt_inst = {"inst": "", "accuracy": 0.0, "responses_path": ""}
reflection_refinement_record = []   # [[(reflection_prompt, reasons, refinement_prompt, improved prompts),...],...]

In [None]:
for i in range(iter_limit):
    if DEBUG:
        print(f"\n\n>>> Current iteration: {i}")
    for inst_dict, inst_dict_id in zip(prompt_insts[i], range(len(prompt_insts[i]))):
        if inst_dict["responses_path"] == "":
            responses = get_target_model_responses(
                CURRENT_DATASET, TARGET_MODEL, MOVIE_RESPONSES_DIR, GSM8K_RESPONSES_DIR, 
                eval_set, inst_dict["inst"], if_print=False
            )
            inst_dict["responses_path"] = write_target_model_responses(
                TARGET_MODEL + f"_{i}-{inst_dict_id}",
                MOVIE_RESPONSES_DIR if CURRENT_DATASET == "movie" else GSM8K_RESPONSES_DIR, 
                responses
            )
            inst_dict["accuracy"] = evaluation_movie(CURRENT_DATASET, inst_dict["responses_path"]) \
                                    if CURRENT_DATASET == 'movie' \
                                    else evaluation_gsm8k(CURRENT_DATASET, inst_dict["responses_path"])
            if inst_dict["accuracy"] > best_prompt_inst["accuracy"]:
                best_prompt_inst = inst_dict
        if DEBUG:
            print(f">>> Current prompt instance, iteration {i}, id {inst_dict_id}")
            print(f">>> Current accuracy rate: {inst_dict['accuracy']}")
    if best_prompt_inst["accuracy"] == 1.0:
        print(f"Early stop at iteration {i}")
        break
    # sort
    top_b_indexes = []
    for _ in range(b):
        max_accuracy = 0.0
        max_accuracy_index = -1
        for inst_dict_id in range(len(prompt_insts[i])):
            if inst_dict_id not in top_b_indexes and prompt_insts[i][inst_dict_id]["accuracy"] > max_accuracy:
                max_accuracy = prompt_insts[i][inst_dict_id]["accuracy"]
                max_accuracy_index = inst_dict_id
        top_b_indexes.append(max_accuracy_index)
    if DEBUG:
        print(f"\n>>> Top b indexes in this iteration: {top_b_indexes}")
        print("Now begin to generate new prompt instances...")
    reflection_refinement_record.append([])
    prompt_insts.append([])
    # sample and update
    for index in top_b_indexes:
        prompt_insts[i+1].append(prompt_insts[i][index].copy())
        error_example_sets = get_error_example_sets_movie(prompt_insts[i][index]["responses_path"], ERROR_SAMPLE_NUM, s - 1) \
                            if CURRENT_DATASET == 'movie' \
                            else get_error_example_sets_gsm8k(prompt_insts[i][index]["responses_path"], ERROR_SAMPLE_NUM, s - 1)
        for error_example_set, set_index in zip(error_example_sets, range(len(error_example_sets))):
            reflection_prompt = gen_reflection(FEEDBACK_REASONS_NUM, prompt_insts[i][index]["inst"], error_example_set)
            reasons = get_reflection_from_optimizer(OPTIMIZER_MODEL, reflection_prompt)
            refinement_prompt = gen_refinement(prompt_insts[i][index]["inst"], error_example_set, reasons)
            improved_inst = get_refinement_from_optimizer(OPTIMIZER_MODEL, refinement_prompt)[0]
            prompt_insts[i+1].append({"inst": improved_inst, "accuracy": 0.0, "responses_path": ""})
            reflection_refinement_record[i].append((reflection_prompt, reasons, refinement_prompt, improved_inst))
print("="*50)

In [None]:
print(f"Prompt instances:")
for iter_list, iter_index in zip(prompt_insts, range(len(prompt_insts))):
    print(f"In iteration {iter_index}: ")
    for inst_dict, dict_index in zip(iter_list, range(len(iter_list))):
        print(f">>> {dict_index}")
        print(f">>> Prompt: {inst_dict['inst']}")
        print(f">>> Accuracy: {inst_dict['accuracy']}")
        print(f">>> Responses path: {inst_dict['responses_path']}")
    print("="*50)
print(f"Best prompt instance: {best_prompt_inst}")
record_save_path = MOVIE_RESPONSES_DIR if CURRENT_DATASET == "movie" else GSM8K_RESPONSES_DIR
record_save_path += "/record.txt"
with open(record_save_path, "w") as f:
    for i in range(len(reflection_refinement_record)):
        f.write(f"\n\n>>> Iteration {i}")
        f.write("\n")
        for record in reflection_refinement_record[i]:
            f.write(f">>> Reflection prompt: \n{record[0]}")
            f.write("\n")
            f.write(f">>> Reasons: \n{record[1]}")
            f.write("\n")
            f.write(f">>> Refinement prompt: \n{record[2]}")
            f.write("\n")
            f.write(f">>> Improved prompt: \n{record[3]}")
            f.write("\n")

In [None]:
accuracy_list = []
for iter_list in prompt_insts:
    accuracy_list.append([inst_dict["accuracy"] for inst_dict in iter_list])
print("Accuracy list:")
for accuracy_list_item in accuracy_list:
    print([round(accuracy, 3) for accuracy in accuracy_list_item])
    print("average: ", round(sum(accuracy_list_item) / len(accuracy_list_item), 3), "; max: ", round(max(accuracy_list_item), 3))

In [None]:
""" test """
test_file_path = '../data/gsm8k/test.json'
TEST_SAMPLE_NUM = 100
test_set = gen_samples_from_dataset(test_file_path, TEST_SAMPLE_NUM, keep_orginal_order=True)
test_responses = get_target_model_responses(
    CURRENT_DATASET, TARGET_MODEL, MOVIE_RESPONSES_DIR, GSM8K_RESPONSES_DIR, 
    test_set, best_prompt_inst["inst"], if_print=False
)
test_responses_path = write_target_model_responses(
    TARGET_MODEL + "_test",
    MOVIE_RESPONSES_DIR if CURRENT_DATASET == "movie" else GSM8K_RESPONSES_DIR, 
    test_responses
)
accuracy_rate = evaluation_gsm8k(CURRENT_DATASET, test_responses_path)
print(f"Accuracy rate on test set: {accuracy_rate}")

In [None]:
# """ initial test for comparison """
# test_file_path = '../data/gsm8k/test.json'
# TEST_SAMPLE_NUM = 100
# test_set = gen_samples_from_dataset(test_file_path, TEST_SAMPLE_NUM, keep_orginal_order=True)
# # 0 for seed-0; 0 for seed-1; 1 for seed-2
# initial_prompt_inst = prompt_insts[0][1]      # initial max
# responses = get_target_model_responses(
#     CURRENT_DATASET, TARGET_MODEL, MOVIE_RESPONSES_DIR, GSM8K_RESPONSES_DIR, 
#     test_set, initial_prompt_inst["inst"], if_print=False
# )
# responses_path = write_target_model_responses(
#     TARGET_MODEL + "_init-test",
#     MOVIE_RESPONSES_DIR if CURRENT_DATASET == "movie" else GSM8K_RESPONSES_DIR, 
#     responses
# )
# accuracy_rate = evaluation_gsm8k(CURRENT_DATASET, responses_path)
# print(f"Accuracy rate on test set with initial prompt: {accuracy_rate}")