# Schwarzschild orbits
Star mass $M$, particle mass $m$, and total energy $E\leq m$

\begin{align}
ds = \big(E^2-m^2+\frac{2Mm^2}{r}\big)\Big(\frac{dr^2}{\big(1-\frac{2M}{r}\big)^2}+ \frac{r^2d\Phi^2}{1-\frac{2M}{r}}\Big)
\end{align}
\begin{align}
&dr=\frac{1}{r}(dx+dy),\qquad d\Phi = -\frac{y}{r^2}dx+\frac{x}{r^2}dy\\
&dr^2 = \frac{1}{r^2}(dx^2+2dxdy+dy^2),\qquad d\Phi^2 = \frac{1}{r^4}(y^2dx^2-2xydxdy+x^2dy^2)
\end{align}
Thus, with $g:=E^2-m^2+\frac{2Mm^2}{r}$ and $f:= \big(1-\frac{2M}{r}\big)^{-1}$
\begin{align}
ds = \frac{fg}{r^2}\Big( (f\,x^2+y^2)\,dx^2 + 2(f-1)xy\,dxdy+(x^2+f\,y^2)\,dy^2 \Big)
\end{align}

In [None]:
import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
from sympy import symbols, diff, Matrix, simplify
import sympy as sp

H = -1.5
L = 0.5
M = 0.0025
m = 1/M**0.5
E = (2*H+m**2)**0.5

X,Y = symbols('x y')
xvec = [X,Y]
r = sp.sqrt(X**2+Y**2)
g = E**2 - m**2 + 2 * M * m**2 / r
f = 1 / (1 - 2 * M / r)
G = simplify(f * g / r**2 * Matrix(2,2, [f * X**2 + Y**2, (f - 1) * X * Y,
                        (f - 1) * X * Y, X**2 + f * Y**2]))


Ginv = simplify(1/G.det()*G.adjugate())
chr1 = [simplify(0.5*(diff(G[k,j],xvec[i])+diff(G[i,k],xvec[j])-diff(G[i,j],xvec[k]))) for i in range(2) for j in range(2) for k in range(2)]
chr2 = [(sum(Ginv[k,p]*chr1[4*i+2*j+p] for p in range(2))) for i in range(2) for j in range(2) for k in range(2)]

# define right-hand side (s... time, v=[gamma^1,gamma^2,v^1,v^2])
def rhs(s, v): 
    chr111 = chr2[4*0+2*0+0].subs(Y, v[1]).subs(X,v[0])
    chr121 = chr2[4*0+2*1+0].subs(Y, v[1]).subs(X,v[0])
    chr221 = chr2[4*1+2*1+0].subs(Y, v[1]).subs(X,v[0])
    chr112 = chr2[4*0+2*0+1].subs(Y, v[1]).subs(X,v[0])
    chr122 = chr2[4*0+2*1+1].subs(Y, v[1]).subs(X,v[0])
    chr222 = chr2[4*1+2*1+1].subs(Y, v[1]).subs(X,v[0])
    return [v[2],v[3],-v[2]**2*chr111-2*v[2]*v[3]*chr121-v[3]**2*chr221,-v[2]**2*chr112-2*v[2]*v[3]*chr122-v[3]**2*chr222]


Tend = 20
refval = solve_ivp(rhs, (0, Tend), [0.5,0,-0.2,0.5],t_eval=np.linspace(0, Tend,400))#, method="DOP853")
plt.plot(refval.y.T[:,0],refval.y.T[:,1])
plt.gca().add_patch(plt.Circle((0, 0), 0.05, color='r'))
plt.axis('equal')
plt.show()

In [None]:
from ngsolve import *
from netgen.occ import *
from ngsolve.webgui import Draw
import numpy as np
import matplotlib.pyplot as plt
from time import time


H = -1.5
L = 0.5
M = 0.0025
m = 1/M**0.5
E = (2*H+m**2)**0.5
r = sqrt(x**2+y**2)


g = E**2 - m**2 + 2 * M * m**2 / r
f = 1 / (1 - 2 * M / r)
Gex = f * g / r**2 * CF((f * x**2 + y**2, (f - 1) * x * y,
                        (f - 1) * x * y, x**2 + f * y**2),dims=(2,2))

shape = Circle((0,0),0.8).Face()
mesh = Mesh(OCCGeometry(shape,dim=2).GenerateMesh(maxh=0.02))

def segment_intersect(A, B, C, D):
    # If the segments are parallel, they do not intersect, then check if points are on the same side
    if np.abs(np.cross(B-A, D-C)) < 1e-12: return (False,None)
    if np.dot(np.cross(B-A,C-A),np.cross(B-A,D-A))>0: return (False,None)
    if np.dot(np.cross(D-C,A-C),np.cross(D-C,B-C))<=0:
        s = (np.cross(A,B-A)-np.cross(C,B-A))/np.cross(D-C,B-A)
        return (True, C+s*(D-C))
    else:
        return (False,None)

def segm_trig_intersect(P,Q,el,mesh):
    A = np.array(mesh[el.vertices[0]].point)
    B = np.array(mesh[el.vertices[1]].point)
    C = np.array(mesh[el.vertices[2]].point)
    res1,res2,res3=True,True,True
    scale = 1e-7
    diff = Q-P
    while int(res1+res2+res3)>1:
        newP = P + scale*diff
        res1,P1 = segment_intersect(A,B, newP, Q)
        res2,P2 = segment_intersect(A,C, newP, Q)
        res3,P3 = segment_intersect(C,B, newP, Q)
        scale *=2
    if res1: return (P1, A, B)
    if res2: return (P2, A, C)
    if res3: return (P3, C, B)
    else: print("Warning: segement does not intersect trig")
        
def KineticEnergy(q,p,G,mesh):
    G_eval = np.array(G(mesh(*q))).reshape(2,2)
    return 0.5*G_eval.dot(p).dot(p)

def BisectElements(oldpoint,currentpoint,direction,old_el,mesh,dt):
    tmp_el = mesh[ElementId(mesh(*currentpoint).nr)]
    a = 0
    b = 1
    other_el = tmp_el
    near_oldp = oldpoint
    near_newp = oldpoint + 1*dt*direction
    while (b-a)>1e-7:
        s = (b+a)/2
        tmppnt = oldpoint + s*dt*direction
        tmp_el = mesh[ElementId(mesh(*tmppnt).nr)]
        if tmp_el.nr == old_el.nr:
            near_oldp = tmppnt
            a = s
        else:
            near_newp = tmppnt
            other_el = tmp_el
            b = s
    return (near_oldp, near_newp, other_el,s)

In [None]:
def rhs(s, v, mesh, chr2):        
    try:
        chr_eval = chr2(mesh(*v[:2]))
    except:
        return 0*v 
    chr111 = chr_eval[4*0+2*0+0]
    chr121 = chr_eval[4*0+2*1+0]
    chr221 = chr_eval[4*1+2*1+0]
    chr112 = chr_eval[4*0+2*0+1]
    chr122 = chr_eval[4*0+2*1+1]
    chr222 = chr_eval[4*1+2*1+1]
    return [v[2],v[3],-v[2]**2*chr111-2*v[2]*v[3]*chr121-v[3]**2*chr221,-v[2]**2*chr112-2*v[2]*v[3]*chr122-v[3]**2*chr222]


def SolveODE(x0,y0,v0,v1, mesh, chr2, dt, Tend=1):
    T=0
    numsteps = int(Tend/dt)
    values = [np.array([x0,y0,v0,v1])]
    
    currentpoint = values[0][:2]
    oldpoint = currentpoint
    el = mesh[ElementId(mesh(*values[0][:2]).nr)]
    ts = [T]
    while T < Tend:
        inc = np.array(rhs(T,values[-1][:], mesh, chr2))
        values.append(values[-1][:] + dt*inc)
        
        oldpoint = currentpoint
        currentpoint = values[-1][:2]
        old_el = el
        
        # new point out of mesh ?
        if mesh(*values[-1][:2]).nr < 0: return values[:-1][:], ts
        
        el = mesh[ElementId(mesh(*values[-1][:2]).nr)]
        
        # Geodesic crosses face
        if old_el.nr != el.nr:
            energy = KineticEnergy(values[-1][:2],values[-1][2:],mesh=mesh,G=gfG) 
            intersection_pnt,V1,V2 = segm_trig_intersect(oldpoint,currentpoint,old_el,mesh)
            
            near_oldp, near_newp, el,s = BisectElements(oldpoint,currentpoint,values[-2][2:],old_el,mesh,dt)
            T += s*dt
            
            values[-1][:2] = near_newp
            intersection_pnt_old = near_oldp
            intersection_pnt_new = near_newp
            
            gold = np.array(gfG(mesh(*intersection_pnt_old))).reshape(2,2)
            gnew = np.array(gfG(mesh(*intersection_pnt_new))).reshape(2,2)

            t = 1/np.linalg.norm(V2-V1)*(V2-V1)
            n = np.array( [t[1],-t[0]])
            told = 1/np.sqrt((gold.dot(t)).dot(t))*t
            nold = 1/np.sqrt((np.linalg.inv(gold).dot(n)).dot(n))*np.linalg.inv(gold).dot(n)
            nnew = -1/np.sqrt((np.linalg.inv(gnew).dot(n)).dot(n))*np.linalg.inv(gnew).dot(n)
                
            values[-1][2:] = (gold.dot(values[-1][2:]).dot(told))*told - (gold.dot(values[-1][2:]).dot(nold))*nnew
        else:
            T += dt
        ts.append(T)
        
    return values, ts

results = []
maxhs=[0.05]

Tend = 200
init = [0.5,0,-0.2,0.5]
          
energy = []
ts = []
with TaskManager():
    for maxh in maxhs:
        mesh = Mesh(OCCGeometry(shape,dim=2).GenerateMesh(maxh=maxh))
    
        gfG = GridFunction(HCurlCurl(mesh,order=0))
        gfG.Set(Gex, dual=True)
        #chr2 = gfG.Operator("christoffel2")
        # Is zero for lowest order Regge
        chr2 = CF( tuple([0 for i in range(8)]), dims=(2,2,2) )

        values, t = SolveODE(x0=init[0],y0=init[1],v0=init[2],v1=init[3], \
                                    mesh=mesh, chr2=chr2, dt=0.005, Tend=Tend)
        results.append(np.array(values))
        energy.append([])
        ts.append(t)
        for k in range(len(results[-1][:,0])):
            energy[-1].append(KineticEnergy(results[-1][k,:2],results[-1][k,2:],mesh=mesh,G=gfG))

for i in range(len(results)):
    plt.plot(results[i][:,0],results[i][:,1], label="h"+str(maxhs[i]))
plt.plot(refval.y.T[:,0],refval.y.T[:,1], '-' , color="k", label="reference")
plt.legend()
plt.axis('equal')
plt.ticklabel_format(useOffset=False)
plt.show()

for i in range(len(energy)):
    plt.plot(ts[i],energy[i])
plt.ticklabel_format(useOffset=False)
plt.show()   

lpts = []
for i in range(len(results)):
    lpts.append([])
    for k in range(len(results[i][:,0])-1):
        lpts[-1] += [results[i][k,0], results[i][k,1], 0.0, \
            results[i][k+1,0], results[i][k+1,1], 0.0]
            
lines = []
colors=["red","blue","magenta","cyan","orange", "teal", "brown", "white","red"]
for i in range(len(lpts)):
    lines.append({ "type": "lines", "position": lpts[i], "name": "my lines", "color" : colors[i]})

Draw (mesh,objects=lines);