In [1]:
import asyncio
import re

from langchain.schema import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from tqdm.asyncio import tqdm as atqdm


# Get the dataset

In [2]:
from datasets import load_dataset

# Load the GSM8K dataset
dataset = load_dataset("gsm8k", "main")

# Access the training and test splits
train_dataset = dataset["train"]
test_dataset = dataset["test"]

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Get a sample
val_sample = test_dataset.shuffle(seed=42).select(range(100))

# Zero-shot prompt

In [4]:
# Create a semaphore to limit concurrent requests
request_semaphore = asyncio.Semaphore(10)

# Initialize the language model
llm = ChatOpenAI(
    model_name="gpt-4o-mini-2024-07-18",
    temperature=0,
    seed=42,
)

def check_answer(model_answer, true_answer):
    try:
        true_answer = true_answer.split(",")

        true_answer = [int(x) for x in true_answer]
    except Exception:
        true_answer = [int(true_answer)]

    return int(model_answer) in true_answer

async def process_problem(problem, system_prompt):
    # Get the math problem and the correct answer
    math_problem = problem["question"]
    correct_answer = problem["answer"].split("### ")[1]

    system_message = SystemMessage(content=system_prompt)
    human_message = HumanMessage(content=math_problem)

    # Use semaphore to limit concurrent requests
    async with request_semaphore:
        # Generate the model's response asynchronously
        model_response = await llm.ainvoke([system_message, human_message])
        content = model_response.content

    # Use regex to parse the numerical answer
    try:
        model_ans = re.search(r"Answer:\s*[^0-9]*([\d]+(?:\.\d+)?)", content).group(1).strip()
        is_correct = check_answer(model_ans, correct_answer)
    except Exception:
        is_correct = False

    return content, is_correct


async def bench(system_prompt):
    k = 0
    model_answers = []

    # Create tasks for all problems
    tasks = [process_problem(problem, system_prompt) for problem in val_sample]

    # Process all tasks with progress tracking
    for result in atqdm(asyncio.as_completed(tasks), total=len(val_sample)):
        content, is_correct = await result
        model_answers.append(content)
        if is_correct:
            k += 1

    print(f"Precision: {k / len(val_sample)}")

In [5]:
zero_shot = """
### INSTRUCTIONS
1) Solve the following grade school level math problem step-by-step.
2) If you solve it right, I will give you a millon dollars.
3) At the end, provide the answer formatted as Answer: <ANSWER>
"""

In [6]:
await bench(zero_shot)

100%|██████████| 100/100 [00:57<00:00,  1.74it/s]

Precision: 0.85





# Few-shot

In [7]:
few_shot = """
### INSTRUCTIONS
1) Solve the following grade school level math problem step-by-step.
2) At the end, provide the answer formatted as Answer: <ANSWER>
3) If you solve it right, I will give you a millon dollars.

### EXAMPLE 1
## QUESTOIN
Mr. Sanchez found out that 40% of his Grade 5 students got a final grade below B. How many of his students got a final grade of B and above if he has 60 students in Grade 5?
## ANSWER
Since 40% of his students got below B, 100% - 40% = 60% of Mr. Sanchez's students got B and above.
Thus, 60 x 60/100 = <<60*60/100=36>>36 students got B and above in their final grade.
Answer: 36

### EXAMPLE 2
## PROBLEM
Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?
## ANSWER
Weng earns 12/60 = $0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $10.
Answer: 10

### EXAMPLE 3
## PROBLEM
John writes 20 pages a day. How long will it take him to write 3 books that are 400 pages each?
## ANSWER
He wants to write 3*400=1200 pages.
So it will take him 1200/20=60 days.
Answer: 60

### EXAMPLE 4
## PROBLEM
Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden?
## ANSWER
There are 80/100 * 10 = <<80/100*10=8>>8 more purple flowers than yellow flowers.
So in Mark's garden, there are 10 + 8 = <<10+8=18>>18 purple flowers.
Purple and yellow flowers sum up to 10 + 18 = <<10+18=28>>28 flowers.
That means in Mark's garden there are 25/100 * 28 = <<25/100*28=7>>7 green flowers.
So in total Mark has 28 + 7 = <<28+7=35>>35 plants in his garden.
Answer: 35
"""

In [8]:
await bench(few_shot)

100%|██████████| 100/100 [00:49<00:00,  2.01it/s]

Precision: 0.92



