In [21]:
import gzip
import pickle

import sys, os
from os import path

import numpy as onp

import jax
from jax import random
import jax.numpy as np
from jax import vjp, custom_vjp
from functools import partial
import openTSNE
from sklearn.utils import check_random_state

In [22]:
def load_data(n_samples=None):
    with gzip.open(path.join("examples/data/mnist", "mnist.pkl.gz"), "rb") as f:
        data = pickle.load(f)

    x, y = data["pca_50"], data["labels"]

    if n_samples is not None:
        indices = onp.random.choice(
            list(range(x.shape[0])), n_samples, replace=False
        )
        x, y = x[indices], y[indices]

    return x, y

In [48]:
@partial(custom_vjp, nondiff_argnums=(0,))
def fixed_point(f, x, y_guess):
    affinity = openTSNE.affinity.PerplexityBasedNN(
        x,
        perplexity=30.0,
        method="annoy",
        random_state=42,
        verbose=True,
    )

    init = openTSNE.initialization.random(
        x, n_components=2, random_state=42, verbose=True,
    )
    
    y_star = openTSNE.TSNEEmbedding(
        init,
        affinity,
        learning_rate=200,
        negative_gradient_method="fft",
        random_state=42,
    )
    y_star.optimize(250, exaggeration=12, momentum=0.8, inplace=True)
    y_star.optimize(750, momentum=0.5, inplace=True)
    return y_star

def fixed_point_fwd(f, x, y_guess):
    y_star = fixed_point(f, x, y_guess)
    return y_star, (x, y_star)

def fixed_point_bwd(f, res, y_star_bar):
    x, y_star = res
    _, vjp_x = vjp(lambda x: f(x, y_star), x)
    x_bar, = vjp_x(1.)
    return x_bar, jnp.zeros_like(y_star)
fixed_point.defvjp(fixed_point_fwd, fixed_point_bwd)

In [None]:
def KL_divergence(x, y):
    

In [49]:
x, label = load_data(n_samples=400)

In [50]:
key = random.PRNGKey(42)
y_guess = random.normal(key, shape=(x.shape[0], 2))
#y_star = fixed_point(lambda r, t: (r-2)*t, x, y_guess)

===> Finding 90 nearest neighbors using Annoy approximate search using euclidean distance...
   --> Time elapsed: 0.12 seconds
===> Calculating affinity matrix...
   --> Time elapsed: 0.02 seconds


In [47]:
y_star

TSNEEmbedding([[ 9.69815460e+00,  5.79930041e+00],
               [-1.05424794e+01,  3.76391981e+00],
               [ 9.58967456e+00, -7.17843019e+00],
               [-1.67044756e+01, -9.49710197e+00],
               [-1.82768855e+01,  7.70417316e+00],
               [-5.04795673e+00, -1.73491630e+01],
               [-9.72283509e+00, -1.75503996e+01],
               [ 7.58853717e+00, -6.25629911e+00],
               [ 6.94522796e+00,  1.81062866e+01],
               [ 9.13919778e+00, -2.13848410e+00],
               [ 1.56999103e+00,  5.69943422e+00],
               [ 3.83729918e-01,  2.01975948e+01],
               [ 4.93949640e-01,  8.66656955e+00],
               [ 1.19741488e+01,  4.64422484e+00],
               [ 3.72802212e+00, -1.51540405e+01],
               [ 9.01455380e+00,  4.20599988e+00],
               [-7.23465139e+00, -1.65239464e+01],
               [ 6.39212972e+00, -1.24174902e+01],
               [-9.95663713e+00,  4.24574968e+00],
               [ 1.44630866e+01

In [None]:
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def fixed_point_layer(solver, f, params, x):
    z_star = solver(lambda z: f(params, x, z), z_init=np.zeros_like(x))
    return z_star

def fixed_point_layer_fwd(solver, f, params, x):
    z_star = fixed_point_layer(solver, f, params, x)
    return z_star, (params, x, z_star)

def fixed_point_layer_bwd(solver, f, res, z_star_bar):
    params, x, z_star = res
    _, vjp_a = jax.vjp(lambda params, x: f(params, x, z_star), params, x)
    _, vjp_z = jax.vjp(lambda z: f(params, x, z), z_star)
    return vjp_a(solver(lambda u: vjp_z(u)[0] + z_star_bar,
                      z_init=np.zeros_like(z_star)))
fixed_point_layer.defvjp(fixed_point_layer_fwd, fixed_point_layer_bwd)

In [5]:
def load_data(n_samples=None):
    with gzip.open(path.join("examples/data/mnist", "mnist.pkl.gz"), "rb") as f:
        data = pickle.load(f)

    x, y = data["pca_50"], data["labels"]

    if n_samples is not None:
        indices = onp.random.choice(
            list(range(x.shape[0])), n_samples, replace=False
        )
        x, y = x[indices], y[indices]

    return x, y

def compute_importances_per_sample(jacobian):
    jacobian_sum_per_sample = []
    jacobian_sum = np.sum(np.abs(jacobian), axis=0)
    for i in range(n):
        sum_for_sample = []
        for j in range(p):
            sum_for_sample.append(jacobian_sum[j*n+i])
        jacobian_sum_per_sample.append(sum(sum_for_sample))
    return jacobian_sum_per_sample

def HdiffGreaterTrue(*betas):
    beta, betamax = betas
    return beta*2

def HdiffGreaterFalse(*betas):
    beta, betamax = betas
    return (beta+betamax)/2

def HdiffSmallerTrue(*betas):
    beta, betamin = betas
    return beta/2

def HdiffSmallerFalse(*betas):
    beta, betamin = betas
    return (beta+betamin)/2

def HdiffGreater(*betas):
    beta, betamin, betamax = betas
    betamin = beta
    beta = cond((np.logical_or(betamax == np.inf, betamax == -np.inf)), HdiffGreaterTrue, HdiffGreaterFalse, *(beta, betamax))
    return beta, betamin, betamax

def HdiffSmaller(*betas):
    beta, betamin, betamax = betas
    betamax = beta
    beta = cond(np.logical_or(betamin == np.inf, betamin == -np.inf), HdiffSmallerTrue, HdiffSmallerFalse, *(beta, betamin))
    return beta, betamin, betamax

def HdiffGreaterTolerance(*betas):
    beta, betamin, betamax, Hdiff = betas
    beta, betamin, betamax = cond(Hdiff > 0, HdiffGreater, HdiffSmaller, *(beta, betamin, betamax))
    return beta, betamin, betamax, Hdiff


def pca(X: np.ndarray, no_dims=50):
    """
        Runs PCA on the NxD array X in order to reduce its dimensionality to
        no_dims dimensions.
    """
    print("Preprocessing the data using PCA...")
    #(n, d) = X.shape
    X = X - np.mean(X, axis=0)
    u, s, vh = np.linalg.svd(X, full_matrices=False, compute_uv=True, hermitian=False)
    #l = (s ** 2 / (d))[0:no_dims]
    M = vh[0:no_dims, :]
    Y = np.dot(X, M.T)
    return Y

def logSoftmax(x):
    """Compute softmax for vector x."""
    max_x = np.max(x)
    exp_x = np.exp(x - max_x)
    sum_exp_x = np.sum(exp_x)
    log_sum_exp_x = np.log(sum_exp_x)
    max_plus_log_sum_exp_x = max_x + log_sum_exp_x
    log_probs = x - max_plus_log_sum_exp_x

    # Recover probs
    exp_log_probs = np.exp(log_probs)
    sum_log_probs = np.sum(exp_log_probs)
    probs = exp_log_probs / sum_log_probs
    return probs


def Hbeta(D: np.ndarray, beta=1.0):
    """
    Compute the log2(perplexity)=Entropy and the P-row (P_i) for a specific value of the
        precision=1/(sigma**2) (beta) of a Gaussian distribution. D: vector of squared Euclidean distances (without i)
    :param D: vector of length d, squared Euclidean distances to all other datapoints (except itself)
    :param beta: precision = beta = 1/sigma**2
    :return: H: log2(Entropy), P: computed probabilites
    """
    # TODO: exchange by softmax as described by https://nlml.github.io/in-raw-numpy/in-raw-numpy-t-sne/
    P = np.exp(-D * beta)     # numerator of p j|i
    sumP = np.sum(P, axis=None)    # denominator of p j|i --> normalization factor
    new_P = logSoftmax(-D * beta)
    sumP += 1e-8
    H = np.log(sumP) + beta * np.sum(D * P) / sumP
    return H, new_P

def binarySearch(res, el, Di, logU):
    print('Entered binary search function')
    Hdiff, thisP, beta, betamin, betamax = res
    Hdiffbool = np.abs(Hdiff) < 1e-5
    beta, betamin, betamax, Hdiff = cond(np.abs(Hdiff) < 1e-5, lambda a, b, c, d: (a, b, c, d), HdiffGreaterTolerance, *(beta, betamin, betamax, Hdiff))

    (H, thisP) = Hbeta(Di, beta)
    Hdiff = H - logU
    return (Hdiff, thisP, beta, betamin, betamax), el

def x2p_inner(Di: np.ndarray, iterator, beta, betamin, betamax, perplexity=30, tol=1e-5):
    """
    binary search for precision for Pi such that it matches the perplexity defined by the user
    :param Di: vector of length d-1, squared Euclidean distances to all other datapoints (except itself)
    :param beta: precision = beta = 1/sigma**2
    :return: final probabilites p j|i
    """
    # Compute the Gaussian kernel and entropy for the current precision
    logU = np.log(perplexity)
    H, thisP = Hbeta(Di, beta)
    Hdiff = H - logU

    print('Starting binary search')
    binarySearch_func = partial(binarySearch, Di=Di, logU=logU)

    # Note: the following binary Search for suitable precisions (betas) will be repeated 50 times and does not include the threshold value
    (Hdiff, thisP, beta, betamin, betamax), el = scan(binarySearch_func, init=(Hdiff, thisP, beta, betamin, betamax), xs=None, length=50)    # Set the final row of P
    thisP = np.insert(thisP, iterator, 0)
    return thisP

def x2p(X: np.ndarray, tol=1e-5, perplexity=30.0):
    """
        Performs a binary search to get P-values (high-dim space) in such a way that each
        conditional Gaussian has the same perplexity.
    """
    # Initialize some variables
    print("Computing pairwise distances...")
    (n, d) = X.shape
    sum_X = np.sum(np.square(X), 1)
    D = np.add(np.add(-2 * np.dot(X, X.T), sum_X).T, sum_X)
    D = np.reshape(np.delete(D, np.array([i for i in range(0, D.shape[0]**2, (D.shape[0]+1))])), (n , n - 1 ))
    beta = np.ones(n)      # precisions (1/sigma**2)
    betamin = np.full(n, -np.inf)
    betamax = np.full(n, np.inf)
    P = vmap(partial(x2p_inner, perplexity=perplexity, tol=tol))(D, np.arange(n), beta=beta, betamin=betamin, betamax=betamax)
    return P

def y2q(Y: np.ndarray):
    # Compute pairwise affinities
    sum_Y = np.sum(np.square(Y), 1)
    num = -2. * np.dot(Y, Y.T)  # numerator
    num = 1. / (1. + np.add(np.add(num, sum_Y).T, sum_Y))
    num = num.at[np.diag_indices_from(num)].set(0.)     # numerator
    Q = num / np.sum(num)
    Q = np.maximum(Q, 1e-12)
    return Q

def optimizeY(res, el, P, initial_momentum=0.5, final_momentum=0.8, eta=500, min_gain=0.01, exaggeration=4.):
    Y, iY, gains, i = res
    n, d = Y.shape

    # Compute pairwise affinities
    sum_Y = np.sum(np.square(Y), 1)
    num = -2. * np.dot(Y, Y.T)  # numerator
    num = 1. / (1. + np.add(np.add(num, sum_Y).T, sum_Y))
    num = num.at[np.diag_indices_from(num)].set(0.)     # numerator
    Q = num / np.sum(num)
    Q = np.maximum(Q, 1e-12)


    # Compute gradient
    PQ = P - Q
    PQ_exp = np.expand_dims(PQ, 2)  # NxNx1
    Y_diffs = np.expand_dims(Y, 1) - np.expand_dims(Y, 0)  # nx1x2 - 1xnx2= # NxNx2
    num_exp = np.expand_dims(num, 2)    # NxNx1
    Y_diffs_wt = Y_diffs * num_exp
    grad = np.sum((PQ_exp * Y_diffs_wt), axis=1) # Nx2

    # Update Y
    momentum = cond(i<250, lambda: initial_momentum, lambda: final_momentum)
    # this business with "gains" is the bar-delta-bar heuristic to accelerate gradient descent
    # code could be simplified by just omitting it
    #inc = iY * grad > 0
    #dec = iY * grad < 0
    #gains = np.where((iY * grad > 0), gains+0.2, gains)
    #gains = np.where((iY * grad < 0), gains*0.8, gains)
    gains = np.clip(gains, min_gain, np.inf)

    iY = momentum * iY - eta * (gains * grad)
    Y = Y + iY

    Y = Y - np.mean(Y, axis=0)
    P = cond(i==100, lambda x: x/exaggeration, lambda x:x, P)
    i += 1
    return ((Y, iY, gains, i), 1.0)

def optimizeYforBackprob(res, el, P, initial_momentum=0.8, final_momentum=0.5, eta=500, min_gain=0.01):
    Y, iY, gains, i = res
    n, d = Y.shape

    # Compute pairwise affinities
    sum_Y = stop_gradient(np.sum(np.square(Y), 1))
    num = stop_gradient(-2. * np.dot(Y, Y.T))  # numerator
    num = stop_gradient(1. / (1. + np.add(np.add(num, sum_Y).T, sum_Y)))
    num = stop_gradient(num.at[np.diag_indices_from(num)].set(0.))     # numerator
    Q = stop_gradient(num / np.sum(num))
    Q = stop_gradient(np.maximum(Q, 1e-12))


    # Compute gradient
    PQ = P - Q
    PQ_exp = np.expand_dims(PQ, 2)  # NxNx1
    Y_diffs = stop_gradient(np.expand_dims(Y, 1) - np.expand_dims(Y, 0))  # nx1x2 - 1xnx2= # NxNx2
    num_exp = np.expand_dims(num, 2)    # NxNx1
    Y_diffs_wt = stop_gradient(Y_diffs * num_exp)
    grad = np.sum((PQ_exp * Y_diffs_wt), axis=1) # Nx2

    # Update Y
    momentum = cond(i<20, lambda: initial_momentum, lambda: final_momentum)
    # this business with "gains" is the bar-delta-bar heuristic to accelerate gradient descent
    # code could be simplified by just omitting it
    # inc = iY * grad > 0
    # dec = iY * grad < 0
    # gains = np.where((iY * grad > 0), gains+0.2, gains)
    # gains = np.where((iY * grad < 0), gains*0.8, gains)
    gains = np.clip(gains, min_gain, np.inf)

    iY = momentum * iY - eta * (gains * grad)
    Y = Y + iY

    Y = Y - np.mean(Y, axis=0)
    P = cond(i==250, lambda x: x/12., lambda x:x, P)
    i += 1
    return ((Y, iY, gains, i), 1.0)

def tsne(X: np.ndarray, no_dims=2, initial_dims=50, perplexity=30.0, learning_rate=500, max_iter = 1000, exaggeration=4., key=42):
    """
        Runs t-SNE on the dataset in the NxD array X to reduce its
        dimensionality to no_dims dimensions. The syntaxis of the function is
        `Y = tsne.tsne(X, no_dims, perplexity), where X is an NxD NumPy array.
    """

    # Check inputs
    if isinstance(no_dims, float):
        print("Error: array X should have type float.")
        return -1
    if round(no_dims) != no_dims:
        print("Error: number of dimensions should be an integer.")
        return -1

    #X = pca(X, initial_dims)
    (n, d) = X.shape
    key = random.PRNGKey(key)

    initial_momentum = 0.8
    final_momentum = 0.5
    eta = learning_rate   # Initial learning rate
    min_gain = 0.01
    # Initialize solution
    #if init is not None:
    #    Y = init
    #else:
    #    Y = random.normal(key, shape=(n, no_dims))
    Y = random.normal(key, shape=(n, no_dims))
    dY = np.zeros((n, no_dims))
    #Y_t1 = np.zeros((n, no_dims))
    #Y_t2 = np.zeros((n, no_dims))
    iY = np.zeros((n, no_dims))
    gains = np.ones((n, no_dims))

    # Compute P-values
    P = x2p(X, 1e-5, perplexity)    # I don't know if the computed P is correct np.sum(P, axis=0) is not 1 everywhere
    P = (P + np.transpose(P))

    P = P / np.sum(P)      # Why don't we devide by 2N as described everywhere?
    P = P * exaggeration  # early exaggeration
    P = np.maximum(P, 1e-12)

    # for debugging
    #for i in range(1000):
    #  ((P, Y, dY, iY, gains, i), j) = optimizeY((P, Y, dY, iY, gains, i), el=1, initial_momentum = initial_momentum, final_momentum = final_momentum, eta = eta, min_gain = min_gain)


    # jit-compiled version
    optimizeY_func = partial(optimizeY, P=P, initial_momentum = initial_momentum, final_momentum = final_momentum, eta = eta, min_gain = min_gain, exaggeration = exaggeration)
    ((Y, iY, gains, i), el) = scan(optimizeY_func, init=(Y, iY, gains, 0), xs=None, length=max_iter)  # Set the final row of P
    return Y

In [None]:
# 
f = np.sum(P * (np.log(P+1e-10) - np.log(Q+1e-10)))

In [10]:
def anderson_solver(f, z_init, m=5, lam=1e-4, max_iter=50, tol=1e-5, beta=1.0):
    x0 = z_init
    x1 = f(x0)
    x2 = f(x1)
    X = np.concatenate([np.stack([x0, x1]), np.zeros((m - 2, *np.shape(x0)))])
    F = np.concatenate([np.stack([x1, x2]), np.zeros((m - 2, *np.shape(x0)))])

    res = []
    for k in range(2, max_iter):
        n = min(k, m)
        G = F[:n] - X[:n]
        GTG = np.tensordot(G, G, [list(range(1, G.ndim))] * 2)
        H = np.block([[np.zeros((1, 1)), np.ones((1, n))],
                   [ np.ones((n, 1)), GTG]]) + lam * np.eye(n + 1)
        alpha = np.linalg.solve(H, np.zeros(n+1).at[0].set(1))[1:]

        xk = beta * np.dot(alpha, F[:n]) + (1-beta) * np.dot(alpha, X[:n])
        X = X.at[k % m].set(xk)
        F = F.at[k % m].set(f(xk))

        res = np.linalg.norm(F[k % m] - X[k % m]) / (1e-5 + np.linalg.norm(F[k % m]))
        if res < tol:
            break
    return xk


In [11]:
@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def fixed_point_layer(solver, f, params, x):
    z_star = solver(lambda z: f(params, x, z), z_init=np.zeros_like(x))
    return z_star

def fixed_point_layer_fwd(solver, f, params, x):
    z_star = fixed_point_layer(solver, f, params, x)
    return z_star, (params, x, z_star)

def fixed_point_layer_bwd(solver, f, res, z_star_bar):
    params, x, z_star = res
    _, vjp_a = jax.vjp(lambda params, x: f(params, x, z_star), params, x)
    _, vjp_z = jax.vjp(lambda z: f(params, x, z), z_star)
    return vjp_a(solver(lambda u: vjp_z(u)[0] + z_star_bar,
                      z_init=np.zeros_like(z_star)))
fixed_point_layer.defvjp(fixed_point_layer_fwd, fixed_point_layer_bwd)

In [13]:
f = lambda W, x, z: np.tanh(np.dot(W, z) + x)
ndim = 10
W = random.normal(random.PRNGKey(0), (ndim, ndim)) / np.sqrt(ndim)
x = random.normal(random.PRNGKey(1), (ndim,))
g = jax.grad(lambda W: fixed_point_layer(anderson_solver, f, W, x).sum())(W)
print(g[0])

[ 0.00752617 -0.8125575  -1.1404572  -0.04857638 -0.7125224  -0.55803984
  0.66979057  1.1068186  -0.09697762  0.97838473]


In [7]:
g = jax.grad(lambda W: fixed_point_layer(newton_solver, f, W, x).sum())(W)
print(g[0])

[ 0.00752136 -0.8125742  -1.1404786  -0.04860303 -0.7125376  -0.55805624
  0.6697907   1.1068398  -0.09697363  0.9784083 ]


In [None]:
class gradient_descent:
    def __init__(self):
        self.gains = None

    def copy(self):
        optimizer = self.__class__()
        if self.gains is not None:
            optimizer.gains = np.copy(self.gains)
        return optimizer

    def __call__(
        self,
        embedding,
        P,
        n_iter,
        objective_function,
        learning_rate=200,
        momentum=0.5,
        exaggeration=None,
        dof=1,
        min_gain=0.01,
        max_grad_norm=None,
        max_step_norm=5,
        theta=0.5,
        n_interpolation_points=3,
        min_num_intervals=50,
        ints_in_interval=1,
        reference_embedding=None,
        n_jobs=1,
        use_callbacks=False,
        callbacks=None,
        callbacks_every_iters=50,
        verbose=False,
    ):
        """Perform batch gradient descent with momentum and gains.

        Parameters
        ----------
        embedding: np.ndarray
            The embedding :math:`Y`.

        P: array_like
            Joint probability matrix :math:`P`.

        n_iter: int
            The number of iterations to run for.

        objective_function: Callable[..., Tuple[float, np.ndarray]]
            A callable that evaluates the error and gradient for the current
            embedding.

        learning_rate: Union[str, float]
            The learning rate for t-SNE optimization. When
            ``learning_rate="auto"`` the appropriate learning rate is selected
            according to max(200, N / 12), as determined in Belkina et al.
            "Automated optimized parameters for t-distributed stochastic
            neighbor embedding improve visualization and analysis of large
            datasets", 2019.

        momentum: float
            Momentum accounts for gradient directions from previous iterations,
            resulting in faster convergence.

        exaggeration: float
            The exaggeration factor is used to increase the attractive forces of
            nearby points, producing more compact clusters.

        dof: float
            Degrees of freedom of the Student's t-distribution.

        min_gain: float
            Minimum individual gain for each parameter.

        max_grad_norm: float
            Maximum gradient norm. If the norm exceeds this value, it will be
            clipped. This is most beneficial when adding points into an existing
            embedding and the new points overlap with the reference points,
            leading to large gradients. This can make points "shoot off" from
            the embedding, causing the interpolation method to compute a very
            large grid, and leads to worse results.

        max_step_norm: float
            Maximum update norm. If the norm exceeds this value, it will be
            clipped. This prevents points from "shooting off" from
            the embedding.

        theta: float
            This is the trade-off parameter between speed and accuracy of the
            tree approximation method. Typical values range from 0.2 to 0.8. The
            value 0 indicates that no approximation is to be made and produces
            exact results also producing longer runtime.

        n_interpolation_points: int
            Only used when ``negative_gradient_method="fft"`` or its other
            aliases. The number of interpolation points to use within each grid
            cell for interpolation based t-SNE. It is highly recommended leaving
            this value at the default 3.

        min_num_intervals: int
            Only used when ``negative_gradient_method="fft"`` or its other
            aliases. The minimum number of grid cells to use, regardless of the
            ``ints_in_interval`` parameter. Higher values provide more accurate
            gradient estimations.

        ints_in_interval: float
            Only used when ``negative_gradient_method="fft"`` or its other
            aliases. Indicates how large a grid cell should be e.g. a value of 3
            indicates a grid side length of 3. Lower values provide more
            accurate gradient estimations.

        reference_embedding: np.ndarray
            If we are adding points to an existing embedding, we have to compute
            the gradients and errors w.r.t. the existing embedding.

        n_jobs: int
            The number of threads to use while running t-SNE. This follows the
            scikit-learn convention, ``-1`` meaning all processors, ``-2``
            meaning all but one, etc.

        use_callbacks: bool

        callbacks: Callable[[int, float, np.ndarray] -> bool]
            Callbacks, which will be run every ``callbacks_every_iters``
            iterations.

        callbacks_every_iters: int
            How many iterations should pass between each time the callbacks are
            invoked.

        Returns
        -------
        float
            The KL divergence of the optimized embedding.
        np.ndarray
            The optimized embedding Y.

        Raises
        ------
        OptimizationInterrupt
            If the provided callback interrupts the optimization, this is raised.

        """
        assert isinstance(embedding, np.ndarray), (
            "`embedding` must be an instance of `np.ndarray`. Got `%s` instead"
            % type(embedding)
        )

        if reference_embedding is not None:
            assert isinstance(reference_embedding, np.ndarray), (
                "`reference_embedding` must be an instance of `np.ndarray`. Got "
                "`%s` instead" % type(reference_embedding)
            )

        # If the interpolation grid has not yet been evaluated, do it now
        if (
            reference_embedding is not None and
            reference_embedding.interp_coeffs is None and
            objective_function is kl_divergence_fft
        ):
            reference_embedding.prepare_interpolation_grid()

        # If we're running transform and using the interpolation scheme, then we
        # should limit the range where new points can go to
        should_limit_range = False
        if reference_embedding is not None:
            if reference_embedding.box_x_lower_bounds is not None:
                should_limit_range = True
                lower_limit = reference_embedding.box_x_lower_bounds[0]
                upper_limit = reference_embedding.box_x_lower_bounds[-1]

        update = np.zeros_like(embedding)
        if self.gains is None:
            self.gains = np.ones_like(embedding).view(np.ndarray)

        bh_params = {"theta": theta}
        fft_params = {
            "n_interpolation_points": n_interpolation_points,
            "min_num_intervals": min_num_intervals,
            "ints_in_interval": ints_in_interval,
        }

        # Lie about the P values for bigger attraction forces
        if exaggeration is None:
            exaggeration = 1

        if exaggeration != 1:
            P *= exaggeration

        # Notify the callbacks that the optimization is about to start
        if isinstance(callbacks, Iterable):
            for callback in callbacks:
                # Only call function if present on object
                getattr(callback, "optimization_about_to_start", lambda: ...)()

        timer = utils.Timer(
            "Running optimization with exaggeration=%.2f, lr=%.2f for %d iterations..." % (
                exaggeration, learning_rate, n_iter
            ),
            verbose=verbose,
        )
        timer.__enter__()

        if verbose:
            start_time = time()

        for iteration in range(n_iter):
            should_call_callback = use_callbacks and (iteration + 1) % callbacks_every_iters == 0
            # Evaluate error on 50 iterations for logging, or when callbacks
            should_eval_error = should_call_callback or \
                (verbose and (iteration + 1) % 50 == 0)

            error, gradient = objective_function(
                embedding,
                P,
                dof=dof,
                bh_params=bh_params,
                fft_params=fft_params,
                reference_embedding=reference_embedding,
                n_jobs=n_jobs,
                should_eval_error=should_eval_error,
            )

            # Clip gradients to avoid points shooting off. This can be an issue
            # when applying transform and points are initialized so that the new
            # points overlap with the reference points, leading to large
            # gradients
            if max_grad_norm is not None:
                norm = np.linalg.norm(gradient, axis=1)
                coeff = max_grad_norm / (norm + 1e-6)
                mask = coeff < 1
                gradient[mask] *= coeff[mask, None]

            # Correct the KL divergence w.r.t. the exaggeration if needed
            if should_eval_error and exaggeration != 1:
                error = error / exaggeration - np.log(exaggeration)

            if should_call_callback:
                # Continue only if all the callbacks say so
                should_stop = any(
                    (bool(c(iteration + 1, error, embedding)) for c in callbacks)
                )
                if should_stop:
                    # Make sure to un-exaggerate P so it's not corrupted in future runs
                    if exaggeration != 1:
                        P /= exaggeration
                    raise OptimizationInterrupt(error=error, final_embedding=embedding)

            # Update the embedding using the gradient
            grad_direction_flipped = np.sign(update) != np.sign(gradient)
            grad_direction_same = np.invert(grad_direction_flipped)
            self.gains[grad_direction_flipped] += 0.2
            self.gains[grad_direction_same] = (
                self.gains[grad_direction_same] * 0.8 + min_gain
            )
            update = momentum * update - learning_rate * self.gains * gradient

            # Clip the update sizes
            if max_step_norm is not None:
                update_norms = np.linalg.norm(update, axis=1, keepdims=True)
                mask = update_norms.squeeze() > max_step_norm
                update[mask] /= update_norms[mask]
                update[mask] *= max_step_norm

            embedding += update

            # Zero-mean the embedding only if we're not adding new data points,
            # otherwise this will reset point positions
            if reference_embedding is None:
                embedding -= np.mean(embedding, axis=0)

            # Limit any new points within the circle defined by the interpolation grid
            if should_limit_range:
                if embedding.shape[1] == 1:
                    mask = (embedding < lower_limit) | (embedding > upper_limit)
                    np.clip(embedding, lower_limit, upper_limit, out=embedding)
                elif embedding.shape[1] == 2:
                    r_limit = max(abs(lower_limit), abs(upper_limit))
                    embedding, mask = utils.clip_point_to_disc(embedding, r_limit, inplace=True)

                # Zero out the momentum terms for the points that hit the boundary
                self.gains[~mask] = 0

            if verbose and (iteration + 1) % 50 == 0:
                stop_time = time()
                print("Iteration %4d, KL divergence %6.4f, 50 iterations in %.4f sec" % (
                    iteration + 1, error, stop_time - start_time))
                start_time = time()

        timer.__exit__()

        # Make sure to un-exaggerate P so it's not corrupted in future runs
        if exaggeration != 1:
            P /= exaggeration

        # The error from the loop is the one for the previous, non-updated
        # embedding. We need to return the error for the actual final embedding, so
        # compute that at the end before returning
        error, _ = objective_function(
            embedding,
            P,
            dof=dof,
            bh_params=bh_params,
            fft_params=fft_params,
            reference_embedding=reference_embedding,
            n_jobs=n_jobs,
            should_eval_error=True,
        )

        return error, embedding

In [9]:
from jax.lax import custom_root

def KL_gradient(Y, theta):
    '''
    computes gradient of the KL-divergence D_KL(P(X)||Q(Y))
    '''
    momentum, learning_rate, perplexity, X = theta
    P = x2p(X, tol=1e-5, perplexity=perplexity)
    Q = y2q(Y)
    PQ = P - Q
    PQ_exp = np.expand_dims(PQ, 2)  # NxNx1
    Y_diffs = np.expand_dims(Y, 1) - np.expand_dims(Y, 0)  # nx1x2 - 1xnx2= # NxNx2
    num_exp = np.expand_dims(num, 2)    # NxNx1
    Y_diffs_wt = Y_diffs * num_exp
    return np.sum((PQ_exp * Y_diffs_wt), axis=1) # Nx2
    
    
def T(Y, theta):
    momentum, learning_rate, perplexity, X = theta
    # only gradient
    return Y - learning_rate * KL_gradient(Y, theta)
    #return Y + momentum * update - learning_rate * self.gains * gradient(Y, x)

def tsne(init_Y, theta):
    """
        Runs t-SNE on the dataset in the NxD array X to reduce its
        dimensionality to no_dims dimensions. The syntaxis of the function is
        `Y = tsne.tsne(X, no_dims, perplexity), where X is an NxD NumPy array.
    """
    momentum, learning_rate, perplexity, X = theta
    no_dims=2
    initial_dims=50
    max_iter = 1000
    exaggeration=4.
    key=42
    # Check inputs
    if isinstance(no_dims, float):
        print("Error: array X should have type float.")
        return -1
    if round(no_dims) != no_dims:
        print("Error: number of dimensions should be an integer.")
        return -1

    #X = pca(X, initial_dims)
    (n, d) = X.shape
    key = random.PRNGKey(key)

    initial_momentum = 0.8
    final_momentum = 0.5
    eta = learning_rate   # Initial learning rate
    min_gain = 0.01
    # Initialize solution
    #if init is not None:
    #    Y = init
    #else:
    #    Y = random.normal(key, shape=(n, no_dims))
    Y = random.normal(key, shape=(n, no_dims))
    dY = np.zeros((n, no_dims))
    #Y_t1 = np.zeros((n, no_dims))
    #Y_t2 = np.zeros((n, no_dims))
    iY = np.zeros((n, no_dims))
    gains = np.ones((n, no_dims))

    # Compute P-values
    P = x2p(X, 1e-5, perplexity)    # I don't know if the computed P is correct np.sum(P, axis=0) is not 1 everywhere
    P = (P + np.transpose(P))

    P = P / np.sum(P)      # Why don't we devide by 2N as described everywhere?
    P = P * exaggeration  # early exaggeration
    P = np.maximum(P, 1e-12)

    # for debugging
    #for i in range(1000):
    #  ((P, Y, dY, iY, gains, i), j) = optimizeY((P, Y, dY, iY, gains, i), el=1, initial_momentum = initial_momentum, final_momentum = final_momentum, eta = eta, min_gain = min_gain)


    # jit-compiled version
    optimizeY_func = partial(optimizeY, P=P, initial_momentum = initial_momentum, final_momentum = final_momentum, eta = eta, min_gain = min_gain, exaggeration = exaggeration)
    ((Y, iY, gains, i), el) = scan(optimizeY_func, init=(Y, iY, gains, 0), xs=None, length=max_iter)  # Set the final row of P
    return Y