In [1]:
import time
import math
import random
import pickle
import signal
import numpy as np
import sympy as sp
import pandas as pd
import multiprocessing
import concurrent.futures
from tqdm.notebook import tqdm_notebook as tqdm
from typing import List, Dict, Tuple, Union, Optional

In [2]:
multiprocessing.set_start_method('fork', force=True)

In [3]:
x = sp.symbols('x', real=True, positive=True)

vocab = np.array([
    ['add', 6, 2],  # binary operators
    ['sub', 3, 2],
    ['mul', 6, 2],
    ['div', 3, 2],
    ['pow', 3, 2],
    ['sq', 2, 1],
    ['sqrt', 2, 1],
    ['cb', 2, 1],
    ['cbrt', 2, 1],
    ['exp', 2, 1],  # log is natural logarithm (ln)
    ['ln', 2, 1], 
    ['sin', 2, 1],  # unary operators
    ['cos', 2, 1],
    ['tan', 2, 1],
    ['asin', 2, 1],  # inverse trig functions
    ['acos', 2, 1],
    ['atan', 2, 1],
    ['sinh', 2, 1],  # hyperbolic functions
    ['cosh', 2, 1],
    ['tanh', 2, 1],
    ['asinh', 2, 1],  # inverse hyperbolic functions
    ['acosh', 2, 1],
    ['atanh', 2, 1], # derivative takes function and variable
    ['x', 10, 0]
])

In [4]:
def generate_expression(vocab, max_depth=6, depth=0) -> List[str]:

    random.seed()
    np.random.seed()
    """
    Recursive function to generate one expression using the tokens and their
    respective probabilities provided by 'vocab'.
    
    Args:
        vocab: Vocabulary array
        max_depth: Maximum depth of the expression tree
        depth: Current depth in the recursion
    
    Returns:
        List of tokens representing an expression
    """
    if depth >= max_depth:
        return ['x']
    
    weights = vocab[:, 1].astype('float32')
    probs = weights / np.sum(weights)
    N = len(vocab)
    expr = []
    rand_idx = np.random.choice(N, p=probs)
    cur_token = vocab[rand_idx, 0]
    cur_arity = int(vocab[rand_idx, 2])
    expr.append(cur_token)
    
    if cur_arity == 0:
        return expr
    else:
        # Define token families to avoid in the same branch
        token_families = [
            ['sin', 'cos', 'tan', 'asin', 'acos', 'atan'],  # Trigonometric functions
            ['sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh'],  # hyperbolic
            ['exp', 'ln'],  # Exponential and logarithm
            ['sq', 'sqrt'],  # Square and square root
            ['cb', 'cbrt']   # Cube and cube root
        ]
        
        # Find which family the current token belongs to, if any
        token_family = None
        for family in token_families:
            if cur_token in family:
                token_family = family
                break
        
        # Remove tokens from the same family for child expressions
        new_vocab = vocab
        if token_family:
            indices_to_remove = []
            for token in token_family:
                indices = np.where(vocab[:, 0] == token)[0]
                if indices.size > 0:
                    indices_to_remove.extend(indices)
            
            if indices_to_remove:
                new_vocab = np.delete(vocab, indices_to_remove, axis=0)
        
        if cur_arity == 1:
            child = generate_expression(new_vocab, max_depth, depth + 1)
            return expr + child
        elif cur_arity == 2:
            child1 = generate_expression(new_vocab, max_depth, depth + 1)
            child2 = generate_expression(new_vocab, max_depth, depth + 1)
            return expr + child1 + child2

In [5]:
def sequence_to_sympy(expr, vocab, x=None) -> sp.Expr:
    """
    Convert a sequence of tokens into a SymPy expression.
    
    Args:
        expr: List of tokens representing an expression
        vocab: Vocabulary array
        x: SymPy symbol for variable x
    
    Returns:
        SymPy expression
    """
    if x is None:
        x = sp.Symbol('x')
        
    cur_token = expr[0]
    try:
        return float(cur_token)  # for cases when constants are evaluated
    except ValueError:
        cur_idx = np.where(vocab[:, 0] == cur_token)[0][0]
        cur_arity = int(vocab[cur_idx, 2])
    
    if cur_arity == 0:
        if cur_token == 'x':
            return x
    elif cur_arity == 1:
        # Handle all unary operators
        operand = sequence_to_sympy(expr[1:], vocab, x)
        
        # Basic functions
        if cur_token == 'sin':
            return sp.sin(operand)
        elif cur_token == 'cos':
            return sp.cos(operand)
        elif cur_token == 'tan':
            return sp.tan(operand)
        
        # Inverse trig functions
        elif cur_token == 'asin':
            return sp.asin(operand)
        elif cur_token == 'acos':
            return sp.acos(operand)
        elif cur_token == 'atan':
            return sp.atan(operand)
        
        # Hyperbolic functions
        elif cur_token == 'sinh':
            return sp.sinh(operand)
        elif cur_token == 'cosh':
            return sp.cosh(operand)
        elif cur_token == 'tanh':
            return sp.tanh(operand)
        
        # Inverse hyperbolic functions
        elif cur_token == 'asinh':
            return sp.asinh(operand)
        elif cur_token == 'acosh':
            return sp.acosh(operand)
        elif cur_token == 'atanh':
            return sp.atanh(operand)
        
        # Exponential, logarithmic, and power functions
        elif cur_token == 'exp':
            return sp.exp(operand)
        elif cur_token == 'ln':
            return sp.log(operand)
        elif cur_token == 'sq':
            return operand**2
        elif cur_token == 'sqrt':
            return sp.sqrt(operand)
        elif cur_token == 'cb':
            return operand**3
        elif cur_token == 'cbrt':
            return sp.cbrt(operand)
            
    elif cur_arity == 2:
        # Calculate where to split the expression for binary operators
        arity_count = 1
        idx_split = 1
        for temp_token in expr[1:]:
            try:
                float(temp_token)  # for cases when constants are evaluated
                arity_count += -1
            except ValueError:
                temp_idx = np.where(vocab[:, 0] == temp_token)[0][0]
                arity_count += int(vocab[temp_idx, 2]) - 1
            idx_split += 1
            if arity_count == 0:
                break
                
        left_list = expr[1:idx_split]
        right_list = expr[idx_split:]
        left_expr = sequence_to_sympy(left_list, vocab, x)
        right_expr = sequence_to_sympy(right_list, vocab, x)
        
        # Handle all binary operators
        if cur_token == 'add':
            return left_expr + right_expr
        elif cur_token == 'sub':
            return left_expr - right_expr
        elif cur_token == 'mul':
            return left_expr * right_expr
        elif cur_token == 'div':
            return left_expr / right_expr
        elif cur_token == 'pow':
            return left_expr ** right_expr
            
    # If we get here, there's an unknown token
    raise ValueError(f"Unknown token: {cur_token}")

In [6]:
def taylor_expansion(function, x, point=0, order=4, precision=5):
    """
    Compute the Taylor expansion of a function.
    
    Args:
        function: SymPy expression
        x: SymPy symbol for variable
        point: Expansion point
        order: Order of Taylor expansion
        precision: Numerical precision
    
    Returns:
        Taylor expansion as a SymPy expression
    """
    try:
        # Try direct substitution first
        value_at_point = function.subs(x, point)
        if not value_at_point.is_finite:
            value_at_point = sp.limit(function, x, point)
            if not value_at_point.is_finite:
                return sp.nan  # Function is undefined at the expansion point

        # Initialize the Taylor series
        taylor_series = value_at_point.evalf(precision)
        dx = x - point
        factorial = 1

        # Compute each term in the series
        for n in range(1, order + 1):
            derivative = sp.diff(function, x, n)

            # Try direct substitution first
            derivative_at_point = derivative.subs(x, point)
            if not derivative_at_point.is_finite:
                derivative_at_point = sp.limit(derivative, x, point)
                if not derivative_at_point.is_finite:
                    continue  # Skip this term if it's undefined

            # Compute factorial
            factorial *= n

            # Compute the Taylor term
            coeff = derivative_at_point.evalf(precision) / factorial
            term = coeff * (dx ** n)
            taylor_series += term

        return taylor_series

    except Exception:
        # Fallback: Use SymPy’s built-in series function
        try:
            return function.series(x, point, order + 1).removeO().evalf(precision)
        except Exception:
            return sp.nan

In [7]:
def not_real(function, x):
    """
    Check if a function has non-real coefficients.
    
    Args:
        function: SymPy expression
        x: SymPy symbol for variable
    
    Returns:
        True if the function has non-real coefficients, False otherwise
    """
    try:
        poly = sp.Poly(function, x)
        
        # Get all coefficients
        coeffs = poly.all_coeffs()
        
        # Check if all coefficients are real
        for coeff in coeffs:
            if sp.im(coeff) != 0:
                return True
        return False

    except Exception as e:
        return True

def is_constant(function):
    """
    Check if a function is constant.
    
    Args:
        function: SymPy expression
    
    Returns:
        True if the function is constant, False otherwise
    """
    if len(function.free_symbols) == 0:
        return True
    
    return all(sp.diff(function, sym) == 0 for sym in function.free_symbols)

In [8]:
def generate_sample(vocab, x, max_depth):
    expr = generate_expression(vocab, max_depth)
    symp = sequence_to_sympy(expr, vocab, x)
    tayl = taylor_expansion(symp, x)
    not_r = not_real(tayl, x)
    is_c = is_constant(tayl)
    is_nan = (tayl == sp.nan)
    if not_r or is_c or is_nan:
        return None
    return (expr, symp, tayl)

In [9]:
def generate_sample_with_timeout(vocab, x, max_depth, timeout=60):
    def worker(conn):
        result = generate_sample(vocab, x, max_depth)
        conn.send(result)
        conn.close()

    parent_conn, child_conn = multiprocessing.Pipe()
    process = multiprocessing.Process(target=worker, args=(child_conn,))
    process.start()
    process.join(timeout)

    if process.is_alive():
        process.terminate()
        process.join()
        return None
    else:
        return parent_conn.recv() if parent_conn.poll() else None

In [10]:
def generate_dataset(num_samples, vocab, x, max_depth, n_workers=16):
    dataset = []
    total_attempts = 0

    with concurrent.futures.ProcessPoolExecutor(max_workers=n_workers) as executor, tqdm(total=num_samples, desc="Generating Samples") as pbar:
        pbar.set_postfix(total_attempts=0, success_rate="0%")

        futures = {executor.submit(generate_sample_with_timeout, vocab, x, max_depth) for _ in range(n_workers)}

        while len(dataset) < num_samples:
            for future in concurrent.futures.as_completed(futures):
                sample = future.result()
                total_attempts += 1

                if sample is not None and sample not in dataset:
                    dataset.append(sample)

                success_rate = (len(dataset) / total_attempts) * 100
                pbar.n = len(dataset)
                pbar.set_postfix(total_attempts=total_attempts, success_rate=f"{success_rate:.2f}%")
                pbar.refresh()

                futures.add(executor.submit(generate_sample_with_timeout, vocab, x, max_depth))
                futures.remove(future)

    return dataset[:num_samples]

In [11]:
dataset = generate_dataset(3000, vocab, x, 3)

Generating Samples:   0%|          | 0/3000 [00:00<?, ?it/s]

In [12]:
df = pd.DataFrame(dataset)
df

Unnamed: 0,0,1,2
0,[x],x,1.0*x
1,"[cos, x]",cos(x),0.041667*x**4 - 0.5*x**2 + 1.0
2,"[tan, sqrt, mul, x, x]",tan(x),0.33333*x**3 + 1.0*x
3,"[tanh, x]",tanh(x),-0.33333*x**3 + 1.0*x
4,"[add, tanh, x, x]",x + tanh(x),-0.33333*x**3 + 2.0*x
...,...,...,...
2995,"[div, cosh, add, x, x, div, cbrt, x, cbrt, x]",cosh(2*x),0.66667*x**4 + 2.0*x**2 + 1.0
2996,"[sub, x, sin, x]",x - sin(x),0.16667*x**3
2997,"[atanh, tan, cb, x]",atanh(tan(x**3)),1.0*x**3
2998,"[div, x, sqrt, div, x, x]",x,1.0*x


In [14]:
df.columns = ['function_tree', 'function', 'taylor']

In [17]:
tqdm.pandas()

In [18]:
df['simplified_functions'] = df['function'].progress_apply(lambda x : sp.simplify(x))

  0%|          | 0/3000 [00:00<?, ?it/s]

In [20]:
test = df

In [21]:
test.shape

(3000, 4)

In [24]:
df.shape

(3000, 4)

In [25]:
df.to_pickle("df.pkl")