Skip to content

Commit

Permalink
feat: provide time estimate data generation
Browse files Browse the repository at this point in the history
Fixes the tqdm progress bar so that it gives a time estimate
  • Loading branch information
redeboer committed Jun 10, 2021
1 parent 493c7de commit fe4f828
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions src/tensorwaves/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""The `.data` module takes care of data generation."""

import logging
import math
from typing import Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -85,7 +84,7 @@ def generate_data(
random_generator = TFUniformRealNumberGenerator()

progress_bar = tqdm(
total=math.ceil(size / bunch_size),
total=size,
desc="Generating intensity-based sample",
disable=logging.getLogger().level > logging.WARNING,
)
Expand All @@ -109,14 +108,14 @@ def generate_data(
current_max,
)
momentum_pool = EventCollection({})
progress_bar.update()
progress_bar.update(n=-progress_bar.n) # reset progress bar
continue
if np.size(momentum_pool, 0) > 0:
momentum_pool.append(bunch)
else:
momentum_pool = bunch
progress_bar.update()
progress_bar.close()
progress_bar.update(n=momentum_pool.n_events - progress_bar.n)
finalize_progress_bar(progress_bar)
return momentum_pool.select_events(slice(0, size))


Expand Down Expand Up @@ -148,7 +147,7 @@ def generate_phsp(
random_generator = TFUniformRealNumberGenerator()

progress_bar = tqdm(
total=size / bunch_size,
total=size,
desc="Generating phase space sample",
disable=logging.getLogger().level > logging.WARNING,
)
Expand All @@ -166,6 +165,12 @@ def generate_phsp(
momentum_pool.append(bunch)
else:
momentum_pool = bunch
progress_bar.update()
progress_bar.close()
progress_bar.update(n=bunch.n_events)
finalize_progress_bar(progress_bar)
return momentum_pool.select_events(slice(0, size))


def finalize_progress_bar(progress_bar: tqdm) -> None:
remainder = progress_bar.total - progress_bar.n
progress_bar.update(n=remainder) # pylint crashes if total is set directly
progress_bar.close()

0 comments on commit fe4f828

Please sign in to comment.