Skip to content

Commit 77408f3

Browse files
stephengreenannalena-kmax-dax
authored
Multibanding (#292)
* Implement MultibandedFrequencyDomain. Lots of code taken from Max's implementation. Not sure this works with the new waveform interface. * Enable inference for MultibandedFrequencyDomain. * Add brute force phase grid to likelihood (for MultibandedFrequencyDomain). * Add toggle for uniform vs multi-banded domain for Result, Injection classes. * Raise NotImplementedError if calling torch_sample_frequencies in MFD. * Update MFD for batched time translation. * Fix domain_update test for ASD dataset. * Fix domain_update test for ASD dataset. * Whiten data before decimating at inference time. * Improve usage of MultibandedFrequencyDomain / storage of event data. No longer use context setter in Sampler, instead always store event data as it was generated. Use transform in Sampler to decimate data. Ensure DecimateWaveformsAndASDS does nothing to data already in MFD. * Update dingo-pipe for MFD. By default use the base domain. Can be changed with option `importance-sampling-settings = {use_base_domain: False}`. * Implement torch frequency arrays for MFD, update numpy arrays to float32. This is needed for GNPE sampling. * Extend waveform generator to multi-banded frequency domain for all relevant waveform approximants. For TD-native waveforms, generate in FD then decimate. Also extend methods that generate waveform modes. Add relevant unit tests. * Remove IrregularFrequencyDomain. * Add tests for MFD. * Add additional test. * Restructure UniformFrequencyDomain and MultibandedFrequencyDomain to inherit from common BaseFrequencyDomain superclass. * Fix bug in build_domain(). * Resolve several comments in PR. * Resolve several more comments in PR. * Address several additional comments, fix some docstrings. * Updated link in documentation to point to public notebook. * Improve check for TD waveforms for decimation method. * Fix bug in recursive_hdf5_load and other small changes. * Remove adaptive MFD generation. * Remove adaptive MFD generation. * Fix test. * Add diagnostic script to test MFD performance. * Improve MFD diagnostic script. * fix bug in error message * fix error when checking time domain type * Update stripping of window factor from ASD dataset for MFD. * elif -> if in Injection.asd setter --------- Co-authored-by: Annalena Kofler <annalena.kofler@tuebingen.mpg.de> Co-authored-by: Maximilian Dax <maximilian.dax@tuebingen.mpg.de>
1 parent 90cb85f commit 77408f3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2606
-946
lines changed

dingo/core/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def from_file(self, file_name: str):
162162
# Set the keys that the class expects
163163
for k, v in loaded_dict.items():
164164
assert k in self._data_keys
165-
vars(self)[k] = v
165+
setattr(self, k, v)
166166
try:
167167
self.settings = ast.literal_eval(f.attrs["settings"])
168168
except KeyError:
@@ -179,6 +179,6 @@ def to_dictionary(self):
179179
def from_dictionary(self, dictionary: dict):
180180
for k, v in dictionary.items():
181181
if k in self._data_keys or k == "settings":
182-
vars(self)[k] = v
182+
setattr(self, k, v)
183183
if "version" not in dictionary:
184184
self.version = f"dingo={get_version()}"

dingo/core/samplers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ def __init__(
7575
# is needed for calculating the likelihood for importance sampling.
7676
# However, it will not be used when sampling from the model, since it is
7777
# unconditional.
78-
self.context = self.model.context
78+
self._context = self.model.context
7979
self.event_metadata = self.model.event_metadata
8080
self.base_model_metadata = self.metadata["base"]
8181
else:
8282
self.unconditional_model = False
83-
self.context = None
83+
self._context = None
8484
self.event_metadata = None
8585
self.base_model_metadata = self.metadata
8686

dingo/gw/data/data_preparation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from dingo.core.utils.misc import recursive_check_dicts_are_equal
88
from dingo.gw.data.data_download import download_raw_data
99
from dingo.gw.gwutils import get_window
10-
from dingo.gw.domains import build_domain_from_model_metadata, FrequencyDomain
10+
from dingo.gw.domains import UniformFrequencyDomain
11+
from dingo.gw.domains import build_domain_from_model_metadata
1112

1213

1314
def load_raw_data(time_event, settings, event_dataset=None):
@@ -35,7 +36,6 @@ def load_raw_data(time_event, settings, event_dataset=None):
3536

3637
# first try to load the event data from the saved dataset
3738
if event_dataset is not None:
38-
3939
if isfile(event_dataset):
4040
dataset = DingoDataset(file_name=event_dataset, data_keys=[event])
4141
if settings is not None:
@@ -66,7 +66,7 @@ def load_raw_data(time_event, settings, event_dataset=None):
6666
def parse_settings_for_raw_data(model_metadata, time_psd, time_buffer):
6767
domain_type = model_metadata["dataset_settings"]["domain"]["type"]
6868

69-
if domain_type == "FrequencyDomain":
69+
if domain_type == "UniformFrequencyDomain":
7070
data_settings = model_metadata["train_settings"]["data"]
7171
settings = {
7272
"window": data_settings["window"],
@@ -98,7 +98,7 @@ def data_to_domain(raw_data, settings_raw_data, domain, **kwargs):
9898
9999
"""
100100

101-
if type(domain) == FrequencyDomain:
101+
if isinstance(domain, UniformFrequencyDomain):
102102
window = get_window(kwargs["window"])
103103
data = {"waveform": {}, "asds": {}}
104104
# convert event strains to frequency domain
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import argparse
2+
3+
import numpy as np
4+
import yaml
5+
from scipy.interpolate import interp1d
6+
7+
from dingo.gw.dataset import generate_parameters_and_polarizations
8+
from dingo.gw.domains import build_domain, MultibandedFrequencyDomain
9+
from dingo.gw.gwutils import get_mismatch
10+
from dingo.gw.prior import build_prior_with_defaults
11+
from dingo.gw.waveform_generator import (
12+
NewInterfaceWaveformGenerator,
13+
WaveformGenerator,
14+
generate_waveforms_parallel,
15+
)
16+
17+
18+
def _evaluate_multibanding_main(
19+
settings_file: str,
20+
num_samples: int,
21+
):
22+
with open(settings_file, "r") as f:
23+
settings = yaml.safe_load(f)
24+
25+
# Ignore any compression settings
26+
if "compression" in settings:
27+
del settings["compression"]
28+
29+
# Update prior to challenge the multi-banding:
30+
#
31+
# (a) Set geocent_time = 0.12 s (boundary of usual prior + Earth-radius crossing time)
32+
# (b) Set chirp mass to bottom end of prior.
33+
prior = build_prior_with_defaults(settings["intrinsic_prior"])
34+
settings["intrinsic_prior"]["geocent_time"] = 0.12
35+
settings["intrinsic_prior"]["chirp_mass"] = prior["chirp_mass"].minimum
36+
# Rebuild prior with updated settings.
37+
prior = build_prior_with_defaults(settings["intrinsic_prior"])
38+
print("Prior")
39+
for k, v in prior.items():
40+
print(f"{k}: {v}")
41+
42+
domain = build_domain(settings["domain"])
43+
print("\nDomain")
44+
print(domain.domain_dict)
45+
46+
if not isinstance(domain, MultibandedFrequencyDomain):
47+
raise ValueError("Waveform dataset domain not a MultibandedFrequencyDomain.")
48+
49+
if settings["waveform_generator"].get("new_interface", False):
50+
waveform_generator_mfd = NewInterfaceWaveformGenerator(
51+
domain=domain,
52+
**settings["waveform_generator"],
53+
)
54+
waveform_generator_ufd = NewInterfaceWaveformGenerator(
55+
domain=domain.base_domain,
56+
**settings["waveform_generator"],
57+
)
58+
else:
59+
waveform_generator_mfd = WaveformGenerator(
60+
domain=domain,
61+
**settings["waveform_generator"],
62+
)
63+
waveform_generator_ufd = WaveformGenerator(
64+
domain=domain.base_domain,
65+
**settings["waveform_generator"],
66+
)
67+
68+
# Generate MFD waveforms.
69+
parameters, polarizations_mfd = generate_parameters_and_polarizations(
70+
waveform_generator_mfd, prior, num_samples, 1
71+
)
72+
73+
# Generate UFD waveforms, re-using the parameter choices from before.
74+
polarizations_ufd = generate_waveforms_parallel(waveform_generator_ufd, parameters)
75+
76+
# Compare UFD waveforms against MFD waveforms interpolated to MFD.
77+
mismatches = {}
78+
ufd = domain.base_domain
79+
mfd = domain
80+
for pol, d in polarizations_mfd.items():
81+
mismatches[pol] = np.empty(len(d))
82+
for i in range(len(d)):
83+
mfd_interpolated = interp1d(mfd(), d[i], fill_value="extrapolate")(ufd())
84+
mismatches[pol][i] = get_mismatch(
85+
polarizations_ufd[pol][i],
86+
mfd_interpolated,
87+
ufd,
88+
asd_file="aLIGO_ZERO_DET_high_P_asd.txt",
89+
)
90+
91+
print("\nMismatches between UFD waveforms and MFD waveforms interpolated to MFD.")
92+
print(
93+
"This is a conservative estimate of the MFD performance when training "
94+
"networks."
95+
)
96+
mismatches = np.concatenate([v for v in mismatches.values()])
97+
print(f"num_samples = {num_samples}")
98+
print(" Mean mismatch = {}".format(np.mean(mismatches)))
99+
print(" Standard deviation = {}".format(np.std(mismatches)))
100+
print(" Max mismatch = {}".format(np.max(mismatches)))
101+
print(" Median mismatch = {}".format(np.median(mismatches)))
102+
print(" Percentiles:")
103+
print(" 99 -> {}".format(np.percentile(mismatches, 99)))
104+
print(" 99.9 -> {}".format(np.percentile(mismatches, 99.9)))
105+
print(" 99.99 -> {}".format(np.percentile(mismatches, 99.99)))
106+
107+
108+
def parse_args():
109+
parser = argparse.ArgumentParser(
110+
formatter_class=argparse.RawDescriptionHelpFormatter,
111+
description="Evaluate performance of multibanding on waveform dataset.",
112+
)
113+
parser.add_argument(
114+
"--settings-file",
115+
type=str,
116+
required=True,
117+
help="YAML file containing database settings",
118+
)
119+
parser.add_argument(
120+
"--num-samples",
121+
type=int,
122+
default=5000,
123+
help="Number of waveform evaluations for comparison.",
124+
)
125+
return parser.parse_args()
126+
127+
128+
def main() -> None:
129+
args = parse_args()
130+
_evaluate_multibanding_main(args.settings_file, args.num_samples)

dingo/gw/dataset/generate_dataset.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from dingo.core.utils.misc import call_func_strict_output_dim
2727

28+
2829
def generate_parameters_and_polarizations(
2930
waveform_generator: WaveformGenerator,
3031
prior: BBHPriorDict,
@@ -198,11 +199,14 @@ def generate_dataset(settings: Dict, num_processes: int) -> WaveformDataset:
198199
n_test = svd_settings.get("num_validation_samples", 0)
199200

200201
func = partial(
201-
generate_parameters_and_polarizations,
202-
waveform_generator,
202+
generate_parameters_and_polarizations,
203+
waveform_generator,
203204
prior,
204-
num_processes=num_processes)
205-
parameters, polarizations = call_func_strict_output_dim(func, n_train + n_test)
205+
num_processes=num_processes,
206+
)
207+
parameters, polarizations = call_func_strict_output_dim(
208+
func, n_train + n_test
209+
)
206210
svd_dataset_settings = copy.deepcopy(settings)
207211
svd_dataset_settings["num_samples"] = len(parameters)
208212
del svd_dataset_settings["compression"]["svd"]
@@ -231,11 +235,14 @@ def generate_dataset(settings: Dict, num_processes: int) -> WaveformDataset:
231235
waveform_generator.transform = Compose(compression_transforms)
232236

233237
func = partial(
234-
generate_parameters_and_polarizations,
235-
waveform_generator,
238+
generate_parameters_and_polarizations,
239+
waveform_generator,
236240
prior,
237-
num_processes=num_processes)
238-
parameters, polarizations = call_func_strict_output_dim(func, settings["num_samples"])
241+
num_processes=num_processes,
242+
)
243+
parameters, polarizations = call_func_strict_output_dim(
244+
func, settings["num_samples"]
245+
)
239246
dataset_dict["parameters"] = parameters
240247
dataset_dict["polarizations"] = polarizations
241248

@@ -277,7 +284,6 @@ def parse_args():
277284
def _generate_dataset_main(
278285
settings_file: str, out_file: str, num_processes: int
279286
) -> None:
280-
281287
if not Path(settings_file).is_file():
282288
raise FileNotFoundError(f"dataset generation, failed to find {settings_file}")
283289
if not Path(out_file).parent.is_dir():

0 commit comments

Comments
 (0)