In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scipy as sp
import numpy as np
import time

seed = 10
key = jax.random.PRNGKey(seed)

#def enable_float64():
#  """Tell jax to enable float64."""
#  jax.config.update('jax_enable_x64', True)

#enable_float64()

## SGD

In [2]:
def jax_lsq_momentum1(key,
                g1, g2, g3, delta, batch, steps, init_x, init_w,
                t_oracle, loss, loss_times = jnp.array([0])
                ):
  """ This routine generates losses for SGD on the least squares
  problem with scalar targets, constant learning rate and constant batch size.

  Parameters
  ----------
  key : PRNGKey
    Jax PRNGKey
  g1, g2, g3 : function(time)
    The learning rate functions
  delta : function(time)
    The momentum function
  batch : int
    The batch-size to use
  steps : int
    The number of steps of minibatch SGD to generate
  init_x : vector
    The initial state for SGD to use
  init_w : vector
    The initial state for momentum
  traceK : float
    Trace of the covariance matrix of the data; i.e., K = E[aa^T]
  t_oracle: callable
    Takes as an argument a jax RNG key and a batch-size.
    Expects in return two tensors (A, y)
    of dimension (batch x data-dimension) and dimension (batch).
  loss: callable
    Takes as an argument a vector of length data-dimension,
    which is the current linear model parameters, and returns the
    loss.
  loss_times: vector
    Iteration counts at which to compute the loss

  Returns
  -------
  losses: vector
    An array of length 'steps' containing the losses
  loss_times: vector
    Iteration counts at which the losses were computed
  """

  if loss_times.shape[0]==1:
    loss_times = jnp.arange(steps)
  x = jnp.reshape(init_x,(len(init_x),1))
  w = jnp.reshape(init_w,(len(init_w),1))


  def update(z, things):
    keyz, iteration = things
    A,y = t_oracle(keyz, batch)
    x,w = z
    grad = jnp.tensordot(A,jnp.tensordot(A,x,axes=1)-y,axes=[[0],[0]])
    neww = (1.0 - delta(iteration)) * w + g1(iteration) * grad
    newx = x - g2(iteration) * grad - neww * g3(iteration)
    return (newx,neww), x

  keys=jax.random.split(key,steps)
  iters =  jnp.linspace(0.0, steps, num = steps)


 # update_jit = jax.jit(update)
  _, states = jax.lax.scan(update,(x,w),(keys,iters))

  return jax.lax.map(loss, states[loss_times[loss_times< steps]]), loss_times[loss_times< steps], states[-1]

In [3]:
def jax_lsq_momentum1_opt2(key,
                g1, g2, g3, delta, batch, steps, init_x, init_w,
                t_oracle, loss
                ):
  """ This routine generates losses for SGD on the least squares
  problem with scalar targets, constant learning rate and constant batch size.
  It has reasonable memory efficiency than the optimized version by only storing losses at loss_times.

  Parameters
  ----------
  key : PRNGKey
    Jax PRNGKey
  lr1 : float
    The learning rate to use \gamma_1; should be constant/(tr(K)^2)
  lr2 : float
    The learning rate to use \gamma_2; should be constant/(tr(K))
  theta : float
    The momentum parameter to use
  batch : int
    The batch-size to use
  steps : int
    The number of steps of minibatch SGD to generate
  init_x : vector
    The initial state for SGD to use
  init_w : vector
    The initial state for momentum
  traceK : float
    Trace of the covariance matrix of the data; i.e., K = E[aa^T]
  t_oracle: callable
    Takes as an argument a jax RNG key and a batch-size.
    Expects in return two tensors (A, y)
    of dimension (batch x data-dimension) and dimension (batch).
  loss: callable
    Takes as an argument a vector of length data-dimension,
    which is the current linear model parameters, and returns the
    loss.
  loss_times: vector
    Iteration counts at which to compute the loss

  Returns
  -------
  losses: vector
    An array of length 'steps' containing the losses
  loss_times: vector
    Iteration counts at which the losses were computed
  """

  if steps < 10**5:
    return jax_lsq_momentum1(key,
                g1, g2, g3, delta, batch, steps, init_x, init_w,
                t_oracle, loss)
  x = jnp.reshape(init_x,(len(init_x),1))
  w = jnp.reshape(init_w,(len(init_w),1))

  def update(z, things): #things = keys for your stochastic updates and momentum terms
    keyz, iteration = things
    A,y = t_oracle(keyz, batch)
    x,w = z
    #delta = theta / (iteration + traceK)
    grad = jnp.tensordot(A,jnp.tensordot(A,x,axes=1)-y,axes=[[0],[0]])
    neww = (1.0 - delta(iteration)) * w + g1(iteration) * grad
    newx = x - g2(iteration) * grad - neww * g3(iteration)
    return (newx,neww), x

  def skinny_update(z, things):
    keyz, iteration = things
    A,y = t_oracle(keyz, batch)
    x,w = z
    #delta = theta / (iteration + traceK)
    grad = jnp.tensordot(A,jnp.tensordot(A,x,axes=1)-y,axes=[[0],[0]])
    neww = ( 1.0 - delta(iteration) ) * w + g1(iteration) * grad
    newx = x - g2(iteration) * grad - neww * g3(iteration)
    return (newx,neww), False

  p = np.int32(np.ceil(np.log10(steps+1)))

  mkey1,mkey2= jax.random.split(key)
  keys=jax.random.split(mkey1,10**5)
  #deltas = theta / ( jnp.linspace(0.0, 10**5, num = 10**5) + traceK)
  iters = jnp.linspace(0.0, 10**5, num = 10**5)

 # update_jit = jax.jit(update)
  z, states = jax.lax.scan(update,(x,w),(keys, iters))

  losslist = jax.lax.map(loss,states)
  timelist = jnp.arange(1,10**5+1,1)
  lastiter = 10**5

  mkeyout =  jax.random.split(mkey2,p-5)
  for j, mkey in enumerate(mkeyout,start=5):
    u=j-2
    def outerloop(xw, thingz):
      keyz,currentiter = thingz
      mkeys = jax.random.split(keyz, 10**u)
      iterlist = currentiter + jnp.arange(0, 10**u, dtype = jnp.float32)
      #deltas= theta / (iterlist + traceK)
      (newx,neww), _ = jax.lax.scan(skinny_update,xw,(mkeys,iterlist))
      return (newx,neww), loss(newx)
    outerloopsteps = min( (steps-lastiter)//(10**u), 100)
    #outerloopitercounts = lastiter + (10**u)*jnp.arange(1,outerloopsteps+1,1)
    outerloopitercounts = lastiter + (10**u)*jnp.arange(0,outerloopsteps,1)
    keys=jax.random.split(mkey,outerloopsteps)
    z, late_loss = jax.lax.scan(outerloop,z,(keys,outerloopitercounts))
    losslist=jnp.concatenate([losslist,late_loss])
    timelist =jnp.concatenate([timelist,outerloopitercounts])
    lastiter += 10**j
  return losslist, timelist, z[0]

  # return losses,loss_times

## ODE

In [4]:
def ode_resolvent_log_implicit_full(eigs_K, rho_init, chi_init, sigma_init, risk_infinity,
                  g1, g2, g3, delta, batch, D, t_max, Dt):
  """Generate the theoretical solution to momentum

  Parameters
  ----------
  eigs_K : array d
      eigenvalues of covariance matrix (W^TDW)
  rho_init : array d
    initial rho_j's (rho_j^2)
  chi_init : array (d)
      initialization of chi's
  sigma_init : array (d)
      initialization of sigma's (xi^2_j)
  risk_infinity : scalar
      represents the risk value at time infinity
  WtranD : array (v x d)
      WtranD where D = diag(j^(-2alpha)) and W is the random matrix
  alpha : float
      data complexity
  V : float
      vocabulary size
  g1, g2, g3 : function(time)
      learning rate functions
  delta : function(time)
      momentum function
  batch : int
      batch size
  D : int
      number of eigenvalues (i.e. shape of eigs_K)
  t_max : float
      The number of epochs
  Dt : float
      time step used in Euler

  Returns
  -------
  t_grid: numpy.array(float)
      the time steps used, which will discretize (0,t_max) into n_grid points
  risks: numpy.array(float)
      the values of the risk

  """
  #times = jnp.arange(0, t_max, step = Dt, dtype= jnp.float64)
  times = jnp.arange(0, jnp.log(t_max), step = Dt, dtype= jnp.float32)

  risk_init = risk_infinity + jnp.sum(eigs_K * rho_init)

  def inverse_3x3(Omega):
      # Extract matrix elements
      a11, a12, a13 = Omega[0][0], Omega[0][1], Omega[0][2]
      a21, a22, a23 = Omega[1][0], Omega[1][1], Omega[1][2]
      a31, a32, a33 = Omega[2][0], Omega[2][1], Omega[2][2]

      # Calculate determinant
      det = (a11*a22*a33 + a12*a23*a31 + a13*a21*a32
            - a13*a22*a31 - a11*a23*a32 - a12*a21*a33)

      #if abs(det) < 1e-10:
      #    raise ValueError("Matrix is singular or nearly singular")

      # Calculate each element of inverse matrix
      inv = [[0,0,0],[0,0,0],[0,0,0]]

      inv[0][0] = (a22*a33 - a23*a32) / det
      inv[0][1] = (a13*a32 - a12*a33) / det
      inv[0][2] = (a12*a23 - a13*a22) / det

      inv[1][0] = (a23*a31 - a21*a33) / det
      inv[1][1] = (a11*a33 - a13*a31) / det
      inv[1][2] = (a13*a21 - a11*a23) / det

      inv[2][0] = (a21*a32 - a22*a31) / det
      inv[2][1] = (a12*a31 - a11*a32) / det
      inv[2][2] = (a11*a22 - a12*a21) / det

      return jnp.array(inv)

  def odeUpdate(stuff, time):
    v, risk = stuff
    timePlus = jnp.exp(time + Dt)

    Omega11 = -2.0 * batch * g2(timePlus) * eigs_K + batch * (batch + 1.0) * g2(timePlus)**2 * eigs_K**2
    Omega12 = g3(timePlus)**2 * jnp.ones_like(eigs_K)
    Omega13 = 2.0 * g3(timePlus) * (-1.0 + g2(timePlus) * batch * eigs_K)
    Omega1 = jnp.array([Omega11, Omega12, Omega13])

    Omega21 = batch * (batch + 1.0) * g1(timePlus)**2 * eigs_K**2
    Omega22 = ( -2.0 * delta(timePlus) + delta(timePlus)**2 ) * jnp.ones_like(eigs_K)
    Omega23 = 2.0 * g1(timePlus) * eigs_K * batch * ( 1.0 - delta(timePlus) )
    Omega2 = jnp.array([Omega21, Omega22, Omega23])

    Omega31 = g1(timePlus) * batch * eigs_K
    Omega32 = -g3(timePlus) * jnp.ones_like(eigs_K)
    Omega33 = -delta(timePlus) - g2(timePlus) * batch * eigs_K
    Omega3 = jnp.array([Omega31, Omega32, Omega33])

    Omega = jnp.array([Omega1, Omega2, Omega3]) #3 x 3 x d

    Identity = jnp.tensordot( jnp.eye(3), jnp.ones(D), 0 )

    A = inverse_3x3(Identity - (Dt * timePlus) * Omega) #3 x 3 x d

    Gamma = jnp.array([batch * g2(timePlus)**2, batch * g1(timePlus)**2, 0.0])
    z = jnp.einsum('i, j -> ij', jnp.array([1.0, 0.0, 0.0]), eigs_K)
    G_Lambda = jnp.einsum('i,j->ij', Gamma, eigs_K) #3 x d

    x_temp = v + Dt * timePlus * risk_infinity * G_Lambda
    x = jnp.einsum('ijk, jk -> ik', A, x_temp)

    y = jnp.einsum('ijk, jk -> ik', A, G_Lambda)

    vNew = x + ( Dt * timePlus * y * jnp.sum(x * z) / (1.0 - Dt * timePlus * jnp.sum(y * z)) )
    #vNew = vNew.at[0].set(jnp.maximum(vNew[0], 0.0))
    #vNew[0] = jnp.maximum(vNew[0], 10**(-7))
    #vNew[2] = jnp.maximum(vNew[2], 10**(-7))

    riskNew = risk_infinity + jnp.sum(eigs_K * vNew[0])
    return (vNew, riskNew), risk #(risk, vNew[0])

  _, risks = jax.lax.scan(odeUpdate,(jnp.array([rho_init,sigma_init, chi_init]),risk_init),times)
  return jnp.exp(times), risks

This is the approximate ODE, i.e., it is an ODE where we drop some of the lower terms in the ODE. This should only be used when **batch size = 1**. It is a further approximation of the ODE that is solved in the method ode_resolvent_log_implicit_full.

In [5]:
def ode_resolvent_log_implicit_approximate(eigs_K, rho_init, chi_init, sigma_init, risk_infinity,
                  g1, g2, g3, delta, batch, D, t_max, Dt):
  """Generate the theoretical solution to momentum

  Parameters
  ----------
  eigs_K : array d
      eigenvalues of covariance matrix (W^TDW)
  rho_init : array d
    initial rho_j's
  chi_init : array (d)
      initialization of chi's
  sigma_init : array (d)
      initialization of sigma's
  risk_infinity : scalar
      represents the risk value at time infinity
  WtranD : array (v x d)
      WtranD where D = diag(j^(-2alpha)) and W is the random matrix
  alpha : float
      data complexity
  V : float
      vocabulary size
  g1, g2, g3 : function(time)
      learning rate functions
  delta : function(time)
      momentum function
  batch : int
      batch size
  D : int
      number of eigenvalues (i.e. shape of eigs_K)
  t_max : float
      The number of epochs
  Dt : float
      time step used in Euler

  Returns
  -------
  t_grid: numpy.array(float)
      the time steps used, which will discretize (0,t_max) into n_grid points
  risks: numpy.array(float)
      the values of the risk

  """
  #times = jnp.arange(0, t_max, step = Dt, dtype= jnp.float64)
  times = jnp.arange(0, jnp.log(t_max), step = Dt, dtype= jnp.float32)

  risk_init = risk_infinity + jnp.sum(eigs_K * rho_init)

  def inverse_3x3(Omega):
      # Extract matrix elements
      a11, a12, a13 = Omega[0][0], Omega[0][1], Omega[0][2]
      a21, a22, a23 = Omega[1][0], Omega[1][1], Omega[1][2]
      a31, a32, a33 = Omega[2][0], Omega[2][1], Omega[2][2]

      # Calculate determinant
      det = (a11*a22*a33 + a12*a23*a31 + a13*a21*a32
            - a13*a22*a31 - a11*a23*a32 - a12*a21*a33)

      #if abs(det) < 1e-10:
      #    raise ValueError("Matrix is singular or nearly singular")

      # Calculate each element of inverse matrix
      inv = [[0,0,0],[0,0,0],[0,0,0]]

      inv[0][0] = (a22*a33 - a23*a32) / det
      inv[0][1] = (a13*a32 - a12*a33) / det
      inv[0][2] = (a12*a23 - a13*a22) / det

      inv[1][0] = (a23*a31 - a21*a33) / det
      inv[1][1] = (a11*a33 - a13*a31) / det
      inv[1][2] = (a13*a21 - a11*a23) / det

      inv[2][0] = (a21*a32 - a22*a31) / det
      inv[2][1] = (a12*a31 - a11*a32) / det
      inv[2][2] = (a11*a22 - a12*a21) / det

      return jnp.array(inv)

  def odeApproximateUpdate(stuff, time):
    v, risk = stuff
    timePlus = jnp.exp(time + Dt)

    Omega11 = -2.0 * batch * g2(timePlus) * eigs_K
    Omega12 = 0.0 * jnp.ones_like(eigs_K)
    Omega13 = 2.0 * g3(timePlus) * -1.0 * jnp.ones_like(eigs_K)
    Omega1 = jnp.array([Omega11, Omega12, Omega13])

    Omega21 = 0.0 * jnp.ones_like(eigs_K)
    Omega22 = ( -2.0 * delta(timePlus) ) * jnp.ones_like(eigs_K)
    Omega23 = 2.0 * g1(timePlus) * eigs_K * batch
    Omega2 = jnp.array([Omega21, Omega22, Omega23])

    Omega31 = g1(timePlus) * batch * eigs_K
    Omega32 = -g3(timePlus) * jnp.ones_like(eigs_K)
    Omega33 = -delta(timePlus) - g2(timePlus) * batch * eigs_K
    Omega3 = jnp.array([Omega31, Omega32, Omega33])

    Omega = jnp.array([Omega1, Omega2, Omega3]) #3 x 3 x d

    Identity = jnp.tensordot( jnp.eye(3), jnp.ones(D), 0 )

    A = inverse_3x3(Identity - (Dt * timePlus) * Omega) #3 x 3 x d

    Gamma = jnp.array([batch * g2(timePlus)**2, batch * g1(timePlus)**2, 0.0])
    z = jnp.einsum('i, j -> ij', jnp.array([1.0, 0.0, 0.0]), eigs_K)
    G_Lambda = jnp.einsum('i,j->ij', Gamma, eigs_K) #3 x d

    x_temp = v + Dt * timePlus * risk_infinity * G_Lambda
    x = jnp.einsum('ijk, jk -> ik', A, x_temp)

    y = jnp.einsum('ijk, jk -> ik', A, G_Lambda)

    vNew = x + ( Dt * timePlus * y * jnp.sum(x * z) / (1.0 - Dt * timePlus * jnp.sum(y * z)) )
    #vNew = vNew.at[0].set(jnp.maximum(vNew[0], 0.0))
    #vNew[0] = jnp.maximum(vNew[0], 10**(-7))
    #vNew[2] = jnp.maximum(vNew[2], 10**(-7))

    riskNew = risk_infinity + jnp.sum(eigs_K * vNew[0])
    return (vNew, riskNew), risk #(risk, vNew[0])

  _, risks = jax.lax.scan(odeApproximateUpdate,(jnp.array([rho_init,sigma_init, chi_init]),risk_init),times)
  return jnp.exp(times), risks

## Theory for limit level of loss and left spectral edge

In [10]:
def tt_lmin(alpha):
    """Generate left edge of the spectral measure (not accurate and only for alpha > 0.5)

    Parameters
    ----------
    alpha : float
        parameter of the model, ASSUMES V>D

    Returns
    -------
    theoretical prediction for the norm
    """

    TMAX = 1000.0
    c, _ = sp.integrate.quad(lambda x: 1.0/(1.0+x**(2*alpha)),0.0,TMAX)

    return (1/(2*alpha-1))*((2*alpha/(2*alpha-1)/c)**(-2*alpha))

In [11]:
def tt_dbetacirc_VD(alpha, beta,V,D):
    """Generate the 'exact' finite V, D expression for $D^{1/2}\circ{\beta}$.
    This is accurate for alpha.

    This generates the finite V,D expression for the residual level (risk at time infinity)

    Parameters
    ----------
    alpha,beta : floats
        parameters of the model, ASSUMES V>D

    Returns
    -------
    theoretical prediction for the norm
    """

    cstar = 0.0
    if 2*alpha >= 1.0:
        kappa = tt_kappa_VD(alpha,V,D)
        cstar = jnp.sum( jnp.arange(1,V,1.0)**(-2.0*(beta+alpha))/( jnp.arange(1,V,1.0)**(-2.0*(alpha))*kappa*(D**(2*alpha)) + 1.0))

    if 2*alpha < 1.0:
        #tau = D/jnp.sum( jnp.arange(1,V,1.0)**(-2.0*alpha))
        tau = tt_tau_VD(alpha,V,D)
        cstar = jnp.sum( jnp.arange(1,V,1.0)**(-2.0*(beta+alpha))/( jnp.arange(1,V,1.0)**(-2.0*(alpha))*tau + 1.0))


    return cstar


In [12]:
def tt_kappa_VD(alpha, V,D):
    """Generate coefficient kappa with finite sample corrections.


    Parameters
    ----------
    alpha : float
        parameter of the model.
    V,D : integers
        parameters of the model.

    Returns
    -------
    theoretical prediction for kappa parameter
    """

    TMAX = 1000.0
    c, _ = sp.integrate.quad(lambda x: 1.0/(1.0+x**(2*alpha)),0.0,TMAX)
    kappa=c**(-2.0*alpha)

    kappa_it = lambda k : sp.integrate.quad(lambda x: 1.0/(k+x**(2*alpha)),0.0,V/D)[0]
    eps = 10E-4
    error = 1.0
    while error > eps:
        kappa1 = 1.0/kappa_it(kappa)
        error = abs(kappa1/kappa - 1.0)
        kappa = kappa1
    return kappa

In [13]:
def tt_tau_VD(alpha, V,D):
    """Generate coefficient tau with finite sample corrections.


    Parameters
    ----------
    alpha : float
        parameter of the model.
    V,D : integers
        parameters of the model.

    Returns
    -------
    theoretical prediction for kappa parameter
    """

    tau_it = lambda k : jnp.sum( 1.0/(D*(jnp.arange(1,V,1)**(2*alpha) +k)))
    tau = tau_it(0)
    eps = 10E-4
    error = 1.0
    while error > eps:
        tau1 = 1.0/tau_it(tau)
        error = abs(tau1/tau - 1.0)
        tau = tau1
    return tau

## Newton Theory (Elliot)

Solves for the spectral distribution weighted by $D^{1/2} b$,
$$
\langle (\hat{K}-z)^{-1}, D^{1/2} b\rangle,
$$
using **Newton Method**. It removes the point mass at $0$ from the spectrum. You will need to add this quantity back in.

*Note this algorithm is much faster than the fixed point iteration above

In [14]:
def jax_gen_m_batched(v,d, alpha, xs,
                eta = -6,
                eta0 = 6.0,
                etasteps=50,
                batches = 100,
                zbatch=1000):
    """Generate the powerlaw m by newton's method


    Parameters
    ----------
    v,d,alpha : floats
        parameters of the model
    xs : vector
        The vector of x-positions at which to estimate the spectrum.  Complex is also possible.
    eta : float
        Error tolerance

    Returns
    -------
    m_Lambda: vector
        m_Lambda evaluated at xs.
    """
    #if zbatch > 0:
    #    xsplit = jnp.split(xs,jnp.arange(1,len(xs)//zbatch,1)*zbatch)
    #    ms = jnp.concatenate( [jax_gen_m_batched(v,d,alpha,x,eta,eta0,etasteps,batches,zbatch=0) for x in xsplit] )
    #    return ms
    v=jnp.int32(v)
    d=jnp.complex64(d)
    xs=jnp.complex64(xs)
    xsplit = jnp.split(xs,jnp.arange(1,len(xs)//zbatch,1)*zbatch)


    #print("xs length = {}".format(len(xs)))

    js=jnp.arange(1,v+1,1,dtype=jnp.complex64)**(-2.0*alpha)
    jt=jnp.reshape(js,(batches,-1))
    onesjtslice=jnp.ones_like(jt)[0]

    # One Newton's method update step for current estimate m on a single value of z
    def mup_single(m,z):
        m1 = m
        F=m1
        Fprime=jnp.ones_like(m1,dtype=jnp.complex64)
        for j in range(batches):
            denom = (jnp.outer(jt[j],m1) - jnp.outer(onesjtslice,z))
            F += (1.0/d)*jnp.sum(jnp.outer(jt[j],m1)/denom,axis=0)
            Fprime -= (1.0/d)*jnp.sum(jnp.outer(jt[j],z)/(denom**2),axis=0)
        return (-F + 1.0)/Fprime + m1
        #return 0.1*jnp.where(mask, m1, newm1)+0.9*m1

#    mup_single = jax.jit(mup_single, static_argnums=(0,1))

    def mup_scanner(ms,z,x):
        #mups = lambda m : mup_single(m,z*1.0j+xs)
        return mup_single(ms,z*1.0j+x), False

    #mup_scanner = jax.jit(mup_scanner, static_argnums=(0,1))
    mup_scannerjit =  jax.jit(mup_scanner)

    etas = jnp.logspace(eta0,eta,num = etasteps)
    ms = jnp.concatenate( [jax.lax.scan(lambda m,z: mup_scanner(m,z,x),jnp.ones_like(x, dtype = jnp.complex64),etas)[0] for x in xsplit] )
    #ms, _ = jax.lax.scan(mup_scanner,jnp.ones_like(xs),etas)

    return ms

In [15]:
def jax_gen_trace_fmeasure(v,d, alpha, beta, xs,
                err = -6.0, timeChecks = False, batches=100):
    """Generate the trace resolvent


    Parameters
    ----------
    v,d,alpha,beta : floats
        parameters of the model
    xs : floats
        X-values at which to return the trace-resolvent
    err : float
        Error tolerance, log scale
    timeChecks: bool
        Print times for each part

    Returns
    -------
    Volterra: vector
        values of the solution of the Volterra
    """

    eps = 10.0**(err)

    zs = xs + 1.0j*eps

    if timeChecks:
        print("The number of points on the spectral curve is {}".format(len(xs)))

    eta = jnp.log10(eps*(d**(-2*alpha)))
    eta0 = 6
    etasteps = jnp.int32(40 + 10*(2*alpha)*jnp.log(d))

    start=time.time()
    if timeChecks:
        print("Running the Newton generator with {} steps".format(etasteps))

    ms = jax_gen_m_batched(v,d,alpha,zs,eta,eta0,etasteps,batches)

    end = time.time()
    if timeChecks:
        print("Completed Newton in {} time".format(end-start) )
    start = end

    js=jnp.arange(1,v+1,1)**(-2.0*alpha)
    jbs=jnp.arange(1,v+1,1)**(-2.0*(alpha+beta))

    jt=jnp.reshape(js,(batches,-1))
    jbt=jnp.expand_dims(jnp.reshape(jbs,(batches,-1)),-1)
    onesjtslice=jnp.ones_like(jt)[0]

    Fmeasure = jnp.zeros_like(ms)
    Kmeasure = jnp.zeros_like(ms)

    for j in range(batches):
        Fmeasure += jnp.sum(jbt[j]/(jnp.outer(jt[j],ms) - jnp.outer(onesjtslice,zs + 1.0j*(10**eta))),axis=0)
        Kmeasure += jnp.sum(1.0/(jnp.outer(jt[j],ms) - jnp.outer(onesjtslice,zs + 1.0j*(10**eta))),axis=0)

    #Kmeasure = (1-ms)*((zs + 1.0j*(10**eta)))*d

    #Fmeasure = Fmeasure * dzs / (jnp.pi)
    #Kmeasure = Kmeasure * dzs / (jnp.pi)

    return jnp.imag(Fmeasure/zs) / jnp.pi
    #return jnp.imag(Fmeasure/(zs**2)) / jnp.pi

## Computing initial $\rho_j$'s deterministically

In [16]:
"""Generate the initial rho_j's deterministically.

This performs many small contour integrals each surrounding the real eigenvalues
where the vector a contains the values for the lower (left) edges of the
contours and the vector b contains the values of the upper (right) edges of the
contours.

The quantity we want to calculate is these contour integrals over the density
of zs, but we are choosing the xs to discretize this density. We therefore need
to choose the xs to be in a fine enough grid to give the desired accuracy.

This code uses a hacky method to choose the xs where the eigenvalues are divided
into num_splits different chunks (each containing the same num of eigenvalues)
so that the range of x values spanned is large for the large eigenvalues and
small for the small eigenvalues. Then this uses a linearly spaced grid within
each split so that each split uses the same number of xs.

The smallest eigenvalues actually don't need this dense of a grid, because they
make very small contributions, and the largest eigenvalues don't need this dense
of a grid because they are far apart. It is actually the intermediate
eigenvalues that are tricky because they are close together but still contribute
significantly.

Parameters
----------
num_splits (int): number of splits
a (vector): lower values of z's to be used to compute the density starting
            from largest j^{-2alpha} to smallest j^{-2alpha}
b (vector): upper values of z's to be used to compute the density starting from
            largest j^{-2alpha} to smallest j^{-2alpha}
xs_per_split (int): the number of x values to use per split

Returns
-------
rho_weights: vector
    returns rho_j weights in order of largest j^{-2alpha} to smallest j^{-2alpha}
"""

def weights(xs, density, a, b):
    # Compute integrals
      integrals = []
      def theoretical_integral(lower, upper):
        # Normalize density to make it a probability measure
        dx = xs[1] - xs[0]
        #norm = jnp.sum(density) * dx
        #density = density / norm

        # Find indices corresponding to interval [a,b]
        idx = (xs >= lower) & (xs <= upper)
        integral = jnp.sum(density[idx]) * dx
        return float(integral)
      i = 0
      for lower, upper in zip(a,b):
        integrals.append(theoretical_integral(lower, upper))
        #integrals.at[i].set(theoretical_integral(lower, upper))
        i = i+ 1
      return integrals


def deterministic_rho_weights(num_splits, a, b, xs_per_split = 10000):
  a_splits = jnp.split(a, num_splits)
  b_splits = jnp.split(b, num_splits)

  # Vectorize lower and upper bounds
  lower_bounds = jnp.array([jnp.min(split) for split in a_splits])
  upper_bounds = jnp.array([jnp.max(split) for split in b_splits])

  # Generate xs and zs for all splits
  xs = jnp.vstack([jnp.linspace(lower, upper, xs_per_split) for lower, upper in zip(lower_bounds, upper_bounds)])
  zs = xs.astype(jnp.complex64)

  rho_weights = jnp.array([])
  for a_split, b_split in zip(a_splits, b_splits):
    lower_bound_split = jnp.min(a_split)
    upper_bound_split = jnp.max(b_split)
    xs = jnp.linspace(lower_bound_split, upper_bound_split, xs_per_split)
    err = -10
    batches = 1

    zs = xs.astype(jnp.complex64)
    density = jax_gen_trace_fmeasure(V, D, alpha, beta, zs, err=err, batches = batches)

    rho_weights_split = weights(xs, density, a_split, b_split)
    rho_weights_split = jnp.array(rho_weights_split)
    rho_weights = jnp.concatenate([rho_weights, rho_weights_split], axis=0)


  # Compute density for all splits
  #density = jax.vmap(lambda z: jax_gen_trace_fmeasure(V, D, alpha, beta, z, err=-10, batches=1))(zs)

  # Compute rho_weights for all splits
  #rho_weights = jnp.array([weights(x, d, a_s, b_s) for x, d, a_s, b_s in zip(xs, density, a_splits, b_splits)])

  # Flatten the rho_weights
  #rho_weights = rho_weights.flatten()

  return rho_weights

## Run SGD and ODE Solver

In [17]:
alpha= 1.0
beta= 0.4
eta = 0.0

#SGD steps
sgd_steps = 10**1
delta_constant = jnp.maximum(2.0  + ( 2.0 * beta - 1 ) / (alpha ), 2.0 - 1.0 / alpha) + 1.0 #Need to be bigger than 2 + (2 * beta - 1) / (2 * alpha)
print('delta is {:.2f}'.format(delta_constant))

D = 500
V = 5 * D
sgd_batch = 1 #jnp.int32(0.2*D)

omega = 1.0/jnp.float32(D)
traceK = jnp.sum(jnp.arange(1,V+1,dtype=jnp.float32)**(-2*alpha))
print('traceK is {:.2f}'.format(traceK))

key,nkey = jax.random.split(key)
W = jnp.sqrt(omega)*jax.random.normal(nkey, (V,D))
data_scale = jnp.power(jnp.arange(1,V+1,dtype=jnp.float32),-1.0*alpha) #D^(1/2)

#move power-scaling from X's to beta and W to save computation
check_beta = jnp.power(jnp.arange(1,V+1,dtype=jnp.float32),-1.0*(beta+alpha)) #D^(1/2) b
check_beta = jnp.reshape(check_beta,(V,1))

check_W1 = jnp.einsum('i, ij->ij', data_scale, W) #D^(1/2) W
WtranD = jnp.einsum('ij, i->ji', W, data_scale**2) #WtranD
check_W = jnp.reshape(data_scale,(V,1)) * W #D^(1/2) W

hatK = jnp.einsum('ji,jk->ik', check_W1, check_W1)

def ab_oracle(key,batch):
  key, nkey = jax.random.split(key)
  xs = jax.random.normal(nkey, (batch, V))
  A = jnp.tensordot(xs,check_W,1)
  key, nkey = jax.random.split(key)
  noise = jax.random.normal(nkey,(batch, 1))
  y = jnp.tensordot(xs, check_beta,1) + eta*noise
  return A,y

def square_loss(theta):
  v=jnp.tensordot(check_W,theta,1) - check_beta
  return jnp.sum(v*v)

delta is 2.80
traceK is 1.64


### SGD Code

Quantities need for empirical ODE check

In [18]:
#Code for running pure SGD
sgd_gamma_2_constant = 0.5
sgd_gamma_2_scaling = sgd_gamma_2_constant * jnp.minimum( 1.0 / jnp.float32(sgd_batch), 1.0 / traceK )
def g2_sgd(time):
  return sgd_gamma_2_scaling * jnp.ones_like(time)
def g1_sgd(time):
  return 0.0
def g3_sgd(time):
  return 0.0
def delta_sgd(time):
  return 0.0

key,nkey = jax.random.split(key)

losses_sgd, times_sgd, theta_final = jax_lsq_momentum1_opt2(nkey,
                g1_sgd, g2_sgd, g3_sgd, delta_sgd, sgd_batch, sgd_steps, jnp.zeros(D),jnp.zeros(D),
                ab_oracle,square_loss
                )

print('Initial loss value is {}'.format( losses_sgd[0]))

Initial loss value is 1.2470312118530273


In [19]:
# K_check method (Equation 18)
# This is slow, use the K_hat cell below instead.

Keigs, Kevecs = np.linalg.eigh(hatK)

b = np.power(np.arange(1,V+1,dtype=jnp.float32),-1.0*beta)

hold_b = np.einsum('ij, j->i', WtranD,b)

check_b = np.linalg.solve(hatK, hold_b)

halfDW = np.einsum('i,ij->ij', data_scale, W)
halfDb = np.power(np.arange(1,V+1,dtype=jnp.float32),-1.0*(beta+alpha))

riskInfty = np.linalg.lstsq(check_W, halfDb)[1]
riskInftyTheory = tt_dbetacirc_VD(alpha, beta,V,D)
print('Empirical limiting loss value is {}'.format(riskInfty[0]))
print('Theoretical limiting loss value is {}'.format(tt_dbetacirc_VD(alpha, beta,V,D)))

#Initialize the rhos
initTheta = jnp.zeros(D, dtype=jnp.float32)
initY = jnp.zeros(D, dtype=jnp.float32)
rho = jnp.einsum('ij,i->j',  Kevecs.astype(jnp.float32), initTheta - check_b)
rho_init = rho**2 #d each row is the ith eigenvalue
rho_init.astype(jnp.float32)
omegaY = jnp.einsum('ij, i -> j',  Kevecs.astype(jnp.float32), initY) #d
sigma_init = omegaY**2
chi_init = omegaY * rho

print('Initial loss value is {}'.format( jnp.sum(rho_init * Keigs) + riskInfty))

Dt = 10**(-2)

#check_b = jnp.ravel(theta_final)

odeTimes_sgd, odeRisks_sgd = ode_resolvent_log_implicit_full(Keigs.astype(jnp.float32), rho_init, chi_init, sigma_init, riskInfty,
                                               g1_sgd, g2_sgd, g3_sgd, delta_sgd, sgd_batch, D, sgd_steps, Dt)

  riskInfty = np.linalg.lstsq(check_W, halfDb)[1]


Empirical limiting loss value is 8.853174949763343e-05
Theoretical limiting loss value is 9.77418094407767e-05
Initial loss value is [1.2470324]


In [None]:
# This uses K_hat (instead of K_check) to compute rho_init (see Prop. F1).
# Using K_hat is a much easier + faster way to compute rho_init than using K_check.
# This one should work for all values of alpha (including alpha < 0.5) and should work for any batch size.

# For the spectra computations
Uvec, s, Vvec =jnp.linalg.svd(check_W,full_matrices=False)

#Compute < ( D^1/2 W W^T D^(1/2) - z)^{-1}, D^(1/2) b >
check_beta_weight = jnp.tensordot(check_beta,Uvec,axes=[[0],[0]])[0]

rho_init = (check_beta_weight)**2 / s**2
rho_init.astype(jnp.float32)
sigma_init = jnp.zeros(D, dtype=jnp.float32)
chi_init = jnp.zeros(D, dtype=jnp.float32)

riskInftyTheory = tt_dbetacirc_VD(alpha, beta,V,D)

print('Theoretical limiting loss value is {}'.format(riskInftyTheory))
K_eigs = s**2

print('Initial loss value is {}'.format( jnp.sum(rho_init * K_eigs) + riskInftyTheory))

Dt = 10**(-2)

odeTimes_sgd_1, odeRisks_sgd_1 = ode_resolvent_log_implicit_full(K_eigs.astype(jnp.float32), rho_init, chi_init, sigma_init, riskInftyTheory,
                                               g1_sgd, g2_sgd, g3_sgd, delta_sgd, sgd_batch, D, sgd_steps, Dt)


In [None]:
#Theory spectra generated by Newton (Elliot)

#Compute the theoretical limiting loss value
riskInftyTheory = tt_dbetacirc_VD(alpha, beta,V,D)
print('Theoretical limiting loss value is {}'.format(riskInftyTheory))

# Compute theoretical integrals using density approximation
lower_bound = tt_lmin(alpha)*(D**(-2*alpha))#jnp.minimum(tt_lmin(alpha)*(D**(-2*alpha)), 0.9*(D+1)**(-2.0*alpha)) #jnp.minimum(0.00001, 0.9*(D+1)**(-2.0*alpha)) #tt_lmin(alpha)*(D**(-2*alpha)) #jnp.minimum(0.00001, 0.9*(D+1)**(-2.0*alpha))
upper_bound = 1.0*1.1

fake_eigs = np.power(np.arange(1,D+1,dtype=jnp.float32),-2.0*alpha)
b_values = fake_eigs - 0.5 * jnp.diff(fake_eigs, prepend = upper_bound)
a_values = fake_eigs + 0.5 * jnp.diff(fake_eigs, append = lower_bound)
num_splits = 5 #Must divide D into equal parts
rho_weights = deterministic_rho_weights(num_splits, a_values, b_values)

#rho_weights = split_eigenvalues(num_splits, a_values, b_values)
#rho_weights = weights(xs, density, a_values, b_values)

print('Initial loss value is is {}'.format(jnp.sum( rho_weights*fake_eigs) + riskInftyTheory))

# Compute integrals


#dx = xs[1] - xs[0]

rho_init = rho_weights #density * dx
#num_grid_points = jnp.shape(xs)[0] #Represents the number of eigenvalues
num_grid_points = D
sigma_init = jnp.zeros(num_grid_points, dtype=jnp.float32)
chi_init = jnp.zeros(num_grid_points, dtype=jnp.float32)



Theoretical limiting loss value is 9.77418094407767e-05


In [None]:
Dt = 10**(-2) #10**(-2)

odeTimes_sgd_theory, odeRisks_sgd_theory = ode_resolvent_log_implicit_full(fake_eigs, rho_init, chi_init, sigma_init, riskInftyTheory,
                                               g1_sgd, g2_sgd, g3_sgd, delta_sgd, sgd_batch, num_grid_points, sgd_steps, Dt)


In [None]:
#Theory spectra generated by Newton (Elliot)

#Compute the theoretical limiting loss value
riskInftyTheory = tt_dbetacirc_VD(alpha, beta,V,D)
print('Theoretical limiting loss value is {}'.format(riskInftyTheory))

# Compute theoretical integrals using density approximation
lower_bound = tt_lmin(alpha)*(D**(-2*alpha))#jnp.minimum(tt_lmin(alpha)*(D**(-2*alpha)), 0.9*(D+1)**(-2.0*alpha)) #jnp.minimum(0.00001, 0.9*(D+1)**(-2.0*alpha)) #tt_lmin(alpha)*(D**(-2*alpha)) #jnp.minimum(0.00001, 0.9*(D+1)**(-2.0*alpha))
upper_bound = 1.0*1.1

fake_eigs = np.power(np.arange(1,D+1,dtype=jnp.float32),-2.0*alpha)
b_values = fake_eigs - 0.5 * jnp.diff(fake_eigs, prepend = upper_bound)
a_values = fake_eigs + 0.5 * jnp.diff(fake_eigs, append = lower_bound)
num_splits = 5 #Must divide D into equal parts
rho_weights = deterministic_rho_weights(num_splits, a_values, b_values)

#rho_weights = split_eigenvalues(num_splits, a_values, b_values)
#rho_weights = weights(xs, density, a_values, b_values)

print('Initial loss value is is {}'.format(jnp.sum( rho_weights*fake_eigs) + riskInftyTheory))

# Compute integrals


#dx = xs[1] - xs[0]

rho_init = rho_weights #density * dx
#num_grid_points = jnp.shape(xs)[0] #Represents the number of eigenvalues
num_grid_points = D
sigma_init = jnp.zeros(num_grid_points, dtype=jnp.float32)
chi_init = jnp.zeros(num_grid_points, dtype=jnp.float32)

In [None]:
Dt = 10**(-2) #10**(-2)

odeTimes_sgd_theory_approximate, odeRisks_sgd_theory_approximate = ode_resolvent_log_implicit_approximate(fake_eigs, rho_init, chi_init, sigma_init, riskInftyTheory,
                                               g1_sgd, g2_sgd, g3_sgd, delta_sgd, sgd_batch, num_grid_points, sgd_steps, Dt)


### Dana-constant Code

In [None]:
#Code for running dana with constant learning rate
dana_gamma_3_constant = 0.1
dana_gamma_3_scaling =  dana_gamma_3_constant / ( float(D) ) *  1.0 / traceK
dana_gamma_2_scaling = 0.5 * jnp.minimum( 1.0 / jnp.float32(sgd_batch), 1.0 / traceK ) #0.5 / traceK
dana_gamma_1_scaling = 1.0
dana_expMomentum = 1.0

def g2_dana(time):
  return dana_gamma_2_scaling * jnp.ones_like(time)

def g1_dana(time):
  return dana_gamma_1_scaling * jnp.ones_like(time)

def g3_dana(time):
  return dana_gamma_3_scaling * jnp.ones_like(time)

def delta_dana(time):
  return delta_constant / ( (1.0 + time)**dana_expMomentum ) * jnp.ones_like(time)

key,nkey = jax.random.split(key)

losses_DANA,times_DANA, theta_final_DANA = jax_lsq_momentum1_opt2(nkey,
                g1_dana, g2_dana, g3_dana, delta_dana, sgd_batch, sgd_steps, jnp.zeros(D),jnp.zeros(D),
                ab_oracle,square_loss
                )

print('Initial loss value is {}'.format( losses_DANA[0]) )

In [None]:
Keigs, Kevecs = np.linalg.eigh(hatK)

b = np.power(np.arange(1,V+1,dtype=jnp.float32),-1.0*beta)

hold_b = np.einsum('ij, j->i', WtranD,b)

check_b = np.linalg.solve(hatK, hold_b)

halfDW = np.einsum('i,ij->ij', data_scale, W)
halfDb = np.power(np.arange(1,V+1,dtype=jnp.float32),-1.0*(beta+alpha))

riskInfty = np.linalg.lstsq(check_W, halfDb)[1]
print('Empirical limiting loss value is {}'.format(riskInfty[0]))
print('Theoretical limiting loss value is {}'.format(tt_dbetacirc_VD(alpha, beta,V,D)))

#Initialize the rhos
initTheta = jnp.zeros(D, dtype=jnp.float32)
initY = jnp.zeros(D, dtype=jnp.float32)
rho = jnp.einsum('ij,i->j',  Kevecs.astype(jnp.float32), initTheta - check_b)
rho_init = rho**2 #d each row is the ith eigenvalue
rho_init.astype(jnp.float32)
omegaY = jnp.einsum('ij, i -> j',  Kevecs.astype(jnp.float32), initY) #d
sigma_init = omegaY**2
chi_init = omegaY * rho

print('Initial loss value is {}'.format( jnp.sum(rho_init * Keigs) + riskInfty))

Dt = 10**(-2)

odeTimes_dana, odeRisks_dana = ode_resolvent_log_implicit_full(Keigs.astype(jnp.float32), rho_init, chi_init, sigma_init, riskInfty,
                                               g1_dana, g2_dana, g3_dana, delta_dana, sgd_batch, D, sgd_steps, Dt)

In [None]:
# For the spectra computations
Uvec, s, Vvec =jnp.linalg.svd(check_W,full_matrices=False)

#Compute < ( D^1/2 W W^T D^(1/2) - z)^{-1}, D^(1/2) hat{\beta} >
check_beta_weight = jnp.tensordot(check_beta,Uvec,axes=[[0],[0]])[0]

rho_init = (check_beta_weight)**2 / s**2
rho_init.astype(jnp.float32)
sigma_init = jnp.zeros(D, dtype=jnp.float32)
chi_init = jnp.zeros(D, dtype=jnp.float32)

riskInftyTheory = tt_dbetacirc_VD(alpha, beta,V,D)

print('Theoretical limiting loss value is {}'.format(riskInftyTheory))
K_eigs = s**2

Dt = 10**(-2)


print('Initial loss value is {}'.format( jnp.sum(rho_init * K_eigs) + riskInftyTheory))

odeTimes_dana_1, odeRisks_dana_1 = ode_resolvent_log_implicit_full(K_eigs.astype(jnp.float32), rho_init, chi_init, sigma_init, riskInftyTheory,
                                               g1_dana, g2_dana, g3_dana, delta_dana, sgd_batch, D, sgd_steps, Dt)

### Dana-decay code

In [None]:
#Code for running dana with decaying gamma 3 learning rate
dana_gamma_3_constant_decay = 0.1
dana_gamma_3_scaling_decay =  dana_gamma_3_constant_decay
dana_gamma_2_scaling_decay = 0.5 * jnp.minimum( 1.0 / jnp.float32(sgd_batch), 1.0 / traceK ) #0.5 / traceK
dana_gamma_1_scaling_decay = 1.0
dana_expMomentum_decay = 1.0

def g2_dana_decay(time):
  return dana_gamma_2_scaling_decay * jnp.ones_like(time)

def g1_dana_decay(time):
  return dana_gamma_1_scaling_decay * jnp.ones_like(time)

def g3_dana_decay(time):
  return dana_gamma_3_scaling_decay / (1.0 + time)**( 1.0 / (2.0 * alpha) ) * 1.0 / traceK

def delta_dana_decay(time):
  return delta_constant / ( (1.0 + time)**dana_expMomentum_decay ) * jnp.ones_like(time)

key,nkey = jax.random.split(key)

losses_DANA_decay,times_DANA_decay, theta_final_DANA_decay = jax_lsq_momentum1_opt2(nkey,
                g1_dana_decay, g2_dana_decay, g3_dana_decay, delta_dana_decay, sgd_batch, sgd_steps, jnp.zeros(D),jnp.zeros(D),
                ab_oracle,square_loss
                )

print('Initial loss value is {}'.format( losses_DANA_decay[0]))

In [None]:
Keigs, Kevecs = np.linalg.eigh(hatK)

b = np.power(np.arange(1,V+1,dtype=jnp.float32),-1.0*beta)

hold_b = np.einsum('ij, j->i', WtranD,b)

check_b = np.linalg.solve(hatK, hold_b)

halfDW = np.einsum('i,ij->ij', data_scale, W)
halfDb = np.power(np.arange(1,V+1,dtype=jnp.float32),-1.0*(beta+alpha))

riskInfty = np.linalg.lstsq(check_W, halfDb)[1]
riskInftyTheory = tt_dbetacirc_VD(alpha, beta,V,D)
print('Empirical limiting loss value is {}'.format(riskInfty[0]))
print('Theoretical limiting loss value is {}'.format(tt_dbetacirc_VD(alpha, beta,V,D)))

#Initialize the rhos
initTheta = jnp.zeros(D, dtype=jnp.float32)
initY = jnp.zeros(D, dtype=jnp.float32)
rho = jnp.einsum('ij,i->j',  Kevecs.astype(jnp.float32), initTheta - check_b)
rho_init = rho**2 #d each row is the ith eigenvalue
rho_init.astype(jnp.float32)
omegaY = jnp.einsum('ij, i -> j',  Kevecs.astype(jnp.float32), initY) #d
sigma_init = omegaY**2
chi_init = omegaY * rho

print('Initial loss value is {}'.format( jnp.sum(rho_init * Keigs) + riskInfty))

Dt = 10**(-2)

#check_b = jnp.ravel(theta_final)

odeTimes_dana_decay, odeRisks_dana_decay = ode_resolvent_log_implicit_full(Keigs.astype(jnp.float32), rho_init, chi_init, sigma_init, riskInfty,
                                               g1_dana_decay, g2_dana_decay, g3_dana_decay, delta_dana_decay, sgd_batch, D, sgd_steps, Dt)

In [None]:
# For the spectra computations
Uvec, s, Vvec =jnp.linalg.svd(check_W,full_matrices=False)

#Compute < ( D^1/2 W W^T D^(1/2) - z)^{-1}, D^(1/2) hat{\beta} >
check_beta_weight = jnp.tensordot(check_beta,Uvec,axes=[[0],[0]])[0]

rho_init = (check_beta_weight)**2 / s**2
rho_init.astype(jnp.float32)
sigma_init = jnp.zeros(D, dtype=jnp.float32)
chi_init = jnp.zeros(D, dtype=jnp.float32)

riskInftyTheory = tt_dbetacirc_VD(alpha, beta,V,D)

print('Theoretical limiting loss value is {}'.format(riskInftyTheory))
K_eigs = s**2

print('Initial loss value is {}'.format( jnp.sum(rho_init * K_eigs) + riskInftyTheory))


odeTimes_dana_decay_1, odeRisks_dana_decay_1 = ode_resolvent_log_implicit_full(K_eigs.astype(jnp.float32), rho_init, chi_init, sigma_init, riskInftyTheory,
                                               g1_dana_decay, g2_dana_decay, g3_dana_decay, delta_dana_decay, sgd_batch, D, sgd_steps, Dt)

In [None]:
#Theory spectra generated by Newton (Elliot)

#Compute the theoretical limiting loss value
riskInftyTheory = tt_dbetacirc_VD(alpha, beta,V,D)
print('Theoretical limiting loss value is {}'.format(riskInftyTheory))

# Compute theoretical integrals using density approximation
lower_bound = tt_lmin(alpha)*(D**(-2*alpha)) #jnp.minimum(0.00001, 0.9*(D+1)**(-2.0*alpha)) #tt_lmin(alpha)*(D**(-2*alpha)) #jnp.minimum(0.00001, 0.9*(D+1)**(-2.0*alpha))
upper_bound = 1.0*1.1


#xs = jnp.linspace(lower_bound, upper_bound, 10000)

#err = -10
#batches = 1
#zs = xs.astype(jnp.complex64)

#density = jax_gen_trace_fmeasure(V, D, alpha, beta, zs, err=err, batches = batches)

fake_eigs = np.power(np.arange(1,D+1,dtype=jnp.float32),-2.0*alpha)
b_values = fake_eigs - 0.5 * jnp.diff(fake_eigs, prepend = upper_bound)
a_values = fake_eigs + 0.5 * jnp.diff(fake_eigs, append = lower_bound)
#rho_weights = deterministic_rho_weights(xs, density, a_values, b_values)


num_splits = 5
rho_weights = deterministic_rho_weights(num_splits, a_values, b_values)

print('Initial loss value is is {}'.format(jnp.sum( rho_weights*fake_eigs) + riskInftyTheory))

# Compute integrals
#dx = xs[1] - xs[0]

rho_init = rho_weights #density * dx
#num_grid_points = jnp.shape(xs)[0] #Represents the number of eigenvalues
num_grid_points = D
sigma_init = jnp.zeros(num_grid_points, dtype=jnp.float32)
chi_init = jnp.zeros(num_grid_points, dtype=jnp.float32)

Dt = 10**(-2) #10**(-2)


odeTimes_dana_decay_theory, odeRisks_dana_decay_theory = ode_resolvent_log_implicit_full(fake_eigs, rho_init, chi_init, sigma_init, riskInftyTheory,
                                               g1_dana_decay, g2_dana_decay, g3_dana_decay, delta_dana_decay, sgd_batch, num_grid_points, sgd_steps, Dt)


In [None]:
odeTimes_dana_decay_theory_approximate, odeRisks_dana_decay_theory_approximate = ode_resolvent_log_implicit_approximate(fake_eigs, rho_init, chi_init, sigma_init, riskInftyTheory,
                                               g1_dana_decay, g2_dana_decay, g3_dana_decay, delta_dana_decay, sgd_batch, num_grid_points, sgd_steps, Dt)

### Graph

In [None]:
axarr = plt.axes()
axarr.set_xlabel("flops", fontsize = '28')
axarr.set_ylabel("risk", fontsize = '28')
plt.title('alpha = {:,},'.format(alpha)+' beta = {:,}'.format(beta)+', d = {:,}'.format(D)+', batch = {:,}'.format(sgd_batch))

plt.plot((times_sgd) * (float(D)),losses_sgd, label= 'sgd', linewidth = 3.0, c = 'blue', alpha = 0.2) #exponent gamma_1 = {:,f}'.format(exponent)+', gamma2={:,}'.format(scaling_gamma_2_momentum)+'/traceK', linewidth = 3.0)
#plt.plot(odeTimes_sgd* (float(D)), odeRisks_sgd, label= 'ode (sgd, empirical old)', linewidth = 3.0, c = 'blue')
plt.plot(odeTimes_sgd_1 * (float(D)), odeRisks_sgd_1, label= 'ode (sgd, empirical new)', linewidth = 3.0, c = 'blue', alpha = 0.6)
#plt.plot(odeTimes_sgd_theory * (float(D)), odeRisks_sgd_theory, label= 'ode (sgd, deterministic)', linewidth = 3.0, c = 'gray')


#plt.plot(times_DANA * (float(D)),losses_DANA, label= 'dana, constant', linewidth = 3.0, c = 'orange', alpha = 0.2) #exponent gamma_1 = {:,f}'.format(exponent)+', gamma2={:,}'.format(scaling_gamma_2_momentum)+'/traceK', linewidth = 3.0)
#plt.plot(odeTimes_dana * (float(D)), odeRisks_dana, label= 'ode (dana, constant, empirical old)', linewidth = 3.0, c = 'orange')
#plt.plot(odeTimes_dana_1 * (float(D)), odeRisks_dana_1, label= 'ode (dana, constant, empirical new)', linewidth = 3.0, c = 'orange', alpha = 0.6)
#plt.plot(odeTimes_dana_theory * (float(D)), odeRisks_dana_theory, label= 'ode (dana, deterministic)', linewidth = 3.0, c = 'red')


plt.plot(times_DANA_decay * (float(D)),losses_DANA_decay, label= 'dana, decay', linewidth = 3.0, c = 'green', alpha = 0.2) #exponent gamma_1 = {:,f}'.format(exponent)+', gamma2={:,}'.format(scaling_gamma_2_momentum)+'/traceK', linewidth = 3.0)
#plt.plot(odeTimes_dana_decay * (float(D)), odeRisks_dana_decay, label= 'ode (dana, decay, empirical, old)', linewidth = 3.0, c = 'green')
plt.plot(odeTimes_dana_decay_1 * (float(D)), odeRisks_dana_decay_1, label= 'ode (dana, decay, emp. new)', linewidth = 3.0, c = 'green', alpha = 0.6)
#plt.plot(odeTimes_dana_decay_theory * (float(D)), odeRisks_dana_decay_theory, label= 'ode (dana, decay, deterministic)', linewidth = 3.0, c = 'black')


plt.plot(odeTimes_sgd_theory * (float(D)), odeRisks_sgd_theory, label= 'ode (sgd, deterministic)', linewidth = 3.0, c = 'gray')
plt.plot(odeTimes_dana_decay_theory * (float(D)), odeRisks_dana_decay_theory, label= 'ode (dana, decay, deterministic)', linewidth = 3.0, c = 'black')
plt.plot(odeTimes_sgd_theory_approximate * (float(D)), odeRisks_sgd_theory_approximate, label= 'approx. ode (sgd, deterministic)', linewidth = 3.0, c = 'red')
plt.plot(odeTimes_dana_decay_theory_approximate * (float(D)), odeRisks_dana_decay_theory_approximate, label= 'approx. ode (dana, decay, deterministic)', linewidth = 3.0, c = 'red', linestyle = 'dashed')


plt.yscale('log')
plt.xscale('log')
plt.grid()



leg = plt.legend(loc='upper right', fontsize='10')
