Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pydantic = {extras = ["dotenv"], version = "^1.10.2"}
psycopg2-binary = "^2.9.5"
sqlalchemy-utils = "^0.38.3"
mimesis = "^6.1.1"
typer = "^0.7.0"


[tool.poetry.group.dev.dependencies]
Expand All @@ -26,7 +27,7 @@ requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.poetry.scripts]
sqlsynthgen = "sqlsynthgen.main:main"
sqlsynthgen = "sqlsynthgen.main:app"

[tool.isort]
profile = "black"
30 changes: 30 additions & 0 deletions sqlsynthgen/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Functions and classes to create and populate the target database."""
from typing import Any

from sqlalchemy import create_engine, insert

from sqlsynthgen.settings import get_settings


def create_db_tables(metadata: Any) -> Any:
"""Create tables described by the sqlalchemy metadata object."""
settings = get_settings()
engine = create_engine(settings.dst_postgres_dsn)
metadata.create_all(engine)


def generate(sorted_tables: list, sorted_generators: list) -> Any:
"""Connect to a database and populate it with data."""
settings = get_settings()
engine = create_engine(settings.dst_postgres_dsn)

with engine.connect() as conn:
populate(conn, sorted_tables, sorted_generators)


def populate(conn: Any, tables: list, generators: list) -> None:
"""Populate a database schema with dummy data."""

for table, generator in zip(tables, generators):
stmt = insert(table).values(generator(conn).__dict__)
conn.execute(stmt)
46 changes: 25 additions & 21 deletions sqlsynthgen/main.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,43 @@
"""Entrypoint for the sqlsynthgen package."""
from typing import Any
from subprocess import run

from sqlalchemy import create_engine, insert
import typer

from sqlsynthgen.settings import get_settings

app = typer.Typer()

def create_tables(metadata: Any) -> Any:
"""Create tables described by the sqlalchemy metadata object."""
settings = get_settings()
engine = create_engine(settings.postgres_dsn)
metadata.create_all(engine)

@app.command()
def create_data() -> None:
"""Fill tables with synthetic data."""


@app.command()
def create_tables() -> None:
"""Create tables using the SQLAlchemy file."""


def main() -> None:
"""Not implemented yet."""
raise NotImplementedError
@app.command()
def make_generators() -> None:
"""Make a SQLSynthGun file of generator classes."""


def generate(sorted_tables: list, sorted_generators: list) -> Any:
"""Connect to a database and populate it with data."""
@app.command()
def make_tables() -> None:
"""Make a SQLAlchemy file of Table classes."""
settings = get_settings()
engine = create_engine(settings.postgres_dsn)

with engine.connect() as conn:
populate(conn, sorted_tables, sorted_generators)
command = ["sqlacodegen"]

if settings.src_schema:
command.append(f"--schema={settings.src_schema}")

def populate(conn: Any, tables: list, generators: list) -> None:
"""Populate a database schema with dummy data."""
command.append(str(get_settings().src_postgres_dsn))

for table, generator in zip(tables, generators):
stmt = insert(table).values(generator(conn).__dict__)
conn.execute(stmt)
completed_process = run(command, capture_output=True, encoding="utf-8", check=True)
print(completed_process.stdout)


if __name__ == "__main__":
main()
app()
58 changes: 39 additions & 19 deletions sqlsynthgen/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,55 @@
class Settings(BaseSettings):
"""A Pydantic settings class with optional and mandatory settings."""

# Connection parameters for a PostgreSQL database. See also,
# Connection parameters for the source PostgreSQL database. See also
# https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS
db_host_name: str # e.g. "mydb.mydomain.com" or "0.0.0.0"
db_port: int = 5432
db_user_name: str # e.g. "postgres" or "myuser@mydb"
db_password: str
db_name: str = "" # leave empty to get the user's default db
ssl_required: bool = False # whether the db requires SSL

# postgres_dsn is calculated so do not provide it explicitly
postgres_dsn: Optional[PostgresDsn]

@validator("postgres_dsn", pre=True)
def validate_postgres_dsn(cls, _: Optional[PostgresDsn], values: Any) -> str:
src_host_name: str # e.g. "mydb.mydomain.com" or "0.0.0.0"
src_port: int = 5432
src_user_name: str # e.g. "postgres" or "myuser@mydb"
src_password: str
src_db_name: str = "" # leave empty to get the user's default db
src_ssl_required: bool = False # whether the db requires SSL
src_schema: Optional[str]

# Connection parameters for the destination PostgreSQL database.
dst_host_name: str # e.g. "mydb.mydomain.com" or "0.0.0.0"
dst_port: int = 5432
dst_user_name: str # e.g. "postgres" or "myuser@mydb"
dst_password: str
dst_db_name: str = "" # leave empty to get the user's default db
dst_ssl_required: bool = False # whether the db requires SSL

# These are calculated so do not provide them explicitly
src_postgres_dsn: Optional[PostgresDsn]
dst_postgres_dsn: Optional[PostgresDsn]

@validator("src_postgres_dsn", pre=True)
def validate_src_postgres_dsn(cls, _: Optional[PostgresDsn], values: Any) -> str:
"""Create and validate the source database DSN."""
return cls.check_postgres_dsn(_, values, "src")

@validator("dst_postgres_dsn", pre=True)
def validate_dst_postgres_dsn(cls, _: Optional[PostgresDsn], values: Any) -> str:
"""Create and validate the destination database DSN."""
return cls.check_postgres_dsn(_, values, "dst")

@staticmethod
def check_postgres_dsn(_: Optional[PostgresDsn], values: Any, prefix: str) -> str:
"""Build a DSN string from the host, db name, port, username and password."""

# We want to build the Data Source Name ourselves so none should be provided
if _:
raise ValueError("postgres_dsn should not be provided")

user = values["db_user_name"]
password = values["db_password"]
host = values["db_host_name"]
port = values["db_port"]
db_name = values["db_name"]
user = values[f"{prefix}_user_name"]
password = values[f"{prefix}_password"]
host = values[f"{prefix}_host_name"]
port = values[f"{prefix}_port"]
db_name = values[f"{prefix}_db_name"]

dsn = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"

if values["ssl_required"]:
if values[f"{prefix}_ssl_required"]:
return dsn + "?sslmode=require"

return dsn
Expand Down
Empty file added tests/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions tests/test_create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Tests for the main module."""
from unittest import TestCase
from unittest.mock import MagicMock, patch

from sqlsynthgen.create import create_db_tables, generate
from tests.utils import get_test_settings


class MyTestCase(TestCase):
"""Module test case."""

def test_generate(self) -> None:
"""Test the generate function."""
with patch("sqlsynthgen.create.populate") as mock_populate, patch(
"sqlsynthgen.create.get_settings"
) as mock_get_settings, patch(
"sqlsynthgen.create.create_engine"
) as mock_create_engine:
mock_get_settings.return_value = get_test_settings()

generate([], [])

mock_populate.assert_called_once()
mock_create_engine.assert_called_once()

def test_create_tables(self) -> None:
"""Test the create_tables function."""
mock_meta = MagicMock()

with patch("sqlsynthgen.create.create_engine") as mock_create_engine, patch(
"sqlsynthgen.create.get_settings"
) as mock_get_settings:

create_db_tables(mock_meta)
mock_get_settings.assert_called_once()
mock_create_engine.assert_called_once_with(
mock_get_settings.return_value.dst_postgres_dsn
)
Loading