In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import optax
import e3nn_jax as e3nn

from ott.geometry import costs, grid, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

## 1D grid example

In [None]:
num_radii = 3
grid_size = (num_radii,)
radii = jnp.asarray([1.0, 2.0, 3.0])
geom = grid.Grid(x=[radii])

def compute_cost(predicted_logits):
    predicted_histogram = jax.nn.softmax(predicted_logits)
    target_histogram = jax.nn.one_hot(jnp.array([1]), num_radii).reshape(-1)
    prob = linear_problem.LinearProblem(geom, a=predicted_histogram, b=target_histogram)
    solver = sinkhorn.Sinkhorn()
    out = solver(prob)
    return out.reg_ot_cost

def optimize_predicted_logits():
    tx = optax.adam(1e-2)
    init_logits = jnp.ones(num_radii)
    opt_state = tx.init(init_logits)

    @jax.jit
    def update_fn(predicted_logits, opt_state):
        loss, grads = jax.value_and_grad(compute_cost)(predicted_logits)
        updates, new_opt_state = tx.update(grads, opt_state)
        new_predicted_logits = optax.apply_updates(predicted_logits, updates)
        return new_predicted_logits, new_opt_state, loss

    predicted_logits = init_logits
    for i in range(1000):
        predicted_logits, opt_state, loss = update_fn(predicted_logits, opt_state)
        if i % 100 == 0:
            print("Loss at step {}: {}".format(i, loss))
            print("Predicted histogram: {}".format(jax.nn.softmax(predicted_logits)))
    
    return predicted_logits

predicted_logits = optimize_predicted_logits()
predicted_histogram = jax.nn.softmax(predicted_logits)
print(predicted_histogram)

## 2D grid example

In [None]:
grid_size = (num_radii,)
radii = jnp.asarray([1.0, 2.0, 3.0])
geom = grid.Grid(x=[radii])

def compute_cost(predicted_irreps):
    predicted_histogram = jax.nn.softmax(predicted_irreps)
    target_histogram = jax.nn.one_hot(jnp.array([1]), num_radii).reshape(-1)
    prob = linear_problem.LinearProblem(geom, a=predicted_histogram, b=target_histogram)
    solver = sinkhorn.Sinkhorn()
    out = solver(prob)
    return out.reg_ot_cost


def optimize_predicted_irreps():
    tx = optax.adam(1e-2)
    init_irreps = jnp.ones(num_radii)
    opt_state = tx.init(init_irreps)

    @jax.jit
    def update_fn(predicted_irreps, opt_state):
        loss, grads = jax.value_and_grad(compute_cost)(predicted_irreps)
        updates, new_opt_state = tx.update(grads, opt_state)
        new_predicted_irreps = optax.apply_updates(predicted_irreps, updates)
        return new_predicted_irreps, new_opt_state, loss

    predicted_irreps = init_irreps
    for i in range(1000):
        predicted_irreps, opt_state, loss = update_fn(predicted_irreps, opt_state)
        if i % 100 == 0:
            print("Loss at step {}: {}".format(i, loss))
            print("Predicted histogram: {}".format(jax.nn.softmax(predicted_irreps)))
    
    return predicted_irreps

predicted_irreps = optimize_predicted_irreps()
predicted_histogram = jax.nn.softmax(predicted_logits)
print(predicted_histogram)