In [25]:
from typing import Iterable
import warnings

import pandas as pd
import sympy
import lark

import gadgets

In [26]:
df = pd.read_json('../data/svamp/svamp.json')
df.columns = df.columns.str.lower()
df

Unnamed: 0,id,body,question,equation,answer,type
0,chal-1,Each pack of dvds costs 76 dollars. If there i...,How much do you have to pay to buy each pack?,( 76.0 - 25.0 ),51,Subtraction
1,chal-2,Dan had $ 3 left with him after he bought a ca...,How much did the candy bar cost?,( 4.0 - 3.0 ),1,Subtraction
2,chal-3,Paco had 26 salty cookies and 17 sweet cookies...,How many salty cookies did Paco have left?,( 26.0 - 9.0 ),17,Subtraction
3,chal-4,43 children were riding on the bus. At the bus...,How many children got off the bus at the bus s...,( 43.0 - 21.0 ),22,Subtraction
4,chal-5,28 children were riding on the bus. At the bus...,How many more children got on the bus than tho...,( 30.0 - 28.0 ),2,Subtraction
...,...,...,...,...,...,...
995,chal-996,Paige was helping her mom plant flowers and to...,How many flower beds did they have?,( 36.0 / 12.0 ),3,Common-Division
996,chal-997,"At the zoo, a cage had 3 snakes and 75 alligat...",How many alligators were not hiding?,( 75.0 - 19.0 ),56,Subtraction
997,chal-998,Paige was helping her mom plant flowers and to...,How many flowers did they grow?,( 60.0 * ( 55.0 / 15.0 ) ),220,Multiplication
998,chal-999,Mary is baking a cake. The recipe calls for 7 ...,How many more cups of sugar does she need to add?,( 7.0 - 4.0 ),3,Subtraction


In [27]:
df["question"].str[0].value_counts()

question
H    988
W      6
A      4
T      2
Name: count, dtype: int64

In [28]:
df["body"].str[-1].value_counts()

body
.    642
s    145
y     45
t     34
k     24
e     17
w     14
r     14
g     12
d     10
l      9
n      6
a      6
m      4
x      4
p      3
u      3
f      2
7      2
5      1
3      1
8      1
2      1
Name: count, dtype: int64

In [29]:
def merge_body_with_question(body: str, question: str):
    if body.endswith("."):
        return body + " " + question
    return body + ", " + question[0].lower() + question[1:]

In [30]:
df["question"] = df.apply(lambda row: merge_body_with_question(row["body"], row["question"]), axis=1)
del df["body"]

In [31]:
df

Unnamed: 0,id,question,equation,answer,type
0,chal-1,Each pack of dvds costs 76 dollars. If there i...,( 76.0 - 25.0 ),51,Subtraction
1,chal-2,Dan had $ 3 left with him after he bought a ca...,( 4.0 - 3.0 ),1,Subtraction
2,chal-3,Paco had 26 salty cookies and 17 sweet cookies...,( 26.0 - 9.0 ),17,Subtraction
3,chal-4,43 children were riding on the bus. At the bus...,( 43.0 - 21.0 ),22,Subtraction
4,chal-5,28 children were riding on the bus. At the bus...,( 30.0 - 28.0 ),2,Subtraction
...,...,...,...,...,...
995,chal-996,Paige was helping her mom plant flowers and to...,( 36.0 / 12.0 ),3,Common-Division
996,chal-997,"At the zoo, a cage had 3 snakes and 75 alligat...",( 75.0 - 19.0 ),56,Subtraction
997,chal-998,Paige was helping her mom plant flowers and to...,( 60.0 * ( 55.0 / 15.0 ) ),220,Multiplication
998,chal-999,Mary is baking a cake. The recipe calls for 7 ...,( 7.0 - 4.0 ),3,Subtraction


In [32]:
grammar = """
?start: expr

?expr: neg

?atom: num
    | implicit_mul
    | "(" expr ")"

implicit_mul: num ( "(" expr ")" )+
            | "(" expr ")" ( "(" expr ")" )+

?neg: add
    | "-" neg -> neg
    | "-" add -> neg
?add: sub
    | sub ("+" sub)+ -> add
?sub: mul
    | mul ("-" mul)+ -> sub
?mul: div
    | div ("*" div)+ -> mul
?div: pow
    | pow ("/" pow)+ -> div
?pow: perc
    | perc ("**" perc)+ -> pow
?perc: atom "%" -> perc
     | atom
?num: SIGNED_NUMBER

%import common.SIGNED_NUMBER
%import common.WS
%ignore WS
"""

In [33]:
class TreeEvaluator:

    def __init__(self, calc: gadgets.gadget.Calculator, parser: lark.Lark) -> None:
        self.cache = {}
        self.calc = calc
        self.parser = parser

    def eval_tree(self, tree: lark.Tree | lark.Token) -> tuple[str, sympy.Expr]:
        if tree not in self.cache:
            self.cache[tree] = self._eval_tree(tree)
        return self.cache[tree]

    def _eval_tree(self, tree: lark.Tree | lark.Token) -> tuple[str, sympy.Expr]:
        if isinstance(tree, lark.Token):
            if tree.type in ("SIGNED_NUMBER", "NUMBER"):
                return None, self.calc.evaluate(tree.value)
            else:
                raise ValueError(f"unknown token {tree}")
        
        assert isinstance(tree.data, str)
        operation = tree.data
        args_nodes = tree.children
        args = [self._format_arg(self.eval_tree(arg_node)[1]) for arg_node in args_nodes]
        inputs = self._format_op(operation, args)
        return inputs, self.calc.evaluate(inputs)
    
    def _format_op(self, op: str, args: list[str]) -> str:
        if op == "neg":
            assert len(args) == 1
            return "-" + args[0]
        if op == "add" or op == "implicit_add":
            return " + ".join(args)  
        if op == "sub":
            return " - ".join(args)
        if op == "mul" or op == "implicit_mul":
            return " * ".join(args)
        if op == "div":
            return " / ".join(args)
        if op == "pow":
            return " ** ".join(args)
        if op == "perc":
            assert len(args) == 1
            return f"{args[0]} / 100"
        raise ValueError(f"unknown operation {op}")

    def _format_arg(self, value_expr: sympy.Number) -> str:
        value_str = self.calc.format_sympy_number(value_expr, add_approx=False)
        if isinstance(value_expr, sympy.core.function.Application):
            return value_str
        if isinstance(value_expr, (sympy.Float, sympy.Integer, sympy.NumberSymbol)):
            if value_expr < 0:
                return "(" + value_str + ")"
            return value_str
        if isinstance(value_expr, sympy.Rational):
            return "(" + value_str + ")"
        if isinstance(value_expr, (sympy.Mul, sympy.Pow, sympy.Add)):
            return "(" + value_str + ")"
        warnings.warn(f"weird value type {type(value_expr)} for {value_expr} (string: '{value_str}')")
        return "(" + value_str + ")"

    def dfs(self, tree: lark.Tree | lark.Token) -> Iterable[lark.Tree | lark.Token]:
        if isinstance(tree, lark.Tree):
            for child in tree.children:
                yield from self.dfs(child)
        yield tree

    def expr_to_steps(self, expr: str, drop_repeated: bool = True) -> tuple[list[gadgets.datatypes.Interaction], sympy.Expr]:
        tree = self.parser.parse(expr)
        steps = []
        for subtree in self.dfs(tree):
            step = self.eval_tree(subtree)
            if step is None:
                continue
            if step[0] is None:
                continue
            inputs, output_expr = step
            interaction = gadgets.datatypes.Interaction(
                gadget_id="calculator",
                inputs=inputs,
                outputs=self.calc.format_sympy_number(output_expr),
            )
            if drop_repeated and interaction in steps:
                continue
            steps.append(interaction)
            
        _, result = self.eval_tree(tree)
        return steps, result
    

In [34]:
calc = gadgets.gadget.Calculator()

parser = lark.Lark(grammar)
tree_evaluator = TreeEvaluator(calc, parser)

df["chain"], df["evaluated_result"] = zip(*df["equation"].apply(tree_evaluator.expr_to_steps))
df

Unnamed: 0,id,question,equation,answer,type,chain,evaluated_result
0,chal-1,Each pack of dvds costs 76 dollars. If there i...,( 76.0 - 25.0 ),51,Subtraction,[gadget_id='calculator' inputs='76 - 25' outpu...,51
1,chal-2,Dan had $ 3 left with him after he bought a ca...,( 4.0 - 3.0 ),1,Subtraction,[gadget_id='calculator' inputs='4 - 3' outputs...,1
2,chal-3,Paco had 26 salty cookies and 17 sweet cookies...,( 26.0 - 9.0 ),17,Subtraction,[gadget_id='calculator' inputs='26 - 9' output...,17
3,chal-4,43 children were riding on the bus. At the bus...,( 43.0 - 21.0 ),22,Subtraction,[gadget_id='calculator' inputs='43 - 21' outpu...,22
4,chal-5,28 children were riding on the bus. At the bus...,( 30.0 - 28.0 ),2,Subtraction,[gadget_id='calculator' inputs='30 - 28' outpu...,2
...,...,...,...,...,...,...,...
995,chal-996,Paige was helping her mom plant flowers and to...,( 36.0 / 12.0 ),3,Common-Division,[gadget_id='calculator' inputs='36 / 12' outpu...,3
996,chal-997,"At the zoo, a cage had 3 snakes and 75 alligat...",( 75.0 - 19.0 ),56,Subtraction,[gadget_id='calculator' inputs='75 - 19' outpu...,56
997,chal-998,Paige was helping her mom plant flowers and to...,( 60.0 * ( 55.0 / 15.0 ) ),220,Multiplication,[gadget_id='calculator' inputs='55 / 15' outpu...,220
998,chal-999,Mary is baking a cake. The recipe calls for 7 ...,( 7.0 - 4.0 ),3,Subtraction,[gadget_id='calculator' inputs='7 - 4' outputs...,3


In [35]:
df[df["answer"] != df["evaluated_result"]]

Unnamed: 0,id,question,equation,answer,type,chain,evaluated_result
679,chal-680,Rachel's tree had 4 apples. She picked 2 apple...,( ( 4.0 - 2.0 ) + 3.0 ),1,Addition,[gadget_id='calculator' inputs='4 - 2' outputs...,5


In [36]:
df["result"] = df["evaluated_result"].apply(calc.format_sympy_number, add_approx=False)
df["result_float"] = df["result"].apply(calc._float_eval)
del df["evaluated_result"]

In [37]:
df["chain"] = df.apply(lambda row: gadgets.markup.to_model_markup(chain=row["chain"], result=str(row["answer"])), axis=1)
df["chain"]

0                [\n, [76 - 25], \n, [51], \n, \n, [51]]
1                    [\n, [4 - 3], \n, [1], \n, \n, [1]]
2                 [\n, [26 - 9], \n, [17], \n, \n, [17]]
3                [\n, [43 - 21], \n, [22], \n, \n, [22]]
4                  [\n, [30 - 28], \n, [2], \n, \n, [2]]
                             ...                        
995                [\n, [36 / 12], \n, [3], \n, \n, [3]]
996              [\n, [75 - 19], \n, [56], \n, \n, [56]]
997    [\n, [55 / 15], \n, [11/3 = around 3.666667], ...
998                  [\n, [7 - 4], \n, [3], \n, \n, [3]]
999               [\n, [13 - 2], \n, [11], \n, \n, [11]]
Name: chain, Length: 1000, dtype: object

In [38]:
df["chain"] = df["chain"].apply(str).str.strip()

In [39]:
df["result_float"] == df["result_float"].apply(round)

0      True
1      True
2      True
3      True
4      True
       ... 
995    True
996    True
997    True
998    True
999    True
Name: result_float, Length: 1000, dtype: bool

In [40]:
df

Unnamed: 0,id,question,equation,answer,type,chain,result,result_float
0,chal-1,Each pack of dvds costs 76 dollars. If there i...,( 76.0 - 25.0 ),51,Subtraction,"<gadget id=""calculator"">76 - 25</gadget>\n<out...",51,51.0
1,chal-2,Dan had $ 3 left with him after he bought a ca...,( 4.0 - 3.0 ),1,Subtraction,"<gadget id=""calculator"">4 - 3</gadget>\n<outpu...",1,1.0
2,chal-3,Paco had 26 salty cookies and 17 sweet cookies...,( 26.0 - 9.0 ),17,Subtraction,"<gadget id=""calculator"">26 - 9</gadget>\n<outp...",17,17.0
3,chal-4,43 children were riding on the bus. At the bus...,( 43.0 - 21.0 ),22,Subtraction,"<gadget id=""calculator"">43 - 21</gadget>\n<out...",22,22.0
4,chal-5,28 children were riding on the bus. At the bus...,( 30.0 - 28.0 ),2,Subtraction,"<gadget id=""calculator"">30 - 28</gadget>\n<out...",2,2.0
...,...,...,...,...,...,...,...,...
995,chal-996,Paige was helping her mom plant flowers and to...,( 36.0 / 12.0 ),3,Common-Division,"<gadget id=""calculator"">36 / 12</gadget>\n<out...",3,3.0
996,chal-997,"At the zoo, a cage had 3 snakes and 75 alligat...",( 75.0 - 19.0 ),56,Subtraction,"<gadget id=""calculator"">75 - 19</gadget>\n<out...",56,56.0
997,chal-998,Paige was helping her mom plant flowers and to...,( 60.0 * ( 55.0 / 15.0 ) ),220,Multiplication,"<gadget id=""calculator"">55 / 15</gadget>\n<out...",220,220.0
998,chal-999,Mary is baking a cake. The recipe calls for 7 ...,( 7.0 - 4.0 ),3,Subtraction,"<gadget id=""calculator"">7 - 4</gadget>\n<outpu...",3,3.0


In [41]:
df = pd.DataFrame({
    "id": "svamp__" + df["id"],
    "question": df["question"],
    "chain": df["chain"],
    "result": df["result"],
    "result_float": df["result_float"],
    "equation": df["equation"],
    "problem_type": df["type"],
})

In [42]:
df

Unnamed: 0,id,question,chain,result,result_float,equation,problem_type
0,svamp__chal-1,Each pack of dvds costs 76 dollars. If there i...,"<gadget id=""calculator"">76 - 25</gadget>\n<out...",51,51.0,( 76.0 - 25.0 ),Subtraction
1,svamp__chal-2,Dan had $ 3 left with him after he bought a ca...,"<gadget id=""calculator"">4 - 3</gadget>\n<outpu...",1,1.0,( 4.0 - 3.0 ),Subtraction
2,svamp__chal-3,Paco had 26 salty cookies and 17 sweet cookies...,"<gadget id=""calculator"">26 - 9</gadget>\n<outp...",17,17.0,( 26.0 - 9.0 ),Subtraction
3,svamp__chal-4,43 children were riding on the bus. At the bus...,"<gadget id=""calculator"">43 - 21</gadget>\n<out...",22,22.0,( 43.0 - 21.0 ),Subtraction
4,svamp__chal-5,28 children were riding on the bus. At the bus...,"<gadget id=""calculator"">30 - 28</gadget>\n<out...",2,2.0,( 30.0 - 28.0 ),Subtraction
...,...,...,...,...,...,...,...
995,svamp__chal-996,Paige was helping her mom plant flowers and to...,"<gadget id=""calculator"">36 / 12</gadget>\n<out...",3,3.0,( 36.0 / 12.0 ),Common-Division
996,svamp__chal-997,"At the zoo, a cage had 3 snakes and 75 alligat...","<gadget id=""calculator"">75 - 19</gadget>\n<out...",56,56.0,( 75.0 - 19.0 ),Subtraction
997,svamp__chal-998,Paige was helping her mom plant flowers and to...,"<gadget id=""calculator"">55 / 15</gadget>\n<out...",220,220.0,( 60.0 * ( 55.0 / 15.0 ) ),Multiplication
998,svamp__chal-999,Mary is baking a cake. The recipe calls for 7 ...,"<gadget id=""calculator"">7 - 4</gadget>\n<outpu...",3,3.0,( 7.0 - 4.0 ),Subtraction


In [43]:
df.to_json("../data/svamp/svamp-processed.jsonl", orient="records", lines=True, force_ascii=False)

In [44]:
import datasets

ds = datasets.DatasetDict({
    "test": datasets.Dataset.from_pandas(df),
})

#ds.push_to_hub("anonym-repos/Calc-svamp", config_name="original-splits")

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading metadata:   0%|          | 0.00/3.09k [00:00<?, ?B/s]

In [45]:
datasets.load_dataset("anonym-repos/Calc-svamp", "original-splits")["test"]

Downloading readme:   0%|          | 0.00/3.09k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/116k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'question', 'chain', 'result', 'result_float', 'equation', 'problem_type'],
    num_rows: 1000
})

In [46]:
datasets.load_dataset("anonym-repos/Calc-svamp", "original-splits")["test"][0]

{'id': 'svamp__chal-1',
 'question': 'Each pack of dvds costs 76 dollars. If there is a discount of 25 dollars on each pack, how much do you have to pay to buy each pack?',
 'chain': '<gadget id="calculator">76 - 25</gadget>\n<output>51</output>\n\n<result>51</result>',
 'result': '51',
 'result_float': 51.0,
 'equation': '( 76.0 - 25.0 )',
 'problem_type': 'Subtraction'}