In [2]:
import numpy as np
import starry
starry.config.lazy = False
import exoplanet

In [3]:
import jax
import jax.numpy as jnp
from jax import jit

In [4]:
fixtures = "/Users/SuperTiger/ADACS/repo/SS2023A-BPope/starry/tests/fixtures"

### general inputs

In [5]:
time = np.linspace(-0.25, 3.25, 10000)

In [6]:
pri_deg = 0
pri_udeg = 2
pri_amp = 1.0
pri_m = 1.0
pri_r = 1.0
pri_prot = 1.0
sec_deg=5
sec_amp = 5e-3
sec_m = 0
sec_r = 0.1
sec_porb = 1.0
sec_prot = 1.0
sec_omega = 30
sec_ecc = 0.3
sec_w = 30
sec_t0 = 0

In [7]:
map1 = starry.Map(deg=pri_deg, udeg=pri_udeg, amp=pri_amp)

In [8]:
ys = map1._y

In [9]:
map2 = starry.Map(ydeg=sec_deg, amp=sec_amp)

Pre-computing some matrices... INFO:starry.ops:Pre-computing some matrices... 
Done.
INFO:starry.ops:Done.


dotr value:  <bound method PyCapsule.dotR of <starry._c_ops.Ops object at 0x13463af30>>


In [10]:
yp = map2._y

### system flux function

In [11]:
def _get_ay(y1, amp1, y2, amp2):
    ay = [amp1*y1, amp2*y2]
    ay = np.concatenate(ay)
    return ay

In [12]:
ay = _get_ay(ys, pri_amp, yp, sec_amp)

#### jax version

In [113]:
@jit
def _get_ay_j(y1,amp1,y2,amp2):
    ay = [amp1*y1, amp2*y2]
    ay = jnp.concatenate(ay)
    return ay

In [114]:
print(jax.make_jaxpr(_get_ay_j)(ys, pri_amp, yp, sec_amp))

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[1][39m b[35m:f32[][39m c[35m:f32[36][39m d[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22me[35m:f32[37][39m = xla_call[
      call_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; f[35m:f32[1][39m g[35m:f32[][39m h[35m:f32[36][39m i[35m:f32[][39m. [34m[22m[1mlet
          [39m[22m[22mj[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] g
          k[35m:f32[1][39m = mul j f
          l[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] i
          m[35m:f32[36][39m = mul l h
          n[35m:f32[37][39m = concatenate[dimension=0] k m
        [34m[22m[1min [39m[22m[22m(n,) }
      name=_get_ay_j
    ] a b c d
  [34m[22m[1min [39m[22m[22m(e,) }


In [115]:
jay = _get_ay_j(ys, pri_amp, yp, sec_amp)

In [116]:
jay

DeviceArray([1.   , 0.005, 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
             0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
             0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
             0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
             0.   , 0.   , 0.   , 0.   , 0.   ], dtype=float32)

In [117]:
ay

array([1.   , 0.005, 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
       0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
       0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
       0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
       0.   ])

In [118]:
def system_flux(x,ay):
    return np.dot(x, ay)    

In [119]:
x = np.load(fixtures + "/" + "system_X.npy")

In [120]:
x.shape

(10000, 37)

In [121]:
sflux = np.load(fixtures + "/" + "system_flux.npy")

In [122]:
%%timeit
res = system_flux(x,ay)

289 µs ± 47.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


#### jax version (func 2)

In [123]:

def system_flux_j(x,ay):
    return jnp.dot(x, ay) 



In [124]:
sf_j_jit(x,ay)

DeviceArray([1.005, 1.005, 1.005, ..., 1.005, 1.005, 1.005], dtype=float32)

In [125]:
%timeit -o -r 7 -n 10000 system_flux_j(x,ay)

465 µs ± 73.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


<TimeitResult : 465 µs ± 73.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)>

In [126]:
sf_j_jit = jit(system_flux_j)

sf_j_jit(x,ay)

%timeit -o -r 7 -n 10000 sf_j_jit(x,jay)


396 µs ± 26.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


<TimeitResult : 396 µs ± 26.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)>

In [127]:
jres

DeviceArray([-1.5707964, -1.5685971, -1.5663977, ..., 20.415955 ,
             20.418154 , 20.420353 ], dtype=float32)

In [128]:
jres.shape

(10000,)

#### verification

In [129]:
np.array_equal(res, sflux)

True

### system design matrix function (core.OpsSystem.X)

#### get relative positions (input)

In [41]:
# todo: need to understand how to get this value
sec_iorb = 1.57079633

In [42]:
# todo: need to understand how to calculate omega and Omega
orbit = exoplanet.orbits.KeplerianOrbit(
period=sec_porb,
t0=sec_t0,
incl=sec_iorb,
ecc=sec_ecc,
omega=0.52359878,
Omega=0.52359878,
m_planet=sec_m,
m_star=pri_m,
r_star=pri_r)

In [132]:
print(sec_porb, sec_t0, sec_iorb, sec_ecc, sec_w, sec_omega, sec_m, pri_m, pri_r)

1.0 0 1.57079633 0.3 30 30 0 1.0 1.0


In [43]:
x, y, z = orbit.get_relative_position(time)

In [44]:
xv = x.eval()
yv = y.eval()
zv = z.eval()

In [135]:
xv

array([-1.64475377, -1.65078852, -1.6568127 , ...,  4.41457845,
        4.41634656,  4.41810274])

In [136]:
x_starry = np.load(fixtures + "/x.npy")
y_starry = np.load(fixtures + "/y.npy")
z_starry = np.load(fixtures + "/z.npy")

In [137]:
def _get_xyz():
    orbit = exoplanet.orbits.KeplerianOrbit(
    period=sec_porb,
    t0=sec_t0,
    incl=sec_iorb,
    ecc=sec_ecc,
    omega=0.52359878,
    Omega=0.52359878,
    m_planet=sec_m,
    m_star=pri_m,
    r_star=pri_r)
    x, y, z = orbit.get_relative_position(time)
    xv = x.eval()
    yv = y.eval()
    zv = z.eval()
    return xv, yv, zv

In [138]:
xx,yy,zz = _get_xyz()

##### verification

In [139]:
np.allclose(xx, x_starry.flatten())

True

In [140]:
np.allclose(yy, y_starry.flatten(), atol=1.e-7)

True

In [141]:
yv

array([-0.94959905, -0.95308322, -0.95656128, ...,  2.54875808,
        2.5497789 ,  2.55079283])

In [142]:
np.allclose(zz, z_starry.flatten())

True

#### get theta of primary

In [143]:
def _get_theta_pri(pri_prot, pri_t0, pri_theta0, time):
    return (2 * np.pi) / pri_prot * (time - pri_t0) + pri_theta0

##### jax version

In [144]:
@jax.jit
def _get_theta_pri_j(pri_prot, pri_t0, pri_theta0, time):
    return (2 * jnp.pi) / pri_prot * (time - pri_t0) + pri_theta0

In [145]:
pri_t0 = 0
pri_theta0 = 0

In [146]:
res = _get_theta_pri(pri_prot, pri_t0, pri_theta0, time)

In [147]:
jres = _get_theta_pri_j(pri_prot, pri_t0, pri_theta0, time)

In [148]:
jres

DeviceArray([-1.5707964, -1.5685971, -1.5663977, ..., 20.415955 ,
             20.418154 , 20.420353 ], dtype=float32)

In [149]:
theta_pri = np.load(fixtures + "/" + "theta_pri.npy")

In [150]:
theta_pri

array([-1.57079633, -1.56859699, -1.56639766, ..., 20.41595358,
       20.41815291, 20.42035225])

##### verification

In [151]:
np.array_equal(res, theta_pri)

True

#### get theta of secondary

In [152]:
sec_theta0 = 0

In [153]:
def _get_theta_sec(sec_prot, sec_t0, sec_theta0, time):
    theta_sec = (2 * np.pi) / sec_prot * (time - sec_t0
        ) + sec_theta0
    theta_sec = np.expand_dims(theta_sec, axis=0)
    return theta_sec

In [154]:
res_theta_sec = _get_theta_sec(sec_prot, sec_t0, sec_theta0, time)

In [155]:
res_theta_sec[0].shape

(10000,)

In [156]:
theta_sec = np.load(fixtures + "/" + "theta_sec.npy")

##### verification

In [157]:
np.array_equal(res_theta_sec, theta_sec)

True

#### get phase primary

In [158]:
phase_pri = np.load(fixtures + "/" + "phase_pri.npy")

In [159]:
phase_pri.min()

1.0

#### get phase secondary

In [160]:
phase_sec = np.load(fixtures + "/" + "phase_sec.npy")

In [161]:
phase_sec.shape

(1, 10000, 36)

In [162]:
phase_sec

array([[[ 1.00000000e+00,  6.40987562e-17,  7.07050159e-17, ...,
          2.60012009e-16, -8.11440028e-32, -4.89276055e-16],
        [ 1.00000000e+00,  6.39577817e-17,  2.53957102e-03, ...,
          2.59995385e-16, -2.91448744e-18, -4.89270987e-16],
        [ 1.00000000e+00,  6.38168079e-17,  5.07912976e-03, ...,
          2.59945517e-16, -5.82878453e-18, -4.89255785e-16],
        ...,
        [ 1.00000000e+00,  6.38168079e-17,  5.07912976e-03, ...,
         -2.59945517e-16, -5.82878453e-18,  4.89255785e-16],
        [ 1.00000000e+00,  6.39577817e-17,  2.53957102e-03, ...,
         -2.59995385e-16, -2.91448744e-18,  4.89270987e-16],
        [ 1.00000000e+00,  6.40987562e-17, -1.13199499e-15, ...,
         -2.60012009e-16,  1.29912431e-30,  4.89276055e-16]]])

#### get transits across the primary (idx)

In [163]:
occ_pri = np.zeros((phase_pri.shape))
occ_sec = np.zeros((phase_sec.shape))

In [46]:
def _get_occ_idx(x,y,z,pri_r, sec_r):
    xo = x/pri_r
    yo = y/pri_r
    zo = z/pri_r
    ro = sec_r/pri_r
    b = np.sqrt(xo**2 + yo**2)
    cond_1 = b >= ro + 1.
    cond_2 = zo <=0.
    cond_3 = ro == 0.
    conds = cond_1 | cond_2 | cond_3
    b_occ = ~conds
    idx = np.arange(b.shape[0])[b_occ]
    return idx
    

##### verification

In [47]:
res_idx = _get_occ_idx(xv,yv,zv,pri_r,sec_r)

In [166]:
occ_idx = np.load(fixtures + "/occ_idx.npy")

In [167]:
np.array_equal(res_idx, occ_idx)

True

#### get transit across primary

In [168]:
primary_x = np.load(fixtures + "/primary_x.npy")

In [169]:
primary_x.shape

(802, 1)

In [170]:
def _get_occ_pri(occ_pri, idx, pri_amp, pri_x, phase_pri):
    occ_pri[idx] = occ_pri[idx] + pri_amp*pri_x - phase_pri[idx]
    return occ_pri

In [171]:
res_occ_prif = _get_occ_pri(occ_pri, res_idx, pri_amp, primary_x, phase_pri)

In [172]:
res_occ_prif.shape

(10000, 1)

In [173]:
res_occ_prif.min()

-0.012133554943887792

##### verification

In [174]:
occ_pri_f = np.load(fixtures +"/occ_pri_f.npy")

In [175]:
occ_pri_f.shape

(10000, 1)

In [176]:
occ_pri_f.min()

-0.012133554943887792

In [177]:
np.array_equal(res_occ_prif, occ_pri_f)

True

#### get occultations by the primary (idx)

In [178]:
def _get_sec_occ_idx(x,y,z,sec_r,pri_r):
    xo = -x/sec_r
    yo = -y/sec_r
    zo = -z/sec_r
    ro = pri_r/sec_r
    b = np.sqrt(xo**2 + yo**2)
    cond1 = b >= ro + 1.
    cond2 = zo <= 0.
    cond3 = ro == 0.
    conds = cond1 | cond2 | cond3
    b_occ = ~conds
    idx = np.arange(b.shape[0])[b_occ]
    return idx

In [179]:
res_socc_idx = _get_sec_occ_idx(xv,yv,zv,sec_r,pri_r)

##### verification

In [180]:
socc_idx = np.load(fixtures + "/sec_occ_idx.npy")

In [181]:
socc_idx.shape

(810,)

In [182]:
res_socc_idx.shape

(810,)

In [183]:
np.array_equal(res_socc_idx, socc_idx)

True

#### get occultations by the primary

In [184]:
secondary_x = np.load(fixtures +"/secondary_x.npy")

In [185]:
secondary_x.shape

(810, 36)

In [186]:
def _get_sec_occ(occ_sec, idx, sec_amp, sec_x, phase_sec):
    occ_sec[0][idx] = occ_sec[0][idx] + sec_amp*sec_x-phase_sec[0][idx]
    return occ_sec

In [187]:
sec_amp_here = 1.0 # need to understand why set the value to 1.0 why given sec_amp

In [188]:
res_sec_occf = _get_sec_occ(occ_sec, res_socc_idx, sec_amp_here, secondary_x, phase_sec)

In [189]:
res_occ_secf.shape

NameError: name 'res_occ_secf' is not defined

##### verification

In [190]:
occ_secf = np.load(fixtures +"/occ_sec_f.npy")

In [191]:
occ_secf.shape

(1, 10000, 36)

In [192]:
np.array_equal(res_sec_occf, occ_secf)

True

#### get the design matrix

In [193]:
def _get_design_matrix(phase_pri, occ_pri, phase_sec, occ_sec):
    x_pri = phase_pri + occ_pri
    x_sec = (phase_sec + occ_sec)[0]
    x = np.hstack((x_pri,x_sec))
    return x

In [194]:
res_x = _get_design_matrix(phase_pri, res_occ_prif, phase_sec, res_sec_occf)

In [195]:
res_x.shape

(10000, 37)

##### verification

In [196]:
np.array_equal(res_x,x)

False

#### summary function

In [197]:
def system_x():
    # get positions
    xv, yv, zv = _get_xyz()
    # get theta pri
    theta_pri = _get_theta_pri(pri_prot, pri_t0, pri_theta0, time)
    # get theta sec
    theta_sec = _get_theta_sec(sec_prot, sec_t0, sec_theta0, time)
    # get empty arrays for occ
    occ_pri = np.zeros((phase_pri.shape))
    occ_sec = np.zeros((phase_sec.shape))
    # get occ pri idx
    occ_pri_idx = _get_occ_idx(xv,yv,zv,pri_r,sec_r)
    # get occ pri
    occ_pri_f = _get_occ_pri(occ_pri, occ_pri_idx, pri_amp, primary_x, phase_pri)
    # get occ sec idx
    occ_sec_idx = _get_sec_occ_idx(xv,yv,zv,sec_r,pri_r)
    # get occ sec
    occ_sec_f = _get_sec_occ(occ_sec, occ_sec_idx, sec_amp_here, secondary_x, phase_sec)
    # get x
    x = _get_design_matrix(phase_pri, occ_pri_f, phase_sec, occ_sec_f)
    return x
    
    

In [198]:
sf_x = system_x()

In [199]:
sf_x.shape

(10000, 37)

In [200]:
np.array_equal(sf_x,x)

False

### phase curve of primary object (core.OpsLD.X)

In [13]:
inp_xo = np.zeros(len(time))
inp_yo = np.zeros(len(time))
inp_zo = np.zeros(len(time))

#### init flat light curve

In [14]:
def _ld_init_light_curve(x_list_len):
    flux_init = np.ones(x_list_len)
    return flux_init

In [15]:
flux_init = _ld_init_light_curve(len(time))

In [16]:
flux_init

array([1., 1., 1., ..., 1., 1., 1.])

#### get occultation mask

In [17]:
def _ld_get_occ_mask(xo,yo,zo, ro):
    b = np.sqrt(xo**2 + yo**2)
    cond1 = b >=ro + 1.
    cond2 = zo <= 0.
    cond3 = ro == 0.
    conds = cond1 | cond2 | cond3
    b_occ = ~conds
    i_occ = np.arange(b.size)[b_occ]
    return i_occ

In [18]:
ld_i_occ = _ld_get_occ_mask(inp_xo,inp_yo,inp_zo,0)

In [19]:
ld_i_occ

array([], dtype=int64)

#### get Agol 'c' coefficients

In [20]:
# todo: understand how c is calculated from u
# u = [-1, 0.4, 0.26] as input
c = np.array([0.21, 0.92,-0.065])

In [21]:
u1 = 0.4
u2 = 0.26

In [22]:
c0 = 1- u1 -1.5*u2
c1 = u1 + 2*u2
c2 = -0.25*u2

In [23]:
print(c0,c1,c2)

0.20999999999999996 0.92 -0.065


In [24]:
@jit
def get_cl(u1, u2):
    c0 = 1- u1 -1.5*u2
    c1 = u1 + 2*u2
    c2 = -0.25*u2
    return jnp.array([c0,c1,c2])

In [25]:
get_cl(0.4, 0.26)

DeviceArray([ 0.21000004,  0.91999996, -0.065     ], dtype=float32)

In [26]:
c_norm = c/(np.pi * (c[0] + 2 * c[1] / 3))

In [27]:
c_norm

array([ 0.08118835,  0.3556823 , -0.02512973])

In [28]:
from scipy.special import binom

In [29]:
from typing import Any

Scalar = Any
Array = Any
PyTree = Any

In [30]:
@jit
def greens_basis_transform(u: Array) -> Array:
    u = jnp.append(-1, u)
    size = len(u)
    i = np.arange(size)
    arg = binom(i[None, :], i[:, None]) @ u
    p = (-1) ** (i + 1) * arg
    g = [0 for _ in range(size + 2)]
    for n in range(size - 1, 1, -1):
        g[n] = p[n] / (n + 2) + g[n + 2]
    g[1] = p[1] + 3 * g[3]
    g[0] = p[0] + 2 * g[2]
    return jnp.stack(g[:-2])

In [31]:
greens_basis_transform(jnp.array([0.4, 0.26]))

DeviceArray([ 0.21000004,  0.91999996, -0.065     ], dtype=float32)

#### get occultation flux

In [32]:
limbdark_arr = np.array([])

In [33]:
limbdark_arr

array([], dtype=float64)

In [34]:
def _ld_get_occ_flux(zo, ro, i_occ, flux, limbdark_arr):
    los = zo[i_occ]
    print(los.shape)
    r = ro * np.ones(len(los))
    flux[i_occ] = limbdark_arr
    return flux

In [35]:
ld_occ_flux = _ld_get_occ_flux(inp_zo, 0, ld_i_occ, flux_init, limbdark_arr)

(0,)


In [36]:
ld_occ_flux.min()

1.0

In [37]:
ld_occ_flux.shape

(10000,)

#### verification

In [64]:
ld_x = np.load(fixtures + "/ld_x.npy")

In [39]:
ld_x.shape

(802, 1)

In [48]:
inp_xo_f = (xv/pri_r)[res_idx]
inp_yo_f = (yv/pri_r)[res_idx]
inp_zo_f = (zv/pri_r)[res_idx]
ro = sec_r/pri_r

In [100]:
ro

0.1

In [49]:
inp_xo_f.shape

(802,)

In [50]:
flux_init_2 = _ld_init_light_curve(len(inp_xo_f))

In [51]:
flux_init_2.shape

(802,)

In [52]:
ld_i_occ_2 = _ld_get_occ_mask(inp_xo_f,inp_yo_f,inp_zo_f,ro)

In [53]:
ld_i_occ_2.shape

(802,)

In [54]:
limbdark_arr_2 = np.load(fixtures + "/ld_limbdark.npy")

In [55]:
limbdark_arr_2.shape

(802,)

In [227]:
ld_occ_flux_2 = _ld_get_occ_flux(inp_zo_f, ro, ld_i_occ_2, flux_init_2, limbdark_arr_2)

(802,)


In [228]:
ld_occ_flux_2.shape

(802,)

In [229]:
np.array_equal(ld_occ_flux_2, ld_x.flatten())

True

#### summary

In [56]:
def ld_x(xo, yo, zo, ro, limbdark_arr):
    flux_init = np.ones(len(xo))
    b = np.sqrt(xo**2 + yo**2)
    return b
#     cond1 = b >=ro + 1.
#     cond2 = zo <= 0.
#     cond3 = ro == 0.
#     conds = cond1 | cond2 | cond3
#     b_occ = ~conds
#     i_occ = np.arange(b.size)[b_occ]
#     los = zo[i_occ]
#     r = ro * np.ones(len(los))
#     flux_init[i_occ] = limbdark_arr
#     return flux_init

In [57]:
ld_b = ld_x(inp_xo_f, inp_yo_f, inp_zo_f, ro, limbdark_arr_2)

In [59]:
ld_b.shape

(802,)

In [231]:
ldx = ld_x(inp_xo, inp_yo, inp_zo, 0, limbdark_arr)

In [232]:
ldx.min()

1.0

In [233]:
ldx_2 = ld_x(inp_xo_f, inp_yo_f, inp_zo_f, ro, limbdark_arr_2)

In [234]:
ldx_2.shape

(802,)

In [236]:
np.array_equal(ldx_2, ld_x.flatten())

True

##### jax version

In [400]:
@jax.jit
def j_cond(xo, yo, zo, ro):
    b = jnp.sqrt(xo**2 + yo**2)
    cond1 = b >=ro + 1.
    cond2 = zo <= 0.
    cond3 = ro == 0.
    conds = cond1 | cond2 | cond3
    b_occ = ~conds
    return jnp.where(b_occ == 1, b_occ, 0)

In [404]:
b_idx = j_cond(inp_xo_f, inp_yo_f, inp_zo_f, ro)

(802,)

In [406]:
@jax.jit
def jld_x(xo, yo, zo, ro, limbdark_arr):
    flux_init = jnp.ones(len(xo))
#     b = jnp.sqrt(xo**2 + yo**2)
#     cond1 = b >=ro + 1.
#     cond2 = zo <= 0.
#     cond3 = ro == 0.
#     conds = cond1 | cond2 | cond3
#     b_occ = ~conds
# #     i_occ = jnp.arange(b.size)[b_occ]
#     val_idx = jnp.where(b_occ == 1, b_occ, 0)
    val_idx = j_cond(xo, yo, zo, ro)
#     return val_idx
#     i_occ = jnp.arange(b.size)*b_occ
#     return i_occ
    flux = jax.lax.select(val_idx, limbdark_arr, flux_init)
#     flux = flux_init.at[i_occ].set(limbdark_arr)
    return flux

In [293]:
b_occ = jld_x(inp_xo_f, inp_yo_f, inp_zo_f, ro, limbdark_arr_2)

In [352]:
i_occ = jld_x(inp_xo_f, inp_yo_f, inp_zo_f, ro, limbdark_arr_2)

In [407]:
jldx2 = jld_x(inp_xo_f, inp_yo_f, inp_zo_f, ro, limbdark_arr_2)

In [408]:
jldx2

DeviceArray([0.999992  , 0.9998695 , 0.99965906, 0.9993814 , 0.9990474 ,
             0.9986648 , 0.99823993, 0.9977787 , 0.9972865 , 0.9967686 ,
             0.9962304 , 0.9956774 , 0.99511564, 0.9945515 , 0.9939924 ,
             0.9934469 , 0.99292576, 0.9924435 , 0.992024  , 0.99173427,
             0.99152464, 0.99133736, 0.9911666 , 0.99100906, 0.99086255,
             0.9907256 , 0.9905969 , 0.9904756 , 0.99036086, 0.99025214,
             0.99014884, 0.99005055, 0.98995686, 0.98986745, 0.989782  ,
             0.98970014, 0.9896218 , 0.9895467 , 0.98947465, 0.9894055 ,
             0.98933905, 0.98927516, 0.9892137 , 0.9891546 , 0.9890977 ,
             0.9890429 , 0.9889901 , 0.9889393 , 0.98889023, 0.988843  ,
             0.9887974 , 0.98875356, 0.9887112 , 0.98867035, 0.98863095,
             0.988593  , 0.9885563 , 0.98852104, 0.988487  , 0.98845416,
             0.9884225 , 0.98839206, 0.98836267, 0.98833436, 0.9883071 ,
             0.98828095, 0.98825574, 0.9882315 , 0.

In [354]:
i_occ.shape

(802,)

In [343]:
ta = jnp.array([True,False,True, True, False])

In [348]:
ta

DeviceArray([ True, False,  True,  True, False], dtype=bool)

In [392]:
import jax.numpy as jnp
from jax import lax

# Define a boolean mask
x = jnp.array([1, 2, 3, 4, 5])
mask = jnp.where(x>2, x, 0)

# Define two arrays to select from
a = jnp.array([10, 20, 30, 40, 50])
b = jnp.array([-10, -20, -30, -40, -50])

# Use jax.lax.select to conditionally select values
result = lax.select(mask, a, b)

# Print the result
print(result)  # [ -10  -20  30  40  50]


[-10 -20  30  40  50]


In [370]:
fn(ta)

DeviceArray([0, 2, 3], dtype=int32)

In [287]:
jldx_2 = jld_x(inp_xo_f, inp_yo_f, inp_zo_f, ro, limbdark_arr_2)

In [387]:
x = jnp.array([1, 2, 3, 4, 5])
mask = x > 2


In [389]:
mask.any()

DeviceArray(True, dtype=bool)

In [277]:
jldx_2_np = jax.device_get(jldx_2)

In [409]:
np.allclose(jldx2, ld_x.flatten())

True

In [77]:
from jaxoplanet._src.core.limb_dark import light_curve as limb_dark_light_curve

In [80]:
from functools import partial

In [93]:
lc_func = partial(limb_dark_light_curve, u)

In [94]:
dir(lc_func)

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'args',
 'func',
 'keywords']

In [96]:
lc_func.args

(array([0.4 , 0.26]),)

In [97]:
res_2 = lc_func(ld_b, ro)

In [98]:
res_2

DeviceArray([-8.04662704e-06, -1.30474567e-04, -3.40938568e-04,
             -6.18517399e-04, -9.52601433e-04, -1.33520365e-03,
             -1.76006556e-03, -2.22122669e-03, -2.71350145e-03,
             -3.23140621e-03, -3.76969576e-03, -4.32258844e-03,
             -4.88442183e-03, -5.44846058e-03, -6.00761175e-03,
             -6.55311346e-03, -7.07429647e-03, -7.55655766e-03,
             -7.97599554e-03, -8.26579332e-03, -8.47542286e-03,
             -8.66264105e-03, -8.83340836e-03, -8.99094343e-03,
             -9.13739204e-03, -9.27442312e-03, -9.40316916e-03,
             -9.52440500e-03, -9.63914394e-03, -9.74786282e-03,
             -9.85115767e-03, -9.94944572e-03, -1.00432038e-02,
             -1.01326108e-02, -1.02180243e-02, -1.02998614e-02,
             -1.03781223e-02, -1.04533434e-02, -1.05253458e-02,
             -1.05945468e-02, -1.06610060e-02, -1.07249022e-02,
             -1.07862353e-02, -1.08453631e-02, -1.09022856e-02,
             -1.09571218e-02, -1.1009872

In [78]:
u = np.array([0.4,0.26])

In [61]:
res = light_curve(u, ld_b, ro)

In [66]:
res = np.array(res)

In [76]:
np.allclose(quad_res, ld_x.flatten())

True

In [73]:
from jaxoplanet._src.core.quad import light_curve

In [74]:
quad_res = light_curve(0.4, 0.26, ld_b, ro)

In [75]:
quad_res = np.array(quad_res) + 1

In [69]:
ld_x

array([[0.999992  ],
       [0.99986955],
       [0.99965909],
       [0.99938145],
       [0.99904741],
       [0.99866477],
       [0.99823996],
       [0.99777874],
       [0.99728652],
       [0.99676862],
       [0.99623041],
       [0.99567743],
       [0.99511564],
       [0.99455151],
       [0.99399238],
       [0.99344689],
       [0.99292574],
       [0.99244351],
       [0.99202402],
       [0.99173426],
       [0.99152462],
       [0.99133737],
       [0.99116659],
       [0.99100905],
       [0.99086256],
       [0.99072557],
       [0.99059689],
       [0.99047558],
       [0.99036088],
       [0.99025215],
       [0.99014887],
       [0.99005058],
       [0.98995688],
       [0.98986744],
       [0.98978196],
       [0.98970017],
       [0.98962182],
       [0.98954672],
       [0.98947467],
       [0.98940549],
       [0.98933904],
       [0.98927515],
       [0.98921371],
       [0.9891546 ],
       [0.98909769],
       [0.98904289],
       [0.98899011],
       [0.988

### phase curve of secondary object (core.OpsYlm.X)

In [407]:
# inputs
sec_inc = 1.57079633
sec_obl = 0.
sec_u = -1.
sec_f = 3.14159265
n_col = 36 # from rTA1

#### get x shape

In [408]:
def _ylm_get_x_shape(theta, n_col):
    rows = theta.shape[0]
    cols = n_col
    x = np.zeros((rows, cols))
    return x

In [432]:
res_theta_sec[0].shape[0]

10000

In [434]:
res_ylm_x = _ylm_get_x_shape(res_theta_sec[0], n_col)

In [435]:
res_ylm_x.shape

(10000, 36)

In [436]:
res_ylm_x

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

#### get occultation mask

In [437]:
def _ylm_get_occ_mask(xo,yo,zo, ro):
    b = np.sqrt(xo**2 + yo**2)
    cond1 = b >=ro + 1.0
    cond2 = zo <= 0.
    cond3 = ro == 0.
    conds = cond1| cond2| cond3
    b_occ = ~conds
    i_rot = np.arange(b.size)[conds]
    i_occ = np.arange(b.size)[b_occ]
    return i_rot, i_occ

In [456]:
inp_xo_sec = -xv
inp_yo_sec = -yv
inp_zo_sec = -zv
ro_sec = 0.

##### verification

In [457]:
res_ylm_irot, res_ylm_iocc = _ylm_get_occ_mask(inp_xo_sec,inp_yo_sec,inp_zo_sec,ro_sec)

In [458]:
res_ylm_irot.shape

(10000,)

In [441]:
ylm_irot = np.load(fixtures + "/ylm_irot.npy")

In [459]:
ylm_irot.shape

(10000,)

In [460]:
np.array_equal(res_ylm_irot, ylm_irot)

True

In [461]:
res_ylm_iocc

array([], dtype=int64)

#### get rotation operator

In [462]:
ylm_rta1 = np.load(fixtures + "/ylm_rta1.npy")

In [463]:
ylm_rta1.shape

(1, 36)

In [464]:
ylm_rta1

array([[ 1.00000000e+00,  0.00000000e+00,  1.15470054e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         5.59016994e-01,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00, -1.25000000e-01,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.11022302e-16,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        -6.66133815e-16,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00, -1.11022302e-16,  0.00000000e+00]])

In [468]:
def _ylm_get_rta1_f(rta1, theta, irot):
    res = np.tile(rta1, (theta[irot].shape[0], 1))
    return res

##### verification

In [469]:
res_ylm_rta1_2 = _ylm_get_rta1_f(ylm_rta1, res_theta_sec[0],res_ylm_irot)

In [465]:
ylm_rta1_2 = np.load(fixtures + "/ylm_rta1_2.npy")

In [467]:
ylm_rta1_2.shape

(10000, 36)

In [472]:
np.array_equal(res_ylm_rta1_2, ylm_rta1_2)

True

#### get rotation x

In [478]:
def _ylm_get_rta1_x(x,right_project, irot):
    x[irot] = right_project
    return x

In [481]:
# todo: understand how to get right project
ylm_right_project = np.load(fixtures + "/ylm_right_project.npy")

##### verification

In [479]:
res_ylm_rta1_x = _ylm_get_rta1_x(res_ylm_x, ylm_right_project, res_ylm_irot)

In [473]:
ylm_rta1_x = np.load(fixtures + "/ylm_rta1_x.npy")

In [474]:
ylm_rta1_x.shape

(10000, 36)

In [480]:
np.array_equal(res_ylm_rta1_x, res_ylm_rta1_x)

True

#### verification

In [494]:
# inputs (additional)
ylm_idx = res_socc_idx
ylm_xo_f = (-xv/sec_r)[ylm_idx]
ylm_yo_f = (-yv/sec_r)[ylm_idx]
ylm_zo_f = (-zv/sec_r)[ylm_idx]
ylm_ro = pri_r/sec_r

In [492]:
res_ylm_x_fi = _ylm_get_x_shape(res_theta_sec[0,ylm_idx], n_col)

In [493]:
res_ylm_x_fi.shape

(810, 36)

In [495]:
res_ylm_irot_f, res_ylm_iocc_f = _ylm_get_occ_mask(ylm_xo_f,ylm_yo_f,ylm_zo_f,ylm_ro)

In [499]:
res_ylm_rta1_2_f = _ylm_get_rta1_f(ylm_rta1, res_theta_sec[0,ylm_idx],res_ylm_irot_f)

In [500]:
res_ylm_rta1_2_f.shape

(0, 36)

In [501]:
ylm_st_f = np.load(fixtures + "/ylm_st_f.npy")

In [502]:
ylm_st_f.shape

(810, 36)

##### core.OpsYlm.A

In [504]:
import scipy

In [505]:
ylm_a_f = scipy.sparse.load_npz(fixtures + "/ylm_A_f.npz")

In [558]:
av = ylm_a_f.toarray()

In [560]:
av

array([[ 3.18309886e-01,  0.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  5.51328895e-01, ...,
         0.00000000e+00,  1.11022302e-16,  0.00000000e+00],
       ...,
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
         7.45262512e+00,  0.00000000e+00,  1.85161960e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00, -7.02640234e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00]])

In [516]:
def _ylm_get_sta(st,a):
    if scipy.sparse.issparse(a):
        a = a.toarray()
    return np.dot(st,a)

In [517]:
res_ylm_sta_f = _ylm_get_sta(ylm_st_f, ylm_a_f)

In [512]:
ylm_sta_f = np.load(fixtures + "/ylm_sta_f.npy")

In [520]:
np.allclose(res_ylm_sta_f, ylm_sta_f)

True

In [526]:
def _ylm_get_theta_z(xo, yo, iocc):
    return np.arctan2(xo[iocc], yo[iocc])

In [527]:
res_ylm_tz_f = _ylm_get_theta_z(ylm_xo_f, ylm_yo_f,res_ylm_iocc_f)

In [529]:
res_ylm_tz_f.shape

(810,)

In [521]:
ylm_theta_z_f = np.load(fixtures + "/ylm_theta_z_f.npy")

In [523]:
ylm_theta_z_f.shape

(810,)

In [536]:
np.allclose(res_ylm_tz_f, ylm_theta_z_f, atol=1.e-5)

False

In [537]:
diff = res_ylm_tz_f - ylm_theta_z_f

In [539]:
diff.min()

-1.3345817678978023e-05

In [541]:
res_ylm_xo_bcc = ylm_xo_f[res_ylm_iocc_f]

In [540]:
ylm_xo_bcc = np.load(fixtures + "/ylm_x_occ.npy")

In [547]:
np.allclose(res_ylm_xo_bcc, ylm_xo_bcc)

True

In [549]:
res_ylm_yo_bcc = ylm_yo_f[res_ylm_iocc_f]

In [548]:
ylm_yo_bcc = np.load(fixtures + "/ylm_y_occ.npy")

In [552]:
np.allclose(res_ylm_yo_bcc, ylm_yo_bcc, atol=1.e-6)

True

In [553]:
ylm_rp_occ = np.load(fixtures + "/ylm_right_project_occ.npy")

In [555]:
def _ylm_get_f_x(x,right_project, iocc):
    x[iocc] = right_project
    return x

In [556]:
res_ylm_x_f = _ylm_get_f_x(res_ylm_x_fi, ylm_rp_occ,res_ylm_iocc_f)

In [482]:
ylm_x_f = np.load(fixtures + "/ylm_x_f.npy")

In [491]:
ylm_x_f.shape

(810, 36)

In [557]:
np.array_equal(res_ylm_x_f, ylm_x_f)

True

### properties of surface map

In [561]:
import starry

In [562]:
map1 = starry.Map(ydeg=5)

Pre-computing some matrices... Done.


In [593]:
dir(map1.ops._c_ops)

['A',
 'A1',
 'A1Big',
 'A1Inv',
 'A2',
 'F',
 'N',
 'Nf',
 'Nu',
 'Ny',
 'OrenNayarPolynomial',
 '__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'deg',
 'dotR',
 'fdeg',
 'pT',
 'rT',
 'rTA1',
 'rTReflected',
 'sT',
 'sTOblate',
 'sTReflected',
 'spotYlm',
 'tensordotRz',
 'udeg',
 'ydeg']

In [608]:
map1.ops._c_ops.A1

<36x36 sparse matrix of type '<class 'numpy.float64'>'
	with 104 stored elements in Compressed Sparse Column format>

In [609]:
map2 = starry.Map(udeg=2)

In [611]:
dir(map2.ops)

['X',
 '_LimbDarkIsPhysical',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_get_cl',
 '_limbdark',
 'flux',
 'intensity',
 'limbdark_is_physical',
 'nw',
 'render',
 'render_ld',
 'set_vector',
 'udeg']