In [16]:
import pandas as pd
import numpy as np
import jax.numpy as jnp
from jax import config
from jax import jacfwd, vmap
config.update('jax_enable_x64', True) 
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (12,4)
from microjax.inverse_ray.lightcurve import mag_lc, mag_uniform

data = pd.read_csv("../data/ogle-2014-blg-0124/phot.dat", 
                   delim_whitespace=True,header=None, names=["HJD", "mag", "mage", "seeing", "sky"])
data["HJD"] -= 2450000
data = data[data.HJD>6600]
mag0 = 18.0
data["flux"] = 10**(-0.4*(data.mag - mag0))
data["fluxe"] = data.flux * 0.4 * np.log(10) * data.mage
print("Number of data: %d"%(len(data)))

Number of data: 2283


In [17]:
def log_likelihood(params, t, y, yerr):
    t0, u0, lntE, lns, lnq, alpha, lnrho, fs, f_sum = params
    tE = jnp.exp(lntE)
    s = jnp.exp(lns)
    q = jnp.exp(lnq)
    alpha_rad = jnp.radians(alpha)
    rho = jnp.exp(lnrho)
    tau = (t - t0) / tE
    um = u0
    y1 = -um * jnp.sin(alpha_rad) + tau * jnp.cos(alpha_rad)
    y2 = um * jnp.cos(alpha_rad) + tau * jnp.sin(alpha_rad)
    w_points = jnp.array(y1 + y2 * 1j, dtype=complex)
    mag_func = lambda w: mag_uniform(w, rho, q=q, s=s, r_resolution=100, th_resolution=100)
    mags = vmap(mag_func)(w_points)
    fb = f_sum - fs
    f_model = mags * fs + fb
    residuals = y - f_model
    N = len(y)
    sigma = yerr 
    loglike = -0.5 * jnp.sum((residuals / sigma) ** 2) - 0.5 * N * jnp.log(2 * jnp.pi) - jnp.sum(jnp.log(sigma))
    return loglike

In [18]:
def get_gradients(params, t, y, yerr):
    return jacfwd(log_likelihood)(params, t, y, yerr)

params_init = jnp.array([6.83640951e+03, 2.24211333e-01, jnp.log(1.33559958e+02), jnp.log(9.16157288e-01),
                         jnp.log(5.87559438e-04), 1.00066409e+02, jnp.log(2.44003713e-03), 8.06074085e-01, 8.62216897e-01])

In [19]:
import optax
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params_init)

In [None]:
num_iterations = 1000
for i in range(num_iterations):
    grads = get_gradients(params_init, data.HJD.values, data.flux.values, data.fluxe.values)
    updates, opt_state = optimizer.update(grads, opt_state)
    params_init = optax.apply_updates(params_init, updates)
    
    if i % 10 == 0:
        loglike_val = log_likelihood(params_init, data.HJD.values, data.flux.values, data.fluxe.values)
        print(f"Iteration {i}, Log-Likelihood: {loglike_val}")

Iteration 0, Log-Likelihood: -1113959.5676480585
Iteration 10, Log-Likelihood: -1408550.9511158466
Iteration 20, Log-Likelihood: -1707445.4057680892
Iteration 30, Log-Likelihood: -2016732.856843921
Iteration 40, Log-Likelihood: -2337879.3302009525
Iteration 50, Log-Likelihood: -2659624.2271432285
Iteration 60, Log-Likelihood: -2986532.6398288063
Iteration 70, Log-Likelihood: -3322211.684732432
Iteration 80, Log-Likelihood: -3644704.553946016
Iteration 90, Log-Likelihood: -3963899.055905881
Iteration 100, Log-Likelihood: -4284765.789383183
Iteration 110, Log-Likelihood: -4603237.451563558
Iteration 120, Log-Likelihood: -4906923.113087061
