In [1]:
from datasets import load_dataset

gsm8k_train = load_dataset("gsm8k", "main")["train"]

import re
lines_without_eqn = []
for i in range(len(gsm8k_train)):
    answer = gsm8k_train[i]["answer"]
    # split lines in the answer
    answer_lines = answer.split("\n")
    # remove leading and trailing whitespace from each line
    answer_lines = [line.strip() for line in answer_lines]
    # remove empty lines
    answer_lines = [line for line in answer_lines if line]
    # remove the final answer line
    answer_lines = answer_lines[:-1]
    # use a regex to only keep lines that do not contain digits
    answer_lines = [line for line in answer_lines if not re.search(r'\d', line) or ("=" not in line and "equals" not in line)]
    # add the remaining lines to the list
    lines_without_eqn.extend(answer_lines)

lines_without_eqn = [line for line in lines_without_eqn if not line.startswith("Let ")]

print(len(lines_without_eqn))
lines_without_eqn

# The list outputted below is saved as a py file "lines_without_eqn.py"
# Lines containing calculations are manually removed

583


['The second hundred years’ ship had twice as many as the first, so it had 2S people.',
 'In one hour, there are 3 sets of 20 minutes.',
 'According to the ratio, for every 5 parts that Johnson gets, Mike gets 2 parts',
 'Her income was increased by $600 so it is now p+$600',
 '15 coins collected in hour one',
 '35 coins collected in hour two',
 '35 coins collected in hour three',
 '50 coins collected in hour four',
 'Lee earned $558 mowing lawns last week.',
 'The last two people each lost 28 kilograms of weight.',
 'Last year: 86 geckos',
 'Janice had X sentences from yesterday already typed to start with today.',
 'She has 2 bills to mail.',
 'Jayden is 4 years old.',
 'Keiko sent 283 texts last week and this week combined.',
 'So this movie is 192 minutes',
 'The first snake is 24 inches because there are 12 inches in a foot.',
 'Dallas is 46 years old now.',
 'He ate 70 apples in total.',
 'We let x be the amount of money she put in the bank the first year,',
 'Then the second yea

In [2]:
import pandas as pd
import json

df = pd.read_csv('../data/aug-5-dataset/error_detection_dataset.csv')

In [3]:
def filter_error_detection_dataset_for_coverage(dataset_df):
    """
    Filter error detection dataset to keep only rows with complete calculator annotation coverage.
    """
    
    def has_complete_annotation_coverage(row, eqn_mapping_col):
        """Check if all solution lines except FA have calculator annotations."""
        try:
            eqn_mapping = json.loads(row[eqn_mapping_col])
            answer_length = row['correct_answer_length'] if 'correct' in eqn_mapping_col else row['wrong_answer_length']
            expected_lines = answer_length - 1 if 'FA' in eqn_mapping else answer_length
            non_empty_equations = sum(1 for key, value in eqn_mapping.items() 
                                    if key != 'FA' and value and str(value).strip())
            return non_empty_equations == expected_lines
        except (json.JSONDecodeError, KeyError, TypeError):
            return False
    
    print("🔍 Filtering error detection dataset for complete annotation coverage...")
    print(f"Original dataset size: {len(dataset_df):,} samples")
    
    # Create coverage mask for each error type
    coverage_masks = []
    
    for error_type in dataset_df['error_type'].unique():
        type_mask = dataset_df['error_type'] == error_type
        type_data = dataset_df[type_mask]
        
        if error_type == 'correct':
            eqn_col = 'correct_eqn_mapping'
        else:
            eqn_col = 'wrong_eqn_mapping'
        
        type_coverage = type_data.apply(
            lambda row: has_complete_annotation_coverage(row, eqn_col), axis=1
        )
        
        with_coverage = type_coverage.sum()
        total_type = len(type_data)
        print(f"  {error_type}: {with_coverage:,} / {total_type:,} samples have complete coverage ({with_coverage/total_type:.1%})")
        
        coverage_masks.append(type_coverage)
    
    full_coverage_mask = pd.concat(coverage_masks, ignore_index=False)
    filtered_dataset = dataset_df[full_coverage_mask].copy()
    
    print(f"\nFiltered dataset size: {len(filtered_dataset):,} samples")
    print(f"Removed {len(dataset_df) - len(filtered_dataset):,} samples without complete coverage")
    print("\nFinal dataset composition:")
    final_counts = filtered_dataset['error_type'].value_counts()
    for error_type, count in final_counts.items():
        print(f"  {error_type}: {count:,} samples")
    
    return filtered_dataset

df_filtered = filter_error_detection_dataset_for_coverage(df)

🔍 Filtering error detection dataset for complete annotation coverage...
Original dataset size: 6,067 samples
  conceptual_error: 1,591 / 2,067 samples have complete coverage (77.0%)
  computational_error: 1,571 / 2,000 samples have complete coverage (78.5%)
  correct: 1,608 / 2,000 samples have complete coverage (80.4%)

Filtered dataset size: 4,770 samples
Removed 1,297 samples without complete coverage

Final dataset composition:
  correct: 1,608 samples
  conceptual_error: 1,591 samples
  computational_error: 1,571 samples


In [4]:
df_comp = df_filtered[df_filtered['error_type'] == 'computational_error']
df_comp = df_comp.dropna(subset=['erroneous_line', 'erroneous_line_eqn'])

df_concep = df_filtered[df_filtered['error_type'] == 'conceptual_error']
df_concep = df_concep.dropna(subset=['erroneous_line', 'erroneous_line_eqn'])

In [5]:
# Add lines with erroneous calculations

erroneous_lines_df = pd.DataFrame({
    "line" : df_comp["erroneous_line"],
    "eqn": df_comp["erroneous_line_eqn"],
    "type": "flawed"
})

In [6]:
# Add lines without calculations

from lines_without_eqn import lines_without_eqn

lines_without_eqn_df = pd.DataFrame({
    "line": lines_without_eqn,
    "eqn": [""] * len(lines_without_eqn),  # Create empty strings for each line,
    "type": "missing"
})

In [7]:
# Add lines with correct calculations

import random
correct_samples = []
for _, row in df_concep.iterrows():
    soln_dict = json.loads(row["correct_answer_mapping"])
    eqn_dict = json.loads(row["correct_eqn_mapping"])
    # randomly choose a line different from "FA" and with non-empty equation string
    random_key = random.choice([key for key in soln_dict.keys() if key != "FA" and eqn_dict[key] != ""])

    correct_samples.append({
        "line": soln_dict[random_key],
        "eqn": eqn_dict[random_key],
        "type": "correct"
    })

correct_lines_df = pd.DataFrame(correct_samples)

In [8]:
# concatenate all 3 dataframes into 1
final_df = pd.concat([correct_lines_df, erroneous_lines_df, lines_without_eqn_df], ignore_index=True)

In [9]:
final_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3383 entries, 0 to 3382
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   line    3383 non-null   object
 1   eqn     3383 non-null   object
 2   type    3383 non-null   object
dtypes: object(3)
memory usage: 79.4+ KB


In [None]:
# final_df.to_csv("aug_10_eqn_extraction_dataset.csv", index=False)

In [11]:
pd.set_option('display.max_colwidth', None)
pd.set_option('display.width', 0)

print(final_df["type"].value_counts())

final_df.groupby("type").sample(5)

type
flawed     1571
correct    1295
missing     517
Name: count, dtype: int64


Unnamed: 0,line,eqn,type
1050,The final length is 35 feet because 41 - 6 = 35,41-6=35,correct
79,So it is 60/60=1 hour,60/60=1,correct
222,He completed 4 jobs because 20 / 5 = 4,20/5=4,correct
285,60 minutes are in an hour and he's adding 240 minutes to his trip so that's 240/60 = 4 more hours,240/60=4,correct
133,Then multiply the number of quarts of champagne Jackson buy by the normal price per bottle to find the cost before the discount: 160 quarts * $50/quart = $8000,160*50=8000,correct
2329,So James's older brother is 12+4=12 years old,12+4=12,flawed
2451,There are 21*2=24 polar bears,21*2=24,flawed
1727,The rental fee is 5000*0.1=$5000 per week,5000*0.1=5000,flawed
1323,A sandwich and a pack of juice cost $0.3 + $0.2 = $0.05.,0.3+0.2=0.05,flawed
2796,"On Tuesday, 227 books were taken, so there are 235 - 227= 7 books left.",235-227=7,flawed
