In [None]:
from collections import namedtuple

import sympy as sym
from sympy import Symbol, symbols, Matrix, MatrixSymbol, Function
from sympy.utilities.codegen import codegen
from sympy.utilities.lambdify import implemented_function
from sympy.utilities.iterables import numbered_symbols
from sympy.codegen.ast import Assignment, FunctionDefinition
from sympy.printing.cxx import CXX11CodePrinter
from sympy.codegen.rewriting import create_expand_pow_optimization

In [None]:
NamedExpr = namedtuple("NamedExpr", ["name", "expr"])

def name_expr(expr, name):
    if hasattr(expr, "shape"):
        s = sym.MatrixSymbol(name, *expr.shape)
    else:
        s = Symbol(name)
    return NamedExpr(s, expr)

In [None]:
# lambda for q/p
l = Symbol('lambda')

# h for step length
h = Symbol('h')

# p for position
p = MatrixSymbol('p', 3, 1)

# d for direction
d = MatrixSymbol('d', 3, 1)

In [None]:
B1 = MatrixSymbol('B1', 3, 1)
B2 = MatrixSymbol('B2', 3, 1)
B3 = MatrixSymbol('B3', 3, 1)

k1 = name_expr(l * d.as_explicit().cross(B1), "k1")
p2 = name_expr(p + h/2 * d + h**2/8 * k1.name, "p2")

k2 = name_expr(l * (d + h/2 * k1.name).as_explicit().cross(B2), "k2")
k3 = name_expr(l * (d + h/2 * k2.name).as_explicit().cross(B2), "k3")
p3 = name_expr(p + h * d + h**2/2 * k3.name, "p3")

k4 = name_expr(l * (d + h * k3.name).as_explicit().cross(B3), "k4")

err = name_expr(h**2 * (k1.name - k2.name - k3.name + k4.name).as_explicit().norm(1), "err")

In [None]:
# code gen

In [None]:
cxx_printer = CXX11CodePrinter()

In [None]:
def my_function_print(printer, name, inputs, output_exprs):
    lines = []

    def input_param(input):
        if isinstance(input, MatrixSymbol):
            return f"const T* {input.name}"
        return f"const T {input.name}"
    def output_param(name, expr):
        return f"T* {name}"

    params = [input_param(input) for input in inputs] + [output_param(name, expr) for name, expr in output_exprs]
    head = f"template<typename T> void {name}({", ".join(params)}) {{"

    lines.append(head)

    names = [i[0] for i in output_exprs]
    expr = [i[1] for i in output_exprs]
    sub_exprs, simp_exprs = sym.cse(expr)

    for var, expr in sub_exprs:
        if hasattr(expr, "shape"):
            lhs = sym.MatrixSymbol(f'{var}', *expr.shape)
            code = printer.doprint(Assignment(lhs, expr))
            lines.extend(['  const auto ' + l for l in code.split('\n')])
        else:
            lhs = sym.Symbol(f'{var}')
            code = printer.doprint(Assignment(lhs, expr))
            lines.extend(['  const auto ' + l for l in code.split('\n')])

    for (name, expr) in zip (names, simp_exprs):
        if hasattr(expr, "shape"):
            lhs = sym.MatrixSymbol(f'{name}', *expr.shape)
            code = printer.doprint(Assignment(lhs, expr))
            lines.extend(['  ' + l for l in code.split('\n')])
        else:
            lhs = sym.Symbol(f'{name}')
            code = printer.doprint(Assignment(lhs, expr))
            lines.extend(['  *' + l for l in code.split('\n')])

    lines.append('}')

    return '\n'.join(lines)

In [None]:
def order_exprs(name_exprs, iter=100):
    all_expr_names = set().union(name for name, _ in name_exprs)
    all_expr_symbols = set().union(*[expr.free_symbols for _, expr in name_exprs])
    inputs = all_expr_symbols - all_expr_names

    order = {}

    order.update({i: 0 for i in inputs})

    for _ in range(iter):
        for name, expr in name_exprs:
            symbols_order = [order.get(s, None) for s in expr.free_symbols]
            if None in symbols_order:
                continue
            order[name] = max(symbols_order) + 1

    if len(order) < len(inputs) + len(name_exprs):
        print(list(order.keys()))
        print(inputs)
        print([name for name, expr in name_exprs])
        raise RuntimeError("did not finish")

    return sorted(
        sorted(
            sorted(name_exprs, key=lambda n_e: len(n_e[1].args)),
            key=lambda n_e: len(n_e[1].free_symbols)
        ),
        key=lambda n_e: order[n_e[0]]
    )

In [None]:
def my_cse(name_exprs, iter=2):
    result_name_exprs = name_exprs

    sub_symbols = numbered_symbols()

    for _ in range(iter):
        names = [x[0] for x in result_name_exprs]
        exprs = [x[1] for x in result_name_exprs]

        sub_exprs, simp_exprs = sym.cse(exprs, symbols=sub_symbols)
        if len(sub_exprs) == 0:
            break
        result_name_exprs = []
        result_name_exprs.extend(sub_exprs)
        result_name_exprs.extend(zip(names, simp_exprs))

    return result_name_exprs

In [None]:
def my_function_print_2(printer, name, inputs, name_exprs, outputs):
    def input_param(input):
        if isinstance(input, MatrixSymbol):
            return f"const T* {input.name}"
        return f"const T {input.name}"
    def output_param(name):
        return f"T* {name}"
    def print_assign(var, expr):
        if isinstance(var, Symbol):
            lhs = var
        elif hasattr(expr, "shape"):
            lhs = sym.MatrixSymbol(var, *expr.shape)
        else:
            lhs = sym.Symbol(var)
        code = printer.doprint(Assignment(lhs, expr))
        return code

    all_name_exprs = my_cse(name_exprs)
    all_name_exprs = order_exprs(all_name_exprs)

    lines = []

    params = [input_param(input) for input in inputs] + [output_param(output) for output in outputs]
    head = f"template<typename T> void {name}({", ".join(params)}) {{"

    lines.append(head)

    for var, expr in all_name_exprs:
        code = print_assign(var, expr)
        if var not in outputs:
            code = print_assign(var, expr)
            if hasattr(expr, "shape"):
                lines.append(f"  T {var}[{expr.shape[0]}];")
                lines.extend(['  ' + l for l in code.split('\n')])
            else:
                lines.extend(['  const auto ' + l for l in code.split('\n')])
        else:
            code = print_assign(var, expr)
            if hasattr(expr, "shape"):
                lines.extend(['  ' + l for l in code.split('\n')])
            else:
                lines.extend(['  *' + l for l in code.split('\n')])

    lines.append('}')

    return '\n'.join(lines)

In [None]:
# all in one

In [None]:
print(my_function_print_2(cxx_printer, "rk4", [p, d, h, l], [k1, p2, k2, k3, p3, k4, err], [k1.name, k2.name, k3.name, k4.name, err.name]))

In [None]:
# step by step

In [None]:
B = MatrixSymbol('B', 3, 1)
k1 = l * d.as_explicit().cross(B)

print(my_function_print(cxx_printer, "rk4_k1", [d, l, B], [("k1", k1)]))

k1 = MatrixSymbol('k1', 3, 1)
p2 = p + h/2 * d + h**2/8 * k1
print(my_function_print(cxx_printer, "rk4_p2", [p, d, h, k1], [("p2", p2)]))

In [None]:
B = MatrixSymbol('B', 3, 1)
k1 = MatrixSymbol('k1', 3, 1)
k2 = l * (d + h/2 * k1).as_explicit().cross(B)

print(my_function_print(cxx_printer, "rk4_k2", [d, h, l, B, k1], [("k2", k2)]))

In [None]:
B = MatrixSymbol('B', 3, 1)
k2 = MatrixSymbol('k2', 3, 1)
k3 = l * (d + h/2 * k2).as_explicit().cross(B)

print(my_function_print(cxx_printer, "rk4_k3", [d, h, l, B, k2], [("k3", k3)]))

k3 = MatrixSymbol('k3', 3, 1)
p3 = p + h * d + h**2/2 * k3
print(my_function_print(cxx_printer, "rk4_p3", [p, d, h, k3], [("p3", p3)]))

In [None]:
B = MatrixSymbol('B', 3, 1)
k3 = MatrixSymbol('k3', 3, 1)
k4 = l * (d + h * k3).as_explicit().cross(B)

print(my_function_print(cxx_printer, "rk4_k4", [d, h, l, B, k3], [("k4", k4)]))

In [None]:
k1 = MatrixSymbol('k1', 3, 1)
k2 = MatrixSymbol('k2', 3, 1)
k3 = MatrixSymbol('k3', 3, 1)
k4 = MatrixSymbol('k4', 3, 1)

err = h**2 * (k1 - k2 - k3 + k4).as_explicit().norm(1)

print(my_function_print(cxx_printer, "rk4_err", [h, k1, k2, k3, k4], [("err", err)]))

In [None]:
k1 = MatrixSymbol('k1', 3, 1)
k2 = MatrixSymbol('k2', 3, 1)
k3 = MatrixSymbol('k3', 3, 1)
k4 = MatrixSymbol('k4', 3, 1)

new_p = p + h * d + h**2/6 * (k1 + k2 + k3)
new_d_tmp = d + h/6 * (k1 + 2 * (k2 + k3) + k4)
new_d = new_d_tmp / new_d_tmp.as_explicit().norm()

print(my_function_print(cxx_printer, "rk4_fin", [p, d, h, k1, k2, k3, k4], [("new_p", new_p), ("new_d", new_d)]))