In [1]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import struct

import optax

import numpy as np
from flax import struct 
from clu import metrics
from dataclasses import field
from functools import partial
from typing import Any, Callable, Optional, Tuple, Union
from flax import struct  

from clu import metrics

from chex import Array
from flax import linen as nn
from flax.linen.activation import sigmoid, tanh
from flax.linen.dtypes import promote_dtype
from flax.linen.initializers import orthogonal
from flax.linen.linear import default_kernel_init
from jax import numpy as jnp
from jax import random, vmap
from chex import Array
from jax import lax
from jax.nn.initializers import Initializer as Initializer
from jax._src import dtypes

from flax.training import train_state

from rieoptax.geometry.hyperbolic import PoincareBall


@struct.dataclass
class Metrics(metrics.Collection):
  loss: metrics.Average.from_output('loss')

  jax.tree_util.register_keypaths(


### Learning target

In [8]:

# from distributions import NormalPoincareBall


key = jax.random.PRNGKey(0)
import numpy as np

class NormalPoincareBall:
    def __init__(self, dim, c):
        self.dim = dim
        self.c = c
        self.ball = PoincareBall(dim, c)
        self.base_point = jnp.zeros(dim)
    
    def sample(self):
        seed = np.random.randint(1000)
        key = jax.random.PRNGKey(seed)
        key, subkey = jax.random.split(key)
        euclidian_sample = jax.random.normal(key, shape=self.dim)
        poincare_ball_sample = self.ball.exp(self.base_point, euclidian_sample)
        return poincare_ball_sample


def nonlin(x):
    return 4 * x ** 3 +  x ** 2 + 2 * x

class NonLinFnPoincareBall:
    def __init__(self, in_dim=10, out_dim=5, c=-1):
        self.c = c
        self.in_dim, self.out_dim = in_dim, out_dim
        global key
        key, subkey = jax.random.split(key)
        self.theta_vec = jax.random.normal(key, shape=(out_dim, in_dim))
        
        poincare_ball_normal_bias = NormalPoincareBall(dim=(out_dim,), c=-1)
        self.theta_bias = poincare_ball_normal_bias.sample()

    def apply_fn(self, x):
        ball_in = PoincareBall(self.in_dim, self.c)
        Ax = ball_in.mobius_matvec(self.theta_vec, x)
        ball_out = PoincareBall(self.out_dim, self.c)
        Ax_plus_b = ball_out.mobius_add(Ax, self.theta_bias)
        
        output = ball_out.log(jnp.zeros_like(Ax_plus_b), Ax_plus_b)
        output = nonlin(output)
        output = ball_out.exp(jnp.zeros_like(output), output)
        return output
    
real_fn_poincare = NonLinFnPoincareBall(in_dim=10, out_dim=5)

poincare_ball_normal_x = NormalPoincareBall(dim=(10,), c=-1)

sampled_poincare = poincare_ball_normal_x.sample()
sampled_y = real_fn_poincare.apply_fn(sampled_poincare)

jnp.linalg.norm(sampled_y)


Array(0.9999899, dtype=float32)

### Poincare Linear

In [33]:

in_dim, out_dim = 10, 5
c = -1

class PoincareLinear(nn.Module):
    
    param_dtype = jnp.float32
    kernel_init: Callable = default_kernel_init
    bias_init: Callable = nn.initializers.uniform(scale= 1 / jnp.sqrt(in_dim))
        
    in_dim, out_dim = in_dim, out_dim
    c = c
    
    def setup(self):
        self.scalars = self.param(
            "scalars",
            self.kernel_init,
            (self.out_dim, self.in_dim),
            self.param_dtype,
        )
    
        self.bias_poincare = self.param(
            "bias_poincare",
            self.bias_init,
            (self.out_dim,),
            self.param_dtype,
        )

    def __call__(self, x):
        ball_in = PoincareBall(self.in_dim, self.c)
        y = ball_in.mobius_matvec(self.scalars, x)
        
        ball_out = PoincareBall(self.out_dim, self.c)

        return ball_out.mobius_add(y, self.bias_poincare)


### Training procedure

In [2]:
batch_size = 1

import flax
import numpy as np
import optax
    
from flax.training import train_state
class TrainStateRiemannian(train_state.TrainState):
    metrics: Metrics
    
    def apply_gradients(self, *, grads, **kwargs):
        """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.

        Note that internally this function calls `.tx.update()` followed by a call
        to `optax.apply_updates()` to update `params` and `opt_state`.

        Args:
          grads: Gradients that have the same pytree structure as `.params`.
          **kwargs: Additional dataclass attributes that should be `.replace()`-ed.

        Returns:
          An updated instance of `self` with `step` incremented by one, `params`
          and `opt_state` updated by applying `grads`, and additional attributes
          replaced as specified by `kwargs`.
        """
        updates, new_opt_state = self.tx.update(
            grads, self.opt_state, self.params)
        
        
        old_bias_params = self.params['bias_poincare']
        
        ball_bias = PoincareBall(out_dim, c)
        
        bias_poincare_grads = updates['bias_poincare']
        
        r_grad_poincare_bias = ball_bias.egrad_to_rgrad(old_bias_params, bias_poincare_grads)
        
        updates = updates.unfreeze()
        updates['bias_poincare'] = r_grad_poincare_bias
        updates = flax.core.frozen_dict.freeze(updates)
        
        new_params = optax.apply_updates(self.params, updates)
        
        bias_poincare_new = new_params['bias_poincare']
        new_params = new_params.unfreeze()
        lr = self.lr
        tv = lr * r_grad_poincare_bias
        
        new_params['bias_poincare'] = ball_bias.exp(old_bias_params, tv)
        
        new_params = flax.core.frozen_dict.freeze(new_params)
        
        return self.replace(
            step=self.step + 1,
            params=new_params,
            opt_state=new_opt_state,
            **kwargs,
        )



def create_train_state(module, rng, learning_rate):
  """Creates an initial `TrainState`."""
  from flax.training import train_state
  class TrainState(TrainStateRiemannian):
    metrics: Metrics
    lr = learning_rate
  params = module.init(rng, jnp.ones([in_dim]))['params'] # initialize parameters by passing a template image
  tx = optax.sgd(learning_rate)
  return TrainState.create(
      apply_fn=module.apply, params=params, tx=tx,
      metrics=Metrics.empty())


# @jax.jit
def train_step(state, batch_x, batch_y):
  """Train for a single step."""

  def loss_fn(params):
    y = state.apply_fn({'params': params}, batch_x)
    loss = ((y - batch_y) ** 2).sum()
    return loss
  
  grad_fn = jax.grad(loss_fn)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

@jax.jit
def compute_metrics(*, state, batch_x, batch_y):
    y = state.apply_fn({'params': state.params}, batch_x)
    loss = ((y - batch_y) ** 2).mean()
    metric_updates = state.metrics.single_from_model_output(loss=loss)
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)
    return state


In [3]:

poincare_linear = PoincareLinear()
train_state = create_train_state(poincare_linear, jax.random.PRNGKey(43), 1e-02)

n_iter = 10000

for i in range(n_iter):
    
    batch_x = poincare_ball_normal_x.sample()
    batch_y = real_fn_poincare.apply_fn(batch_x)
    
    train_state = train_step(train_state, batch_x, batch_y)
    train_state = compute_metrics(state=train_state, batch_x=batch_x,batch_y=batch_y)
    
    loss = train_state.metrics.compute()['loss']
    if i % 100 == 0:
        print(f'{i} Loss: {loss}')


NameError: name 'PoincareLinear' is not defined

### Poincare MLP

In [4]:

in_dim, out_dim = 10, 5

hidden_dim = 100

dims = [in_dim, hidden_dim, out_dim]
c = -1

class PoincareMLP(nn.Module):
    
    param_dtype = jnp.float32
    kernel_init: Callable = default_kernel_init
    bias_init: Callable = nn.initializers.uniform(scale= 1 / jnp.sqrt(in_dim))
        
    in_dim, out_dim = in_dim, out_dim
    hidden_dim = hidden_dim
    c = c
    act = nn.sigmoid
    
    def setup(self):
        self.scalars_1 = self.param(
            "scalars@1",
            self.kernel_init,
            (self.hidden_dim, self.in_dim),
            self.param_dtype,
        )
    
        self.bias_poincare_1 = self.param(
            "bias_poincare@1",
            self.bias_init,
            (self.hidden_dim,),
            self.param_dtype,
        )
        
        self.scalars_2 = self.param(
            "scalars@2",
            self.kernel_init,
            (self.out_dim, self.hidden_dim),
            self.param_dtype,
        )
    
        self.bias_poincare_2 = self.param(
            "bias_poincare@2",
            self.bias_init,
            (self.out_dim,),
            self.param_dtype,
        )
        
        self.balls = {'bias_poincare@1': PoincareBall(self.hidden_dim, self.c),
                      'bias_poincare@2': PoincareBall(self.out_dim, self.c)}
        

    def __call__(self, x):
        # Linear 1
        ball_in = PoincareBall(self.in_dim, self.c)
        Ax = ball_in.mobius_matvec(self.scalars_1, x)
        ball_hidden = PoincareBall(self.hidden_dim, self.c)
        Ax_b = ball_hidden.mobius_add(Ax, self.bias_poincare_1)
        
        # Activation
        activation_hid = ball_hidden.log(jnp.zeros_like(Ax_b), Ax_b)
        activation_hid = nonlin(activation_hid)
        activation_hid = ball_hidden.exp(jnp.zeros_like(activation_hid), activation_hid)
        
        # Linear 2
        ball_out = PoincareBall(self.out_dim, self.c)
        output = ball_in.mobius_matvec(self.scalars_2, activation_hid)
        output = ball_out.mobius_add(output, self.bias_poincare_2)
        return output


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [9]:

from flax.training import train_state
class TrainStateRiemannianMLP(train_state.TrainState):
    metrics: Metrics
    
    balls = {'bias_poincare@1': PoincareBall(hidden_dim, c),
                  'bias_poincare@2': PoincareBall(out_dim, c)}
    
    def apply_gradients(self, *, grads, **kwargs):
        """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.

        Note that internally this function calls `.tx.update()` followed by a call
        to `optax.apply_updates()` to update `params` and `opt_state`.

        Args:
          grads: Gradients that have the same pytree structure as `.params`.
          **kwargs: Additional dataclass attributes that should be `.replace()`-ed.

        Returns:
          An updated instance of `self` with `step` incremented by one, `params`
          and `opt_state` updated by applying `grads`, and additional attributes
          replaced as specified by `kwargs`.
        """
        updates, new_opt_state = self.tx.update(
            grads, self.opt_state, self.params)
        
        
        poincare_param_names = self.balls.keys()
        r_grad_poincare_bias, old_bias_params = {}, {}
        for name in poincare_param_names:
            
            old_bias_params[name] = self.params[name]
#             print('old_bias_params', old_bias_params.shape)
#             ball_bias = PoincareBall(out_dim, c)
            bias_poincare_grads = updates[name]
            r_grad_poincare_bias[name] = self.balls[name].egrad_to_rgrad(old_bias_params[name], bias_poincare_grads)
        
        updates = updates.unfreeze()
        for name in poincare_param_names:
            updates[name] = r_grad_poincare_bias[name]
            
        updates = flax.core.frozen_dict.freeze(updates)
        
        new_params = optax.apply_updates(self.params, updates)
        
        lr = self.lr
        new_params = new_params.unfreeze()
        for name in poincare_param_names:
            bias_poincare_new = new_params[name]
            tv = lr * r_grad_poincare_bias[name]
            new_params[name] = self.balls[name].exp(old_bias_params[name], tv)
        
        new_params = flax.core.frozen_dict.freeze(new_params)
        
        return self.replace(
            step=self.step + 1,
            params=new_params,
            opt_state=new_opt_state,
            **kwargs,
        )

In [30]:
loss_safe_grads_coef = 100
@jax.jit
def train_step(state, batch_x, batch_y):
  """Train for a single step."""

  def loss_fn(params):
    y = state.apply_fn({'params': params}, batch_x)
    loss = 1000 * ((y - batch_y) ** 2).sum()
    return loss
  
  grad_fn = jax.grad(loss_fn)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

def create_train_state(module, rng, learning_rate):
  """Creates an initial `TrainState`."""
  from flax.training import train_state
  class TrainState(TrainStateRiemannianMLP):
    metrics: Metrics
    lr = learning_rate
  params = module.init(rng, jnp.ones([in_dim]))['params'] # initialize parameters by passing a template image
  tx = optax.sgd(learning_rate)
  return TrainState.create(
      apply_fn=module.apply, params=params, tx=tx,
      metrics=Metrics.empty())

In [37]:

poincare_linear = PoincareMLP()
train_state = create_train_state(poincare_linear, jax.random.PRNGKey(42), 1e-03)

n_iter = 10000

for i in range(n_iter):
    
    batch_x = poincare_ball_normal_x.sample()
    batch_y = real_fn_poincare.apply_fn(batch_x)
    
    train_state = train_step(train_state, batch_x, batch_y)
    train_state = compute_metrics(state=train_state, batch_x=batch_x,batch_y=batch_y)
    
    loss = train_state.metrics.compute()['loss']
    if i % 100 == 0:
        print(f'{i} Loss: {loss}')


0 Loss: 0.05742352083325386
100 Loss: 0.11213408410549164
200 Loss: 0.07124282419681549
300 Loss: 0.05154486745595932
400 Loss: 0.040676213800907135
500 Loss: 0.03365926817059517
600 Loss: 0.028710979968309402
700 Loss: 0.025007419288158417
800 Loss: 0.022199496626853943
900 Loss: 0.020025629550218582
1000 Loss: 0.018245309591293335
1100 Loss: 0.016704751178622246
1200 Loss: 0.01546729076653719
1300 Loss: 0.014344530180096626
1400 Loss: 0.013374610804021358
1500 Loss: 0.012529105879366398
1600 Loss: 0.011789258569478989
1700 Loss: 0.011152444407343864
1800 Loss: 0.010583269409835339
1900 Loss: 0.010052936151623726
2000 Loss: 0.009694541804492474
2100 Loss: 0.009276802651584148
2200 Loss: 0.008880427107214928
2300 Loss: 0.008538922294974327
2400 Loss: 0.00823428574949503
2500 Loss: 0.007926433347165585
2600 Loss: 0.007645555306226015
2700 Loss: 0.007375412620604038
2800 Loss: 0.007122365292161703
2900 Loss: 0.006889529526233673


KeyboardInterrupt: 

In [None]:
train_state