In [1]:
import jax.numpy as np
from utils import MidpointNormalize, load_data
from jax import random, flatten_util, vjp, jvp, custom_vjp, jacfwd, jacrev, vmap, grad
from IFD_tsne import tsne_fwd
import jax
import matplotlib.pylab as plt
import seaborn as sns

In [53]:
# Pairwise squared Euclidean distance
def get_dists(Z):
    """
    Compute all pairwise distances of each data points z_i in Z.
    
    Params:
        Z: matrix with n rows. Each row is the z_i
        
    Return:
        All pairwise distances of each data points z_i in Z
    """

    
    diff = Z[:, :, None] - Z[:, :, None].T
    return (np.square(diff)).sum(1)

def perp_fn(i, beta_i, dists, perplexity_function):
    """
    Function that takes an index i, beta_i, and all pairwise distances of X
    and return the perplexity of p_{j|i} (Eq. 1 in the paper).
    """
    exp_dists = np.exp(-dists[i] / beta_i)
    exp_dists.at[i].set(0.)
    p_j_given_i = exp_dists / exp_dists.sum()

    perp_i = perplexity_function(p_j_given_i)

    return perp_i

def binary_search(perp, dists, perplexity_function):
    """
    Let beta_i := 2 \sigma_i^2. This function computes (beta_i) that achieve
    the desired perplexity.
    
    Params:                 
        perp: Desired perplexity value.
        
        dists: Pairwise squared Euclidean distances, stored in an (n x n)-matrix
        
        perplexity_function: A function that return the perplexity number given a probability vector
        
    Returns:
        betas: (n,) array of beta_i's 
    """    
    n = len(dists)
    betas = []
    
    for i in range(n):
        # Binary search
        min_beta, max_beta = 1e-10, 1e10

        for _ in range(1000):
            mid_beta = (min_beta + max_beta) / 2
            p_mid = perp_fn(i, mid_beta, dists, perplexity_function)

            if p_mid >= perp:
                max_beta = mid_beta
            else:
                min_beta = mid_beta

            # Close enough, use the current mid value
            if np.abs(p_mid - perp) < 1e-3:
                break
                
        betas.append(mid_beta)

    return np.array(betas)

def get_perplexity(p):
    """
    Returns the perplexity of p. See https://en.wikipedia.org/wiki/Perplexity
    
    Params:
        p: probability vector
        
    Return:
        A single number---the perplexity of p
    """
    entropy = -np.sum(p * np.log2(p + 1e-10))
    return 2**entropy


def get_beta(perp, dists_X):
    """
    Let beta_i := 2 \sigma_i^2. This function computes (beta_i) that achieve
    the desired perplexity.
    
    Params:                 
        perp: Desired perplexity value.
        
        dists_X: Pairwise squared Euclidean distances between points in X, stored in an (n x n)-matrix
                
    Returns:
        betas: (n,) array of beta_i's 
    """    
    return binary_search(perp, dists_X, get_perplexity)

def logSoftmax(x, i):
    """Compute softmax for vector x."""
    max_x = np.max(x)
    exp_x = np.exp(x - max_x)
    sum_exp_x = np.sum(exp_x)
    log_sum_exp_x = np.log(sum_exp_x)
    max_plus_log_sum_exp_x = max_x + log_sum_exp_x
    log_probs = x - max_plus_log_sum_exp_x

    # Recover probs
    exp_log_probs = np.exp(log_probs)
    exp_log_probs = exp_log_probs.at[i].set(0.)
    sum_log_probs = np.sum(exp_log_probs)
    probs = exp_log_probs / sum_log_probs
    return probs

def get_p_j_given_i(dists_X, perp):
    """
    Compute the conditional probabilities p_{j|i}'s
    
    Params
        dists_X: pairwise-distances matrix of X
        perp: the desired perplexity level (single number)

    Return:
        (n, n) matrix containing p_{j|i}'s
    """
    betas = get_beta(perp, jax.lax.stop_gradient(dists_X))
    
    p_j_given_i = []
    for i, j in enumerate(betas):
        p_j_given_i.append(logSoftmax(-dists_X[i] / (j), i))
    return np.array(p_j_given_i)
    

def x2p(dists_X, perp):
    """
    Compute the joint probabilities p_ij's
    
    Params
        dists_X: pairwise-distances matrix of X
        perp: the desired perplexity level (single number)

    Return:
        (n, n) matrix P containing p_ij's
    """
    n = len(dists_X)
    p_j_given_i = get_p_j_given_i(dists_X, perp)
    p_i_given_j = p_j_given_i.T
    P = (p_j_given_i + p_i_given_j) / (2*n)
    #P = (p_j_given_i + p_i_given_j) / (np.sum(p_i_given_j))
    return P

def y2q(dists_Y):    
    """
    Compute low-dimensional affinities q_ij
    
    Params
        dists_Y: (n, n) matrix containing all pairwise distances of elements of Y

    Return:
        (n, n) matrix Q containing q_ij's
    """
    Q = 1./(1. + dists_Y)
    Q = Q.at[np.diag_indices_from(Q)].set(0.)
    Q /= Q.sum()
    
    return Q

In [37]:
def KL_divergence(X_flat, Y_flat, X_unflattener, Y_unflattener):
    """
    (R^nxp x R^nxp)--> R
    """
    X = X_unflattener(X_flat)
    Y = Y_unflattener(Y_flat)
    learning_rate, perplexity = (200, 30)
    dists_X = get_dists(X)
    P = x2p(dists_X, perplexity)
    P = np.maximum(P, 1e-12)
    print('P', P)
    dists_Y = get_dists(Y)
    Q = y2q(dists_Y)
    Q = np.maximum(Q, 1e-12)
    print('Q', Q)
    return np.sum(P * (np.log(P+1e-10) - np.log(Q+1e-10)))

In [12]:
X, y = load_data(50)
key = random.PRNGKey(41)
#X = onp.array(random.normal(key, shape=(50, 50)))
y_guess = random.normal(key, shape=(X.shape[0], 2))
#Y_star = TSNE(n_components=2, learning_rate=200, init=onp.array(y_guess), perplexity=30).fit_transform(X)
Y_star = tsne_fwd(X, y_guess)

X_flat, X_unflattener = flatten_util.ravel_pytree(X)   # row-wise
Y_flat, Y_unflattener = flatten_util.ravel_pytree(Y_star) 

===> Finding 49 nearest neighbors using Annoy approximate search using euclidean distance...
   --> Time elapsed: 0.01 seconds
===> Calculating affinity matrix...
   --> Time elapsed: 0.00 seconds
===> Running optimization with exaggeration=12.00, lr=200.00 for 250 iterations...
Iteration   50, KL divergence 0.8348, 50 iterations in 0.5900 sec
Iteration  100, KL divergence 0.9573, 50 iterations in 0.4103 sec
Iteration  150, KL divergence 0.8907, 50 iterations in 0.4086 sec
Iteration  200, KL divergence 0.8565, 50 iterations in 0.4094 sec
Iteration  250, KL divergence 1.0399, 50 iterations in 0.3464 sec
   --> Time elapsed: 2.16 seconds
===> Running optimization with exaggeration=1.00, lr=200.00 for 750 iterations...
Iteration   50, KL divergence 0.1908, 50 iterations in 0.2989 sec
Iteration  100, KL divergence 0.1889, 50 iterations in 0.3017 sec
Iteration  150, KL divergence 0.1889, 50 iterations in 0.3016 sec
Iteration  200, KL divergence 0.1889, 50 iterations in 0.2989 sec
Iteration 

In [54]:
KL_divergence(X_flat, Y_flat, X_unflattener, Y_unflattener)

P [[1.0000000e-12 4.5234329e-04 1.2838001e-04 ... 3.0309788e-04
  1.7307920e-04 1.9320578e-04]
 [4.5234329e-04 1.0000000e-12 2.0654105e-04 ... 3.0992090e-04
  3.0359143e-04 2.5453736e-04]
 [1.2838001e-04 2.0654105e-04 1.0000000e-12 ... 4.4234644e-04
  3.3275821e-04 9.3964400e-04]
 ...
 [3.0309788e-04 3.0992090e-04 4.4234644e-04 ... 1.0000000e-12
  2.4623130e-04 4.4399317e-04]
 [1.7307920e-04 3.0359143e-04 3.3275821e-04 ... 2.4623130e-04
  1.0000000e-12 2.5393465e-04]
 [1.9320578e-04 2.5453736e-04 9.3964400e-04 ... 4.4399317e-04
  2.5393465e-04 1.0000000e-12]]
Q [[1.0000000e-12 4.9978946e-05 1.2624923e-04 ... 8.3590996e-05
  2.0909587e-04 1.5988105e-04]
 [4.9978946e-05 1.0000000e-12 1.9751908e-04 ... 5.4623774e-04
  9.2509639e-05 1.9815417e-04]
 [1.2624923e-04 1.9751908e-04 1.0000000e-12 ... 6.0966029e-04
  1.4086519e-04 1.0957340e-03]
 ...
 [8.3590996e-05 5.4623774e-04 6.0966029e-04 ... 1.0000000e-12
  1.4363181e-04 6.1760045e-04]
 [2.0909587e-04 9.2509639e-05 1.4086519e-04 ... 1.43631

Array(0.13831377, dtype=float32)

In [58]:
J_X_Y = jacfwd(jacfwd(KL_divergence, argnums=1), argnums=0)(X_flat, Y_flat, X_unflattener, Y_unflattener)
print('J', J_X_Y)

P Traced<ConcreteArray([[1.0000000e-12 4.5234329e-04 1.2838001e-04 ... 3.0309788e-04
  1.7307920e-04 1.9320578e-04]
 [4.5234329e-04 1.0000000e-12 2.0654105e-04 ... 3.0992090e-04
  3.0359143e-04 2.5453736e-04]
 [1.2838001e-04 2.0654105e-04 1.0000000e-12 ... 4.4234644e-04
  3.3275821e-04 9.3964400e-04]
 ...
 [3.0309788e-04 3.0992090e-04 4.4234644e-04 ... 1.0000000e-12
  2.4623130e-04 4.4399317e-04]
 [1.7307920e-04 3.0359143e-04 3.3275821e-04 ... 2.4623130e-04
  1.0000000e-12 2.5393465e-04]
 [1.9320578e-04 2.5453736e-04 9.3964400e-04 ... 4.4399317e-04
  2.5393465e-04 1.0000000e-12]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([[1.0000000e-12, 4.5234329e-04, 1.2838001e-04, ..., 3.0309788e-04,
        1.7307920e-04, 1.9320578e-04],
       [4.5234329e-04, 1.0000000e-12, 2.0654105e-04, ..., 3.0992090e-04,
        3.0359143e-04, 2.5453736e-04],
       [1.2838001e-04, 2.0654105e-04, 1.0000000e-12, ..., 4.4234644e-04,
        3.3275821e-04, 9.3964400e-04],
       ...,
       

In [55]:
H = jax.hessian(KL_divergence, argnums=1)(X_flat, Y_flat, X_unflattener, Y_unflattener)

P [[1.0000000e-12 4.5234329e-04 1.2838001e-04 ... 3.0309788e-04
  1.7307920e-04 1.9320578e-04]
 [4.5234329e-04 1.0000000e-12 2.0654105e-04 ... 3.0992090e-04
  3.0359143e-04 2.5453736e-04]
 [1.2838001e-04 2.0654105e-04 1.0000000e-12 ... 4.4234644e-04
  3.3275821e-04 9.3964400e-04]
 ...
 [3.0309788e-04 3.0992090e-04 4.4234644e-04 ... 1.0000000e-12
  2.4623130e-04 4.4399317e-04]
 [1.7307920e-04 3.0359143e-04 3.3275821e-04 ... 2.4623130e-04
  1.0000000e-12 2.5393465e-04]
 [1.9320578e-04 2.5453736e-04 9.3964400e-04 ... 4.4399317e-04
  2.5393465e-04 1.0000000e-12]]
Q Traced<ConcreteArray([[1.0000000e-12 4.9978946e-05 1.2624923e-04 ... 8.3590996e-05
  2.0909587e-04 1.5988105e-04]
 [4.9978946e-05 1.0000000e-12 1.9751908e-04 ... 5.4623774e-04
  9.2509639e-05 1.9815417e-04]
 [1.2624923e-04 1.9751908e-04 1.0000000e-12 ... 6.0966029e-04
  1.4086519e-04 1.0957340e-03]
 ...
 [8.3590996e-05 5.4623774e-04 6.0966029e-04 ... 1.0000000e-12
  1.4363181e-04 6.1760045e-04]
 [2.0909587e-04 9.2509639e-05 1.40

In [56]:
print(H)

[[ 7.74874352e-04 -2.03166110e-03  1.07850519e-05 ...  1.16690906e-04
  -3.03857560e-05  2.94965612e-05]
 [-2.03166110e-03  4.26577963e-03 -1.27868989e-04 ... -1.24047612e-04
   4.80741001e-05 -1.00675963e-04]
 [ 1.07850519e-05 -1.27869018e-04  4.81019076e-03 ... -1.06833235e-04
  -5.50453005e-05  4.88687074e-05]
 ...
 [ 1.16690906e-04 -1.24047627e-04 -1.06833220e-04 ...  4.13600635e-03
  -5.66409399e-05 -1.36302348e-04]
 [-3.03857560e-05  4.80740855e-05 -5.50453260e-05 ... -5.66409617e-05
   1.95421763e-02  1.32833445e-03]
 [ 2.94965594e-05 -1.00675941e-04  4.88687147e-05 ... -1.36302348e-04
   1.32833445e-03  1.92178823e-02]]


In [16]:
from simplified_tsne_jax import * 
def KL_divergence(X_flat, Y_flat, X_unflattener, Y_unflattener):
    """
    (R^nxp x R^nxp)--> R
    """
    X = X_unflattener(X_flat)
    Y = Y_unflattener(Y_flat)
    learning_rate, perplexity = (200, 30.0)
    P = x2p(X, tol=1e-5, perplexity=perplexity)
    P = (P + np.transpose(P))
    P = P / np.sum(P)      # Why don't we devide by 2N as described everywhere?
    P = np.maximum(P, 1e-12)
    print('P', P)
    Q, _ = y2q(Y)
    print('Q', Q)
    return np.sum(P * (np.log(P+1e-10) - np.log(Q+1e-10)))

In [17]:
KL_divergence(X_flat, Y_flat, X_unflattener, Y_unflattener)

Computing pairwise distances...
Starting binary search
Entered binary search function
P [[1.0000000e-12 4.1164763e-04 7.6536162e-05 ... 2.5737088e-04
  6.9805843e-05 1.5387476e-04]
 [4.1164763e-04 1.0000000e-12 1.3923405e-04 ... 2.4732167e-04
  1.7294187e-04 2.2261821e-04]
 [7.6536162e-05 1.3923405e-04 1.0000000e-12 ... 4.0122587e-04
  3.2324783e-04 1.1109435e-03]
 ...
 [2.5737088e-04 2.4732167e-04 4.0122587e-04 ... 1.0000000e-12
  1.4182110e-04 4.4223890e-04]
 [6.9805843e-05 1.7294187e-04 3.2324783e-04 ... 1.4182110e-04
  1.0000000e-12 2.1406981e-04]
 [1.5387476e-04 2.2261821e-04 1.1109435e-03 ... 4.4223890e-04
  2.1406981e-04 1.0000000e-12]]
Q [[1.0000000e-12 4.9978946e-05 1.2624923e-04 ... 8.3590996e-05
  2.0909587e-04 1.5988105e-04]
 [4.9978946e-05 1.0000000e-12 1.9751908e-04 ... 5.4623780e-04
  9.2509639e-05 1.9815419e-04]
 [1.2624923e-04 1.9751908e-04 1.0000000e-12 ... 6.0966029e-04
  1.4086517e-04 1.0957340e-03]
 ...
 [8.3590996e-05 5.4623780e-04 6.0966029e-04 ... 1.0000000e-12


Array(0.18891683, dtype=float32)

In [18]:
H = jax.hessian(KL_divergence, argnums=1)(X_flat, Y_flat, X_unflattener, Y_unflattener)

Computing pairwise distances...
Starting binary search
Entered binary search function
P [[1.0000000e-12 4.1164763e-04 7.6536162e-05 ... 2.5737088e-04
  6.9805843e-05 1.5387476e-04]
 [4.1164763e-04 1.0000000e-12 1.3923405e-04 ... 2.4732167e-04
  1.7294187e-04 2.2261821e-04]
 [7.6536162e-05 1.3923405e-04 1.0000000e-12 ... 4.0122587e-04
  3.2324783e-04 1.1109435e-03]
 ...
 [2.5737088e-04 2.4732167e-04 4.0122587e-04 ... 1.0000000e-12
  1.4182110e-04 4.4223890e-04]
 [6.9805843e-05 1.7294187e-04 3.2324783e-04 ... 1.4182110e-04
  1.0000000e-12 2.1406981e-04]
 [1.5387476e-04 2.2261821e-04 1.1109435e-03 ... 4.4223890e-04
  2.1406981e-04 1.0000000e-12]]
Q Traced<ConcreteArray([[1.0000000e-12 4.9978946e-05 1.2624923e-04 ... 8.3590996e-05
  2.0909587e-04 1.5988105e-04]
 [4.9978946e-05 1.0000000e-12 1.9751908e-04 ... 5.4623780e-04
  9.2509639e-05 1.9815419e-04]
 [1.2624923e-04 1.9751908e-04 1.0000000e-12 ... 6.0966029e-04
  1.4086517e-04 1.0957340e-03]
 ...
 [8.3590996e-05 5.4623780e-04 6.0966029e-

In [19]:
print(H)

[[ 1.09934481e-02 -2.08088756e-03  1.34697184e-05 ...  1.44742022e-04
  -1.81146625e-05  3.73517832e-05]
 [-2.08088756e-03  1.52829634e-02 -1.23976934e-04 ... -8.31123907e-05
   5.59293112e-05 -1.10158326e-04]
 [ 1.34697766e-05 -1.23976992e-04  1.28834257e-02 ... -1.04032639e-04
  -5.11443504e-05  6.18489066e-05]
 ...
 [ 1.44742065e-04 -8.31123689e-05 -1.04032697e-04 ...  9.06807743e-03
  -7.40125324e-05 -1.38489588e-04]
 [-1.81146515e-05  5.59293185e-05 -5.11443941e-05 ... -7.40125834e-05
   2.62568612e-02  1.24325382e-03]
 [ 3.73517905e-05 -1.10158318e-04  6.18489139e-05 ... -1.38489588e-04
   1.24325452e-03  2.56244503e-02]]


In [3]:
from sklearn.datasets import make_blobs
X, y = make_blobs(n_samples=50, n_features=127, centers=4, random_state=0, shuffle=False, cluster_std=[0.1, 3, 3, 3])
key = random.PRNGKey(41)
#X = onp.array(random.normal(key, shape=(50, 50)))
y_guess = random.normal(key, shape=(X.shape[0], 2))
#Y_star = TSNE(n_components=2, learning_rate=200, init=onp.array(y_guess), perplexity=30).fit_transform(X)
Y_star = tsne_fwd(X, y_guess)

X_flat, X_unflattener = flatten_util.ravel_pytree(X)   # row-wise
Y_flat, Y_unflattener = flatten_util.ravel_pytree(Y_star) 

===> Finding 49 nearest neighbors using Annoy approximate search using euclidean distance...
   --> Time elapsed: 0.01 seconds
===> Calculating affinity matrix...
   --> Time elapsed: 0.00 seconds
===> Running optimization with exaggeration=12.00, lr=200.00 for 250 iterations...
Iteration   50, KL divergence 1.0322, 50 iterations in 0.5239 sec
Iteration  100, KL divergence 0.9489, 50 iterations in 0.3998 sec
Iteration  150, KL divergence 0.9039, 50 iterations in 0.3732 sec
Iteration  200, KL divergence 0.9917, 50 iterations in 0.3724 sec
Iteration  250, KL divergence 0.9602, 50 iterations in 0.3707 sec
   --> Time elapsed: 2.04 seconds
===> Running optimization with exaggeration=1.00, lr=200.00 for 750 iterations...
Iteration   50, KL divergence 0.0145, 50 iterations in 0.3714 sec
Iteration  100, KL divergence 0.0145, 50 iterations in 0.3672 sec
Iteration  150, KL divergence 0.0145, 50 iterations in 0.3663 sec
Iteration  200, KL divergence 0.0145, 50 iterations in 0.3639 sec
Iteration 

In [4]:
#f = lambda X, Y: KL_divergence(X, Y, X_unflattener, Y_unflattener)
J_X_Y = jacfwd(jacfwd(KL_divergence, argnums=1), argnums=0)(X_flat, Y_flat, X_unflattener, Y_unflattener)
print(J_X_Y)

Computing pairwise distances...
Starting binary search
Entered binary search function
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]


In [6]:
print(jacfwd(KL_divergence, argnums=0)(X_flat, Y_flat, X_unflattener, Y_unflattener))

Computing pairwise distances...
Starting binary search
Entered binary search function
[0. 0. 0. ... 0. 0. 0.]
