diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py index 6e7d9c2c..02bd2243 100644 --- a/sqlsynthgen/create.py +++ b/sqlsynthgen/create.py @@ -26,18 +26,23 @@ def create_db_tables(metadata: Any) -> Any: def create_db_data(sorted_tables: list, sorted_generators: list, num_rows: int) -> None: """Connect to a database and populate it with data.""" settings = get_settings() - engine = create_engine(settings.dst_postgres_dsn) + dst_engine = create_engine(settings.dst_postgres_dsn) + src_engine = create_engine(settings.src_postgres_dsn) - with engine.connect() as conn: - populate(conn, sorted_tables, sorted_generators, num_rows) + with dst_engine.connect() as dst_conn: + with src_engine.connect() as src_conn: + populate(src_conn, dst_conn, sorted_tables, sorted_generators, num_rows) -def populate(conn: Any, tables: list, generators: list, num_rows: int) -> None: +def populate( + src_conn: Any, dst_conn: Any, tables: list, generators: list, num_rows: int +) -> None: """Populate a database schema with dummy data.""" - for table, generator in zip(tables, generators): - # Run all the inserts for one table in a transaction - with conn.begin(): + for table, generator in zip( + tables, generators + ): # Run all the inserts for one table in a transaction + with dst_conn.begin(): for _ in range(num_rows): - stmt = insert(table).values(generator(conn).__dict__) - conn.execute(stmt) + stmt = insert(table).values(generator(src_conn, dst_conn).__dict__) + dst_conn.execute(stmt) diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index d3df0a2c..638b1b1d 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -9,11 +9,11 @@ '"""This file was auto-generated by sqlsynthgen but can be edited manually."""', "from mimesis import Generic", "from mimesis.locales import Locale", - "from sqlsynthgen.providers import BinaryProvider, ForeignKeyProvider", + "from sqlsynthgen.providers import BytesProvider, ColumnValueProvider", "", "generic = Generic(locale=Locale.EN)", - "generic.add_provider(ForeignKeyProvider)", - "generic.add_provider(BinaryProvider)", + "generic.add_provider(ColumnValueProvider)", + "generic.add_provider(BytesProvider)", "", ) ) @@ -38,12 +38,14 @@ def make_generators_from_tables(tables_module: ModuleType) -> str: sql_to_mimesis_map = { sqltypes.BigInteger: "generic.numeric.integer_number()", sqltypes.Boolean: "generic.development.boolean()", - sqltypes.DateTime: "generic.datetime.datetime()", sqltypes.Date: "generic.datetime.date()", + sqltypes.DateTime: "generic.datetime.datetime()", + sqltypes.Float: "generic.numeric.float_number()", sqltypes.Integer: "generic.numeric.integer_number()", + sqltypes.LargeBinary: "generic.bytes_provider.bytes()", + sqltypes.Numeric: "generic.numeric.float_number()", + sqltypes.String: "generic.text.color()", sqltypes.Text: "generic.text.color()", - sqltypes.Float: "generic.numeric.float_number()", - sqltypes.LargeBinary: "generic.binary_provider.bytes()", } for table in tables_module.Base.metadata.sorted_tables: @@ -54,15 +56,15 @@ def make_generators_from_tables(tables_module: ModuleType) -> str: + new_class_name + ":\n" + INDENTATION - + "def __init__(self, db_connection):\n" + + "def __init__(self, src_db_conn, dst_db_conn):\n" ) for column in table.columns: # We presume that primary keys are populated automatically if column.primary_key: - continue + new_content += f"{INDENTATION*2}pass\n" - if column.foreign_keys: + elif column.foreign_keys: if len(column.foreign_keys) > 1: raise NotImplementedError("Can't handle multiple foreign keys.") fkey = column.foreign_keys.pop() @@ -70,12 +72,12 @@ def make_generators_from_tables(tables_module: ModuleType) -> str: fk_schema, fk_table, fk_column = fk_column_path.split(".") new_content += ( f"{INDENTATION*2}self.{column.name} = " - f"generic.foreign_key_provider.key(db_connection, " + f"generic.column_value_provider.column_value(dst_db_conn, " f'"{fk_schema}", "{fk_table}", "{fk_column}"' ")\n" ) - else: + else: new_content += ( INDENTATION * 2 + "self." diff --git a/sqlsynthgen/providers.py b/sqlsynthgen/providers.py index b8e5e31a..cadf52c4 100644 --- a/sqlsynthgen/providers.py +++ b/sqlsynthgen/providers.py @@ -3,35 +3,33 @@ from mimesis import Text from mimesis.providers.base import BaseDataProvider, BaseProvider - -# from mimesis.locales import Locale from sqlalchemy.sql import text -# generic = Generic(locale=Locale.EN) - -class ForeignKeyProvider(BaseProvider): - """A Mimesis provider of foreign keys.""" +class ColumnValueProvider(BaseProvider): + """A Mimesis provider of random values from the source database.""" class Meta: - """Meta-class for ForeignKeyProvider settings.""" + """Meta-class for ColumnValueProvider settings.""" - name = "foreign_key_provider" + name = "column_value_provider" - def key(self, db_connection: Any, schema: str, table: str, column: str) -> Any: - """Return a random value from the table and column specified.""" + def column_value( + self, db_connection: Any, schema: str, table: str, column: str + ) -> Any: + """Return a random value from the column specified.""" query_str = f"SELECT {column} FROM {schema}.{table} ORDER BY random() LIMIT 1" key = db_connection.execute(text(query_str)).fetchone()[0] return key -class BinaryProvider(BaseDataProvider): +class BytesProvider(BaseDataProvider): """A Mimesis provider of binary data.""" class Meta: - """Meta-class for ForeignKeyProvider settings.""" + """Meta-class for BytesProvider settings.""" - name = "binary_provider" + name = "bytes_provider" def bytes(self) -> bytes: """Return a UTF-8 encoded sentence.""" diff --git a/tests/examples/example_orm.py b/tests/examples/example_orm.py index 950ba69d..e37ea654 100644 --- a/tests/examples/example_orm.py +++ b/tests/examples/example_orm.py @@ -35,3 +35,15 @@ class HopsitalVisit(Base): visit_end = Column(Date) visit_duration_seconds = Column(Float) visit_image = Column(LargeBinary) + + +class Entity(Base): + __tablename__ = "entity" + __table_args__ = {"schema": "myschema"} + + # NB Do not add any more columns to this table as + # we use it to test what happens in the one-column case + entity_id = Column( + Integer, + primary_key=True, + ) diff --git a/tests/examples/expected_ssg.py b/tests/examples/expected_ssg.py index ca4a00fb..ce14784f 100644 --- a/tests/examples/expected_ssg.py +++ b/tests/examples/expected_ssg.py @@ -1,15 +1,21 @@ """This file was auto-generated by sqlsynthgen but can be edited manually.""" from mimesis import Generic from mimesis.locales import Locale -from sqlsynthgen.providers import BinaryProvider, ForeignKeyProvider +from sqlsynthgen.providers import BytesProvider, ColumnValueProvider generic = Generic(locale=Locale.EN) -generic.add_provider(ForeignKeyProvider) -generic.add_provider(BinaryProvider) +generic.add_provider(ColumnValueProvider) +generic.add_provider(BytesProvider) + + +class entityGenerator: + def __init__(self, src_db_conn, dst_db_conn): + pass class personGenerator: - def __init__(self, db_connection): + def __init__(self, src_db_conn, dst_db_conn): + pass self.name = generic.text.color() self.nhs_number = generic.text.color() self.research_opt_out = generic.development.boolean() @@ -18,15 +24,17 @@ def __init__(self, db_connection): class hospital_visitGenerator: - def __init__(self, db_connection): - self.person_id = generic.foreign_key_provider.key(db_connection, "myschema", "person", "person_id") + def __init__(self, src_db_conn, dst_db_conn): + pass + self.person_id = generic.column_value_provider.column_value(dst_db_conn, "myschema", "person", "person_id") self.visit_start = generic.datetime.datetime() self.visit_end = generic.datetime.date() self.visit_duration_seconds = generic.numeric.float_number() - self.visit_image = generic.binary_provider.bytes() + self.visit_image = generic.bytes_provider.bytes() sorted_generators = [ + entityGenerator, personGenerator, hospital_visitGenerator, ] diff --git a/tests/examples/providers.dump b/tests/examples/providers.dump new file mode 100644 index 00000000..525cc832 --- /dev/null +++ b/tests/examples/providers.dump @@ -0,0 +1,52 @@ +-- +-- PostgreSQL database dump +-- + +-- Dumped from database version 14.2 (Debian 14.2-1.pgdg110+1) +-- Dumped by pg_dump version 14.6 (Homebrew) + +SET statement_timeout = 0; +SET lock_timeout = 0; +SET idle_in_transaction_session_timeout = 0; +SET client_encoding = 'UTF8'; +SET standard_conforming_strings = on; +SELECT pg_catalog.set_config('search_path', '', false); +SET check_function_bodies = false; +SET xmloption = content; +SET client_min_messages = warning; +SET row_security = off; + +DROP DATABASE IF EXISTS providers; +-- +-- Name: providers; Type: DATABASE; Schema: -; Owner: postgres +-- + +CREATE DATABASE providers WITH TEMPLATE = template0 ENCODING = 'UTF8' LOCALE = 'en_US.utf8'; + + +ALTER DATABASE providers OWNER TO postgres; + +\connect providers + +SET statement_timeout = 0; +SET lock_timeout = 0; +SET idle_in_transaction_session_timeout = 0; +SET client_encoding = 'UTF8'; +SET standard_conforming_strings = on; +SELECT pg_catalog.set_config('search_path', '', false); +SET check_function_bodies = false; +SET xmloption = content; +SET client_min_messages = warning; +SET row_security = off; + +SET default_tablespace = ''; + +SET default_table_access_method = heap; + +-- +-- Name: patient; Type: TABLE; Schema: public; Owner: postgres +-- + +CREATE TABLE public.patient ( + sex text NOT NULL +); diff --git a/tests/test_create.py b/tests/test_create.py index de1b163a..3041b8ba 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -2,7 +2,7 @@ from unittest import TestCase from unittest.mock import MagicMock, patch -from sqlsynthgen.create import create_db_data, create_db_tables +from sqlsynthgen.create import create_db_data, create_db_tables, populate from tests.utils import get_test_settings @@ -21,7 +21,7 @@ def test_create_db_data(self) -> None: create_db_data([], [], 0) mock_populate.assert_called_once() - mock_create_engine.assert_called_once() + mock_create_engine.assert_called() def test_create_db_tables(self) -> None: """Test the create_tables function.""" @@ -36,3 +36,21 @@ def test_create_db_tables(self) -> None: mock_create_engine.assert_called_once_with( mock_get_settings.return_value.dst_postgres_dsn ) + + def test_populate(self) -> None: + """Test the populate function.""" + with patch("sqlsynthgen.create.insert") as mock_insert: + mock_src_conn = MagicMock() + mock_dst_conn = MagicMock() + mock_gen = MagicMock() + tables = [None] + generators = [mock_gen] + populate(mock_src_conn, mock_dst_conn, tables, generators, 1) + + mock_gen.assert_called_once_with(mock_src_conn, mock_dst_conn) + mock_insert.return_value.values.assert_called_once_with( + mock_gen.return_value.__dict__ + ) + mock_dst_conn.execute.assert_called_once_with( + mock_insert.return_value.values.return_value + ) diff --git a/tests/test_functional.py b/tests/test_functional.py index e3dfeaba..2f4a83c7 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,9 +1,11 @@ -"""Tests for the main module.""" +"""Tests for the CLI.""" import os from pathlib import Path from subprocess import run from unittest import TestCase, skipUnless +from tests.utils import run_psql + @skipUnless( os.environ.get("FUNCTIONAL_TESTS") == "1", "Set 'FUNCTIONAL_TESTS=1' to enable." @@ -19,27 +21,7 @@ def setUp(self) -> None: self.orm_file_path.unlink(missing_ok=True) self.ssg_file_path.unlink(missing_ok=True) - # If you need to update src.dump or dst.dump, use - # pg_dump -d src|dst -h localhost -U postgres -C -c > tests/examples/src|dst.dump - - env = os.environ.copy() - env = {**env, "PGPASSWORD": "password"} - - # Clear and re-create the destination database - completed_process = run( - [ - "psql", - "--host=localhost", - "--username=postgres", - "--file=" + str(Path("tests/examples/dst.dump")), - ], - capture_output=True, - env=env, - check=True, - ) - - # psql doesn't always return != 0 if it fails - assert completed_process.stderr == b"", completed_process.stderr + run_psql("dst.dump") def test_workflow(self) -> None: """Test the recommended CLI workflow runs without errors.""" diff --git a/tests/test_providers.py b/tests/test_providers.py new file mode 100644 index 00000000..fe79fbb3 --- /dev/null +++ b/tests/test_providers.py @@ -0,0 +1,64 @@ +"""Tests for the providers module.""" +import os +from unittest import TestCase, skipUnless + +from sqlalchemy import Column, Integer, Text, create_engine, insert +from sqlalchemy.ext.declarative import declarative_base + +from sqlsynthgen.providers import BytesProvider, ColumnValueProvider +from tests.utils import run_psql + +# pylint: disable=invalid-name +Base = declarative_base() +# pylint: enable=invalid-name +metadata = Base.metadata + + +class Person(Base): # type: ignore + """A SQLAlchemy table.""" + + __tablename__ = "person" + person_id = Column( + Integer, + primary_key=True, + ) + # We don't actually need a foreign key constraint to test this + sex = Column(Text) + + +class BinaryProviderTestCase(TestCase): + """Tests for the BytesProvider class.""" + + def test_bytes(self) -> None: + """Test the bytes method.""" + self.assertTrue(BytesProvider().bytes().decode("utf-8") != "") + + +@skipUnless( + os.environ.get("FUNCTIONAL_TESTS") == "1", "Set 'FUNCTIONAL_TESTS=1' to enable." +) +class ColumnValueProviderTestCase(TestCase): + """Tests for the ColumnValueProvider class.""" + + def setUp(self) -> None: + """Pre-test setup.""" + + run_psql("providers.dump") + + self.engine = create_engine( + "postgresql://postgres:password@localhost:5432/providers" + ) + metadata.create_all(self.engine) + + def test_column_value(self) -> None: + """Test the key method.""" + # pylint: disable=invalid-name + + with self.engine.connect() as conn: + stmt = insert(Person).values(sex="M") + conn.execute(stmt) + + provider = ColumnValueProvider() + key = provider.column_value(conn, "public", "person", "sex") + + self.assertEqual("M", key) diff --git a/tests/utils.py b/tests/utils.py index 8635d78b..49624879 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,8 @@ """Utilities for testing.""" +import os from functools import lru_cache +from pathlib import Path +from subprocess import run from sqlsynthgen import settings @@ -20,3 +23,28 @@ def get_test_settings() -> settings.Settings: # To stop any local .env files influencing the test _env_file=None, ) + + +def run_psql(dump_file_name: str) -> None: + """Run psql and""" + + # If you need to update a .dump file, use + # pg_dump -d DBNAME -h localhost -U postgres -C -c > tests/examples/FILENAME.dump + + env = os.environ.copy() + env = {**env, "PGPASSWORD": "password"} + + # Clear and re-create the test database + completed_process = run( + [ + "psql", + "--host=localhost", + "--username=postgres", + "--file=" + str(Path(f"tests/examples/{dump_file_name}")), + ], + capture_output=True, + env=env, + check=True, + ) + # psql doesn't always return != 0 if it fails + assert completed_process.stderr == b"", completed_process.stderr