# Verify spins transformation in Jim against LALSimulation

In [65]:
import sys
from pathlib import Path

from lal import MSUN_SI
from lalsimulation import \
    SimInspiralTransformPrecessingNewInitialConditions, \
    SimInspiralTransformPrecessingWvf2PE


## Define Jim functions in-place

In [72]:
## Little hack to get pass Jax and Jim
from lal import MTSUN_SI
jnp = np
MTSUN = MTSUN_SI

def Mc_q_to_m1_m2(M_c: Float, q: Float) -> tuple[Float, Float]:
    """
    Transforming the chirp mass M_c and mass ratio q to the primary mass m1 and
    secondary mass m2.

    Parameters
    ----------
    M_c : Float
            Chirp mass.
    q : Float
            Mass ratio.

    Returns
    -------
    m1 : Float
            Primary mass.
    m2 : Float
            Secondary mass.
    """
    eta = q / (1 + q) ** 2
    M_tot = M_c / eta ** (3.0 / 5)
    m1 = M_tot / (1 + q)
    m2 = m1 * q
    return m1, m2

def rotate_y(angle, vec):
    """
    Rotate the vector (x, y, z) about y-axis
    """
    cos_angle = jnp.cos(angle)
    sin_angle = jnp.sin(angle)
    rotation_matrix = jnp.array(
            [[cos_angle, 0, sin_angle], [0, 1, 0], [-sin_angle, 0, cos_angle]]
    )
    rotated_vec = jnp.dot(rotation_matrix, vec)
    return rotated_vec


def rotate_z(angle, vec):
    """
    Rotate the vector (x, y, z) about z-axis
    """
    cos_angle = jnp.cos(angle)
    sin_angle = jnp.sin(angle)
    rotation_matrix = jnp.array(
        [[cos_angle, -sin_angle, 0], [sin_angle, cos_angle, 0], [0, 0, 1]]
    )
    rotated_vec = jnp.dot(rotation_matrix, vec)
    return rotated_vec


def Lmag_2PN(m1, m2, v0):
    """
    Compute the magnitude of the orbital angular momentum
    to 2 post-Newtonian orders.
    """
    eta = m1 * m2 / (m1 + m2)**2
    ## Simplified from:
    ## (m1 + m2) * (m1 + m2) * eta = m1 * m2
    LN = m1 * m2 / v0
    L_2PN = 1.5 + eta / 6.0
    return LN * (1.0 + v0 * v0 * L_2PN)


def spin_angles_to_cartesian_spin(
    theta_jn: Float,
    phi_jl: Float,
    tilt_1: Float,
    tilt_2: Float,
    phi_12: Float,
    chi_1: Float,
    chi_2: Float,
    M_c: Float,
    q: Float,
    fRef: Float,
    phiRef: Float,
) -> tuple[Float, Float, Float, Float, Float, Float, Float]:
    """
    Transforming the spin parameters

    The code is based on the approach used in LALsimulation:
    https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group__lalsimulation__inference.html

    Parameters:
    -------
    theta_jn: Float
        Zenith angle between the total angular momentum and the line of sight
    phi_jl: Float
        Difference between total and orbital angular momentum azimuthal angles
    tilt_1: Float
        Zenith angle between the spin and orbital angular momenta for the primary object
    tilt_2: Float
        Zenith angle between the spin and orbital angular momenta for the secondary object
    phi_12: Float
        Difference between the azimuthal angles of the individual spin vector projections
        onto the orbital plane
    chi_1: Float
        Primary object aligned spin:
    chi_2: Float
        Secondary object aligned spin:
    M_c: Float
        The chirp mass
    eta: Float
        The symmetric mass ratio
    fRef: Float
        The reference frequency
    phiRef: Float
        Binary phase at a reference frequency

    Returns:
    -------
    iota: Float
        Zenith angle between the orbital angular momentum and the line of sight
    S1x: Float
        The x-component of the primary spin
    S1y: Float
        The y-component of the primary spin
    S1z: Float
        The z-component of the primary spin
    S2x: Float
        The x-component of the secondary spin
    S2y: Float
        The y-component of the secondary spin
    S2z: Float
        The z-component of the secondary spin
    """

    # Starting frame: LNh along the z-axis
    # S1hat on the x-z plane
    LNh = jnp.array([0.0, 0.0, 1.0])

    # Define the spin vectors in the LNh frame
    s1hat = jnp.array(
        [
            jnp.sin(tilt_1) * jnp.cos(phiRef),
            jnp.sin(tilt_1) * jnp.sin(phiRef),
            jnp.cos(tilt_1)
        ]
    )
    s2hat = jnp.array(
        [
            jnp.sin(tilt_2) * jnp.cos(phi_12 + phiRef),
            jnp.sin(tilt_2) * jnp.sin(phi_12 + phiRef),
            jnp.cos(tilt_2)
        ]
    )

    m1, m2 = Mc_q_to_m1_m2(M_c, q)
    v0 = jnp.cbrt((m1 + m2) * MTSUN * jnp.pi * fRef)

    #Define S1, S2, and J
    Lmag = Lmag_2PN(m1, m2, v0)
    s1 = m1 * m1 * chi_1 * s1hat
    s2 = m2 * m2 * chi_2 * s2hat
    J = s1 + s2 + jnp.array([0.0, 0.0, Lmag])

    # Normalize J, and find theta0 and phi0 (the angles in starting frame)
    Jhat = J / jnp.linalg.norm(J)
    theta0 = jnp.arccos(Jhat[2])
    phi0 = jnp.arctan2(Jhat[1], Jhat[0])

    # Rotation 1: Rotate about z-axis by -phi0
    s1hat = rotate_z(-phi0, s1hat)
    s2hat = rotate_z(-phi0, s2hat)

    # Rotation 2: Rotate about y-axis by -theta0
    LNh = rotate_y(-theta0, LNh)
    s1hat = rotate_y(-theta0, s1hat)
    s2hat = rotate_y(-theta0, s2hat)

    # Rotation 3: Rotate about z-axis by -phi_jl
    LNh = rotate_z(phi_jl - jnp.pi, LNh)
    s1hat = rotate_z(phi_jl - jnp.pi, s1hat)
    s2hat = rotate_z(phi_jl - jnp.pi, s2hat)

    # Compute iota
    N = jnp.array([0.0, jnp.sin(theta_jn), jnp.cos(theta_jn)])
    iota = jnp.arccos(jnp.dot(N, LNh))

    thetaLJ = jnp.arccos(LNh[2])
    phiL = jnp.arctan2(LNh[1], LNh[0])

    # Rotation 4: Rotate about z-axis by -phiL
    s1hat = rotate_z(-phiL, s1hat)
    s2hat = rotate_z(-phiL, s2hat)
    N = rotate_z(-phiL, N)

    # Rotation 5: Rotate about y-axis by -thetaLJ
    s1hat = rotate_y(-thetaLJ, s1hat)
    s2hat = rotate_y(-thetaLJ, s2hat)
    N = rotate_y(-thetaLJ, N)

    # Rotation 6:
    phiN = jnp.arctan2(N[1], N[0])
    s1hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s1hat)
    s2hat = rotate_z(jnp.pi / 2.0 - phiN - phiRef, s2hat)

    S1 = s1hat * chi_1
    S2 = s2hat * chi_2
    return iota, S1[0], S1[1], S1[2], S2[0], S2[1], S2[2]


def cartesian_spin_to_spin_angles(
        iota: Float,
        S1x: Float,
        S1y: Float,
        S1z: Float,
        S2x: Float,
        S2y: Float,
        S2z: Float,
        M_c: Float,
        q: Float,
        fRef: Float,
        phiRef: Float,
) -> tuple[Float, Float, Float, Float, Float, Float, Float]:
    """
    Transforming the cartesian spin parameters to the spin angles

    The code is based on the approach used in LALsimulation:
    https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/group__lalsimulation__inference.html

    Parameters:
    -------
    iota: Float
       Zenith angle between the orbital angular momentum and the line of sight
    S1x: Float
        The x-component of the primary spin
    S1y: Float
        The y-component of the primary spin
    S1z: Float
        The z-component of the primary spin
    S2x: Float
        The x-component of the secondary spin
    S2y: Float
        The y-component of the secondary spin
    S2z: Float
        The z-component of the secondary spin
    M_c: Float
        The chirp mass
    q: Float
        The mass ratio
    fRef: Float
        The reference frequency
    phiRef: Float
        The binary phase at the reference frequency

    Returns:
    -------
    theta_jn: Float
    Zenith angle between the total angular momentum and the line of sight
    phi_jl: Float
        Difference between total and orbital angular momentum azimuthal angles
    tilt_1: Float
        Zenith angle between the spin and orbital angular momenta for the primary object
    tilt_2: Float
        Zenith angle between the spin and orbital angular momenta for the secondary object
    phi_12: Float
        Difference between the azimuthal angles of the individual spin vector projections
        onto the orbital plane
    chi_1: Float
        Primary object aligned spin:
    chi_2: Float
        Secondary object aligned spin:
    """
    # Starting frame: LNh along the z-axis
    LNh = jnp.array([0.0, 0.0, 1.0])

    # Define the dimensionless component spin vectors and magnitudes
    s1_vec = jnp.array([S1x, S1y, S1z])
    s2_vec = jnp.array([S2x, S2y, S2z])
    chi_1 = jnp.linalg.norm(s1_vec)
    chi_2 = jnp.linalg.norm(s2_vec)

    # Define the spin unit vectors in the LNh frame
    if chi_1 > 0:
        s1hat = s1_vec / chi_1
    else:
        s1hat = jnp.array([0.0, 0.0, 0.0])
    if chi_2 > 0:
        s2hat = s2_vec / chi_2
    else:
        s2hat = jnp.array([0.0, 0.0, 0.0])

    # Azimuthal and polar angles of the spin vectors
    phi1 = jnp.arctan2(s1hat[1], s1hat[0])
    phi2 = jnp.arctan2(s2hat[1], s2hat[0])

    phi_12 = phi2 - phi1

    if phi_12 < 0:
        phi_12 += 2 * jnp.pi

    tilt_1 = jnp.arccos(s1hat[2])
    tilt_2 = jnp.arccos(s2hat[2])

    # Get angles in the J-N frame
    m1, m2 = Mc_q_to_m1_m2(M_c, q)
    total_mass = m1 + m2
    v0 = jnp.cbrt(total_mass * MTSUN * jnp.pi * fRef)

    # Define S1, S2, J
    S1 = m1 * m1 * s1_vec
    S2 = m2 * m2 * s2_vec

    Lmag = Lmag_2PN(m1, m2, v0)
    J = S1 + S2 + Lmag * LNh

    # Normalize J
    Jhat = J / jnp.linalg.norm(J)

    thetaJL = jnp.arccos(Jhat[2])
    phiJ = jnp.arctan2(Jhat[1], Jhat[0])

    # Azimuthal angle from phase angle
    phi0 = 0.5 * jnp.pi - phiRef
    # Line-of-sight vector in L-frame
    N = jnp.array(
            [
                jnp.sin(iota) * jnp.cos(phi0),
                jnp.sin(iota) * jnp.sin(phi0),
                jnp.cos(iota)
            ]
    )

    # Inclination w.r.t. J
    theta_jn = jnp.arccos(jnp.dot(Jhat, N))

    # Rotate from L-frame to J-frame
    N = rotate_z(-phiJ, N)
    N = rotate_y(-thetaJL, N)

    LNh = rotate_z(-phiJ, LNh)
    LNh = rotate_y(-thetaJL, LNh)

    phiN = jnp.arctan2(N[1], N[0])
    LNh = rotate_z(0.5 * jnp.pi - phiN, LNh)

    phi_jl = jnp.arctan2(LNh[1], LNh[0])
    if phi_jl < 0:
        phi_jl += 2 * jnp.pi

    return theta_jn, phi_jl, tilt_1, tilt_2, phi_12, chi_1, chi_2

## Backward transform (Components to Angles)

In [62]:
rng = np.random.default_rng()

n_points = 100
params = np.vstack([
    rng.uniform(2, 100, n_points),      # chirp mass
    rng.uniform(0.1, 1, n_points),      # mass ratio
    rng.uniform(-1, 1, n_points),       # spin 1x
    rng.uniform(-1, 1, n_points),       # spin 1y
    rng.uniform(-1, 1, n_points),       # spin 1z
    rng.uniform(-1, 1, n_points),       # spin 2x
    rng.uniform(-1, 1, n_points),       # spin 2y
    rng.uniform(-1, 1, n_points),       # spin 2z
    rng.uniform(10, 300, n_points),     # reference frequency
    rng.uniform(0, np.pi, n_points),    # iota
    rng.uniform(0, 2*np.pi, n_points),  # phase
]).T

atol = 1e-18
rtol = 1e-15

print(f'Begin test with abs.tol: {atol:.1e} and rel.tol: {rtol:.1e}')
print(f'Begin test with abs.tol: {atol:.1e} and rel.tol: {rtol:.1e}')

for row in params:
    chirp_mass, mass_ratio, \
        spin_1x, spin_1y, spin_1z, \
        spin_2x, spin_2y, spin_2z, \
        f_ref, iota, phase = row
    # Compute Jim result
    jim_result = cartesian_spin_to_spin_angles(
        iota, spin_1x, spin_1y, spin_1z, \
        spin_2x, spin_2y, spin_2z, \
        chirp_mass, mass_ratio, f_ref, phase)
    # Convert chirp mass to component mass
    mass_1, mass_2 = Mc_q_to_m1_m2(chirp_mass, mass_ratio)
    # Compute LAL result
    lal_result = SimInspiralTransformPrecessingWvf2PE(
        iota, spin_1x, spin_1y, spin_1z, \
        spin_2x, spin_2y, spin_2z, \
        mass_1, mass_2, f_ref, phase)
    
    is_equal = np.all(np.isclose(jim_result, lal_result, atol=atol, rtol=rtol))
    if not is_equal:
        print('-----------------------------------------------------')
        print('This one fails ↓')
        print('Input:  ' + ('{:.4f}  ' * 11).format(*row))
        diff = np.array(lal_result) - np.array(jim_result)
        print('Output diff: ' + ('{:.3e}  ' * 7).format(*diff))
print('-----------------------------------------------------')

Begin test with abs.tol: 1.0e-18 and rel.tol: 1.0e-15
Begin test with abs.tol: 1.0e-18 and rel.tol: 1.0e-15
-----------------------------------------------------
This one fails ↓
Input:  11.8971  0.7758  0.5176  -0.7164  0.4866  0.0715  0.6946  0.1356  82.7228  1.3398  4.0470  
Output diff: 0.000e+00  -4.996e-16  0.000e+00  0.000e+00  4.441e-16  -2.220e-16  0.000e+00  
-----------------------------------------------------
This one fails ↓
Input:  28.1024  0.7745  0.2853  0.0617  -0.6425  -0.9804  0.8564  -0.7462  262.0770  1.8339  1.8908  
Output diff: 0.000e+00  5.551e-16  4.441e-16  0.000e+00  0.000e+00  0.000e+00  0.000e+00  
-----------------------------------------------------
This one fails ↓
Input:  62.7490  0.4954  0.0690  -0.0338  0.2364  -0.0769  0.9920  0.9999  38.2152  0.1267  5.6020  
Output diff: 1.388e-17  -4.441e-16  -3.886e-16  0.000e+00  0.000e+00  -2.776e-17  0.000e+00  
-----------------------------------------------------
This one fails ↓
Input:  16.3182  0.7710  -

## Forward transform (Angles to Components)

**Remark**: Jim assumes solar mass for chirp_mass input, while LAL assumes SI units (kg) for component masses.

In [71]:
rng = np.random.default_rng()

n_points = 100
params = np.vstack([
    rng.uniform(2, 100, n_points),      # chirp mass
    rng.uniform(0.01, 0.25, n_points),  # symmetric mass ratio
    rng.uniform(0, 1, n_points),        # a_1
    rng.uniform(0, 1, n_points),        # a_2
    rng.uniform(0, np.pi, n_points),    # tilt_1
    rng.uniform(0, np.pi, n_points),    # tilt_2
    rng.uniform(0, 2*np.pi, n_points),  # phi_12
    rng.uniform(0, 2*np.pi, n_points),  # phi_jl
    rng.uniform(10, 300, n_points),     # reference frequency
    rng.uniform(0, np.pi, n_points),    # theta_jn
    rng.uniform(0, 2*np.pi, n_points),  # phase
]).T

atol = 1e-16
rtol = 1e-13

print(f'Begin test with abs.tol: {atol:.1e} and rel.tol: {rtol:.1e}')

for row in params:
    chirp_mass, mass_ratio, \
        a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl, \
        f_ref, theta_jn, phase = row
    # Compute Jim result
    jim_result = spin_angles_to_cartesian_spin(
        theta_jn, phi_jl, tilt_1, tilt_2, \
        phi_12, a_1, a_2, \
        chirp_mass, mass_ratio, f_ref, phase)
    # Convert chirp mass to component mass
    mass_1, mass_2 = Mc_q_to_m1_m2(chirp_mass, mass_ratio)
    # Compute LAL result
    lal_result = SimInspiralTransformPrecessingNewInitialConditions(
        theta_jn, phi_jl, tilt_1, tilt_2, \
        phi_12, a_1, a_2, \
        mass_1*MSUN_SI, mass_2*MSUN_SI, f_ref, phase)
    
    is_equal = np.all(np.isclose(jim_result, lal_result, atol=atol, rtol=rtol))
    if not is_equal:
        print('-----------------------------------------------------')
        print('This one fails ↓')
        print('Input:  ' + ('{:.4f}  ' * 11).format(*row))
        diff = np.array(lal_result) - np.array(jim_result)
        print('Output diff: ' + ('{:.3e}  ' * 7).format(*diff))
print('-----------------------------------------------------')

Begin test with abs.tol: 1.0e-16 and rel.tol: 1.0e-13
-----------------------------------------------------
This one fails ↓
Input:  86.9934  0.2375  0.2450  0.9479  0.0324  2.8610  4.7945  0.5569  299.3490  2.1282  3.2522  
Output diff: 0.000e+00  1.106e-15  5.586e-16  -1.388e-16  -4.219e-15  -1.461e-15  -6.661e-16  
-----------------------------------------------------
This one fails ↓
Input:  39.7141  0.0116  0.6685  0.2404  0.0339  1.4745  3.0423  0.4697  205.1901  0.5042  5.3020  
Output diff: -5.551e-17  -1.050e-15  -4.458e-15  -6.661e-16  2.734e-15  -8.882e-16  1.624e-15  
-----------------------------------------------------
This one fails ↓
Input:  89.0694  0.1813  0.3462  0.0000  0.0067  2.8994  6.1739  0.5009  247.5381  2.7641  3.0933  
Output diff: -3.775e-14  -2.090e-16  3.357e-16  -5.551e-17  -5.336e-19  1.116e-18  6.776e-21  
-----------------------------------------------------
This one fails ↓
Input:  84.9090  0.1344  0.9451  0.9321  1.3132  1.2855  5.1284  1.3235  179