In [1]:
# speedy.ipynb
# Authors: Stephan Meighen-Berger
# Testing different rotation implementations

<a href="https://colab.research.google.com/github/mjg-phys/cdm-computing-subgroup/blob/main/advancedPythonTutorial/nuisance/notebooks/speedy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


In [None]:
import numpy as np
from numba import njit
import timeit

import jax
import jax.numpy as jnp

In [2]:
mixing_angles = np.array([
    [1, 2, np.arcsin(np.sqrt(0.312)) / np.pi * 180.],
    [1, 3, np.arcsin(np.sqrt(0.025)) / np.pi * 180., 0],  #  add cp violation here
    [2, 3, np.arcsin(np.sqrt(0.420)) / np.pi * 180.],
], dtype=object)

In [3]:
def rotmatrix_inefficient(dim: int, i: int, j: int, ang: float, cp: float):
    """ constructs a (Gell-Mann) rotational matrix ij with angle ang and cp violating phase cp
    for symmetry i < j is required

    Parameters
    ----------
    dim: int
        Dimensions of the matrix
    i, j: int
        Positions of the matrix
    ang: float
        Angle of roration in radians
    cp: float
        CP violating phase

    Returns
    -------
    rotmat: jax.numpy.array
        The rotation matrix

    Raises
    ------
    DimensionError:
        The input dimensions dim, i, j are wrong
    """
    # Building
    R = np.eye(dim, dtype=np.complex128)
    
    if cp == 0:
        for row in range(dim):
            for col in range(dim):
                if row == i and col == j:
                    R[row, col] = np.sin(ang)
                elif row == j and col == i:
                    R[row, col] = -np.sin(ang)
                elif row == col == i or row == col == j:
                    R[row, col] = np.cos(ang)
                else:
                    R[row, col] = 0.0
    else:
        for row in range(dim):
            for col in range(dim):
                if row == i and col == j:
                    R[row, col] = np.sin(ang) * np.exp(-1j * cp)
                elif row == j and col == i:
                    R[row, col] = -np.sin(ang) * np.exp(1j * cp)
                elif row == col == i or row == col == j:
                    R[row, col] = np.cos(ang)
                else:
                    R[row, col] = 0.0
    return R

def buildmixingmatrix_inefficient(params: np.ndarray) -> np.ndarray:
    """ constructs the mixing matrix from the input parameters
    for symmetry i < j is required

    For CP-violating factors, use tuples like (i,j,theta_ij,delta_ij),
    with delta_ij in degrees.

    Parameters
    ----------
    params: jax.numpy.array
        tuples descriping the mixing matrix.
        The format of each tuple should be [i, j, theta_ij]

    Returns
    -------
    mixing_matrix: jax.numpy.array
        The mixing matrix constructed in revers order.
        E.g. params = [(1,2,33.89),(1,3,9.12),(2,3,45.00)]
        => U = R_23 . R_13 . R_12
    """
    dim = max(np.array([par[1] for par in params]))
    U = np.eye(dim)
    # Applying the rotation matrices
    for par in params:
        if len(par) > 3:
            U = np.dot(rotmatrix_inefficient(dim, par[0] - 1, par[1] - 1, np.deg2rad(par[2]), np.deg2rad(par[3])), U)
        else:
            U = np.dot(rotmatrix_inefficient(dim, par[0] - 1, par[1] - 1, np.deg2rad(par[2]) , 0), U)
    return U

In [4]:
# basis transform
def rotmatrix(dim: int, i: int, j: int, ang: float, cp: float):
    """ constructs a (Gell-Mann) rotational matrix ij with angle ang and cp violating phase cp
    for symmetry i < j is required

    Parameters
    ----------
    dim: int
        Dimensions of the matrix
    i,j: int
        Positions of the matrix
    ang: float
        Angle of roration in radians
    cp: float
        CP violating phase

    Returns
    -------
    rotmat: jax.numpy.array
        The rotation matrix

    Raises
    ------
    DimensionError:
        The input dimensions dim, i, j are wrong
    """
    # Building
    if cp == 0:
        R = np.eye(dim)
        R[i, j] = np.sin(ang)
        R[j, i] = -np.sin(ang)
    else:
        R = np.eye(dim, dtype='complex128')
        R[i, j] = np.sin(ang) * np.exp(-1j * cp)
        R[j, i] = -np.sin(ang) * np.exp(1j * cp)
    R[i, i] = R[j, j] = np.cos(ang)
    return R

def buildmixingmatrix(params: np.ndarray) -> np.ndarray:
    """ constructs the mixing matrix from the input parameters
    for symmetry i < j is required

    For CP-violating factors, use tuples like (i,j,theta_ij,delta_ij),
    with delta_ij in degrees.

    Parameters
    ----------
    params: jax.numpy.array
        tuples descriping the mixing matrix.
        The format of each tuple should be [i, j, theta_ij]

    Returns
    -------
    mixing_matrix: jax.numpy.array
        The mixing matrix constructed in revers order.
        E.g. params = [(1,2,33.89),(1,3,9.12),(2,3,45.00)]
        => U = R_23 . R_13 . R_12
    """
    dim = max(np.array([par[1] for par in params]))
    U = np.eye(dim)
    # Applying the rotation matrices
    for par in params:
        if len(par) > 3:
            U = np.dot(rotmatrix(dim, par[0] - 1, par[1] - 1, np.deg2rad(par[2]), np.deg2rad(par[3])), U)
        else:
            U = np.dot(rotmatrix(dim, par[0] - 1, par[1] - 1, np.deg2rad(par[2]) , 0), U)
    return U

In [5]:
@njit
def rotmatrix_jit(dim: int, i: int, j: int, ang: float, cp: float) -> np.ndarray:
    """ constructs a (Gell-Mann) rotational matrix ij with angle ang
    and cp violating phase cp for symmetry i < j is required

    Parameters
    ----------
    dim: int
        Dimensions of the matrix
    i,j: int
        Positions of the matrix
    ang: float
        Angle of roration in radians
    cp: float
        CP violating phase

    Returns
    -------
    rotmat: jax.numpy.array
        The rotation matrix

    Raises
    ------
    DimensionError:
        The input dimensions dim, i, j are wrong
    """
    # Building
    R = np.eye(int(dim), dtype=np.complex128)
    R[i, j] = np.sin(ang) * np.exp(-1j * cp)
    R[j, i] = -np.sin(ang) * np.exp(1j * cp)
    R[i, i] = R[j, j] = np.cos(ang)
    return R

def buildmixingmatrix_jit(params: np.ndarray) -> np.ndarray:
    """ constructs the mixing matrix from the input parameters
    for symmetry i < j is required

    For CP-violating factors, use tuples like (i,j,theta_ij,delta_ij),
    with delta_ij in degrees.

    Parameters
    ----------
    params: jax.numpy.array
        tuples descriping the mixing matrix.
        The format of each tuple should be [i, j, theta_ij]

    Returns
    -------
    mixing_matrix: jax.numpy.array
        The mixing matrix constructed in revers order.
        E.g. params = [(1,2,33.89),(1,3,9.12),(2,3,45.00)]
        => U = R_23 . R_13 . R_12
    """
    dim = max(np.array([par[1] for par in params]))
    U = np.eye(dim)
    # Applying the rotation matrices
    for par in params:
        if len(par) > 3:
            U = np.dot(rotmatrix_jit(dim, par[0] - 1, par[1] - 1, np.deg2rad(par[2]), np.deg2rad(par[3])), U)
        else:
            U = np.dot(rotmatrix_jit(dim, par[0] - 1, par[1] - 1, np.deg2rad(par[2]) , 0), U)
    return U

In [6]:
@jax.jit
def rotmatrix_jax(i: int, j: int, ang: float, cp: float) -> jnp.ndarray:
    """ constructs a (Gell-Mann) rotational matrix ij with angle ang
    and cp violating phase cp for symmetry i < j is required

    Parameters
    ----------
    dim: int
        Dimensions of the matrix
    i, j: int
        Positions of the matrix
    ang: float
        Angle of rotation in radians
    cp: float
        CP violating phase

    Returns
    -------
    rotmat: jax.numpy.array
        The rotation matrix

    Raises
    ------
    DimensionError:
        The input dimensions dim, i, j are wrong
    """
    # Building
    R = jnp.eye(3, dtype=jnp.complex64)
    # Equal is required otherwise nothing happens!
    R = R.at[i, j].set(jnp.sin(ang) * jnp.exp(-1j * cp))
    R = R.at[j, i].set(-jnp.sin(ang) * jnp.exp(1j * cp))
    R = R.at[i, i].set(jnp.cos(ang))
    R = R.at[j, j].set(jnp.cos(ang))
    return R

def buildmixingmatrix_jax(params: np.ndarray) -> np.ndarray:
    """ constructs the mixing matrix from the input parameters
    for symmetry i < j is required

    For CP-violating factors, use tuples like (i,j,theta_ij,delta_ij),
    with delta_ij in degrees.

    Parameters
    ----------
    params: jax.numpy.array
        tuples descriping the mixing matrix.
        The format of each tuple should be [i, j, theta_ij]

    Returns
    -------
    mixing_matrix: jax.numpy.array
        The mixing matrix constructed in revers order.
        E.g. params = [(1,2,33.89),(1,3,9.12),(2,3,45.00)]
        => U = R_23 . R_13 . R_12
    """
    dim = 3
    U = jnp.eye(dim)
    # Applying the rotation matrices
    for par in params:
        if len(par) > 3:
            U = jnp.dot(rotmatrix_jax(par[0] - 1, par[1] - 1, np.deg2rad(par[2]), np.deg2rad(par[3])), U)
        else:
            U = jnp.dot(rotmatrix_jax(par[0] - 1, par[1] - 1, np.deg2rad(par[2]), 0), U)
    return U

In [7]:
# Chat GPT rewrote it to be wrong!
def testing_inefficient():
    buildmixingmatrix_inefficient(mixing_angles)
testing_inefficient()

In [8]:
print(timeit.timeit("testing_inefficient()", globals=locals(), number=1000000))
print(timeit.timeit("testing_inefficient()", globals=locals(), number=1))

22.070774467999968
4.8891000005824026e-05


In [9]:
def testing():
    buildmixingmatrix(mixing_angles)
testing()

In [10]:
print(timeit.timeit("testing()", globals=locals(), number=1000000))
print(timeit.timeit("testing()", globals=locals(), number=1))

14.541723359999992
4.708500000560889e-05


In [11]:
# Need to run jit once
def testing_jit():
    buildmixingmatrix_jit(mixing_angles)
testing_jit()

In [12]:
print(timeit.timeit("testing_jit()", globals=locals(), number=1000000))
print(timeit.timeit("testing_jit()", globals=locals(), number=1))

9.660977402999947
3.727099999650818e-05


In [13]:
# Need to run jit once
# Chat GPT rewrote it to be wrong!
def testing_jax():
    buildmixingmatrix_jax(mixing_angles)
testing_jax()

In [14]:
timeit.timeit("testing_jax()", globals=locals(), number=1)

0.0012558700000226963

In [15]:
# Got this from ChatGPT doesnt work!
# Index_update doesn't exist (anymore) for a while
# in a jitted function dimensions need to be static
# 
# import jax
# import jax.numpy as jnp
# 
# @jax.jit
# def rotmatrix_jit_optimized(dim: int, i: int, j: int, ang: float, cp: float) -> jnp.ndarray:
#     """ constructs a (Gell-Mann) rotational matrix ij with angle ang
#     and cp violating phase cp for symmetry i < j is required
# 
#     Parameters
#     ----------
#     dim: int
#         Dimensions of the matrix
#     i, j: int
#         Positions of the matrix
#     ang: float
#         Angle of rotation in radians
#     cp: float
#         CP violating phase
# 
#     Returns
#     -------
#     rotmat: jax.numpy.array
#         The rotation matrix
# 
#     Raises
#     ------
#     DimensionError:
#         The input dimensions dim, i, j are wrong
#     """
#     # Building
#     R = jnp.eye(dim, dtype=jnp.complex128)
#     R = jax.ops.index_update(R, jax.ops.index[i, j], jnp.sin(ang) * jnp.exp(-1j * cp))
#     R = jax.ops.index_update(R, jax.ops.index[j, i], -jnp.sin(ang) * jnp.exp(1j * cp))
#     R = jax.ops.index_update(R, jax.ops.index[i, i], jnp.cos(ang))
#     R = jax.ops.index_update(R, jax.ops.index[j, j], jnp.cos(ang))
#     return R