In [2]:
import numpyro
import numpyro.distributions as dist
import numpy as np
import jax.numpy as jnp
import jax
import nested_pandas as npd

In [3]:
df = npd.read_parquet("data/lightcurve_thindisk_eztaox_nested.parquet")
cols = [c for c in df.columns if c not in ["transfer_function"]]
df = df[cols]
df

Unnamed: 0_level_0,ID,cos_inc,log_mbh,log_mdot,redshift,lag,light_curve
time,band,mag,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
time,band,mag,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
time,band,mag,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3
time,band,mag,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4
time,band,mag,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5
0,0,0.919107,6.048582,1.688403,0.447895,"[1.8212359287936297, 2.479463007698864, 3.5135329173566037, 4.530151762976888, 5.445982479700769, 6.23571103759643]",time  band  mag  0.0  0  18.682762  +1199 rows  ...  ...
time,band,mag,,,,,
0.0,0,18.682762,,,,,
+1199 rows,...,...,,,,,
1,1,0.975773,6.049580,-1.306920,0.543590,"[0.21918880609402863, 0.2817659004204899, 0.3821933452707273, 0.48386308244618753, 0.5799001324487597, 0.668142247509635]",time  band  mag  0.0  0  26.55779  +1199 rows  ...  ...
time,band,mag,,,,,
0.0,0,26.55779,,,,,
+1199 rows,...,...,,,,,
2,2,0.784115,6.919634,1.331035,0.941750,"[2.6785613870700895, 3.6405118607085574, 5.0724912083771985, 6.330727224359494, 7.333063965660035, 8.112609569134229]",time  band  mag  0.0  0  21.419539  +1199 rows  ...  ...
time,band,mag,,,,,

time,band,mag
0.0,0,18.682762
+1199 rows,...,...

time,band,mag
0.0,0,26.55779
+1199 rows,...,...

time,band,mag
0.0,0,21.419539
+1199 rows,...,...

time,band,mag
0.0,0,24.799694
+1199 rows,...,...

time,band,mag
0.0,0,22.936393
+1199 rows,...,...


In [8]:
from eztaox.kernels.quasisep import Exp
from eztaox.fitter import random_search
from eztaox.models import MultiVarModel

SNR = 1000.0

def run_eztaox(time, band, mag):
    times, mags, noisy_mags, mag_errs = {}, {}, {}, {}
    bands = list(np.unique(band)[[0, -1]])
    for b in bands:
        mask = band == b
        times[b] = time[mask]
        mags[b] = mag[mask]
        noisy_mags[b] = mags[b] + np.random.normal(0, 1, size=mags[b].shape) * mags[b] / SNR
        mag_errs[b] = mags[b] / SNR
    inds = jnp.argsort(jnp.concatenate([times[b] for b in bands]))
    X = (
        jnp.concatenate([times[b] for b in bands])[inds],
        jnp.concatenate(
            [i * jnp.ones_like(times[b], dtype=int) for i, b in enumerate(bands)]
        )[inds],
    )
    for b in bands:
        noisy_mags[b] = jnp.array(noisy_mags[b])
        noisy_mags[b] -= jnp.median(noisy_mags[b])

    y = jnp.concatenate([noisy_mags[b] for b in bands])[inds]
    yerr = jnp.concatenate([mag_errs[b] for b in bands])[inds]


    has_lag = True  # if fit interband lags
    zero_mean = True  # if fit a mean function
    nBand = len(bands)

    # initialize a GP kernel, note the initial parameters are not used in the fitting
    k = Exp(scale=100.0, sigma=1.0)
    m = MultiVarModel(X, y, yerr, k, nBand, has_lag=has_lag, zero_mean=zero_mean)

    def initSampler():
        # GP kernel param
        log_drw_scale = numpyro.sample(
            "drw_scale", dist.Uniform(jnp.log(0.01), jnp.log(1000))
        )
        log_drw_sigma = numpyro.sample(
            "drw_sigma", dist.Uniform(jnp.log(0.01), jnp.log(10))
        )
        log_kernel_param = jnp.stack([log_drw_scale, log_drw_sigma])
        numpyro.deterministic("log_kernel_param", log_kernel_param)

        # parameters to relate the amplitudes in each band
        log_amp_scale = numpyro.sample("log_amp_scale", dist.Uniform(-2, 2))

        mean = numpyro.sample(
            "mean",
            dist.Uniform(low=jnp.asarray([-0.1, -0.1]), high=jnp.asarray([0.1, 0.1])),
        )

        # interband lags
        lag = numpyro.sample("lag", dist.Uniform(-10, 10))

        sample_params = {
            "log_kernel_param": log_kernel_param,
            "log_amp_scale": log_amp_scale,
            "mean": mean,
            "lag": lag,
        }

        return sample_params
    model = m
    fit_key = jax.random.PRNGKey(1)
    nSample = 1_000
    nBest = 5  # it seems like this number needs to be high

    bestP, ll = random_search(model, initSampler, fit_key, nSample, nBest)

    return {"best_params": bestP, "log_likelihood": ll}


In [9]:
res = df.iloc[:10].map_rows(
    run_eztaox, columns=["light_curve.time", "light_curve.band", "light_curve.mag"], row_container="args"
)
res

Unnamed: 0,best_params,log_likelihood
0,"{'lag': 4.532301193077516, 'log_amp_scale': -0.09623848826986242, 'log_kernel_param': [4.195422715209656, -2.745403534759837], 'mean': [0.0968087660558603, 0.03435811476643238]}",938.4027397740308
1,"{'lag': 0.5130291017083336, 'log_amp_scale': -0.0054624080656410084, 'log_kernel_param': [3.867325056728691, -0.36272231147101397], 'mean': [-0.0237469984259449, -0.010574931568028712]}",329.91030957062753
2,"{'lag': 4.48741383825741, 'log_amp_scale': -0.25205400752429913, 'log_kernel_param': [3.9633722347126197, -2.7410955027578945], 'mean': [0.027819531856516557, -0.02511603695248024]}",887.7757780708171
3,"{'lag': 1.5013800138875437, 'log_amp_scale': -0.023241268625562334, 'log_kernel_param': [4.411725176467439, -0.5443366374003742], 'mean': [-0.018465344919473426, -0.005279119829656765]}",454.6323668442375
4,"{'lag': 7.50591773690792, 'log_amp_scale': -0.07588872397780824, 'log_kernel_param': [6.04107791288455, -1.0292009466415055], 'mean': [0.018557203127101518, 0.0624817650788653]}",641.3743834160646
5,"{'lag': 2.4637657350431716, 'log_amp_scale': -0.035802532939098075, 'log_kernel_param': [4.596122283046429, -0.23271638735135886], 'mean': [-0.03511726390175549, 0.03982690277189787]}",390.5895719404094
6,"{'lag': 0.7533616571117219, 'log_amp_scale': -0.028660945699909873, 'log_kernel_param': [3.334585393112959, -0.4552441537172663], 'mean': [0.06320468770200219, 0.06898092721919617]}",294.04907834266766
7,"{'lag': 5.441585256248297, 'log_amp_scale': -0.06384586189789385, 'log_kernel_param': [5.248038052168193, -1.5841064255295194], 'mean': [-0.025166519615345842, 0.04818903698857083]}",746.7637262796477
8,"{'lag': 4.419365435890954, 'log_amp_scale': -0.10220278623446949, 'log_kernel_param': [7.083035382190981, -1.3044095546371874], 'mean': [0.027819531856516557, -0.02511603695248024]}",890.36299129427
9,"{'lag': 7.498734865922704, 'log_amp_scale': -0.12228195743115797, 'log_kernel_param': [4.1536493352771044, -0.22987043824467723], 'mean': [-0.07321655979687614, -0.09977668840628025]}",332.72285174831654


In [10]:
df['lag'].iloc[:10]

0    [1.82123593 2.47946301 3.51353292 4.53015176 5...
1    [0.21918881 0.2817659  0.38219335 0.48386308 0...
2    [2.67856139 3.64051186 5.07249121 6.33072722 7...
3    [0.54609444 0.73203963 1.02645964 1.3220506  1...
4    [ 7.97327769  9.6016965  11.31389374 12.427101...
5    [0.82466225 1.11566874 1.57613279 2.0382705  2...
6    [0.26623999 0.34765564 0.47782724 0.6092818  0...
7    [1.87562207 2.55352454 3.61176364 4.6285085  5...
8    [ 3.49328897  4.75812787  6.63636557  8.260680...
9    [ 6.89691429  8.31472362 10.23132393 11.705777...
Name: lag, dtype: list<element: double>[pyarrow]

In [11]:
res.map_rows(lambda x: {"lag": x["best_params"]['lag'], "log_kernel_param": x["best_params"]['log_kernel_param']})

Unnamed: 0,lag,log_kernel_param
0,4.532301193077516,"[4.195422715209656, -2.745403534759837]"
1,0.5130291017083336,"[3.867325056728691, -0.36272231147101397]"
2,4.48741383825741,"[3.9633722347126197, -2.7410955027578945]"
3,1.5013800138875435,"[4.411725176467439, -0.5443366374003742]"
4,7.50591773690792,"[6.04107791288455, -1.0292009466415055]"
5,2.463765735043172,"[4.596122283046429, -0.23271638735135886]"
6,0.7533616571117219,"[3.334585393112959, -0.4552441537172663]"
7,5.441585256248297,"[5.248038052168193, -1.5841064255295194]"
8,4.419365435890954,"[7.083035382190981, -1.3044095546371874]"
9,7.498734865922704,"[4.1536493352771044, -0.22987043824467723]"
