In [7]:
import numpy as np 
import jax
import jax.numpy as jnp
from jax import grad, random, vmap
import optax
import math

In [8]:
X = np.load('/Users/ariellerosinski/My Drive/Cambridge/Project/churchland.npy') 
X = jnp.array(X) 
print(X.shape)
K, N, T = X.shape

(108, 218, 61)


In [9]:
X_centered = X - jnp.mean(X, axis=0)                #(K, N, T) - (N, T) = (K, N, T)
A = jnp.swapaxes(X_centered, 0, 1)                  #(N, K, T)
A = A.reshape(N,-1)                                 #(N, K*T)

In [10]:
def K_X_Y(X, Y, sigma_sqrd):
    """For two spatial patterns X and Y, the kernel k(x_i,y_i) is equal to sum_i sigma_i^2 x_i y_i"""
    return jnp.dot(X.T * sigma_sqrd, Y) 

In [11]:
def stack(alpha, sigma_sqrd):
    return jnp.concatenate([alpha.reshape(-1), sigma_sqrd.reshape(-1)]) 

def unstack(params,K=108,T=61, D=3,N=218):
    alpha, sigma_sqrd = jnp.split(params, [K*T*D])
    alpha = alpha.reshape(K*T, D)
    sigma_sqrd.reshape(N,)
    return alpha, sigma_sqrd

In [12]:
def get_alpha(alpha_tilde, A, sigma_sqrd):
    K_A_A = K_X_Y(A, A, sigma_sqrd)
    K_A_A_reshaped = K_A_A.reshape(K,T,K,T)                          #(K,T,K,T)
    means = jnp.mean(K_A_A_reshaped, axis=(0, 2), keepdims=True)     #(1, T, 1, T)
    K_A_A_tilde = (K_A_A_reshaped - means).reshape(K*T,K*T)          #(K*T,K*T)
    P, S, Pt = jnp.linalg.svd(K_A_A_tilde, full_matrices=False)      #P is (K*T, K*T) and S is (K*T,)

    alpha_tilde_QR, _ = jnp.linalg.qr(alpha_tilde) 

    alpha = jnp.dot(P , 1/jnp.sqrt(S))[:,None] * alpha_tilde_QR

    print(f' P: {jnp.isnan(P).any()}' )
    print(f' 1/jnp.sqrt(S): {jnp.isnan(1/jnp.sqrt(S)).any()}' )
    print(f' alpha_tilde_QR: {jnp.isnan(alpha_tilde_QR).any()}' )
    print(f' alpha: {jnp.isnan(alpha).any()}' )
    return alpha

def single_pair_loss(alpha_H, sigma_sqrd, A, X_centered, id_1, id_2):
    K_A_X = K_X_Y(A, X_centered[id_1], sigma_sqrd)
    K_X_A = K_X_Y(X_centered[id_2], A, sigma_sqrd)
    
    Q = alpha_H.T @ K_A_X @ K_X_A @ alpha_H                         #(KT,D).T @ (KT,T) and (T,KT) @ (KT,D) --> (D,T) @ (T,D) --> (D,D)
    QQ_product = jnp.einsum('ij,lm->im', Q, Q)
    S_pair = jnp.trace(Q)**2 - jnp.trace(QQ_product)
    return S_pair 

def loss(params, A, X_centered, key, i,D=3):  
    K, N, T = X_centered.shape
    alpha_tilde, sigma_sqrd = unstack(params)
    alpha = get_alpha(alpha_tilde, A, sigma_sqrd)


    alpha_reshaped = alpha.reshape(K,T,D)                           #(K, T, D)
    mean = jnp.mean(alpha_reshaped, axis=(0), keepdims=True)        #(1, T, D)
    alpha_H = (alpha_reshaped - mean).reshape(K*T,D)                #(K*T,D)

    num_pairs = 10  
    indices = random.randint(key, shape=(num_pairs*2,), minval=0, maxval=N)
    index_pairs = indices.reshape((num_pairs, 2))

    batched_loss = vmap(single_pair_loss, in_axes=(None, None, None, None, 0, 0))(alpha_H, sigma_sqrd, A, X_centered, index_pairs[:, 0], index_pairs[:, 1]) #(num_pairs)

    S = (2 / (K**2) ) * jnp.sum(batched_loss)
    jax.debug.print("-S: {}", -S)
    return -S

def update(params, A, X_centered, optimizer, opt_state, key,i):
    grad_loss = grad(loss)(params, A, X_centered, key,i)
  
    updates, opt_state_updated = optimizer.update(grad_loss, opt_state, params)
    params_updated = optax.apply_updates(params, updates)
    return params_updated, opt_state_updated

def optimize_params(A, X_centered, iterations=100, learning_rate=0.001, D=3, seed=42):
    K, N, T = X_centered.shape
    key = random.PRNGKey(seed)
    
    sigma_sqrd = random.normal(key, (N,))
    alpha_tilde = random.normal(key, (K*T, D))
    
    params = stack(alpha_tilde, sigma_sqrd)
    
    keys = random.split(key, num=iterations)

    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(params)
    
    for i in range(iterations):
        params, opt_state = update(params, A, X_centered,  optimizer, opt_state, keys[i],i)

    return params

optimized_params = optimize_params(A, X_centered)

 P: False
 1/jnp.sqrt(S): False
 alpha_tilde_QR: False
 alpha: False
-S: 217824853229568.0
 P: True
 1/jnp.sqrt(S): True
 alpha_tilde_QR: False
 alpha: True
-S: nan
 P: True
 1/jnp.sqrt(S): True
 alpha_tilde_QR: True
 alpha: True
-S: nan
 P: True
 1/jnp.sqrt(S): True
 alpha_tilde_QR: True
 alpha: True
-S: nan
 P: True
 1/jnp.sqrt(S): True
 alpha_tilde_QR: True
 alpha: True
-S: nan
 P: True
 1/jnp.sqrt(S): True
 alpha_tilde_QR: True
 alpha: True
-S: nan
 P: True
 1/jnp.sqrt(S): True
 alpha_tilde_QR: True
 alpha: True
-S: nan
 P: True
 1/jnp.sqrt(S): True
 alpha_tilde_QR: True
 alpha: True
-S: nan
 P: True
 1/jnp.sqrt(S): True
 alpha_tilde_QR: True
 alpha: True
-S: nan
 P: True
 1/jnp.sqrt(S): True
 alpha_tilde_QR: True
 alpha: True
-S: nan
 P: True
 1/jnp.sqrt(S): True
 alpha_tilde_QR: True
 alpha: True
-S: nan
 P: True
 1/jnp.sqrt(S): True
 alpha_tilde_QR: True
 alpha: True
-S: nan
 P: True
 1/jnp.sqrt(S): True
 alpha_tilde_QR: True
 alpha: True
-S: nan
 P: True
 1/jnp.sqrt(S): True
 a

KeyboardInterrupt: 

In [22]:
a = jnp.ones((1,))
#a = a.at[0,0].set(math.nan)
a_check = jnp.isnan(a)
print(a_check)
print(a_check.any())

[False]
False


ENDS HERE