In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
from dotenv import load_dotenv
import pickle
import os
import plotly.express as px
import json
import os

from langchain_community.llms import HuggingFaceEndpoint
from langchain.chat_models import ChatHuggingFace, ChatOpenAI

from scripts.optimize_prompt import optimize_prompt, get_answers
from scripts.evaluation import (
    load_math_datasets,
    score_last_match_series,
)

load_dotenv()
pd.set_option("display.max_columns", None)
pd.set_option("display.expand_frame_repr", False)
pd.set_option("max_colwidth", None)

In [None]:
math_dataset = load_math_datasets()

eval_dataset, fewshot_dataset, validation_dataset = (
    math_dataset.select(range(30)),
    math_dataset.select(range(30, 35)),
    math_dataset.select(range(35, 50)),
)
eval_df, fewshot_df, validation_df = (
    pd.DataFrame(eval_dataset),
    pd.DataFrame(fewshot_dataset),
    pd.DataFrame(validation_dataset),
)

OUTPUT_DIR = "dump"
prompts_file_name = f"{OUTPUT_DIR}/prompts_gsm8k_mixtral.json"

In [None]:
USE_COT = False
USE_FEW_SHOT = False

# Setup student agent

In [None]:
from huggingface_hub import InferenceClient

llm_client = InferenceClient(
    model="HuggingFaceH4/zephyr-7b-beta",
    timeout=120,
)
llm_client.text_generation(prompt="How are you today?", max_new_tokens=20)

# Setup teacher optimization

In [None]:
teacher_agent = ChatOpenAI(model="gpt-4-1106-preview")

In [None]:
initial_prompt = """
Q: What is the answer to the following math problem? Make sure to think first, then give your answer at the end in the format "The answer is 42.36".
- {question}

A:"""

In [None]:
initial_logs = []
logs = optimize_prompt(
    initial_logs,
    initial_prompt,
    validation_df,
    llm_client,
    teacher_agent=teacher_agent,
    scoring_function=score_last_match_series,
)

In [None]:
file_path = f"{OUTPUT_DIR}/optimizer_zephyr_gpt4-teach4.pkl"

In [None]:
if not os.path.exists(file_path):
    with open(file_path, "wb") as f:
        pickle.dump(logs, f)

In [None]:
logs = pickle.load(open(file_path, "rb"))

In [None]:
logs[0]["answers"]

In [None]:
index_best_prompt = max(enumerate(logs), key=(lambda x: x[1]["score"]))[0]
best_prompt = logs[index_best_prompt]["prompt"]

In [None]:
[(el["prompt"], el["score"]) for el in logs]

# Test all prompts

In [None]:
prompt_dict = {
    "initial_prompt": {"prompt": initial_prompt},
    "best_prompt": {"prompt": best_prompt},
    "CoT": {"prompt": initial_prompt + " Let's think step-by-step. "},
}

In [None]:
fewshot_prompt = """
Please answer the following math problem. Make sure to give your answer as a float, and the LAST NUMBER OF ALL NUMBER YOU GIVE in the format "The answer is 42.36 dollars".
Here are a few examples to help you.
"""
for example in eval_dataset.select(range(3)):
    fewshot_prompt += f"""
Q: {example['question']}
A: {example['true_reasoning'] + '. So the answer is ' + str(example['true_answer'])}
"""
fewshot_prompt += "Now begin!\nQ: {question}\n\nA:"
fewshot_cot_prompt = fewshot_prompt + "Let's think step-by-step."

prompt_dict["fewshot"] = {"prompt": fewshot_prompt}
prompt_dict["fewshot_cot"] = {"prompt": fewshot_cot_prompt}

In [None]:
for prompt_name, values in prompt_dict.items():
    if "score" not in values:
        prompt = values["prompt"]
        print(f"========== Prompt: {prompt_name} ==========")
        print(f"Prompt content: {prompt}")
        answers = get_answers(prompt, llm_client, eval_df["question"])
        eval_df["is_correct"] = score_last_match_series(answers, eval_df["true_answer"])
        prompt_dict[prompt_name]["answers"] = answers.to_dict()
        prompt_dict[prompt_name]["score"] = eval_df["is_correct"].mean()
        print(eval_df["is_correct"].mean())

prompt_dict["langchain_agent"] = {
    "prompt": "Cf source file",
    "answers": "cf other experiment",
    "score": 0.73,
}

In [None]:
prompts_file_name = "dump/prompts_gsm8k_zephyr_gpt4-teach4.json"
if not os.path.exists(prompts_file_name):
    with open(prompts_file_name, "w") as f:
        json.dump(prompt_dict, f)

### Display results

In [None]:
file_name_mistral = "dump/prompts_gsm8k_zephyr_gpt4-teach4.json"
file_name_mixtral = "dump/prompts_gsm8k_mixtral.json"

prompt_dict_mistral = json.load(open(file_name_mistral, "r"))
prompt_dict_mixtral = json.load(open(file_name_mixtral, "r"))
prompt_dict_mistral.pop("langchain_agent", None)
prompt_dict_mixtral.pop("langchain_agent", None)
prompt_dict_mixtral.pop("fewshot", None)
prompt_dict_mistral.pop("fewshot", None)

results_df_mistral = pd.DataFrame(
    [
        {**{"prompt_name": key, "model": "mistral-7b"}, **value}
        for key, value in prompt_dict_mistral.items()
    ]
)
results_df_mixtral = pd.DataFrame(
    [
        {**{"prompt_name": key, "model": "mixtral-8x7b"}, **value}
        for key, value in prompt_dict_mixtral.items()
    ]
)

In [None]:
results_df = results_df_mixtral

In [None]:
results_df = pd.concat([results_df_mistral, results_df_mixtral])

In [None]:
aggregate = results_df.groupby(["prompt_name", "model"])[["score"]].mean().reset_index()
aggregate["score"] = aggregate["score"] * 100
aggregate = aggregate.sort_values("score")

In [None]:
aggregate["prompt_name"] = aggregate["prompt_name"].map(
    {
        "initial_prompt": "Initial prompt",
        "best_prompt": "Optimized prompt",
        "CoT": "CoT",
        "fewshot": "Fewshot",
        "fewshot_cot": "Fewshot+CoT",
    }
)

In [None]:
fig = px.bar(
    aggregate,
    x="prompt_name",
    color="model",
    y="score",
    labels={
        "prompt_name": "<b>Prompt choice</b>",
        "score": "<b>Score</b>",
        "fewshot": "Few-shot",
    },
)
fig.update_layout(
    width=aggregate["prompt_name"].nunique() * 100 + 200,
    height=600,
    barmode="group",
    bargap=0.35,
    bargroupgap=0.0,
    yaxis_range=[0, 80],
)
fig.update_traces(texttemplate="%{y:.0f}", textposition="outside")
fig.layout.yaxis.ticksuffix = "%"
fig.show()

### Insights from the experiment
- Prompt optimization with GPT4 does not seem to work well for big models like Mixtral 🚫
- Prompting techniques are most important for less powerful models like Mistral-7B