In [None]:
from xmm.preprocess._expr2ast import expr2ast
from xmm.preprocess._ast2sympy import ast2sympy


def main(expression):
    # Parse the expression into an AST
    ast_tree = expr2ast(expression)

    # Create SymPy expression from the AST
    sympy_expr = ast2sympy(ast_tree)

    print("SymPy expression:", sympy_expr)

    return sympy_expr

# Example usage
expression = "c1 * (c2 * c2 * (r1 + c3) * (r1 + c3) - 1) * exp(-0.5 * (r1 * r1 * c2 * c2))"
expr = main(expression)
expr

In [None]:
from xmm.preprocess._ast2CUDAexpr import ast2CUDAexpr

def main_taichi(expression):
    # Parse the expression into an AST
    ast_tree = expr2ast(expression)

    # Convert the AST directly to a Taichi expression string
    taichi_expr = ast2CUDAexpr(ast_tree.body)

    print("CUDA expression:", taichi_expr)

    return taichi_expr

# Example usage
expression = "log(c1 * (c2 * c2 * (r1 + c3)) * atan2(r1 + c3, c2) - 1) * exp(-0.5 * (r1 * r1 * c2 * c2))"
taichi_expr = main_taichi(expression)

In [None]:
from xmm.preprocess._sympy2ast import sympy2ast
import astor

# Replacing the main function to include conversion back to AST
def around(expression):
    # Parse the expression into an AST
    ast_tree = expr2ast(expression)

    # Create SymPy expression from the AST
    sympy_expr = ast2sympy(ast_tree)

    # print("SymPy expression:", sympy_expr)

    # Convert the SymPy expression back to AST
    new_ast_tree = sympy2ast(sympy_expr)
    # print('\n'+'*'*50+'\n')
    # print("Reconstructed AST:", ast.dump(new_ast_tree, indent=4))
    # print('\n'+'*'*50+'\n')

    # Use astor.to_source to convert the AST back to a string
    expression_str = astor.to_source(new_ast_tree)

    print("Reconstructed Expression String:", expression_str.strip())

    return new_ast_tree

# Example usage
expression = "c1 * (c2 * c2 * (r1 + c3) * (r1 + c3) - 1) * atan2(-0.5 * (r1 * r1 * c2 * c2), c1)"
new_ast = around(expression)

In [None]:
from xmm.codegen.codegen import generate_expr
# Example usage
expression = "c1 * (c2 * c2 * (r1 + c3) * (r1 + c3) - 1) * atan2(-0.5 * (r1 * r1 * c2 * c2), c1)"
taichi_expr, taichi_derivatives = generate_expr(expression)

print('\n'+'*'*50+'\n')

print("Original CUDA Expression:\n", taichi_expr)
print("CUDA Derivatives:")
for var, expr in taichi_derivatives.items():
    print(f"d/d{var}:", expr)

In [None]:
from xmm.codegen.codegen import generate_operator_source

a, b = generate_operator_source(1, 3, "c1 * (c2 * c2 * (r1 + c3) * (r1 + c3) - 1) * atan2(-0.5 * (r1 * r1 * c2 * c2), c1)")
print(b)

## Example Main

In [None]:
from xmm.SumOperator import SumOperator
import torch

op = SumOperator(1, 3, "0.867325070 * c3 * (c1 * c1 * (r1 + c2) * (r1 + c2) - 1) * exp(-0.5 * c1 * c1 * (r1 + c2) * (r1 + c2))")
op.compile()


In [45]:
import torch
device = torch.device('cuda:0')

M, N, K = 128, 42, 62
a = torch.randn(M, K).to(device)
b = torch.randn(N, K).to(device)
c = torch.randn(N, K).to(device)
d = torch.randn(N, K).to(device)


opf = op.forward(a, b, c, d)
print(opf)
import wavkan.wavelet
print(wavkan.wavelet.mexhat_wavelet(a, b, c, d))
from math import exp
res = 0
for k in range(K):
    res += d[0, k] * ((b[0, k] ** 2) * ((a[0, k] + c[0, k]) ** 2) - 1) * exp(-0.5 * (b[0, k] ** 2) * ((a[0, k] + c[0, k]) ** 2))
print(res * 0.867325070)

tensor([[-0.3832, -1.1263, -2.8932,  ...,  7.3758, -5.1120,  1.2537],
        [-1.8569,  1.8759,  2.7706,  ...,  0.3477,  1.3390, -1.2371],
        [-2.5918, -2.4867,  3.5251,  ...,  5.5163, -4.5873,  2.8335],
        ...,
        [ 4.0604,  0.2948,  2.7115,  ...,  3.9962, -6.0438,  6.3948],
        [ 1.6745,  0.4129,  2.2036,  ...,  5.3293, -3.4070,  1.0890],
        [ 2.4895,  0.6932,  4.4876,  ...,  4.4679, -3.9333,  1.3650]],
       device='cuda:0')
tensor([[-0.3832, -1.1263, -2.8932,  ...,  7.3758, -5.1120,  1.2537],
        [-1.8569,  1.8759,  2.7706,  ...,  0.3477,  1.3390, -1.2371],
        [-2.5918, -2.4867,  3.5251,  ...,  5.5163, -4.5873,  2.8335],
        ...,
        [ 4.0604,  0.2948,  2.7115,  ...,  3.9962, -6.0438,  6.3948],
        [ 1.6745,  0.4129,  2.2036,  ...,  5.3293, -3.4070,  1.0890],
        [ 2.4895,  0.6932,  4.4876,  ...,  4.4679, -3.9333,  1.3650]],
       device='cuda:0')
tensor(-0.3832, device='cuda:0')
