Imports

In [31]:
"""Collection of general utility functions."""
import os
from typing import Tuple, Text, Dict, Union, Any

import numpy as np
from jax.experimental import sparse
from absl import logging
from jax import random, vmap, jit, grad, numpy as jnp
from jax.nn import softmax, relu
#from jax.ops import index_update, index DEPRECATED: replace with jax.numpy.ndarray.at in the corresponding code
from jax.scipy.linalg import expm
from scipy.optimize import minimize

Params = Tuple[jnp.ndarray, ...]
Results = Dict[Text, Union[jnp.ndarray, list, float]]
Tup = Tuple[float, float]



New unitary decomposition parametrization:

In [203]:

def init_params(key: jnp.ndarray, n: int, f: int, m: int,
                sigma_weights: float = 0, init_spectrum: float = 1,
                sigma_spectrum: float = 0) -> Params:
  """Initialize optimization parameters.

  Args:
    key: The random key.
    n: The spin dimension.
    f: The number of particles.
    m: The cardinality of the support of the discrete measure.
    sigma_weights: The standard deviation of the noise to be added to the fixed
      initialization of weights.
    init_spectrum: The initial absolute value of the mean of the negative
      eigenvalues. We will add 1/n to this value for each of the positive
      eigenvalues to ensure trace = 1.
    sigma_spectrum: The standard deviation of the noise to be added to the fixed
      initialization of the spectra.

  Returns:
    a 5-tuple of initial parameters:
        weights: The weights of the discrete measure (m,); real
        pos_spectrum: initial log values of pos spectrum (m, n); real
        neg_spectrum: initial log values of pos spectrum (m, n); real
        alphas: initial values of the alpha angles used to define the unitary (m,2n(f-n)-n): real
        betas: initial values of the beta angles used to define the unitary (m,2n(f-n)-n): real
  """
  subkeys = random.split(key, 5)
  weights = sigma_weights * random.normal(subkeys[0], (m,))
  # Initial pos spectra are all (init_spectrum + 1 / n) and
  # negative ones are all (init_spectrum)
  # We work with logs of desired spectra to ensure the correct signs later
  pos_spectrum = jnp.log(init_spectrum + 1. / n)
  pos_spectrum += sigma_spectrum * random.normal(subkeys[1], (m, n))
  neg_spectrum = jnp.log(init_spectrum)
  neg_spectrum += sigma_spectrum * random.normal(subkeys[2], (m, n))
  alphas = random.uniform(subkeys[3],(m,2*n*(f-n)-n), minval = 0, maxval = 4*jnp.pi)
  betas = random.uniform(subkeys[4],(m,2*n*(f-n)-n), minval = 0, maxval = jnp.pi/2)
  return weights, pos_spectrum, neg_spectrum, alphas, betas


# =============================================================================
# STEPS TO COMPOSE ACTION (AND BOUNDEDNESS FUNCTIONAL) FROM PARAMETERS
# =============================================================================

def make_masks(f:int, band_number: int):
  """Create the original masks from which a band of the unitary will be created.
  Precisely, it creates 3 masks which are np.arrays with  f-band_number sparse matrices of shape (f,f).
  Each of the 3 masks is associated with one of the main "building blocks"
  of the band matrices: the exp(ialpha)cos(beta) terms, the exp(-ialpha)cos(beta)
  an sin(beta) terms. The ith line of a given mask is a sparse matrix with
  1s where the ith term appears in the band matrix
  Note that, reading bands from left to right, bands progressively have fewer parameters

  Args:
    f: the number of particles
    band_number: the number of the band in the unitary for which we are computing the mask
  Returns:
    the 3 masks
  """
  mask_cos_exp_pos_coordinates =[[] for _ in range(f-band_number)]
  for i in range(f-band_number):
    for j in range(i+1):
        mask_cos_exp_pos_coordinates[i].append([i+band_number-1,j+band_number-1])

  mask_cos_exp_neg_coordinates = [[] for _ in range(f-band_number)]
  for i in range(1,f-band_number + 1):
    for j in range(1,i+1):
        mask_cos_exp_neg_coordinates[i-1].append([i+band_number-1,j+band_number-1])

  mask_sin_coordinates = [[] for _ in range(f-band_number)]
  for k in range(f-band_number+1):
    #first add the terms on the superdiagonal (-sin)
    if k != f-band_number:
      mask_sin_coordinates[k].append([k+band_number-1,k+1+band_number-1])
    #then add the remaining (+sin) terms
    for i in range(k,f-band_number+1):
      for j in range(k):
        mask_sin_coordinates[k-1].append([i+band_number-1,j+band_number-1])

  def make_mask_matrices (mask_coordinates):
    f = len(mask_coordinates) + band_number
    print(f)
    mask_matrices = np.empty((f-band_number), object)
    for i in range(f-band_number):
      coordinates = mask_coordinates[i]
      num_coordinates = len(coordinates)
      data = jnp.ones(num_coordinates)
      mask_matrices[i] = sparse.BCOO((data, jnp.array(coordinates)), shape=(f, f))
    return mask_matrices

  mask_cos_exp_pos_matrices = make_mask_matrices(mask_cos_exp_pos_coordinates)
  mask_cos_exp_neg_matrices = make_mask_matrices(mask_cos_exp_neg_coordinates)
  mask_sin_matrices = make_mask_matrices(mask_sin_coordinates)

  return mask_cos_exp_pos_matrices,mask_cos_exp_neg_matrices,mask_sin_matrices

def get_building_blocks(alphas:jnp.ndarray, betas:jnp.ndarray):
  """Get the spacetime points from the current parameters.

  Args:
    alphas: Values of the alpha angles used to define the unitary (m,2n(f-n)-n): real
    betas: Values of the beta angles used to define the unitary (m,2n(f-n)-n): real
  Returns:
  3 "building blocks" terms, basic trigonometric functions of the alphas and betas of the same size
  (m,2n(f-n)-n)
  """
  exp_alphas = jnp.exp(1J*alphas)
  cos_betas = jnp.cos(betas)
  sin_betas = jnp.sin(betas)
  cos_betas_exp_pos_alphas = exp_alphas*cos_betas
  cos_betas_exp_neg_alphas = jnp.conj(exp_alphas)* cos_betas

  return cos_betas_exp_pos_alphas, cos_betas_exp_neg_alphas, sin_betas
def make_unitary(alphas,betas,f, n, m):
  """Use the parameters and the masks to generate successive band matrices
  and multiply them together to obtain the first 2n eigenvectors of the Unitary
  for each spacetime point

    Args:
    n: number of eigenvalues CAREFUL
    params: The tuple of parameters (weights, positive spectrum,
        negative spectrum, alphas, betas).

  Returns:
    (m, f, f) array with a stack of m unitary matrices
    """
  #extract all the parameters
  building_blocks = get_building_blocks(alphas, betas)

  #initialise the unitary
  unitary = jnp.eye(f)

  #iterate over the bands
  end_index = -1
  for band_number in range(1,n+1):

    #extract the correct alphas and betas
    start_index = end_index +1
    end_index =start_index + f - band_number
    alphas_band, betas_band = alphas[:,start_index:end_index],betas[:,start_index:end_index]

    #extract the building blocks
    building_blocks = get_building_blocks(alphas_band, betas_band)

    # extract the correct masks
    masks = make_masks(f,band_number)

    #initialise the band matrix with a matrix with ones on lower triangle and superdiagonal
    ones_tril= jnp.tril(jnp.ones((f,f))) + jnp.eye(f,k = 1)
    m_ones_tril = jnp.tile(ones_tril,(m,1,1))
    band_matrix= m_ones_tril.copy()

    #iterate over the different masks for each building block
    num_masks = len(masks)
    for building_block_index in range(num_masks):
      mask= masks[building_block_index] #shape (f-band_number,f,f)
      building_block = building_blocks[building_block_index]

      #reshape the arrays to multiply them
      mask_reshaped = jnp.tile(mask,(m,1,1,1))
      building_block_reshaped = building_block[:,:,jnp.newaxis,jnp.newaxis]
      band_matrix_building_block = mask_reshaped* building_block_reshaped

      #add ones in the lower triangle and superdiagonal before multiplying  the matrices
      ones_tril_reshaped = jnp.tile(ones_tril,(m,f-band_number,1,1))
      band_matrix_building_block += ones_tril_reshaped - mask_reshaped

      #multiply the matrices together
      band_matrix_building_block = jnp.prod(band_matrix_building_block, axis =1)

      #multiply the different  building blocks together
      band_matrix *=band_matrix_building_block

    #multiply the superdiagonal by -1 to represent  the negative  sines.
    final_mask = jnp.tril(jnp.ones((f,f))) - jnp.eye(f,k = 1)
    m_final_mask = jnp.tile(final_mask, (m,1,1))
    band_matrix*=m_final_mask #shape (m,f,f)

    #add the diagonal 1s for the different bands
    for i in range(band_number-1):
      band_matrix[:,i,i] = 1

    #multiply the different unitaries together
    unitary = jnp.dot(unitary,band_matrix)

  return unitary

def make_spectra(pos_spectrum: jnp.ndarray,
                 neg_spectrum: jnp.ndarray) -> jnp.ndarray:
  """Compute actual spectra from optimization parameters.

  The spectra have to have n positive and n negative eigenvalues
  and satisfy the trace constraint, which we ensure here.

  Args:
    pos_spectrum: Optimization parameters for positive eigenvalues.
    neg_spectrum: Optimization parameters for negative eigenvalues.

  Returns:
    Full (m, 2 n) array of the m spectra
  """
  spectra = jnp.concatenate((jnp.exp(pos_spectrum), - jnp.exp(neg_spectrum)), 1)
  return spectra / jnp.sum(spectra, axis=1)[..., jnp.newaxis]


def make_xs(spectra: jnp.ndarray, hamiltonian: jnp.ndarray) -> jnp.ndarray:
  """Generate the m spacetime points.

  Args:
    spectra: The spectra of the m points.
    hamiltonian: The Hamiltonian matrices.

  Returns:
    (m, f, f) stack of m spacetime points
  """
  m, two_n = spectra.shape
  f = hamiltonian.shape[-1]
  # (m, f, f), unitary <- expm(anti-hermitian)
  unitary = vmap(expm)(- 1J * hamiltonian)
  # (m, f)
  xs = vmap(jnp.diag)(jnp.concatenate((spectra, jnp.zeros((m, f - two_n))), 1))
  # (m, f, f), hermitian (U_i x_i U_i^{\dagger})
  xs = jnp.einsum('...ij,...jk,...lk->...il', unitary, xs, jnp.conj(unitary))
  return xs


def make_xs_and_weights(params: Params) -> Tuple[jnp.ndarray, jnp.ndarray]:
  """Get the spacetime points from the current parameters.

  Args:
    params: The tuple of parameters (weights, positive spectrum,
        negative spectrum, block_ul, block_ur).

  Returns:
    (m, f, f) array with a stack of m spacetime points of dimensions (f, f)
  """
  # (m, n), (m, n), (m, 2n, 2n), (m, 2n, f - 2n)
  weights, pos_spectrum, neg_spectrum, block_ul, block_ur = params
  # Scale weights into [0,1] and sum to 1
  weights = softmax(weights)
  # (m, 2n)
  spectra = make_spectra(pos_spectrum, neg_spectrum)
  # (m, f, f), hermitian
  hamiltonian = make_hamiltonian(block_ul, block_ur)
  # (m, f, f), hermitian
  return make_xs(spectra, hamiltonian), weights


def make_lagrangian_n(xs: jnp.ndarray, i: int, j: int, two_n: int) -> float:
  """The Lagrangian for a single pair of spacetime points for n >= 1.

  Args:
    xs: (m, f, f) array of all the spacetime points.
    i: Index for first point.
    j: Index for second point.
    two_n: 2n (2 times the spin dimension)

  Returns:
    value of the Lagrangian
  """
  xij = xs[i] @ xs[j]
  spec = jnp.sort(jnp.abs(jnp.linalg.eigvals(xij)))[-two_n:]
  bnd = jnp.sum(spec) ** 2
  return jnp.sum(spec ** 2) - bnd / two_n


def make_lagrangian_1(xs: jnp.ndarray, i: int, j: int, two_n: int) -> float:
  """The Lagrangian for a single pair of spacetime points for n = 1.

  Args:
    xs: (m, f, f) array of all the spacetime points.
    i: Index for first point.
    j: Index for second point.
    two_n: 2n (2 times the spin dimension); unused here, kept for compatibility.

  Returns:
    value of the Lagrangian
  """
  xij = xs[i] @ xs[j]
  tmp = jnp.real(jnp.trace(xij) ** 2)
  return relu(jnp.real(jnp.trace(xij @ xij)) - tmp / 2.)


def action(params: Params) -> float:
  """The action.

  Args:
    params: The 5-tuple of parameters (weights, positive spectrum,
        negative spectrum, block_ul, block_ur).

  Returns:
    single float for the value of the action
  """
  xs, weights = make_xs_and_weights(params)
  # weighted sum of Lagrangian for pairs
  m, two_n, _ = params[-2].shape
  if two_n == 2:
    make_lag = vmap(make_lagrangian_1, (None, 0, 0, None))
  else:
    make_lag = vmap(make_lagrangian_n, (None, 0, 0, None))
  # Only looking at upper triangle (without diagonal)
  rows, cols = jnp.triu_indices(m, k=1)
  lag_ij = make_lag(xs, rows, cols, two_n)
  act = 2 * jnp.sum(weights[rows] * weights[cols] * lag_ij)
  # Add diagonal
  diag = jnp.arange(m)
  lag_ij = make_lag(xs, diag, diag, two_n)
  act += jnp.sum(weights ** 2 * lag_ij)
  return act


Old exponential parametrization:

In [None]:

def init_params(key: jnp.ndarray, n: int, f: int, m: int,
                sigma_weights: float = 0, init_spectrum: float = 1,
                sigma_spectrum: float = 0) -> Params:
  """Initialize optimization parameters.

  Args:
    key: The random key.
    n: The spin dimension.
    f: The number of particles.
    m: The cardinality of the support of the discrete measure.
    sigma_weights: The standard deviation of the noise to be added to the fixed
      initialization of weights.
    init_spectrum: The initial absolute value of the mean of the negative
      eigenvalues. We will add 1/n to this value for each of the positive
      eigenvalues to ensure trace = 1.
    sigma_spectrum: The standard deviation of the noise to be added to the fixed
      initialization of the spectra.

  Returns:
    a 5-tuple of initial parameters:
        weights: The weights of the discrete measure (m,); real
        pos_spectrum: initial log values of pos spectrum (m, n); real
        neg_spectrum: initial log values of pos spectrum (m, n); real
        block_ul: the upper left block of H (m, 2n, 2n); complex
        block_ur: the upper right block of H (m, 2n, f - 2n); complex
  """
  subkeys = random.split(key, 7)
  weights = sigma_weights * random.normal(subkeys[0], (m,))
  # Initial pos spectra are all (init_spectrum + 1 / n) and
  # negative ones are all (init_spectrum)
  # We work with logs of desired spectra to ensure the correct signs later
  pos_spectrum = jnp.log(init_spectrum + 1. / n)
  pos_spectrum += sigma_spectrum * random.normal(subkeys[1], (m, n))
  neg_spectrum = jnp.log(init_spectrum)
  neg_spectrum += sigma_spectrum * random.normal(subkeys[2], (m, n))


  # Upper right block of Hermitian uniformly at random from
  # \{x + i y | -pi \le x, y \le pi \}
  # Compute lower left as conjugate transpose
  block_ur = random.uniform(subkeys[3], (m, 2 * n, f - 2 * n),
                            minval=-jnp.pi, maxval=jnp.pi)
  block_ur += 1J * random.uniform(subkeys[4], (m, 2 * n, f - 2 * n),
                                  minval=-jnp.pi, maxval=jnp.pi)
  # Upper left block paramaterized as block_ul + block_ul^{\dagger}
  # will set diagonal to zero explicitly
  block_ul = random.uniform(subkeys[5], (m, 2 * n, 2 * n),
                            minval=-jnp.pi, maxval=jnp.pi)
  block_ul += 1J * random.uniform(subkeys[6], (m, 2 * n, 2 * n),
                                  minval=-jnp.pi, maxval=jnp.pi)
  return weights, pos_spectrum, neg_spectrum, block_ul, block_ur


# =============================================================================
# STEPS TO COMPOSE ACTION (AND BOUNDEDNESS FUNCTIONAL) FROM PARAMETERS
# =============================================================================

def make_hamiltonian(block_ul: jnp.ndarray,
                     block_ur: jnp.ndarray) -> jnp.ndarray:
  """Put together the full Hamiltonian from blocks.

  We ensure that Hamiltonians H are hermitian.
  Thereby, - i H are anti-hermitian and exp(-i H) is unitary.

  Args:
    block_ul: The upper left block is of shape (m, 2n, 2n)
    block_ur: The upper right block is of shape (m, 2n, f - 2n)

  Returns:
    The stack of full Hamiltonians of shape (m, f, f)
  """
  m = block_ul.shape[0]
  two_n = block_ul.shape[-1]
  f = block_ur.shape[-1] + two_n
  # compute upper left block as A + A^{\dagger} and set diag to 0
  block_ul = block_ul + jnp.swapaxes(jnp.conj(block_ul), 1, 2)
  block_ul *= (1 - jnp.eye(two_n)[jnp.newaxis, ...])
  # put pieces together
  hamiltonian = jnp.zeros((m, f, f), dtype=jnp.complex64)
  hamiltonian = index_update(hamiltonian, index[:, :two_n, :two_n], block_ul)
  hamiltonian = index_update(hamiltonian, index[:, :two_n, two_n:], block_ur)
  block_ur = jnp.swapaxes(jnp.conj(block_ur), 1, 2)
  hamiltonian = index_update(hamiltonian, index[:, two_n:, :two_n], block_ur)
  return hamiltonian


def make_spectra(pos_spectrum: jnp.ndarray,
                 neg_spectrum: jnp.ndarray) -> jnp.ndarray:
  """Compute actual spectra from optimization parameters.

  The spectra have to have n positive and n negative eigenvalues
  and satisfy the trace constraint, which we ensure here.

  Args:
    pos_spectrum: Optimization parameters for positive eigenvalues.
    neg_spectrum: Optimization parameters for negative eigenvalues.

  Returns:
    Full (m, 2 n) array of the m spectra
  """
  spectra = jnp.concatenate((jnp.exp(pos_spectrum), - jnp.exp(neg_spectrum)), 1)
  return spectra / jnp.sum(spectra, axis=1)[..., jnp.newaxis]


def make_xs(spectra: jnp.ndarray, hamiltonian: jnp.ndarray) -> jnp.ndarray:
  """Generate the m spacetime points.

  Args:
    spectra: The spectra of the m points.
    hamiltonian: The Hamiltonian matrices.

  Returns:
    (m, f, f) stack of m spacetime points
  """
  m, two_n = spectra.shape
  f = hamiltonian.shape[-1]
  # (m, f, f), unitary <- expm(anti-hermitian)
  unitary = vmap(expm)(- 1J * hamiltonian)
  # (m, f)
  xs = vmap(jnp.diag)(jnp.concatenate((spectra, jnp.zeros((m, f - two_n))), 1))
  # (m, f, f), hermitian (U_i x_i U_i^{\dagger})
  xs = jnp.einsum('...ij,...jk,...lk->...il', unitary, xs, jnp.conj(unitary))
  return xs


def make_xs_and_weights(params: Params) -> Tuple[jnp.ndarray, jnp.ndarray]:
  """Get the spacetime points from the current parameters.

  Args:
    params: The tuple of parameters (weights, positive spectrum,
        negative spectrum, block_ul, block_ur).

  Returns:
    (m, f, f) array with a stack of m spacetime points of dimensions (f, f)
  """
  # (m, n), (m, n), (m, 2n, 2n), (m, 2n, f - 2n)
  weights, pos_spectrum, neg_spectrum, block_ul, block_ur = params
  # Scale weights into [0,1] and sum to 1
  weights = softmax(weights)
  # (m, 2n)
  spectra = make_spectra(pos_spectrum, neg_spectrum)
  # (m, f, f), hermitian
  hamiltonian = make_hamiltonian(block_ul, block_ur)
  # (m, f, f), hermitian
  return make_xs(spectra, hamiltonian), weights


def make_lagrangian_n(xs: jnp.ndarray, i: int, j: int, two_n: int) -> float:
  """The Lagrangian for a single pair of spacetime points for n >= 1.

  Args:
    xs: (m, f, f) array of all the spacetime points.
    i: Index for first point.
    j: Index for second point.
    two_n: 2n (2 times the spin dimension)

  Returns:
    value of the Lagrangian
  """
  xij = xs[i] @ xs[j]
  spec = jnp.sort(jnp.abs(jnp.linalg.eigvals(xij)))[-two_n:]
  bnd = jnp.sum(spec) ** 2
  return jnp.sum(spec ** 2) - bnd / two_n


def make_lagrangian_1(xs: jnp.ndarray, i: int, j: int, two_n: int) -> float:
  """The Lagrangian for a single pair of spacetime points for n = 1.

  Args:
    xs: (m, f, f) array of all the spacetime points.
    i: Index for first point.
    j: Index for second point.
    two_n: 2n (2 times the spin dimension); unused here, kept for compatibility.

  Returns:
    value of the Lagrangian
  """
  xij = xs[i] @ xs[j]
  tmp = jnp.real(jnp.trace(xij) ** 2)
  return relu(jnp.real(jnp.trace(xij @ xij)) - tmp / 2.)


def action(params: Params) -> float:
  """The action.

  Args:
    params: The 5-tuple of parameters (weights, positive spectrum,
        negative spectrum, block_ul, block_ur).

  Returns:
    single float for the value of the action
  """
  xs, weights = make_xs_and_weights(params)
  # weighted sum of Lagrangian for pairs
  m, two_n, _ = params[-2].shape
  if two_n == 2:
    make_lag = vmap(make_lagrangian_1, (None, 0, 0, None))
  else:
    make_lag = vmap(make_lagrangian_n, (None, 0, 0, None))
  # Only looking at upper triangle (without diagonal)
  rows, cols = jnp.triu_indices(m, k=1)
  lag_ij = make_lag(xs, rows, cols, two_n)
  act = 2 * jnp.sum(weights[rows] * weights[cols] * lag_ij)
  # Add diagonal
  diag = jnp.arange(m)
  lag_ij = make_lag(xs, diag, diag, two_n)
  act += jnp.sum(weights ** 2 * lag_ij)
  return act


In [4]:
f = 3
mask_exp_pos = [[] for _ in range(f-1)]
for i in range(f-1):
  for j in range(i+1):
       mask_exp_pos[i].append((i,j))

mask_exp_neg = [[] for _ in range(f-1)]
for i in range(1,f):
  for j in range(1,i+1):
       mask_exp_neg[i-1].append((i,j))

mask_sin = [[] for _ in range(f-1)]
for k in range(f):
  #first add the terms on the superdiagonal (-sin)
  if k != f-1:
    mask_sin_pos[k].append((k,k+1))
  #then add the remaining (+sin) terms
  for i in range(k,f):
    for j in range(k):
      mask_sin_pos[k-1].append(((i,j)))



[[(0, 1)], [(1, 2)]]
[[(0, 1), (1, 0), (2, 0)], [(1, 2), (2, 0), (2, 1)]]


In [143]:
f = 4
band_number =  3 #from 1 to n (number of eigenvalues)

from jax.experimental import sparse


mask_exp_pos_coordinates =[[] for _ in range(f-band_number)]
for i in range(f-band_number):
  for j in range(i+1):
       mask_exp_pos_coordinates[i].append([i+band_number-1,j+band_number-1])

mask_exp_neg_coordinates = [[] for _ in range(f-band_number)]
for i in range(1,f-band_number + 1):
  for j in range(1,i+1):
       mask_exp_neg_coordinates[i-1].append([i+band_number-1,j+band_number-1])

mask_sin_coordinates = [[] for _ in range(f-band_number)]
for k in range(f-band_number+1):
  #first add the terms on the superdiagonal (-sin)
  if k != f-band_number:
    mask_sin_coordinates[k].append([k+band_number-1,k+1+band_number-1])
  #then add the remaining (+sin) terms
  for i in range(k,f-band_number+1):
    for j in range(k):
      mask_sin_coordinates[k-1].append([i+band_number-1,j+band_number-1])

def make_mask_matrices (mask_coordinates):
  f = len(mask_coordinates) + band_number
  print(f)
  mask_matrices = np.empty((f-band_number), object)
  for i in range(f-band_number):
    coordinates = mask_coordinates[i]
    num_coordinates = len(coordinates)
    data = jnp.ones(num_coordinates)
    mask_matrices[i] = sparse.BCOO((data, jnp.array(coordinates)), shape=(f, f))
  return mask_matrices

masks =

#now remember you have to add the diagonal terms in your overall unitaries at the end

4
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 1. 0.]]


[[0. 1. 0.]
 [0. 0. 1.]
 [0. 0. 0.]]


In [92]:

import numpy as np

# Sample input: list of coordinates and corresponding values
coordinates_list = [
    [(0, 1), (1, 2)],  # Coordinates for the first matrix
    [(0, 0), (2, 2)],  # Coordinates for the second matrix
    [(1, 1)]           # Coordinates for the third matrix
]

values = np.array([1, 1, 1])  # Values to multiply each matrix

# Determine the shape of the final matrices (e.g., 3x3 for this example)
shape = (3, 3)
num_matrices = len(coordinates_list)

# Flatten the list of coordinates and create arrays of indices
all_coords = np.array([coord for coords in coordinates_list for coord in coords])
row_indices = all_coords[:, 0]
col_indices = all_coords[:, 1]
matrix_indices = np.repeat(np.arange(num_matrices), [len(coords) for coords in coordinates_list])

# Initialize an array to hold the mask for all matrices
masks = np.zeros((num_matrices, *shape), dtype=np.float32)

# Set the appropriate elements in the masks array to 1
masks[matrix_indices, row_indices, col_indices] = 1

# Multiply each matrix in masks by the corresponding value
weighted_masks = masks * values[:, np.newaxis, np.newaxis]

# Sum all the matrices together
sum_matrix = np.sum(weighted_masks, axis=0)

print("Sum of matrices:")
print(sum_matrix)

Sum of matrices:
[[1. 1. 0.]
 [0. 1. 1.]
 [0. 0. 1.]]


tasks:
0) Replace sin_neg and sin_pos by sin
1) transform the lists of lists of coordinates into a numpy  array of sparse matrices
2) Extract the correct amount of alphas from the total sequence
3) at each step, make  the right modification of the sparse matrices to get the correct masks
4) transform a mask into by adding an extra  dimension to your array to have a array of matrices for each m
5) modify the multiplications

In [29]:
f = 5
band_number = 1
from scipy import sparse

mask_cos_exp_pos_coordinates =[[] for _ in range(f-band_number)]
for i in range(f-band_number):
  for j in range(i+1):
      mask_cos_exp_pos_coordinates[i].append([i+band_number-1,j+band_number-1])

mask_cos_exp_neg_coordinates = [[] for _ in range(f-band_number)]
for i in range(1,f-band_number + 1):
  for j in range(1,i+1):
      mask_cos_exp_neg_coordinates[i-1].append([i+band_number-1,j+band_number-1])

mask_sin_coordinates = [[[],[]] for _ in range(f-band_number)]
for k in range(f-band_number+1):
  #first add the terms on the superdiagonal (-sin)
  if k != f-band_number:
    mask_sin_coordinates[k][0].append(k+band_number-1)
    mask_sin_coordinates[k][1].append(k+1+band_number-1)
  #then add the remaining (+sin) terms
  for i in range(k,f-band_number+1):
    for j in range(k):
      mask_sin_coordinates[k-1][0].append(i+band_number-1)
      mask_sin_coordinates[k-1][1].append(j+band_number-1)

def make_mask_matrices (mask_coordinates):
  f = len(mask_coordinates) + band_number
  print(f)
  mask_matrices = np.empty((f-band_number), object)
  for i in range(f-band_number):
    coordinates = mask_coordinates[i]
    row_coordinates, col_coordinates = coordinates[0],coordinates[1]
    num_coordinates = len(row_coordinates)
    data = jnp.ones(num_coordinates)
    mask_matrices[i] = sparse.coo_matrix((data, (row_coordinates,col_coordinates)), shape=(f, f)).toarray()
  return mask_matrices

#mask_cos_exp_pos_matrices = make_mask_matrices(mask_cos_exp_pos_coordinates)
#mask_cos_exp_neg_matrices = make_mask_matrices(mask_cos_exp_neg_coordinates)
mask_sin_matrices = make_mask_matrices(mask_sin_coordinates)

print(jnp.tile(mask_sin_matrices,(2,1,1,1)))
#mask_cos_exp_pos_matrices.dtype

5


TypeError: Value '[[[[[[[[array([[0., 1., 0., 0., 0.],
               [1., 0., 0., 0., 0.],
               [1., 0., 0., 0., 0.],
               [1., 0., 0., 0., 0.],
               [1., 0., 0., 0., 0.]], dtype=float32)
        array([[0., 0., 0., 0., 0.],
               [0., 0., 1., 0., 0.],
               [1., 1., 0., 0., 0.],
               [1., 1., 0., 0., 0.],
               [1., 1., 0., 0., 0.]], dtype=float32)
        array([[0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.],
               [0., 0., 0., 1., 0.],
               [1., 1., 1., 0., 0.],
               [1., 1., 1., 0., 0.]], dtype=float32)
        array([[0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 0.],
               [0., 0., 0., 0., 1.],
               [1., 1., 1., 1., 0.]], dtype=float32)]]]]]]]]' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

In [196]:
a = jnp.array([[[1,0],[1,0]],[[0,1],[0,1]],[[0,2],[0,2]]] )
b = jnp.tile(a, (3,1,1,1))
values = jnp.array([[1,2,3], [3,4,4],[5,6,5]])
values_b = values[:, :, jnp.newaxis,jnp.newaxis]

c = b* values_b

d = jnp.prod(c, axis =1)
c = jnp.array([[1,0],[1,0]])
d = jnp.tile(c,(3,1,1))
#c = b*values_b
print(d)



[[[1 0]
  [1 0]]

 [[1 0]
  [1 0]]

 [[1 0]
  [1 0]]]


In [206]:
f = 4 ;m = 1; n = 2
alphas = np.random.uniform(low=0, high=4*np.pi, size=(m, 2*n*(f-n)-n))
betas = np.random.uniform(low=0, high=np.pi/2, size=(m, 2*n*(f-n)-n))
make_unitary(alphas,betas,f, n, m)

4
4
4


TypeError: Value '[[[[[[[[BCOO(float32[4, 4], nse=1) BCOO(float32[4, 4], nse=2)
        BCOO(float32[4, 4], nse=3)]]]]]]]]' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

In [37]:
from jax.experimental import sparse
def make_masks(f:int, band_number: int):
  """Create the original masks from which a band of the unitary will be created.
  Precisely, it creates 3 masks which are np.arrays with  f-band_number sparse matrices of shape (f,f).
  Each of the 3 masks is associated with one of the main "building blocks"
  of the band matrices: the exp(ialpha)cos(beta) terms, the exp(-ialpha)cos(beta)
  an sin(beta) terms. The ith line of a given mask is a sparse matrix with
  1s where the ith term appears in the band matrix
  Note that, reading bands from left to right, bands progressively have fewer parameters

  Args:
    f: the number of particles
    band_number: the number of the band in the unitary for which we are computing the mask
  Returns:
    the 3 masks
  """
  mask_cos_exp_pos_coordinates =[[] for _ in range(f-band_number)]
  for i in range(f-band_number):
    for j in range(i+1):
        mask_cos_exp_pos_coordinates[i].append([i+band_number-1,j+band_number-1])

  mask_cos_exp_neg_coordinates = [[] for _ in range(f-band_number)]
  for i in range(1,f-band_number + 1):
    for j in range(1,i+1):
        mask_cos_exp_neg_coordinates[i-1].append([i+band_number-1,j+band_number-1])

  mask_sin_coordinates = [[] for _ in range(f-band_number)]
  for k in range(f-band_number+1):
    #first add the terms on the superdiagonal (-sin)
    if k != f-band_number:
      mask_sin_coordinates[k].append([k+band_number-1,k+1+band_number-1])
    #then add the remaining (+sin) terms
    for i in range(k,f-band_number+1):
      for j in range(k):
        mask_sin_coordinates[k-1].append([i+band_number-1,j+band_number-1])

  def make_mask_matrices (mask_coordinates):
    f = len(mask_coordinates) + band_number
    print(f)
    mask_matrices = np.empty((f-band_number), object)
    for i in range(f-band_number):
      coordinates = mask_coordinates[i]
      num_coordinates = len(coordinates)
      data = jnp.ones(num_coordinates)
      mask_matrices[i] = sparse.BCOO((data, jnp.array(coordinates)), shape=(f, f))
    return mask_matrices

  mask_cos_exp_pos_matrices = make_mask_matrices(mask_cos_exp_pos_coordinates)
  mask_cos_exp_neg_matrices = make_mask_matrices(mask_cos_exp_neg_coordinates)
  mask_sin_matrices = make_mask_matrices(mask_sin_coordinates)

  return mask_cos_exp_pos_matrices,mask_cos_exp_neg_matrices,mask_sin_matrices
print(make_masks(5,1)[1][2])

5
5
5
BCOO(float32[5, 5], nse=3)


In [3]:
import jax.numpy as jnp

# Example arrays
x = jnp.array([1, 2, 3, 4, 5])
y = jnp.array([10, 20, 30, 40, 50])

# Condition
condition = x > 3

# Create new array based on condition
result = jnp.where(condition, x, y)
print(result)


[10 20 30  4  5]


In [23]:
#mask_exp_pos
term_index = jnp.arange(f-band_number)[:,jnp.newaxis, jnp.newaxis]
row_index = jnp.arange(f)[jnp.newaxis,:, jnp.newaxis]
col_index = jnp.arange(f)[jnp.newaxis,jnp.newaxis,:]

condition = (row_index == term_index + band_number-1) & (band_number-1 <= col_index) & (col_index <= row_index)

mask_exp_pos = jnp.where(condition, 1, 0)


[[[0 0 0 0]
  [0 1 0 0]
  [0 0 0 0]
  [0 0 0 0]]

 [[0 0 0 0]
  [0 0 0 0]
  [0 1 1 0]
  [0 0 0 0]]]


Array([[[0, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       [[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 1, 1, 0],
        [0, 0, 0, 0]]], dtype=int32, weak_type=True)

In [47]:
#mask_sin
f = 4
band_number = 2

term_index = jnp.arange(f-band_number)[:,jnp.newaxis, jnp.newaxis]
row_index = jnp.arange(f)[jnp.newaxis,:, jnp.newaxis]
col_index = jnp.arange(f)[jnp.newaxis,jnp.newaxis,:]

condition = ((band_number-1 <= col_index) & (col_index <= term_index + band_number-1) &
 (row_index >= term_index + band_number)|
((row_index == term_index + band_number-1) & (col_index == row_index +1))
)

mask_sing = jnp.where(condition, 1, 0)



[[[0 0 0 0]
  [0 0 1 0]
  [0 1 0 0]
  [0 1 0 0]]

 [[0 0 0 0]
  [0 0 0 0]
  [0 0 0 1]
  [0 1 1 0]]]


In [104]:
#masks
def make_masks(f,band_number):

  """build

  Args:
    f: the number of particles
    n_col: the number of columns of the masks we want to build. To build a full matrix we will want n_col = f
    band_number: index of the band in the unitary decomposition (ranges from 1 to number of nonzero eigenvalues)
  Returns:
  3 arrays of masks of shape (f-band_number,f,n_col)
  (m,2n(f-n)-n)
  """

  term_index = jnp.arange(f-band_number)[:,jnp.newaxis, jnp.newaxis]
  row_index = jnp.arange(f)[jnp.newaxis,:, jnp.newaxis]
  col_index = jnp.arange(f)[jnp.newaxis,jnp.newaxis,:]

  #conditions
  condition_cos_exp_pos = (row_index == term_index + band_number-1) & (band_number-1 <= col_index) & (col_index <= row_index)
  condition_cos_exp_neg = (col_index == term_index + band_number) & (row_index >= col_index)
  condition_sin = ((band_number-1 <= col_index) & (col_index <= term_index + band_number-1) & (row_index >= term_index + band_number)|
((row_index == term_index + band_number-1) & (col_index == row_index +1)))

  mask_cos_exp_pos = jnp.where(condition_cos_exp_pos, 1, 0)
  mask_cos_exp_neg = jnp.where(condition_cos_exp_neg, 1, 0)
  mask_sin = jnp.where(condition_sin, 1, 0)

  return mask_cos_exp_pos, mask_cos_exp_neg,mask_sin


Array([[[0, 1, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 0, 0, 0]],

       [[0, 0, 0, 0],
        [0, 0, 1, 0],
        [1, 1, 0, 0],
        [1, 1, 0, 0]],

       [[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 1],
        [1, 1, 1, 0]]], dtype=int32, weak_type=True)

In [62]:
def get_building_blocks(alphas:jnp.ndarray, betas:jnp.ndarray):
  """Get the spacetime points from the current parameters.

  Args:
    alphas: Values of the alpha angles used to define the unitary (m,2n(f-n)-n): real
    betas: Values of the beta angles used to define the unitary (m,2n(f-n)-n): real
  Returns:
  3 "building blocks" terms, basic trigonometric functions of the alphas and betas of the same size
  (m,2n(f-n)-n)
  """
  exp_alphas = jnp.exp(1J*alphas)
  cos_betas = jnp.cos(betas)
  sin_betas = jnp.sin(betas)
  cos_betas_exp_pos_alphas = exp_alphas*cos_betas
  cos_betas_exp_neg_alphas = jnp.conj(exp_alphas)* cos_betas

  return cos_betas_exp_pos_alphas, cos_betas_exp_neg_alphas, sin_betas


In [106]:
def make_single_band_unitary(alphas_band,betas_band,f):
  """Use the angle parameters and the masks to generate a unitary band matrix

    Args:
    f: dimension of the matrix (number of  particles)
    alphas_band: shape (f-band_number,)
    betas_band: shape (f-band_number,)

  Returns:
    (f, f) band  unitary matrix
    """
  band_number = f - len(alphas_band)

  #extract the building blocks and masks
  building_blocks = get_building_blocks(alphas_band, betas_band)
  masks = make_masks(f,band_number)

  #initialise the band matrix with a matrix with ones on lower triangle and superdiagonal
  ones_tril= jnp.tril(jnp.ones((f,f))) + jnp.eye(f,k = 1)
  band_matrix= ones_tril.copy()

  #iterate over the different masks for each building block
  num_masks = len(masks)
  for building_block_index in range(num_masks):
    mask = masks[building_block_index] #shape (f-band_number,f,f)
    building_block = building_blocks[building_block_index]
    band_matrix_building_block = mask*building_block[:,jnp.newaxis,jnp.newaxis]

    #add ones in the lower triangle and superdiagonal before multiplying  the matrices
    ones_tril_reshaped = jnp.tile(ones_tril,(f-band_number,1,1))
    band_matrix_building_block += ones_tril_reshaped - mask

    #multiply the matrices together
    band_matrix_building_block = jnp.prod(band_matrix_building_block, axis =0)

    #multiply the different  building blocks together
    band_matrix *=band_matrix_building_block

  #multiply the superdiagonal by -1 to represent  the negative  sines.
  final_mask = jnp.tril(jnp.ones((f,f))) - jnp.eye(f,k = 1)
  band_matrix *= final_mask #shape (f,f)

  # Add the diagonal 1s for the first band_number-1 bands
  diag_indices = jnp.arange(band_number-1)  # indices for the diagonals to set to 1
  band_matrix = band_matrix.at[diag_indices, diag_indices].set(1)

  return band_matrix

  make_band_unitary = vmap(make_single_band_unitary, in_axes=(0, 0, None))

In [108]:
f = 3
c = [jnp.pi,jnp.pi/2,2*jnp.pi]
d = [jnp.pi/4, jnp.pi/6, jnp.pi/3]

make_band_unitary = vmap(make_single_band_unitary, in_axes=(0, 0, None))
alphas = jnp.array([[jnp.pi,jnp.pi/2],[jnp.pi,jnp.pi/2]])
betas = jnp.array([[0,0],[0,0]])

alphas = jnp.array([[jnp.pi,jnp.pi/2]] )
betas = jnp.array([[0,0]])

print(jnp.int_(make_band_unitary(alphas,betas,f))+1J*jnp.int_((-1J)*make_band_unitary(alphas,betas,f)))



[[[-1.+0.j  0.+0.j  0.+0.j]
  [ 0.+0.j  0.-1.j  0.+0.j]
  [ 0.+0.j  0.+0.j  0.-1.j]]]


  return asarray(x, dtype=self.dtype)
  out_array: Array = lax_internal._convert_element_type(


In [125]:
def make_single_unitary(alphas,betas,f, n):
  """Use the angle parameters to build a unitary matrix through
  the matrix product of  successive band unitary band matrices
  Note: This is not any unitary but a unitary corresponding to the change of basis into
  the eigenvector basis of a matrix with n non zero eigenvalues counted with their multiplicity

    Args:
    f: dimension of the matrix (number of  particles)
    n: total number of eigenvalues
    alphas: shape (2n(f-n)-n,) WRONG BECAUSE DIFF DEF OF N
    betas: shape (2n(f-n)-n,) SAME HERE

  Returns:
    (f, f) unitary matrix
    """
  #extract all the parameters
  building_blocks = get_building_blocks(alphas, betas)

  #initialise the unitary
  unitary = jnp.eye(f)

  #iterate over the bands
  end_index = 0
  for band_number in range(1,n+1):
    #extract the correct alphas and betas
    start_index = end_index
    end_index =start_index + f - band_number
    alphas_band, betas_band = alphas[start_index:end_index],betas[start_index:end_index]


    #build the band unitary
    band_matrix = make_single_band_unitary(alphas_band,betas_band,f)

    #multiply the different unitaries together
    unitary = jnp.dot(unitary,band_matrix)

  return unitary

  #vectorize
  make_unitary = vmap(make_single_unitary, in_axes=(0, 0, None, None))

In [126]:
f = 4
n = 2
alphas = jnp.array([jnp.pi,jnp.pi/2,3/2*jnp.pi, 2*jnp.pi, jnp.pi])
betas = jnp.array([0,0,0,0,0])

print(alphas)

print(jnp.int_(make_single_unitary(alphas,betas,f,n))+1J*jnp.int_((-1J)*make_single_unitary(alphas,betas,f,n)))


[3.1415927 1.5707964 4.712389  6.2831855 3.1415927]
[3.1415927 1.5707964 4.712389 ]
[6.2831855 3.1415927]
[3.1415927 1.5707964 4.712389 ]
[6.2831855 3.1415927]
[[-1.+0.j  1.+0.j  0.+0.j  0.+0.j]
 [ 0.-1.j  0.-1.j  0.+0.j  0.+0.j]
 [-1.+0.j  0.+0.j  1.+0.j  0.+0.j]
 [ 0.+1.j  0.+0.j  0.+0.j  0.-1.j]]


In [127]:
#adapt for eigenvector
#masks

def make_masks(f,n_col, band_number):

  """build

  Args:
    f: the number of particles
    n_col: the number of columns of the masks we want to build. To build a full matrix we will want n_col = f
    band_number: index of the band in the unitary decomposition (ranges from 1 to number of nonzero eigenvalues)
  Returns:
  3 arrays of masks of shape (f-band_number,f,n_col)
  """

  term_index = jnp.arange(f-band_number)[:,jnp.newaxis, jnp.newaxis]
  row_index = jnp.arange(f)[jnp.newaxis,:, jnp.newaxis]
  col_index = jnp.arange(n_col)[jnp.newaxis,jnp.newaxis,:]

  #conditions
  condition_cos_exp_pos = (row_index == term_index + band_number-1) & (band_number-1 <= col_index) & (col_index <= row_index)
  condition_cos_exp_neg = (col_index == term_index + band_number) & (row_index >= col_index)
  condition_sin = ((band_number-1 <= col_index) & (col_index <= term_index + band_number-1) & (row_index >= term_index + band_number)|
((row_index == term_index + band_number-1) & (col_index == row_index +1)))

  mask_cos_exp_pos = jnp.where(condition_cos_exp_pos, 1, 0)
  mask_cos_exp_neg = jnp.where(condition_cos_exp_neg, 1, 0)
  mask_sin = jnp.where(condition_sin, 1, 0)

  return mask_cos_exp_pos, mask_cos_exp_neg,mask_sin


In [130]:
print(make_masks(4,2,1)[0])

[[[1 0]
  [0 0]
  [0 0]
  [0 0]]

 [[0 0]
  [1 1]
  [0 0]
  [0 0]]

 [[0 0]
  [0 0]
  [1 1]
  [0 0]]]


In [135]:
def make_single_band_unitary(alphas_band,betas_band,f,n_col):
  """Use the angle parameters and the masks to generate a unitary band matrix

    Args:

    alphas_band: shape (f-band_number,)
    betas_band: shape (f-band_number,)
    f: dimension of the matrix (number of  particles)
    n_col: number of columns of the unitary which we want to build

  Returns:
    (f, n_col) band  unitary matrix
    """
  band_number = f - len(alphas_band)

  #extract the building blocks and masks
  building_blocks = get_building_blocks(alphas_band, betas_band)
  masks = make_masks(f,n_col,band_number)

  #initialise the band matrix with a matrix with ones on lower triangle and superdiagonal
  ones_tril= jnp.tril(jnp.ones((f,f))) + jnp.eye(f,k = 1)
  ones_tril = ones_tril[:,:n_col]
  band_matrix= ones_tril.copy()

  #iterate over the different masks for each building block
  num_masks = len(masks)
  for building_block_index in range(num_masks):
    mask = masks[building_block_index] #shape (f-band_number,f,f)
    building_block = building_blocks[building_block_index]
    band_matrix_building_block = mask*building_block[:,jnp.newaxis,jnp.newaxis]

    #add ones in the lower triangle and superdiagonal before multiplying  the matrices
    ones_tril_reshaped = jnp.tile(ones_tril,(f-band_number,1,1))
    band_matrix_building_block += ones_tril_reshaped - mask

    #multiply the matrices together
    band_matrix_building_block = jnp.prod(band_matrix_building_block, axis =0)

    #multiply the different  building blocks together
    band_matrix *=band_matrix_building_block

  #multiply the superdiagonal by -1 to represent  the negative  sines.
  final_mask = jnp.tril(jnp.ones((f,f))) - jnp.eye(f,k = 1)
  band_matrix *= final_mask[:,:n_col] #shape (f,n_col)

  # Add the diagonal 1s for the first band_number-1 bands
  diag_indices = jnp.arange(band_number-1)  # indices for the diagonals to set to 1
  band_matrix = band_matrix.at[diag_indices, diag_indices].set(1)

  return band_matrix

  make_band_unitary = vmap(make_single_band_unitary, in_axes=(0, 0, None,None))

In [137]:
f = 4
n_col = 2
alphas = jnp.array([jnp.pi,jnp.pi/2,3/2*jnp.pi])
betas = jnp.array([0,0,0])

print(alphas)

print(jnp.int_(make_single_band_unitary(alphas,betas,f,n_col))+1J*jnp.int_((-1J)*make_single_band_unitary(alphas,betas,f,n_col)))

[3.1415927 1.5707964 4.712389 ]
[[-1.+0.j  0.+0.j]
 [ 0.+0.j  0.-1.j]
 [ 0.+0.j  0.+0.j]
 [ 0.+0.j  0.+0.j]]


  return asarray(x, dtype=self.dtype)
  out_array: Array = lax_internal._convert_element_type(


In [129]:
def make_single_eigenvectors(alphas,betas,f, n):
  """Use the angle parameters to build the 2n first eigenvectors of a spacetime point  x

    Args:
    f: dimension of the matrix (number of  particles)
    n: spin number hence  2n is total number of eigenvalues
    alphas: shape (2n(f-n)-n)
    betas: shape (2n(f-n)-n)

  Returns:
    (f,2n) matrix
    """
  #extract all the parameters
  num_alphas = len(alphas)
  building_blocks = get_building_blocks(alphas, betas)

  #initialise the eigenvectors:
  start_index = len(alphas)-(f-2*n)
  end_index = len(alphas)
  alphas_band, betas_band = alphas[start_index:end_index],betas[start_index:end_index]

  eigenvectors = make_single_band_unitary(alphas_band,betas_band,f,n_col = 2*n)

  #iterate over the remaining 2n-1 bands
  for band_number in range(1,2*n,-1):
    #extract the correct alphas and betas
    start_index = end_index
    end_index =start_index + f - band_number
    alphas_band, betas_band = alphas[start_index:end_index],betas[start_index:end_index]

    #build the band unitary
    band_matrix = make_single_band_unitary(alphas_band,betas_band,f,n_col = f)

    #multiply the unitaries and the vectors
    eigenvectors = jnp.dot(band_matrix,eigenvectors)

  return eigenvectors

  #vectorize
  make_eigenvectors = vmap(make_single_eigenvectors, in_axes=(0, 0, None,None))

In [138]:
f = 4

alphas = jnp.array([jnp.pi,jnp.pi/2,3/2*jnp.pi,2*jnp.pi,jnp.pi])
betas = jnp.array([0,0,0,0,0])


print(jnp.int_(make_single_eigenvectors(alphas,betas,f,n=1))+1J*jnp.int_((-1J)*make_single_eigenvectors(alphas,betas,f,n = 1)))

[3.1415927 1.5707964 4.712389  6.2831855 3.1415927]
[[ 1.+0.j -1.+0.j]
 [ 1.+0.j  1.+0.j]
 [ 1.+0.j  0.+0.j]
 [ 1.+0.j  0.+0.j]]


  return asarray(x, dtype=self.dtype)
  out_array: Array = lax_internal._convert_element_type(
