Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/615.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added `ref db` CLI subcommand group for database management. Includes commands for running migrations, checking schema status, viewing migration history, creating backups, executing SQL queries, and listing tables.
20 changes: 20 additions & 0 deletions packages/climate-ref/src/climate_ref/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
("datasets", "list"),
("datasets", "list-columns"),
("datasets", "stats"),
("db", "status"),
("db", "heads"),
("db", "history"),
("db", "tables"),
("executions", "list-groups"),
("executions", "inspect"),
("executions", "stats"),
Expand Down Expand Up @@ -76,6 +80,7 @@ class CLIContext:
console: Console
skip_backup: bool = False
_database: Database | None = field(default=None, alias="_database")
_database_unmigrated: Database | None = field(default=None, alias="_database_unmigrated")

@property
def database(self) -> Database:
Expand All @@ -89,10 +94,23 @@ def database(self) -> Database:
self._database = Database.from_config(self.config, skip_backup=self.skip_backup)
return self._database

@property
def database_unmigrated(self) -> Database:
"""
Get a database instance without running migrations.

Used by ``db`` subcommands that inspect or manage migration state directly.
"""
if self._database_unmigrated is None:
self._database_unmigrated = Database.from_config(self.config, run_migrations=False)
return self._database_unmigrated

def close(self) -> None:
"""Close the database connection if it was opened."""
if self._database is not None:
self._database.close()
if self._database_unmigrated is not None:
self._database_unmigrated.close()


def _version_callback(value: bool) -> None:
Expand Down Expand Up @@ -155,6 +173,7 @@ def build_app() -> typer.Typer:
from climate_ref.cli import (
config,
datasets,
db,
executions,
providers,
solve,
Expand All @@ -166,6 +185,7 @@ def build_app() -> typer.Typer:
app.command(name="solve")(solve.solve)
app.add_typer(config.app, name="config")
app.add_typer(datasets.app, name="datasets")
app.add_typer(db.app, name="db")
app.add_typer(executions.app, name="executions")
app.add_typer(providers.app, name="providers")
app.add_typer(test_cases.app, name="test-cases")
Expand Down
235 changes: 235 additions & 0 deletions packages/climate-ref/src/climate_ref/cli/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""
Database management commands
"""

from typing import Annotated

import sqlalchemy
import typer
from alembic.script import ScriptDirectory
from rich.table import Table

from climate_ref.config import Config
from climate_ref.database import Database, _create_backup, _get_database_revision, _get_sqlite_path

Comment thread
lewisjared marked this conversation as resolved.
app = typer.Typer(help=__doc__)


def _get_script_directory(db: Database, config: Config) -> ScriptDirectory:
"""Build an Alembic ScriptDirectory from a Database and Config."""
alembic_cfg = db.alembic_config(config)
return ScriptDirectory.from_config(alembic_cfg)


@app.command()
def migrate(ctx: typer.Context) -> None:
"""
Run database migrations to bring the schema up to date.

This applies any pending Alembic migrations. A backup is created
before migrating (SQLite only).
"""
db = ctx.obj.database_unmigrated
config = ctx.obj.config
console = ctx.obj.console

script = _get_script_directory(db, config)
head_rev = script.get_current_head()

with db._engine.connect() as connection:
current_rev = _get_database_revision(connection)

if current_rev == head_rev:
console.print(f"Database is already up to date at revision [bold]{current_rev}[/bold].")
return

console.print(f"Current revision: [yellow]{current_rev or '(empty)'}[/yellow]")
console.print(f"Target revision: [green]{head_rev}[/green]")
console.print("Running migrations...")

db.migrate(config, skip_backup=False)
console.print("[green]Migrations applied successfully.[/green]")


@app.command()
def status(ctx: typer.Context) -> None:
"""
Check if the database schema is up to date.

Shows the current revision, the latest available revision,
and whether any migrations are pending.
"""
db = ctx.obj.database_unmigrated
config = ctx.obj.config
console = ctx.obj.console

script = _get_script_directory(db, config)
head_rev = script.get_current_head()

with db._engine.connect() as connection:
current_rev = _get_database_revision(connection)

console.print(f"Database URL: [bold]{db.url}[/bold]")
console.print(f"Current revision: [bold]{current_rev or '(empty)'}[/bold]")
console.print(f"Head revision: [bold]{head_rev}[/bold]")

if current_rev == head_rev:
console.print("[green]Database is up to date.[/green]")
elif current_rev is None:
console.print("[yellow]Database has no revision stamp (new or unmanaged).[/yellow]")
else:
console.print(
"[yellow]Database is behind. Run 'ref db migrate' to apply pending migrations.[/yellow]"
)


@app.command()
def heads(ctx: typer.Context) -> None:
"""
Show the latest migration revision(s).
"""
db = ctx.obj.database_unmigrated
config = ctx.obj.config
console = ctx.obj.console

script = _get_script_directory(db, config)

for head in script.get_heads():
revision = script.get_revision(head)
if revision is not None:
console.print(f"[bold]{revision.revision}[/bold] — {revision.doc or '(no description)'}")


@app.command()
def history(
ctx: typer.Context,
last: Annotated[
int | None,
typer.Option("--last", "-n", help="Show only the last N migrations"),
] = None,
) -> None:
"""
Show the migration history.
"""
db = ctx.obj.database_unmigrated
config = ctx.obj.config
console = ctx.obj.console

script = _get_script_directory(db, config)

with db._engine.connect() as connection:
current_rev = _get_database_revision(connection)

revisions = list(script.walk_revisions())
if last is not None:
Comment thread
lewisjared marked this conversation as resolved.
if last < 1:
raise typer.BadParameter("--last must be greater than or equal to 1")
revisions = revisions[:last]

table = Table(title="Migration History")
table.add_column("Revision", style="bold")
table.add_column("Description")
table.add_column("Status")

for rev in revisions:
is_current = rev.revision == current_rev
status_text = "[green]current[/green]" if is_current else ""
table.add_row(
rev.revision[:12],
rev.doc or "(no description)",
status_text,
)

console.print(table)


@app.command()
def backup(ctx: typer.Context) -> None:
"""
Create a manual backup of the database (SQLite only).
"""
config = ctx.obj.config
console = ctx.obj.console

db_path = _get_sqlite_path(config.db.database_url)
if db_path is None:
console.print("[red]Backup is only supported for local SQLite databases.[/red]")
raise typer.Exit(1)

if not db_path.exists():
console.print(f"[red]Database file not found: {db_path}[/red]")
raise typer.Exit(1)

backup_path = _create_backup(db_path, config.db.max_backups)
console.print(f"[green]Backup created at: {backup_path}[/green]")


@app.command()
def sql(
ctx: typer.Context,
query: Annotated[
str,
typer.Argument(help="SQL query to execute"),
],
limit: Annotated[
int,
typer.Option("--limit", "-l", help="Maximum number of rows to display"),
] = 100,
) -> None:
"""
Execute an arbitrary SQL query against the database.

SELECT queries display results as a table (default limit: 100 rows).
Other statements report the number of rows affected.
"""
db = ctx.obj.database_unmigrated
console = ctx.obj.console

with db._engine.connect() as connection:
result = connection.execute(sqlalchemy.text(query))

if result.returns_rows:
columns = list(result.keys())
rows = result.fetchmany(limit)
total_remaining = len(result.fetchall())

total_rows = len(rows) + total_remaining
table = Table(title=f"Results ({len(rows)} of {total_rows} rows)")
for col in columns:
table.add_column(str(col))

for row in rows:
table.add_row(*(str(v) for v in row))

console.print(table)

if total_remaining > 0:
console.print(
f"[yellow]{total_remaining} additional rows not shown. Use --limit to adjust.[/yellow]"
)
else:
connection.commit()
console.print(f"[green]Query executed successfully. Rows affected: {result.rowcount}[/green]")


@app.command()
def tables(ctx: typer.Context) -> None:
"""
List all tables in the database.
"""
db = ctx.obj.database_unmigrated
console = ctx.obj.console

with db._engine.connect() as connection:
inspector = sqlalchemy.inspect(connection)
table_names = inspector.get_table_names()

table = Table(title="Database Tables")
table.add_column("Table Name", style="bold")
table.add_column("Columns", justify="right")

for name in sorted(table_names):
columns = inspector.get_columns(name)
table.add_row(name, str(len(columns)))

console.print(table)
26 changes: 20 additions & 6 deletions packages/climate-ref/src/climate_ref/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@
"""


def _get_sqlite_path(database_url: str) -> Path | None:
"""
Extract the file path from a SQLite database URL.

Returns ``None`` for in-memory databases or non-SQLite URLs.
"""
split_url = urlparse.urlsplit(database_url)
if split_url.scheme != "sqlite":
return None
path = urlparse.unquote(split_url.path[1:])
if not path or path == ":memory:":
return None
return Path(path)


def _get_database_revision(connection: sqlalchemy.engine.Connection) -> str | None:
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
Expand Down Expand Up @@ -123,13 +138,13 @@ def validate_database_url(database_url: str) -> str:
The validated database URL
"""
split_url = urlparse.urlsplit(database_url)
path = split_url.path[1:]

if split_url.scheme == "sqlite":
if path == ":memory:":
sqlite_path = _get_sqlite_path(database_url)
if sqlite_path is None:
logger.warning("Using an in-memory database")
else:
Path(path).parent.mkdir(parents=True, exist_ok=True)
sqlite_path.parent.mkdir(parents=True, exist_ok=True)
elif split_url.scheme == "postgresql":
# We don't need to do anything special for PostgreSQL
logger.warning("PostgreSQL support is currently experimental and untested")
Expand Down Expand Up @@ -260,9 +275,8 @@ def migrate(self, config: "Config", skip_backup: bool = False) -> None:
)

# Create backup before running migrations (unless skipped)
split_url = urlparse.urlsplit(self.url)
if not skip_backup and split_url.scheme == "sqlite" and split_url.path != ":memory:":
db_path = Path(split_url.path[1:])
db_path = _get_sqlite_path(self.url)
if not skip_backup and db_path is not None:
_create_backup(db_path, config.db.max_backups)

alembic.command.upgrade(self.alembic_config(config), "heads")
Expand Down
Loading
Loading