In [1]:
import json
from pathlib import Path

def json_load(file_path):
    with open(file_path) as file:
        return json.load(file)

file_path = './results-gsm8k-capybara-bs-1.json'
data = json_load(file_path)

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer

model_id = "TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ"
tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=False)
tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
gsm8k_dataset = load_dataset("gsm8k", 'main')

In [4]:
import textwrap

def preprocess_dataset(dataset, tokenizer, pt, pt_cols, system_prompt, add_generation_prompt = True):

    def wrapper(sample):
        """Takes in a sample, formats it using prompt template, applies chat template and returns the formatted string"""
        messages = [] if system_prompt is None else [{"role": "system", "content": system_prompt}]
        formatted_pt = pt.format(**{pt_col : sample[pt_col] for pt_col in pt_cols})
        messages.append(
            {
                "role": "user",
                "content": formatted_pt,
            }
        )
        formatted_pt_with_ct = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt=add_generation_prompt)
        return formatted_pt_with_ct

    def actual_input(sample):
        """Takes in a sample, formats it using prompt template, applies chat template and returns the formatted string"""
        return sample[pt_cols[0]]

    pt_dataset = dataset.map(
        lambda sample : {
            "X" : wrapper(sample),
            'actual input' : actual_input(sample),
        }
    )

    return pt_dataset

pt = textwrap.dedent("""\
    Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
    A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.

    Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
    A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

    Q: {question}""")
pt_cols = ['question']
system_prompt = "Solve the following math problems, end with The answer is"

# Add prompt template
processed_dataset = preprocess_dataset(gsm8k_dataset['train'], tokenizer,pt = pt, pt_cols = pt_cols, system_prompt = system_prompt, add_generation_prompt = True)

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

In [5]:
seed = 42
bm_sample_size = 150
bm_samples = processed_dataset.shuffle(seed = seed).select(range(bm_sample_size))

In [10]:
for item in bm_samples['answer']:
    print(item)
    break

Mimi has 2 x 12 = <<2*12=24>>24 sea shells.
Kyle has 24 x 2 = <<24*2=48>>48 sea shells.
Leigh has 48 / 3 = <<48/3=16>>16 sea shells.
#### 16


In [8]:
bm_samples['answer']:
    print(item)

Mimi has 2 x 12 = <<2*12=24>>24 sea shells.
Kyle has 24 x 2 = <<24*2=48>>48 sea shells.
Leigh has 48 / 3 = <<48/3=16>>16 sea shells.
#### 16
He has 6 - 2 = <<6-2=4>>4 cats.
He has 4 - 1 = <<4-1=3>>3 parrots.
He has 4 + 6 = <<4+6=10>>10 snakes.
He has a total of 2 + 4 + 3 + 10 = <<2+4+3+10=19>>19 pets.
#### 19
Dad gave Olaf 10 toy cars,
Mom has given Olaf 5 more toy cars than Dad, so 10 + 5 = <<10+5=15>>15 toy cars
Auntie gave Olaf 6 toy cars,
Uncle has given 1 less toy than Auntie, so 6 - 1 = <<6-1=5>>5 toy cars
Grandpa gave Olaf 2 * 5 = <<2*5=10>>10 toy cars.
All the family together gave Olaf 10 +15 + 6 + 5 + 10 = <<10+15+6+5+10=46>>46.
Adding the cars Olaf already had, Olaf's collection has 150 + 46 = <<150+46=196>>196 cars.
#### 196
She spend $56 because 7 x 8 = <<7*8=56>>56
She has $44 left in the bank because 100 - 56 = <<100-56=44>>44
She can get 8 five dollar bills because 44 / 5 = <<44/5=8.8>>8.8
This is equal to $40 because 8 x 5 = <<8*5=40>>40
She has $4 left in the account b

In [15]:
import re

def extract_answer_from_out(s):
    pattern = re.compile(r"The answer is (\d+(?:\.\d+)?)")
    match = pattern.search(s)
    if match:
        return match.group(1).strip()
    else:
        return None

before_counter, after_counter = 0, 0

for out_before, out_after, answer in zip(data['outputs_before'], data['outputs_after'], bm_samples['answer']):

    before_answer = extract_answer_from_out(out_before)
    after_answer = extract_answer_from_out(out_after)
    correct_answer = answer.split("####")[-1].strip()

    if (correct_answer == after_answer) and (before_answer != after_answer):
        # print(f"Before: {before_answer}, After: {after_answer}, Correct: {correct_answer}")
        after_counter += 1
    elif (correct_answer == before_answer) and (before_answer != after_answer):
        print(f"Before: {before_answer}, After: {after_answer}, Correct: {correct_answer}")
        before_counter += 1

Before: 160, After: None, Correct: 160
Before: 1440, After: 1360, Correct: 1440
Before: 50, After: 600, Correct: 50
Before: 40, After: 28, Correct: 40
Before: 22, After: None, Correct: 22
Before: 21, After: 12, Correct: 21
Before: 10, After: 2, Correct: 10
Before: 1, After: 0, Correct: 1
Before: 3, After: 8, Correct: 3


In [16]:
before_counter, after_counter

(9, 11)