From a1c489e02ab91712d912f04ce7db4a0225bb5194 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 10 Feb 2023 15:42:28 +0000 Subject: [PATCH] Use ORM in ColumnValueProvider --- sqlsynthgen/make.py | 30 +++++++++++++++++++++++------- sqlsynthgen/providers.py | 12 +++++------- tests/examples/expected_ssg.py | 3 ++- tests/test_providers.py | 2 +- 4 files changed, 31 insertions(+), 16 deletions(-) diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index 2237a03e..39501669 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -1,7 +1,7 @@ """Functions to make a module of generator classes.""" import inspect from types import ModuleType -from typing import Any, Final +from typing import Any, Final, Optional from mimesis.providers.base import BaseProvider from sqlalchemy.sql import sqltypes @@ -40,6 +40,15 @@ } +def _orm_class_from_table_name(tables_module: Any, full_name: str) -> Optional[Any]: + """Return the ORM class corresponding to a table name.""" + for mapper in tables_module.Base.registry.mappers: + cls = mapper.class_ + if cls.__table__.fullname == full_name: + return cls + return None + + def _add_custom_generators(content: str, table_config: dict) -> tuple[str, list[str]]: """Add to the generators file, written in the string `content`, the custom generators for the given table. @@ -67,7 +76,7 @@ def _add_custom_generators(content: str, table_config: dict) -> tuple[str, list[ return content, columns_covered -def _add_default_generator(content: str, column: Any) -> str: +def _add_default_generator(content: str, tables_module: ModuleType, column: Any) -> str: """Add to the generator file `content` a default generator for the given column, determined by the column's type. """ @@ -84,11 +93,17 @@ def _add_default_generator(content: str, column: Any) -> str: "Can't handle multiple foreign keys for one column." ) fkey = column.foreign_keys.pop() - fk_schema, fk_table, fk_column = fkey.target_fullname.split(".") + target_name_parts = fkey.target_fullname.split(".") + target_table_name = ".".join(target_name_parts[:-1]) + target_column_name = target_name_parts[-1] + target_orm_class = _orm_class_from_table_name(tables_module, target_table_name) + if target_orm_class is None: + raise ValueError(f"Could not find the ORM class for {target_table_name}.") content += ( f"self.{column.name} = " f"generic.column_value_provider.column_value(dst_db_conn, " - f'"{fk_schema}", "{fk_table}", "{fk_column}"' + f"{tables_module.__name__}.{target_orm_class.__name__}, " + f'"{target_column_name}"' ")" ) @@ -101,7 +116,7 @@ def _add_default_generator(content: str, column: Any) -> str: def _add_generator_for_table( - content: str, table_config: dict, table: Any + content: str, tables_module: ModuleType, table_config: dict, table: Any ) -> tuple[str, str]: """Add to the generator file `content` a generator for the given table.""" new_class_name = table.name + "Generator" @@ -117,7 +132,7 @@ def _add_generator_for_table( if column.name in columns_covered: # A generator for this column was already covered in the user config. continue - content = _add_default_generator(content, column) + content = _add_default_generator(content, tables_module, column) return content, new_class_name @@ -133,6 +148,7 @@ def make_generators_from_tables( A string that is a valid Python module, once written to file. """ new_content = HEADER_TEXT + new_content += f"\nimport {tables_module.__name__}" generator_module_name = generator_config.get("custom_generators_module", None) if generator_module_name is not None: new_content += f"\nfrom . import {generator_module_name}" @@ -141,7 +157,7 @@ def make_generators_from_tables( for table in tables_module.Base.metadata.sorted_tables: table_config = generator_config.get("tables", {}).get(table.name, {}) new_content, new_generator_name = _add_generator_for_table( - new_content, table_config, table + new_content, tables_module, table_config, table ) sorted_generators += f"{INDENTATION}{new_generator_name},\n" sorted_generators += "]" diff --git a/sqlsynthgen/providers.py b/sqlsynthgen/providers.py index 55ac3c37..374c25e7 100644 --- a/sqlsynthgen/providers.py +++ b/sqlsynthgen/providers.py @@ -5,7 +5,7 @@ from mimesis import Datetime, Text from mimesis.providers.base import BaseDataProvider, BaseProvider -from sqlalchemy.sql import text +from sqlalchemy.sql import func, select class ColumnValueProvider(BaseProvider): @@ -16,13 +16,11 @@ class Meta: name = "column_value_provider" - def column_value( - self, db_connection: Any, schema: str, table: str, column: str - ) -> Any: + def column_value(self, db_connection: Any, orm_class: Any, column_name: str) -> Any: """Return a random value from the column specified.""" - query_str = f"SELECT {column} FROM {schema}.{table} ORDER BY random() LIMIT 1" - key = db_connection.execute(text(query_str)).fetchone()[0] - return key + query = select(orm_class).order_by(func.random()).limit(1) + random_row = db_connection.execute(query).first() + return getattr(random_row, column_name) class BytesProvider(BaseDataProvider): diff --git a/tests/examples/expected_ssg.py b/tests/examples/expected_ssg.py index 927a878f..5a42583c 100644 --- a/tests/examples/expected_ssg.py +++ b/tests/examples/expected_ssg.py @@ -13,6 +13,7 @@ from sqlsynthgen.providers import TimespanProvider generic.add_provider(TimespanProvider) +import tests.examples.example_orm from . import custom_generators class entityGenerator: @@ -34,7 +35,7 @@ class hospital_visitGenerator: def __init__(self, src_db_conn, dst_db_conn): self.visit_start, self.visit_end, self.visit_duration_seconds = custom_generators.timespan_generator(generic=generic, earliest_start_year=2021, last_start_year=2022, min_dt_days=1, max_dt_days=30) pass - self.person_id = generic.column_value_provider.column_value(dst_db_conn, "myschema", "person", "person_id") + self.person_id = generic.column_value_provider.column_value(dst_db_conn, tests.examples.example_orm.Person, "person_id") self.visit_image = generic.bytes_provider.bytes() diff --git a/tests/test_providers.py b/tests/test_providers.py index 9c1c767f..13459383 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -60,7 +60,7 @@ def test_column_value(self) -> None: conn.execute(stmt) provider = providers.ColumnValueProvider() - key = provider.column_value(conn, "public", "person", "sex") + key = provider.column_value(conn, Person, "sex") self.assertEqual("M", key)