In [None]:
"""
Final target: automated prompt engineering
def APE_single(dataset, iter_limit, expected_accuracy_rate):
    prompt_inst = initial prompt_inst according to dataset
    for i in range(iter_limit):
        responses = get_responses_from_target_model(prompt_inst)
        store responses to file
        accuracy_rate = calculate accuracy of responses
        if accuracy_rate > expected_accuracy rate:
            break
        if accuracy_rate has been decreasing for several iterations:
            prompt_inst = the hightest one ever occured
        error_examples = select error examples from responses
        reflection_prompt = get_reflection_prompt(prompt_inst, error_examples)
        reasons = reflect(reflection_prompt)
        refinement_prompt = get_refinement_prompt(prompt_inst, error_examples, reasons)
        improved_prompt = refine(refinement_prompt)
        prompt_inst = improved_prompt
    return prompt_inst

def APE_batch(dataset, iter_limit, expected_accuracy_rate):
    prompt_insts = {"0": initial prompt_inst according to dataset}
    best_prompt_inst = prompt_insts["0"]
    for i in range(iter_limit):
        responses = {key: get_responses_from_target_model(prompt_inst) for key, prompt_inst in prompt_insts}
        store responses to file
        for each key in responses:
            accuracy_rate = calculate accuracy of responses[key]
            update best_prompt_inst if needed
            if accuracy_rate > expected_accuracy rate:
                break
        delete the prompts with bad behavior in prompt_insts
        for each key in prompt_insts:
            error_examples = select error examples from responses[key]
            reflection_prompt = get_reflection_prompt(prompt_insts[key], error_examples)
            reasons = reflect(reflection_prompt)
            refinement_prompt = get_refinement_prompt(prompt_insts[key], error_examples, reasons)
            improved_prompt = refine(refinement_prompt) # {newkey1: ..., newkey2: ......, ......}
            prompt_insts += improved_prompt
    return best_prompt_inst
"""
""" Globals """
""" for APE_utils """
TARGET_MODEL = "gpt-3.5-turbo"
OPTIMIZER_MODEL = "gpt-3.5-turbo"
""" for data_process """
TASK_CODE_NAME = 'trial-1'
MOVIE_RESPONSES_DIR = "../data/response/movie" + '/' + TASK_CODE_NAME
GSM8K_RESPONSES_DIR = "../data/response/gsm8k" + '/' + TASK_CODE_NAME

DEBUG = True

In [None]:
from APE_utils import *

# CURRENT_DATASET = "movie"
CURRENT_DATASET = "gsm8k"
# eval_file_path = '../data/movie/eval.json'
eval_file_path = '../data/gsm8k/eval.json'
# if not os.path.exists(MOVIE_RESPONSES_DIR):
#     os.makedirs(MOVIE_RESPONSES_DIR)
if not os.path.exists(GSM8K_RESPONSES_DIR):
    os.makedirs(GSM8K_RESPONSES_DIR)

def APE_single(eval_file_path: str, iter_limit: int, expected_accuracy_rate = 1.0):
    # set globals
    FEEDBACK_REASONS_NUM = 3
    FEEDBACK_IMPROVED_PROMPTS_NUM = 1
    EVAL_SAMPLE_NUM = 10
    ERROR_SAMPLE_NUM = 4
    # init
    eval_set = gen_samples_from_dataset(eval_file_path, EVAL_SAMPLE_NUM, True)
    prompt_inst = MOVIE_INIT_INST if CURRENT_DATASET == 'movie' else GSM8K_INIT_INST
    best_prompt_inst = prompt_inst
    best_prompt_accuracy_rate = 0.0
    early_stop_record_list = []
    current_responses_file_path = ''
    prompt_inst_id = "0"
    prompt_inst_list = {prompt_inst_id: prompt_inst}
    reason_list = {}
    for i in range(iter_limit):
        if DEBUG:
            print(f"\n\n>>> Current iteration: {i}")
            print(f">>> Current prompt instance:\n{prompt_inst}")
        _, current_responses_file_path = get_target_model_responses(
            CURRENT_DATASET, TARGET_MODEL, MOVIE_RESPONSES_DIR, GSM8K_RESPONSES_DIR, 
            eval_set, prompt_inst, if_print=False
        )
        accuracy_rate = evaluation_movie(CURRENT_DATASET, current_responses_file_path) \
                        if CURRENT_DATASET == 'movie' \
                        else evaluation_gsm8k(CURRENT_DATASET, current_responses_file_path) # will read CURRENT_RESPONSES_FILE_PATH
        if DEBUG:
            print(f">>> Current accuracy rate: {accuracy_rate}")
        if accuracy_rate > expected_accuracy_rate:
            if DEBUG:
                print('>>> Better than expected accuracy rate, stop iteration')
            best_prompt_inst = prompt_inst
            break
        early_stop_record_list.append(accuracy_rate)
        if  len(early_stop_record_list) > 3 and \
            early_stop_record_list[-1] < early_stop_record_list[-2] and \
            early_stop_record_list[-2] < early_stop_record_list[-3] and \
            early_stop_record_list[-3] < early_stop_record_list[-4]:
            print(f">>> Early stop triggered at iteration {i}")
            break
        best_prompt_inst = prompt_inst if accuracy_rate > best_prompt_accuracy_rate else best_prompt_inst
        error_examples = get_error_examples_movie(CURRENT_DATASET, current_responses_file_path, ERROR_SAMPLE_NUM) \
                         if CURRENT_DATASET == 'movie' \
                         else get_error_examples_gsm8k(CURRENT_DATASET, current_responses_file_path, ERROR_SAMPLE_NUM)
        reflection_prompt = gen_reflection(FEEDBACK_REASONS_NUM, prompt_inst, error_examples)
        if DEBUG:
            print(f"\n>>> Reflection prompt:\n{reflection_prompt}")
        reasons = get_reflection_from_optimizer(OPTIMIZER_MODEL, reflection_prompt)
        if DEBUG:
            print(f"\n>>> Reasons:\n{reasons}")
        reason_list[prompt_inst_id] = reasons
        refinement_prompt = gen_refinement(FEEDBACK_IMPROVED_PROMPTS_NUM, prompt_inst, error_examples, reasons)
        # if DEBUG:
        #     print(f"\n>>> Refinement prompt:\n{refinement_prompt}")
        prompt_inst = get_refinement_from_optimizer(OPTIMIZER_MODEL, refinement_prompt)[0]
        prompt_inst_id = str(i+1)
        prompt_inst_list[prompt_inst_id] = prompt_inst
    print(f"Best prompt instance: {best_prompt_inst}")
    print(f"Accuracy rate list: {early_stop_record_list}")
    print(f"Prompt instance list:")
    for key in prompt_inst_list:
        print(f"Prompt instance {key}: {prompt_inst_list[key]}")
        if key in reason_list:
            print(f"Reasons: {reason_list[key]}")
    return best_prompt_inst

In [None]:
APE_single(eval_file_path, 6, 1.0)