Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions sqlsynthgen/make.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Functions to make a module of generator classes."""
import inspect
from types import ModuleType
from typing import Any, Final
from typing import Any, Final, Optional

from mimesis.providers.base import BaseProvider
from sqlalchemy.sql import sqltypes
Expand Down Expand Up @@ -40,6 +40,15 @@
}


def _orm_class_from_table_name(tables_module: Any, full_name: str) -> Optional[Any]:
"""Return the ORM class corresponding to a table name."""
for mapper in tables_module.Base.registry.mappers:
cls = mapper.class_
if cls.__table__.fullname == full_name:
return cls
return None


def _add_custom_generators(content: str, table_config: dict) -> tuple[str, list[str]]:
"""Add to the generators file, written in the string `content`, the custom
generators for the given table.
Expand Down Expand Up @@ -67,7 +76,7 @@ def _add_custom_generators(content: str, table_config: dict) -> tuple[str, list[
return content, columns_covered


def _add_default_generator(content: str, column: Any) -> str:
def _add_default_generator(content: str, tables_module: ModuleType, column: Any) -> str:
"""Add to the generator file `content` a default generator for the given column,
determined by the column's type.
"""
Expand All @@ -84,11 +93,17 @@ def _add_default_generator(content: str, column: Any) -> str:
"Can't handle multiple foreign keys for one column."
)
fkey = column.foreign_keys.pop()
fk_schema, fk_table, fk_column = fkey.target_fullname.split(".")
target_name_parts = fkey.target_fullname.split(".")
target_table_name = ".".join(target_name_parts[:-1])
target_column_name = target_name_parts[-1]
target_orm_class = _orm_class_from_table_name(tables_module, target_table_name)
if target_orm_class is None:
raise ValueError(f"Could not find the ORM class for {target_table_name}.")
content += (
f"self.{column.name} = "
f"generic.column_value_provider.column_value(dst_db_conn, "
f'"{fk_schema}", "{fk_table}", "{fk_column}"'
f"{tables_module.__name__}.{target_orm_class.__name__}, "
f'"{target_column_name}"'
")"
)

Expand All @@ -101,7 +116,7 @@ def _add_default_generator(content: str, column: Any) -> str:


def _add_generator_for_table(
content: str, table_config: dict, table: Any
content: str, tables_module: ModuleType, table_config: dict, table: Any
) -> tuple[str, str]:
"""Add to the generator file `content` a generator for the given table."""
new_class_name = table.name + "Generator"
Expand All @@ -117,7 +132,7 @@ def _add_generator_for_table(
if column.name in columns_covered:
# A generator for this column was already covered in the user config.
continue
content = _add_default_generator(content, column)
content = _add_default_generator(content, tables_module, column)
return content, new_class_name


Expand All @@ -133,6 +148,7 @@ def make_generators_from_tables(
A string that is a valid Python module, once written to file.
"""
new_content = HEADER_TEXT
new_content += f"\nimport {tables_module.__name__}"
generator_module_name = generator_config.get("custom_generators_module", None)
if generator_module_name is not None:
new_content += f"\nfrom . import {generator_module_name}"
Expand All @@ -141,7 +157,7 @@ def make_generators_from_tables(
for table in tables_module.Base.metadata.sorted_tables:
table_config = generator_config.get("tables", {}).get(table.name, {})
new_content, new_generator_name = _add_generator_for_table(
new_content, table_config, table
new_content, tables_module, table_config, table
)
sorted_generators += f"{INDENTATION}{new_generator_name},\n"
sorted_generators += "]"
Expand Down
12 changes: 5 additions & 7 deletions sqlsynthgen/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mimesis import Datetime, Text
from mimesis.providers.base import BaseDataProvider, BaseProvider
from sqlalchemy.sql import text
from sqlalchemy.sql import func, select


class ColumnValueProvider(BaseProvider):
Expand All @@ -16,13 +16,11 @@ class Meta:

name = "column_value_provider"

def column_value(
self, db_connection: Any, schema: str, table: str, column: str
) -> Any:
def column_value(self, db_connection: Any, orm_class: Any, column_name: str) -> Any:
"""Return a random value from the column specified."""
query_str = f"SELECT {column} FROM {schema}.{table} ORDER BY random() LIMIT 1"
key = db_connection.execute(text(query_str)).fetchone()[0]
return key
query = select(orm_class).order_by(func.random()).limit(1)
random_row = db_connection.execute(query).first()
return getattr(random_row, column_name)


class BytesProvider(BaseDataProvider):
Expand Down
3 changes: 2 additions & 1 deletion tests/examples/expected_ssg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlsynthgen.providers import TimespanProvider
generic.add_provider(TimespanProvider)

import tests.examples.example_orm
from . import custom_generators

class entityGenerator:
Expand All @@ -34,7 +35,7 @@ class hospital_visitGenerator:
def __init__(self, src_db_conn, dst_db_conn):
self.visit_start, self.visit_end, self.visit_duration_seconds = custom_generators.timespan_generator(generic=generic, earliest_start_year=2021, last_start_year=2022, min_dt_days=1, max_dt_days=30)
pass
self.person_id = generic.column_value_provider.column_value(dst_db_conn, "myschema", "person", "person_id")
self.person_id = generic.column_value_provider.column_value(dst_db_conn, tests.examples.example_orm.Person, "person_id")
self.visit_image = generic.bytes_provider.bytes()


Expand Down
2 changes: 1 addition & 1 deletion tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_column_value(self) -> None:
conn.execute(stmt)

provider = providers.ColumnValueProvider()
key = provider.column_value(conn, "public", "person", "sex")
key = provider.column_value(conn, Person, "sex")

self.assertEqual("M", key)

Expand Down