In [9]:
import numpyro
import numpyro.distributions as dist
import numpy as np
import jax.numpy as jnp
import jax
import nested_pandas as npd

In [11]:
df = npd.read_parquet("data/lightcurve_thindisk_fixed_nested.parquet")
cols = [c for c in df.columns if c not in ["transfer_function"]]
df = df[cols]
df

Unnamed: 0_level_0,ID,cos_inc,log_mbh,log_mdot,redshift,lag,light_curve
time,band,mag,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
time,band,mag,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
time,band,mag,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3
time,band,mag,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4
time,band,mag,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5
0,0,0.866000,8.000000,0.000000,0.100000,"[2.1842446730521603, 2.973574177527805, 4.200837845368423, 5.364424911008781, 6.361364586737727, 7.18104368469921]",time  band  mag  0.0  0  18.601379  +1199 rows  ...  ...
time,band,mag,,,,,
0.0,0,18.601379,,,,,
+1199 rows,...,...,,,,,
1,1,0.866000,8.000000,0.000000,0.100000,"[2.1842446730521603, 2.973574177527805, 4.200837845368423, 5.364424911008781, 6.361364586737727, 7.18104368469921]",time  band  mag  0.0  0  19.047684  +1199 rows  ...  ...
time,band,mag,,,,,
0.0,0,19.047684,,,,,
+1199 rows,...,...,,,,,
2,2,0.866000,8.000000,0.000000,0.100000,"[2.1842446730521603, 2.973574177527805, 4.200837845368423, 5.364424911008781, 6.361364586737727, 7.18104368469921]",time  band  mag  0.0  0  20.512323  +1199 rows  ...  ...
time,band,mag,,,,,

time,band,mag
0.0,0,18.601379
+1199 rows,...,...

time,band,mag
0.0,0,19.047684
+1199 rows,...,...

time,band,mag
0.0,0,20.512323
+1199 rows,...,...

time,band,mag
0.0,0,19.780731
+1199 rows,...,...

time,band,mag
0.0,0,19.738477
+1199 rows,...,...


In [12]:
from eztaox.kernels.quasisep import Exp
from eztaox.fitter import random_search
from eztaox.models import MultiVarModel

SNR = 1000.0


def run_eztaox(time, band, mag):
    times, mags, noisy_mags, mag_errs = {}, {}, {}, {}
    bands = list(np.unique(band)[[0, -1]])
    for b in bands:
        mask = band == b
        times[b] = time[mask]
        mags[b] = mag[mask]
        noisy_mags[b] = (
            mags[b] + np.random.normal(0, 1, size=mags[b].shape) * mags[b] / SNR
        )
        mag_errs[b] = mags[b] / SNR
    inds = jnp.argsort(jnp.concatenate([times[b] for b in bands]))
    X = (
        jnp.concatenate([times[b] for b in bands])[inds],
        jnp.concatenate(
            [i * jnp.ones_like(times[b], dtype=int) for i, b in enumerate(bands)]
        )[inds],
    )
    for b in bands:
        noisy_mags[b] = jnp.array(noisy_mags[b])
        noisy_mags[b] -= jnp.median(noisy_mags[b])

    y = jnp.concatenate([noisy_mags[b] for b in bands])[inds]
    yerr = jnp.concatenate([mag_errs[b] for b in bands])[inds]

    has_lag = True  # if fit interband lags
    zero_mean = True  # if fit a mean function
    nBand = len(bands)

    # initialize a GP kernel, note the initial parameters are not used in the fitting
    k = Exp(scale=100.0, sigma=1.0)
    m = MultiVarModel(X, y, yerr, k, nBand, has_lag=has_lag, zero_mean=zero_mean)

    def initSampler():
        # GP kernel param
        log_drw_scale = numpyro.sample(
            "drw_scale", dist.Uniform(jnp.log(0.01), jnp.log(1000))
        )
        log_drw_sigma = numpyro.sample(
            "drw_sigma", dist.Uniform(jnp.log(0.01), jnp.log(10))
        )
        log_kernel_param = jnp.stack([log_drw_scale, log_drw_sigma])
        numpyro.deterministic("log_kernel_param", log_kernel_param)

        # parameters to relate the amplitudes in each band
        log_amp_scale = numpyro.sample("log_amp_scale", dist.Uniform(-2, 2))

        mean = numpyro.sample(
            "mean",
            dist.Uniform(low=jnp.asarray([-0.1, -0.1]), high=jnp.asarray([0.1, 0.1])),
        )

        # interband lags
        lag = numpyro.sample("lag", dist.Uniform(-10, 10))

        sample_params = {
            "log_kernel_param": log_kernel_param,
            "log_amp_scale": log_amp_scale,
            "mean": mean,
            "lag": lag,
        }

        return sample_params

    model = m
    fit_key = jax.random.PRNGKey(1)
    nSample = 1_000
    nBest = 5  # it seems like this number needs to be high

    bestP, ll = random_search(model, initSampler, fit_key, nSample, nBest)

    return {"best_params": bestP, "log_likelihood": ll}

In [13]:
res = df.iloc[:10].map_rows(
    run_eztaox,
    columns=["light_curve.time", "light_curve.band", "light_curve.mag"],
    row_container="args",
)
res

Unnamed: 0,best_params,log_likelihood
0,"{'lag': 4.485634800377804, 'log_amp_scale': -0.012517831709503331, 'log_kernel_param': [6.280977131400453, -0.3222664536000996], 'mean': [-0.08006180995604205, 0.05387841598852514]}",645.8614086676298
1,"{'lag': 2.52834132302178, 'log_amp_scale': -0.0423906068996533, 'log_kernel_param': [3.0195855958455597, -1.1381191480470858], 'mean': [-0.0013556186603485899, 0.004095973980314672]}",451.4390238963537
2,"{'lag': 2.5234785598652016, 'log_amp_scale': 0.004733461536943287, 'log_kernel_param': [3.4904263502785318, -0.6709334454384429], 'mean': [-0.0013556186603485899, 0.004095973980314672]}",363.666495103379
3,"{'lag': 3.48099408975768, 'log_amp_scale': -0.09029771803181019, 'log_kernel_param': [3.434506166367917, -1.499840031767058], 'mean': [0.0020413791618782984, -0.03800461300923348]}",601.3049654979931
4,"{'lag': 2.542600098457406, 'log_amp_scale': -0.060920499975962085, 'log_kernel_param': [3.8905857997133064, -1.028932612590522], 'mean': [-0.0013556186603485899, 0.004095973980314672]}",541.8749600622378
5,"{'lag': 3.4997980549832777, 'log_amp_scale': -0.11608937320005511, 'log_kernel_param': [3.1631299820683223, -1.3626185410366085], 'mean': [0.0020413791618782984, -0.03800461300923348]}",546.0576337278795
6,"{'lag': 3.4939133619599896, 'log_amp_scale': -0.16852295112391893, 'log_kernel_param': [3.9096803848202293, -1.7146598255317873], 'mean': [0.0020413791618782984, -0.03800461300923348]}",728.1864690226193
7,"{'lag': 3.53368758467303, 'log_amp_scale': -0.04510635479932078, 'log_kernel_param': [3.8726327022708205, -0.9915142049711875], 'mean': [0.08278969359306863, -0.05068042678406135]}",533.2677042445933
8,"{'lag': 2.4899426259643582, 'log_amp_scale': -0.01336503240019363, 'log_kernel_param': [4.325341505010771, -0.48920224520760264], 'mean': [-0.03511726390175549, 0.03982690277189787]}",442.8771443186927
9,"{'lag': 4.465401841312857, 'log_amp_scale': -0.038288504595726525, 'log_kernel_param': [3.9615853978492, -1.2280408477747586], 'mean': [-0.07496934448001613, 0.03380726620443477]}",603.4410868008575


In [14]:
df["lag"].iloc[:10]

0    [2.18424467 2.97357418 4.20083785 5.36442491 6...
1    [2.18424467 2.97357418 4.20083785 5.36442491 6...
2    [2.18424467 2.97357418 4.20083785 5.36442491 6...
3    [2.18424467 2.97357418 4.20083785 5.36442491 6...
4    [2.18424467 2.97357418 4.20083785 5.36442491 6...
5    [2.18424467 2.97357418 4.20083785 5.36442491 6...
6    [2.18424467 2.97357418 4.20083785 5.36442491 6...
7    [2.18424467 2.97357418 4.20083785 5.36442491 6...
8    [2.18424467 2.97357418 4.20083785 5.36442491 6...
9    [2.18424467 2.97357418 4.20083785 5.36442491 6...
Name: lag, dtype: list<element: double>[pyarrow]

In [15]:
res.map_rows(
    lambda x: {
        "lag": x["best_params"]["lag"],
        "log_kernel_param": x["best_params"]["log_kernel_param"],
    }
)

Unnamed: 0,lag,log_kernel_param
0,4.485634800377804,"[6.280977131400453, -0.3222664536000996]"
1,2.52834132302178,"[3.0195855958455597, -1.1381191480470858]"
2,2.5234785598652016,"[3.4904263502785318, -0.6709334454384429]"
3,3.48099408975768,"[3.434506166367917, -1.499840031767058]"
4,2.542600098457406,"[3.8905857997133064, -1.028932612590522]"
5,3.4997980549832777,"[3.1631299820683223, -1.3626185410366085]"
6,3.4939133619599896,"[3.9096803848202293, -1.7146598255317873]"
7,3.53368758467303,"[3.8726327022708205, -0.9915142049711875]"
8,2.4899426259643582,"[4.325341505010771, -0.48920224520760264]"
9,4.465401841312857,"[3.9615853978492, -1.2280408477747586]"
