Skip to content

Commit

Permalink
feat: add ETA to progress bar and fix not showing the progress bar if…
Browse files Browse the repository at this point in the history
… irrelavant (#253)

* script for trying the progress bar

* TimeRemainingColumn added

* erase try

* adding elapsed when finished

* not showing labeller row when there's no labeller

* only showing progress bar if there's labellers and generators

* chore: applied linting

* chore: resolve missing pipeline argument in llm base.py by initializing with partial func

---------

Co-authored-by: davidberenstein1957 <david.m.berenstein@gmail.com>
  • Loading branch information
ignacioct and davidberenstein1957 committed Jan 16, 2024
1 parent e9dc2a1 commit d15b671
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/distilabel/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,8 @@ def _generate( # noqa: C901
num_rows=len(dataset),
num_generations=num_generations,
display_progress_bar=display_progress_bar,
has_labeller=True if self.labeller else False,
has_generator=True if self.generator else False,
)

num_batches = math.ceil(len(dataset) / batch_size)
Expand Down
36 changes: 29 additions & 7 deletions src/distilabel/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import wraps
from functools import partial, wraps
from typing import Any, Callable, Tuple, TypeVar, Union

from rich.progress import (
Expand All @@ -21,6 +21,7 @@
Progress,
TaskProgressColumn,
TextColumn,
TimeRemainingColumn,
)
from typing_extensions import ParamSpec

Expand All @@ -29,6 +30,7 @@
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
TimeRemainingColumn(elapsed_when_finished=True),
)

P = ParamSpec("P")
Expand Down Expand Up @@ -61,23 +63,43 @@ def update_progress_bar(**kwargs: Any) -> None:


def get_progress_bars_for_pipeline(
num_rows: int, num_generations: int, display_progress_bar: bool
num_rows: int,
num_generations: int,
display_progress_bar: bool,
has_generator: bool,
has_labeller: bool,
) -> Tuple[ProgressFunc, ProgressFunc]:
if display_progress_bar:
generation_progress_bar = get_progress_bar(
description="Texts Generated", total=num_rows * num_generations
)

def _generation_progress_func(advance=None) -> None:
generation_progress_bar(advance=advance or num_generations)
def _generation_progress_func(has_generator: bool, advance=None) -> None:
# If there's no generator, we are not showing the progress bar.
# This information comes from pipelines.py
return (
generation_progress_bar(advance=advance or num_generations)
if has_generator
else None
)

labelling_progress_bar = get_progress_bar(
description="Rows labelled", total=num_rows
)

def _labelling_progress_func(advance=None) -> None:
labelling_progress_bar(advance=1)
def _labelling_progress_func(has_labeller: bool, advance=None) -> None:
# If there's no labeller, we are not showing the progress bar.
# This information comes from pipelines.py
return (
labelling_progress_bar(advance=advance or 1) if has_labeller else None
)

return _generation_progress_func, _labelling_progress_func
_partial_generation_progress_func = partial(
_generation_progress_func, has_generator=has_generator
)
_partial_labelling_progress_func = partial(
_labelling_progress_func, has_labeller=has_labeller
)
return _partial_generation_progress_func, _partial_labelling_progress_func

return None, None

0 comments on commit d15b671

Please sign in to comment.