# Curvature

In this introductory example we compute various curvature quantities from the Riemannian metric on a manifold and compare against analytically known results. Everything here should be accessible with a basic knowledge of scientific computing and differential geometry. There are two Jax-specific transformations which we explain briefly below, for more detail please see the [official guides](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#how-to-think-in-jax).

* `jax.jit`: Short for Just-in-Time compilation, this converts Jax Python functions to an optimised sequence of primitive operations which are then passed to some hardware accelerator. The output of `jit` is another function - usually one that executes significantly faster than the Python equivalent. The price to be paid is that the program logic of a `jit`-compatible function is constrained by the compiler, so you don't want (or need) to `jit` everything.
* `jax.vmap`: Short for Vectorising Map, this transforms Jax Python functions written for execution on a single array element, to one which is automatically vectorised across the specified array axes. Again, program logic of a `vmap`-compatible function is restricted.

Jax transformations are compatible - you can `jit` a `vmap`-ed function and vice-versa. And that's pretty much all you need to know to understand this example!

While not a dependency of the package, the example notebooks require the installation of `jupyter`, run this locally if you haven't already.
```
pip install --upgrade jupyter notebook
```

In [1]:
import jax
from jax import random, jit, vmap
import jax.numpy as jnp

import os, time
import numpy as np

from functools import partial

jax.config.update("jax_enable_x64", True)

## Manifold definition / point sampling
The routines in this library will work for an arbitrary real or complex manifold from which points may be sampled from. In this example, we consider complex projective space $\mathbb{P}^n$. This the space of complex lines in $\mathbb{C}^{n+1}$ which pass through the origin.

To sample from $\mathbb{P}^n$, we use the fact that every complex line intersects the unit sphere along a circle, whose $U(1)$ action we mod out, $\mathbb{P}^n \simeq S^{2n+1} / U(1)$. This means that samples from the unit sphere, appropriately complexified, give samples in homogeneous coordinates on projective space. Here we set $n=5$.

In [2]:
from cymyc.utils import math_utils

ambient_dim = 10
N = 10
seed = int(time.time()) # 42
rng = random.PRNGKey(seed)
rng, _rng = random.split(rng)

def S2np1_uniform(key, n_p, n, dtype=np.float64):
    """
    Sample `n_p` points uniformly on the unit sphere $S^{2n+1}$, treated as CP^n
    """
    # return random.uniform(key, (n,))*jnp.pi, random.uniform(key, (n,)) * 2 * jnp.pi
    x = random.normal(key, shape=(n_p, 2*(n+1)), dtype=dtype)
    x_norm = x / jnp.linalg.norm(x, axis=1, keepdims=True)
    sample = math_utils.to_complex(x_norm.reshape(-1, n+1, 2))

    return jnp.squeeze(sample)

In [3]:
Z = S2np1_uniform(_rng, N, ambient_dim)
Z

Array([[ 0.08603016+0.12421914j, -0.19522247-0.21834665j,
        -0.11015934-0.13064771j,  0.03369113+0.10227832j,
        -0.38531301-0.08720344j, -0.22524021-0.01446186j,
         0.04569006+0.15634691j, -0.1345815 +0.11960158j,
        -0.47859065+0.01476375j,  0.04425167-0.38554386j,
         0.35284231+0.28328966j],
       [-0.04930381+0.18714094j, -0.00862798-0.26209397j,
         0.06624076+0.26345604j, -0.47763194+0.24401247j,
         0.00746399+0.02439196j, -0.1925575 -0.1395849j ,
        -0.16413301-0.30888711j,  0.42940631+0.13573117j,
        -0.24028406+0.01348021j,  0.22884049-0.07554974j,
         0.11892793+0.14069137j],
       [-0.46571271+0.06606655j, -0.19043696+0.09948309j,
         0.11143777-0.02330335j, -0.10892183-0.25118627j,
         0.12068737+0.40133746j,  0.07840926+0.23023486j,
        -0.15291269+0.01495585j, -0.08312896-0.21207946j,
        -0.11511381-0.02870484j, -0.20139866+0.1836997j ,
         0.46872185-0.16215461j],
       [ 0.22165086-0.511594

We now use the scaling freedom in projective space to convert homogeneous coords on $\mathbb{C}\mathbb{P}^n$, $\left[z_0 : \cdots : z_n\right]$ to inhomogeneous coords in some local coordinate chart where $z_{\alpha}$ nonzero, setting $z_{\alpha} = 1$ and removing it from the coordinate description,

$$\left[z_0 : \cdots : z_n\right] \mapsto \left(\frac{z_0}{z_{\alpha}}, \ldots, \frac{z_{\alpha-1}}{z_{\alpha}}, \frac{z_{\alpha+1}}{z_{\alpha}}, \ldots, \frac{z_n}{z_{\alpha}}\right) \triangleq \zeta^{(\alpha)}~. $$

In [4]:
Z, _ = math_utils.rescale(Z)
z = vmap(math_utils._inhomogenize)(Z)
z.shape

(10, 10)

## Metric definition

There is a natural metric on $\mathbb{P}^n$ - the **Fubini-Study metric**. Viewing $\mathbb{P}^n$ as the quotient $S^{2n+1} / U(1)$, the Fubini_study metric is the unique metric such that the projection $\pi: S^{2n+1} \rightarrow \mathbb{P}^n$ is a Riemannian submersion. In inhomogeneous coordinates,

$$ g_{\mu \bar{\nu}} = \frac{1}{\sigma}\left( \delta_{\mu \overline{\nu}} - \frac{\zeta_{\mu}\zeta_{\bar{\nu}}}{\sigma}\right), \quad \sigma = 1 + \sum_{m=1}^n \zeta_m\bar{\zeta}_m~. $$

The function below returns the FS metric in local coordinates. Note it requires a real input for autodiff to play nice, so we use the map 

$$z = (z_1, \ldots, z_n) \in \mathbb{C}^n \mapsto (\Re(z_1), \ldots, \Re(z_n); \Im(z_1), \ldots, \Im(z_n)) \in \mathbb{R}^{2n}~.$$

In [5]:
def fubini_study_metric(p):
    """
    Returns FS metric in CP^n evaluated at `p`.
    Parameters
    ----------
        `p`     : 2*complex_dim real inhomogeneous coords at 
                  which metric matrix is evaluated. Shape [i].
    Returns
    ----------
        `g`     : Hermitian metric in CP^n, $g_{ij}$. Shape [i,j].
    """

    # Inhomogeneous coords
    complex_dim = p.shape[-1]//2
    zeta = jax.lax.complex(p[:complex_dim],
                           p[complex_dim:])
    zeta_bar = jnp.conjugate(zeta)
    zeta_sq = 1. + jnp.sum(zeta * zeta_bar)
    
    zeta_outer = jnp.einsum('...i,...j->...ij', zeta_bar, zeta)

    delta_mn = jnp.eye(complex_dim, dtype=jnp.complex64) 

    g_FS = jnp.divide(delta_mn * zeta_sq - zeta_outer, jnp.square(zeta_sq))
    
    return g_FS

In [6]:
p = math_utils.to_real(z)
g_FS = vmap(fubini_study_metric)(p)
g_FS.shape

(10, 10, 10)

We can benchmark execution times with and without `jit`-compilation - note the exact speedup will depend on the hardware available. 

In [7]:
%%timeit
_ = vmap(fubini_study_metric)(p).block_until_ready()

6.18 ms ± 182 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
%%timeit
_ = vmap(jit(fubini_study_metric))(p).block_until_ready()

659 μs ± 6.58 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### The Kähler potential
$\mathbb{P}^n$ is a Kähler manifold - this imbues it with many special properties, one of them being that the metric is locally determined by a single real scalar function, the Kähler potential, $\mathcal{K} \in C^{\infty}(\mathbb{P}^n)$.

\begin{align*}
g_{\mu \bar{\nu }} &= \partial_{\mu}\overline{\partial}_{\bar{\nu}} \mathcal{K}~, \\
\mathcal{K} &= \log \left( 1+ \sum_{m=1}^n \left\vert \zeta_m \right\vert^2\right)~.
\end{align*}

This is particularly important in the context of approximating metrics, as it allows one to reduce the problem to approximation of a single scalar function.

In [9]:
def fubini_study_potential(p):
    """
    Returns Kahler potential associated with the FS metric
    in CP^n evaluated at `p`.
    Parameters
    ----------
        `p`        : 2*complex_dim real inhomogeneous coords at 
                     which potential is evaluated. Shape [i].
    Returns
    ----------
        `phi`      : Kahler potential, real scalar. Shape [].  
    """
    zeta_sq = jnp.sum(p**2)
    return jnp.log(1. + zeta_sq)

In [10]:
from cymyc import curvature
_g_FS = vmap(curvature.del_z_bar_del_z, in_axes=(0,None))(p, fubini_study_potential)
_g_FS.shape

(10, 10, 10)

In [11]:
jnp.allclose(g_FS, _g_FS)

Array(True, dtype=bool)

## Riemann tensor

Measures of curvature corresponding to a given metric tensor involve derivatives of the metric - if a function corresponding to the metric tensor is known, these may be easily computed numerically using autodiff. The most important curvature quantity is the Riemann curvature - the endomorphism-valued two-form that informs us about local curvature effects, $\textsf{Riem} \in \Omega^2(X; \textsf{End}(T_X))$.

Schematically, the curvature tensor is given by taking two derivatives of the metric tensor w.r.t. the input coordinates. $\Gamma$ below refers to the Levi-Civita connection in local coordinates,

$$\textsf{Riem} \sim \partial \Gamma + \Gamma \cdot \Gamma, \quad \Gamma \sim g^{-1} \partial g~.$$

In [12]:
riem = vmap(curvature.riemann_tensor_kahler, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric))

This involves two derivatives of a potentially expensive function, but is reasonably speedy for even $10^4$ points, as we can test by benchmarking - in this case the function is already `jit`-ed at definition. Note nested `jit`s are equivalent to a single `jit`.

In [13]:
%%timeit
riem = vmap(curvature.riemann_tensor_kahler, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric)).block_until_ready()
riem.shape

1.5 ms ± 25.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [12]:
rtk = partial(curvature.riemann_tensor_kahler, return_aux=True)

In [13]:
_, riem = vmap(rtk, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric))

In [14]:
riem = jnp.einsum('...abcd, ...ae->...becd', riem, g_FS)

In [15]:
riem.shape

(10000, 4, 4, 4, 4)

In [16]:
riem[m][1,1,1,1]

NameError: name 'm' is not defined

In [18]:
riem[m][0,0,0,0]

NameError: name 'm' is not defined

In [None]:
riem[m][0,1,0,1]

In [None]:
riem[m][0,1,0,1]

In [12]:
m = 77
a = riem[m][jnp.nonzero(np.round(riem[m],5))]
a

Array([ 4.84335398e-01+1.62630326e-19j, -2.70039201e-03-1.77210357e-03j,
        1.89404177e-04+2.57673254e-03j, ...,
        2.56613641e-02-1.89857814e-02j,  2.78335756e-03+7.84549814e-04j,
        4.69871583e-01-7.80625564e-18j], dtype=complex128)

In [13]:
riem[m].shape

(10, 10, 10, 10)

In [14]:
np.round(riem[m],5).shape

(10, 10, 10, 10)

In [15]:
riem[m][0,1,1,0]

Array(-3.46944695e-18+1.21430643e-17j, dtype=complex128)

In [16]:
riem[m][1,1,1,0]

Array(-0.00270039+0.0017721j, dtype=complex128)

In [17]:
riem[m][0,1,1,1]

Array(4.33680869e-18+4.33680869e-19j, dtype=complex128)

In [18]:
jnp.nonzero(np.round(riem[m],5))

(Array([0, 0, 0, ..., 9, 9, 9], dtype=int64),
 Array([0, 0, 0, ..., 9, 9, 9], dtype=int64),
 Array([0, 0, 0, ..., 9, 9, 9], dtype=int64),
 Array([0, 1, 2, ..., 7, 8, 9], dtype=int64))

In [19]:
n = 8
np.round(a, decimals=n)

Array([ 4.8433540e-01+0.j        , -2.7003900e-03-0.0017721j ,
        1.8940000e-04+0.00257673j, ...,  2.5661360e-02-0.01898578j,
        2.7833600e-03+0.00078455j,  4.6987158e-01-0.j        ],      dtype=complex128)

In [20]:
np.unique(np.round(a, decimals=n), return_counts=True)

(array([-6.9898940e-02-3.303800e-02j, -6.9898940e-02+3.303800e-02j,
        -6.9645600e-02-9.984210e-03j, -6.9645600e-02+9.984210e-03j,
        -5.0326590e-02-2.239172e-02j, -5.0326590e-02+2.239172e-02j,
        -4.9531890e-02-1.134680e-03j, -4.9531890e-02+1.134680e-03j,
        -3.4949470e-02-1.651900e-02j, -3.4949470e-02+1.651900e-02j,
        -3.4822800e-02-4.992110e-03j, -3.4822800e-02+4.992110e-03j,
        -2.5163300e-02-1.119586e-02j, -2.5163300e-02+1.119586e-02j,
        -2.4765950e-02-5.673400e-04j, -2.4765950e-02+5.673400e-04j,
        -1.1899650e-02-1.956415e-02j, -1.1899650e-02+1.956415e-02j,
        -9.2854900e-03-7.705090e-03j, -9.2854900e-03+7.705090e-03j,
        -7.3237200e-03-2.640970e-03j, -7.3237200e-03+2.640970e-03j,
        -5.9498300e-03-9.782070e-03j, -5.9498300e-03+9.782070e-03j,
        -5.8163200e-03-2.863932e-02j, -5.8163200e-03+2.863932e-02j,
        -4.7963300e-03-8.147560e-03j, -4.7963300e-03+8.147560e-03j,
        -4.6427400e-03-3.852540e-03j, -4.6427400

In [21]:
np.unique(np.round(jnp.real(a), decimals=n), return_counts=True)

(array([-6.9898940e-02, -6.9645600e-02, -5.0326590e-02, -4.9531890e-02,
        -3.4949470e-02, -3.4822800e-02, -2.5163300e-02, -2.4765950e-02,
        -1.1899650e-02, -9.2854900e-03, -7.3237200e-03, -5.9498300e-03,
        -5.8163200e-03, -4.7963300e-03, -4.6427400e-03, -3.8579000e-03,
        -3.7299000e-03, -3.6618600e-03, -2.9369100e-03, -2.9081600e-03,
        -2.7003900e-03, -2.6666300e-03, -2.3981600e-03, -2.1570700e-03,
        -1.9289500e-03, -1.8706600e-03, -1.8649500e-03, -1.5533600e-03,
        -1.5116700e-03, -1.4684500e-03, -1.3502000e-03, -1.3333200e-03,
        -1.0785300e-03, -9.3533000e-04, -7.7668000e-04, -7.5584000e-04,
        -2.4924000e-04, -1.2462000e-04,  8.1890000e-05,  9.1470000e-05,
         9.4700000e-05,  1.6378000e-04,  1.8295000e-04,  1.8940000e-04,
         4.6065000e-04,  5.3624000e-04,  6.5644000e-04,  6.7366000e-04,
         9.2131000e-04,  9.5229000e-04,  1.0724800e-03,  1.0881700e-03,
         1.3128700e-03,  1.3473100e-03,  1.3916800e-03,  1.90457

In [22]:
n = 10
len(np.unique(np.round(jnp.real(a), decimals=n), return_counts=True)[0])

110

### First Bianchi identity
We form the Riemann tensor with all indices lowered using the musical isomorphism defined by the metric. The resulting tensor satisifies the following symmetries, as a consequence of the first Bianchi identity,

$$ \textsf{Riem}_{a\overline{b}c\overline{d}} = \textsf{Riem}_{a \overline{d} c \overline{b}} = \textsf{Riem}_{c \overline{b} a \overline{d}} = \textsf{Riem}_{c \overline{d} a \overline{b}}~.$$

In [14]:
riem_lower = jnp.einsum('...ibcd, ...ia->...bacd', riem, g_FS)

In [15]:
jnp.allclose(riem_lower, jnp.einsum('...abcd->...adcb', riem_lower))  # first equality

Array(True, dtype=bool)

In [16]:
jnp.allclose(riem_lower, jnp.einsum('...abcd->...cbad', riem_lower))  # second equality

Array(True, dtype=bool)

In [17]:
jnp.allclose(riem_lower, jnp.einsum('...abcd->...cdab', riem_lower))  # third equality

Array(True, dtype=bool)

## Ricci curvature

Complex projective space is an Einstein manifold, meaning that the Fubini-Study metric on $\mathbb{P}^n$ is proportional to the Ricci curvature. The Ricci curvature is another important measure of curvature derived from $\textsf{Riem}$, which roughly measures the degree of volume distortion relative to Euclidean space as one travels along geodesics emanating from a given point.

$$\textsf{Ric} = \Lambda g~.$$

For $\mathbb{P}^n$ the Einstein constant is $\Lambda = n+1; \textsf{Ric} = (n+1) g_{FS}$.

The Ricci curvature is given, in local coordinates, as the trace of the endomorphism part of the Riemann curvature tensor,

$$ \textsf{Ric}_{\mu \bar{\nu}} \triangleq \textsf{Riem}^{\kappa}_{\; \kappa \mu \bar{\nu}} = \textsf{Riem}^{\kappa}_{\; \mu \kappa \bar{\nu}}~.$$

In [18]:
ricci = vmap(curvature.ricci_tensor_kahler, in_axes=(0,None))(p, jax.tree_util.Partial(fubini_study_metric))

In [19]:
jnp.allclose(ricci, (ambient_dim + 1) * g_FS)

Array(True, dtype=bool)

This also means that the Ricci scalar, the trace of the Ricci curvature, should be, on $\mathbb{P}^n$:

$$ \textsf{R} = n(n+1)~.$$

In [20]:
jnp.einsum('...ba, ...ab', jnp.linalg.inv(g_FS), ricci)

Array([30.-8.33123781e-16j, 30.+1.21746809e-16j, 30.+1.72612599e-15j, ...,
       30.-7.98853592e-17j, 30.+1.10371070e-16j, 30.+9.64313597e-17j],      dtype=complex128)