In [9]:
import jax.numpy as jnp

a = jnp.array([[1], [2], [3]])
b = jnp.array([[4], [5], [6]])
a = a[:, jnp.newaxis, :]
b = b[jnp.newaxis, :, :]
Z = jnp.sqrt(jnp.sum((a - b) ** 2, axis=2))
print(Z)

[[3. 4. 5.]
 [2. 3. 4.]
 [1. 2. 3.]]


In [10]:
from jax import config

config.update("jax_debug_nans", True)

In [11]:
# def cdist(A):
#     a = A[:,jnp.newaxis,:]
#     b = A[jnp.newaxis,:,:]
#     return jnp.sqrt(jnp.sum((a - b)**2, axis=2))


def calc_y(A, alpha):
    A = cdist(A)
    return A @ alpha


def joint_probabilities(sq_distance, perplexity: int):
    conditional_P = binary_search_perplexity(sq_distance, perplexity)
    P = conditional_P + conditional_P.T
    P /= jnp.sum(P)
    return P


def binary_search_perplexity(sq_distance, perplexity: int):
    PERPLEXITY_TOLERANCE = 1e-5
    n = sq_distance.shape[0]
    # Maximum number of binary search steps
    max_iter = 100
    eps = 1.0e-10
    full_eps = jnp.full(n, eps)
    beta = jnp.full(n, 1.0)
    beta_max = jnp.full(n, jnp.inf)
    beta_min = jnp.full(n, -jnp.inf)
    logPerp = jnp.log(perplexity)
    for _ in range(max_iter):
        conditional_P = jnp.exp(-sq_distance * beta.reshape((n, 1)))
        for i in range(n):
            conditional_P.at[i, i].set(0.0)
        P_sum = jnp.sum(conditional_P, axis=1)
        P_sum = jnp.maximum(P_sum, full_eps)
        conditional_P /= P_sum.reshape((n, 1))
        H = jnp.log(P_sum) + beta * jnp.sum(sq_distance * conditional_P, axis=1)
        H_diff = H - logPerp
        if jnp.abs(H_diff).max() < PERPLEXITY_TOLERANCE:
            break

        # 二分探索
        # beta_min
        pos_flag = jnp.logical_and((H_diff > 0.0), (jnp.abs(H_diff) > eps))
        beta_min.at[pos_flag].set(beta[pos_flag])
        inf_flag = jnp.logical_and(pos_flag, (beta_max == jnp.inf))
        beta.at[inf_flag].multiply(2.0)
        not_inf_flag = jnp.logical_and((H_diff > 0.0), (beta_max != jnp.inf))
        not_inf_flag = jnp.logical_and(jnp.logical_not(inf_flag), not_inf_flag)
        beta.at[not_inf_flag].set((beta[not_inf_flag] + beta_max[not_inf_flag]) / 2.0)
        # beta_max
        neg_flag = jnp.logical_and((H_diff <= 0.0), jnp.abs(H_diff) > eps)
        beta_max.at[neg_flag].set(beta[neg_flag])
        neg_inf_flag = jnp.logical_and(neg_flag, (beta_min == -jnp.inf))
        beta.at[neg_inf_flag].divide(2.0)
        neg_not_inf_flag = jnp.logical_and((H_diff <= 0.0), (beta_min != -jnp.inf))
        neg_not_inf_flag = jnp.logical_and(jnp.logical_not(neg_inf_flag), neg_not_inf_flag)
        beta.at[neg_not_inf_flag].set((beta[neg_not_inf_flag] + beta_min[neg_not_inf_flag]) / 2.0)
    return conditional_P


def calc_probabilities_q(c_data):
    # Student's t-distribution
    q_tmp = 1 / (1 + cdist(c_data))
    n_data = len(c_data)
    for i in range(n_data):
        q_tmp.at[i, i].set(0.0)
    q_sum = jnp.sum(q_tmp)
    q_probs = q_tmp / q_sum
    return q_probs


def calc_loss(p_prob, q_prob):
    p_prob = jnp.maximum(p_prob, 1e-12)
    q_prob = jnp.maximum(q_prob, 1e-12)
    C = p_prob * jnp.log(p_prob / q_prob)
    c = jnp.sum(C)
    return c


def calc_probabilities_p(X):
    sq_distance = cdist(X)
    p_probs = joint_probabilities(sq_distance, 30)
    return p_probs


from jax.scipy.special import kl_div


def cdist(x):
    x1_norm = jnp.sum(x**2, axis=1)[:, jnp.newaxis]
    x1_pad = jnp.ones_like(x1_norm)
    psi = jnp.concatenate((-2 * x, x1_norm, x1_pad), axis=1)
    x2_norm = jnp.sum(x**2, axis=1)[:, jnp.newaxis]
    x2_pad = jnp.ones_like(x2_norm)
    print(x)
    phi = jnp.concatenate((x, x2_pad, x2_norm), axis=1)
    d2 = jnp.dot(psi, phi.T)
    d2 = jnp.sqrt(d2)
    return d2
    # return jnp.sqrt(((x[:, None, :] - x[None, :, :]) ** 2).sum(-1))


#   return jnp.sqrt(jnp.sum((x[:, None] - x[None, :]) ** 2, -1))
def cost(A, alpha):
    p_prob = calc_probabilities_p(A)
    y = calc_y(A, alpha)
    q_prob = calc_probabilities_q(y)
    loss = kl_div(p_prob, q_prob)
    loss = jnp.sum(loss)
    print(loss)
    return loss


from jax import grad, value_and_grad

grad_cdist = grad(cdist)
grad_y = value_and_grad(cost, argnums=(0, 1))
alpha = jnp.array([1, 2, 3, 4, 5, 6], dtype=jnp.float32)
alpha = jnp.reshape(alpha, (3, 2))
print(grad_y(jnp.array([[1], [2], [3]], dtype=jnp.float32), alpha))

Traced<ConcreteArray([[1.]
 [2.]
 [3.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[1.],
       [2.],
       [3.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[3,1])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3,1]), None)
    recipe = LambdaBinding()
Traced<ConcreteArray([[1.]
 [2.]
 [3.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[1.],
       [2.],
       [3.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[3,1])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3,1]), None)
    recipe = LambdaBinding()
Traced<ConcreteArray([[13. 16.]
 [ 6.  8.]
 [ 5.  8.]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[13., 16.],
       [ 6.,  8.],
       [ 5.,  8.]], dtype=float32)
  tangent = Traced<ShapedArray(float32[3,2])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3,2]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x0000016A943A03D0>, in_trac

FloatingPointError: invalid value (nan) encountered in jit(add). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/jax-ml/jax.