## 23.5 - Code Generation - Example - Three Body Problem

In [None]:
%matplotlib widget
from sympy import *
init_printing(use_latex=True)
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
get_pos = lambda i: Matrix([*[Symbol(s + str(i)) for s in ["X", "Y", "Z"]]])
get_vel = lambda i: Matrix([*[Symbol("V_" + s + str(i)) for s in ["X", "Y", "Z"]]])
r1, r2, r3 = [get_pos(i) for i in range(1, 4)]
v1, v2, v3 = [get_vel(i) for i in range(1, 4)]
r1, r2, r3, v1, v2, v3

In [None]:
G, m1, m2, m3 = symbols("G, m1, m2, m3")

In [None]:
def term(ma, ra, rb):
    return G * ma * (ra - rb) / sqrt((ra - rb).dot(ra - rb))**3

In [None]:
a1 = term(m2, r2, r1) + term(m3, r3, r1)
a2 = term(m1, r1, r2) + term(m3, r3, r2)
a3 = term(m1, r1, r3) + term(m2, r2, r3)
display(a1)

In [None]:
from collections import OrderedDict
data = {
    "Flower in Circle": OrderedDict([
        ("r1", [-0.602885898116520, 1.059162128863347-1, 0]), 
        ("r2", [0.252709795391000, 1.058254872224370-1, 0]), 
        ("r3", [-0.355389016941814, 1.038323764315145-1, 0]), 
        ("v1", [0.122913546623784, 0.747443868604908, 0]), 
        ("v2", [-0.019325586404545, 1.369241993562101, 0]), 
        ("v3", [-0.103587960218793, -2.116685862168820, 0]),
        ("period", 2.246101255307486),
        ("gmasses", (1, 1, 1, 1))
    ])
}

# initial conditions
case = "Flower in Circle"
y0 = [t for k, v in data[case].items() if isinstance(v, list) for t in v]
period = data[case]["period"]
_G, _m1, _m2, _m3 = data[case]["gmasses"]

In [None]:
from sympy_utils import plot_arrows_direction_from_line

def add_line(x, y, i, label, arrows=True, N=6, hw=.025):
    line = plt.plot(x, y, "C" + str(i))
    plt.plot(x[0], y[0], "oC" + str(i), label=label)
    
    if arrows:
        plot_arrows_direction_from_line(line, N, hw, skipfirst=True)
    return line

def plot_orbit(c, m=[1, 1, 1], com=True):
    """ Plot the orbits for the Three Body Problem.
    
    Parameters
    ----------
        c : np.ndarray
            Coordinates
        m : float
            Total mass
        com : boolean
            If True, plot in the Center Of Mass reference system.
            Default to True
    """
    if not isinstance(m, (list, tuple)):
        raise TypeError("`m` must be a list of three elements")
    if len(m) != 3:
        raise ValueError("`m` must be a list of three elements")
    m1, m2, m3 = m
    m = m1 + m2 + m3
    
    fig = plt.figure()
    if com:
        xG = (m1 * c[0, :] + m1 * c[3, :] + m1 * c[6, :]) / m
        yG = (m2 * c[1, :] + m2 * c[4, :] + m2 * c[7, :]) / m
        zG = (m3 * c[2, :] + m3 * c[5, :] + m3 * c[8, :]) / m
        add_line(c[0, :] - xG, c[1, :] - yG, 0, label="m1")
        add_line(c[3, :] - xG, c[4, :] - yG, 1, label="m2")
        add_line(c[6, :] - xG, c[7, :] - yG, 2, label="m3")
    else:
        add_line(c[0, :], c[1, :], 0, label="m1")
        add_line(c[3, :], c[4, :], 1, label="m2")
        add_line(c[6, :], c[7, :], 2, label="m3")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.title("Flower in circle")
    plt.grid()
    plt.legend()
    plt.axis("equal")
    plt.show()

In [None]:
from scipy.integrate import solve_ivp
from sympy.utilities.codegen import codegen
from sympy.utilities.autowrap import autowrap

### 23.5.1 - Generating a lambda function

In [None]:
y = [*r1, *r2, *r3, *v1, *v2, *v3]
dydt = [*v1, *v2, *v3, *a1, *a2, *a3]

In [None]:
Y = IndexedBase("Y")
d = dict(zip(y, Y))
print(d)

In [None]:
dydt = [e.xreplace(d) for e in dydt]
dydt[-1]

In [None]:
t = symbols("t")
f = lambdify([t, Y, G, m1, m2, m3], dydt, "numpy")

In [None]:
import inspect
print(inspect.getsource(f))

In [None]:
sol = solve_ivp(f, [0, period], y0, args=(_G, _m1, _m2, _m3),
              method="RK45", rtol=1e-10, atol=1e-10)
r = sol.y
plot_orbit(r)

In [None]:
%timeit solve_ivp(f, [0, period], y0, args=(_G, _m1, _m2, _m3), method="RK45", rtol=1e-10, atol=1e-10)

### 23.5.2 - Generating an executable with autowrap()

In [None]:
from sympy.utilities.codegen import codegen
from sympy.utilities.autowrap import autowrap
y = [*r1, *r2, *r3, *v1, *v2, *v3]
dydt = [*v1, *v2, *v3, *a1, *a2, *a3]

In [None]:
help(codegen)

In [None]:
[(cfilename, csourcecode), (hfilename, hsourcecode)] = codegen(("three_body", Matrix(dydt)), language="c")
print(csourcecode)

In [None]:
Y = MatrixSymbol("Y", len(y), 1)
d = dict(zip(y, Y))
print(d)

In [None]:
dydt = Matrix(dydt).xreplace(d)
[(cfilename, csourcecode), (hfilename, hsourcecode)] = codegen(("three_body", dydt), language="c")
print(csourcecode)

In [None]:
dY = MatrixSymbol("dY", *dydt.shape)
eq = Eq(dY, dydt)
[(cfilename, csourcecode), (hfilename, hsourcecode)] = codegen(("three_body", eq), language="c")
print(csourcecode)

In [None]:
from sympy.utilities.codegen import C99CodeGen
help(C99CodeGen)

In [None]:
cg = C99CodeGen(cse=True)
[(cfilename, csourcecode), (hfilename, hsourcecode)] = codegen(("three_body", eq), code_gen=cg)
print(csourcecode)

In [None]:
help(autowrap)

In [None]:
t = symbols("t")
binary_func = autowrap(eq, args=[t, Y, G, m1, m2, m3], 
                        backend='cython', tempdir='./wtf3', code_gen=cg)

In [None]:
def three_body_problem(t, y, G, m1, m2, m3):
    dY = binary_func(t, y[:, np.newaxis], G, m1, m2, m3)
    return dY.squeeze()

In [None]:
sol = solve_ivp(three_body_problem, [0, period], np.asarray(y0), args=(_G, _m1, _m2, _m3),
              method="RK45", rtol=1e-10, atol=1e-10)
r = sol.y
plot_orbit(r)

In [None]:
%timeit solve_ivp(three_body_problem, [0, period], np.asarray(y0), args=(_G, _m1, _m2, _m3), method="RK45", rtol=1e-10, atol=1e-10)

### 23.5.3 - Manually generating an executable

Before executing the following cells, copy the folder `autowrap_3bp` (generated in the previous section) and follow the instructions on the book.

In [None]:
import os
olddir = os.getcwd()
try:
    os.chdir(os.path.join(olddir, "autowrap_3bp_final"))
    !python3 setup.py build_ext --inplace
finally:
    os.chdir(olddir)

In [None]:
from autowrap_3bp_final.wrapper_module_0 import autofunc_c as three_body_binary


In [None]:
sol = solve_ivp(three_body_binary, [0, period], y0, args=(_G, _m1, _m2, _m3),
              method="RK45", rtol=1e-10, atol=1e-10)
r = sol.y
plot_orbit(r)

In [None]:
%timeit solve_ivp(three_body_binary, [0, period], y0, args=(_G, _m1, _m2, _m3), method="RK45", rtol=1e-10, atol=1e-10)