In [None]:
from langchain_openai import ChatOpenAI



# Define the local llama.cpp server as an OpenAI-compatible endpoint
llm = ChatOpenAI(
    openai_api_key="",
    model_name="gpt-4o-mini-2024-07-18",
    temperature=0,
    seed=42,
)

# Get the dataset

In [3]:
from datasets import load_dataset

# Load the GSM8K dataset
dataset = load_dataset("gsm8k", "main")

# Access the training and test splits
train_dataset = dataset["train"]
test_dataset = dataset["test"]

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Get a sample
val_sample = test_dataset.shuffle(seed=42).select(range(100))

# Test the model

In [5]:
def check_answer(model_answer, true_answer):
    try:
        true_answer = true_answer.split(',')
        
        true_answer = [int(x) for x in true_answer]
    except Exception:
        true_answer = [int(true_answer)]
    
    return int(model_answer) in true_answer

In [None]:
from langchain.schema import HumanMessage, SystemMessage
from langchain.schema import HumanMessage, AIMessage, SystemMessage
from tqdm import tqdm
import re

zero_shot = \
"""
### INSTRUCTIONS
1) Solve the following grade school level math problem step-by-step.
2) If you solve it right, I will give you a millon dollars.
3) At the end, provide the answer formatted as Answer: <ANSWER>
"""
system_message = SystemMessage(content=zero_shot)

k = 0
model_answers = []

for problem in tqdm(val_sample):
    # Get the math problem and the correct answer
    math_problem = problem['question']
    correct_answer = problem['answer'].split("### ")[1]

    human_message = HumanMessage(content=math_problem)
    
    # Generate the model's response
    model_response = llm.invoke([system_message, human_message])
    model_answers.append(model_response.content)

    # Use regex to parse the numerical answer
    try:
        model_ans = re.search(r'Answer:\s*[^0-9]*([\d]+(?:\.\d+)?)', model_response.content).group(1).strip()
        # print(model_ans)
        # print(correct_answer)
        # print('----------------')

        k += 1 if check_answer(model_ans, correct_answer) else 0
        
    except Exception:
        continue

  1%|          | 1/100 [00:03<06:15,  3.79s/it]

109
109
----------------


  2%|▏         | 2/100 [00:10<09:07,  5.59s/it]

107
89
----------------


  3%|▎         | 3/100 [00:13<07:16,  4.50s/it]

6
13
----------------


  4%|▍         | 4/100 [00:18<07:08,  4.46s/it]

5
5
----------------


  5%|▌         | 5/100 [00:22<06:45,  4.27s/it]

25
25
----------------


  6%|▌         | 6/100 [00:29<08:17,  5.29s/it]

452
452
----------------


  7%|▋         | 7/100 [00:32<06:57,  4.49s/it]

43
43
----------------


  8%|▊         | 8/100 [00:36<06:38,  4.33s/it]

34
34
----------------


  9%|▉         | 9/100 [00:42<07:20,  4.84s/it]

120
120
----------------


 10%|█         | 10/100 [00:45<06:45,  4.50s/it]

11
11
----------------


 11%|█         | 11/100 [00:51<07:10,  4.83s/it]

34
34
----------------


 12%|█▏        | 12/100 [00:54<06:16,  4.28s/it]

12
12
----------------


 13%|█▎        | 13/100 [01:01<07:21,  5.07s/it]

15
15
----------------


 14%|█▍        | 14/100 [01:08<08:10,  5.70s/it]

24
24
----------------


 15%|█▌        | 15/100 [01:15<08:42,  6.15s/it]

105
105
----------------


 16%|█▌        | 16/100 [01:18<07:06,  5.08s/it]

120
120
----------------


 17%|█▋        | 17/100 [01:22<06:43,  4.86s/it]

17
17
----------------


 18%|█▊        | 18/100 [01:26<06:10,  4.51s/it]

36
36
----------------


 19%|█▉        | 19/100 [01:29<05:30,  4.08s/it]

12
12
----------------


 20%|██        | 20/100 [01:32<05:06,  3.83s/it]

45
45
----------------


 21%|██        | 21/100 [01:34<04:18,  3.28s/it]

45
45
----------------


 22%|██▏       | 22/100 [01:37<04:11,  3.22s/it]

8
8
----------------


 23%|██▎       | 23/100 [01:42<04:33,  3.55s/it]

36
36
----------------


 24%|██▍       | 24/100 [01:50<06:12,  4.90s/it]

1
1
----------------


 25%|██▌       | 25/100 [01:53<05:24,  4.33s/it]

10
18
----------------


 26%|██▌       | 26/100 [01:58<05:42,  4.62s/it]

33.00
33
----------------


 27%|██▋       | 27/100 [02:03<05:52,  4.83s/it]

3744
1248
----------------


 28%|██▊       | 28/100 [02:11<06:56,  5.78s/it]

25
25
----------------


 29%|██▉       | 29/100 [02:17<06:53,  5.82s/it]

1
1
----------------


 30%|███       | 30/100 [02:20<05:40,  4.87s/it]

32
32
----------------


 31%|███       | 31/100 [02:28<06:52,  5.97s/it]

6
2
----------------


 32%|███▏      | 32/100 [02:32<05:59,  5.28s/it]

30
30
----------------


 33%|███▎      | 33/100 [02:41<07:15,  6.50s/it]

4.0
4
----------------


 34%|███▍      | 34/100 [02:45<06:16,  5.70s/it]

54
54
----------------


 35%|███▌      | 35/100 [02:51<06:08,  5.66s/it]

250
250
----------------


 36%|███▌      | 36/100 [02:56<06:00,  5.64s/it]

324
324
----------------


 37%|███▋      | 37/100 [03:01<05:27,  5.19s/it]

129200
129200
----------------


 38%|███▊      | 38/100 [03:07<05:49,  5.63s/it]

4400
4400
----------------


 39%|███▉      | 39/100 [03:12<05:35,  5.50s/it]

70
70
----------------


 40%|████      | 40/100 [03:17<05:21,  5.35s/it]

276000
276,000
----------------


 41%|████      | 41/100 [03:22<05:06,  5.20s/it]

12
108
----------------


 42%|████▏     | 42/100 [03:25<04:23,  4.54s/it]

160
160
----------------


 43%|████▎     | 43/100 [03:31<04:45,  5.00s/it]

90
90
----------------


 44%|████▍     | 44/100 [03:35<04:22,  4.69s/it]

20
20
----------------


 45%|████▌     | 45/100 [03:38<03:52,  4.22s/it]

296
296
----------------


 46%|████▌     | 46/100 [03:51<06:09,  6.85s/it]

4800
4800
----------------


 47%|████▋     | 47/100 [03:56<05:22,  6.08s/it]

30
30
----------------


 48%|████▊     | 48/100 [03:59<04:34,  5.28s/it]

8
8
----------------


 49%|████▉     | 49/100 [04:07<05:03,  5.95s/it]

1920
1920
----------------


 50%|█████     | 50/100 [04:17<06:08,  7.36s/it]

6
6
----------------


 51%|█████     | 51/100 [04:24<05:50,  7.15s/it]

13
13
----------------


 52%|█████▏    | 52/100 [04:32<06:02,  7.56s/it]

420
420
----------------


 53%|█████▎    | 53/100 [04:39<05:39,  7.22s/it]

10
10
----------------


 54%|█████▍    | 54/100 [04:43<04:53,  6.39s/it]

12
12
----------------


 55%|█████▌    | 55/100 [04:47<04:04,  5.44s/it]

50
50
----------------


 56%|█████▌    | 56/100 [04:51<03:48,  5.18s/it]

92
92
----------------


 57%|█████▋    | 57/100 [04:58<03:59,  5.57s/it]

623
623
----------------


 58%|█████▊    | 58/100 [05:00<03:19,  4.75s/it]

38
40
----------------


 59%|█████▉    | 59/100 [05:08<03:46,  5.52s/it]

17
17
----------------


 60%|██████    | 60/100 [05:11<03:08,  4.72s/it]

940
940
----------------


 61%|██████    | 61/100 [05:20<04:00,  6.16s/it]

10
10
----------------


 62%|██████▏   | 62/100 [05:23<03:19,  5.26s/it]

540
540
----------------


 63%|██████▎   | 63/100 [05:25<02:36,  4.24s/it]

110
110
----------------


 64%|██████▍   | 64/100 [05:33<03:08,  5.25s/it]

20
20
----------------


 65%|██████▌   | 65/100 [05:40<03:20,  5.74s/it]

560
560
----------------


 66%|██████▌   | 66/100 [05:42<02:37,  4.64s/it]

24
24
----------------


 67%|██████▋   | 67/100 [05:46<02:27,  4.47s/it]

22
22
----------------


 68%|██████▊   | 68/100 [05:51<02:30,  4.69s/it]

4
4
----------------


 69%|██████▉   | 69/100 [05:57<02:40,  5.17s/it]

22
22
----------------


 70%|███████   | 70/100 [06:02<02:33,  5.13s/it]

64
64
----------------


 71%|███████   | 71/100 [06:09<02:40,  5.54s/it]

525
525
----------------


 72%|███████▏  | 72/100 [06:14<02:28,  5.31s/it]

10
10
----------------


 73%|███████▎  | 73/100 [06:18<02:13,  4.95s/it]

410
410
----------------


 74%|███████▍  | 74/100 [06:20<01:49,  4.20s/it]

140
140
----------------


 75%|███████▌  | 75/100 [06:23<01:33,  3.74s/it]

720
720
----------------


 76%|███████▌  | 76/100 [06:26<01:26,  3.59s/it]

1050
1050
----------------


 77%|███████▋  | 77/100 [06:30<01:22,  3.59s/it]

8
8
----------------


 78%|███████▊  | 78/100 [06:43<02:23,  6.54s/it]

20
20
----------------


 79%|███████▉  | 79/100 [06:47<02:03,  5.88s/it]

139
138
----------------


 80%|████████  | 80/100 [06:49<01:33,  4.67s/it]

175
175
----------------


 81%|████████  | 81/100 [06:52<01:18,  4.11s/it]

64
64
----------------


 82%|████████▏ | 82/100 [06:56<01:11,  3.99s/it]

5600
5600
----------------


 83%|████████▎ | 83/100 [07:00<01:07,  3.98s/it]

720.00
720
----------------


 84%|████████▍ | 84/100 [07:03<01:00,  3.76s/it]

312
312
----------------


 85%|████████▌ | 85/100 [07:06<00:53,  3.58s/it]

60
60
----------------


 86%|████████▌ | 86/100 [07:08<00:44,  3.20s/it]

450
450
----------------


 87%|████████▋ | 87/100 [07:11<00:40,  3.10s/it]

10000
10000
----------------


 88%|████████▊ | 88/100 [07:17<00:48,  4.03s/it]

43500
43,500
----------------


 89%|████████▉ | 89/100 [07:21<00:43,  3.99s/it]

150
150
----------------


 90%|█████████ | 90/100 [07:25<00:38,  3.90s/it]

9.00
9
----------------


 91%|█████████ | 91/100 [07:32<00:42,  4.77s/it]

60
60
----------------


 92%|█████████▏| 92/100 [07:35<00:33,  4.14s/it]

24
24
----------------


 93%|█████████▎| 93/100 [07:40<00:30,  4.42s/it]

539
539
----------------


 94%|█████████▍| 94/100 [07:49<00:34,  5.81s/it]

2
18
----------------


 95%|█████████▌| 95/100 [07:56<00:31,  6.22s/it]

14
14
----------------


 96%|█████████▌| 96/100 [07:59<00:20,  5.19s/it]

60
60
----------------


 97%|█████████▋| 97/100 [08:04<00:15,  5.16s/it]

1080
1080
----------------


 98%|█████████▊| 98/100 [08:16<00:14,  7.20s/it]

291
291
----------------


 99%|█████████▉| 99/100 [08:19<00:06,  6.04s/it]

70
70
----------------


100%|██████████| 100/100 [08:25<00:00,  5.05s/it]

63
63
----------------





In [7]:
print(f"Precision: {k/len(val_sample)}")

Precision: 0.85
