Tokenization and encoding

In [10]:
with open('../data/expressions.txt') as f:
    data = list(map(lambda x: x.rstrip('\n'), f.readlines()))

data[0]

'(x1 - 0.125192826477114) + 0.43174347075339936'

In [14]:
import sympy as sp

expr = sp.parse_expr(data[1])
expr

sin(exp(0.4609526466213322*sin(x1)))

In [22]:
# Get all properties of expr
print("Expression:", expr)
print("Type:", type(expr))
print("Class:", expr.__class__)
print("Arguments:", expr.args)
print("Free symbols:", expr.free_symbols)
print("Is atomic:", expr.is_Atom)

if hasattr(expr, 'func'):
    print("Function:", expr.func)

if hasattr(expr, 'is_commutative'):
    print("Is commutative:", expr.is_commutative)

if hasattr(expr, 'is_real'):
    print("Is real:", expr.is_real)

if hasattr(expr, 'is_integer'):
    print("Is integer:", expr.is_integer)

if hasattr(expr, 'is_rational'):
    print("Is rational:", expr.is_rational)

if hasattr(expr, 'is_algebraic'):
    print("Is algebraic:", expr.is_algebraic)

if hasattr(expr, 'is_transcendental'):
    print("Is transcendental:", expr.is_transcendental)

# Get the structure of the expression
print("\nExpression structure:")
print(expr.structurally_equal(expr, check_numerics=False))

# If it's an operation, get the operation type
if hasattr(expr, 'is_Add') and expr.is_Add:
    print("Operation: Addition")
elif hasattr(expr, 'is_Mul') and expr.is_Mul:
    print("Operation: Multiplication")
elif hasattr(expr, 'is_Pow') and expr.is_Pow:
    print("Operation: Exponentiation")

# If it's a function, get the function name
if hasattr(expr, 'func') and expr.func != expr.__class__:
    print("Function name:", expr.func.__name__)




Expression: sin(exp(0.4609526466213322*sin(x1)))
Type: sin
Class: sin
Arguments: (exp(0.4609526466213322*sin(x1)),)
Free symbols: {x1}
Is atomic: False
Function: sin
Is commutative: True
Is real: None
Is integer: None
Is rational: None
Is algebraic: None
Is transcendental: None

Expression structure:


AttributeError: 'sin' object has no attribute 'structurally_equal'

In [26]:
a = ['sin', 'cos', 'tan', 'exp', 'log']
a.index('cos')

1

In [23]:
from sympy import pi, sin
from sympy.abc import a,x,y

# Turn into prefix notation first
# Then 1-hot encode with xVal

def depth_first_traverse(expr):
    print(f'Operator: {expr.func}')
    for arg in expr.args:
        print(f'Traversing: {arg}')
        depth_first_traverse(arg)
    if len(expr.args) == 0:
        # we reached a leaf of the tree
        print(f'Reached leaf node: {expr}')
    else:
        # do something with compound expr
        print(f'Reached compound node: {expr}')
depth_first_traverse(sin(a*x*pi+1.5)/y)




Operator: <class 'sympy.core.mul.Mul'>
Traversing: 1/y
Operator: <class 'sympy.core.power.Pow'>
Traversing: y
Operator: <class 'sympy.core.symbol.Symbol'>
Reached leaf node: y
Traversing: -1
Operator: <class 'sympy.core.numbers.NegativeOne'>
Reached leaf node: -1
Reached compound node: 1/y
Traversing: sin(pi*a*x + 1.5)
Operator: sin
Traversing: pi*a*x + 1.5
Operator: <class 'sympy.core.add.Add'>
Traversing: 1.50000000000000
Operator: <class 'sympy.core.numbers.Float'>
Reached leaf node: 1.50000000000000
Traversing: pi*a*x
Operator: <class 'sympy.core.mul.Mul'>
Traversing: pi
Operator: <class 'sympy.core.numbers.Pi'>
Reached leaf node: pi
Traversing: a
Operator: <class 'sympy.core.symbol.Symbol'>
Reached leaf node: a
Traversing: x
Operator: <class 'sympy.core.symbol.Symbol'>
Reached leaf node: x
Reached compound node: pi*a*x
Reached compound node: pi*a*x + 1.5
Reached compound node: sin(pi*a*x + 1.5)
Reached compound node: sin(pi*a*x + 1.5)/y


In [28]:
sp.parse_expr('sin(a*x*pi+1.5)/y')

sin(pi*a*x + 1.5)/y

In [45]:
import h5py

# Load the dataset
with h5py.File('../data/dataset.h5', 'r') as f:
    dataset = f['dataset'][()]
    encoding = f['encoding'][()]
    encoding = [str(token)[2:-1] for token in encoding]

#### Decoding and evaluation

In [61]:
import sympy

# Decode the onehot matrix into seq
def onehot_to_seq(onehot_matrix):
    seq = []
    for i, token_id in enumerate(onehot_matrix[:, :-1].argmax(axis=1)):
        if encoding[token_id] == '[PAD]':
            break
        elif encoding[token_id] == '[NUM]':
            seq.append(str(onehot_matrix[i, -1]))
        else:
            seq.append(encoding[token_id])
    return seq

def seq_to_infix(prefix):
    stack = []
    
    # read prefix in reverse order
    i = len(prefix) - 1
    while i >= 0:
        if not is_operator(prefix[i]):
            stack.append(prefix[i])
            i -= 1
        else:
            str = "(" + stack.pop() + prefix[i] + stack.pop() + ")"
            stack.append(str)
            i -= 1
    return stack.pop()

def is_operator(c):
    if c == "*" or c == "+" or c == "-" or c == "/" or c == "^" or c == "(" or c == ")":
        return True
    else:
        return False

In [65]:
onehot_matrix = dataset[0, ...]
seq = onehot_to_seq(onehot_matrix)
infix_str = seq_to_infix(seq)
expr = sympy.parse_expr(infix_str)

print(expr)
# evaluated_expr = expr.evalf(subs={'x1': 1})

1.5883794*x1 - 0.050188385
