In [None]:
from collections import namedtuple

import numpy as np

import sympy as sym
from sympy import Symbol, symbols, Matrix, ImmutableMatrix, MatrixSymbol, Function, simplify, diff, Integer
from sympy.utilities.codegen import codegen
from sympy.utilities.lambdify import implemented_function
from sympy.utilities.iterables import numbered_symbols
from sympy.codegen.ast import Assignment, FunctionDefinition
from sympy.printing.cxx import CXX17CodePrinter
from sympy.codegen.rewriting import create_expand_pow_optimization

In [None]:
sym.init_printing()

In [None]:
NamedExpr = namedtuple("NamedExpr", ["name", "expr"])

def name_expr(name, expr):
    if hasattr(expr, "shape"):
        s = sym.MatrixSymbol(name, *expr.shape)
    else:
        s = Symbol(name)
    return NamedExpr(s, expr)

def find_by_name(name_exprs, name):
    return next((name_expr for name_expr in name_exprs if str(name_expr[0]) == name), None)

In [None]:
# lambda for q/p
l = Symbol("lambda")

# h for step length
h = Symbol("h")

# p for position
p = MatrixSymbol("p", 3, 1)

# d for direction
d = MatrixSymbol("d", 3, 1)

# time
t = Symbol("t")

# mass
m = Symbol("m")

# absolute momentum
p_abs = Symbol("p_abs")

# magnetic field
B1 = MatrixSymbol("B1", 3, 1)
B2 = MatrixSymbol("B2", 3, 1)
B3 = MatrixSymbol("B3", 3, 1)

In [None]:
def rk4_short_math():
    k1 = name_expr("k1", d.as_explicit().cross(l * B1))
    p2 = name_expr("p2", p + h/2 * d + h**2/8 * k1.name)

    k2 = name_expr("k2", (d + h/2 * k1.name).as_explicit().cross(l * B2))
    k3 = name_expr("k3", (d + h/2 * k2.name).as_explicit().cross(l * B2))
    p3 = name_expr("p3", p + h * d + h**2/2 * k3.name)

    k4 = name_expr("k4", (d + h * k3.name).as_explicit().cross(l * B3))

    err = name_expr("err", h**2 * (k1.name - k2.name - k3.name + k4.name).as_explicit().norm(1))

    new_p = name_expr("new_p", p + h * d + h**2/6 * (k1.name + k2.name + k3.name))
    new_d = name_expr("new_d", d + h/6 * (k1.name + 2 * (k2.name + k3.name) + k4.name))
    new_d_norm = name_expr("new_d", new_d.expr / new_d.expr.as_explicit().norm())

    dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
    new_time = name_expr("new_time", t + h * dtds.name)

    path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
    path_derivatives.expr[0:3,0] = new_d.name.as_explicit()
    path_derivatives.expr[3,0] = dtds.name
    path_derivatives.expr[4:7,0] = k4.name.as_explicit()

    # get the full jacobian by substituting all the expressions above
    dk1dTL = name_expr("dk1dTL", k1.expr.jacobian([d, l]))
    dk2dTL = name_expr("dk2dTL", k2.expr.jacobian([d, l]) + k2.expr.jacobian(k1.name) * dk1dTL.expr)
    dk3dTL = name_expr("dk3dTL", k3.expr.jacobian([d, l]) + k3.expr.jacobian(k2.name) * dk2dTL.name.as_explicit())
    dk4dTL = name_expr("dk4dTL", k4.expr.jacobian([d, l]) + k4.expr.jacobian(k3.name) * dk3dTL.name.as_explicit())

    dFdTL = name_expr("dFdTL",
        new_p.expr.as_explicit().jacobian([d, l]) +
        new_p.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr +
        new_p.expr.as_explicit().jacobian(k2.name) * dk2dTL.name.as_explicit() +
        new_p.expr.as_explicit().jacobian(k3.name) * dk3dTL.name.as_explicit())
    dGdTL = name_expr("dGdTL",
        new_d.expr.as_explicit().jacobian([d, l]) +
        new_d.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr +
        new_d.expr.as_explicit().jacobian(k2.name) * dk2dTL.name.as_explicit() +
        new_d.expr.as_explicit().jacobian(k3.name) * dk3dTL.name.as_explicit() +
        new_d.expr.as_explicit().jacobian(k4.name) * dk4dTL.name.as_explicit())

    D = sym.eye(8)
    D[0:3,4:8] = dFdTL.name.as_explicit()
    D[4:7,4:8] = dGdTL.name.as_explicit()
    D[3,7] = h * m**2 * l / dtds.name

    J = Matrix(MatrixSymbol("J", 8, 8).as_explicit())
    for indices in np.ndindex(J.shape):
        if D[indices] in [0, 1]:
            J[indices] = D[indices]
    J = ImmutableMatrix(J)
    new_J = name_expr("new_J", J * D)

    return [k1, p2, k2, k3, p3, k4, err, new_p, new_d, dtds, new_time, path_derivatives, dk2dTL, dk3dTL, dk4dTL, dFdTL, dGdTL, new_J]

rk4_short_math();

In [None]:
def rk4_full_math():
    k1 = name_expr("k1", d.as_explicit().cross(l * B1))
    p2 = name_expr("p2", p + h/2 * d + h**2/8 * k1.expr)

    k2 = name_expr("k2", (d + h/2 * k1.expr).as_explicit().cross(l * B2))
    k3 = name_expr("k3", (d + h/2 * k2.expr).as_explicit().cross(l * B2))
    p3 = name_expr("p3", p + h * d + h**2/2 * k3.expr)

    k4 = name_expr("k4", (d + h * k3.expr).as_explicit().cross(l * B3))

    err = name_expr("err", h**2 * (k1.expr - k2.expr - k3.expr + k4.expr).norm(1))

    new_p = name_expr("new_p", p + h * d + h**2/6 * (k1.expr + k2.expr + k3.expr))
    new_d = name_expr("new_d", d + h/6 * (k1.expr + 2 * (k2.expr + k3.expr) + k4.expr))
    new_d_norm = name_expr("new_d", new_d.expr / new_d.expr.as_explicit().norm())

    dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
    new_time = name_expr("new_time", t + h * dtds.expr)

    path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
    path_derivatives.expr[0:3,0] = new_d.expr.as_explicit()
    path_derivatives.expr[3,0] = dtds.expr
    path_derivatives.expr[4:7,0] = k4.expr.as_explicit()

    # get the full jacobian step by step with intermediate expressions
    D = sym.eye(8)
    D[0:3,:] = new_p.expr.as_explicit().jacobian([p, t, d, l])
    D[4:7,:] = new_d.expr.as_explicit().jacobian([p, t, d, l])
    D[3,7] = h * m**2 * l / dtds.expr;

    J = MatrixSymbol("J", 8, 8).as_explicit().as_mutable()
    for indices in np.ndindex(J.shape):
        if D[indices] in [0, 1]:
            J[indices] = D[indices]
    J = ImmutableMatrix(J)

    new_J = name_expr("new_J", J * D)

    return [p2, p3, err, new_p, new_d, new_time, path_derivatives, new_J]

rk4_full_math();

In [None]:
# full transport

In [None]:
step_path_derivatives = MatrixSymbol("step_path_derivatives", 8, 1).as_explicit().as_mutable()
step_path_derivatives[7,0] = 0 # qop

surface_path_derivatives = MatrixSymbol("surface_path_derivatives", 1, 8).as_explicit().as_mutable()
surface_path_derivatives[0,3] = 0
surface_path_derivatives[0,7] = 0

J_bf = MatrixSymbol("J_bf", 8, 6).as_explicit().as_mutable()
tmp = sym.zeros(8, 6)
tmp[0:3,0:2] = J_bf[0:3,0:2]
tmp[0:3,2:4] = J_bf[0:3,2:4]
tmp[4:7,2:4] = J_bf[4:7,2:4] # line surface
tmp[3,5] = 1
tmp[7,4] = 1
J_bf = tmp

J_t = MatrixSymbol("J_t", 8, 8).as_explicit().as_mutable()
tmp = sym.eye(8)
tmp[0:3,4:8] = J_t[0:3,4:8]
tmp[3,7] = J_t[3,7]
tmp[4:7,4:8] = J_t[4:7,4:8]
J_t = tmp

J_fb = MatrixSymbol("J_fb", 6, 8).as_explicit().as_mutable()
tmp = sym.zeros(6, 8)
tmp[0:2,0:3] = J_fb[0:2,0:3]
tmp[2:4,4:7] = J_fb[2:4,4:7]
tmp[5,3] = 1
tmp[4,7] = 1
J_fb = tmp

def full_transport_jacobian_generic():
    J_full = name_expr("J_full", J_fb * (sym.eye(8) + step_path_derivatives * surface_path_derivatives) * J_t * J_bf)

    return [J_full]

def full_transport_jacobian_curvilinear(direction):
    surface_path_derivatives = MatrixSymbol("surface_path_derivatives", 1, 8).as_explicit().as_mutable()
    surface_path_derivatives[0,0:3] = -direction.as_explicit().transpose()
    surface_path_derivatives[0,3:8] = sym.zeros(1, 5)

    J_full = name_expr("J_full", J_fb * (sym.eye(8) + step_path_derivatives * surface_path_derivatives) * J_t * J_bf)

    return [J_full]

full_transport_jacobian_generic();
full_transport_jacobian_curvilinear(MatrixSymbol("dir", 3, 1));

In [None]:
C = MatrixSymbol("C", 6, 6).as_explicit().as_mutable()
for indices in np.ndindex(C.shape):
    C[indices] = C[tuple(sorted(indices))]

J_full = MatrixSymbol("J_full", 6, 6).as_explicit().as_mutable()
tmp = sym.eye(6)
tmp[0:4,0:5] = J_full[0:4,0:5]
tmp[5:6,0:5] = J_full[5:6,0:5]
J_full = tmp

def covariance_transport_generic():
    new_C = name_expr("new_C", J_full * C * J_full.T)

    return [new_C]

covariance_transport_generic();

In [None]:
# code gen

In [None]:
class MyCXXCodePrinter(CXX17CodePrinter):
    def _traverse_matrix_indices(self, mat):
        rows, cols = mat.shape
        return ((i, j) for j in range(cols) for i in range(rows))

    def _print_MatrixElement(self, expr):
        from sympy.printing.precedence import PRECEDENCE

        return "{}[{}]".format(self.parenthesize(expr.parent, PRECEDENCE["Atom"],
            strict=True), expr.i + expr.j*expr.parent.shape[0])

cxx_printer = MyCXXCodePrinter()

In [None]:
def inflate_expr(name_expr):
    name, expr = name_expr

    result = []
    references= []

    if hasattr(expr, "shape"):
        for indices in np.ndindex(expr.shape):
            result.append((name[indices], expr[indices]))
            references.append((name, expr.shape, indices))
    else:
        result.append((name, expr))
        references.append(None)

    return result, references

def inflate_exprs(name_exprs):
    result = []
    references= []
    for name_expr in name_exprs:
        res, refs = inflate_expr(name_expr)
        result.extend(res)
        references.extend(refs)
    return result, references

def deflate_exprs(name_exprs, references):
    result = []
    deflated = {}

    for name_expr, reference in zip(name_exprs, references):
        if reference is None:
            result.append(name_expr)
        else:
            _, expr = name_expr
            name, shape, indices = reference
            if name not in deflated:
                e = Matrix(np.zeros(shape))
                result.append(NamedExpr(name, e))
                deflated[name] = e
            deflated[name][*indices] = expr

    another_result = []
    for name_expr in result:
        name, expr = name_expr
        if isinstance(expr, Matrix):
            another_result.append(NamedExpr(name, ImmutableMatrix(expr)))
        else:
            another_result.append(name_expr)

    return another_result

def my_subs(expr, sub_name_exprs):
    sub_name_exprs, _ = inflate_exprs(sub_name_exprs)

    result = expr
    result = result.subs([(e,n) for n, e in sub_name_exprs])
    return result

In [None]:
def build_dependency_graph(name_exprs):
    graph = {}
    for name, expr in name_exprs:
        graph[name] = expr.free_symbols
    return graph

def build_influence_graph(name_exprs):
    graph = {}
    for name, expr in name_exprs:
        for s in expr.free_symbols:
            graph.setdefault(s, set()).add(name)
    return graph

In [None]:
def order_exprs_by_input(name_exprs):
    all_expr_names = set().union(name for name, _ in name_exprs)
    all_expr_symbols = set().union(*[expr.free_symbols for _, expr in name_exprs])
    inputs = all_expr_symbols - all_expr_names

    order = {}

    order.update({i: 0 for i in inputs})

    while len(order) < len(inputs) + len(name_exprs):
        for name, expr in name_exprs:
            symbols_order = [order.get(s, None) for s in expr.free_symbols]
            if None in symbols_order:
                continue
            order[name] = max(symbols_order) + 1

    result = name_exprs
    result = sorted(result, key=lambda n_e: len(n_e[1].args))
    result = sorted(result, key=lambda n_e: len(n_e[1].free_symbols))
    result = sorted(result, key=lambda n_e: order[n_e[0]])
    return result

In [None]:
def order_exprs_by_output(name_exprs, outputs):
    name_expr_by_name = {name_expr[0]: name_expr for name_expr in name_exprs}

    def get_inputs(output): 
        name_expr = name_expr_by_name.get(output, None)
        if name_expr is None:
            return set()
        inputs = set(name_expr[1].free_symbols)
        inputs.update(*[get_inputs(name) for name in inputs])
        return inputs

    result = []
    done = set()

    for output in outputs:
        inputs = get_inputs(output) - done
        result.extend(order_exprs_by_input([name_exprs for name_exprs in name_exprs if name_exprs[0] in inputs]))
        result.append(name_expr_by_name[output])
        done.update(inputs)
        done.add(output)

    return result

In [None]:
def my_cse(name_exprs, inflate_deflate=True):
    sub_symbols = numbered_symbols()

    if inflate_deflate:
        name_exprs, references = inflate_exprs(name_exprs)

    names = [x[0] for x in name_exprs]
    exprs = [x[1] for x in name_exprs]

    sub_exprs, simp_exprs = sym.cse(exprs, symbols=sub_symbols)

    simp_name_exprs = list(zip(names, simp_exprs))
    if inflate_deflate:
        simp_name_exprs = deflate_exprs(simp_name_exprs, references)

    name_exprs = []
    name_exprs.extend(sub_exprs)
    name_exprs.extend(simp_name_exprs)

    return name_exprs

In [None]:
def my_expression_print(printer, name_exprs, outputs, run_cse=True, pre_expr_hook=None, post_expr_hook=None):
    def print_assign(var, expr):
        code = printer.doprint(Assignment(var, expr))
        return code

    if run_cse:
        name_exprs = my_cse(name_exprs, inflate_deflate=True)
    name_exprs = order_exprs_by_output(name_exprs, outputs)

    lines = []

    for var, expr in name_exprs:
        if pre_expr_hook is not None:
            code = pre_expr_hook(var)
            if code is not None:
                lines.extend(code.split("\n"))

        code = printer.doprint(Assignment(var, expr))
        if var not in outputs:
            if hasattr(expr, "shape"):
                lines.append(f"T {var}[{np.prod(expr.shape)}];")
                lines.extend(code.split("\n"))
            else:
                lines.append("const auto " + code)
        else:
            if hasattr(expr, "shape"):
                lines.extend(code.split("\n"))
            else:
                lines.append("*" + code)

        if post_expr_hook is not None:
            code = post_expr_hook(var)
            if code is not None:
                lines.extend(code.split("\n"))

    return "\n".join(lines)

def my_function_print(printer, name, inputs, name_exprs, outputs, run_cse=True):
    def input_param(input):
        if isinstance(input, MatrixSymbol):
            return f"const T* {input.name}"
        return f"const T {input.name}"
    def output_param(name):
        return f"T* {name}"

    lines = []

    params = [input_param(input) for input in inputs] + [output_param(output) for output in outputs]
    head = f"template<typename T> void {name}({", ".join(params)}) {{"

    lines.append(head)

    code = my_expression_print(printer, name_exprs, outputs, run_cse=run_cse)
    lines.extend([f"  {l}" for l in code.split("\n")])

    lines.append("}")

    return "\n".join(lines)

In [None]:
def my_step_function_print(name_exprs, run_cse=True):
    printer = cxx_printer
    name = "rk4"
    inputs = []
    outputs = [find_by_name(name_exprs, name)[0] for name in ["p2", "p3", "err", "new_p", "new_d", "new_time", "path_derivatives", "new_J"]]

    lines = []

    head = "template <typename T, typename GetB> bool rk4(const T* p, const T* d, const T t, const T h, const T lambda, const T m, const T p_abs, GetB getB, T* new_p, T* new_d, T* new_time, T* path_derivatives, T* J) {"
    lines.append(head)

    lines.append("  const auto B1 = getB(p);");

    def pre_expr_hook(var):
        if str(var) == "p2":
            return "T p2[3];"
        if str(var) == "p3":
            return "T p3[3];"
        if str(var) == "err":
            return "T err;"
        if str(var) == "new_J":
            return "T new_J[64];"
        return None
    def post_expr_hook(var):
        if str(var) == "p2":
            return "const auto B2 = getB(p2);"
        if str(var) == "p3":
            return "const auto B3 = getB(p3);"
        if str(var) == "err":
            return "if (err > 1e-4) {\n  return false;\n}"
        if str(var) == "new_time":
            return "if (J == nullptr) {\n  return true;\n}"
        if str(var) == "new_J":
            return printer.doprint(Assignment(MatrixSymbol("J", 8, 8), var))
        return None

    code = my_expression_print(
        printer,
        name_exprs,
        outputs,
        run_cse=run_cse,
        pre_expr_hook=pre_expr_hook,
        post_expr_hook=post_expr_hook,
    )
    code = code.replace("*err", "err")
    lines.extend([f"  {l}" for l in code.split("\n")])

    lines.append("  return true;")

    lines.append("}")

    return "\n".join(lines)

In [None]:
all_name_exprs = rk4_short_math()
sub_name_exprs = [
    name_expr("hlB1", h * l * B1),
    name_expr("hlB2", h * l * B2),
    name_expr("hlB3", h * l * B3),
    name_expr("lB1", l * B1),
    name_expr("lB2", l * B2),
    name_expr("lB3", l * B3),
    name_expr("h2_2", h**2/2),
    name_expr("h_8", h/8),
    name_expr("h_6", h/6),
    name_expr("h_2", h/2),
]
all_name_exprs = [NamedExpr(name, expr.expand()) for name, expr in all_name_exprs]
all_name_exprs = [NamedExpr(name, my_subs(expr, sub_name_exprs)) for name, expr in all_name_exprs]
all_name_exprs.extend(sub_name_exprs)

code = my_step_function_print(
    all_name_exprs,
    run_cse=True,
)
print(code)

In [None]:
all_name_exprs = full_transport_jacobian_generic()
code = my_expression_print(
    cxx_printer,
    all_name_exprs,
    [find_by_name(all_name_exprs, name)[0] for name in ["J_full"]],
    run_cse=True,
)
print(code)

In [None]:
all_name_exprs = covariance_transport_generic()
code = my_expression_print(
    cxx_printer,
    all_name_exprs,
    [find_by_name(all_name_exprs, name)[0] for name in ["new_C"]],
    run_cse=True,
)
print(code)

In [None]:
all_name_exprs = full_transport_jacobian_curvilinear(MatrixSymbol("dir", 3, 1))
code = my_expression_print(
    cxx_printer,
    all_name_exprs,
    [find_by_name(all_name_exprs, name)[0] for name in ["J_full"]],
    run_cse=True,
)
print(code)