In [64]:
import jaxoplanet
import numpy as np
from numpy.linalg import inv
from scipy.sparse import load_npz

In [65]:
import starry
starry.config.lazy = False

In [66]:
jaxoplanet.__path__

['/Users/SuperTiger/ADACS/repo/SS2023A-BPope/jaxoplanet/bpope_2023a/src/jaxoplanet']

In [67]:
from jaxoplanet.light_curves import LimbDarkLightCurve, QuadLightCurve

In [68]:
from jaxoplanet.orbits import KeplerianOrbit, TransitOrbit

In [69]:
from jaxoplanet._src.core.limb_dark import light_curve as limb_dark_light_curve
from jaxoplanet._src.core.quad import light_curve as quad_light_curve

In [70]:
from jaxoplanet._src.experimental.starry.rotation import *
from jaxoplanet._src.experimental.starry.solution import *
from jaxoplanet._src.experimental.starry.basis import *

In [71]:
params = {
    "log_f0": 0.0,
    "u": jnp.array([0.3, 0.2]),
    "log_duration": jnp.log(0.12),
    "t0": 0.0,
    "b": 0.1,
    "log_r": jnp.log(0.1),
    "log_amp": jnp.log(0.002),
    "log_ell": jnp.log(0.02),
}


In [72]:
lc = QuadLightCurve.init(u1=params["u"][0], u2=params["u"][1])

In [73]:
orbit = TransitOrbit.init(
        period=1.0,
        duration=jnp.exp(params["log_duration"]),
        time_transit=params["t0"],
        impact_param=params["b"],
        radius=jnp.exp(params["log_r"]),
    )

In [74]:
import jax
import jax.numpy as jnp

In [75]:
@jax.jit
def test_light_curve(params, t, period=1.0):
    u1=params["u"][0]
    u2=params["u"][1]
    orbit = TransitOrbit.init(
        period=period,
        duration=jnp.exp(params["log_duration"]),
        time_transit=params["t0"],
        impact_param=params["b"],
        radius=jnp.exp(params["log_r"]),
    )
    x, y, z = orbit.relative_position(t)
    b = jnp.sqrt(x**2 + y**2)
    r_star = orbit.central_radius
    r = orbit.radius / r_star
    lc_func = partial(quad_light_curve, u1, u2)
#     lc_func = quad_light_curve(u1, u2)
    if orbit.shape == ():
            b /= r_star
            lc = lc_func(b, r)
    else:
        b /= r_star[..., None]
        print("b", b.shape)
        print("r", r.shape)
        
        lc = jnp.vectorize(lc_func)(b, r)
        print("lc", lc.shape)
    return lc
#     return jnp.where(z > 0, lc, 0)

In [76]:
@jax.jit
def light_curve(params, t, period=1.0):
    lc = QuadLightCurve.init(u1=params["u"][0], u2=params["u"][1])
    orbit = TransitOrbit.init(
        period=period,
        duration=jnp.exp(params["log_duration"]),
        time_transit=params["t0"],
        impact_param=params["b"],
        radius=jnp.exp(params["log_r"]),
    )
    return lc.light_curve(orbit, t)[0]
#     return jnp.exp(params["log_f0"]) * (1 + lc.light_curve(orbit, t)[0])

In [77]:
t_grid = jnp.linspace(-0.3, 0.3, 1000)
t = jnp.linspace(-0.2, 0.2, 75)
y_err = 0.001

In [78]:
lc = test_light_curve(params, t)

b (1, 75)
r (1,)
lc (1, 75)


In [79]:
lc.shape

(1, 75)

### input (for testing)

In [80]:
fixtures = "/Users/SuperTiger/ADACS/repo/SS2023A-BPope/starry/tests/fixtures/"

In [81]:
x = np.load(fixtures + "ylm_x_xo.npy").flatten()
y = np.load(fixtures + "ylm_x_yo.npy").flatten()
z = np.load(fixtures + "ylm_x_zo.npy").flatten()
theta = np.load(fixtures + "ylm_x_theta.npy")

In [82]:
theta.shape

(810,)

In [83]:
ro = 10.
inc = 1.57080
obl = 0.0

In [84]:
iocc = np.load(fixtures + "ylm_x_iocc.npy")
rta1_1 = np.load(fixtures + "ylm_x_rta1_1.npy")
rta1_2 = np.load(fixtures + "ylm_x_rta1_2.npy")
x_1 = np.load(fixtures + "ylm_x_X_1.npy")
a = load_npz(fixtures + "ylm_x_A.npz")
st = np.load(fixtures + "ylm_st_f.npy")
sta = np.load(fixtures + "ylm_sta_f.npy")
theta_z = np.load(fixtures + "ylm_x_theta_z.npy")
star = np.load(fixtures + "ylm_star.npy")
x_2 = np.load(fixtures + "ylm_x_X_2.npy")

In [85]:
x_2.shape

(810, 36)

In [86]:
np.array_equal(occ_idx, iocc)

NameError: name 'occ_idx' is not defined

In [87]:
np.array_equal(y, y_occ)

NameError: name 'y_occ' is not defined

### implement calcluations in OpsYlm.X

In [88]:
b = np.sqrt(x**2 + y**2)

In [89]:
b_iocc = b[iocc]

#### sT

In [90]:
st_func = jax.jit(solution_vector(5))

In [91]:
st_res = st_func(b_iocc,ro)

In [92]:
st_res.shape

(810, 36)

In [93]:
st.shape

(810, 36)

In [94]:
np.allclose(np.array(st_res), st, atol=1.e-5)

True

#### A

In [95]:
a1 = A1(5)

In [96]:
a1.shape

(36, 36)

In [97]:
a2_inv = A2_inv(5)

In [98]:
a2 = inv(a2_inv)

In [99]:
a_res = np.dot(a2, a1)

In [100]:
a.shape

(36, 36)

##### verification

In [101]:
starry_a = a.toarray()

In [102]:
np.allclose(a_res*2/np.sqrt(np.pi), starry_a)

True

In [103]:
a_norm = a_res*2/np.sqrt(np.pi)

#### sTA

In [104]:
sta_res = jnp.dot(st_res, a_norm)

In [105]:
np.allclose(np.array(sta_res), sta, atol=1.e-4)

True

#### theta_z

In [106]:
theta_z_res = jnp.arctan2(x[iocc], y[iocc])

In [107]:
np.allclose(np.array(theta_z_res), theta_z)

True

#### sTAR

In [108]:
u_axis = [0,0,1]

In [109]:
rfull_func = jax.jit(R_full(5, u_axis))

In [110]:
r_time = rfull_func(theta_z_res)

##### computeRz

In [111]:
# @jax.jit
def computeRz(theta,deg,N):
    # Length of timeseries
    npts = theta.size

    # Compute sin & cos
    costheta = np.cos(theta)
    sintheta = np.sin(theta)

    # Initialize our z rotation vectors
    cosnt = np.zeros((npts, max(2, deg + 1)))
    cosnt[:,0] = 1.0
    sinnt = np.zeros((npts, max(2, deg + 1)))
    cosnt[:,1] = costheta
    sinnt[:,1] = sintheta
    cosmt = np.zeros((npts, N))
    sinmt = np.zeros((npts, N))

    # Compute the cos and sin vectors for the z-hat rotation
    for n in range(2, deg + 1):
        cosnt[:,n] = 2.0 * cosnt[:,n-1] * costheta - cosnt[:,n-2]
        sinnt[:,n] = 2.0 * sinnt[:,n-1] * costheta - sinnt[:,n-2]
    n = 0
    for l in range(deg + 1):
        for m in range(-l, 0):
            cosmt[:,n] = cosnt[:,-m]
            sinmt[:,n] = -sinnt[:,-m]
            n += 1
        for m in range(l + 1):
            cosmt[:,n] = cosnt[:,m]
            sinmt[:,n] = sinnt[:,m]
            n += 1

    # Set output
    return (cosmt, sinmt)

In [112]:
def computeRz_j(theta,deg, N):
    # Length of timeseries
    npts = theta.size
    nt_col = max(2, deg + 1)
    
    
    @partial(jnp.vectorize, signature=f"()->({npts},{N}),({npts},{N})")
    def _computeRz(x):
        # Compute sin & cos
        costheta = np.cos(theta)
        sintheta = np.sin(theta)

        # Initialize our z rotation vectors
        cosnt = np.zeros((npts, nt_col))
        cosnt[:,0] = 1.0
        sinnt = np.zeros((npts, nt_col))
        cosnt[:,1] = costheta
        sinnt[:,1] = sintheta
        cosmt = np.zeros((npts, N))
        sinmt = np.zeros((npts, N))

        # Compute the cos and sin vectors for the z-hat rotation
        for n in range(2, deg + 1):
            cosnt[:,n] = 2.0 * cosnt[:,n-1] * costheta - cosnt[:,n-2]
            sinnt[:,n] = 2.0 * sinnt[:,n-1] * costheta - sinnt[:,n-2]
        
        n = 0
        for l in range(deg + 1):
            for m in range(-l, 0):
                cosmt[:,n] = cosnt[:,-m]
                sinmt[:,n] = -sinnt[:,-m]
                n += 1
            for m in range(l + 1):
                cosmt[:,n] = cosnt[:,m]
                sinmt[:,n] = sinnt[:,m]
                n += 1
        return (cosmt, sinmt)
    return _computeRz

In [113]:
crz_func = jax.jit(computeRz_j(theta_z, 5, 36))

In [114]:
cm,sm = crz_func(5)

In [115]:
cm.shape

(810, 36)

In [116]:
cmt, smt = computeRz(theta_z,5,36)

In [117]:
np.allclose(cmt, np.array(cm))

True

In [118]:
np.allclose(smt, np.array(sm))

True

##### tensordotRz

In [119]:
def tensordotRz(M, cosmt, sinmt, theta):
#     global tensordotRz_result, cosmt, sinmt, deg
    # Shape checks
    npts = len(theta)
    Nr = M.shape[1]
    degr = int(np.sqrt(Nr) - 1)

    # Compute the sin & cos matrices
#     computeRz(theta)

    # Init result
    tensordotRz_result = np.zeros((npts, Nr))
#     if (npts == 0):
#         return

    # Dot them in
    for l in range(degr + 1):
        for j in range(2 * l + 1):
            if (M.shape[0] == 1):
                tensordotRz_result[:, l * l + j] = M[0, l * l + j] * cosmt[:, l * l + j] + M[0, l * l + 2 * l - j] * sinmt[:, l * l + j]
            else:
                tensordotRz_result[:, l * l + j] = np.multiply(M[:, l * l + j], cosmt[:, l * l + j]) + np.multiply(M[:, l * l + 2 * l - j], sinmt[:, l * l + j])
    return tensordotRz_result

In [120]:
tdrz = tensordotRz(sta, cmt, smt, theta_z)

In [121]:
tdrz.shape

(810, 36)

In [122]:
np.allclose(tdrz, star)

True

In [123]:
cmt.shape

(810, 36)

In [124]:
def tensordotRz_j(M, cosmt, sinmt, theta):
    # Shape checks
    npts = len(theta)
    Nr = cosmt.shape[1]
    degr = int(np.sqrt(Nr) - 1)
    
    @partial(jnp.vectorize, signature=f"()->({npts},{Nr})")
    def _tensordotRz(x):
        # Init result
        tensordotRz_result = np.zeros((npts, Nr))

        # Dot them in
        for l in range(degr + 1):
            for j in range(2 * l + 1):
                if (M.shape[0] == 1):
                    tensordotRz_result[:, l * l + j] = M[0, l * l + j] * cosmt[:, l * l + j] + M[0, l * l + 2 * l - j] * sinmt[:, l * l + j]
                else:
                    tensordotRz_result[:, l * l + j] = M[:, l * l + j] * cosmt[:, l * l + j] + M[:, l * l + 2 * l - j] * sinmt[:, l * l + j]

        return tensordotRz_result
    return _tensordotRz

In [125]:
def tensordotRz_j(M, l, theta):
    # Shape checks
    npts = len(theta)
    degr = l
    Nr = (l + 1)**2
#     Nr = cosmt.shape[1]
#     degr = int(np.sqrt(Nr) - 1)
    @jax.jit
    @partial(jnp.vectorize, signature=f"()->({npts},{Nr})")
    def _tensordotRz(x):
        # Init result
        tensordotRz_result = jnp.zeros((npts, Nr))
        
        # calculate Rz
        cosmt, sinmt = computeRz(theta, degr, Nr)
        
        # Dot them in
        for l in range(degr + 1):
            for j in range(2 * l + 1):
                if (M.shape[0] == 1):
                    tensordotRz_result = tensordotRz_result.at[:, l * l + j].set(M[0, l * l + j] * cosmt[:, l * l + j] + M[0, l * l + 2 * l - j] * sinmt[:, l * l + j])
                else:
                    tensordotRz_result = tensordotRz_result.at[:, l * l + j].set(M[:, l * l + j] * cosmt[:, l * l + j] + M[:, l * l + 2 * l - j] * sinmt[:, l * l + j])
        return tensordotRz_result
    return _tensordotRz

In [126]:
trz_func = tensordotRz_j(sta, 5, theta_z)

In [127]:
srz_j = trz_func(3)

In [128]:
srz_j

DeviceArray([[ 9.9054283e-01,  7.8796223e-03,  1.1512566e+00, ...,
               3.2623989e-18,  5.3955959e-03, -1.1990021e-02],
             [ 9.6892178e-01,  2.4646113e-02,  1.1380664e+00, ...,
               2.4031512e-19,  1.7007306e-02, -1.9275181e-02],
             [ 9.4061136e-01,  4.4718497e-02,  1.1157191e+00, ...,
              -1.0413722e-17,  2.3796119e-02, -1.2194489e-02],
             ...,
             [ 9.3787611e-01, -4.6560436e-02,  1.1133505e+00, ...,
              -1.7290695e-17,  2.3981314e-02,  1.1063970e-02],
             [ 9.6639317e-01, -2.6518634e-02,  1.1362591e+00, ...,
              -8.2554688e-19,  1.8004864e-02,  1.9156536e-02],
             [ 9.8858660e-01, -9.4591742e-03,  1.1502808e+00, ...,
               5.1178591e-18,  6.6289399e-03,  1.3507359e-02]],            dtype=float32)

In [129]:
tt = jnp.zeros((5,5))

In [130]:
for i in range(5):
    tt = tt.at[i,i].add(i**3)

In [131]:
tt

DeviceArray([[ 0.,  0.,  0.,  0.,  0.],
             [ 0.,  1.,  0.,  0.,  0.],
             [ 0.,  0.,  8.,  0.,  0.],
             [ 0.,  0.,  0., 27.,  0.],
             [ 0.,  0.,  0.,  0., 64.]], dtype=float32)

In [132]:
np.allclose(star, np.array(srz_j))

True

#### right_project

In [133]:
rp = np.load(fixtures + "ylm_x_right_project_2.npy")

##### dotR

In [134]:
dotr_0 = np.load(fixtures + "ylm_x_rp_dotR_0.npy")

In [135]:
u_a = [-jnp.cos(obl), -jnp.sin(obl), 0.]
theta_a = -(0.5 * jnp.pi - inc)        
dotR_func_a = dotR(5, u_a)

In [136]:
theta_a

3.673205103416066e-06

In [137]:
a_arr = a.toarray()

In [138]:
from jaxoplanet._src.types import Array
from typing import Callable

In [155]:
def new_dotR(l_max: int, u: Array, M: Array) -> Callable[[Array],Array]:
    """Dot product M@R of a matrix M with the rotation matrix R
    Parameters
    ----------
    l_max : int
        maximum order of the spherical harmonics map
    u : Array
        axis-rotation vector
    Returns
    -------
    Callable[[Array], Array]
        a jax.vmap function of (M, theta) returning the product M@R where
        - M is a matrix (Array)
        - theta is the rotation angle in radians
    """
    Rls = [Rl(l) for l in range(l_max + 1)]
    n_max = l_max**2 + 2 * l_max + 1
    m_row = M.shape[0]
    
#     @jax.jit
    @partial(jnp.vectorize, signature=f"()->({m_row},{n_max})")
    def _R(theta: Array) -> Array:
        alpha, beta, gamma = axis_to_euler(u[0], u[1], u[2], theta)
        return jnp.hstack(
            [
                M[:, l**2 : (l + 1) ** 2] @ Rls[l](alpha, beta, gamma)
                for l in range(l_max + 1)
            ]
        )

    return _R

In [154]:
new_dotR_func_a = new_dotR(5, u_a, star)

TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [141]:
new_res = new_dotR_func_a(theta_a)

In [142]:
new_res.shape

(810, 36)

In [143]:
np.allclose(np.array(new_res), dotr_0)

True

In [144]:
theta_b = obl

In [145]:
theta_b

0.0

##### right_project

In [156]:
@jax.jit
def right_project(l, M, inc, obl, theta):
        r"""Apply the projection operator on the right.

        Specifically, this method returns the dot product :math:`M \cdot R`,
        where ``M`` is an input matrix and ``R`` is the Wigner rotation matrix
        that transforms a spherical harmonic coefficient vector in the
        input frame to a vector in the observer's frame.
        """
        # Rotate to the sky frame
        # TODO: Do this in a single compound rotation
        u_a = [-jnp.cos(obl), -jnp.sin(obl), 0.]
        theta_a = -(0.5 * jnp.pi - inc)   
#         dotR_func_a = partial(new_dotR, l, u_a)
        dotR_func_a = new_dotR(l, u_a, M)
        return dotR_func_a
        dotR_a = dotR_func_a(theta_a)
        
        u_b = [0., 0., 1.]
        theta_b = obl
        dotR_func_b = new_dotR(l, u_b, dotR_a)
        dotR_b = dotR_func_b(theta_b)
    
        u_c = [1., 0., 0.]
        theta_c = -0.5 * jnp.pi
        dotR_func_c = new_dotR(l, u_c, dotR_b)
        M = dotR_func_c(theta_c)

        # Rotate to the correct phase
        if theta.ndim > 0:
            print("ndim > 0")
            trz_j = tensordotRz_j(M,l,theta)
            M = trz_j(2)

        else:
            M_func = new_dotR(l, u_b, M)
            M = M_func(theta)


        # Rotate to the polar frame
        M_func = new_dotR(l, u_c, M)
        M = M_func(-theta_c)

        return M

In [157]:
theta_j = jnp.array(theta)

In [158]:
rp_res = right_project(5, star, inc, obl, theta_j)

TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [None]:
np.allclose(np.array(rp_res), rp, atol=1.e-6)