In [1]:
import tomllib

import dill
import jax
jax.config.update("jax_enable_x64", True)

%matplotlib inline

In [2]:
from burst import BurstParameters, load_basis, simulate_population
from utils import segment_times

In [3]:
config = tomllib.load(open("sinegaussian_selection.toml", "rb"))

bounds = config["bounds"]
duration = config["duration"]
sample_rate = config["sample_rate"]
times = segment_times(duration, sample_rate)

In [4]:
variables = bounds.keys()

rng_key = jax.random.PRNGKey(10)

print(bounds)

{'amplitude': [110, 5], 'frequency': [5, 0.2], 'bandwidth': [0.5, 0.01], 'phase': [0, 3.141592653589793], 'delta_t': [0.0, 0.01]}


In [5]:
basis, weights = load_basis("sinegaussian_svd_250502.pkl", truncation=config["truncation"])

(2000, 20) (20, 1024)


In [6]:
true_mean = 5
all_divs = list()
offset = config["configurations"]["offset"]

In [7]:

simulate_kwargs = dict(
    basis=basis,
    projection=weights,
    bounds=bounds,
    times=times,
    duration=duration,
    sample_rate=sample_rate,
    time_align=True,
)

for (true_sigma, snr_threshold) in zip(
    config["configurations"]["sigma"], config["configurations"]["threshold"]
):
    label = f"{true_sigma}_{snr_threshold}"

    events, truths, keep, rng_key, filtered = simulate_population(
        rng_key,
        mean=true_mean,
        sigma=true_sigma,
        offset=offset,
        threshold=snr_threshold,
        n_events=config["simulate"]["n_events"] * (1 + (snr_threshold > 1)),
        **simulate_kwargs,
    )
    print(truths["frequency"].std())
    print(sum(keep))
    events = events[keep]
    filtered = filtered[keep]
    truths = BurstParameters(**{k: truths[k][keep] for k in variables})

    print(f"Writing data and parameters to data_{label}.pkl and params_{label}.pkl")
    with open(f"params_{label}.pkl", "wb") as f:
        dill.dump(truths, f)
    with open(f"data_{label}.pkl", "wb") as f:
        dill.dump(filtered, f)
    with open(f"events_{label}.pkl", "wb") as f:
        dill.dump(events, f)

2.955213642251156 0.08973178179879919
0.29835672142929226
0
Writing data and parameters to data_0.3_8.pkl and params_0.3_8.pkl
2.9528934760172962 0.09795205818273145
0.401775807430262
0
Writing data and parameters to data_0.4_8.pkl and params_0.4_8.pkl
2.9595129010574075 0.11926266413664297
0.5933618876278189
0
Writing data and parameters to data_0.6_8.pkl and params_0.6_8.pkl
2.9614582763092057 0.14510988906284902
0.7875178262805171
0
Writing data and parameters to data_0.8_8.pkl and params_0.8_8.pkl
2.9556957595073756 0.09259330783600461
0.3493128626495368
2000
Writing data and parameters to data_0.35_0.pkl and params_0.35_0.pkl
2.9535827011764058 0.10262752345693513
0.4419552923563357
2000
Writing data and parameters to data_0.45_0.pkl and params_0.45_0.pkl
2.9552922490334956 0.1264455129094698
0.6532798180158002
2000
Writing data and parameters to data_0.65_0.pkl and params_0.65_0.pkl
2.9701574462276894 0.15538138222416617
0.8670507234654697
2000
Writing data and parameters to data