diff --git a/poetry.lock b/poetry.lock index 6a88b5a7..5e9fe8eb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,3 +1,5 @@ +# This file is automatically @generated by Poetry and should not be changed by hand. + [[package]] name = "astroid" version = "2.12.13" @@ -557,6 +559,56 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "pyyaml" +version = "6.0" +description = "YAML parser and emitter for Python" +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, + {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, + {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc"}, + {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a80a78046a72361de73f8f395f1f1e49f956c6be882eed58505a15f3e430962b"}, + {file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"}, + {file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"}, + {file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"}, + {file = "PyYAML-6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358"}, + {file = "PyYAML-6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f"}, + {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782"}, + {file = "PyYAML-6.0-cp311-cp311-win32.whl", hash = "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7"}, + {file = "PyYAML-6.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf"}, + {file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"}, + {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"}, + {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"}, + {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:98c4d36e99714e55cfbaaee6dd5badbc9a1ec339ebfc3b1f52e293aee6bb71a4"}, + {file = "PyYAML-6.0-cp36-cp36m-win32.whl", hash = "sha256:0283c35a6a9fbf047493e3a0ce8d79ef5030852c51e9d911a27badfde0605293"}, + {file = "PyYAML-6.0-cp36-cp36m-win_amd64.whl", hash = "sha256:07751360502caac1c067a8132d150cf3d61339af5691fe9e87803040dbc5db57"}, + {file = "PyYAML-6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:819b3830a1543db06c4d4b865e70ded25be52a2e0631ccd2f6a47a2822f2fd7c"}, + {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:473f9edb243cb1935ab5a084eb238d842fb8f404ed2193a915d1784b5a6b5fc0"}, + {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ce82d761c532fe4ec3f87fc45688bdd3a4c1dc5e0b4a19814b9009a29baefd4"}, + {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:231710d57adfd809ef5d34183b8ed1eeae3f76459c18fb4a0b373ad56bedcdd9"}, + {file = "PyYAML-6.0-cp37-cp37m-win32.whl", hash = "sha256:c5687b8d43cf58545ade1fe3e055f70eac7a5a1a0bf42824308d868289a95737"}, + {file = "PyYAML-6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:d15a181d1ecd0d4270dc32edb46f7cb7733c7c508857278d3d378d14d606db2d"}, + {file = "PyYAML-6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0b4624f379dab24d3725ffde76559cff63d9ec94e1736b556dacdfebe5ab6d4b"}, + {file = "PyYAML-6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213c60cd50106436cc818accf5baa1aba61c0189ff610f64f4a3e8c6726218ba"}, + {file = "PyYAML-6.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9fa600030013c4de8165339db93d182b9431076eb98eb40ee068700c9c813e34"}, + {file = "PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287"}, + {file = "PyYAML-6.0-cp38-cp38-win32.whl", hash = "sha256:d4eccecf9adf6fbcc6861a38015c2a64f38b9d94838ac1810a9023a0609e1b78"}, + {file = "PyYAML-6.0-cp38-cp38-win_amd64.whl", hash = "sha256:1e4747bc279b4f613a09eb64bba2ba602d8a6664c6ce6396a4d0cd413a50ce07"}, + {file = "PyYAML-6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:055d937d65826939cb044fc8c9b08889e8c743fdc6a32b33e2390f66013e449b"}, + {file = "PyYAML-6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e61ceaab6f49fb8bdfaa0f92c4b57bcfbea54c09277b1b4f7ac376bfb7a7c174"}, + {file = "PyYAML-6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d67d839ede4ed1b28a4e8909735fc992a923cdb84e618544973d7dfc71540803"}, + {file = "PyYAML-6.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cba8c411ef271aa037d7357a2bc8f9ee8b58b9965831d9e51baf703280dc73d3"}, + {file = "PyYAML-6.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:40527857252b61eacd1d9af500c3337ba8deb8fc298940291486c465c8b46ec0"}, + {file = "PyYAML-6.0-cp39-cp39-win32.whl", hash = "sha256:b5b9eccad747aabaaffbc6064800670f0c297e52c12754eb1d976c57e4f74dcb"}, + {file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"}, + {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, +] + [[package]] name = "sqlacodegen" version = "3.0.0rc1" @@ -731,6 +783,18 @@ dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2 doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +[[package]] +name = "types-pyyaml" +version = "6.0.12.4" +description = "Typing stubs for PyYAML" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "types-PyYAML-6.0.12.4.tar.gz", hash = "sha256:ade6e328a5a3df816c47c912c2e1e946ae2bace90744aa73111ee6834b03a314"}, + {file = "types_PyYAML-6.0.12.4-py3-none-any.whl", hash = "sha256:de3bacfc4e0772d9b1baf007c37354f3c34c8952e90307d5155b6de0fc183a67"}, +] + [[package]] name = "typing-extensions" version = "4.4.0" @@ -836,4 +900,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "bbd38d0c749f03d0b5a8290b5e915d209e9e9520a624872fd00fbc2c75f2650a" +content-hash = "9902165958da3e7071182cabfe601d608c371088a6f9849f2d6777745884c9f6" diff --git a/pyproject.toml b/pyproject.toml index a10efc56..230beb84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ psycopg2-binary = "^2.9.5" sqlalchemy-utils = "^0.38.3" mimesis = "^6.1.1" typer = "^0.7.0" +pyyaml = "^6.0" [tool.poetry.group.dev.dependencies] @@ -21,6 +22,7 @@ black = "^22.10.0" isort = "^5.10.1" pylint = "^2.15.8" mypy = "^0.991" +types-pyyaml = "^6.0.12.4" [build-system] requires = ["poetry-core"] diff --git a/sqlsynthgen/main.py b/sqlsynthgen/main.py index d1044aff..829986d7 100644 --- a/sqlsynthgen/main.py +++ b/sqlsynthgen/main.py @@ -5,8 +5,10 @@ from subprocess import CalledProcessError, run from sys import stderr from types import ModuleType +from typing import Any, Optional import typer +import yaml from sqlsynthgen.create import create_db_data, create_db_tables from sqlsynthgen.make import make_generators_from_tables @@ -22,6 +24,13 @@ def import_file(file_path: str) -> ModuleType: return import_module(module_path) +def read_yaml_file(path: str) -> Any: + """Read a yaml file in to dictionary, given a path.""" + with open(path, "r", encoding="utf8") as f: + config = yaml.safe_load(f) + return config + + @app.command() def create_data( orm_file: str = typer.Argument(...), @@ -44,10 +53,14 @@ def create_tables(orm_file: str = typer.Argument(...)) -> None: @app.command() -def make_generators(orm_file: str = typer.Argument(...)) -> None: +def make_generators( + orm_file: str = typer.Argument(...), + config_file: Optional[str] = typer.Argument(None), +) -> None: """Make a SQLSynthGen file of generator classes.""" orm_module = import_file(orm_file) - result = make_generators_from_tables(orm_module) + generator_config = read_yaml_file(config_file) if config_file is not None else {} + result = make_generators_from_tables(orm_module, generator_config) print(result) diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index 638b1b1d..2237a03e 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -1,27 +1,129 @@ """Functions to make a module of generator classes.""" +import inspect from types import ModuleType -from typing import Final +from typing import Any, Final +from mimesis.providers.base import BaseProvider from sqlalchemy.sql import sqltypes -HEADER_TEXT: Final[str] = "\n".join( +from sqlsynthgen import providers + +HEADER_TEXT: str = "\n".join( ( '"""This file was auto-generated by sqlsynthgen but can be edited manually."""', "from mimesis import Generic", "from mimesis.locales import Locale", - "from sqlsynthgen.providers import BytesProvider, ColumnValueProvider", "", "generic = Generic(locale=Locale.EN)", - "generic.add_provider(ColumnValueProvider)", - "generic.add_provider(BytesProvider)", "", ) ) +for entry_name, entry in inspect.getmembers(providers, inspect.isclass): + if issubclass(entry, BaseProvider) and entry.__module__ == "sqlsynthgen.providers": + HEADER_TEXT += f"\nfrom sqlsynthgen.providers import {entry_name}" + HEADER_TEXT += f"\ngeneric.add_provider({entry_name})" +HEADER_TEXT += "\n" INDENTATION: Final[str] = " " * 4 +SQL_TO_MIMESIS_MAP = { + sqltypes.BigInteger: "generic.numeric.integer_number()", + sqltypes.Boolean: "generic.development.boolean()", + sqltypes.Date: "generic.datetime.date()", + sqltypes.DateTime: "generic.datetime.datetime()", + sqltypes.Float: "generic.numeric.float_number()", + sqltypes.Integer: "generic.numeric.integer_number()", + sqltypes.LargeBinary: "generic.bytes_provider.bytes()", + sqltypes.Numeric: "generic.numeric.float_number()", + sqltypes.String: "generic.text.color()", + sqltypes.Text: "generic.text.color()", +} + + +def _add_custom_generators(content: str, table_config: dict) -> tuple[str, list[str]]: + """Add to the generators file, written in the string `content`, the custom + generators for the given table. + """ + generators_config = table_config.get("custom_generators", {}) + columns_covered = [] + for gen_conf in generators_config: + name = gen_conf["name"] + columns_assigned = gen_conf["columns_assigned"] + args = gen_conf["args"] + if isinstance(columns_assigned, str): + columns_assigned = [columns_assigned] + + content += INDENTATION * 2 + content += ", ".join(map(lambda x: f"self.{x}", columns_assigned)) + try: + columns_covered += columns_assigned + except TypeError: + # Might be a single string, rather than a list of strings. + columns_covered.append(columns_assigned) + content += f" = {name}(" + if args is not None: + content += ", ".join(f"{key}={value}" for key, value in args.items()) + content += ")\n" + return content, columns_covered + + +def _add_default_generator(content: str, column: Any) -> str: + """Add to the generator file `content` a default generator for the given column, + determined by the column's type. + """ + content += INDENTATION * 2 + # If it's a primary key column, we presume that primary keys are populated + # automatically. + if column.primary_key: + content += "pass" + # If it's a foreign key column, pull random values from the column it + # references. + elif column.foreign_keys: + if len(column.foreign_keys) > 1: + raise NotImplementedError( + "Can't handle multiple foreign keys for one column." + ) + fkey = column.foreign_keys.pop() + fk_schema, fk_table, fk_column = fkey.target_fullname.split(".") + content += ( + f"self.{column.name} = " + f"generic.column_value_provider.column_value(dst_db_conn, " + f'"{fk_schema}", "{fk_table}", "{fk_column}"' + ")" + ) -def make_generators_from_tables(tables_module: ModuleType) -> str: + # Otherwise generate values based on just the datatype of the column. + else: + provider = SQL_TO_MIMESIS_MAP[type(column.type)] + content += f"self.{column.name} = {provider}" + content += "\n" + return content + + +def _add_generator_for_table( + content: str, table_config: dict, table: Any +) -> tuple[str, str]: + """Add to the generator file `content` a generator for the given table.""" + new_class_name = table.name + "Generator" + if table_config.get("vocabulary_table", False): + raise NotImplementedError("Vocabulary tables currently unimplemented.") + + content += ( + f"\n\nclass {new_class_name}:\n" + f"{INDENTATION}def __init__(self, src_db_conn, dst_db_conn):\n" + ) + content, columns_covered = _add_custom_generators(content, table_config) + for column in table.columns: + if column.name in columns_covered: + # A generator for this column was already covered in the user config. + continue + content = _add_default_generator(content, column) + return content, new_class_name + + +def make_generators_from_tables( + tables_module: ModuleType, generator_config: dict +) -> str: """Creates sqlsynthgen generator classes from a sqlacodegen-generated file. Args: @@ -30,65 +132,18 @@ def make_generators_from_tables(tables_module: ModuleType) -> str: Returns: A string that is a valid Python module, once written to file. """ - new_content = HEADER_TEXT + generator_module_name = generator_config.get("custom_generators_module", None) + if generator_module_name is not None: + new_content += f"\nfrom . import {generator_module_name}" sorted_generators = "[\n" - - sql_to_mimesis_map = { - sqltypes.BigInteger: "generic.numeric.integer_number()", - sqltypes.Boolean: "generic.development.boolean()", - sqltypes.Date: "generic.datetime.date()", - sqltypes.DateTime: "generic.datetime.datetime()", - sqltypes.Float: "generic.numeric.float_number()", - sqltypes.Integer: "generic.numeric.integer_number()", - sqltypes.LargeBinary: "generic.bytes_provider.bytes()", - sqltypes.Numeric: "generic.numeric.float_number()", - sqltypes.String: "generic.text.color()", - sqltypes.Text: "generic.text.color()", - } - for table in tables_module.Base.metadata.sorted_tables: - new_class_name = table.name + "Generator" - sorted_generators += INDENTATION + new_class_name + ",\n" - new_content += ( - "\n\nclass " - + new_class_name - + ":\n" - + INDENTATION - + "def __init__(self, src_db_conn, dst_db_conn):\n" + table_config = generator_config.get("tables", {}).get(table.name, {}) + new_content, new_generator_name = _add_generator_for_table( + new_content, table_config, table ) - - for column in table.columns: - # We presume that primary keys are populated automatically - if column.primary_key: - new_content += f"{INDENTATION*2}pass\n" - - elif column.foreign_keys: - if len(column.foreign_keys) > 1: - raise NotImplementedError("Can't handle multiple foreign keys.") - fkey = column.foreign_keys.pop() - fk_column_path = fkey.target_fullname - fk_schema, fk_table, fk_column = fk_column_path.split(".") - new_content += ( - f"{INDENTATION*2}self.{column.name} = " - f"generic.column_value_provider.column_value(dst_db_conn, " - f'"{fk_schema}", "{fk_table}", "{fk_column}"' - ")\n" - ) - - else: - new_content += ( - INDENTATION * 2 - + "self." - + column.name - + " = " - + sql_to_mimesis_map[type(column.type)] - + "\n" - ) - + sorted_generators += f"{INDENTATION}{new_generator_name},\n" sorted_generators += "]" - new_content += "\n\n" + "sorted_generators = " + sorted_generators + "\n" - return new_content diff --git a/sqlsynthgen/providers.py b/sqlsynthgen/providers.py index cadf52c4..55ac3c37 100644 --- a/sqlsynthgen/providers.py +++ b/sqlsynthgen/providers.py @@ -1,7 +1,9 @@ """This module contains Mimesis Provider sub-classes.""" +import datetime as dt +import random from typing import Any -from mimesis import Text +from mimesis import Datetime, Text from mimesis.providers.base import BaseDataProvider, BaseProvider from sqlalchemy.sql import text @@ -34,3 +36,51 @@ class Meta: def bytes(self) -> bytes: """Return a UTF-8 encoded sentence.""" return Text(self.locale).sentence().encode("utf-8") + + +class TimedeltaProvider(BaseProvider): + """A Mimesis provider of timedeltas.""" + + class Meta: + """Meta-class for TimedeltaProvider settings.""" + + name = "timedelta_provider" + + def timedelta( + self, + min_dt: Any = dt.timedelta(seconds=0), + # ints bigger than this cause trouble + max_dt: Any = dt.timedelta(seconds=2**32), + ) -> dt.timedelta: + """Return a random timedelta object.""" + min_s = min_dt.total_seconds() + max_s = max_dt.total_seconds() + seconds = random.randint(min_s, max_s) + return dt.timedelta(seconds=seconds) + + +class TimespanProvider(BaseProvider): + """A Mimesis provider for timespans. + + A timespan consits of start datetime, end datetime, and the timedelta in between. + Returns a 3-tuple. + """ + + class Meta: + """Meta-class for TimespanProvider settings.""" + + name = "timespan_provider" + + def timespan( + self, + earliest_start_year: Any, + last_start_year: Any, + min_dt: Any = dt.timedelta(seconds=0), + # ints bigger than this cause trouble + max_dt: Any = dt.timedelta(seconds=2**32), + ) -> tuple[dt.datetime, dt.datetime, dt.timedelta]: + """Return a timespan as a 3-tuple of (start, end, delta).""" + delta = TimedeltaProvider().timedelta(min_dt, max_dt) + start = Datetime().datetime(start=earliest_start_year, end=last_start_year) + end = start + delta + return start, end, delta diff --git a/tests/examples/custom_generators.py b/tests/examples/custom_generators.py new file mode 100644 index 00000000..f906f9d1 --- /dev/null +++ b/tests/examples/custom_generators.py @@ -0,0 +1,16 @@ +import datetime as dt + + +def timespan_generator( + generic, + earliest_start_year, + last_start_year, + min_dt_days, + max_dt_days, +): + min_dt = dt.timedelta(days=min_dt_days) + max_dt = dt.timedelta(days=max_dt_days) + start, end, delta = generic.timespan_provider.timespan( + earliest_start_year, last_start_year, min_dt, max_dt + ) + return start, end, delta.total_seconds() diff --git a/tests/examples/expected_ssg.py b/tests/examples/expected_ssg.py index ce14784f..927a878f 100644 --- a/tests/examples/expected_ssg.py +++ b/tests/examples/expected_ssg.py @@ -1,12 +1,19 @@ """This file was auto-generated by sqlsynthgen but can be edited manually.""" from mimesis import Generic from mimesis.locales import Locale -from sqlsynthgen.providers import BytesProvider, ColumnValueProvider generic = Generic(locale=Locale.EN) -generic.add_provider(ColumnValueProvider) + +from sqlsynthgen.providers import BytesProvider generic.add_provider(BytesProvider) +from sqlsynthgen.providers import ColumnValueProvider +generic.add_provider(ColumnValueProvider) +from sqlsynthgen.providers import TimedeltaProvider +generic.add_provider(TimedeltaProvider) +from sqlsynthgen.providers import TimespanProvider +generic.add_provider(TimespanProvider) +from . import custom_generators class entityGenerator: def __init__(self, src_db_conn, dst_db_conn): @@ -15,21 +22,19 @@ def __init__(self, src_db_conn, dst_db_conn): class personGenerator: def __init__(self, src_db_conn, dst_db_conn): + self.name = generic.person.full_name() + self.stored_from = generic.datetime.datetime(start=2022, end=2022) pass - self.name = generic.text.color() self.nhs_number = generic.text.color() self.research_opt_out = generic.development.boolean() self.source_system = generic.text.color() - self.stored_from = generic.datetime.datetime() class hospital_visitGenerator: def __init__(self, src_db_conn, dst_db_conn): + self.visit_start, self.visit_end, self.visit_duration_seconds = custom_generators.timespan_generator(generic=generic, earliest_start_year=2021, last_start_year=2022, min_dt_days=1, max_dt_days=30) pass self.person_id = generic.column_value_provider.column_value(dst_db_conn, "myschema", "person", "person_id") - self.visit_start = generic.datetime.datetime() - self.visit_end = generic.datetime.date() - self.visit_duration_seconds = generic.numeric.float_number() self.visit_image = generic.bytes_provider.bytes() diff --git a/tests/examples/generator_conf.yaml b/tests/examples/generator_conf.yaml new file mode 100644 index 00000000..6566666d --- /dev/null +++ b/tests/examples/generator_conf.yaml @@ -0,0 +1,29 @@ +custom_generators_module: custom_generators +tables: + person: + num_rows_per_pass: 2 + vocabulary_table: false + custom_generators: + - name: generic.person.full_name + args: null + columns_assigned: name + - name: generic.datetime.datetime + args: + start: 2022 + end: 2022 + columns_assigned: stored_from + + hospital_visit: + num_rows_per_pass: 3 + custom_generators: + - name: custom_generators.timespan_generator + args: + generic: generic + earliest_start_year: 2021 + last_start_year: 2022 + min_dt_days: 1 + max_dt_days: 30 + columns_assigned: + - visit_start + - visit_end + - visit_duration_seconds diff --git a/tests/test_main.py b/tests/test_main.py index 04d6f3ae..4d75ec9c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -3,6 +3,7 @@ from unittest import TestCase from unittest.mock import call, patch +import yaml from click.testing import Result from typer.testing import CliRunner @@ -119,14 +120,21 @@ def test_make_tables_handles_errors(self) -> None: def test_make_generators(self) -> None: """Test the make-generators sub-command.""" with patch("sqlsynthgen.main.make_generators_from_tables") as mock_make: + conf_path = "tests/examples/generator_conf.yaml" + with open(conf_path, "r", encoding="utf8") as f: + config = yaml.safe_load(f) result = runner.invoke( app, - ["make-generators", "tests/examples/example_orm.py"], + [ + "make-generators", + "tests/examples/example_orm.py", + conf_path, + ], catch_exceptions=False, ) self.assertSuccess(result) - mock_make.assert_called_once_with(example_orm) + mock_make.assert_called_once_with(example_orm, config) def test_create_tables(self) -> None: """Test the create-tables sub-command.""" diff --git a/tests/test_make.py b/tests/test_make.py index 7677e105..94e59a07 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -1,6 +1,8 @@ """Tests for the main module.""" from unittest import TestCase +import yaml + from sqlsynthgen import make from tests.examples import example_orm @@ -10,11 +12,14 @@ class MyTestCase(TestCase): def test_make_generators_from_tables(self) -> None: """Check that we can make a generators file from a tables module.""" - + self.maxDiff = None # pylint: disable=invalid-name with open( "tests/examples/expected_ssg.py", encoding="utf-8" ) as expected_output: expected = expected_output.read() + conf_path = "tests/examples/generator_conf.yaml" + with open(conf_path, "r", encoding="utf8") as f: + config = yaml.safe_load(f) - actual = make.make_generators_from_tables(example_orm) + actual = make.make_generators_from_tables(example_orm, config) self.assertEqual(expected, actual) diff --git a/tests/test_providers.py b/tests/test_providers.py index fe79fbb3..9c1c767f 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,11 +1,12 @@ """Tests for the providers module.""" +import datetime as dt import os from unittest import TestCase, skipUnless from sqlalchemy import Column, Integer, Text, create_engine, insert from sqlalchemy.ext.declarative import declarative_base -from sqlsynthgen.providers import BytesProvider, ColumnValueProvider +from sqlsynthgen import providers from tests.utils import run_psql # pylint: disable=invalid-name @@ -31,7 +32,7 @@ class BinaryProviderTestCase(TestCase): def test_bytes(self) -> None: """Test the bytes method.""" - self.assertTrue(BytesProvider().bytes().decode("utf-8") != "") + self.assertTrue(providers.BytesProvider().bytes().decode("utf-8") != "") @skipUnless( @@ -58,7 +59,39 @@ def test_column_value(self) -> None: stmt = insert(Person).values(sex="M") conn.execute(stmt) - provider = ColumnValueProvider() + provider = providers.ColumnValueProvider() key = provider.column_value(conn, "public", "person", "sex") self.assertEqual("M", key) + + +class TimedeltaProvider(TestCase): + """Tests for TimedeltaProvider""" + + def test_timedelta(self) -> None: + """Test the timedelta method.""" + min_dt = dt.timedelta(days=1) + max_dt = dt.timedelta(days=2) + delta = providers.TimedeltaProvider().timedelta(min_dt=min_dt, max_dt=max_dt) + assert isinstance(delta, dt.timedelta) + assert min_dt <= delta <= max_dt + + +class TimespanProvider(TestCase): + """Tests for TimespanProvider.""" + + def test_timespan(self) -> None: + """Test the timespan method""" + earliest_start_year = 1917 + last_start_year = 1923 + min_dt = dt.timedelta(seconds=2) + max_dt = dt.timedelta(days=10000) + start, end, delta = providers.TimespanProvider().timespan( + earliest_start_year, last_start_year, min_dt, max_dt + ) + assert isinstance(start, dt.datetime) + assert isinstance(end, dt.datetime) + assert isinstance(delta, dt.timedelta) + assert earliest_start_year <= start.year <= last_start_year + assert min_dt <= delta <= max_dt + assert end - start == delta