diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py index ad42f602..9a308519 100644 --- a/sqlsynthgen/create.py +++ b/sqlsynthgen/create.py @@ -33,7 +33,9 @@ def create_db_vocab(sorted_vocab: List[Any]) -> None: vocab_table.load(dst_conn) -def create_db_data(sorted_tables: list, sorted_generators: list, num_rows: int) -> None: +def create_db_data( + sorted_tables: list, sorted_generators: list, num_passes: int +) -> None: """Connect to a database and populate it with data.""" settings = get_settings() dst_engine = create_engine(settings.dst_postgres_dsn) @@ -41,17 +43,18 @@ def create_db_data(sorted_tables: list, sorted_generators: list, num_rows: int) with dst_engine.connect() as dst_conn: with src_engine.connect() as src_conn: - populate(src_conn, dst_conn, sorted_tables, sorted_generators, num_rows) + populate(src_conn, dst_conn, sorted_tables, sorted_generators, num_passes) def populate( - src_conn: Any, dst_conn: Any, tables: list, generators: list, num_rows: int + src_conn: Any, dst_conn: Any, tables: list, generators: list, num_passes: int ) -> None: """Populate a database schema with dummy data.""" for table, generator in reversed( list(zip(reversed(tables), reversed(generators))) ): # Run all the inserts for one table in a transaction with dst_conn.begin(): - for _ in range(num_rows): - stmt = insert(table).values(generator(src_conn, dst_conn).__dict__) - dst_conn.execute(stmt) + for _ in range(num_passes): + for __ in range(generator.num_rows_per_pass): + stmt = insert(table).values(generator(src_conn, dst_conn).__dict__) + dst_conn.execute(stmt) diff --git a/sqlsynthgen/main.py b/sqlsynthgen/main.py index 5efb4166..60a294b5 100644 --- a/sqlsynthgen/main.py +++ b/sqlsynthgen/main.py @@ -45,7 +45,7 @@ def read_yaml_file(path: str) -> Any: def create_data( orm_file: str = typer.Argument(...), ssg_file: str = typer.Argument(...), - num_rows: int = typer.Argument(...), + num_passes: int = typer.Argument(...), ) -> None: """Populate schema with synthetic data. @@ -63,12 +63,12 @@ def create_data( Final input is the number of rows required. Example: - $ python sqlsynthgen/main.py create-data example_orm.py expected_ssg.py 100 + $ sqlsynthgen create-data example_orm.py expected_ssg.py 100 Args: orm_file (str): Path to object relational model. ssg_file (str): Path to sqlsyngen output. - num_rows (int): Number of rows of values required + num_passes (int): Number of passes to make. Returns: None @@ -76,7 +76,7 @@ def create_data( orm_module = import_file(orm_file) ssg_module = import_file(ssg_file) create_db_data( - orm_module.Base.metadata.sorted_tables, ssg_module.sorted_generators, num_rows + orm_module.Base.metadata.sorted_tables, ssg_module.sorted_generators, num_passes ) @@ -95,7 +95,7 @@ def create_tables(orm_file: str = typer.Argument(...)) -> None: declared as Python tables. (eg.) Example: - $ python sqlsynthgen/main.py create-tables example_orm.py + $ sqlsynthgen create-tables example_orm.py Args: orm_file (str): Path to Python tables file. @@ -119,7 +119,7 @@ def make_generators( returns a set of synthetic data generators for each attribute Example: - $ python sqlsynthgen/main.py make-generators example_orm.py + $ sqlsynthgen make-generators example_orm.py Args: orm_file (str): Path to Python tables file. @@ -140,7 +140,7 @@ def make_tables() -> None: as Python classes. Example: - $ python sqlsynthgen/main.py make_tables + $ sqlsynthgen make_tables """ settings = get_settings() diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index 580ec018..0f650852 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -121,19 +121,14 @@ def _add_generator_for_table( ) -> tuple[str, str]: """Add to the generator file `content` a generator for the given table.""" new_class_name = table.name + "Generator" - if table_config.get("vocabulary_table", False): - raise NotImplementedError("Vocabulary tables currently unimplemented.") - - content += ( - f"\n\nclass {new_class_name}:\n" - f"{INDENTATION}def __init__(self, src_db_conn, dst_db_conn):\n" - ) + content += f"\n\nclass {new_class_name}:\n" + content += f"{INDENTATION}num_rows_per_pass = {table_config.get('num_rows_per_pass', 1)}\n\n" + content += f"{INDENTATION}def __init__(self, src_db_conn, dst_db_conn):\n" content, columns_covered = _add_custom_generators(content, table_config) for column in table.columns: - if column.name in columns_covered: - # A generator for this column was already covered in the user config. - continue - content = _add_default_generator(content, tables_module, column) + if column.name not in columns_covered: + # No generator for this column in the user config. + content = _add_default_generator(content, tables_module, column) return content, new_class_name @@ -176,11 +171,9 @@ def make_generators_from_tables( engine = create_engine(settings.src_postgres_dsn) for table in tables_module.Base.metadata.sorted_tables: - if table.name in [ - x - for x in generator_config.get("tables", {}).keys() - if generator_config["tables"][x].get("vocabulary_table") - ]: + table_config = generator_config.get("tables", {}).get(table.name, {}) + + if table_config.get("vocabulary_table") is True: orm_class = _orm_class_from_table_name(tables_module, table.fullname) if not orm_class: @@ -194,13 +187,11 @@ def make_generators_from_tables( _download_table(table, engine) - continue - - table_config = generator_config.get("tables", {}).get(table.name, {}) - new_content, new_generator_name = _add_generator_for_table( - new_content, tables_module, table_config, table - ) - sorted_generators += f"{INDENTATION}{new_generator_name},\n" + else: + new_content, new_generator_name = _add_generator_for_table( + new_content, tables_module, table_config, table + ) + sorted_generators += f"{INDENTATION}{new_generator_name},\n" sorted_generators += "]" sorted_vocab += "]" diff --git a/tests/examples/expected_ssg.py b/tests/examples/expected_ssg.py index dee1447c..796c5524 100644 --- a/tests/examples/expected_ssg.py +++ b/tests/examples/expected_ssg.py @@ -20,11 +20,15 @@ concept_vocab = FileUploader(tests.examples.example_orm.Concept.__table__) class entityGenerator: + num_rows_per_pass = 1 + def __init__(self, src_db_conn, dst_db_conn): pass class personGenerator: + num_rows_per_pass = 2 + def __init__(self, src_db_conn, dst_db_conn): self.name = generic.person.full_name() self.stored_from = generic.datetime.datetime(start=2022, end=2022) @@ -35,6 +39,8 @@ def __init__(self, src_db_conn, dst_db_conn): class hospital_visitGenerator: + num_rows_per_pass = 3 + def __init__(self, src_db_conn, dst_db_conn): self.visit_start, self.visit_end, self.visit_duration_seconds = custom_generators.timespan_generator(generic=generic, earliest_start_year=2021, last_start_year=2022, min_dt_days=1, max_dt_days=30) pass diff --git a/tests/test_create.py b/tests/test_create.py index 1e4dfe28..75dde7ea 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -48,16 +48,17 @@ def test_populate(self) -> None: mock_src_conn = MagicMock() mock_dst_conn = MagicMock() mock_gen = MagicMock() + mock_gen.num_rows_per_pass = 2 tables = [None] generators = [mock_gen] populate(mock_src_conn, mock_dst_conn, tables, generators, 1) - mock_gen.assert_called_once_with(mock_src_conn, mock_dst_conn) - mock_insert.return_value.values.assert_called_once_with( - mock_gen.return_value.__dict__ + mock_gen.assert_has_calls([call(mock_src_conn, mock_dst_conn)] * 2) + mock_insert.return_value.values.assert_has_calls( + [call(mock_gen.return_value.__dict__)] * 2 ) - mock_dst_conn.execute.assert_called_once_with( - mock_insert.return_value.values.return_value + mock_dst_conn.execute.assert_has_calls( + [call(mock_insert.return_value.values.return_value)] * 2 ) def test_populate_diff_length(self) -> None: