# Geodesic Solver

In [1]:
import sys
sys.path.append("/home/alok/Programming/py-manifolds")

In [2]:
import jax
import jax.numpy as jnp

from manifolds.manifold import ChartPoint, Cotangent, ContravariantTensor, CovariantTensor, Tangent, Tensor
from manifolds.sphere import Sphere, StereographicChart, SpherePoint
from manifolds.riemannian import levi_civita

In [3]:
sphere = Sphere(5.)



The goal here is to implement the exponential map with a simple ODE solver.
The exponential map is defined by a second-order linear ODE, which can be converted to a system of first-order ODEs with parameters x and v.

$$\dot x(t) = v(t)$$

$$ \dot v^k(t) = -v^i(t) v^j(t)\Gamma^k_{ij}(x(t)) $$

What does taking a step look like?

- You are at some point $x_0$ with velocity $v_0$
- If you are in explicit mode, then you adjust the coordinates of $x$ by $dt \cdot v$, and you adjust the coordinates of $v$ by evaluating the Christoffel symbol product.
- Explicit mode is equivalent to approximating x' and v' with forward differences in time.
- If you are in implicit mode, then you approximate x' and v' with backward differences, and then you cry because you need to solve a nonlinear system involving a black box.
- So I guess we're sticking to explicit mode.

In [27]:
def step(manifold, x, v, dt):
    """Given a manifold, a point x, a tangent v, and a step size dt, compute the next values of x and v.
    
    Args:
      manifold: the manifold
      x: a point in the manifold
      v: a tangent vector anchored at x in any chart
      dt: time step size
    """
    chart_x = x
    chart_v = v
    # I think the nonlinearity is causing issues here
    chart_x_new = ChartPoint(chart_x.coords + dt * chart_v.v_coords, chart_x.chart)
    x_new = chart_x_new.to_point()
    christoffel = levi_civita(manifold, chart_x)
    v_derivative = -christoffel.coords @ chart_v.v_coords @ chart_v.v_coords
    chart_v_new = Tangent(chart_x_new, chart_v.v_coords + dt * v_derivative)
    return chart_x_new, chart_v_new
    
# intentionally using equator because chart doesn't matter there
@jax.jit
def exp_map_step(initial_x, initial_v, signed_radius):
    x = ChartPoint(initial_x, StereographicChart(signed_radius))
    # chart = sphere.preferred_chart(x_0)
    v = Tangent(x, initial_v)
    x_new, v_new = step(sphere, x, v, 0.001)
    return x_new.coords, v_new.v_coords

signed_r = jnp.array(-5.)
x_coords = jnp.array([4., 0.])
v_coords = jnp.array([1., 1.])
for i in range(1000):
    x_coords, v_coords = exp_map_step(x_coords, v_coords, signed_r)
#     print("iteration done")
    
print(x_coords, v_coords)
print((x_coords ** 2).sum())

[4.868924  1.5079792] [0.5189059 2.1239953]
25.980423


In [28]:
for i in range(100):
    for j in range(100):
        x_coords, v_coords = exp_map_step(x_coords, v_coords, signed_r)
    print((x_coords ** 2).sum())
    
print(x_coords, v_coords)

27.126127
28.258762
29.357578
30.397831
31.351269
32.18747
32.875786
33.387905
33.700836
33.799667
33.679474
33.346
32.81499
32.10996
31.259583
30.294683
29.245611
28.140423
27.003582
25.85555
24.712729
23.587769
22.490026
21.426132
20.400467
19.415747
18.473257
17.57334
16.715597
15.899117
15.12261
14.384589
13.683374
13.017282
12.384577
11.783552
11.212547
10.669947
10.154217
9.663896
9.197599
8.754029
8.331957
7.9302335
7.5477886
7.183611
6.8367634
6.506369
6.1916103
5.8917227
5.6059914
5.3337517
5.0743837
4.8273096
4.5919886
4.3679194
4.1546297
3.9516883
3.7586732
3.5752106
3.4009442
3.2355418
3.078693
2.930109
2.7895212
2.65668
2.5313525
2.4133205
2.3023834
2.1983542
2.1010618
2.010344
1.9260553
1.8480599
1.7762339
1.7104641
1.6506476
1.596692
1.5485139
1.5060406
1.469206
1.4379567
1.4122441
1.3920285
1.3772821
1.367982
1.3641144
1.3656731
1.3726602
1.3850863
1.4029697
1.4263374
1.4552238
1.489673
1.5297358
1.5754738
1.6269557
1.6842613
1.7474778
1.8167036
[-1.042496  -0.8543452] 