In [None]:
from collections import namedtuple

import numpy as np

import sympy as sym
from sympy import Symbol, symbols, Matrix, ImmutableMatrix, MatrixSymbol, Function, simplify, diff
from sympy.utilities.codegen import codegen
from sympy.utilities.lambdify import implemented_function
from sympy.utilities.iterables import numbered_symbols
from sympy.codegen.ast import Assignment, FunctionDefinition
from sympy.printing.cxx import CXX11CodePrinter
from sympy.codegen.rewriting import create_expand_pow_optimization

In [None]:
NamedExpr = namedtuple("NamedExpr", ["name", "expr"])

def name_expr(expr, name):
    if hasattr(expr, "shape"):
        s = sym.MatrixSymbol(name, *expr.shape)
    else:
        s = Symbol(name)
    return NamedExpr(s, expr)

In [None]:
# lambda for q/p
l = Symbol("lambda")

# h for step length
h = Symbol("h")

# p for position
p = MatrixSymbol("p", 3, 1)

# d for direction
d = MatrixSymbol("d", 3, 1)

In [None]:
B1 = MatrixSymbol("B1", 3, 1)
B2 = MatrixSymbol("B2", 3, 1)
B3 = MatrixSymbol("B3", 3, 1)

k1 = name_expr(l * d.as_explicit().cross(B1), "k1")
p2 = name_expr(p + h/2 * d + h**2/8 * k1.name, "p2")

k2 = name_expr(l * (d + h/2 * k1.name).as_explicit().cross(B2), "k2")
k3 = name_expr(l * (d + h/2 * k2.name).as_explicit().cross(B2), "k3")
p3 = name_expr(p + h * d + h**2/2 * k3.name, "p3")

k4 = name_expr(l * (d + h * k3.name).as_explicit().cross(B3), "k4")

err = name_expr(h**2 * (k1.name - k2.name - k3.name + k4.name).as_explicit().norm(1), "err")

dtds = name_expr(sym.sqrt(1 + m**2 / p_abs**2), "dtds")

new_p = name_expr(p + h * d + h**2/6 * (k1.name + k2.name + k3.name), "new_p")
new_d = name_expr(d + h/6 * (k1.name + 2 * (k2.name + k3.name) + k4.name), "new_d")
new_d_norm = name_expr(new_d_tmp / new_d_tmp.as_explicit().norm(), "new_d")

In [None]:
t = Symbol("t")
m = Symbol("m")
p_abs = Symbol("p_abs")

D = Matrix(np.zeros((8, 8)))
D[0:3,:] = new_p.expr.subs([k3,k2,k1]).simplify().as_explicit().jacobian([p, t, d, l])
D[3,7] = h * m**2 * l / dtds;
D[4:7,:] = new_d.expr.subs([k4,k3,k2,k1]).simplify().as_explicit().jacobian([p, t, d, l])
D[7,7] = 1
D = sym.simplify(D)
D.subs([(k1.expr,k1.name),(k2.expr,k2.name),(k3.expr,k3.name),(k4.expr,k4.name)])
D = sym.simplify(D)
D = name_expr(D, "D")

In [None]:
dk1dT = name_expr(k1.expr.jacobian(d), "dk1dT")
dk2dT = name_expr(k2.expr.jacobian(d) + k2.expr.jacobian(k1.name) * dk1dT.expr, "dk2dT")
dk3dT = name_expr(k3.expr.jacobian(d) + k3.expr.jacobian(k2.name) * dk2dT.name.as_explicit(), "dk3dT")
dk4dT = name_expr(k4.expr.jacobian(d) + k4.expr.jacobian(k3.name) * dk3dT.name.as_explicit(), "dk4dT")

dk1dL = name_expr(k1.expr.diff(l), "dk1dL")
dk2dL = name_expr(k2.expr.diff(l) + k2.expr.jacobian(k1.name) * dk1dL.expr, "dk2dL")
dk3dL = name_expr(k3.expr.diff(l) + k3.expr.jacobian(k2.name) * dk2dL.name.as_explicit(), "dk3dL")
dk4dL = name_expr(k4.expr.diff(l) + k4.expr.jacobian(k3.name) * dk3dL.name.as_explicit(), "dk4dL")

dFdT = name_expr(
    new_p.expr.as_explicit().jacobian(p) +
    new_p.expr.as_explicit().jacobian(k1.name) * dk1dT.expr +
    new_p.expr.as_explicit().jacobian(k2.name) * dk2dT.name.as_explicit() +
    new_p.expr.as_explicit().jacobian(k3.name) * dk3dT.name.as_explicit(), "dFdT")
dGdT = name_expr(
    new_d.expr.as_explicit().jacobian(p) +
    new_d.expr.as_explicit().jacobian(k1.name) * dk1dT.expr +
    new_d.expr.as_explicit().jacobian(k2.name) * dk2dT.name.as_explicit() +
    new_d.expr.as_explicit().jacobian(k3.name) * dk3dT.name.as_explicit() +
    new_d.expr.as_explicit().jacobian(k4.name) * dk4dT.name.as_explicit(), "dGdT")
dFdL = name_expr(
    new_p.expr.as_explicit().diff(l) +
    new_p.expr.as_explicit().jacobian(k1.name) * dk1dL.expr +
    new_p.expr.as_explicit().jacobian(k2.name) * dk2dL.name.as_explicit() +
    new_p.expr.as_explicit().jacobian(k3.name) * dk3dL.name.as_explicit(), "dFdL")
dGdL = name_expr(
    new_d.expr.as_explicit().diff(l) +
    new_d.expr.as_explicit().jacobian(k1.name) * dk1dL.expr +
    new_d.expr.as_explicit().jacobian(k2.name) * dk2dL.name.as_explicit() +
    new_d.expr.as_explicit().jacobian(k3.name) * dk3dL.name.as_explicit() +
    new_d.expr.as_explicit().jacobian(k4.name) * dk4dL.name.as_explicit(), "dGdL")

D = sym.eye(8)
D[0:3,4:7] = dFdT.expr
D[0:3,7:8] = dFdL.expr
D[4:7,4:7] = dGdT.expr
D[4:7,7:8] = dGdL.expr
D[3,7] = h * m**2 * l / dtds.name;

J = Matrix(MatrixSymbol("J", 8, 8).as_explicit())
for indices in np.ndindex(J.shape):
    if D[indices] in [0, 1]:
        J[indices] = D[indices]
J = ImmutableMatrix(J)
J_new = name_expr(J * D, "J_new")

In [None]:
# code gen

In [None]:
cxx_printer = CXX11CodePrinter()

In [None]:
def inflate_expr(name_expr):
    name, expr = name_expr

    result = []
    references= []

    if hasattr(expr, "shape"):
        for indices in np.ndindex(expr.shape):
            result.append((name[indices], expr[indices]))
            references.append((name, expr.shape, indices))
    else:
        result.append((name, expr))
        references.append(None)

    return result, references

def inflate_exprs(name_exprs):
    result = []
    references= []
    for name_expr in name_exprs:
        res, refs = inflate_expr(name_expr)
        result.extend(res)
        references.extend(refs)
    return result, references

def deflate_exprs(name_exprs, references):
    result = []
    deflated = {}

    for name_expr, reference in zip(name_exprs, references):
        if reference is None:
            result.append(name_expr)
        else:
            _, expr = name_expr
            name, shape, indices = reference
            if name not in deflated:
                e = Matrix(np.zeros(shape))
                result.append(NamedExpr(name, e))
                deflated[name] = e
            deflated[name][*indices] = expr

    another_result = []
    for name_expr in result:
        name, expr = name_expr
        if isinstance(expr, Matrix):
            another_result.append(NamedExpr(name, ImmutableMatrix(expr)))
        else:
            another_result.append(name_expr)

    return another_result

In [None]:
def build_dependency_graph(name_exprs):
    graph = {}
    for name, expr in name_exprs:
        graph[name] = expr.free_symbols
    return graph

def build_influence_graph(name_exprs):
    graph = {}
    for name, expr in name_exprs:
        for s in expr.free_symbols:
            graph.setdefault(s, set()).add(name)
    return graph

In [None]:
def order_exprs_by_input(name_exprs):
    all_expr_names = set().union(name for name, _ in name_exprs)
    all_expr_symbols = set().union(*[expr.free_symbols for _, expr in name_exprs])
    inputs = all_expr_symbols - all_expr_names

    order = {}

    order.update({i: 0 for i in inputs})

    while len(order) < len(inputs) + len(name_exprs):
        for name, expr in name_exprs:
            symbols_order = [order.get(s, None) for s in expr.free_symbols]
            if None in symbols_order:
                continue
            order[name] = max(symbols_order) + 1

    result = name_exprs
    result = sorted(result, key=lambda n_e: len(n_e[1].args))
    result = sorted(result, key=lambda n_e: len(n_e[1].free_symbols))
    result = sorted(result, key=lambda n_e: order[n_e[0]])
    return result

In [None]:
def order_exprs_by_output(name_exprs, outputs):
    name_expr_by_name = {name_expr[0]: name_expr for name_expr in name_exprs}
    
    def get_inputs(output): 
        name_expr = name_expr_by_name.get(output, None)
        if name_expr is None:
            return set()
        inputs = set(name_expr[1].free_symbols)
        inputs.update(*[get_inputs(name) for name in inputs])
        return inputs

    result = []
    done = set()

    for output in outputs:
        inputs = get_inputs(output) - done
        result.extend(order_exprs_by_input([name_exprs for name_exprs in name_exprs if name_exprs[0] in inputs]))
        result.append(name_expr_by_name[output])
        done.update(inputs)
        done.add(output)

    return result

In [None]:
def my_cse(name_exprs, inflate_deflate=True):
    sub_symbols = numbered_symbols()

    if inflate_deflate:
        name_exprs, references = inflate_exprs(name_exprs)

    names = [x[0] for x in name_exprs]
    exprs = [x[1] for x in name_exprs]

    sub_exprs, simp_exprs = sym.cse(exprs, symbols=sub_symbols)

    simp_name_exprs = list(zip(names, simp_exprs))
    if inflate_deflate:
        simp_name_exprs = deflate_exprs(simp_name_exprs, references)

    name_exprs = []
    name_exprs.extend(sub_exprs)
    name_exprs.extend(simp_name_exprs)

    return name_exprs

In [None]:
def my_expression_print(printer, name_exprs, outputs, run_cse=True):
    def print_assign(var, expr):
        code = printer.doprint(Assignment(var, expr))
        return code

    if run_cse:
        name_exprs = my_cse(name_exprs, inflate_deflate=True)
    name_exprs = order_exprs_by_output(name_exprs, outputs)

    lines = []

    for var, expr in name_exprs:
        code = printer.doprint(Assignment(var, expr))
        if var not in outputs:
            if hasattr(expr, "shape"):
                lines.append(f"T {var}[{expr.shape[0]}];")
                lines.extend(code.split("\n"))
            else:
                lines.append("const auto " + code)
        else:
            if hasattr(expr, "shape"):
                lines.extend(code.split("\n"))
            else:
                lines.append("*" + code)

    return "\n".join(lines)

def my_function_print(printer, name, inputs, name_exprs, outputs, run_cse=True):
    def input_param(input):
        if isinstance(input, MatrixSymbol):
            return f"const T* {input.name}"
        return f"const T {input.name}"
    def output_param(name):
        return f"T* {name}"

    lines = []

    params = [input_param(input) for input in inputs] + [output_param(output) for output in outputs]
    head = f"template<typename T> void {name}({", ".join(params)}) {{"

    lines.append(head)

    code = my_expression_print(printer, name_exprs, outputs, run_cse)
    lines.extend([f"  {l}" for l in code.split("\n")])

    lines.append("}")

    return "\n".join(lines)

In [None]:
# all in one

In [None]:
B1 = name_expr(p, "B1")
B2 = name_expr(p2.name, "B2")
B3 = name_expr(p3.name, "B3")

code = my_function_print(
    cxx_printer,
    "rk4",
    [p, d, h, l],
    [B1, k1, p2, B2, k2, k3, p3, B3, k4, err, new_p, new_d, dk2dT, dk3dT, dk4dT, dk2dL, dk3dL, dk4dL, dFdT, dFdL, dGdT, dGdL, J_new],
    [err.name, new_p.name, new_d.name, J_new.name],
    run_cse=True,
)
print(code)

In [None]:
# step by step

In [None]:
B = MatrixSymbol('B', 3, 1)
k1 = name_expr(l * d.as_explicit().cross(B), "k1")

print(my_function_print(cxx_printer, "rk4_k1", [d, l, B], [k1], [k1.name]))

k1 = MatrixSymbol('k1', 3, 1)
p2 = name_expr(p + h/2 * d + h**2/8 * k1, "p2")
print(my_function_print(cxx_printer, "rk4_p2", [p, d, h, k1], [p2], [p2.name]))

In [None]:
B = MatrixSymbol('B', 3, 1)
k1 = MatrixSymbol('k1', 3, 1)
k2 = l * (d + h/2 * k1).as_explicit().cross(B)

print(my_function_print(cxx_printer, "rk4_k2", [d, h, l, B, k1], [("k2", k2)]))

In [None]:
B = MatrixSymbol('B', 3, 1)
k2 = MatrixSymbol('k2', 3, 1)
k3 = l * (d + h/2 * k2).as_explicit().cross(B)

print(my_function_print(cxx_printer, "rk4_k3", [d, h, l, B, k2], [("k3", k3)]))

k3 = MatrixSymbol('k3', 3, 1)
p3 = p + h * d + h**2/2 * k3
print(my_function_print(cxx_printer, "rk4_p3", [p, d, h, k3], [("p3", p3)]))

In [None]:
B = MatrixSymbol('B', 3, 1)
k3 = MatrixSymbol('k3', 3, 1)
k4 = l * (d + h * k3).as_explicit().cross(B)

print(my_function_print(cxx_printer, "rk4_k4", [d, h, l, B, k3], [("k4", k4)]))

In [None]:
k1 = MatrixSymbol('k1', 3, 1)
k2 = MatrixSymbol('k2', 3, 1)
k3 = MatrixSymbol('k3', 3, 1)
k4 = MatrixSymbol('k4', 3, 1)

err = h**2 * (k1 - k2 - k3 + k4).as_explicit().norm(1)

print(my_function_print(cxx_printer, "rk4_err", [h, k1, k2, k3, k4], [("err", err)]))

In [None]:
k1 = MatrixSymbol('k1', 3, 1)
k2 = MatrixSymbol('k2', 3, 1)
k3 = MatrixSymbol('k3', 3, 1)
k4 = MatrixSymbol('k4', 3, 1)

new_p = p + h * d + h**2/6 * (k1 + k2 + k3)
new_d_tmp = d + h/6 * (k1 + 2 * (k2 + k3) + k4)
new_d = new_d_tmp / new_d_tmp.as_explicit().norm()

print(my_function_print(cxx_printer, "rk4_fin", [p, d, h, k1, k2, k3, k4], [("new_p", new_p), ("new_d", new_d)]))