Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: merge WeightedDataGenerator into DataGenerator #458

Merged
merged 5 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions docs/amplitude-analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"::::{margin}\n",
":::{tip}\n",
"{doc}`TR-018<compwa-org:report/018>` explains some of the mechanisms behind the phase space generator as well as how to do {ref}`importance sampling<compwa-org:report/018:Intensity distribution>`.\n",
":::\n",
"::::\n",
"\n",
"In this section, we use the {class}`~ampform.helicity.HelicityModel` that we created with {mod}`ampform` in {ref}`the previous step <compwa-step-1>` to generate a data sample via hit & miss Monte Carlo. We do this with the {mod}`.data` module.\n",
"\n",
"First, we {func}`~pickle.load` the {class}`~ampform.helicity.HelicityModel` that was created in the previous step. This does not have to be done if the model has been generated in the same script or notebook, but can be useful if the model was generated elsewhere."
Expand Down Expand Up @@ -353,7 +359,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The {class}`~qrules.transition.ReactionInfo` class defines the constraints of the phase space. As such, we have enough information to generate a **phase-space sample** for this particle reaction. We do this with a {class}`.TFPhaseSpaceGenerator` class, which is an implementation of the {class}`.DataGenerator` for a {obj}`.DataSample` of **four-momenta** arrays (using {obj}`tensorflow <tf.Tensor>` and the [`phasespace`](https://phasespace.readthedocs.io) package as a back-end). We also need to construct a {class}`.RealNumberGenerator` that can generate random numbers. {class}`.TFUniformRealNumberGenerator` is the natural choice here.\n",
"The {class}`~qrules.transition.ReactionInfo` class defines the constraints of the phase space. As such, we have enough information to generate a **phase-space sample** for this particle reaction. We do this with a {class}`.TFPhaseSpaceGenerator` class, which is a {class}`.DataGenerator` for a {obj}`.DataSample` of **four-momenta** arrays (using {obj}`tensorflow <tf.Tensor>` and the [`phasespace`](https://phasespace.readthedocs.io) package as a back-end). We also need to construct a {class}`.RealNumberGenerator` that can generate random numbers. {class}`.TFUniformRealNumberGenerator` is the natural choice here.\n",
"\n",
"As opposed to the main {ref}`amplitude-analysis:Step 2: Generate data` of the main usage example page, we will generate a **deterministic** data sample. This can be done by feeding a {class}`.RealNumberGenerator` with a specific {attr}`~.RealNumberGenerator.seed` and giving that generator to the {meth}`.TFPhaseSpaceGenerator.generate` method:"
]
Expand Down Expand Up @@ -1935,8 +1941,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
8 changes: 8 additions & 0 deletions docs/amplitude-analysis/analytic-continuation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,15 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
Expand Down
10 changes: 9 additions & 1 deletion docs/usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -857,8 +857,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion docs/usage/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1291,8 +1291,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion docs/usage/binned-fit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion docs/usage/caching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion docs/usage/chi-squared.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion docs/usage/faster-lambdify.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion docs/usage/unbinned-fit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
19 changes: 7 additions & 12 deletions src/tensorwaves/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
DataTransformer,
Function,
RealNumberGenerator,
WeightedDataGenerator,
)

from ._data_sample import (
Expand Down Expand Up @@ -71,7 +70,7 @@ class IntensityDistributionGenerator(DataGenerator):

def __init__(
self,
domain_generator: DataGenerator | WeightedDataGenerator,
domain_generator: DataGenerator,
function: Function,
domain_transformer: DataTransformer | None = None,
bunch_size: int = 50_000,
Expand Down Expand Up @@ -115,18 +114,14 @@ def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
return select_events(returned_data, selector=slice(None, size))

def _generate_bunch(self, rng: RealNumberGenerator) -> tuple[DataSample, float]:
domain_generator = self.__domain_generator
if isinstance(domain_generator, WeightedDataGenerator):
domain, weights = domain_generator.generate(self.__bunch_size, rng)
else:
domain = _generate_without_progress_bar(
domain_generator, self.__bunch_size, rng
)
weights = 1 # type: ignore[assignment]
domain = _generate_without_progress_bar(
self.__domain_generator, self.__bunch_size, rng
)
transformed_domain = self.__domain_transformer(domain)
computed_intensities = self.__function(transformed_domain)
max_intensity: float = np.max(computed_intensities)
random_intensities = rng(size=self.__bunch_size, max_value=max_intensity)
weights = domain.get("weights", 1)
hit_and_miss_sample = select_events(
domain,
selector=weights * computed_intensities > random_intensities,
Expand All @@ -139,9 +134,9 @@ def _generate_without_progress_bar(
) -> DataSample:
# https://github.com/ComPWA/tensorwaves/issues/395
show_progress = getattr(domain_generator, "show_progress", None)
if show_progress:
if show_progress is not None:
domain_generator.show_progress = False # type: ignore[attr-defined]
domain = domain_generator.generate(bunch_size, rng)
if show_progress:
if show_progress is not None:
domain_generator.show_progress = show_progress # type: ignore[attr-defined]
return domain
39 changes: 22 additions & 17 deletions src/tensorwaves/data/phasespace.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
# pylint: disable=import-outside-toplevel
"""Implementations of `.DataGenerator` and `.WeightedDataGenerator`."""
"""Implementations of a `.DataGenerator` for four-momentum samples."""
from __future__ import annotations

import logging
from typing import Mapping

import numpy as np
from tqdm.auto import tqdm

from tensorwaves.function._backend import raise_missing_module_error
from tensorwaves.interface import (
DataGenerator,
DataSample,
RealNumberGenerator,
WeightedDataGenerator,
)
from tensorwaves.interface import DataGenerator, DataSample, RealNumberGenerator

from ._data_sample import (
finalize_progress_bar,
Expand Down Expand Up @@ -64,20 +58,30 @@ def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
)
momentum_pool: DataSample = {}
while get_number_of_events(momentum_pool) < size:
phsp_momenta, weights = self.__phsp_generator.generate(
self.__bunch_size, rng
)
phsp_momenta = self.__phsp_generator.generate(self.__bunch_size, rng)
weights = phsp_momenta.get("weights")
if weights is None:
raise ValueError(
"DataSample returned by"
f" {type(self.__phsp_generator).__name__} doesn't contain"
' "weights"'
)
hit_and_miss_randoms = rng(self.__bunch_size)
bunch = select_events(phsp_momenta, selector=weights > hit_and_miss_randoms)
momentum_pool = merge_events(momentum_pool, bunch)
progress_bar.update(n=get_number_of_events(bunch))
finalize_progress_bar(progress_bar)
return select_events(momentum_pool, selector=slice(None, size))
phsp = select_events(momentum_pool, selector=slice(None, size))
del phsp["weights"]
return phsp


class TFWeightedPhaseSpaceGenerator(WeightedDataGenerator):
class TFWeightedPhaseSpaceGenerator(DataGenerator):
"""Implements a phase space generator **with weights** using tensorflow.

The weights are provided in the returned `.DataSample` under the key
:code:`"weights"`.

Args:
initial_state_mass: Mass of the decaying state.
final_state_masses: A mapping of final state IDs to the corresponding masses.
Expand All @@ -102,9 +106,7 @@ def __init__(
names=list(map(str, sorted_ids)),
)

def generate(
self, size: int, rng: RealNumberGenerator
) -> tuple[DataSample, np.ndarray]:
def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
r"""Generate a `.DataSample` of phase space four-momenta with weights.

Returns:
Expand All @@ -122,4 +124,7 @@ def generate(
f"p{label}": momenta.numpy()[:, [3, 0, 1, 2]]
for label, momenta in particles.items()
}
return phsp_momenta, weights.numpy()
return {
"weights": weights.numpy(),
**phsp_momenta,
}
5 changes: 4 additions & 1 deletion src/tensorwaves/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def __init__( # pylint: disable=too-many-arguments
backend: str = "numpy",
) -> None:
self.__data = {k: np.array(v) for k, v in data.items()}
self.__phsp = {k: np.array(v) for k, v in phsp.items()}
self.__phsp = {k: np.array(v) for k, v in phsp.items() if k != "weights"}
self.__phsp_weights = phsp.get("weights")
self.__function = function
self.__gradient = gradient_creator(self.__call__, backend)

Expand All @@ -207,6 +208,8 @@ def __call__(self, parameters: Mapping[str, ParameterValue]) -> float:
self.__function.update_parameters(parameters)
bare_intensities = self.__function(self.__data)
phsp_intensities = self.__function(self.__phsp)
if self.__phsp_weights is not None:
phsp_intensities *= self.__phsp_weights
normalization_factor = 1.0 / (
self.__phsp_volume * self.__mean_function(phsp_intensities)
)
Expand Down
12 changes: 1 addition & 11 deletions src/tensorwaves/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,17 +231,7 @@ class DataGenerator(ABC):

@abstractmethod
def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
...


class WeightedDataGenerator(ABC):
"""Abstract class for generating a `.DataSample` with weights."""

@abstractmethod
def generate(
self, size: int, rng: RealNumberGenerator
) -> tuple[DataSample, np.ndarray]:
r"""Generate `.DataSample` with weights.
r"""Generate a `.DataSample` with :code:`size` events.

Returns:
A `tuple` of a `.DataSample` with an array of weights.
Expand Down
4 changes: 3 additions & 1 deletion tests/data/test_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# pylint: disable=import-outside-toplevel
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -96,7 +98,7 @@ def test_generate_four_momenta_on_flat_distribution(self):
assert pytest.approx(phsp[i]) == data[i]


def test_generate_without_progress_bar(capsys: "CaptureFixture"):
def test_generate_without_progress_bar(capsys: CaptureFixture):
class SilentGenerator(DataGenerator):
def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
return {"x": 1} # type: ignore[dict-item]
Expand Down
5 changes: 4 additions & 1 deletion tests/data/test_phasespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ def test_generate_deterministic(self, pdg: "ParticleCollection"):
i: pdg[name].mass for i, name in enumerate(final_state_names)
},
)
phsp_momenta, weights = phsp_generator.generate(sample_size, rng)
phsp_momenta = phsp_generator.generate(sample_size, rng)
assert list(phsp_momenta) == ["weights", "p0", "p1", "p2"]
weights = phsp_momenta.get("weights", [])
del phsp_momenta["weights"]
print("Expected values, get by running pytest with the -s flag")
pprint(
{
Expand Down