# Init a model

In [1]:
from langchain_ollama import ChatOllama

# Initialize the model
llm = ChatOllama(model="mistral:7b-instruct", seed=42, temperature=0.001)

# Get the dataset

In [2]:
from datasets import load_dataset

# Load the GSM8K dataset
dataset = load_dataset("gsm8k", "main")

# Access the training and test splits
train_dataset = dataset["train"]
test_dataset = dataset["test"]

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Get a sample
val_sample = test_dataset.shuffle(seed=42).select(range(100))

# Test the model

In [None]:
def check_answer(model_answer, true_answer):
    try:
        true_answer = true_answer.split(',')
        
        true_answer = [int(x) for x in true_answer]
    except Exception:
        true_answer = [int(true_answer)]
    
    return int(model_answer) in true_answer

False

In [None]:
from langchain.schema import HumanMessage, SystemMessage
from langchain.schema import HumanMessage, AIMessage, SystemMessage
from tqdm import tqdm
import re

zero_shot = \
"""
### INSTRUCTIONS
1) Solve the following grade school level math problem step-by-step.
2) If you solve it right, I will give you a millon dollars.
3) At the end, provide the answer formatted as Answer: <ANSWER>
"""
system_message = SystemMessage(content=zero_shot)

k = 0

for problem in tqdm(val_sample):
    # Get the math problem and the correct answer
    math_problem = problem['question']
    correct_answer = problem['answer'].split("### ")[1]

    human_message = HumanMessage(content=math_problem)
    
    # Generate the model's response
    model_response = llm.invoke([system_message, human_message])
    # Use regex to parse the numerical answer
    try:
        model_ans = re.search(r'Answer:\s*[^0-9]*([\d]+(?:\.\d+)?)', model_response.content).group(1).strip()
        print(model_ans)
        print(correct_answer)
        print('----------------')

        k += 1 if check_answer(model_ans, correct_answer) else 0
        
    except Exception:
        continue


  1%|          | 1/100 [00:39<1:05:24, 39.64s/it]

10
109
----------------


  2%|▏         | 2/100 [01:04<50:55, 31.18s/it]  

114
89
----------------


  3%|▎         | 3/100 [01:27<44:00, 27.23s/it]

6
13
----------------


  4%|▍         | 4/100 [01:47<39:01, 24.39s/it]

5
5
----------------


  5%|▌         | 5/100 [02:22<44:32, 28.13s/it]

25
25
----------------


  6%|▌         | 6/100 [02:59<49:04, 31.33s/it]

22
452
----------------


  7%|▋         | 7/100 [03:21<43:34, 28.11s/it]

43
43
----------------


  8%|▊         | 8/100 [03:40<39:00, 25.44s/it]

66
34
----------------


  9%|▉         | 9/100 [04:09<40:03, 26.41s/it]

120
120
----------------


 11%|█         | 11/100 [05:11<42:22, 28.57s/it]

39
34
----------------


 12%|█▏        | 12/100 [05:29<37:11, 25.36s/it]

12
12
----------------


 13%|█▎        | 13/100 [06:04<40:40, 28.05s/it]

24
15
----------------


 14%|█▍        | 14/100 [06:37<42:33, 29.69s/it]

104
24
----------------


 15%|█▌        | 15/100 [07:21<48:06, 33.96s/it]

25
105
----------------


 16%|█▌        | 16/100 [07:42<41:59, 30.00s/it]

80
120
----------------


 17%|█▋        | 17/100 [07:59<36:20, 26.28s/it]

18
17
----------------


 18%|█▊        | 18/100 [08:15<31:42, 23.21s/it]

36
36
----------------


 19%|█▉        | 19/100 [08:27<26:48, 19.85s/it]

12
12
----------------


 20%|██        | 20/100 [09:07<34:13, 25.66s/it]

15
45
----------------


 21%|██        | 21/100 [09:27<31:50, 24.18s/it]

45
45
----------------


 22%|██▏       | 22/100 [09:38<26:16, 20.21s/it]

8
8
----------------


 23%|██▎       | 23/100 [10:04<28:04, 21.88s/it]

36
36
----------------


 24%|██▍       | 24/100 [10:27<28:13, 22.28s/it]

1
1
----------------


 25%|██▌       | 25/100 [10:44<25:51, 20.69s/it]

14
18
----------------


 26%|██▌       | 26/100 [11:08<26:44, 21.69s/it]

33.00
33
----------------


 27%|██▋       | 27/100 [11:51<33:56, 27.90s/it]

3744
1248
----------------


 28%|██▊       | 28/100 [12:18<33:13, 27.69s/it]

3
25
----------------


 29%|██▉       | 29/100 [12:47<33:24, 28.23s/it]

1.00
1
----------------


 30%|███       | 30/100 [13:07<29:57, 25.68s/it]

32
32
----------------


 31%|███       | 31/100 [13:33<29:29, 25.65s/it]

3
2
----------------


 32%|███▏      | 32/100 [13:47<25:21, 22.37s/it]

30
30
----------------


 33%|███▎      | 33/100 [14:12<25:38, 22.96s/it]

4
4
----------------


 34%|███▍      | 34/100 [14:30<23:39, 21.51s/it]

144
54
----------------


 35%|███▌      | 35/100 [14:56<24:47, 22.88s/it]

110
250
----------------


 36%|███▌      | 36/100 [15:24<26:08, 24.50s/it]

420.00
324
----------------


 37%|███▋      | 37/100 [15:46<24:51, 23.68s/it]

127
129200
----------------


 38%|███▊      | 38/100 [16:16<26:32, 25.69s/it]

4400
4400
----------------


 39%|███▉      | 39/100 [16:38<24:54, 24.49s/it]

140
70
----------------


 40%|████      | 40/100 [17:10<26:49, 26.83s/it]

228
276,000
----------------


 41%|████      | 41/100 [17:41<27:29, 27.96s/it]

100
108
----------------


 42%|████▏     | 42/100 [17:58<23:57, 24.78s/it]

160
160
----------------


 43%|████▎     | 43/100 [18:11<20:02, 21.10s/it]

90
90
----------------


 44%|████▍     | 44/100 [18:29<18:43, 20.06s/it]

20
20
----------------


 45%|████▌     | 45/100 [18:58<21:05, 23.00s/it]

296
296
----------------


 46%|████▌     | 46/100 [19:30<23:05, 25.66s/it]

4800
4800
----------------


 47%|████▋     | 47/100 [19:49<20:57, 23.72s/it]

20
30
----------------


 48%|████▊     | 48/100 [20:11<19:57, 23.03s/it]

24
8
----------------


 49%|████▉     | 49/100 [20:40<21:13, 24.97s/it]

1920
1920
----------------


 50%|█████     | 50/100 [21:00<19:26, 23.34s/it]

18
6
----------------


 51%|█████     | 51/100 [21:16<17:20, 21.23s/it]

13
13
----------------


 52%|█████▏    | 52/100 [21:42<18:01, 22.53s/it]

267
420
----------------


 53%|█████▎    | 53/100 [22:00<16:34, 21.15s/it]

8
10
----------------


 54%|█████▍    | 54/100 [22:16<15:12, 19.84s/it]

16
12
----------------


 55%|█████▌    | 55/100 [22:31<13:45, 18.36s/it]

50
50
----------------


 56%|█████▌    | 56/100 [22:49<13:16, 18.10s/it]

92
92
----------------


 57%|█████▋    | 57/100 [23:07<12:58, 18.10s/it]

623
623
----------------


 58%|█████▊    | 58/100 [23:28<13:13, 18.89s/it]

38
40
----------------


 59%|█████▉    | 59/100 [23:53<14:11, 20.76s/it]

9
17
----------------


 60%|██████    | 60/100 [24:04<11:58, 17.95s/it]

940
940
----------------


 61%|██████    | 61/100 [24:30<13:13, 20.36s/it]

10
10
----------------


 62%|██████▏   | 62/100 [24:52<13:06, 20.70s/it]

540
540
----------------


 63%|██████▎   | 63/100 [25:16<13:30, 21.90s/it]

20
110
----------------


 64%|██████▍   | 64/100 [26:15<19:43, 32.88s/it]

45
20
----------------


 65%|██████▌   | 65/100 [26:35<17:01, 29.17s/it]

8400
560
----------------


 66%|██████▌   | 66/100 [26:47<13:30, 23.84s/it]

24
24
----------------


 67%|██████▋   | 67/100 [27:17<14:11, 25.81s/it]

6
22
----------------


 68%|██████▊   | 68/100 [27:49<14:47, 27.74s/it]

15
4
----------------


 70%|███████   | 70/100 [28:46<14:14, 28.49s/it]

16
64
----------------


 71%|███████   | 71/100 [29:28<15:43, 32.54s/it]

555
525
----------------


 72%|███████▏  | 72/100 [29:45<13:04, 28.03s/it]

10
10
----------------


 73%|███████▎  | 73/100 [30:06<11:33, 25.70s/it]

410
410
----------------


 74%|███████▍  | 74/100 [30:32<11:13, 25.90s/it]

1540
140
----------------


 75%|███████▌  | 75/100 [30:53<10:12, 24.49s/it]

720
720
----------------


 76%|███████▌  | 76/100 [31:11<09:01, 22.58s/it]

1050
1050
----------------


 77%|███████▋  | 77/100 [31:38<09:08, 23.83s/it]

7.5
8
----------------


 78%|███████▊  | 78/100 [32:40<12:55, 35.25s/it]

20
20
----------------


 79%|███████▉  | 79/100 [33:05<11:16, 32.21s/it]

140
138
----------------


 80%|████████  | 80/100 [33:43<11:21, 34.06s/it]

175
175
----------------


 81%|████████  | 81/100 [34:06<09:41, 30.63s/it]

64
64
----------------


 82%|████████▏ | 82/100 [34:35<09:04, 30.27s/it]

11
5600
----------------


 83%|████████▎ | 83/100 [34:57<07:50, 27.66s/it]

720.00
720
----------------


 84%|████████▍ | 84/100 [35:23<07:12, 27.04s/it]

312
312
----------------


 85%|████████▌ | 85/100 [35:39<05:58, 23.91s/it]

60
60
----------------


 86%|████████▌ | 86/100 [35:53<04:50, 20.77s/it]

450
450
----------------


 87%|████████▋ | 87/100 [36:36<05:57, 27.49s/it]

4
10000
----------------


 88%|████████▊ | 88/100 [37:16<06:14, 31.22s/it]

43
43,500
----------------


 89%|████████▉ | 89/100 [37:39<05:16, 28.81s/it]

75
150
----------------


 90%|█████████ | 90/100 [38:08<04:48, 28.89s/it]

11.00
9
----------------


 91%|█████████ | 91/100 [38:37<04:20, 28.92s/it]

60
60
----------------


 92%|█████████▏| 92/100 [38:59<03:34, 26.83s/it]

24
24
----------------


 94%|█████████▍| 94/100 [40:09<03:04, 30.76s/it]

25
18
----------------


 95%|█████████▌| 95/100 [40:33<02:24, 28.81s/it]

0
14
----------------


 96%|█████████▌| 96/100 [40:54<01:46, 26.60s/it]

60
60
----------------


 97%|█████████▋| 97/100 [41:39<01:35, 31.92s/it]

1080
1080
----------------


 98%|█████████▊| 98/100 [43:22<01:46, 53.35s/it]

3
291
----------------


 99%|█████████▉| 99/100 [43:59<00:48, 48.51s/it]

70
70
----------------


100%|██████████| 100/100 [44:31<00:00, 26.71s/it]

63
63
----------------





In [None]:
print(f"Precision: {k/len(val_sample)}")

Precision: 0.46
