Skip to content

Commit

Permalink
Poincare ball model (geoopt#45)
Browse files Browse the repository at this point in the history
* add base

* add mobius add|sub

* fix

* missing formulas

* remove unused import

* add scalar mul, test props

* unnessesary cons in project

* no cover script functions

* add distance

* fix typo in comment

* add geodesics

* add expmap

* add functions

* add singlt apply

* black

* fix typos in docs

* fix typos in docs

* add parallel transport

* add dist to a plane and parallel transport. Parallel transport is numerically unstable

* fix math bugs

* add cool plots

* fix small things

* add egrad2rgrad

* add reference

* docs

* fix typos

* finish Poincare ball implementation

* fix small typo

* add to inifinite and beyond test

* add signed distance

* infinity and beyond test

* black

* docfix

* fix docs

* fix doc

* fix docs typos

* add import

* add dist0

* optim fails

* fix numerics, do not repare broken test

* black

* some refactoring

* fix typo

* p.data -> p in optim

* update docs a bit

* split pr

* remove torch script (it gave minor improvemets), delay to pytorch/pytorch#14455 resolution

* fix coadd impl

* coma typo in docs

* nan police float32

* nan police! arcsinh

* typo

* nan police scripted!\nwratpping artanh in a script function results in umstable behavior

* tests

* fix typo

* another test for parallel transport 0

* random doc fix to make typechecker happy

* manifold->module migration fix

* black

* fix test for poincare (autocast double)

* add float32 tests

* fix typo

* rename project->clip tangent

* docs

* fix side effect in tests

* infinity anb beyond test was failing in torch==1.0.1 but not in torch_nightly, acceptable tolerance differs

* add dim argument for poincare math

* batched matvec

* typo in dist formula

* fix tracing issues and grad numerics for Arsinh,Artanh

* _max_norm, specify device + dtype

* clamp before save to backward in artanh

* inplace ops in function impl

* black

* fix typo

* fix spelling

* some fixes to docs

* euclidean -> Euclidean

* black

* math font for number

* random travis fail?

* pytorch future reminder
  • Loading branch information
ferrine committed Mar 31, 2019
1 parent 9451655 commit e68796e
Show file tree
Hide file tree
Showing 34 changed files with 2,358 additions and 108 deletions.
88 changes: 44 additions & 44 deletions docs/conf.py
Expand Up @@ -36,34 +36,35 @@
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx.ext.napoleon',
"matplotlib.sphinxext.plot_directive",
"sphinx.ext.autodoc",
"sphinx.ext.coverage",
"sphinx.ext.mathjax",
"sphinx.ext.viewcode",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]

# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
source_suffix = ".rst"

# The encoding of source files.
#
# source_encoding = 'utf-8-sig'

# The master toctree document.
master_doc = 'index'
master_doc = "index"

# General information about the project.
project = u'geoopt'
copyright = u'2018, Max Kochurov'
author = u'Max Kochurov'
project = u"geoopt"
copyright = u"2018, Max Kochurov"
author = u"Max Kochurov"

# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
Expand Down Expand Up @@ -93,7 +94,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

# The reST default role (used for this markup: `text`) to use for all
# documents.
Expand All @@ -115,7 +116,7 @@
# show_authors = False

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
pygments_style = "sphinx"

# A list of ignored prefixes for module index sorting.
# modindex_common_prefix = []
Expand All @@ -132,7 +133,7 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'alabaster'
html_theme = "alabaster"

# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
Expand Down Expand Up @@ -166,7 +167,7 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]

# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied
Expand Down Expand Up @@ -246,34 +247,30 @@
# html_search_scorer = 'scorer.js'

# Output file base name for HTML help builder.
htmlhelp_basename = 'geooptdoc'
htmlhelp_basename = "geooptdoc"

# -- Options for LaTeX output ---------------------------------------------

latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',

# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',

# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',

# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}

# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'geoopt.tex', u'geoopt Documentation',
u'Max Kochurov', 'manual'),
(master_doc, "geoopt.tex", u"geoopt Documentation", u"Max Kochurov", "manual")
]

# The name of an image file (relative to this directory) to place at the top of
Expand Down Expand Up @@ -313,10 +310,7 @@

# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'geoopt', u'geoopt Documentation',
[author], 1)
]
man_pages = [(master_doc, "geoopt", u"geoopt Documentation", [author], 1)]

# If true, show URL addresses after external links.
#
Expand All @@ -329,9 +323,15 @@
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'geoopt', u'geoopt Documentation',
author, 'geoopt', 'One line description of project.',
'Miscellaneous'),
(
master_doc,
"geoopt",
u"geoopt Documentation",
author,
"geoopt",
"One line description of project.",
"Miscellaneous",
)
]

# Documents to append as an appendix to all manuals.
Expand All @@ -352,7 +352,7 @@

# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'python': ('https://docs.python.org/', None),
'torch': ('https://pytorch.org/docs/master/', None),
"numpy": ("http://docs.scipy.org/doc/numpy/", None),
"python": ("https://docs.python.org/", None),
"torch": ("https://pytorch.org/docs/master/", None),
}
6 changes: 3 additions & 3 deletions docs/devguide.rst
@@ -1,5 +1,5 @@
Extending ``geoopt``
====================
Developer Guide
===============

Base Manifold
-------------
Expand All @@ -9,7 +9,7 @@ The common base class for all manifolds is :class:`geoopt.manifolds.base.Manifol
.. autoclass:: geoopt.manifolds.base.Manifold
:private-members:
:members:

:noindex:

Metaclass
---------
Expand Down
7 changes: 7 additions & 0 deletions docs/extended.rst
@@ -0,0 +1,7 @@
Extended Guide
==============

.. toctree::
:maxdepth: 1

extended/poincare
117 changes: 117 additions & 0 deletions docs/extended/poincare.rst
@@ -0,0 +1,117 @@
Poincare Ball model
===================

Poincare ball model is a compact representation of hyperbolic space.
To have a nice introduction into this model we should start from
simple concepts, putting them all together to build a more complete picture.

Hyperbolic spaces
-----------------

Hyperbolic space is a constant negative curvature Riemannian manifold.
A very simple example of Riemannian manifold with constant, but positive curvature is sphere.

An (N+1)-dimensional hyperboloid spans the manifold that can be embedded into N-dimensional space via projections.

.. figure:: ../plots/extended/poincare/hyperboloid_projection.png
:width: 300

img source `Wikipedia, Hyperboloid Model <https://en.wikipedia.org/wiki/Hyperboloid_model/>`_

Originally, the distance between points on the hyperboloid is defined as

.. math::
d(x, y) = \operatorname{arccosh}(x, y)
It is difficult to work in (N+1)-dimensional space and there is a range of useful embeddings
exist in literature

Klein Model
~~~~~~~~~~~

.. figure:: ../plots/extended/poincare/klein_tiling.png
:width: 300

img source `Wikipedia, Klein Model <https://en.wikipedia.org/wiki/Beltrami-Klein_model/>`_


Poincare Model
~~~~~~~~~~~~~~

.. figure:: ../plots/extended/poincare/poincare_lines.gif
:width: 300

img source `Bulatov, Poincare Model <http://bulatov.org/math/1001/>`_

Here we go.

First of all we note, that Poincare ball is embedded in a Sphere of radius :math:`r=1/\sqrt{c}`,
where c is negative curvature. We also note, as :math:`c` goes to :math:`0`, we recover infinite radius ball.
We should expect this limiting behaviour recovers Euclidean geometry.

To connect Euclidean space with its embedded manifold we need to get :math:`g_x`.
It is done via `conformal factor` :math:`\lambda^c_x`.


.. autofunction:: geoopt.manifolds.poincare.math.lambda_x


:math:`\lambda^c_x` connects Euclidean inner product with Riemannian one

.. autofunction:: geoopt.manifolds.poincare.math.inner
.. autofunction:: geoopt.manifolds.poincare.math.norm
.. autofunction:: geoopt.manifolds.poincare.math.egrad2rgrad

Math
----
The good thing about Poincare ball is that it forms a Gyrogroup. Minimal definition of a Gyrogroup
assumes a binary operation :math:`*` defined that satisfies a set of properties.

Left identity
For every element :math:`a\in G` there exist :math:`e\in G` such that :math:`e * a = a`.
Left Inverse
For every element :math:`a\in G` there exist :math:`b\in G` such that :math:`b * a = e`
Gyroassociativity
For any :math:`a,b,c\in G` there exist :math:`gyr[a, b]c\in G` such that :math:`a * (b * c)=(a * b) * gyr[a, b]c`
Gyroautomorphism
:math:`gyr[a, b]` is a magma automorphism in G
Left loop
:math:`gyr[a, b] = gyr[a * b, b]`

As mentioned above, hyperbolic space forms a Gyrogroup equipped with

.. autofunction:: geoopt.manifolds.poincare.math.mobius_add
.. autofunction:: geoopt.manifolds.poincare.math.gyration

Using this math, it is possible to define another useful operations

.. autofunction:: geoopt.manifolds.poincare.math.mobius_sub
.. autofunction:: geoopt.manifolds.poincare.math.mobius_scalar_mul
.. autofunction:: geoopt.manifolds.poincare.math.mobius_pointwise_mul
.. autofunction:: geoopt.manifolds.poincare.math.mobius_matvec
.. autofunction:: geoopt.manifolds.poincare.math.mobius_fn_apply
.. autofunction:: geoopt.manifolds.poincare.math.mobius_fn_apply_chain

Manifold
--------
Now we are ready to proceed with studying distances, geodesics, exponential maps and more

.. autofunction:: geoopt.manifolds.poincare.math.dist
.. autofunction:: geoopt.manifolds.poincare.math.dist2plane
.. autofunction:: geoopt.manifolds.poincare.math.parallel_transport
.. autofunction:: geoopt.manifolds.poincare.math.geodesic
.. autofunction:: geoopt.manifolds.poincare.math.geodesic_unit
.. autofunction:: geoopt.manifolds.poincare.math.expmap
.. autofunction:: geoopt.manifolds.poincare.math.expmap0
.. autofunction:: geoopt.manifolds.poincare.math.logmap
.. autofunction:: geoopt.manifolds.poincare.math.logmap0


Stability
---------
Numerical stability is a pain in this model. It is strongly recommended to work in ``float64``,
so expect adventures in ``float32`` (but this is not certain).

.. autofunction:: geoopt.manifolds.poincare.math.project
.. autofunction:: geoopt.manifolds.poincare.math.clip_tangent
1 change: 1 addition & 0 deletions docs/index.rst
Expand Up @@ -17,6 +17,7 @@ API
optimizers
tensors
samplers
extended
devguide

Indices and tables
Expand Down
8 changes: 3 additions & 5 deletions docs/manifolds.rst
Expand Up @@ -4,13 +4,11 @@ Manifolds
.. currentmodule:: geoopt.manifolds


All manifolds share same API. In order not to duplicate the same information, the complete public API is provided only for :class:`geoopt.manifolds.Euclidean` in the end of this file.
All manifolds share same API. In order not to duplicate the same information, the complete public API is provided only for :class:`geoopt.manifolds.Manifold` in the end of this file.

.. automodule:: geoopt.manifolds
:members: Stiefel, Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection
:members: Euclidean, Stiefel, Sphere, SphereSubspaceComplementIntersection, SphereSubspaceIntersection, PoincareBall


.. autoclass:: geoopt.manifolds.Euclidean
.. autoclass:: geoopt.manifolds.base.Manifold
:members:
:inherited-members:

27 changes: 27 additions & 0 deletions docs/plots/extended/poincare/distance.py
@@ -0,0 +1,27 @@
import geoopt.manifolds.poincare.math as pmath
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("white")
radius = 1
coords = np.linspace(-radius, radius, 100)
x = torch.tensor([-0.75, 0])
xx, yy = np.meshgrid(coords, coords)
dist2 = xx ** 2 + yy ** 2
mask = dist2 <= radius ** 2
grid = np.stack([xx, yy], axis=-1)
dists = pmath.dist(torch.from_numpy(grid).float(), x)
dists[(~mask).nonzero()] = np.nan
circle = plt.Circle((0, 0), 1, fill=False, color="b")
plt.gca().add_artist(circle)
plt.xlim(-1.1, 1.1)
plt.ylim(-1.1, 1.1)
plt.gca().set_aspect("equal")
plt.contourf(
grid[..., 0], grid[..., 1], dists.log().numpy(), levels=100, cmap="inferno"
)
plt.colorbar()
plt.title("log distance to ($-$0.75, 0)")
plt.show()

0 comments on commit e68796e

Please sign in to comment.