In [1]:
from huggingface_hub import login
login(token="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") 

In [2]:
dataset = "svamp"
llm = "gemma-3-4b-pt"

In [3]:
'''from datasets import load_dataset

ds = load_dataset(f"openai/{dataset}", "main")

print(ds)

example = ds["train"][0]
print("Example:", example)

print("\n\nQuestion:", example["question"], "\n\n")
print("Answer:", example["answer"])''';

In [4]:
from datasets import load_dataset, DatasetDict
import re

# 1) Load dataset
raw = load_dataset(f"garrethlee/{dataset}")  # has 'train' and 'test' splits

# 2) Parser: from "( 26 / 13 ) = 2 #### 2" -> eq="26/13", ans="2"
eq_ans_re = re.compile(
    r"""            # whole answer field
    (?P<lhs>.+?)    # left side including equation text
    \s*####\s*
    (?P<ans>[-+]?\d+(?:\.\d+)?)\s*$   # numeric final answer
    """,
    re.VERBOSE,
)

def normalize_equation(lhs: str) -> str:
    """
    lhs examples:
      "( 26 / 13 ) = 2"
      "( ( 11 - 2 ) - 2 ) = 7"
    We take the text before '=', strip spaces, but preserve parentheses.
    """
    # take everything before '='
    before_eq = lhs.split('=')[0].strip()
    # remove all spaces
    before_eq = before_eq.replace(' ', '')
    # also collapse redundant outer parentheses like "((a))" -> "(a)" (optional)
    # simple pass: keep as-is; downstream doesn't require removal
    return before_eq

def to_gsm8k_style(example):
    """
    Produce:
      question: copy as-is
      answer: "<eq> = <<eq=ans>>ans"
    """
    text = example["answer"]
    m = eq_ans_re.match(text)
    if not m:
        # fallback: if the format deviates, return minimal passthrough
        eq = ""
        ans = ""
    else:
        lhs = m.group("lhs")
        ans = m.group("ans")
        eq = normalize_equation(lhs)

    return {
        "question": example["question"],
        "answer": f"The final answer is {eq} = <<{eq}={ans}>>{ans}",
    }

# 3) Transform both splits, keep only the 2 fields
processed = {}
for split in raw.keys():  # 'train', 'test'
    processed[split] = raw[split].map(
        to_gsm8k_style,
        remove_columns=raw[split].column_names
    )

# 4) Ensure final variable name is `ds` and contains only 'train' and 'test'
ds = DatasetDict(processed)

# 5) Quick inspection
print(ds)
print(ds["train"][0])


Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['question', 'answer'],
        num_rows: 800
    })
    test: Dataset({
        features: ['question', 'answer'],
        num_rows: 200
    })
})
{'question': "In Haley's class 13 boys love to play marbles and 50 boys love to play cards. If Haley has 26 marbles How many will each of the boys receive?", 'answer': 'The final answer is (26/13) = <<(26/13)=2>>2'}


In [5]:
from transformers import AutoTokenizer 
tokenizer = AutoTokenizer.from_pretrained(f"google/{llm}")


In [6]:
import re
tool_call_pattern = re.compile(r'(<<[^>]+>>)([^<]*)')

In [7]:
list(tool_call_pattern.finditer(ds['train'][0]['answer']))

[<re.Match object; span=(30, 44), match='<<(26/13)=2>>2'>]

In [8]:
def parse_tool_call(expression: str):
    """
    Parses a mathematical expression and converts it into the ToolkenGPT format.
    Examples: 
    "48/2=24" -> "<divide>(48, 2)=24"
    "+30+46+38+11+18=143" -> "<add>(30, 46, 38, 11, 18)=143"
    """
    # Simple mapping from operator to function name
    op_to_func = {
        '+': 'add',
        '-': 'subtract',
        '*': 'multiply',
        '/': 'divide'
    }

    # Find the operator and split on equals sign
    if '=' in expression:
        expr_part, result = expression.split('=', 1)
        result = result.strip()
        
        # Handle multiple additions/subtractions
        for op, func_name in op_to_func.items():
            if op in expr_part:
                # Skip if it's just a leading sign
                if expr_part.strip() == op:
                    continue
                    
                # Split on operator, filter out empty strings
                parts = [p.strip() for p in expr_part.split(op) if p.strip()]
                
                # If we found valid parts, convert to function call format
                if parts:
                    return f"<{func_name}>({', '.join(parts)})={result}"
    
    return expression

In [9]:
parse_tool_call("30+46+18=100")

'<add>(30, 46, 18)=100'

In [10]:
ds['train'][108]['answer']

'The final answer is (11+8) = <<(11+8)=19>>19'

In [11]:
# Get the index by searching through processed samples
'''found_idx = next((i for i, sample in enumerate(processed_samples) 
                  if sample['text'] == original[108]['text']), None)
print(found_idx)''';

In [12]:
ds['train'][140]["answer"]

'The final answer is (6-(2+3)) = <<(6-(2+3))=1>>1'

In [13]:
#original[108]

In [14]:
from tqdm import tqdm
verbose = 1
mismatches = 0


processed_samples = []


for sample in tqdm(ds['train']):
    start_indices = []
    end_indices = []
    target_equations = []
    target_numbers = []
        
    removed_chars_offset = 0

    full_text = sample['question'] + " Let's think step by step. " + sample['answer']

    matches = list(tool_call_pattern.finditer(full_text))

    # Create the final clean text by removing only the <<...>> syntax.
    clean_text = re.sub(r'<<[^>]+>>', '', full_text)

    start_indices = []
    end_indices = []
    target_equations = []
    target_numbers = []

    removed_chars_offset = 0

    for match in matches:
        expression_part = match.group(1) # The <<...>> part
    
        expression = expression_part[2:-2] # The content inside
        expression = re.sub(r'(?<=[\s=+\-*/])\.(\d+)', r'0.\1', expression)  # Add 0 before decimal points

        following_text = match.group(2) # The text after <<...>>
        
        # The pattern should handle negative numbers as well.
        number_pattern = re.compile(r'-?[\d,]*\.?\d+')
        num_match = number_pattern.search(following_text)
        num = num_match.group(0) if num_match else None
        expected_num_str = expression.split('=')[-1].strip()

        """
        try:
            assert num == expected_num_str
        except AssertionError:
            mismatches += 1
            print("-"*25)
            print(num_match)
            print(f"Full text: {sample['answer']}")
            print("/"*10)
            print(f"Found match: {match.group(0)}")
            print(f"Expression part: {expression_part}")
            print(f"Following text: {following_text}")
            print(f"Number found: {num}")
            print(f"Expected number from expression: {expected_num_str}")
        """
        # 1. Calculate the character position of where the number *starts* in the clean_text.
        char_pos_start = (match.start() - removed_chars_offset) + num_match.start()

        # 2. Tokenize the clean text *before* the number's position to find the start_idx.
        text_before_result = clean_text[:char_pos_start]
        tokens_before = tokenizer.encode(text_before_result)
        start_idx = len(tokens_before)

        # 3. Calculate the character position of where the number *ends* in the clean_text.
        char_pos_end = char_pos_start + len(num)
        
        # 4. Tokenize the clean text up to the *end* of the number. The length of this
        #    token sequence is our end_idx.
        text_up_to_end_of_result = clean_text[:char_pos_end]
        tokens_up_to_end = tokenizer.encode(text_up_to_end_of_result)
        end_idx = len(tokens_up_to_end)

        # 5. Store the start and end indices for the current match.
        start_indices.append(start_idx)
        end_indices.append(end_idx)
        target_equations.append(parse_tool_call(expression))

        target_numbers.append(num)

        # Update the offset for the next iteration by adding the length of the
        # <<...>> syntax string we just processed.
        removed_chars_offset += len(expression_part)

        if num != expected_num_str and False:
            print(f"Text: {clean_text}")
            print(f"Before: {text_before_result}")
            print(f"After: {text_up_to_end_of_result}")
            print(f"Target equations: {target_equations}")
    
    processed_samples.append({
            "text": clean_text,
            "start_token_idx": start_indices,
            "end_token_idx": end_indices,
            "tar_eq": target_equations,
            "tar_number": target_numbers,
        })

print(f"Mismatch ratio: {mismatches/len(ds['train'])}")

100%|██████████| 800/800 [00:00<00:00, 2944.39it/s]

Mismatch ratio: 0.0





In [15]:
expression_part

'<<(16+6)=22>>'

In [16]:
expression

'(16+6)=22'

In [17]:
following_text

'22'

In [18]:
#len(mismatches_after_correction)

In [19]:
# Strict validation: update samples in-place with corrected indices and match status.
strict_total = 0
strict_ok = 0
mismatches_after_correction = []

for i, samp in enumerate(processed_samples):
    # Initialize a new list to store the match status for each tool call in the sample
    samp["strict_match"] = []
    
    # keep tokenizer behavior consistent with above preprocessing (same encode)
    token_ids = tokenizer.encode(samp["text"])
    
    # Enumerate to get the index 'j' for updating the lists
    for j, (s, t, num) in enumerate(zip(samp["start_token_idx"], samp["end_token_idx"], samp.get("tar_number", []))):
        strict_total += 1
        span_text = tokenizer.decode(token_ids[s:t])
        span_text_m1 = tokenizer.decode(token_ids[s-1:t]).strip()
        span_text_p1 = tokenizer.decode(token_ids[s+1:t]).strip()

        is_match = False

        if (span_text == num):
            strict_ok += 1
            is_match = True
        # Check for a match with a left-shifted start index
        elif (span_text_m1 == num):
            strict_ok += 1
            is_match = True
            # Correct the start index in the sample
            samp["start_token_idx"][j] -= 1
        # Check for a match with a right-shifted start index
        elif (span_text_p1 == num):
            strict_ok += 1
            is_match = True
            # Correct the start index in the sample
            samp["start_token_idx"][j] += 1
        
        # Append the match status (True or False) for the current tool call
        samp["strict_match"].append(is_match)

        if not is_match:
            mismatches_after_correction.append({
                "sample_idx": i,
                "tool_call_idx": j,
                "s": s,
                "t": t,
                "expected": num,
                "span_text": span_text,
                "ctx_left": tokenizer.decode(token_ids[max(0, s-5):s]),
                "ctx_right": tokenizer.decode(token_ids[t:min(len(token_ids), t+5)]),
            })

print(f"Strict — total: {strict_total}, matches (with correction): {strict_ok}, mismatches: {len(mismatches_after_correction)}")

# Show a few remaining mismatches for inspection
for r in mismatches_after_correction[:10]:
    print("-"*50)
    print({k: r[k] for k in ["sample_idx", "tool_call_idx", "s", "t", "expected", "span_text"]})
    print("L:", r["ctx_left"])
    print("R:", r["ctx_right"])

Strict — total: 800, matches (with correction): 800, mismatches: 0


In [20]:
processed_samples[0]

{'text': "In Haley's class 13 boys love to play marbles and 50 boys love to play cards. If Haley has 26 marbles How many will each of the boys receive? Let's think step by step. The final answer is (26/13) = 2",
 'start_token_idx': [61],
 'end_token_idx': [62],
 'tar_eq': ['<divide>((26, 13))=2'],
 'tar_number': ['2'],
 'strict_match': [True]}

In [21]:
import os
import json

# --- Save the processed data ---
os.makedirs("outputs", exist_ok=True)

output_path = os.path.join("outputs", f"{dataset}_{llm}.json")
with open(output_path, 'w') as f:
    json.dump(processed_samples, f, indent=4)
print(f"✅ Processed data saved to {output_path}")

# --- Create and save the final function dictionary ---
func_dict = {
    "<add>": 0,
    "<subtract>": 1,
    "<multiply>": 2,
    "<divide>": 3
}
func_dict_path = os.path.join("outputs", "func_dict.json")
with open(func_dict_path, 'w') as f:
    json.dump(func_dict, f, indent=4)
print(f"✅ Function dictionary saved to {func_dict_path}")


✅ Processed data saved to outputs/svamp_gemma-3-4b-pt.json
✅ Function dictionary saved to outputs/func_dict.json


In [22]:
'''import json

original_path = os.path.join("outputs", f"{dataset}_{llm}_original.json")
processed_path = os.path.join("outputs", f"{dataset}_{llm}.json")

with open(original_path, 'r') as f:
    original = json.load(f)

with open(processed_path, 'r') as f:
    processed = json.load(f)

# Statistics
total_samples = len(original)
found_matches = 0
text_matches = 0
tar_eq_matches = 0
tar_number_matches = 0

# Store mismatches
mismatches = {
    'text_mismatches': [],
    'tar_eq_mismatches': [],
    'tar_number_mismatches': []
}

# Create a dictionary of processed samples with text as key for faster lookup
processed_dict = {p['text']: p for p in processed}

for i, o_sample in enumerate(original):
    # Try to find matching text in processed samples
    if o_sample['text'] in processed_dict:
        found_matches += 1
        p_sample = processed_dict[o_sample['text']]
        
        # Compare each field
        if o_sample['text'] == p_sample['text']:
            text_matches += 1
        else:
            mismatches['text_mismatches'].append({
                'index': i,
                'original': o_sample['text'],
                'processed': p_sample['text']
            })

        c = [item.replace("<eoe>", "") for item in o_sample['tar_eq']]
        if c == p_sample['tar_eq']:
            tar_eq_matches += 1
        else:
            mismatches['tar_eq_mismatches'].append({
                'index': i,
                'original': o_sample['tar_eq'],
                'processed': p_sample['tar_eq']
            })
        
        n = [item.replace(",", "") for item in p_sample['tar_number']]
        if o_sample['tar_number'] == n:
            tar_number_matches += 1
        else:
            mismatches['tar_number_mismatches'].append({
                'index': i,
                'original': o_sample['tar_number'],
                'processed': p_sample['tar_number']
            })

# Print statistics
print(f"Total original samples: {total_samples}")
print(f"Matching samples found: {found_matches} ({found_matches/total_samples:.2%})")
print(f"Text matches: {text_matches} ({text_matches/total_samples:.2%})")
print(f"Target equation matches: {tar_eq_matches} ({tar_eq_matches/total_samples:.2%})")
print(f"Target number matches: {tar_number_matches} ({tar_number_matches/total_samples:.2%})")

# Save mismatches to file
mismatch_path = "outputs/mismatches.json"
with open(mismatch_path, 'w') as f:
    json.dump(mismatches, f, indent=4)
print(f"\nMismatches saved to {mismatch_path}")''';


In [23]:
ds['train'][0]

{'question': "In Haley's class 13 boys love to play marbles and 50 boys love to play cards. If Haley has 26 marbles How many will each of the boys receive?",
 'answer': 'The final answer is (26/13) = <<(26/13)=2>>2'}

In [24]:
ds['train']

Dataset({
    features: ['question', 'answer'],
    num_rows: 800
})

In [25]:
len(processed_samples)

800

In [26]:
#assert len(original) == len(processed_samples), "Original and processed datasets have different lengths"

In [27]:
'''text_mismatch = 0
for o_sample, gsm, p_sample in zip(original, processed_samples, ds['train']):
    full_text = p_sample['question'] + " Let's think step by step. " + gsm['text']

    try:
        assert "Let's think step by step." in o_sample['text'], "Missing 'Let's think step by step.' in original text"
        #assert o_sample['text'] == full_text, "Text mismatch"
    except:
        print(f"Original: {o_sample['text']}")
        print(f"Processed: {full_text}")
        text_mismatch += 1
        continue
    #assert o_sample['start_token_idx'] == p_sample['start_token_idx'], "Start indices mismatch"
    #assert o_sample['end_token_idx'] == p_sample['end_token_idx'], "End indices mismatch"
    #assert o_sample['tar_eq'] == p_sample['tar_eq'], "Target equations mismatch"
    #assert o_sample['tar_number'] == p_sample['tar_number'], "Target numbers mismatch"

print(f"Text mismatches found: {text_mismatch/len(original):.2%}")''';