### Requirements

In [1]:
import json
import os
import re
import pandas as pd

from latex2sympy2 import latex2sympy, latex2latex
from sympy import simplify, srepr, Eq
from sympy.core.basic import Basic
from zss import simple_distance, Node
from PrettyPrint import PrettyPrintTree

from main import *

### Constants

In [2]:
BASE_PATH = os.path.dirname(os.path.abspath("__file__"))
EXAMPLE_DATA_FILE = "data_example.json"
DATA_FILE = "data.json"

### Read data

In [3]:
# JSON example tree data
with open(os.path.join(BASE_PATH, EXAMPLE_DATA_FILE), 'r') as file:
    example_json_data = json.load(file)
tree1 = example_json_data.get("exprl", {}) # Template answer
tree2 = example_json_data.get("expr2", {}) # Right answer
tree3 = example_json_data.get("expr3", {}) # Wrong answer

# Full JSON data
with open(os.path.join(BASE_PATH, DATA_FILE), 'r') as file:
    json_data = json.load(file)

# Latex string data
expr1 = r"\frac{d}{dx}(x^2 + 2*x) \times \int x \,dx"
expr2 = r"x^3 + x^2" # Correct would be x^3 + x^2
expr3 = r"\frac{(x^3 + x^3)}{\tan(10)}"

### Test of similarity tree analysis

In [4]:
tree1 = latex_to_tree(expr1)
tree2 = latex_to_tree(expr2)
expression_tree_similarity = get_tree_sequence_similarity(tree1, tree2)
print(f"Expression tree similarity: {round(expression_tree_similarity*100, 0)}%")

Expression tree similarity: 100.0%


### Test Bert text similarity

In [5]:
from transformers import BertModel, BertTokenizer

# Carregar o modelo BERT pré-treinado e o tokenizador
modelo = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

emb1 = get_bert_embeddings(latex2latex(expr1), modelo, tokenizer)
emb2 = get_bert_embeddings(latex2latex(expr2), modelo, tokenizer)
print(f"Expression tree similarity: {round(get_text_similarity(emb1, emb2)*100,0)}%")

  from .autonotebook import tqdm as notebook_tqdm
model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]


KeyboardInterrupt: 

### Tree edit distance (Zhang-Shasha)

In [6]:
pt = PrettyPrintTree(lambda x: x.children, lambda x: x.label)

In [7]:
# Latex string data

expr1 = r"\frac{d}{dx}(x^2 + 2*x) \times \int x \,dx"
expr2 = r"x^3 + x^2" # Correct would be x^3 + x^2
expr3 = r"\frac{(x^3 + x^3)}{\tan(10)}"

In [8]:
# Build the tree from the expression
tree1 = build_tree(simplify_latex_expression(expr1))
tree2 = build_tree(simplify_latex_expression(expr2))
tree3 = build_tree(simplify_latex_expression(expr3))

In [9]:
pt(tree1)

                 [100m <class 'sympy.core.mul.Mul'> [0m
               ┌───────────────┴───────────────┐
[100m <class 'sympy.core.power.Pow'> [0m [100m <class 'sympy.core.add.Add'> [0m
             ┌─┴─┐                           ┌─┴─┐             
            [100m x [0m [100m 2 [0m                         [100m x [0m [100m 1 [0m


In [10]:
pt(tree2) 

                 [100m <class 'sympy.core.mul.Mul'> [0m
               ┌───────────────┴───────────────┐
[100m <class 'sympy.core.power.Pow'> [0m [100m <class 'sympy.core.add.Add'> [0m
             ┌─┴─┐                           ┌─┴─┐             
            [100m x [0m [100m 2 [0m                         [100m x [0m [100m 1 [0m


In [11]:
pt(tree3)

                     [100m <class 'sympy.core.mul.Mul'> [0m
              ┌────────────────────┴────────────────────┐
[100m <class 'sympy.core.mul.Mul'> [0m           [100m <class 'sympy.core.power.Pow'> [0m
     ┌────────┴────────┐                              ┌─┴──┐             
    [100m 2 [0m [100m <class 'sympy.core.power.Pow'> [0m            [100m tan [0m [100m -1 [0m           
                     ┌─┴─┐                            |                  
                    [100m x [0m [100m 3 [0m                          [100m 10 [0m


In [12]:
simple_distance(tree1, tree2), simple_distance(tree1, tree3), simple_distance(tree2, tree3)

(0.0, 7.0, 7.0)

### Caspa Dataset 

In [13]:
df = pd.read_csv('test_datagen.csv')
df_test = df[:50].copy(deep=True)

In [14]:
def remove_format_operations(expr):
    return re.sub(r"\\left|\\right", "", expr)

In [15]:
# Apply simplify_latex_expression() to columns expr_l and expr_r
df_test['expr_l'] = df_test['expr_l'].apply(remove_format_operations)
df_test['expr_r'] = df_test['expr_r'].apply(remove_format_operations)

In [16]:
# Apply simplify_latex_expression() to columns expr_l and expr_r
df_test['expr_l'] = df_test['expr_l'].apply(simplify_latex_expression)
df_test['expr_r'] = df_test['expr_r'].apply(simplify_latex_expression)

In [17]:
# Build trees for simplified expressions in columns expr_l and expr_r

df_test['tree_l'] = df_test['expr_l'].apply(build_tree)
df_test['tree_r'] = df_test['expr_r'].apply(build_tree)

In [18]:
# Calculate tree scores using simple_distance()
df_test['tree_score'] = df_test.apply(lambda row: simple_distance(row['tree_l'], row['tree_r']), axis=1)

# Save the updated DataFrame to CSV
df_test.to_csv('output.csv', index=False)

In [19]:
simplify_latex_expression(r"\tan(x)")

tan(x)

In [20]:
pt(df_test.iloc[1]["tree_l"])
pt(df_test.iloc[1]["tree_r"])
df_test.iloc[1]

                           [100m <class 'sympy.core.add.Add'> [0m
              ┌──────────────────────────┴───────────────────────────┐
[100m <class 'sympy.core.mul.Mul'> [0m                         [100m <class 'sympy.core.mul.Mul'> [0m                                     
     ┌────────┴────────┐                  ┌──────────────────────────┴──────────────────────────┐                         
    [100m 5 [0m [100m <class 'sympy.core.power.Pow'> [0m [100m 1 [0m                                     [100m <class 'sympy.core.power.Pow'> [0m         
                     ┌─┴─┐                                                ┌─────────────────────┴──────────────────────┐  
                    [100m 2 [0m [100m -1 [0m                               [100m <class 'sympy.core.power.Pow'> [0m                           [100m -1 [0m
                                                           ┌──────────────┴───────────────┐                               
                       

Unnamed: 0                                                    1
expr_l                                  5/2 + 1/(p + r)**(11/2)
expr_r        -1*189168519186093/500000000000000 + 1/(p + r)...
score                                                        24
tree_l             <main.TreeNode object at 0x00000253244833A0>
tree_r             <main.TreeNode object at 0x0000025324529C00>
tree_score                                                  3.0
Name: 1, dtype: object

In [21]:
pt(df_test.iloc[17]["tree_r"])
pt(df_test.iloc[17]["tree_l"])
df_test.iloc[17]#[["expr_l", "expr_r", "tree_l", "tree_r"]]

            [100m exp [0m
              |
[100m <class 'sympy.core.mul.Mul'> [0m
   ┌────┬─────┴───────────┐
  [100m -1 [0m [100m i [0m [100m <class 'sympy.core.power.Pow'> [0m
                        ┌─┴──┐             
                       [100m 32 [0m [100m -1 [0m
            [100m exp [0m
              |
[100m <class 'sympy.core.mul.Mul'> [0m
   ┌────┬─────┴───────────┐
  [100m -1 [0m [100m i [0m [100m <class 'sympy.core.power.Pow'> [0m
                        ┌─┴──┐             
                       [100m 32 [0m [100m -1 [0m


Unnamed: 0                                              17
expr_l                                        exp(-1*i/32)
expr_r                                        exp(-1*i/32)
score                                                   24
tree_l        <main.TreeNode object at 0x00000253244EB1C0>
tree_r        <main.TreeNode object at 0x000002532451C460>
tree_score                                             0.0
Name: 17, dtype: object

#### Testing extreme cases

In [22]:
# Latex string data

lim_expr1 = r"\lim_{x \to \infty} \frac{1}{x}"
lim_expr2 = r"\lim_{x \to 0} \frac{10}{x}"
lim_expr3 = r"\lim_{x \to \infty} \cos(x)" 

In [23]:
simplify_latex_expression(lim_expr1), simplify_latex_expression(lim_expr2), simplify_latex_expression(lim_expr3) 

(0, oo, AccumBounds(-1, 1))

In [17]:
str1 = 'x - y'
tree1 = load_expr(str1)
pt(tree1)

[100m <class 'sympy.core.add.Add'> [0m
      ┌───────┴────────┐
     [100m x [0m [100m <class 'sympy.core.mul.Mul'> [0m
                     ┌─┴──┐            
                    [100m -1 [0m [100m y [0m


In [31]:
tree1.label == Add

True

In [52]:
def give_feedback(answer, expected):
    tree1 = load_expr(answer)
    tree2 = load_expr(expected)
    feedback = []
    stack1 = []
    stack2 = []
    #Breath-first traversal
    while True:
        stack1.extend(tree1.children)
        stack2.extend(tree2.children)
        print(tree1.label)
        print(tree2.label)
        print(stack1)
        print(stack2)
        print('==========')
        # if len(tree1.children) == 0:
        #     feedback.append("You forgot terms!")

        # if len(tree2.children) == 0:
        #     feedback.append("You have extra terms!")

        if tree1.label == Mul and not tree2.label.is_symbol:
            print('a')
            if tree1.children[0].label == '-1':
                feedback.append("You got one sign wrong!")

        if tree2.label == Mul and not tree1.label.is_symbol:
            print('b')
            if tree2.children[0].label == '-1':
                feedback.append("You got one sign wrong!")

        tree1 = stack1.pop(0)
        tree2 = stack2.pop(0)

        if len(tree1.children) > 0 or len(tree2.children) > 0:
            break

    return feedback

In [4]:
str1 = 'x - y'
str2 = 'x + y'

tree1 = load_expr(str1)
tree2 = load_expr(str2)

pt = PrettyPrintTree()

pt(tree1)
pt(tree2)

AttributeError: 'TreeNode' object has no attribute 'value'

In [1]:
from main import give_feedback
import json
import os
import re
import pandas as pd
from latex2sympy2 import latex2sympy, latex2latex
from sympy import simplify, srepr, Eq
from sympy.core.basic import Basic
from zss import simple_distance, Node
from PrettyPrint import PrettyPrintTree
from main import *

str1 = 'x - y'
str2 = 'x + y'

detect_terms(str1, str2, verbose=True)

x
y
x
y
------Terms expected: 2
------Terms received: 2


True

In [54]:
str1 = '(x + 1)^2 + y'
tree1 = load_expr(str1)

tree1.children[0].label.is_symbol

AttributeError: 'str' object has no attribute 'is_symbol'