<a href="https://colab.research.google.com/github/abhilash1910/AI-Geometric-Learning/blob/master/Covariant_Derivative_%26_Christoffel_Symbols.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Riemann Manifolds in Depth

This notebook contains the detailed implementation of Riemann transformations utilizing the [geomstats package](https://github.com/geomstats/geomstats/) and Pytorch. The Geodesic distance is computed by solving the flow equation and computing Christoffel Symbols.
$$\nabla_{\lambda} T^{\mu} = \frac{dT^{\mu}}{d \lambda} + \Gamma^{\mu}_{\kappa \nu} T^{\nu} \frac{dx^{\kappa}}{d \lambda} \ldotp$$

The the Christoffel symbol in terms of the metric in one dimension is defined by the equation:

$$\Gamma^{d}_{ba} = \frac{1}{2} g^{cd} (\partial_{?} g_{??})$$

where inversion of the one-component matrix G has been replaced by matrix inversion, and, more importantly, the question marks indicate that there would be more than one way to place the subscripts so that the result would be a grammatical tensor equation. The general equation:

$$\Gamma^{b}_{ac} = \frac{1}{2} g^{db} (L \partial_{c} g_{ab} + M \partial_{a} g_{cb} + N \partial_{b} g_{ca})$$

L,M,N being constants where if L+M+N=1 indicates single dimension. The resulting general expression for the Christoffel symbol in terms of the metric is:

$$\Gamma^{c}_{ab} = \frac{1}{2} g^{cd} (\partial_{a} g_{bd} + \partial_{b} g_{ad} - \partial_{d} g_{ab}) \ldotp$$

which is the solution of :
$$\nabla_{c} g_{ab} = 0. $$

In this case, we will be using the solution of Christoffel symbols for computing log and exponential maps, inner products and derivatives.Additional details on Christoffel symbols are mentioned in this [article](https://jmureika.lmu.build/PHYS471/InClass/Christoffel.pdf).


In [None]:
!pip install geomstats

In [None]:
torch.__version__

'1.9.0+cu102'

In [None]:
"""Implementation of Riemann metric and manifold based on solutions of Covariant derivative equations and Christoffel symbols (adapted from Geomstats implementation)"""
import autograd
import autograd.numpy as np
import joblib
from scipy.optimize import minimize
import torch
import geomstats.errors
import geomstats.backend as gs
import geomstats.geometry as geometry
from geomstats.geometry.connection import Connection
from geomstats.integrator import integrate
EPSILON = 1e-4
N_CENTERS = 10
N_REPETITIONS = 20
N_MAX_ITERATIONS = 50000
N_STEPS = 2



class ManifoldMetric():
    """Class for Riemannian and pseudo-Riemannian metrics.
    """


    def inner_product_inverse_matrix(self, base_point):
        """Inner product matrix at the tangent space at a base point.
        """
        metric_matrix = (base_point)
        print("Metric Matrix should be invertible-> square matrix")
        cometric_matrix = torch.linalg.inv(torch.tensor(metric_matrix))
        return cometric_matrix

    def entropy_loss(self,base_point):
        return -base_point*torch.log(base_point)

    def inner_product_derivative_matrix(self,base_point):
        """Compute derivative of the inner prod matrix at base point.
        """
        base_point=torch.tensor(base_point)
        cal=torch.autograd.functional.jacobian(self.entropy_loss,base_point)
        return cal.detach().clone()


    def christoffels(self, base_point):
        """Compute Christoffel symbols associated with the connection.
        """
        cometric_mat_at_point = self.inner_product_inverse_matrix(base_point)
        metric_derivative_at_point = self.inner_product_derivative_matrix(
            base_point)
        term_1 = torch.einsum('...im,...mkl->...ikl',
                           cometric_mat_at_point,
                           metric_derivative_at_point)
        term_2 = gs.einsum('...im,...mlk->...ilk',
                           cometric_mat_at_point,
                           metric_derivative_at_point)
        term_3 = - gs.einsum('...im,...klm->...ikl',
                             cometric_mat_at_point,
                             metric_derivative_at_point)

        christoffels = 0.5 * (term_1 + term_2 + term_3)
        return christoffels


    def inner_product(self, tangent_vec_a, tangent_vec_b, base_point):
        """Inner product between two tangent vectors at a base point."""
        inner_prod_mat = torch.tensor(base_point)
        tangent_vec_a=torch.tensor(tangent_vec_a)
        tangent_vec_b=torch.tensor(tangent_vec_b)
        aux = torch.einsum('...j,...jk->...k', tangent_vec_a, inner_prod_mat)
        inner_prod = torch.einsum('...k,...k->...', aux, tangent_vec_b)
        return inner_prod


    def squared_norm(self, vector, base_point=None):
        """Compute the square of the norm of a vector.

        Squared norm of a vector associated to the inner product
        at the tangent space at a base point.
        """
        sq_norm = self.inner_product(vector, vector, base_point)
        return sq_norm

    
    def norm(self, vector, base_point=None):
        """Compute norm of a vector.        
        """
        sq_norm = self.squared_norm(vector, base_point)
        norm = gs.sqrt(sq_norm)
        return norm
    
    def exp(self, tangent_vec, base_point, n_steps=N_STEPS, step='euler',
            point_type=None, **kwargs):
        """Exponential map associated to the affine connection.
        Exponential map at base_point of tangent_vec computed by integration
        of the geodesic equation (initial value problem), using the
        christoffel symbols.
        
       """
        initial_state = gs.stack([base_point, tangent_vec])
        base_point=torch.tensor(base_point)
        flow=Connection.geodesic_equation(self,tangent_vec, tangent_vec)
        print('Flow Equation Solution:',flow)
        exp = flow[-1][0]
        print("Exponential map in Riemann manifold: ",exp)
        return exp
    
    def log(self, point, base_point, n_steps=N_STEPS, step='euler',
            max_iter=25, verbose=False, tol=None):
      
        """Compute logarithm map associated to the affine connection.
        Solve the boundary value problem associated to the geodesic equation
        using the Christoffel symbols and conjugate gradient descent.
        Parameters
        
        """
        max_shape = point.shape if point.ndim > base_point.ndim else base_point.shape

        def objective(velocity):
            """Define the objective function."""
            velocity_li=[]
            for i in velocity:
              velocity_li.append(i._value)
            velocity = torch.tensor(velocity_li)
            #velocity = gs.cast(velocity, dtype=base_point.dtype)
            velocity = torch.reshape(velocity, max_shape)
            delta = self.exp(velocity, base_point, n_steps, step) - point
            return gs.sum(delta ** 2)

        objective_with_grad = autograd.value_and_grad(objective)
        tangent_vec = gs.flatten(gs.random.rand(*max_shape))
        res = minimize(
            objective_with_grad, tangent_vec, method='L-BFGS-B', jac=True,
            options={'disp': verbose, 'maxiter': max_iter}, tol=None)

        tangent_vec = gs.array(res.x)
        tangent_vec = gs.reshape(tangent_vec, max_shape)
        tangent_vec = gs.cast(tangent_vec, dtype=base_point.dtype)
        print("Logarithmic map in Riemann manifold: ",tangent_vec)
        return tangent_vec
    
    def squared_dist(self, point_a, point_b):
        """Squared geodesic distance between two points.
        """
        point_a=gs.array(point_a)
        point_b=gs.array(point_b)
        log_map = torch.tensor(self.log(point_a,point_b))
        sq_dist = self.squared_norm(vector=log_map, base_point=point_a)
        return torch.norm(sq_dist)**2,sq_dist


    def dist(self, point_a, point_b):
        """Geodesic distance between two points.

        Note: It only works for positive definite
        Riemannian metrics.
        """
        
        sq_dist,sq_tensor = self.squared_dist(point_a, point_b)
        dist = gs.sqrt(sq_dist)
        return dist,sq_tensor

    def orthonormal_basis(self, basis, base_point):
        """Orthonormalize the basis with respect to the metric.
        """
        norms = self.squared_norm(basis, base_point)
        return torch.einsum('...i,...ikl->...ikl',    (1. / (gs.sqrt(norms)+EPSILON)), basis)

if __name__=='__main__':
    #Create an object of RiemanMetric class
    m=ManifoldMetric()

    #Create 2 tensors /tangents at 2 base points for geodesic distance
    pa=torch.tensor([[[-0.65726771, -0.02678122,  0.7531812]]])
    pb=torch.tensor([[[1.02, 0, 0.8]]])
    pc=torch.tensor([[[-0.58831187, -0.02677797,  0.80819062]]])
    print(f"The base points at which the tangents are drawn on the manifold: {pa} -> {pb}")
    print(f"The dimensions for measuring inner product with Einstein sum: {pa.shape} -> {pb.shape}")

    #Create 2 tensors /tangents at 2 base points for geodesic distance (tensors need to be square for inverse to exist)
    pa_1=torch.tensor([[-0.65726771, -0.02678122,  0.7531812],[-0.65726771, -0.02678122,  0.7531812],[-0.65726771, -0.02678122,  0.7531812]])
    pb_1=torch.tensor([[-0.58831187, -0.02677797,  0.80819062],[-0.58831187, -0.02677797,  0.80819062],[-0.58831187, -0.02677797,  0.80819062]])
    pc_1=torch.tensor([[0.5892, 0.7285, 0.0756],[0.7464, 0.9279, 0.5770],[0.8678, 0.6884, 0.3460]])
    
    print(f"Creating an orthonormal basis w.r.t base point {pa} w.r.t {pb}")
    val=m.orthonormal_basis(pa,pb)
    print("Basis Tensor values: ", val)
    print(f"Inner product of 2 tangents {pa}  & {pb} at base point {pc}")
    inner_product=m.inner_product(pa,pb,pc)
    print("Inner product values: " ,inner_product)
    print(f"Christoffel symbols associated with the affine connection at {pc_1} base point")
    chris_coef=m.christoffels(pc_1)
    print(f"Christoffel Coefficients: ",chris_coef)
    print(f"Geodesic Distance computation by solving Boudary value problem between points {pa_1} & {pb_1} (points need to be square tensors for inverse to exist) ")
    geodesic_dist,geodesic_tensor=m.dist(pa_1,pb_1)
    print("Geodesic distance: ",geodesic_dist)
    print("Geodesic Tensor: ",geodesic_tensor)
    

The base points at which the tangents are drawn on the manifold: tensor([[[-0.6573, -0.0268,  0.7532]]]) -> tensor([[[1.0200, 0.0000, 0.8000]]])
The dimensions for measuring inner product with Einstein sum: torch.Size([1, 1, 3]) -> torch.Size([1, 1, 3])
Creating an orthonormal basis w.r.t base point tensor([[[-0.6573, -0.0268,  0.7532]]]) w.r.t tensor([[[1.0200, 0.0000, 0.8000]]])
Basis Tensor values:  tensor([[[[nan, nan, nan]]]])
Inner product of 2 tangents tensor([[[-0.6573, -0.0268,  0.7532]]])  & tensor([[[1.0200, 0.0000, 0.8000]]]) at base point tensor([[[-0.5883, -0.0268,  0.8082]]])
Inner product values:  tensor([[0.0032]])
Christoffel symbols associated with the affine connection at tensor([[0.5892, 0.7285, 0.0756],
        [0.7464, 0.9279, 0.5770],
        [0.8678, 0.6884, 0.3460]]) base point
Metric Matrix should be invertible-> square matrix
Christoffel Coefficients:  tensor([[[[ 1.6343e-01,  1.2453e+00,  5.0495e+00],
          [-6.2267e-01,  0.0000e+00,  0.0000e+00],
     

  return f_raw(*args, **kwargs)
