<a href="https://colab.research.google.com/github/abhilash1910/AI-Geometric-Learning/blob/master/Chapter_2_Understanding_the_data/Intrinsic_Mean.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

$\psi = \sum_{i=1}^n d^2(p,x_{i})$


$\mu = argmin_{p\in M}\sum_{i=1}^n d^2(p,x_{i})$

In [None]:
!pip install sklearn
!pip install geomstats

Collecting geomstats
  Downloading geomstats-2.3.1-py3-none-any.whl (10.1 MB)
[K     |████████████████████████████████| 10.1 MB 6.9 MB/s 
Collecting matplotlib>=3.3.4
  Downloading matplotlib-3.4.3-cp37-cp37m-manylinux1_x86_64.whl (10.3 MB)
[K     |████████████████████████████████| 10.3 MB 11.0 MB/s 
Installing collected packages: matplotlib, geomstats
  Attempting uninstall: matplotlib
    Found existing installation: matplotlib 3.2.2
    Uninstalling matplotlib-3.2.2:
      Successfully uninstalled matplotlib-3.2.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.[0m
Successfully installed geomstats-2.3.1 matplotlib-3.4.3


In [27]:
import torch
import numpy as np
import logging
import math
from sklearn.base import BaseEstimator
import geomstats.backend as gs
import geomstats.errors as error
import geomstats.vectorization
from geomstats.geometry.euclidean import Euclidean, EuclideanMetric
from geomstats.geometry.riemannian_metric import RiemannianMetric


EPSILON = 1e-4

class Fretchet_Mean():
  def variance(self,points, base_point, weights=None, point_type="vector"):
      """Variance of (weighted) points wrt a base point."""

      n_points = geomstats.vectorization.get_n_points(points, point_type)

      if weights is None:
          weights = gs.ones((n_points,))

      sum_weights = gs.sum(weights)
      sq_dists = metric.squared_dist(base_point, points)
      var = weights * sq_dists

      var = gs.sum(var)
      var /= sum_weights

      return var


  def f_mean(self,points, weights=None, point_type="vector"):
      """Compute the weighted linear mean.
      The linear mean is the Frechet mean when points:
      - lie in a Euclidean space with Euclidean metric,
      - lie in a Minkowski space with Minkowski metric.
      """
      if isinstance(points, list):
          points = gs.stack(points, axis=0)
      if isinstance(weights, list):
          weights = gs.array(weights)
      n_points = geomstats.vectorization.get_n_points(points, point_type)
      if weights is None:
          weights = gs.ones((n_points,))
      sum_weights = gs.sum(weights)

      einsum_str = "...,...j->...j"
      if point_type == "matrix":
          einsum_str = "...,...jk->...jk"

      weighted_points = gs.einsum(einsum_str, weights, points)

      mean = gs.sum(weighted_points, axis=0) / sum_weights
      return mean


  def gradient_descent(self,points, metric, weights, max_iter, point_type, epsilon, initial_step_size, verbose):
      """Perform default gradient descent."""
      if point_type == "vector":
          points = gs.to_ndarray(points, to_ndim=2)
          einsum_str = "n,nj->j"
      else:
          points = gs.to_ndarray(points, to_ndim=3)
          einsum_str = "n,nij->ij"
      n_points = gs.shape(points)[0]

      if weights is None:
          weights = gs.ones((n_points,))

      mean = points[0]

      if n_points == 1:
          return mean

      sum_weights = gs.sum(weights)
      sq_dists_between_iterates = []
      iteration = 0
      sq_dist = 0.0
      var = 0.0

      norm_old = gs.linalg.norm(points)
      step = initial_step_size
      #_n_points = gs.array([[2.0, 1.0], [-2.0, -4.0], [-5.0, 1.0]])
      #result = metric.squared_norm(_n_points,None)
      #print(metric,result)
      while iteration < max_iter:
          logs = metric.log(point=points, base_point=mean)
          print(logs)
          logs=gs.array(logs)
          print(type(logs))
          var = gs.sum(gs.sqrt(EuclideanMetric.norm(logs, mean)) * weights) / gs.sum(weights)

          tangent_mean = gs.einsum(einsum_str, weights, logs)
          tangent_mean /= sum_weights
          norm = gs.linalg.norm(tangent_mean)

          sq_dist = gs.sqrt(metric.norm(tangent_mean, mean))
          sq_dists_between_iterates.append(sq_dist)

          var_is_0 = gs.isclose(var, 0.0)
          sq_dist_is_small = gs.less_equal(sq_dist, epsilon * metric.dim)
          condition = ~gs.logical_or(var_is_0, sq_dist_is_small)
          if not (condition or iteration == 0):
              break

          estimate_next = metric.exp(step * tangent_mean, mean)
          mean = estimate_next
          iteration += 1

          if norm < norm_old:
              norm_old = norm
          elif norm > norm_old:
              step = step / 2.0

      if iteration == max_iter:
          logging.warning(
              "Maximum number of iterations {} reached. "
              "The mean may be inaccurate".format(max_iter)
          )

      if verbose:
          logging.info(
              "n_iter: {}, final variance: {}, final dist: {}".format(
                  iteration, var, sq_dist
              )
          )

      return mean

if __name__=='__main__':
  fretchet=Fretchet_Mean()
  points=np.asarray([[-0.58831187, -0.02677797,  0.80819062],[-0.55208236, -0.02669815,  0.83336203]],dtype=np.float32)
  metric=EuclideanMetric
  mean=fretchet.f_mean(points)
  print('Fretchet Mean:',mean)
  weights=None
  max_iter=100
  point_type="vector"
  epsilon=EPSILON
  initial_step_size=5
  verbose=True
  dim=2
  metric=EuclideanMetric(dim)
  gradient_mean=fretchet.gradient_descent(points, metric, weights, max_iter, point_type, epsilon, initial_step_size, verbose)
  print(gradient_mean)

INFO: n_iter: 100, final variance: 1.0000286732700217, final dist: 0.1565888306948323


Fretchet Mean: [-0.57019711 -0.02673806  0.82077634]
[[0.000000e+00 0.000000e+00 0.000000e+00]
 [3.622949e-02 7.981993e-05 2.517140e-02]]
<class 'numpy.ndarray'>
[[-0.09057373 -0.00019955 -0.0629285 ]
 [-0.05434424 -0.00011973 -0.0377571 ]]
<class 'numpy.ndarray'>
[[0.27172118 0.00059865 0.18878549]
 [0.30795068 0.00067847 0.21395689]]
<class 'numpy.ndarray'>
[[-0.45286864 -0.00099775 -0.31464249]
 [-0.41663915 -0.00091793 -0.28947109]]
<class 'numpy.ndarray'>
[[0.09057373 0.00019955 0.0629285 ]
 [0.12680322 0.00027937 0.0880999 ]]
<class 'numpy.ndarray'>
[[2.26434320e-02 4.98874579e-05 1.57321244e-02]
 [5.88729233e-02 1.29707390e-04 4.09035236e-02]]
<class 'numpy.ndarray'>
[[9.90650151e-03 2.18257628e-05 6.88280445e-03]
 [4.61359927e-02 1.01645695e-04 3.20542036e-02]]
<class 'numpy.ndarray'>
[[5.52818165e-03 1.21795551e-05 3.84085070e-03]
 [4.17576729e-02 9.19994877e-05 2.90122498e-02]]
<class 'numpy.ndarray'>
[[3.68107795e-03 8.11006128e-06 2.55752646e-03]
 [3.99105692e-02 8.79299939