In [1]:
import os

import dill
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as np
import matplotlib.pyplot as plt
from numpyro.infer import MCMC, NUTS
from tqdm.auto import tqdm

%matplotlib inline

In [2]:
from burst import (
    BurstParameters,
    construct_numpyro_model,
    load_basis,
    simulate_population,
)
from utils import (
    approximate_divergence,
    segment_times,
)

In [3]:
bounds = dict(
    amplitude=(10, 80),
    frequency=(1, 9),
    bandwidth=(0.3, 0.7),
    phase=(0, np.pi),
    delta_t=(-0.1, 0.1),
)
duration = 4
sample_rate = 256
times = segment_times(duration, sample_rate)

numpyro_model = construct_numpyro_model(duration, sample_rate, bounds.copy())
# del bounds["delta_t"]

In [4]:
variables = ["amplitude", "frequency", "bandwidth", "phase", "delta_t"]

mcmc_kwargs = dict(
    num_warmup=100,
    num_samples=500,
    num_chains=1,
    jit_model_args=True,
    progress_bar=False,
)

mcmc = MCMC(NUTS(numpyro_model), **mcmc_kwargs)

rng_key = jax.random.PRNGKey(10)

In [None]:
basis, weights, projection = load_basis("sinegaussian_svd.pkl", truncation=20)

In [None]:
true_mean = 5
true_sigma = 0.1
fpeaks = np.linspace(4.8, 5.2, 100)
all_divs = list()
offset = 0.0

In [5]:

simulate_kwargs = dict(
    basis=basis,
    project=projection,
    bounds=bounds,
    times=times,
    duration=duration,
    sample_rate=sample_rate,
    time_align=True,
)

polyfits = dict()

for true_sigma in [0.3, 0.4, 0.6, 0.8, 0.35, 0.45, 0.65, 0.85]:
    if true_sigma * 10 % 1 == 0:
        snr_threshold = 8
        pdets = np.array([
            np.mean(simulate_population(
                rng_key,
                mean=true_mean,
                sigma=true_sigma,
                offset=offset,
                threshold=snr_threshold,
                n_events=20000,
                **simulate_kwargs,
            )[2])
            for true_mean in fpeaks
        ])
        pfit = np.polyfit(fpeaks, np.log(pdets), 4)
        polyfits[true_sigma] = pfit
    else:
        snr_threshold = 0

    label = f"{true_sigma}_{snr_threshold}"

    if os.path.exists(f"data_{label}.pkl") and False:
        with open(f"data_{label}.pkl", "rb") as f:
            events = dill.load(f)
        with open(f"params_{label}.pkl", "rb") as f:
            truths = dill.load(f)
    else:
        events, truths, keep, rng_key, filtered = simulate_population(
            rng_key,
            mean=true_mean,
            sigma=true_sigma,
            offset=offset,
            threshold=snr_threshold,
            n_events=3000 + 2000 * (snr_threshold // 8),
            **simulate_kwargs,
        )
        events = events[keep]
        filtered = filtered[keep]
        truths = BurstParameters(**{k: truths[k][keep] for k in variables})

        print(f"Writing data and parameters to data_{label}.pkl and params_{label}.pkl")
        with open(f"params_{label}.pkl", "wb") as f:
            dill.dump(truths, f)
        with open(f"data_{label}.pkl", "wb") as f:
            # dill.dump(events, f)
            dill.dump(filtered, f)

    fpeak_posteriors = list()
    for event, truth in zip(tqdm(events), truths):
        rng_key, subkey = jax.random.split(rng_key)
        mcmc.run(subkey, event.squeeze())
        fpeak_posteriors.append(mcmc.get_samples()["frequency"])

    fpeak_posteriors = np.array(fpeak_posteriors)
    np.save(f"fpeak_posteriors_{label}.npy", fpeak_posteriors)
    all_divs.append(np.array([
        approximate_divergence(fpeak_posteriors.T, mean_frequency, sigma_frequency=true_sigma)
        for mean_frequency in fpeaks
    ]))

    plt.plot(fpeaks, all_divs[-1] - min(all_divs[-1]))
plt.show()
plt.close()

with open("polyfits.pkl", "wb") as f:
    dill.dump(polyfits, f)

Writing data and parameters to data_0.3_8.pkl and params_0.3_8.pkl


  0%|          | 0/2559 [00:00<?, ?it/s]