In [1]:
from datasets import load_dataset

ds = load_dataset("openai/gsm8k", "main")

In [8]:
ds['train'][0]['answer']

'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'

In [2]:
from transformers import AutoTokenizer 
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-pt")


In [3]:
import re
tool_call_pattern = re.compile(r'(<<[^>]+>>)([^<]*)')

In [4]:
list(tool_call_pattern.finditer(ds['train'][0]['answer']))

[<re.Match object; span=(20, 69), match='<<48/2=24>>24 clips in May.\nNatalia sold 48+24 =>,
 <re.Match object; span=(69, 126), match='<<48+24=72>>72 clips altogether in April and May.>]

In [19]:
parse_tool_call("30+46+18=100")

'<add>(30, 46, 18)=100'

In [18]:
def parse_tool_call(expression: str):
    """
    Parses a mathematical expression and converts it into the ToolkenGPT format.
    Examples: 
    "48/2=24" -> "<divide>(48, 2)=24"
    "+30+46+38+11+18=143" -> "<add>(30, 46, 38, 11, 18)=143"
    """
    # Simple mapping from operator to function name
    op_to_func = {
        '+': 'add',
        '-': 'subtract',
        '*': 'multiply',
        '/': 'divide'
    }

    # Find the operator and split on equals sign
    if '=' in expression:
        expr_part, result = expression.split('=', 1)
        result = result.strip()
        
        # Handle multiple additions/subtractions
        for op, func_name in op_to_func.items():
            if op in expr_part:
                # Skip if it's just a leading sign
                if expr_part.strip() == op:
                    continue
                    
                # Split on operator, filter out empty strings
                parts = [p.strip() for p in expr_part.split(op) if p.strip()]
                
                # If we found valid parts, convert to function call format
                if parts:
                    return f"<{func_name}>({', '.join(parts)})={result}"
    
    return expression

In [25]:
ds['train'][108]['answer']

'There were 9-4 = <<9-4=5>>5 other pills\nEach of the other pills cost 1.50+5.50 = <<1.50+5.50=7>>7 dollars each.\nThe 5 pills cost a total of 7*5 = <<7*5=35>>35 dollars.\nThe first 4 pills cost 1.50*4 = <<1.50*4=6>>6 dollars in total.\nHenry spent a total of 35+6 = <<35+6=41>>41 dollars.\n#### 41'

In [31]:
# Get the index by searching through processed samples
found_idx = next((i for i, sample in enumerate(processed_samples) 
                  if sample['text'] == original[108]['text']), None)
print(found_idx)

140


In [32]:
ds['train'][140]["answer"]

'First find how many kids from Riverside High are rejected: 20% * 120 kids = <<20*.01*120=24>>24 kids\nThen find how many kids from West Side High are rejected: 70% * 90 kids = <<70*.01*90=63>>63 kids\nThen find how many kids from Mountaintop High are rejected: 50 kids / 2 = <<50/2=25>>25 kids\nThen add the number of kids from each school to find the total number of kids: 120 kids + 90 kids + 50 kids = <<120+90+50=260>>260 kids\nThen subtract all the kids who were rejected from the total number of kids to find the number who got in: 260 kids - 24 kids - 63 kids - 25 kids = <<260-24-63-25=148>>148 kids\n#### 148'

In [27]:
original[108]

{'text': "Fern is checking IDs to get into an R-rated movie. She denied 20% of the 120 kids from Riverside High, 70% of the 90 kids from West Side High, and half the 50 kids from Mountaintop High. How many kids got into the movie? Let's think step by step. First find how many kids from Riverside High are rejected: 20% * 120 kids = 24 kids\nThen find how many kids from West Side High are rejected: 70% * 90 kids = 63 kids\nThen find how many kids from Mountaintop High are rejected: 50 kids / 2 = 25 kids\nThen add the number of kids from each school to find the total number of kids: 120 kids + 90 kids + 50 kids = 260 kids\nThen subtract all the kids who were rejected from the total number of kids to find the number who got in: 260 kids - 24 kids - 63 kids - 25 kids = 148 kids\n#### 148",
 'start_token_idx': [109, 139, 168, 212, 268],
 'end_token_idx': [111, 141, 170, 215, 271],
 'tar_eq': ['<multiply>(20, 0.01, 120)=24<eoe>',
  '<multiply>(70, 0.01, 90)=63<eoe>',
  '<divide>(50, 2)=25<eoe

In [53]:
from tqdm import tqdm
verbose = 1
mismatches = 0


processed_samples = []


for sample in tqdm(ds['train']):
    start_indices = []
    end_indices = []
    target_equations = []
    target_numbers = []
        
    removed_chars_offset = 0

    full_text = sample['question'] + " Let's think step by step. " + sample['answer']

    matches = list(tool_call_pattern.finditer(full_text))

    # Create the final clean text by removing only the <<...>> syntax.
    clean_text = re.sub(r'<<[^>]+>>', '', full_text)

    start_indices = []
    end_indices = []
    target_equations = []
    target_numbers = []

    removed_chars_offset = 0

    for match in matches:
        expression_part = match.group(1) # The <<...>> part
    
        expression = expression_part[2:-2] # The content inside
        expression = re.sub(r'(?<=[\s=+\-*/])\.(\d+)', r'0.\1', expression)  # Add 0 before decimal points

        following_text = match.group(2) # The text after <<...>>
        
        # The pattern should handle negative numbers as well.
        number_pattern = re.compile(r'-?[\d,]*\.?\d+')
        num_match = number_pattern.search(following_text)
        num = num_match.group(0) if num_match else None
        expected_num_str = expression.split('=')[-1].strip()

        """
        try:
            assert num == expected_num_str
        except AssertionError:
            mismatches += 1
            print("-"*25)
            print(num_match)
            print(f"Full text: {sample['answer']}")
            print("/"*10)
            print(f"Found match: {match.group(0)}")
            print(f"Expression part: {expression_part}")
            print(f"Following text: {following_text}")
            print(f"Number found: {num}")
            print(f"Expected number from expression: {expected_num_str}")
        """
        # 1. Calculate the character position of where the number *starts* in the clean_text.
        char_pos_start = (match.start() - removed_chars_offset) + num_match.start()

        # 2. Tokenize the clean text *before* the number's position to find the start_idx.
        text_before_result = clean_text[:char_pos_start]
        tokens_before = tokenizer.encode(text_before_result)
        start_idx = len(tokens_before)

        # 3. Calculate the character position of where the number *ends* in the clean_text.
        char_pos_end = char_pos_start + len(num)
        
        # 4. Tokenize the clean text up to the *end* of the number. The length of this
        #    token sequence is our end_idx.
        text_up_to_end_of_result = clean_text[:char_pos_end]
        tokens_up_to_end = tokenizer.encode(text_up_to_end_of_result)
        end_idx = len(tokens_up_to_end)

        # 5. Store the start and end indices for the current match.
        start_indices.append(start_idx)
        end_indices.append(end_idx)
        target_equations.append(parse_tool_call(expression))

        target_numbers.append(num)

        # Update the offset for the next iteration by adding the length of the
        # <<...>> syntax string we just processed.
        removed_chars_offset += len(expression_part)

        if num != expected_num_str and False:
            print(f"Text: {clean_text}")
            print(f"Before: {text_before_result}")
            print(f"After: {text_up_to_end_of_result}")
            print(f"Target equations: {target_equations}")
    
    processed_samples.append({
            "text": clean_text,
            "start_token_idx": start_indices,
            "end_token_idx": end_indices,
            "tar_eq": target_equations,
            "tar_number": target_numbers,
        })

print(f"Mismatch ratio: {mismatches/len(ds['train'])}")

  0%|          | 0/7473 [00:00<?, ?it/s]

100%|██████████| 7473/7473 [00:15<00:00, 488.94it/s]

Mismatch ratio: 0.0





In [18]:
expression_part

'<<48+24=72>>'

In [19]:
expression

'48+24=72'

In [20]:
following_text

'72 clips altogether in April and May.\n#### 72'

'72'

In [12]:
len(mismatches_after_correction)

23694

In [54]:
# Strict validation: update samples in-place with corrected indices and match status.
strict_total = 0
strict_ok = 0
mismatches_after_correction = []

for i, samp in enumerate(processed_samples):
    # Initialize a new list to store the match status for each tool call in the sample
    samp["strict_match"] = []
    
    # keep tokenizer behavior consistent with above preprocessing (same encode)
    token_ids = tokenizer.encode(samp["text"])
    
    # Enumerate to get the index 'j' for updating the lists
    for j, (s, t, num) in enumerate(zip(samp["start_token_idx"], samp["end_token_idx"], samp.get("tar_number", []))):
        strict_total += 1
        span_text = tokenizer.decode(token_ids[s:t])
        span_text_m1 = tokenizer.decode(token_ids[s-1:t]).strip()
        span_text_p1 = tokenizer.decode(token_ids[s+1:t]).strip()

        is_match = False

        if (span_text == num):
            strict_ok += 1
            is_match = True
        # Check for a match with a left-shifted start index
        elif (span_text_m1 == num):
            strict_ok += 1
            is_match = True
            # Correct the start index in the sample
            samp["start_token_idx"][j] -= 1
        # Check for a match with a right-shifted start index
        elif (span_text_p1 == num):
            strict_ok += 1
            is_match = True
            # Correct the start index in the sample
            samp["start_token_idx"][j] += 1
        
        # Append the match status (True or False) for the current tool call
        samp["strict_match"].append(is_match)

        if not is_match:
            mismatches_after_correction.append({
                "sample_idx": i,
                "tool_call_idx": j,
                "s": s,
                "t": t,
                "expected": num,
                "span_text": span_text,
                "ctx_left": tokenizer.decode(token_ids[max(0, s-5):s]),
                "ctx_right": tokenizer.decode(token_ids[t:min(len(token_ids), t+5)]),
            })

print(f"Strict — total: {strict_total}, matches (with correction): {strict_ok}, mismatches: {len(mismatches_after_correction)}")

# Show a few remaining mismatches for inspection
for r in mismatches_after_correction[:10]:
    print("-"*50)
    print({k: r[k] for k in ["sample_idx", "tool_call_idx", "s", "t", "expected", "span_text"]})
    print("L:", r["ctx_left"])
    print("R:", r["ctx_right"])

Strict — total: 23716, matches (with correction): 23695, mismatches: 21
--------------------------------------------------
{'sample_idx': 929, 'tool_call_idx': 1, 's': 90, 't': 92, 'expected': '.25', 'span_text': '25'}
L: 0/2= $.
R:  on the half-priced
--------------------------------------------------
{'sample_idx': 1273, 'tool_call_idx': 0, 's': 66, 't': 69, 'expected': '.75', 'span_text': '$.75'}
L: .5/2=
R:  per pack
So he
--------------------------------------------------
{'sample_idx': 1875, 'tool_call_idx': 0, 's': 75, 't': 77, 'expected': '.25', 'span_text': '25'}
L: 5/60=.
R:  hour head start
That
--------------------------------------------------
{'sample_idx': 1957, 'tool_call_idx': 3, 's': 122, 't': 124, 'expected': '.5', 'span_text': '$.5'}
L: -1.5=
R: 
So he spent 
--------------------------------------------------
{'sample_idx': 2270, 'tool_call_idx': 5, 's': 215, 't': 216, 'expected': '.8', 'span_text': '8'}
L: 2/15=.
R: 
#### 80
----------------------------------------

In [11]:
processed_samples[0]

{'text': 'Natalia sold 48/2 = 24 clips in May.\nNatalia sold 48+24 = 72 clips altogether in April and May.\n#### 72',
 'start_token_idx': [11, 29],
 'end_token_idx': [13, 31],
 'tar_eq': ['<divide>(48, 2)=24', '<add>(48, 24)=72'],
 'tar_number': ['24', '72'],
 'strict_match': [True, True]}

In [55]:
import os
import json

# --- Save the processed data ---
os.makedirs("outputs", exist_ok=True)

output_path = os.path.join("outputs", "gsm8k_gemma-4b-pt.json")
with open(output_path, 'w') as f:
    json.dump(processed_samples, f, indent=4)
print(f"✅ Processed data saved to {output_path}")

# --- Create and save the final function dictionary ---
func_dict = {
    "<add>": 0,
    "<subtract>": 1,
    "<multiply>": 2,
    "<divide>": 3
}
func_dict_path = os.path.join("outputs", "func_dict.json")
with open(func_dict_path, 'w') as f:
    json.dump(func_dict, f, indent=4)
print(f"✅ Function dictionary saved to {func_dict_path}")


✅ Processed data saved to outputs\gsm8k_gemma-4b-pt.json
✅ Function dictionary saved to outputs\func_dict.json


In [56]:
import json

original_path = "data/gsm8k-xl/train.json"
processed_path = "outputs/gsm8k_gemma-4b-pt.json"

with open(original_path, 'r') as f:
    original = json.load(f)

with open(processed_path, 'r') as f:
    processed = json.load(f)

# Statistics
total_samples = len(original)
found_matches = 0
text_matches = 0
tar_eq_matches = 0
tar_number_matches = 0

# Store mismatches
mismatches = {
    'text_mismatches': [],
    'tar_eq_mismatches': [],
    'tar_number_mismatches': []
}

# Create a dictionary of processed samples with text as key for faster lookup
processed_dict = {p['text']: p for p in processed}

for i, o_sample in enumerate(original):
    # Try to find matching text in processed samples
    if o_sample['text'] in processed_dict:
        found_matches += 1
        p_sample = processed_dict[o_sample['text']]
        
        # Compare each field
        if o_sample['text'] == p_sample['text']:
            text_matches += 1
        else:
            mismatches['text_mismatches'].append({
                'index': i,
                'original': o_sample['text'],
                'processed': p_sample['text']
            })

        c = [item.replace("<eoe>", "") for item in o_sample['tar_eq']]
        if c == p_sample['tar_eq']:
            tar_eq_matches += 1
        else:
            mismatches['tar_eq_mismatches'].append({
                'index': i,
                'original': o_sample['tar_eq'],
                'processed': p_sample['tar_eq']
            })
        
        n = [item.replace(",", "") for item in p_sample['tar_number']]
        if o_sample['tar_number'] == n:
            tar_number_matches += 1
        else:
            mismatches['tar_number_mismatches'].append({
                'index': i,
                'original': o_sample['tar_number'],
                'processed': p_sample['tar_number']
            })

# Print statistics
print(f"Total original samples: {total_samples}")
print(f"Matching samples found: {found_matches} ({found_matches/total_samples:.2%})")
print(f"Text matches: {text_matches} ({text_matches/total_samples:.2%})")
print(f"Target equation matches: {tar_eq_matches} ({tar_eq_matches/total_samples:.2%})")
print(f"Target number matches: {tar_number_matches} ({tar_number_matches/total_samples:.2%})")

# Save mismatches to file
mismatch_path = "outputs/mismatches.json"
with open(mismatch_path, 'w') as f:
    json.dump(mismatches, f, indent=4)
print(f"\nMismatches saved to {mismatch_path}")


Total original samples: 6054
Matching samples found: 5799 (95.79%)
Text matches: 5799 (95.79%)
Target equation matches: 5564 (91.91%)
Target number matches: 5760 (95.14%)

Mismatches saved to outputs/mismatches.json

Mismatches saved to outputs/mismatches.json


In [21]:
ds['train'][0]

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'}

In [34]:
ds['train']

Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})

In [33]:
len(processed_samples)

7473

In [31]:
assert len(original) == len(processed_samples), "Original and processed datasets have different lengths"

AssertionError: Original and processed datasets have different lengths

In [35]:
text_mismatch = 0
for o_sample, gsm, p_sample in zip(original, processed_samples, ds['train']):
    full_text = p_sample['question'] + " Let's think step by step. " + gsm['text']

    try:
        assert "Let's think step by step." in o_sample['text'], "Missing 'Let's think step by step.' in original text"
        #assert o_sample['text'] == full_text, "Text mismatch"
    except:
        print(f"Original: {o_sample['text']}")
        print(f"Processed: {full_text}")
        text_mismatch += 1
        continue
    #assert o_sample['start_token_idx'] == p_sample['start_token_idx'], "Start indices mismatch"
    #assert o_sample['end_token_idx'] == p_sample['end_token_idx'], "End indices mismatch"
    #assert o_sample['tar_eq'] == p_sample['tar_eq'], "Target equations mismatch"
    #assert o_sample['tar_number'] == p_sample['tar_number'], "Target numbers mismatch"

print(f"Text mismatches found: {text_mismatch/len(original):.2%}")

Text mismatches found: 0.00%
