Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ def _maybe_report(self) -> None:

def _update_bar(self) -> None:
elapsed = time.perf_counter() - self._start_time
updates: dict[str, tuple[int, int, int]] = {}
for col, tracker in self._trackers.items():
completed, _total, success, failed, _skipped, _pct, _rate, _emoji = tracker.get_snapshot(elapsed)
self._bar.update(col, completed=completed, success=success, failed=failed)
updates[col] = (completed, success, failed)
self._bar.update_many(updates)

def _emit(self) -> None:
current_total = sum(tracker.get_snapshot(0.0)[0] for tracker in self._trackers.values())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import logging
import re
import shutil
import sys
import time
Expand All @@ -13,6 +14,7 @@

BAR_FILLED = "β–ˆ"
BAR_EMPTY = "β–‘"
_ANSI_RE = re.compile(r"\033\[[0-9;]*m")


def _compute_stats_width(total: int) -> int:
Expand Down Expand Up @@ -67,6 +69,10 @@ def __init__(self, stream: TextIO | None = None) -> None:
def is_active(self) -> bool:
return self._active

@property
def drawn_lines(self) -> int:
return self._drawn_lines

# -- context manager --

def __enter__(self) -> StickyProgressBar:
Expand Down Expand Up @@ -108,6 +114,16 @@ def update(
if self._active:
self._redraw()

def update_many(self, updates: dict[str, tuple[int, int, int]]) -> None:
with self._lock:
for key, (completed, success, failed) in updates.items():
if bar := self._bars.get(key):
bar.completed = completed
bar.success = success
bar.failed = failed
if self._active:
self._redraw()

def remove_bar(self, key: str) -> None:
with self._lock:
self._bars.pop(key, None)
Expand Down Expand Up @@ -163,23 +179,30 @@ def _redraw(self) -> None:
for bar in self._bars.values():
line = self._format_bar(bar, width, max_label)
self._write(line + "\n")
self._drawn_lines += 1
visible = len(_ANSI_RE.sub("", line))
if width > 0 and visible > width:
self._drawn_lines += (visible + width - 1) // width
else:
self._drawn_lines += 1

def _format_bar(self, bar: _BarState, width: int, label_width: int) -> str:
completed = min(bar.completed, bar.total)
pct = (completed / bar.total * 100) if bar.total > 0 else 100.0
elapsed = time.perf_counter() - bar.start_time
rate = bar.completed / elapsed if elapsed > 0 else 0.0
rate = min(bar.completed / elapsed if elapsed > 0 else 0.0, 9999.9)
remaining = max(0, bar.total - completed)
eta = f"{remaining / rate:.0f}s" if rate > 0 else "?"
eta = f"{min(remaining / rate, 999):.0f}s" if rate > 0 else "?"

label = bar.label.ljust(label_width)
total_w = len(str(bar.total))
count_str = f"{completed:>{total_w}}/{bar.total}"
stats = f" {pct:3.0f}% | {count_str} | {rate:6.1f} rec/s | eta {eta:>4s} | {bar.failed:>{total_w}} failed"
stats = stats.ljust(bar.stats_width)

bar_width = max(10, width - len(label) - bar.stats_width - 4)
bar_width = width - len(label) - bar.stats_width - 4
if bar_width < 1:
return f" {label} {stats}"[: max(0, width - 1)]

filled = int(bar_width * pct / 100)
empty = bar_width - filled

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import io
import logging
import os
import re
import shutil
from unittest.mock import patch

import pytest

from data_designer.engine.dataset_builders.utils.async_progress_reporter import AsyncProgressReporter
from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker
from data_designer.engine.dataset_builders.utils.sticky_progress_bar import (
StickyProgressBar,
)

CURSOR_UP_CLEAR = "\033[A\033[2K"
HIDE_CURSOR = "\033[?25l"
SHOW_CURSOR = "\033[?25h"
_ALL_ANSI_RE = re.compile(r"\033\[[0-9;?]*[a-zA-Z]")


class FakeTTY(io.StringIO):
"""StringIO that reports itself as a TTY so StickyProgressBar activates."""

def isatty(self) -> bool:
return True


@pytest.fixture
def tty_stream() -> FakeTTY:
return FakeTTY()


def test_no_output_when_not_tty() -> None:
stream = io.StringIO()
with StickyProgressBar(stream=stream) as bar:
bar.add_bar("a", "col_a", 10)
bar.update("a", completed=5, success=5)
assert stream.getvalue() == ""


def test_hides_and_shows_cursor(tty_stream: FakeTTY) -> None:
with StickyProgressBar(stream=tty_stream):
pass
output = tty_stream.getvalue()
assert output.startswith(HIDE_CURSOR)
assert output.endswith(SHOW_CURSOR)


def test_drawn_lines_tracks_add_and_remove(tty_stream: FakeTTY) -> None:
with StickyProgressBar(stream=tty_stream) as bar:
bar.add_bar("a", "col_a", 10)
bar.add_bar("b", "col_b", 10)
bar.add_bar("c", "col_c", 10)
assert bar.drawn_lines == 3

bar.remove_bar("a")
assert bar.drawn_lines == 2

bar.add_bar("d", "col_d", 10)
assert bar.drawn_lines == 3

bar.update("b", completed=5, success=5)
assert bar.drawn_lines == 3

bar.remove_bar("b")
bar.remove_bar("c")
bar.remove_bar("d")
assert bar.drawn_lines == 0


def test_drawn_lines_stable_across_many_updates(tty_stream: FakeTTY) -> None:
with StickyProgressBar(stream=tty_stream) as bar:
bar.add_bar("a", "col_a", 100)
bar.add_bar("b", "col_b", 100)
bar.add_bar("c", "col_c", 100)
for i in range(50):
bar.update("a", completed=i, success=i)
bar.update("b", completed=i, success=i)
bar.update("c", completed=i, success=i)

snapshot = tty_stream.getvalue()
bar.update("a", completed=50, success=50)
assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 3


def test_log_interleaving_preserves_drawn_lines(tty_stream: FakeTTY) -> None:
root_logger = logging.getLogger()
handler = logging.StreamHandler(tty_stream)
handler.setFormatter(logging.Formatter("%(message)s"))
root_logger.addHandler(handler)

try:
with StickyProgressBar(stream=tty_stream) as bar:
bar.add_bar("x", "col_x", 100)
bar.add_bar("y", "col_y", 100)
bar.add_bar("z", "col_z", 100)

for i in range(20):
bar.update("x", completed=i, success=i)
root_logger.info("log at step %d", i)
bar.update("y", completed=i, success=i)
bar.update("z", completed=i, success=i)

snapshot = tty_stream.getvalue()
bar.update("x", completed=20, success=20)
assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 3
finally:
root_logger.removeHandler(handler)


def test_wrapping_counts_physical_lines(tty_stream: FakeTTY) -> None:
narrow = os.terminal_size((40, 24))
with patch.object(shutil, "get_terminal_size", return_value=narrow):
with StickyProgressBar(stream=tty_stream) as bar:
bar.add_bar("a", "col_a", 100)
bar.add_bar("b", "col_b", 100)

original_format = bar._format_bar

def oversized_format(b: object, width: int, label_width: int) -> str:
line = original_format(b, width, label_width)
return line + "X" * 20

with patch.object(bar, "_format_bar", side_effect=oversized_format):
bar.update("a", completed=5, success=5)

snapshot = tty_stream.getvalue()
bar.update("b", completed=1, success=1)
assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) > 2


def test_wrapping_stable_across_updates(tty_stream: FakeTTY) -> None:
narrow = os.terminal_size((40, 24))
with patch.object(shutil, "get_terminal_size", return_value=narrow):
with StickyProgressBar(stream=tty_stream) as bar:
bar.add_bar("a", "col_a", 100)
bar.add_bar("b", "col_b", 100)

snapshot = tty_stream.getvalue()
bar.update("a", completed=0, success=0)
initial_clears = tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR)

for i in range(1, 21):
bar.update("a", completed=i, success=i)
bar.update("b", completed=i, success=i)

snapshot = tty_stream.getvalue()
bar.update("a", completed=21, success=21)
assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == initial_clears


def test_narrow_terminal_graceful_degradation(tty_stream: FakeTTY) -> None:
narrow = os.terminal_size((30, 24))
with patch.object(shutil, "get_terminal_size", return_value=narrow):
with StickyProgressBar(stream=tty_stream) as bar:
bar.add_bar("a", "column 'verification_1'", 300)
bar.update("a", completed=50, success=50)

snapshot = tty_stream.getvalue()
bar.update("a", completed=51, success=51)
assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 1

output = tty_stream.getvalue()
clean = _ALL_ANSI_RE.sub("", output).replace("\r", "")
for line in clean.split("\n"):
assert len(line) <= 29


def test_update_many_single_redraw(tty_stream: FakeTTY) -> None:
with StickyProgressBar(stream=tty_stream) as bar:
bar.add_bar("a", "col_a", 100)
bar.add_bar("b", "col_b", 100)
before = tty_stream.getvalue()

bar.update_many({"a": (10, 10, 0), "b": (20, 20, 0)})
after = tty_stream.getvalue()

new_output = after[len(before) :]
assert new_output.count(CURSOR_UP_CLEAR) == 2

clean = _ALL_ANSI_RE.sub("", after)
assert "10/100" in clean
assert "20/100" in clean


def test_update_many_ignores_unknown_keys(tty_stream: FakeTTY) -> None:
with StickyProgressBar(stream=tty_stream) as bar:
bar.add_bar("a", "col_a", 100)
bar.update_many({"a": (10, 10, 0), "unknown": (5, 5, 0)})

clean = _ALL_ANSI_RE.sub("", tty_stream.getvalue())
assert "10/100" in clean
assert "unknown" not in clean

snapshot = tty_stream.getvalue()
bar.update("a", completed=11, success=11)
assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 1


def test_reporter_updates_and_logs_keep_drawn_lines_in_sync(tty_stream: FakeTTY) -> None:
root_logger = logging.getLogger()
handler = logging.StreamHandler(tty_stream)
handler.setFormatter(logging.Formatter("%(message)s"))
root_logger.addHandler(handler)

try:
bar = StickyProgressBar(stream=tty_stream)
trackers = {
"col_a": ProgressTracker(total_records=100, label="column 'a'", quiet=True),
"col_b": ProgressTracker(total_records=100, label="column 'b'", quiet=True),
"col_c": ProgressTracker(total_records=100, label="column 'c'", quiet=True),
}

with bar:
reporter = AsyncProgressReporter(trackers, report_interval=0.1, progress_bar=bar)
reporter.log_start(num_row_groups=1)

snapshot = tty_stream.getvalue()
reporter.record_success("col_a")
assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 3

for i in range(49):
if i % 10 == 0:
root_logger.info("Processing batch %d", i)
reporter.record_success("col_b")
reporter.record_success("col_c")

snapshot = tty_stream.getvalue()
reporter.record_success("col_a")
assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 3

reporter.log_final()
assert bar.drawn_lines == 0
finally:
root_logger.removeHandler(handler)
Loading