Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Custom
.vscode
.idea
.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
4 changes: 3 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,9 @@ good-names=i,
k,
ex,
Run,
_
_,
e,
f
# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
good-names-rgxs=
Expand Down
2 changes: 1 addition & 1 deletion sqlsynthgen/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def create_db_tables(metadata: Any) -> Any:
metadata.create_all(engine)


def generate(sorted_tables: list, sorted_generators: list) -> Any:
def create_db_data(sorted_tables: list, sorted_generators: list) -> None:
"""Connect to a database and populate it with data."""
settings = get_settings()
engine = create_engine(settings.dst_postgres_dsn)
Expand Down
46 changes: 39 additions & 7 deletions sqlsynthgen/main.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,51 @@
"""Entrypoint for the sqlsynthgen package."""
from subprocess import run
"""Entrypoint for the SQLSynthGen package."""
import sys
from importlib import import_module
from pathlib import Path
from subprocess import CalledProcessError, run
from sys import stderr
from types import ModuleType

import typer

from sqlsynthgen.create import create_db_data, create_db_tables
from sqlsynthgen.make import make_generators_from_tables
from sqlsynthgen.settings import get_settings

app = typer.Typer()


def import_file(file_path: str) -> ModuleType:
"""Import a file given a relative path."""
file_path_path = Path(file_path)
module_path = ".".join(file_path_path.parts[:-1] + (file_path_path.stem,))
return import_module(module_path)


@app.command()
def create_data() -> None:
def create_data(
orm_file: str = typer.Argument(...),
ssg_file: str = typer.Argument(...),
) -> None:
"""Fill tables with synthetic data."""
orm_module = import_file(orm_file)
ssg_module = import_file(ssg_file)
create_db_data(orm_module.metadata.sorted_tables, ssg_module.sorted_generators)


@app.command()
def create_tables() -> None:
def create_tables(orm_file: str = typer.Argument(...)) -> None:
"""Create tables using the SQLAlchemy file."""
orm_module = import_file(orm_file)
create_db_tables(orm_module.metadata)


@app.command()
def make_generators() -> None:
"""Make a SQLSynthGun file of generator classes."""
def make_generators(orm_file: str = typer.Argument(...)) -> None:
"""Make a SQLSynthGen file of generator classes."""
orm_module = import_file(orm_file)
result = make_generators_from_tables(orm_module)
print(result)


@app.command()
Expand All @@ -35,7 +60,14 @@ def make_tables() -> None:

command.append(str(get_settings().src_postgres_dsn))

completed_process = run(command, capture_output=True, encoding="utf-8", check=True)
try:
completed_process = run(
command, capture_output=True, encoding="utf-8", check=True
)
except CalledProcessError as e:
print(e.stderr, file=stderr)
sys.exit(e.returncode)

print(completed_process.stdout)


Expand Down
10 changes: 4 additions & 6 deletions sqlsynthgen/create_generators.py → sqlsynthgen/make.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Functions to create a module of generator classes."""
import importlib
"""Functions to make a module of generator classes."""
from types import ModuleType
from typing import Final

from sqlalchemy.sql import sqltypes
Expand All @@ -21,12 +21,11 @@
INDENTATION: Final[str] = " " * 4


def create_generators_from_tables(tables_module_name: str) -> str:
def make_generators_from_tables(tables_module: ModuleType) -> str:
"""Creates sqlsynthgen generator classes from a sqlacodegen-generated file.

Args:
tables_module_name: The name of a sqlacodegen-generated module
as you would provide to importlib.import_module.
tables_module: A sqlacodegen-generated module.

Returns:
A string that is a valid Python module, once written to file.
Expand All @@ -47,7 +46,6 @@ def create_generators_from_tables(tables_module_name: str) -> str:
sqltypes.LargeBinary: "generic.binary_provider.bytes()",
}

tables_module = importlib.import_module(tables_module_name)
for table in tables_module.metadata.sorted_tables:
new_class_name = table.name + "Generator"
sorted_generators += INDENTATION + new_class_name + ",\n"
Expand Down
4 changes: 2 additions & 2 deletions sqlsynthgen/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Settings(BaseSettings):
src_port: int = 5432
src_user_name: str # e.g. "postgres" or "myuser@mydb"
src_password: str
src_db_name: str = "" # leave empty to get the user's default db
src_db_name: str
src_ssl_required: bool = False # whether the db requires SSL
src_schema: Optional[str]

Expand All @@ -25,7 +25,7 @@ class Settings(BaseSettings):
dst_port: int = 5432
dst_user_name: str # e.g. "postgres" or "myuser@mydb"
dst_password: str
dst_db_name: str = "" # leave empty to get the user's default db
dst_db_name: str
dst_ssl_required: bool = False # whether the db requires SSL

# These are calculated so do not provide them explicitly
Expand Down
8 changes: 4 additions & 4 deletions tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from unittest import TestCase
from unittest.mock import MagicMock, patch

from sqlsynthgen.create import create_db_tables, generate
from sqlsynthgen.create import create_db_data, create_db_tables
from tests.utils import get_test_settings


class MyTestCase(TestCase):
"""Module test case."""

def test_generate(self) -> None:
def test_create_db_data(self) -> None:
"""Test the generate function."""
with patch("sqlsynthgen.create.populate") as mock_populate, patch(
"sqlsynthgen.create.get_settings"
Expand All @@ -18,12 +18,12 @@ def test_generate(self) -> None:
) as mock_create_engine:
mock_get_settings.return_value = get_test_settings()

generate([], [])
create_db_data([], [])

mock_populate.assert_called_once()
mock_create_engine.assert_called_once()

def test_create_tables(self) -> None:
def test_create_db_tables(self) -> None:
"""Test the create_tables function."""
mock_meta = MagicMock()

Expand Down
76 changes: 55 additions & 21 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Tests for the main module."""
from subprocess import CalledProcessError
from unittest import TestCase
from unittest.mock import call, patch

from click.testing import Result
from typer.testing import CliRunner

from sqlsynthgen.main import app
from tests.examples import example_tables, expected_output
from tests.utils import get_test_settings

runner = CliRunner()
Expand Down Expand Up @@ -90,38 +92,70 @@ def test_make_tables_with_schema(self) -> None:
)
self.assertNotEqual("", result.stdout)

def test_make_tables_handles_errors(self) -> None:
"""Test the make-tables sub-command handles sqlacodegen errors."""

with patch("sqlsynthgen.main.run") as mock_run, patch(
"sqlsynthgen.main.get_settings"
) as mock_get_settings, patch("sqlsynthgen.main.stderr") as mock_stderr:
mock_run.side_effect = CalledProcessError(
returncode=99, cmd="some-cmd", stderr="some-error-output"
)
mock_get_settings.return_value = get_test_settings()

result = runner.invoke(
app,
[
"make-tables",
],
catch_exceptions=False,
)

self.assertEqual(99, result.exit_code)
mock_stderr.assert_has_calls(
[call.write("some-error-output"), call.write("\n")]
)

def test_make_generators(self) -> None:
"""Test the make-generators sub-command."""
result = runner.invoke(
app,
[
"make-generators",
],
catch_exceptions=False,
)
with patch("sqlsynthgen.main.make_generators_from_tables") as mock_make:
result = runner.invoke(
app,
["make-generators", "tests/examples/example_tables.py"],
catch_exceptions=False,
)

self.assertSuccess(result)
mock_make.assert_called_once_with(example_tables)

def test_create_tables(self) -> None:
"""Test the create-tables sub-command."""
result = runner.invoke(
app,
[
"create-tables",
],
catch_exceptions=False,
)

with patch("sqlsynthgen.main.create_db_tables") as mock_create:
result = runner.invoke(
app,
["create-tables", "tests/examples/example_tables.py"],
catch_exceptions=False,
)

self.assertSuccess(result)
mock_create.assert_called_once_with(example_tables.metadata)

def test_create_data(self) -> None:
"""Test the create-data sub-command."""
result = runner.invoke(
app,
[
"create-data",
],
catch_exceptions=False,
)

with patch("sqlsynthgen.main.create_db_data") as mock_create_db_data:
result = runner.invoke(
app,
[
"create-data",
"tests/examples/example_tables.py",
"tests/examples/expected_output.py",
],
catch_exceptions=False,
)

self.assertSuccess(result)
mock_create_db_data.assert_called_once_with(
example_tables.metadata.sorted_tables, expected_output.sorted_generators
)
12 changes: 6 additions & 6 deletions tests/test_create_generators.py → tests/test_make.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
"""Tests for the main module."""
from unittest import TestCase

from sqlsynthgen import create_generators
from sqlsynthgen import make
from tests.examples import example_tables


class MyTestCase(TestCase):
"""Module test case."""

def test_generators_from_tables(self) -> None:
"""Check that we can create a generators file from a tables file."""
def test_make_generators_from_tables(self) -> None:
"""Check that we can make a generators file from a tables module."""

with open(
"tests/examples/expected_output.py", encoding="utf-8"
) as expected_output:
expected = expected_output.read()

actual = create_generators.create_generators_from_tables(
"tests.examples.example_tables"
)
actual = make.make_generators_from_tables(example_tables)
self.assertEqual(expected, actual)
8 changes: 6 additions & 2 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,24 @@ def test_default_settings(self) -> None:
src_host_name="shost",
src_user_name="suser",
src_password="spassword",
src_db_name="sdbname",
dst_host_name="dhost",
dst_user_name="duser",
dst_password="dpassword",
dst_db_name="ddbname",
# To stop any local .env files influencing the test
_env_file=None,
)

self.assertEqual(
"postgresql://suser:spassword@shost:5432/", str(settings.src_postgres_dsn)
"postgresql://suser:spassword@shost:5432/sdbname",
str(settings.src_postgres_dsn),
)
self.assertIsNone(settings.src_schema)

self.assertEqual(
"postgresql://duser:dpassword@dhost:5432/", str(settings.dst_postgres_dsn)
"postgresql://duser:dpassword@dhost:5432/ddbname",
str(settings.dst_postgres_dsn),
)

def test_maximal_settings(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ def get_test_settings() -> settings.Settings:
src_host_name="shost",
src_user_name="suser",
src_password="spassword",
src_db_name="sdbname",
dst_host_name="dhost",
dst_user_name="duser",
dst_password="dpassword",
dst_db_name="ddbname",
# To stop any local .env files influencing the test
_env_file=None,
)