In [1]:
import numpy as np
import matplotlib.pyplot as plt
import h5py
import os

# Load response matrix

f = h5py.File(os.path.join("neuron",'centered_normalized_response_matrix_100_neurons_13k_steps_nonuniform_half_random_walk.h5'), 'r')
response_matrix = np.array(f.get('raster')).T
print(response_matrix.shape)

(13000, 100)


In [2]:
ys = response_matrix[:, :, np.newaxis]
ys = ys - np.mean(ys, axis=0)

# Make PCA faster

In [3]:
def PCA1(ys, n):
    S_hat = np.sum(np.array([y.dot(y.T) for y in ys]), axis=0)
    _, v = np.linalg.eigh(S_hat)

    return np.flip(v[:, -n:], axis=-1)

def PCA2(ys, n):
    S_hat = np.concatenate(ys, axis=-1)
    u, _, _ = np.linalg.svd(S_hat, full_matrices=False)

    return u[:, :n]

In [4]:
%timeit PCA1(ys, 2)
%timeit PCA2(ys, 2)

370 ms ± 17.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
113 ms ± 1.48 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
a1 = PCA1(ys, 2)
a2 = PCA2(ys, 2)
np.allclose(a1, a2)

True

# Make yhat projection faster

In [6]:
def orth_proj(X, k=None):
    u, _, vh = np.linalg.svd(X)
    if not k:
        k = len(vh)
    return u[:, :k].dot(vh)

def yhat_alpha(alpha, y):
    return orth_proj(alpha.T.dot(y))

def yhat_alpha_all1(alpha, ys):
    return np.array([yhat_alpha(alpha, y) for y in ys])

def yhat_alpha_all2(alpha, ys):
  Y = alpha.T.dot(ys).transpose((1, 0, 2)) # faster version of np.array([alpha.T.dot(y) for y in ys])
  u, s, vh = np.linalg.svd(Y) # faster version of np.array([np.linalg.svd(y) for y in Y])
  k = vh.shape[-1]
  return np.einsum('ijk,ikl->ijl', u[:, :, :k], vh)

In [7]:
%timeit yhat_alpha_all1(a1, ys)
%timeit yhat_alpha_all2(a1, ys)

70.4 ms ± 904 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
8.86 ms ± 141 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
x1 = yhat_alpha_all1(a1, ys)
x2 = yhat_alpha_all2(a1, ys)
np.allclose(x1, x2)

True

# Make pi_alpha faster

In [9]:
def pi_alpha(alpha, y):
    return alpha.dot(yhat_alpha(alpha, y))

def pi_alpha_all1(alpha, ys):
    return np.array([pi_alpha(alpha, y) for y in ys])

def pi_alpha_all2(alpha, ys):
  yhats = yhat_alpha_all2(alpha, ys)
  return alpha.dot(yhats).transpose((1,0,2))

In [10]:
%timeit pi_alpha_all1(a1, ys)
%timeit pi_alpha_all2(a1, ys)

87.9 ms ± 722 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
15.9 ms ± 288 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
x1 = pi_alpha_all1(a1, ys)
x2 = pi_alpha_all2(a1, ys)
np.allclose(x1, x2)

True

# Make cost calculation faster

In [12]:
def projection_cost1(alpha, ys):
    return np.sum([np.linalg.norm(y - pi_alpha(alpha, y), 'fro')**2 for y in ys])/len(ys)

def projection_cost2(alpha, ys):
    return np.sum((ys - pi_alpha_all2(alpha, ys))**2)/len(ys)

In [13]:
%timeit projection_cost1(a1, ys)
%timeit projection_cost2(a1, ys)

111 ms ± 858 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
16.8 ms ± 162 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
x1 = projection_cost1(a1, ys)
x2 = projection_cost2(a1, ys)
np.allclose(x1, x2)

True

In [15]:
def nuc_cost1(alpha, ys):
    return -np.sum([np.linalg.norm(alpha.T.dot(y), 'nuc') for y in ys])/len(ys)

def nuc_cost2(alpha, ys):
  Y = alpha.T.dot(ys).transpose((1, 0, 2))
  u, s, vh = np.linalg.svd(Y)
  return -np.sum(s)/len(ys)


In [16]:
%timeit nuc_cost1(a1, ys)
%timeit nuc_cost2(a1, ys)

106 ms ± 3.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
8.81 ms ± 159 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
x1 = nuc_cost1(a1, ys)
x2 = nuc_cost2(a1, ys)
np.allclose(x1, x2)

True

# Make manopt_alpha faster

In [18]:
import autograd.numpy as anp
from pymanopt.manifolds.stiefel import Stiefel
from pymanopt.core.problem import Problem
from pymanopt.function import autograd, numpy
from pymanopt import optimizers

def manopt_alpha1(ys, alpha_init, verbosity=1):
    N, n = alpha_init.shape
    st_Nn = Stiefel(N, n)

    @autograd(st_Nn)
    def cost(point):
        return -anp.sum([anp.linalg.norm(anp.dot(point.T, y), 'nuc') for y in ys])/len(ys)

    problem = Problem(st_Nn, cost=cost)
    optimizer = optimizers.SteepestDescent(verbosity=verbosity)
    res = optimizer.run(problem, initial_point=alpha_init).point

    # print('nuc_cost of initial alpha', cost(alpha_init), 'nuc_cost of final alpha', cost(res))
    # print('projection_cost of initial alpha', projection_cost(alpha_init, ys), 'projection_cost of final alpha', projection_cost(res, ys))
    return res

def manopt_alpha2(ys, alpha_init, verbosity=1):
    N, n = alpha_init.shape
    st_Nn = Stiefel(N, n)

    @autograd(st_Nn)
    def cost(point):
        Y = anp.dot(anp.transpose(point), ys)
        Y = anp.swapaxes(Y, 1, 0)
        u, s, vh = anp.linalg.svd(Y, full_matrices=False)
        return -anp.sum(s)/len(ys)

    problem = Problem(st_Nn, cost=cost)
    optimizer = optimizers.SteepestDescent(verbosity=verbosity)
    res = optimizer.run(problem, initial_point=anp.array(alpha_init)).point

    # print('nuc_cost of initial alpha', cost(alpha_init), 'nuc_cost of final alpha', cost(res))
    # print('projection_cost of initial alpha', projection_cost(alpha_init, ys), 'projection_cost of final alpha', projection_cost(res, ys))
    return res

In [19]:
x1 = manopt_alpha1(ys, a1, verbosity=2)
x2 = manopt_alpha2(ys, a1, verbosity=2)
np.allclose(x1, x2) # seems to not give the same result

Optimizing...
Iteration    Cost                       Gradient norm     
---------    -----------------------    --------------    
   1         -3.6574821471507646e-01    6.68931555e-02    
   2         -3.7816227656240808e-01    8.44618316e-02    
   3         -3.9996651832323760e-01    3.58912720e-02    
   4         -4.0219798555868946e-01    4.27138891e-02    
   5         -4.0727973361841507e-01    1.13127138e-02    
   6         -4.0752822439418435e-01    1.42102700e-02    
   7         -4.0816024022745140e-01    5.06526215e-03    
   8         -4.0831577774482508e-01    3.66670567e-03    
   9         -4.0841210843044384e-01    6.53883376e-03    
  10         -4.0856908822534016e-01    4.27192289e-03    
  11         -4.0861537303225109e-01    7.85643646e-03    
  12         -4.0875315424559294e-01    4.21594358e-03    
  13         -4.0882286823402730e-01    6.89291590e-03    
  14         -4.0896794268608044e-01    3.34659866e-03    
  15         -4.0912372881806747e-01    6.

False