Imports

In [1]:
"""Collection of general utility functions."""
import os
from typing import Tuple, Text, Dict, Union, Any

import numpy as np
from jax.experimental import sparse
from absl import logging
from jax import random, vmap, jit, grad, numpy as jnp
from jax.nn import softmax, relu
#from jax.ops import index_update, index DEPRECATED: replace with jax.numpy.ndarray.at in the corresponding code
from jax.scipy.linalg import expm
from scipy.optimize import minimize

Params = Tuple[jnp.ndarray, ...]
Results = Dict[Text, Union[jnp.ndarray, list, float]]
Tup = Tuple[float, float]



Intialise

In [188]:
def init_params(key: jnp.ndarray, n: int, f: int, m: int,
                sigma_weights: float = 0, init_spectrum: float = 1,
                sigma_spectrum: float = 0) -> Params:
  """Initialize optimization parameters.
  Args:
    key: The random key.
    n: The spin dimension.
    f: The number of particles.
    m: The cardinality of the support of the discrete measure.
    sigma_weights: The standard deviation of the noise to be added to the fixed
      initialization of weights.
    init_spectrum: The initial absolute value of the mean of the negative
      eigenvalues. We will add 1/n to this value for each of the positive
      eigenvalues to ensure trace = 1.
    sigma_spectrum: The standard deviation of the noise to be added to the fixed
      initialization of the spectra.

  Returns:
    a 5-tuple of initial parameters:
        weights: The weights of the discrete measure (m,); real
        pos_spectrum: initial log values of pos spectrum (m, n); real
        neg_spectrum: initial log values of pos spectrum (m, n); real
        alphas: initial values of the alpha angles used to define the unitary (m,2n(f-n)-n): real
        betas: initial values of the beta angles used to define the unitary (m,2n(f-n)-n): real
  """
  subkeys = random.split(key, 5)
  weights = sigma_weights * random.normal(subkeys[0], (m,))
  # Initial pos spectra are all (init_spectrum + 1 / n) and
  # negative ones are all (init_spectrum)
  # We work with logs of desired spectra to ensure the correct signs later
  pos_spectrum = jnp.log(init_spectrum + 1. / n)
  pos_spectrum += sigma_spectrum * random.normal(subkeys[1], (m, n))
  neg_spectrum = jnp.log(init_spectrum)
  neg_spectrum += sigma_spectrum * random.normal(subkeys[2], (m, n))
  alphas = random.uniform(subkeys[3],(m,2*n*(f-n)-n), minval = 0, maxval = 4*jnp.pi)
  betas = random.uniform(subkeys[4],(m,2*n*(f-n)-n), minval = 0, maxval = jnp.pi/2)
  return weights, pos_spectrum, neg_spectrum, alphas, betas

Make spectra

In [192]:
def make_spectra(pos_spectrum: jnp.ndarray,
                 neg_spectrum: jnp.ndarray) -> jnp.ndarray:
  """Compute actual spectra from optimization parameters.

  The spectra have to have n positive and n negative eigenvalues
  and satisfy the trace constraint, which we ensure here.

  Args:
    pos_spectrum: Optimization parameters for positive eigenvalues.
    neg_spectrum: Optimization parameters for negative eigenvalues.

  Returns:
    Full (m, 2 n) array of the m spectra
  """
  spectra = jnp.concatenate((jnp.exp(pos_spectrum), - jnp.exp(neg_spectrum)), 1)
  return spectra / jnp.sum(spectra, axis=1)[..., jnp.newaxis]

Make the unitary matrix

In [189]:
#masks
def make_masks(f,band_number):

  """build

  Args:
    f: the number of particles
    band_number: index of the band in the unitary decomposition (ranges from 1 to number of nonzero eigenvalues)
  Returns:
  3 arrays of masks of shape (f-band_number,f)
  """

  term_index = jnp.arange(f-band_number)[:,jnp.newaxis, jnp.newaxis]
  row_index = jnp.arange(f)[jnp.newaxis,:, jnp.newaxis]
  col_index = jnp.arange(f)[jnp.newaxis,jnp.newaxis,:]

  #conditions
  mask_cos_exp_pos = (row_index == term_index + band_number-1) & (band_number-1 <= col_index) & (col_index <= row_index)
  mask_cos_exp_neg = (col_index == term_index + band_number) & (row_index >= col_index)
  mask_sin = ((band_number-1 <= col_index) & (col_index <= term_index + band_number-1) & (row_index >= term_index + band_number)|
((row_index == term_index + band_number-1) & (col_index == row_index +1)))

  return mask_cos_exp_pos, mask_cos_exp_neg,mask_sin



In [3]:
def get_building_blocks(alphas:jnp.ndarray, betas:jnp.ndarray):
  """convert the arrays of angles into arrays of trigonometric functions which are "building blocks" of the band unitaries

  Args:
    alphas: Values of the alpha angles used to define the unitary (m,2n(f-n-1/2)): real
    betas: Values of the beta angles used to define the unitary (m,2n(f-n-1/2)): real
  Returns:
  3 "building blocks" terms, basic trigonometric functions of the alphas and betas of the same size
  (m,2n(f-n)-n)
  """
  exp_alphas = jnp.exp(1J*alphas)
  cos_betas = jnp.cos(betas)
  sin_betas = jnp.sin(betas)
  cos_betas_exp_pos_alphas = exp_alphas*cos_betas
  cos_betas_exp_neg_alphas = jnp.conj(exp_alphas)* cos_betas

  return cos_betas_exp_pos_alphas, cos_betas_exp_neg_alphas, sin_betas


In [179]:
def make_single_band_unitary(alphas_band,betas_band,f):
  """Use the angle parameters and the masks to generate a unitary band matrix

    Args:
    f: dimension of the matrix (number of  particles)
    alphas_band: shape (f-band_number,)
    betas_band: shape (f-band_number,)

  Returns:
    (f, f) band  unitary matrix
    """
  band_number = f - len(alphas_band)

  #extract the building blocks and masks
  building_blocks = get_building_blocks(alphas_band, betas_band)
  masks = make_masks(f,band_number)

  #initialise the band matrix with a matrix with ones on lower triangle and superdiagonal
  ones_tril= jnp.tril(jnp.ones((f,f))) + jnp.eye(f,k = 1)
  band_matrix= ones_tril.copy()

  #iterate over the different masks for each building block
  num_building_blocks= len(building_blocks)
  for building_block_index in range(num_building_blocks):
    mask = masks[building_block_index] #shape (f-band_number,f,f)
    building_block = building_blocks[building_block_index]
    band_matrix_building_block = mask*building_block[:,jnp.newaxis,jnp.newaxis]

    #add ones in the lower triangle and superdiagonal before multiplying  the matrices
    band_matrix_building_block += ones_tril - mask

    #multiply the matrices together
    band_matrix_building_block = jnp.prod(band_matrix_building_block, axis =0)

    #multiply the different  building blocks together
    band_matrix *=band_matrix_building_block

 #multiply by a final mask
  final_mask = jnp.tril(jnp.concatenate((jnp.zeros((f, band_number-1)), jnp.ones((f, f - band_number+1))), axis=1), k=-1)# zeros in first band_number-1  columns
  final_mask += jnp.eye(f) #add ones on the diagonal
  super_diagonal_terms = jnp.concatenate((jnp.zeros(band_number-1),jnp.ones(f-band_number)))
  final_mask -= jnp.diag(super_diagonal_terms, k = 1) #-1 on the superdiagonal to represent  the negative  sines
  band_matrix *= final_mask #shape (f,n_col)



  return band_matrix

  make_band_unitary = vmap(make_single_band_unitary, in_axes=(0, 0, None))

Random test for band unitary

In [225]:
f = 5 ;n = 1; band_number = 2
alphas = jnp.array([0.1,0.2,0.3])
betas = jnp.array([0.4,0.5,0.6])
#alphas = np.random.uniform(low=0, high=4*np.pi, size=(f-band_number))
#betas = np.random.uniform(low=0, high=np.pi/2, size=(f-band_number))
print('alphas: ',alphas)
print('betas: ',betas)
print("band unitary: ", make_single_band_unitary(alphas,betas,f))

alphas:  [0.1 0.2 0.3]
betas:  [0.4 0.5 0.6]
band unitary:  [[ 1.        +0.j          0.        +0.j          0.        +0.j
   0.        +0.j          0.        +0.j        ]
 [ 0.        +0.j          0.9164595 +0.09195267j -0.38941833+0.j
   0.        +0.j          0.        +0.j        ]
 [ 0.        +0.j          0.33493456+0.06789459j  0.8042689 +0.08069605j
  -0.47942555+0.j          0.        +0.j        ]
 [ 0.        +0.j          0.14720567+0.04553605j  0.35718706+0.07240541j
   0.7206817 +0.07230937j -0.5646425 +0.j        ]
 [ 0.        +0.j          0.10541711+0.j          0.24808928-0.02489196j
   0.485643  -0.09844471j  0.78847325-0.24390337j]]


suggestion: instead of having three different masks ocver we iterate, vectorise a function for a single mask

In [6]:
f = 3
c = [jnp.pi,jnp.pi/2,2*jnp.pi]
d = [jnp.pi/4, jnp.pi/6, jnp.pi/3]

make_band_unitary = vmap(make_single_band_unitary, in_axes=(0, 0, None))
alphas = jnp.array([[jnp.pi,jnp.pi/2],[jnp.pi,jnp.pi/2]])
betas = jnp.array([[0,0],[0,0]])

alphas = jnp.array([[jnp.pi,jnp.pi/2]] )
betas = jnp.array([[0,0]])

print(jnp.int_(make_band_unitary(alphas,betas,f))+1J*jnp.int_((-1J)*make_band_unitary(alphas,betas,f)))



[[[-1.+0.j  0.+0.j  0.+0.j]
  [ 0.+0.j  0.-1.j  0.+0.j]
  [ 0.+0.j  0.+0.j  0.-1.j]]]


  return asarray(x, dtype=self.dtype)
  out_array: Array = lax_internal._convert_element_type(


In [28]:
def make_single_unitary(alphas,betas,f, n):
  """Use the angle parameters to build a unitary matrix through
  the matrix product of  successive band unitary band matrices
  Note: This is not any unitary but a unitary corresponding to the change of basis into
  the eigenvector basis of a matrix with n non zero eigenvalues counted with their multiplicity

    Args:
    f: dimension of the matrix (number of  particles)
    n: total number of eigenvalues
    alphas: shape (2n(f-n)-n,) WRONG BECAUSE DIFF DEF OF N
    betas: shape (2n(f-n)-n,) SAME HERE

  Returns:
    (f, f) unitary matrix
    """
  #extract all the parameters
  building_blocks = get_building_blocks(alphas, betas)

  #initialise the unitary
  unitary = jnp.eye(f)

  #iterate over the bands
  end_index = 0
  for band_number in range(1,n+1):
    #extract the correct alphas and betas
    start_index = end_index
    end_index =start_index + f - band_number
    alphas_band, betas_band = alphas[start_index:end_index],betas[start_index:end_index]

    #build the band unitary
    band_matrix = make_single_band_unitary(alphas_band,betas_band,f)

    #multiply the different unitaries together
    unitary = jnp.dot(unitary,band_matrix)

  return unitary

  #vectorize
  make_unitary = vmap(make_single_unitary, in_axes=(0, 0, None, None))

In [8]:
f = 4
n = 2
alphas = jnp.array([jnp.pi,jnp.pi/2,3/2*jnp.pi, 2*jnp.pi, jnp.pi])
betas = jnp.array([0,0,0,0,0])

print(alphas)

print(jnp.int_(make_single_unitary(alphas,betas,f,n))+1J*jnp.int_((-1J)*make_single_unitary(alphas,betas,f,n)))


[3.1415927 1.5707964 4.712389  6.2831855 3.1415927]
[[-2.+0.j  1.+0.j  0.+0.j  0.+0.j]
 [ 0.-1.j  0.-1.j  0.+0.j  0.+0.j]
 [-1.+0.j  0.+0.j  1.+0.j  0.+0.j]
 [ 0.+1.j  0.+0.j  0.+0.j  0.-1.j]]


  return asarray(x, dtype=self.dtype)
  out_array: Array = lax_internal._convert_element_type(


Random test, total unitary

In [16]:
f = 5; n = 1
alphas = np.random.uniform(low=0, high=4*np.pi, size=(n*(2*f-2*n-1)))
betas = np.random.uniform(low=0, high=np.pi/2, size=(n*(2*f-2*n-1)))
print('alphas: ',alphas)
print('betas: ',betas)
print("Unitary: ", make_single_unitary(alphas,betas,f,n))


alphas:  [ 1.7041011   0.9618042   2.70441546  2.79945505  2.91583542  0.88146734
 12.4149835 ]
betas:  [0.40743155 0.55728247 0.09842417 0.16788414 0.34641495 0.2499538
 0.16757512]
Unitary:  [[-0.12203047+0.9099959j  -0.39625242+0.j          0.        +0.j
   0.        +0.j          0.        +0.j        ]
 [ 0.19237606+0.27583975j  0.5742231 -0.52674073j -0.5288818 +0.j
   0.        +0.j          0.        +0.j        ]
 [-0.18894178+0.08829939j  0.26096684+0.406713j   -0.14439999+0.83215237j
  -0.09826533+0.j          0.        +0.j        ]
 [-0.01912717+0.00681202j  0.02153428+0.04182782j -0.02168258+0.07931449j
   0.9767409 +0.09310954j -0.16709661+0.j        ]
 [ 0.00344111+0.j         -0.00105973-0.00790253j  0.00797163-0.01143018j
  -0.15064861-0.07040359j -0.9287951 -0.3307845j ]]


Adapt for  eigenvector

In [62]:

#masks

def make_masks_cfs(f,n_col, band_number):

  """build

  Args:
    f: the number of particles
    n_col: the number of columns of the masks we want to build. To build a full matrix we will want n_col = f
    band_number: index of the band in the unitary decomposition (ranges from 1 to number of nonzero eigenvalues)
  Returns:
  3 arrays of masks of shape (f-band_number,f,n_col)
  """

  term_index = jnp.arange(f-band_number)[:,jnp.newaxis, jnp.newaxis]
  row_index = jnp.arange(f)[jnp.newaxis,:, jnp.newaxis]
  col_index = jnp.arange(n_col)[jnp.newaxis,jnp.newaxis,:]

  #conditions
  mask_cos_exp_pos = (row_index == term_index + band_number-1) & (band_number-1 <= col_index) & (col_index <= row_index)
  mask_cos_exp_neg = (col_index == term_index + band_number) & (row_index >= col_index)
  mask_sin = ((band_number-1 <= col_index) & (col_index <= term_index + band_number-1) & (row_index >= term_index + band_number)|
((row_index == term_index + band_number-1) & (col_index == row_index +1)))

  return mask_cos_exp_pos, mask_cos_exp_neg,mask_sin


In [84]:
print(1*make_masks_cfs(5,3,2)[2])

[[[0 0 0]
  [0 0 1]
  [0 1 0]
  [0 1 0]
  [0 1 0]]

 [[0 0 0]
  [0 0 0]
  [0 0 0]
  [0 1 1]
  [0 1 1]]

 [[0 0 0]
  [0 0 0]
  [0 0 0]
  [0 0 0]
  [0 1 1]]]


In [174]:
def make_single_band_unitary_cfs(alphas_band,betas_band,f,n_col):
  """Use the angle parameters and the masks to generate a unitary band matrix

    Args:

    alphas_band: shape (f-band_number,)
    betas_band: shape (f-band_number,)
    f: dimension of the matrix (number of  particles)
    n_col: number of columns of the unitary which we want to build

  Returns:
    (f, n_col) band  unitary matrix
    """
  band_number = f - len(alphas_band)

  #extract the building blocks and masks
  building_blocks = get_building_blocks(alphas_band, betas_band)
  masks = make_masks_cfs(f,n_col,band_number)

  #initialise the band matrix with a matrix with ones on lower triangle and superdiagonal
  ones_tril= jnp.tril(jnp.ones((f,f))) + jnp.eye(f,k = 1)
  ones_tril = ones_tril[:,:n_col]
  band_matrix= ones_tril.copy()

  #iterate over the different masks for each building block
  num_masks = len(masks)
  for building_block_index in range(num_masks):
    mask = masks[building_block_index] #shape (f-band_number,f,f)
    building_block = building_blocks[building_block_index]
    band_matrix_building_block = mask*building_block[:,jnp.newaxis,jnp.newaxis]

    #add ones in the lower triangle and superdiagonal before multiplying  the matrices
    band_matrix_building_block += ones_tril - mask

    #multiply the matrices together
    band_matrix_building_block = jnp.prod(band_matrix_building_block, axis =0)

    #multiply the different  building blocks together
    band_matrix *=band_matrix_building_block

  #multiply by a final mask
  final_mask = jnp.tril(jnp.concatenate((jnp.zeros((f, band_number-1)), jnp.ones((f, f - band_number+1))), axis=1), k=-1)# zeros in first band_number-1  columns
  final_mask += jnp.eye(f) #add ones on the diagonal
  super_diagonal_terms = jnp.concatenate((jnp.zeros(band_number-1),jnp.ones(f-band_number)))
  final_mask -= jnp.diag(super_diagonal_terms, k = 1) #-1 on the superdiagonal to represent  the negative  sines
  band_matrix *= final_mask[:,:n_col] #shape (f,n_col)

  return band_matrix

  make_band_unitary = vmap(make_single_band_unitary_cfs, in_axes=(0, 0, None,None))

In [175]:
f = 4
n_col = 4
alphas = jnp.array([jnp.pi,jnp.pi, jnp.pi])
betas = jnp.array([0,0,0])

print(alphas)

print(jnp.int_(make_single_band_unitary_cfs(alphas,betas,f,n_col))-1J*jnp.int_(1J*make_single_band_unitary_cfs(alphas,betas,f,n_col)))

[3.1415927 3.1415927 3.1415927]
[[-1.+0.j  0.+0.j  0.+0.j  0.+0.j]
 [ 0.+0.j  1.+0.j  0.+0.j  0.+0.j]
 [ 0.+0.j  0.+0.j  1.+0.j  0.+0.j]
 [ 0.+0.j  0.+0.j  0.+0.j -1.+0.j]]


  return asarray(x, dtype=self.dtype)
  out_array: Array = lax_internal._convert_element_type(


In [187]:
def make_single_eigenvectors(alphas,betas,f, n):
  """Use the angle parameters to build the 2n first eigenvectors of a spacetime point  x

    Args:
    f: dimension of the matrix (number of  particles)
    n: spin number hence  2n is total number of eigenvalues
    alphas: shape (2n(f-n)-n)
    betas: shape (2n(f-n)-n)

  Returns:
    (f,2n) matrix
    """
  #extract all the parameters
  num_alphas = len(alphas)
  building_blocks = get_building_blocks(alphas, betas)

  #initialise the eigenvectors:
  start_index = len(alphas)-(f-2*n)
  end_index = len(alphas)
  alphas_band, betas_band = alphas[start_index:end_index],betas[start_index:end_index]
  eigenvectors = make_single_band_unitary_cfs(alphas_band,betas_band,f,n_col = 2*n)

  #iterate over the remaining 2n-1 bands
  for band_number in range(2*n-1,0,-1):
    #extract the correct alphas and betas
    end_index = start_index
    start_index = end_index - (f - band_number)
    alphas_band, betas_band = alphas[start_index:end_index],betas[start_index:end_index]

    #build the band unitary
    band_matrix = make_single_band_unitary_cfs(alphas_band,betas_band,f,n_col = f)

    #multiply the unitaries and the vectors
    eigenvectors = jnp.dot(band_matrix,eigenvectors)

  return eigenvectors

  #vectorize
  make_eigenvectors = vmap(make_single_eigenvectors, in_axes=(0, 0, None,None))



In [137]:
f = 4

alphas = jnp.array([jnp.pi,jnp.pi/2,3/2*jnp.pi,2*jnp.pi,jnp.pi])
betas = jnp.array([0,0,0,0,0])


print(jnp.int_(make_single_eigenvectors(alphas,betas,f,n=1))+1J*jnp.int_((-1J)*make_single_eigenvectors(alphas,betas,f,n = 1)))

eigenvectors initial [[ 1.+0.0000000e+00j -1.+0.0000000e+00j]
 [ 0.+0.0000000e+00j  1.+1.7484555e-07j]
 [ 0.+0.0000000e+00j  0.+0.0000000e+00j]
 [ 0.+0.0000000e+00j  0.+0.0000000e+00j]]
eigenvectors initial [[ 1.+0.0000000e+00j -1.+0.0000000e+00j]
 [ 0.+0.0000000e+00j  1.+1.7484555e-07j]
 [ 0.+0.0000000e+00j  0.+0.0000000e+00j]
 [ 0.+0.0000000e+00j  0.+0.0000000e+00j]]
[[-1.+0.j  1.+0.j]
 [ 0.+0.j  0.-1.j]
 [ 0.+0.j  0.+0.j]
 [ 0.+0.j  0.+0.j]]


Random Test Eigenvectors

In [178]:
f = 8; n = 2
alphas = np.random.uniform(low=0, high=4*np.pi, size=(n*(2*f-2*n-1)))
betas = np.random.uniform(low=0, high=np.pi/2, size=(n*(2*f-2*n-1)))

print('alphas: ',alphas)
print('betas: ',betas)
print("Unitary: ", make_single_unitary(alphas,betas,f,n)[:,:2*n])

print( "eigenvectors", make_single_eigenvectors(alphas,betas,f,n))




alphas:  [11.83454483  4.7001774  11.88405135  5.30077314 11.31669712  9.82417698
  9.92689272 10.33022909  8.69606062 11.78894435 11.43511158  3.37789308
  0.50561575 10.65022845 12.49114432  7.03986528 10.67774546  3.92685007
  1.30326135  4.13147884  4.45699129  4.89403326]
betas:  [0.78087431 1.56967989 1.27859307 0.52718766 0.06210627 0.31798072
 0.60153587 1.11806767 0.40708618 0.98322888 1.2222888  0.31441328
 0.0816592  1.28775437 0.05949248 0.82857341 0.23230854 1.28421272
 1.32681702 1.3691796  0.74273294 1.0627137 ]
Unitary:  [[ 5.2843058e-01-0.47464183j -3.3835381e-01+0.7168686j
   6.3298756e-01+0.j          0.0000000e+00+0.j        ]
 [-9.5968126e-06-0.00078583j  6.1570972e-01-0.549123j
   2.4967952e-02+0.401404j    3.9593509e-01+0.j        ]
 [ 1.5737031e-01-0.12786411j -3.4403908e-01+0.1873368j
  -1.9395868e-01-0.1005514j   4.8680219e-01-0.02383385j]
 [ 3.2332128e-01-0.48458123j -5.2641547e-01+0.3308674j
  -5.1346707e-01+0.13053705j  1.4168407e-03-0.00833675j]
 [ 1.06832

test band  unitary vs band unitary for cfs

In [184]:
f = 8 ;n = 2; band_number = 4
alphas = np.random.uniform(low=0, high=4*np.pi, size=(f-band_number))
betas = np.random.uniform(low=0, high=np.pi/2, size=(f-band_number))

print('alphas: ',alphas)
print('betas: ',betas)

print(make_single_band_unitary_cfs(alphas,betas,f,n_col=f) == make_single_band_unitary(alphas,betas,f))

alphas:  [12.41127559 11.34728954  0.40510808 11.86795894]
betas:  [1.21733967 0.12445021 1.3335901  1.5623361 ]
[[ True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True]]


Build Lagrangian


In [248]:


def make_lagrangian_1(spectra,eigenvectors, i: int, j: int) -> float:
  """The Lagrangian for a single pair of spacetime points for n = 1.

  Args:
    spectra: (m,2n)
    eigenvectors: (m,f,2n)
    i: Index for first point.
    j: Index for second point.

  Returns:
    value of the Lagrangian
  """
  gram = jnp.dot(jnp.conj(eigenvectors[i].T),eigenvectors[j])
  eigenvalue_products = jnp.outer(spectra[i],spectra[j])
  xy_product = jnp.dot(gram *eigenvalue_products, jnp.conj(gram.T)) #not exactly xy but an isospectral matrix, M in write up
  D = 0.5*(jnp.real(xy_product[0,0])-jnp.real(xy_product[1,1]))**2 + 2*jnp.real(xy_product[0,1]*xy_product[1,0])

  #R = jnp.trace(xy_product@xy_product) - 0.5*(jnp.trace(xy_product))**2
  return relu(D)




test Make_lagrangian_1d

In [228]:
seed_value = 34
key = random.PRNGKey(seed_value)
f = 5
n = 1
m = 3

#initialise values

sigma_spectrum = 0.01
init_spectrum = 1
sigma_weights = 0.01

weights, pos_spectrum, neg_spectrum, alphas, betas = init_params(key,n,f,m,sigma_weights,init_spectrum,sigma_spectrum)


In [249]:
spectra = make_spectra(pos_spectrum,neg_spectrum)
make_eigenvectors = vmap(make_single_eigenvectors, in_axes=(0, 0, None,None))
eigenvectors = make_eigenvectors(alphas,betas,f,n)

print(make_lagrangian_1(spectra,eigenvectors,0,1))



eigenvalue_products:  [[ 4.008974  -1.9972023]
 [-2.016216   1.0044444]]
xy_products:  (-0.3905294+2.9802322e-08j)
xy_products:  [[ 0.24903847+2.9802322e-08j  1.2698573 -1.6396946e-01j]
 [-0.63864386-8.2464427e-02j -0.3905294 +2.9802322e-08j]]
(Array(1.1920929e-07-6.803816e-08j, dtype=complex64), Array(-1.4444928, dtype=float32), Array(-1.4444929+6.803816e-08j, dtype=complex64), Array(-1.4444928+7.063845e-08j, dtype=complex64))


Action

In [None]:
def action(params: Params) -> float:
  """The action.

  Args:
    params: The 5-tuple of parameters (weights, positive spectrum,
        negative spectrum, block_ul, block_ur).

  Returns:
    single float for the value of the action
  """
  xs, weights = make_xs_and_weights(params)
  # weighted sum of Lagrangian for pairs
  m, n, _ = params[-2].shape
  if n == 1:
    make_lag = vmap(make_lagrangian_1, (None, 0, 0, None))
  #else:
    #make_lag = vmap(make_lagrangian_n, (None, 0, 0, None))

  # Only looking at upper triangle (without diagonal)
  rows, cols = jnp.triu_indices(m, k=1)
  lag_ij = make_lag(xs, rows, cols, two_n)
  act = 2 * jnp.sum(weights[rows] * weights[cols] * lag_ij)
  # Add diagonal
  diag = jnp.arange(m)
  lag_ij = make_lag(xs, diag, diag, two_n)
  act += jnp.sum(weights ** 2 * lag_ij)
  return act

Boundedness functional

In [None]:
def boundedness(params: Params) -> float:
  """The boundedness functional.

  Args:
    params: The 5-tuple of parameters (weights, positive spectrum,
        negative spectrum, block_ul, block_ur).

  Returns:
    single float for the value of the action
  """
  xs, weights = make_xs_and_weights(params)
  # weighted sum of Lagrangian for pairs
  m, two_n, _ = params[-2].shape

  def _boundedness(_xs, _i, _j, _two_n):
    xij = _xs[_i] @ _xs[_j]
    spec = jnp.sort(jnp.abs(jnp.linalg.eigvals(xij)))[-_two_n:]
    _bnd = jnp.sum(spec) ** 2
    return _bnd

  make_bnd = vmap(_boundedness, (None, 0, 0, None))
  # Only looking at upper triangle (without diagonal)
  rows, cols = jnp.triu_indices(m, k=1)
  bnd_ij, _ = make_bnd(xs, rows, cols, two_n)
  bnd = 2 * jnp.sum(weights[rows] * weights[cols] * bnd_ij)
  # Add diagonal
  diag = jnp.arange(m)
  bnd_ij, _ = make_bnd(xs, diag, diag, two_n)
  bnd += jnp.sum(weights ** 2 * bnd_ij)
  return bnd
