Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/data-designer/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ path = "dev-tools/hatch_build.py"
dependencies = [
"data-designer-config=={{ version }}",
"data-designer-engine=={{ version }}",
"packaging>=25,<27",
"prompt-toolkit>=3.0.0,<4",
"typer>=0.12.0,<1",
]
Expand Down
23 changes: 22 additions & 1 deletion packages/data-designer/src/data_designer/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import importlib.metadata
import sys
from typing import TextIO

import typer

Expand All @@ -16,14 +17,34 @@
_PACKAGE_NAME = "data-designer"


def _should_show_update_notice(stream: TextIO | None = None) -> bool:
stream = sys.stdout if stream is None else stream
return stream.isatty()


def _version_callback(value: bool) -> None:
if not value:
return
try:
typer.echo(importlib.metadata.version(_PACKAGE_NAME))
installed_version = importlib.metadata.version(_PACKAGE_NAME)
except importlib.metadata.PackageNotFoundError:
typer.echo(f"Unable to resolve installed {_PACKAGE_NAME} package version.", err=True)
raise typer.Exit(1) from None

typer.echo(installed_version)
if not _should_show_update_notice():
raise typer.Exit()

from data_designer.cli.ui import print_update_notice
from data_designer.cli.version_notice import get_update_notice

try:
# The update CTA is opportunistic; version output should stay usable if lookup fails.
notice = get_update_notice(installed_version)
except Exception:
notice = None
if notice is not None:
print_update_notice(notice.latest_version, notice.upgrade_command)
raise typer.Exit()


Expand Down
23 changes: 23 additions & 0 deletions packages/data-designer/src/data_designer/cli/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from rich.console import Console
from rich.padding import Padding
from rich.panel import Panel
from rich.text import Text

from data_designer.config.utils.constants import RICH_CONSOLE_THEME, NordColor

Expand Down Expand Up @@ -574,6 +575,28 @@ def print_info(message: str) -> None:
_print_with_padding(f"๐Ÿ’ก {message}")


def print_update_notice(latest_version: str, upgrade_command: str) -> None:
"""Print a compact version update notice.

Args:
latest_version: Latest available Data Designer version.
upgrade_command: Command users can run to upgrade.
"""
content = Text.assemble(
"New Data Designer version available: ",
(latest_version, f"bold {NordColor.NORD14.value}"),
"\nUpgrade with: ",
(upgrade_command, f"bold {NordColor.NORD8.value}"),
)
panel = Panel.fit(
content,
title="๐Ÿš€ Update available",
title_align="left",
border_style=NordColor.NORD8.value,
)
_console.print(panel)


def print_text(message: str) -> None:
"""Print a text message.

Expand Down
236 changes: 236 additions & 0 deletions packages/data-designer/src/data_designer/cli/version_notice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import json
import os
import sys
import time
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen

from packaging.version import InvalidVersion, Version

from data_designer.config.utils.constants import DATA_DESIGNER_HOME

_PACKAGE_NAME = "data-designer"
_PYPI_JSON_URL = f"https://pypi.org/pypi/{_PACKAGE_NAME}/json"
_VERSION_CHECK_TIMEOUT_SECONDS = 0.75
_CACHE_TTL_SECONDS = 6 * 60 * 60
_CACHE_SCHEMA_VERSION = 1
_CACHE_FILE_NAME = "version-check.json"
_DISABLE_VERSION_CHECK_ENV_VAR = "DATA_DESIGNER_DISABLE_VERSION_CHECK"
_INCLUDE_PRERELEASES_ENV_VAR = "DATA_DESIGNER_VERSION_CHECK_PRERELEASES"
_UV_TOOL_UPGRADE_COMMAND = "uv tool upgrade data-designer"
_PROJECT_UPGRADE_COMMAND = "uv add --upgrade data-designer"
_PIPX_UPGRADE_COMMAND = "pipx upgrade data-designer"


@dataclass(frozen=True)
class UpdateNotice:
latest_version: str
upgrade_command: str


def get_update_notice(
installed_version: str,
*,
cache_dir: Path = DATA_DESIGNER_HOME,
environ: Mapping[str, str] | None = None,
now: Callable[[], float] = time.time,
python_prefix: str | None = None,
) -> UpdateNotice | None:
env = os.environ if environ is None else environ
if _env_flag_enabled(env, _DISABLE_VERSION_CHECK_ENV_VAR):
return None

try:
installed = Version(installed_version)
except InvalidVersion:
return None
if installed.local is not None:
return None

include_prereleases = installed.is_prerelease or _env_flag_enabled(env, _INCLUDE_PRERELEASES_ENV_VAR)
latest_version = _get_latest_version(
include_prereleases=include_prereleases,
cache_dir=cache_dir,
now=now,
)
if latest_version is None:
return None

try:
latest = Version(latest_version)
except InvalidVersion:
return None

if latest <= installed:
return None

return UpdateNotice(
latest_version=latest.public,
upgrade_command=select_upgrade_command(environ=env, python_prefix=python_prefix),
)


def select_upgrade_command(
*,
environ: Mapping[str, str] | None = None,
python_prefix: str | None = None,
) -> str:
env = os.environ if environ is None else environ
prefix = Path(sys.prefix if python_prefix is None else python_prefix)
prefix_parts = prefix.parts
if env.get("UV_PROJECT_ENVIRONMENT") or prefix.name == ".venv":
return _PROJECT_UPGRADE_COMMAND
if _has_direct_child_path(prefix_parts, "pipx", "venvs"):
return _PIPX_UPGRADE_COMMAND
if _has_direct_child_path(prefix_parts, "uv", "tools"):
return _UV_TOOL_UPGRADE_COMMAND
if env.get("VIRTUAL_ENV"):
return _PROJECT_UPGRADE_COMMAND
return _UV_TOOL_UPGRADE_COMMAND


def _has_direct_child_path(parts: tuple[str, ...], parent: str, child: str) -> bool:
return any(
parts[index] == parent and parts[index + 1] == child and index + 2 == len(parts) - 1
for index in range(len(parts) - 1)
)


def _get_latest_version(
*,
include_prereleases: bool,
cache_dir: Path,
now: Callable[[], float],
) -> str | None:
cache_path = cache_dir / _CACHE_FILE_NAME
cached_version = _read_cached_version(
cache_path=cache_path,
include_prereleases=include_prereleases,
now=now,
)
if cached_version is not None:
return cached_version

try:
latest_version = _fetch_latest_version(include_prereleases=include_prereleases)
except (HTTPError, URLError, TimeoutError, OSError, json.JSONDecodeError):
return None

if latest_version is not None:
_write_cached_version(
cache_path=cache_path,
latest_version=latest_version,
include_prereleases=include_prereleases,
checked_at=now(),
)
return latest_version


def _fetch_latest_version(*, include_prereleases: bool) -> str | None:
request = Request(_PYPI_JSON_URL, headers={"Accept": "application/json", "User-Agent": "data-designer"})
with urlopen(request, timeout=_VERSION_CHECK_TIMEOUT_SECONDS) as response:
payload = json.load(response)
if not isinstance(payload, dict):
return None
return _latest_version_from_pypi_payload(payload, include_prereleases=include_prereleases)


def _latest_version_from_pypi_payload(payload: Mapping[str, Any], *, include_prereleases: bool) -> str | None:
releases = payload.get("releases")
if not isinstance(releases, dict):
return None

candidates: list[Version] = []
for version_text, release_files in releases.items():
if not isinstance(version_text, str) or not _has_installable_release_file(release_files):
continue
try:
version = Version(version_text)
except InvalidVersion:
continue
if version.is_prerelease and not include_prereleases:
continue
candidates.append(version)

if not candidates:
return None

return max(candidates).public


def _has_installable_release_file(release_files: Any) -> bool:
if not isinstance(release_files, list):
return False
return any(
isinstance(release_file, dict) and not release_file.get("yanked", False) for release_file in release_files
)


def _read_cached_version(
*,
cache_path: Path,
include_prereleases: bool,
now: Callable[[], float],
) -> str | None:
try:
cache_data = json.loads(cache_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return None
if not isinstance(cache_data, dict):
return None

if cache_data.get("schema_version") != _CACHE_SCHEMA_VERSION:
return None
if cache_data.get("package_name") != _PACKAGE_NAME:
return None
if cache_data.get("include_prereleases") != include_prereleases:
return None

checked_at = cache_data.get("checked_at")
latest_version = cache_data.get("latest_version")
if not isinstance(checked_at, (int, float)) or not isinstance(latest_version, str):
return None
if now() - float(checked_at) > _CACHE_TTL_SECONDS:
return None
return latest_version


def _write_cached_version(
*,
cache_path: Path,
latest_version: str,
include_prereleases: bool,
checked_at: float,
) -> None:
cache_data = {
"checked_at": checked_at,
"include_prereleases": include_prereleases,
"latest_version": latest_version,
"package_name": _PACKAGE_NAME,
"schema_version": _CACHE_SCHEMA_VERSION,
}
temp_path = cache_path.with_name(f"{cache_path.name}.{os.getpid()}.tmp")
try:
cache_path.parent.mkdir(parents=True, exist_ok=True)
temp_path.write_text(json.dumps(cache_data), encoding="utf-8")
temp_path.replace(cache_path)
except OSError:
try:
temp_path.unlink()
except OSError:
pass
return


def _env_flag_enabled(env: Mapping[str, str], name: str) -> bool:
value = env.get(name, "")
return value.strip().lower() in {"1", "true", "yes", "on"}
48 changes: 47 additions & 1 deletion packages/data-designer/tests/cli/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typer.testing import CliRunner

from data_designer.cli.main import app, main
from data_designer.cli.version_notice import UpdateNotice
from data_designer.config.utils.constants import DEFAULT_NUM_RECORDS

runner = CliRunner()
Expand Down Expand Up @@ -51,12 +52,57 @@ def test_main_skips_bootstrap_when_version_follows_another_flag(mock_bootstrap:


def test_app_version_prints_installed_data_designer_version() -> None:
with patch("data_designer.cli.main.importlib.metadata.version", return_value="0.6.0") as mock_version:
with (
patch("data_designer.cli.main.importlib.metadata.version", return_value="0.6.0") as mock_version,
patch("data_designer.cli.main._should_show_update_notice", return_value=True),
patch("data_designer.cli.version_notice.get_update_notice", return_value=None) as mock_notice,
):
result = runner.invoke(app, ["--version"])

assert result.exit_code == 0
assert result.output == "0.6.0\n"
mock_version.assert_called_once_with("data-designer")
mock_notice.assert_called_once_with("0.6.0")


def test_app_version_prints_update_notice_after_installed_version() -> None:
notice = UpdateNotice(latest_version="0.6.1", upgrade_command="uv tool upgrade data-designer")
with (
patch("data_designer.cli.main.importlib.metadata.version", return_value="0.6.0"),
patch("data_designer.cli.main._should_show_update_notice", return_value=True),
patch("data_designer.cli.version_notice.get_update_notice", return_value=notice),
):
result = runner.invoke(app, ["--version"])

assert result.exit_code == 0
assert result.output.splitlines()[0] == "0.6.0"
assert "New Data Designer version available: 0.6.1" in result.output
assert "Upgrade with: uv tool upgrade data-designer" in result.output


def test_app_version_skips_update_notice_when_stdout_is_not_tty() -> None:
with (
patch("data_designer.cli.main.importlib.metadata.version", return_value="0.6.0"),
patch("data_designer.cli.main._should_show_update_notice", return_value=False),
patch("data_designer.cli.version_notice.get_update_notice") as mock_notice,
):
result = runner.invoke(app, ["--version"])

assert result.exit_code == 0
assert result.output == "0.6.0\n"
mock_notice.assert_not_called()


def test_app_version_skips_update_notice_when_lookup_fails() -> None:
with (
patch("data_designer.cli.main.importlib.metadata.version", return_value="0.6.0"),
patch("data_designer.cli.main._should_show_update_notice", return_value=True),
patch("data_designer.cli.version_notice.get_update_notice", side_effect=RuntimeError("network failure")),
):
result = runner.invoke(app, ["--version"])

assert result.exit_code == 0
assert result.output == "0.6.0\n"


def test_app_version_errors_when_package_version_is_missing() -> None:
Expand Down
Loading
Loading