/
__init__.py
142 lines (122 loc) · 5.37 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""The `.data` module takes care of data generation."""
from __future__ import annotations
import logging
import numpy as np
from tqdm.auto import tqdm
from tensorwaves.interface import (
DataGenerator,
DataSample,
DataTransformer,
Function,
RealNumberGenerator,
)
from ._data_sample import (
finalize_progress_bar,
get_number_of_events,
merge_events,
select_events,
)
# pyright: reportUnusedImport=false
from .phasespace import (
TFPhaseSpaceGenerator, # noqa: F401
TFWeightedPhaseSpaceGenerator, # noqa: F401
)
from .rng import NumpyUniformRNG, TFUniformRealNumberGenerator # noqa: F401
from .transform import IdentityTransformer, SympyDataTransformer # noqa: F401
_LOGGER = logging.getLogger(__name__)
class NumpyDomainGenerator(DataGenerator):
"""Generate a uniform `.DataSample` as a domain for a `.Function`.
Args:
boundaries: A mapping of the keys in the `.DataSample` that is to be
generated. The boundaries have to be a `tuple` of a minimum and a maximum
value that define the range for each key in the `.DataSample`.
"""
def __init__(self, boundaries: dict[str, tuple[float, float]]) -> None:
self.__boundaries = boundaries
def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
return {
var_name: rng(size, min_value, max_value)
for var_name, (min_value, max_value) in self.__boundaries.items()
}
class IntensityDistributionGenerator(DataGenerator):
"""Generate an hit-and-miss `.DataSample` distribution for a `.Function`.
Args:
domain_generator: A `.DataGenerator` that can be used to generate a
**domain** `.DataSample` over which to evaluate the :code:`function`.
function: An **intensity** `.Function` with which the output
distribution `.DataSample` is generated using a :ref:`hit-and-miss strategy
<usage/basics:Hit & miss>`.
domain_transformer: Optional `.DataTransformer` that can convert a generated
**domain** `.DataSample` to a `.DataSample` that the :code:`function` can
take as input.
bunch_size: Size of a bunch that is generated during a hit-and-miss
iteration.
"""
def __init__(
self,
domain_generator: DataGenerator,
function: Function[DataSample, np.ndarray],
domain_transformer: DataTransformer | None = None,
bunch_size: int = 50_000,
) -> None:
self.__domain_generator = domain_generator
if domain_transformer is not None:
self.__domain_transformer = domain_transformer
else:
self.__domain_transformer = IdentityTransformer()
self.__function = function
self.__bunch_size = bunch_size
def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
progress_bar = tqdm(
total=size,
desc="Generating intensity-based sample",
disable=_LOGGER.level > logging.WARNING,
)
returned_data: DataSample = {}
current_max_intensity = 0.0
while get_number_of_events(returned_data) < size:
data_bunch, bunch_max = self._generate_bunch(rng)
if bunch_max > current_max_intensity:
current_max_intensity = 1.05 * bunch_max
if get_number_of_events(returned_data) > 0:
_LOGGER.info(
f"Processed bunch maximum of {bunch_max} is over"
f" current maximum {current_max_intensity}. Restarting"
" generation!"
)
returned_data = {}
# reset progress bar
progress_bar.update(n=-progress_bar.n)
continue
if len(returned_data):
returned_data = merge_events(returned_data, data_bunch)
else:
returned_data = data_bunch
progress_bar.update(n=get_number_of_events(returned_data) - progress_bar.n)
finalize_progress_bar(progress_bar)
return select_events(returned_data, selector=slice(None, size))
def _generate_bunch(self, rng: RealNumberGenerator) -> tuple[DataSample, float]:
domain = _generate_without_progress_bar(
self.__domain_generator, self.__bunch_size, rng
)
transformed_domain = self.__domain_transformer(domain)
computed_intensities = self.__function(transformed_domain)
max_intensity: float = np.max(computed_intensities)
random_intensities = rng(size=self.__bunch_size, max_value=max_intensity)
weights = domain.get("weights", 1)
hit_and_miss_sample = select_events(
domain,
selector=weights * computed_intensities > random_intensities,
)
return hit_and_miss_sample, max_intensity
def _generate_without_progress_bar(
domain_generator: DataGenerator, bunch_size: int, rng: RealNumberGenerator
) -> DataSample:
# https://github.com/ComPWA/tensorwaves/issues/395
show_progress = getattr(domain_generator, "show_progress", None)
if show_progress is not None:
domain_generator.show_progress = False # type: ignore[attr-defined]
domain = domain_generator.generate(bunch_size, rng)
if show_progress is not None:
domain_generator.show_progress = show_progress # type: ignore[attr-defined]
return domain