In [None]:
def ode_resolvent(K, x, x_star, h, Dh, cov_grad_f, gamma, t_max, n_grid):
  """Generate the theoretical solution to gradient flow
  Parameters
  ----------
  K : array (d x d) covariance matrix
  x, x_star : array (d x o) initializations of x_0 and x_star
  h : function (outputs scalar) Computes the risk given C
  Dh : function (outputs 2 matrices) Computes for any time the derivatives $Dh_11$ and $Dh_{12}$ for $h(C(X)) = R(X)$
  cov_grad_f : function (outputs 1 matrix) Computes for any time the derivative of $E_a[\nabla f(x) \nabla f(x)^T]
  gamma : float Step size
  t_max : float The number of epochs
  n_grid : int The number of grid points
  """
  t_grid = np.linspace(0.0, t_max, n_grid)
  risks = np.zeros_like(t_grid)

  num_time_loops = n_grid
  Dt = t_max/n_grid

  Keigs, Kvecs = np.linalg.eigh(K)
  halfS_x = x.transpose()@Kvecs
  halfS_x_star = x_star.transpose()@Kvecs

  #S_12 is d x o x t
  S_11 = np.einsum('ki,ji->ijk', halfS_x, halfS_x)
  S_12 = np.einsum('ki,ji->ijk', halfS_x, halfS_x_star)
  S_22 = np.einsum('ki,ji->ijk', halfS_x_star, halfS_x_star)
  for i in range(n_grid):
    C_11 = np.tensordot(S_11, Keigs, axes=(0,0))
    C_12 = np.tensordot(S_12, Keigs, axes=(0,0))
    C_22 = np.tensordot(S_22, Keigs, axes=(0,0))
  #DH_11 is o x o and DH_21 is t x 0
    DH_11, DH_21 = Dh(C_11, C_12, C_22)

    S_11_gr = -2.0*gamma*np.einsum('i,ijk->ijk', Keigs,(
                                       np.tensordot(S_11,DH_11, axes=(2,0))
                                       +np.tensordot(S_12,DH_21, axes=(2,0))
                                       +np.tensordot(S_11,DH_11, axes=(1,1))
                                       +np.tensordot(S_12,DH_21, axes=(2,0))
                                       ))
    S_12_gr = -2.0*gamma*np.einsum('i,ijk->ijk', Keigs,(
                                       np.tensordot(S_12,DH_11, axes=(1,0))
                                       +np.einsum('ijk,jl->ilk', S_22, DH_21)
                                       ))
    S_11_noise = (gamma**2/d)*np.tensordot(Keigs,cov_grad_f(C_11,C_12,C_22),axes=0)
    S_11 += Dt*(S_11_gr + S_11_noise)
    S_12 += Dt*(S_12_gr)
    risks[i] = h(C_11, C_12, C_22)
  return t_grid, risks

In [None]:
#Number of time grid points in theory for the ODE
n_grid = 1000

# Covariance K used in theory
K = np.diagflat(strlin_cov**2)

#Functions used to compute h, Dh and cov_grad_f, Dh = [D_11, D_21]
def h(C11,C12,C22):
  part_1 = np.log( np.exp( np.sqrt(C11 * 2.0 ) * x_integral_points ) + 1.0 ) * weights * (1.0 / np.sqrt(np.pi) )
  part_2 = np.exp( np.sqrt(C22 * 2.0) * x_integral_points ) / ( ( np.exp( np.sqrt(C22 * 2.0) * x_integral_points ) + 1.0 )**2 ) * weights * (1.0 / np.sqrt(np.pi))
  return np.sum( part_1 ) - 1.0 * C12 * np.sum( part_2 )

def Dh(C11,C12,C22):
  L_11 = np.sqrt(C11)
  L_21 = C12 / np.sqrt(C11)
  L_22 = np.sqrt( C22 - ( C12**2 / C11) )

  det_C = C11 * C22 - C12**2

  x = np.sqrt(2.0) * ( L_11 * total_points[:,0] )
  y = np.sqrt(2.0) * ( L_21 * total_points[:,0] + L_22 * total_points[:,1] )

  temp = 0.5 * (1.0 / np.pi ) * x * (np.exp(y) / (1.0 + np.exp(y) ))

  h11_start = ( y**2 / det_C ) - ( 2.0 * C22 / det_C ) * ( ( total_points[:,0] )**2 + ( total_points[:,1] )**2 )  + ( C22 / det_C )
  h21_start = (-1.0 * x * y / det_C) + ( 2.0 * C12 / det_C ) * ( ( total_points[:,0] )**2 + ( total_points[:,1] )**2 ) - ( C12 / det_C )

  h11_end = np.sqrt(2.0) * x_integral_points * 0.5 / (np.sqrt(C11)) * ( np.exp( np.sqrt(C11 * 2.0) * x_integral_points ) / ( 1.0 + np.exp( np.sqrt(C11 * 2.0) * x_integral_points ) ) ) * weights * (1.0 / np.sqrt(np.pi))

  Dh11 = np.sum( temp * mult_weights * h11_start ) + np.sum( h11_end )
  Dh21 = np.sum( temp * mult_weights * h21_start )

  return(Dh11 *np.eye(log_o), Dh21 * np.eye(log_o))

def cov_grad_f(C11,C12,C22):
  L_11 = np.sqrt(C11)
  L_21 = C12 / np.sqrt(C11)
  L_22 = np.sqrt( C22 - ( C12**2 / C11) )

  det_C = C11 * C22 - C12**2

  x = np.sqrt(2.0) * ( L_11 * total_points[:,0] )
  y = np.sqrt(2.0) * ( L_21 * total_points[:,0] + L_22 * total_points[:,1] )

  temp = -1.0 * ( np.exp( y ) / (1.0 + np.exp(y) ) ) + ( np.exp(x) / (1.0 + np.exp(x)) )

  return np.sum( temp**2 * (1.0 / np.pi) * mult_weights ) * np.eye(log_o)

times, risks = ode_resolvent(K, np.reshape(ist,(d,log_o)), np.reshape(strlin_xstar,(d,log_o)),
              h, Dh, cov_grad_f, gamma, t_max, n_grid)