diff --git a/changelog/615.feature.md b/changelog/615.feature.md new file mode 100644 index 000000000..936f003f9 --- /dev/null +++ b/changelog/615.feature.md @@ -0,0 +1 @@ +Added `ref db` CLI subcommand group for database management. Includes commands for running migrations, checking schema status, viewing migration history, creating backups, executing SQL queries, and listing tables. diff --git a/packages/climate-ref/src/climate_ref/cli/__init__.py b/packages/climate-ref/src/climate_ref/cli/__init__.py index d11fe4208..0de2494ae 100644 --- a/packages/climate-ref/src/climate_ref/cli/__init__.py +++ b/packages/climate-ref/src/climate_ref/cli/__init__.py @@ -25,6 +25,10 @@ ("datasets", "list"), ("datasets", "list-columns"), ("datasets", "stats"), + ("db", "status"), + ("db", "heads"), + ("db", "history"), + ("db", "tables"), ("executions", "list-groups"), ("executions", "inspect"), ("executions", "stats"), @@ -76,6 +80,7 @@ class CLIContext: console: Console skip_backup: bool = False _database: Database | None = field(default=None, alias="_database") + _database_unmigrated: Database | None = field(default=None, alias="_database_unmigrated") @property def database(self) -> Database: @@ -89,10 +94,23 @@ def database(self) -> Database: self._database = Database.from_config(self.config, skip_backup=self.skip_backup) return self._database + @property + def database_unmigrated(self) -> Database: + """ + Get a database instance without running migrations. + + Used by ``db`` subcommands that inspect or manage migration state directly. + """ + if self._database_unmigrated is None: + self._database_unmigrated = Database.from_config(self.config, run_migrations=False) + return self._database_unmigrated + def close(self) -> None: """Close the database connection if it was opened.""" if self._database is not None: self._database.close() + if self._database_unmigrated is not None: + self._database_unmigrated.close() def _version_callback(value: bool) -> None: @@ -155,6 +173,7 @@ def build_app() -> typer.Typer: from climate_ref.cli import ( config, datasets, + db, executions, providers, solve, @@ -166,6 +185,7 @@ def build_app() -> typer.Typer: app.command(name="solve")(solve.solve) app.add_typer(config.app, name="config") app.add_typer(datasets.app, name="datasets") + app.add_typer(db.app, name="db") app.add_typer(executions.app, name="executions") app.add_typer(providers.app, name="providers") app.add_typer(test_cases.app, name="test-cases") diff --git a/packages/climate-ref/src/climate_ref/cli/db.py b/packages/climate-ref/src/climate_ref/cli/db.py new file mode 100644 index 000000000..2d43580ce --- /dev/null +++ b/packages/climate-ref/src/climate_ref/cli/db.py @@ -0,0 +1,235 @@ +""" +Database management commands +""" + +from typing import Annotated + +import sqlalchemy +import typer +from alembic.script import ScriptDirectory +from rich.table import Table + +from climate_ref.config import Config +from climate_ref.database import Database, _create_backup, _get_database_revision, _get_sqlite_path + +app = typer.Typer(help=__doc__) + + +def _get_script_directory(db: Database, config: Config) -> ScriptDirectory: + """Build an Alembic ScriptDirectory from a Database and Config.""" + alembic_cfg = db.alembic_config(config) + return ScriptDirectory.from_config(alembic_cfg) + + +@app.command() +def migrate(ctx: typer.Context) -> None: + """ + Run database migrations to bring the schema up to date. + + This applies any pending Alembic migrations. A backup is created + before migrating (SQLite only). + """ + db = ctx.obj.database_unmigrated + config = ctx.obj.config + console = ctx.obj.console + + script = _get_script_directory(db, config) + head_rev = script.get_current_head() + + with db._engine.connect() as connection: + current_rev = _get_database_revision(connection) + + if current_rev == head_rev: + console.print(f"Database is already up to date at revision [bold]{current_rev}[/bold].") + return + + console.print(f"Current revision: [yellow]{current_rev or '(empty)'}[/yellow]") + console.print(f"Target revision: [green]{head_rev}[/green]") + console.print("Running migrations...") + + db.migrate(config, skip_backup=False) + console.print("[green]Migrations applied successfully.[/green]") + + +@app.command() +def status(ctx: typer.Context) -> None: + """ + Check if the database schema is up to date. + + Shows the current revision, the latest available revision, + and whether any migrations are pending. + """ + db = ctx.obj.database_unmigrated + config = ctx.obj.config + console = ctx.obj.console + + script = _get_script_directory(db, config) + head_rev = script.get_current_head() + + with db._engine.connect() as connection: + current_rev = _get_database_revision(connection) + + console.print(f"Database URL: [bold]{db.url}[/bold]") + console.print(f"Current revision: [bold]{current_rev or '(empty)'}[/bold]") + console.print(f"Head revision: [bold]{head_rev}[/bold]") + + if current_rev == head_rev: + console.print("[green]Database is up to date.[/green]") + elif current_rev is None: + console.print("[yellow]Database has no revision stamp (new or unmanaged).[/yellow]") + else: + console.print( + "[yellow]Database is behind. Run 'ref db migrate' to apply pending migrations.[/yellow]" + ) + + +@app.command() +def heads(ctx: typer.Context) -> None: + """ + Show the latest migration revision(s). + """ + db = ctx.obj.database_unmigrated + config = ctx.obj.config + console = ctx.obj.console + + script = _get_script_directory(db, config) + + for head in script.get_heads(): + revision = script.get_revision(head) + if revision is not None: + console.print(f"[bold]{revision.revision}[/bold] — {revision.doc or '(no description)'}") + + +@app.command() +def history( + ctx: typer.Context, + last: Annotated[ + int | None, + typer.Option("--last", "-n", help="Show only the last N migrations"), + ] = None, +) -> None: + """ + Show the migration history. + """ + db = ctx.obj.database_unmigrated + config = ctx.obj.config + console = ctx.obj.console + + script = _get_script_directory(db, config) + + with db._engine.connect() as connection: + current_rev = _get_database_revision(connection) + + revisions = list(script.walk_revisions()) + if last is not None: + if last < 1: + raise typer.BadParameter("--last must be greater than or equal to 1") + revisions = revisions[:last] + + table = Table(title="Migration History") + table.add_column("Revision", style="bold") + table.add_column("Description") + table.add_column("Status") + + for rev in revisions: + is_current = rev.revision == current_rev + status_text = "[green]current[/green]" if is_current else "" + table.add_row( + rev.revision[:12], + rev.doc or "(no description)", + status_text, + ) + + console.print(table) + + +@app.command() +def backup(ctx: typer.Context) -> None: + """ + Create a manual backup of the database (SQLite only). + """ + config = ctx.obj.config + console = ctx.obj.console + + db_path = _get_sqlite_path(config.db.database_url) + if db_path is None: + console.print("[red]Backup is only supported for local SQLite databases.[/red]") + raise typer.Exit(1) + + if not db_path.exists(): + console.print(f"[red]Database file not found: {db_path}[/red]") + raise typer.Exit(1) + + backup_path = _create_backup(db_path, config.db.max_backups) + console.print(f"[green]Backup created at: {backup_path}[/green]") + + +@app.command() +def sql( + ctx: typer.Context, + query: Annotated[ + str, + typer.Argument(help="SQL query to execute"), + ], + limit: Annotated[ + int, + typer.Option("--limit", "-l", help="Maximum number of rows to display"), + ] = 100, +) -> None: + """ + Execute an arbitrary SQL query against the database. + + SELECT queries display results as a table (default limit: 100 rows). + Other statements report the number of rows affected. + """ + db = ctx.obj.database_unmigrated + console = ctx.obj.console + + with db._engine.connect() as connection: + result = connection.execute(sqlalchemy.text(query)) + + if result.returns_rows: + columns = list(result.keys()) + rows = result.fetchmany(limit) + total_remaining = len(result.fetchall()) + + total_rows = len(rows) + total_remaining + table = Table(title=f"Results ({len(rows)} of {total_rows} rows)") + for col in columns: + table.add_column(str(col)) + + for row in rows: + table.add_row(*(str(v) for v in row)) + + console.print(table) + + if total_remaining > 0: + console.print( + f"[yellow]{total_remaining} additional rows not shown. Use --limit to adjust.[/yellow]" + ) + else: + connection.commit() + console.print(f"[green]Query executed successfully. Rows affected: {result.rowcount}[/green]") + + +@app.command() +def tables(ctx: typer.Context) -> None: + """ + List all tables in the database. + """ + db = ctx.obj.database_unmigrated + console = ctx.obj.console + + with db._engine.connect() as connection: + inspector = sqlalchemy.inspect(connection) + table_names = inspector.get_table_names() + + table = Table(title="Database Tables") + table.add_column("Table Name", style="bold") + table.add_column("Columns", justify="right") + + for name in sorted(table_names): + columns = inspector.get_columns(name) + table.add_row(name, str(len(columns))) + + console.print(table) diff --git a/packages/climate-ref/src/climate_ref/database.py b/packages/climate-ref/src/climate_ref/database.py index 02edafaae..6129ac5d7 100644 --- a/packages/climate-ref/src/climate_ref/database.py +++ b/packages/climate-ref/src/climate_ref/database.py @@ -53,6 +53,21 @@ """ +def _get_sqlite_path(database_url: str) -> Path | None: + """ + Extract the file path from a SQLite database URL. + + Returns ``None`` for in-memory databases or non-SQLite URLs. + """ + split_url = urlparse.urlsplit(database_url) + if split_url.scheme != "sqlite": + return None + path = urlparse.unquote(split_url.path[1:]) + if not path or path == ":memory:": + return None + return Path(path) + + def _get_database_revision(connection: sqlalchemy.engine.Connection) -> str | None: context = MigrationContext.configure(connection) current_rev = context.get_current_revision() @@ -123,13 +138,13 @@ def validate_database_url(database_url: str) -> str: The validated database URL """ split_url = urlparse.urlsplit(database_url) - path = split_url.path[1:] if split_url.scheme == "sqlite": - if path == ":memory:": + sqlite_path = _get_sqlite_path(database_url) + if sqlite_path is None: logger.warning("Using an in-memory database") else: - Path(path).parent.mkdir(parents=True, exist_ok=True) + sqlite_path.parent.mkdir(parents=True, exist_ok=True) elif split_url.scheme == "postgresql": # We don't need to do anything special for PostgreSQL logger.warning("PostgreSQL support is currently experimental and untested") @@ -260,9 +275,8 @@ def migrate(self, config: "Config", skip_backup: bool = False) -> None: ) # Create backup before running migrations (unless skipped) - split_url = urlparse.urlsplit(self.url) - if not skip_backup and split_url.scheme == "sqlite" and split_url.path != ":memory:": - db_path = Path(split_url.path[1:]) + db_path = _get_sqlite_path(self.url) + if not skip_backup and db_path is not None: _create_backup(db_path, config.db.max_backups) alembic.command.upgrade(self.alembic_config(config), "heads") diff --git a/packages/climate-ref/tests/unit/cli/test_db.py b/packages/climate-ref/tests/unit/cli/test_db.py new file mode 100644 index 000000000..c865fefee --- /dev/null +++ b/packages/climate-ref/tests/unit/cli/test_db.py @@ -0,0 +1,189 @@ +import pytest + + +def test_without_subcommand(invoke_cli): + result = invoke_cli(["db"], expected_exit_code=2) + assert "Missing command." in result.stderr + + +def test_db_help(invoke_cli): + result = invoke_cli(["db", "--help"]) + + assert "Database management commands" in result.stdout + + +class TestDbStatus: + def test_status_fresh_database(self, invoke_cli): + result = invoke_cli(["db", "status"]) + + assert "Current revision:" in result.stdout + assert "Head revision:" in result.stdout + assert "Database URL:" in result.stdout + assert "no revision stamp" in result.stdout + + def test_status_up_to_date(self, invoke_cli): + # Migrate first, then check status + invoke_cli(["db", "migrate"]) + + result = invoke_cli(["db", "status"]) + + assert "Database is up to date" in result.stdout + + def test_status_behind(self, invoke_cli): + # Migrate, then stamp with a fake old revision so the DB appears behind + invoke_cli(["db", "migrate"]) + invoke_cli(["db", "sql", "UPDATE alembic_version SET version_num = 'fake_old_rev'"]) + + result = invoke_cli(["db", "status"]) + + assert "Database is behind" in result.stdout + + +class TestDbMigrate: + def test_migrate_fresh_database(self, invoke_cli): + result = invoke_cli(["db", "migrate"]) + + assert "Migrations applied successfully" in result.stdout + + def test_migrate_already_up_to_date(self, invoke_cli): + # Migrate first + invoke_cli(["db", "migrate"]) + + # Second migrate should report up to date + result = invoke_cli(["db", "migrate"]) + + assert "already up to date" in result.stdout + + +class TestDbHeads: + def test_heads(self, invoke_cli): + result = invoke_cli(["db", "heads"]) + + # Should show at least one head revision + assert result.exit_code == 0 + assert result.stdout.strip() != "" + + +class TestDbHistory: + def test_history(self, invoke_cli): + result = invoke_cli(["db", "history"]) + + assert "Migration History" in result.stdout + assert "Revision" in result.stdout + + def test_history_last(self, invoke_cli): + result = invoke_cli(["db", "history", "--last", "3"]) + + assert "Migration History" in result.stdout + + def test_history_last_invalid(self, invoke_cli): + result = invoke_cli(["db", "history", "--last", "0"], expected_exit_code=2) + + assert "must be greater than or equal to 1" in result.stderr + + def test_history_last_negative(self, invoke_cli): + result = invoke_cli(["db", "history", "--last", "-1"], expected_exit_code=2) + + assert "must be greater than or equal to 1" in result.stderr + + +class TestDbBackup: + def test_backup(self, invoke_cli): + # Trigger DB creation first + invoke_cli(["db", "migrate"]) + + result = invoke_cli(["db", "backup"]) + + assert "Backup created at" in result.stdout + + def test_backup_no_database_file(self, invoke_cli): + # No migrate, so no database file exists on disk + result = invoke_cli(["db", "backup"], expected_exit_code=1) + + assert "Database file not found" in result.stdout + + @pytest.mark.parametrize( + "url", + [ + "postgresql://localhost/test", # non-SQLite + "sqlite://", # canonical SQLAlchemy in-memory format + "sqlite:///:memory:", # alternative in-memory format + "sqlite://:memory:", # :memory: parsed as netloc, empty path + ], + ) + def test_backup_unsupported_database(self, config, invoke_cli, url): + config.db.database_url = url + config.save() + + result = invoke_cli(["db", "backup"], expected_exit_code=1) + + assert "only supported for local SQLite" in result.stdout + + +class TestDbSql: + def test_select_query(self, invoke_cli): + # Trigger DB creation first + invoke_cli(["db", "migrate"]) + + result = invoke_cli(["db", "sql", "SELECT COUNT(*) AS cnt FROM provider"]) + + assert "cnt" in result.stdout + assert "Results" in result.stdout + + def test_select_empty_table(self, invoke_cli): + invoke_cli(["db", "migrate"]) + + result = invoke_cli(["db", "sql", "SELECT * FROM provider"]) + + assert "Results (0 of 0 rows)" in result.stdout + + def test_update_query(self, invoke_cli): + invoke_cli(["db", "migrate"]) + + result = invoke_cli( + ["db", "sql", "INSERT INTO provider (slug, name, version) VALUES ('test', 'Test', '1.0')"] + ) + + assert "Query executed successfully" in result.stdout + + def test_select_with_limit(self, invoke_cli): + invoke_cli(["db", "migrate"]) + # Insert some rows + for i in range(5): + stmt = f"INSERT INTO provider (slug, name, version) VALUES ('t{i}', 'T{i}', '1.0')" # noqa: S608 + invoke_cli(["db", "sql", stmt]) + + result = invoke_cli(["db", "sql", "SELECT * FROM provider", "--limit", "2"]) + + assert "Results (2 of 5 rows)" in result.stdout + assert "additional rows not shown" in result.stdout + + def test_sql_no_tables(self, invoke_cli): + # No migrate, so database has no tables. SQLite auto-creates the file + # but querying a non-existent table should fail. + result = invoke_cli( + ["db", "sql", "SELECT * FROM provider"], + expected_exit_code=1, + ) + + assert result.exit_code == 1 + + +class TestDbTables: + def test_tables(self, invoke_cli): + invoke_cli(["db", "migrate"]) + + result = invoke_cli(["db", "tables"]) + + assert "Database Tables" in result.stdout + assert "provider" in result.stdout + assert "dataset" in result.stdout + assert "execution" in result.stdout + + def test_tables_no_database(self, invoke_cli): + # No migrate -- empty database with no tables + result = invoke_cli(["db", "tables"]) + + assert "Database Tables" in result.stdout + # No application tables should be listed + assert "provider" not in result.stdout diff --git a/packages/climate-ref/tests/unit/cli/test_root.py b/packages/climate-ref/tests/unit/cli/test_root.py index e4f98d033..c7c8fcc40 100644 --- a/packages/climate-ref/tests/unit/cli/test_root.py +++ b/packages/climate-ref/tests/unit/cli/test_root.py @@ -97,7 +97,7 @@ def test_config_directory_append(config, invoke_cli): @pytest.fixture() def expected_groups() -> set[str]: - return {"config", "datasets", "executions", "providers", "celery", "test-cases"} + return {"config", "datasets", "db", "executions", "providers", "celery", "test-cases"} def test_build_app(expected_groups): diff --git a/packages/climate-ref/tests/unit/test_database.py b/packages/climate-ref/tests/unit/test_database.py index 6b114bb1c..fd650c1fe 100644 --- a/packages/climate-ref/tests/unit/test_database.py +++ b/packages/climate-ref/tests/unit/test_database.py @@ -8,13 +8,55 @@ import sqlalchemy from sqlalchemy import inspect -from climate_ref.database import Database, _create_backup, _values_differ, validate_database_url +from climate_ref.database import ( + Database, + _create_backup, + _get_sqlite_path, + _values_differ, + validate_database_url, +) from climate_ref.models import MetricValue from climate_ref.models.dataset import CMIP6Dataset, Dataset, Obs4MIPsDataset from climate_ref_core.datasets import SourceDatasetType from climate_ref_core.pycmec.controlled_vocabulary import CV +class TestGetSqlitePath: + """Tests for _get_sqlite_path helper that extracts file paths from SQLite URLs.""" + + @pytest.mark.parametrize( + ("url", "expected"), + [ + ("sqlite:///climate_ref.db", Path("climate_ref.db")), + ("sqlite:////tmp/climate_ref.db", Path("/tmp/climate_ref.db")), # noqa: S108 + ("sqlite:///path%20with%20spaces/db.sqlite", Path("path with spaces/db.sqlite")), + ], + ) + def test_returns_path_for_file_databases(self, url, expected): + assert _get_sqlite_path(url) == expected + + @pytest.mark.parametrize( + "url", + [ + "sqlite://", # SQLAlchemy documented in-memory format + "sqlite:///:memory:", + "sqlite://:memory:", + ], + ) + def test_returns_none_for_in_memory(self, url): + assert _get_sqlite_path(url) is None + + @pytest.mark.parametrize( + "url", + [ + "postgresql://localhost/db", + "mysql://localhost/db", + ], + ) + def test_returns_none_for_non_sqlite(self, url): + assert _get_sqlite_path(url) is None + + @pytest.mark.parametrize( "database_url", [