From a7e47924f8ea341181ee6f626fffdeefb49a719e Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Fri, 3 Feb 2023 15:59:40 +0000 Subject: [PATCH 01/13] Vocab WIP --- sqlsynthgen/create.py | 6 +++++ sqlsynthgen/main.py | 10 +++++++++ sqlsynthgen/make.py | 6 +++++ tests/examples/src.dump | 47 ++++++++++++++++++---------------------- tests/test_functional.py | 5 +++++ 5 files changed, 48 insertions(+), 26 deletions(-) diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py index 02bd2243..40eb1365 100644 --- a/sqlsynthgen/create.py +++ b/sqlsynthgen/create.py @@ -23,6 +23,12 @@ def create_db_tables(metadata: Any) -> Any: metadata.create_all(engine) +def create_db_vocab(sorted_vocab): + settings = get_settings() + for vocab_table in sorted_vocab: + vocab_table.load() + + def create_db_data(sorted_tables: list, sorted_generators: list, num_rows: int) -> None: """Connect to a database and populate it with data.""" settings = get_settings() diff --git a/sqlsynthgen/main.py b/sqlsynthgen/main.py index 829986d7..bea1d1b7 100644 --- a/sqlsynthgen/main.py +++ b/sqlsynthgen/main.py @@ -45,6 +45,16 @@ def create_data( ) +@app.command() +def create_vocab( + orm_file: str = typer.Argument(...), ssg_file: str = typer.Argument(...) +) -> None: + """Create tables using the SQLAlchemy file.""" + orm_module = import_file(orm_file) + ssg_module = import_file(ssg_file) + create_db_vocab(ssg_module.sorted_vocab) + + @app.command() def create_tables(orm_file: str = typer.Argument(...)) -> None: """Create tables using the SQLAlchemy file.""" diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index 39501669..a17fe618 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -155,11 +155,17 @@ def make_generators_from_tables( sorted_generators = "[\n" for table in tables_module.Base.metadata.sorted_tables: + # ToDo Get list of vocab tables from config file + if table.name in ("",): + pass + continue + table_config = generator_config.get("tables", {}).get(table.name, {}) new_content, new_generator_name = _add_generator_for_table( new_content, tables_module, table_config, table ) sorted_generators += f"{INDENTATION}{new_generator_name},\n" + sorted_generators += "]" new_content += "\n\n" + "sorted_generators = " + sorted_generators + "\n" return new_content diff --git a/tests/examples/src.dump b/tests/examples/src.dump index f9a5258b..dffa05dd 100644 --- a/tests/examples/src.dump +++ b/tests/examples/src.dump @@ -44,26 +44,23 @@ SET default_tablespace = ''; SET default_table_access_method = heap; -- --- Name: hospital_visit; Type: TABLE; Schema: public; Owner: postgres +-- Name: concept; Type: TABLE; Schema: public; Owner: postgres -- -CREATE TABLE public.hospital_visit ( - hospital_visit_id bigint NOT NULL, - person_id integer NOT NULL, - visit_start date NOT NULL, - visit_duration_seconds real NOT NULL, - visit_image bytea NOT NULL +CREATE TABLE public.concept ( + concept_id integer NOT NULL PRIMARY KEY, + concept_name text NOT NULL ); -ALTER TABLE public.hospital_visit OWNER TO postgres; +ALTER TABLE public.concept OWNER TO postgres; -- -- Name: person; Type: TABLE; Schema: public; Owner: postgres -- CREATE TABLE public.person ( - person_id integer NOT NULL, + person_id integer NOT NULL PRIMARY KEY, name text NOT NULL, research_opt_out boolean NOT NULL, stored_from timestamp with time zone NOT NULL @@ -73,36 +70,34 @@ CREATE TABLE public.person ( ALTER TABLE public.person OWNER TO postgres; -- --- Data for Name: hospital_visit; Type: TABLE DATA; Schema: public; Owner: postgres +-- Name: hospital_visit; Type: TABLE; Schema: public; Owner: postgres -- -COPY public.hospital_visit (hospital_visit_id, person_id, visit_start, visit_duration_seconds, visit_image) FROM stdin; -\. +CREATE TABLE public.hospital_visit ( + hospital_visit_id bigint NOT NULL PRIMARY KEY, + person_id integer NOT NULL references public.person(person_id), + visit_start date NOT NULL, + visit_duration_seconds real NOT NULL, + visit_image bytea NOT NULL, + visit_type_concept_id integer NOT NULL references public.concept(concept_id) +); +ALTER TABLE public.hospital_visit OWNER TO postgres; -- --- Data for Name: person; Type: TABLE DATA; Schema: public; Owner: postgres +-- Data for Name: hospital_visit; Type: TABLE DATA; Schema: public; Owner: postgres -- -COPY public.person (person_id, name, research_opt_out, stored_from) FROM stdin; +COPY public.hospital_visit (hospital_visit_id, person_id, visit_start, visit_duration_seconds, visit_image) FROM stdin; \. -- --- Name: hospital_visit hospital_visit_pkey; Type: CONSTRAINT; Schema: public; Owner: postgres --- - -ALTER TABLE ONLY public.hospital_visit - ADD CONSTRAINT hospital_visit_pkey PRIMARY KEY (hospital_visit_id); - - --- --- Name: person person_pkey; Type: CONSTRAINT; Schema: public; Owner: postgres +-- Data for Name: person; Type: TABLE DATA; Schema: public; Owner: postgres -- -ALTER TABLE ONLY public.person - ADD CONSTRAINT person_pkey PRIMARY KEY (person_id); - +COPY public.person (person_id, name, research_opt_out, stored_from) FROM stdin; +\. -- -- PostgreSQL database dump complete diff --git a/tests/test_functional.py b/tests/test_functional.py index 2f4a83c7..4898e9b6 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -52,6 +52,11 @@ def test_workflow(self) -> None: ) run(["sqlsynthgen", "create-tables", self.orm_file_path], env=env, check=True) + run( + ["sqlsynthgen", "create-vocab", self.orm_file_path, self.ssg_file_path], + env=env, + check=True, + ) run( ["sqlsynthgen", "create-data", self.orm_file_path, self.ssg_file_path, "1"], env=env, From 295cd2ab6ce49df5e9cdb8eb68ea04e4acafa608 Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Thu, 9 Feb 2023 18:27:13 +0000 Subject: [PATCH 02/13] Make vocab vars if there are vocab tables --- sqlsynthgen/base.py | 15 +++++++++++++++ sqlsynthgen/create.py | 7 ++++--- sqlsynthgen/main.py | 2 +- sqlsynthgen/make.py | 21 +++++++++++++++++++-- tests/examples/example_orm.py | 11 +++++++++++ tests/examples/expected_ssg.py | 8 ++++++++ tests/test_base.py | 15 +++++++++++++++ tests/test_create.py | 13 ++++++++++++- 8 files changed, 85 insertions(+), 7 deletions(-) create mode 100644 sqlsynthgen/base.py create mode 100644 tests/test_base.py diff --git a/sqlsynthgen/base.py b/sqlsynthgen/base.py new file mode 100644 index 00000000..6551369a --- /dev/null +++ b/sqlsynthgen/base.py @@ -0,0 +1,15 @@ +"""Base generator classes.""" +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + + +@dataclass +class FileUploader: + """For uploading data files.""" + + table: Any + file_name: Optional[Path] = None + + def load(self) -> None: + """Load the data from file.""" diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py index 40eb1365..4419d968 100644 --- a/sqlsynthgen/create.py +++ b/sqlsynthgen/create.py @@ -1,5 +1,5 @@ """Functions and classes to create and populate the target database.""" -from typing import Any +from typing import Any, List from sqlalchemy import create_engine, insert from sqlalchemy.schema import CreateSchema @@ -23,8 +23,9 @@ def create_db_tables(metadata: Any) -> Any: metadata.create_all(engine) -def create_db_vocab(sorted_vocab): - settings = get_settings() +def create_db_vocab(sorted_vocab: List[Any]) -> None: + """Load vocabulary tables from files.""" + # settings = get_settings() for vocab_table in sorted_vocab: vocab_table.load() diff --git a/sqlsynthgen/main.py b/sqlsynthgen/main.py index bea1d1b7..fa21cd0e 100644 --- a/sqlsynthgen/main.py +++ b/sqlsynthgen/main.py @@ -10,7 +10,7 @@ import typer import yaml -from sqlsynthgen.create import create_db_data, create_db_tables +from sqlsynthgen.create import create_db_data, create_db_tables, create_db_vocab from sqlsynthgen.make import make_generators_from_tables from sqlsynthgen.settings import get_settings diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index a17fe618..e1ac129a 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -13,6 +13,7 @@ '"""This file was auto-generated by sqlsynthgen but can be edited manually."""', "from mimesis import Generic", "from mimesis.locales import Locale", + "from sqlsynthgen.base import FileUploader", "", "generic = Generic(locale=Locale.EN)", "", @@ -143,6 +144,7 @@ def make_generators_from_tables( Args: tables_module: A sqlacodegen-generated module. + generator_config: Configuration to control the generator creation. Returns: A string that is a valid Python module, once written to file. @@ -154,10 +156,21 @@ def make_generators_from_tables( new_content += f"\nfrom . import {generator_module_name}" sorted_generators = "[\n" + sorted_vocab = "[\n" + for table in tables_module.Base.metadata.sorted_tables: # ToDo Get list of vocab tables from config file - if table.name in ("",): - pass + if table.name in ("concept",): + + orm_class = _orm_class_from_table_name(tables_module, table.fullname) + if not orm_class: + raise RuntimeError(f"Couldn't find {table.fullname} in {tables_module}") + class_name = orm_class.__name__ + new_content += ( + f"\n\n{class_name.lower()}_vocab " + f"= FileUploader({tables_module.__name__}.{class_name})" + ) + sorted_vocab += f"{INDENTATION}{class_name.lower()}_vocab,\n" continue table_config = generator_config.get("tables", {}).get(table.name, {}) @@ -167,5 +180,9 @@ def make_generators_from_tables( sorted_generators += f"{INDENTATION}{new_generator_name},\n" sorted_generators += "]" + sorted_vocab += "]" + new_content += "\n\n" + "sorted_generators = " + sorted_generators + "\n" + new_content += "\n\n" + "sorted_vocab = " + sorted_vocab + "\n" + return new_content diff --git a/tests/examples/example_orm.py b/tests/examples/example_orm.py index e37ea654..b242a366 100644 --- a/tests/examples/example_orm.py +++ b/tests/examples/example_orm.py @@ -47,3 +47,14 @@ class Entity(Base): Integer, primary_key=True, ) + + +class Concept(Base): + __tablename__ = "concept" + __table_args__ = {"schema": "myschema"} + + concept_id = Column( + Integer, + primary_key=True, + ) + concept_name = Column(Text) diff --git a/tests/examples/expected_ssg.py b/tests/examples/expected_ssg.py index 5a42583c..720a1787 100644 --- a/tests/examples/expected_ssg.py +++ b/tests/examples/expected_ssg.py @@ -1,6 +1,7 @@ """This file was auto-generated by sqlsynthgen but can be edited manually.""" from mimesis import Generic from mimesis.locales import Locale +from sqlsynthgen.base import FileUploader generic = Generic(locale=Locale.EN) @@ -16,6 +17,8 @@ import tests.examples.example_orm from . import custom_generators +concept_vocab = FileUploader(tests.examples.example_orm.Concept) + class entityGenerator: def __init__(self, src_db_conn, dst_db_conn): pass @@ -44,3 +47,8 @@ def __init__(self, src_db_conn, dst_db_conn): personGenerator, hospital_visitGenerator, ] + + +sorted_vocab = [ + concept_vocab, +] diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 00000000..8f41aab7 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,15 @@ +"""Tests for the main module.""" +from unittest import TestCase +from unittest.mock import MagicMock + +from sqlsynthgen.base import FileUploader + + +class VocabTests(TestCase): + """Module test case.""" + + def test_load(self) -> None: + """Test the load method.""" + mock_table = MagicMock() + vocab_gen = FileUploader(mock_table) + vocab_gen.load() diff --git a/tests/test_create.py b/tests/test_create.py index 3041b8ba..8c51d454 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -2,7 +2,12 @@ from unittest import TestCase from unittest.mock import MagicMock, patch -from sqlsynthgen.create import create_db_data, create_db_tables, populate +from sqlsynthgen.create import ( + create_db_data, + create_db_tables, + create_db_vocab, + populate, +) from tests.utils import get_test_settings @@ -54,3 +59,9 @@ def test_populate(self) -> None: mock_dst_conn.execute.assert_called_once_with( mock_insert.return_value.values.return_value ) + + def test_create_db_vocab(self) -> None: + """Test the create_db_vocab function.""" + vocab_list = [MagicMock()] + create_db_vocab(vocab_list) + vocab_list[0].load.assert_called_once() From 9501d0fee783e47fc0911658db754467a6e30854 Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Tue, 14 Feb 2023 16:51:26 +0000 Subject: [PATCH 03/13] Implement FileUploader.load() --- sqlsynthgen/base.py | 14 ++++++++--- tests/examples/expected_ssg.py | 2 +- tests/examples/mytable.csv | 4 +++ tests/test_base.py | 45 +++++++++++++++++++++++++++++++--- 4 files changed, 57 insertions(+), 8 deletions(-) create mode 100644 tests/examples/mytable.csv diff --git a/sqlsynthgen/base.py b/sqlsynthgen/base.py index 6551369a..52a51e28 100644 --- a/sqlsynthgen/base.py +++ b/sqlsynthgen/base.py @@ -1,7 +1,10 @@ """Base generator classes.""" +import csv from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional +from typing import Any + +from sqlalchemy import insert @dataclass @@ -9,7 +12,12 @@ class FileUploader: """For uploading data files.""" table: Any - file_name: Optional[Path] = None - def load(self) -> None: + def load(self, connection: Any) -> None: """Load the data from file.""" + with Path(self.table.fullname + ".csv").open( + "r", newline="", encoding="utf-8" + ) as csvfile: + reader = csv.DictReader(csvfile) + stmt = insert(self.table).values(list(reader)) + connection.execute(stmt) diff --git a/tests/examples/expected_ssg.py b/tests/examples/expected_ssg.py index 720a1787..dee1447c 100644 --- a/tests/examples/expected_ssg.py +++ b/tests/examples/expected_ssg.py @@ -17,7 +17,7 @@ import tests.examples.example_orm from . import custom_generators -concept_vocab = FileUploader(tests.examples.example_orm.Concept) +concept_vocab = FileUploader(tests.examples.example_orm.Concept.__table__) class entityGenerator: def __init__(self, src_db_conn, dst_db_conn): diff --git a/tests/examples/mytable.csv b/tests/examples/mytable.csv new file mode 100644 index 00000000..7a8faf79 --- /dev/null +++ b/tests/examples/mytable.csv @@ -0,0 +1,4 @@ +id +1 +2 +3 diff --git a/tests/test_base.py b/tests/test_base.py index 8f41aab7..7c9c982e 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,15 +1,52 @@ """Tests for the main module.""" +import os from unittest import TestCase -from unittest.mock import MagicMock + +from sqlalchemy import Column, Integer, create_engine, select +from sqlalchemy.orm import declarative_base from sqlsynthgen.base import FileUploader +from tests.utils import run_psql + +# pylint: disable=invalid-name +Base = declarative_base() +# pylint: enable=invalid-name +metadata = Base.metadata + + +class MyTable(Base): # type: ignore + """A SQLAlchemy table.""" + + __tablename__ = "mytable" + id = Column( + Integer, + primary_key=True, + ) class VocabTests(TestCase): """Module test case.""" + def setUp(self) -> None: + """Pre-test setup.""" + + run_psql("providers.dump") + + self.engine = create_engine( + "postgresql://postgres:password@localhost:5432/providers" + ) + metadata.create_all(self.engine) + os.chdir("tests/examples") + + def tearDown(self) -> None: + os.chdir("../..") + def test_load(self) -> None: """Test the load method.""" - mock_table = MagicMock() - vocab_gen = FileUploader(mock_table) - vocab_gen.load() + vocab_gen = FileUploader(MyTable.__table__) + + with self.engine.connect() as conn: + vocab_gen.load(conn) + statement = select([MyTable]) + rows = list(conn.execute(statement)) + self.assertEqual(3, len(rows)) From ba69f24e0f8d051e47e8365711f506777e6827bf Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Tue, 14 Feb 2023 18:11:59 +0000 Subject: [PATCH 04/13] Fix unit tests --- sqlsynthgen/create.py | 11 ++++++++--- sqlsynthgen/make.py | 2 +- tests/test_create.py | 14 +++++++++++--- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py index 4419d968..66d9ab48 100644 --- a/sqlsynthgen/create.py +++ b/sqlsynthgen/create.py @@ -25,9 +25,12 @@ def create_db_tables(metadata: Any) -> Any: def create_db_vocab(sorted_vocab: List[Any]) -> None: """Load vocabulary tables from files.""" - # settings = get_settings() - for vocab_table in sorted_vocab: - vocab_table.load() + settings = get_settings() + dst_engine = create_engine(settings.dst_postgres_dsn) + + with dst_engine.connect() as dst_conn: + for vocab_table in sorted_vocab: + vocab_table.load(dst_conn) def create_db_data(sorted_tables: list, sorted_generators: list, num_rows: int) -> None: @@ -46,6 +49,8 @@ def populate( ) -> None: """Populate a database schema with dummy data.""" + # ToDo Now that we have the vocab list, we can't assume that these are the same length + # We could reverse them, zip them and then reverse them again? for table, generator in zip( tables, generators ): # Run all the inserts for one table in a transaction diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index e1ac129a..15d00436 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -168,7 +168,7 @@ def make_generators_from_tables( class_name = orm_class.__name__ new_content += ( f"\n\n{class_name.lower()}_vocab " - f"= FileUploader({tables_module.__name__}.{class_name})" + f"= FileUploader({tables_module.__name__}.{class_name}.__table__)" ) sorted_vocab += f"{INDENTATION}{class_name.lower()}_vocab,\n" continue diff --git a/tests/test_create.py b/tests/test_create.py index 8c51d454..1de41536 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -62,6 +62,14 @@ def test_populate(self) -> None: def test_create_db_vocab(self) -> None: """Test the create_db_vocab function.""" - vocab_list = [MagicMock()] - create_db_vocab(vocab_list) - vocab_list[0].load.assert_called_once() + with patch("sqlsynthgen.create.create_engine") as mock_create_engine, patch( + "sqlsynthgen.create.get_settings" + ) as mock_get_settings: + vocab_list = [MagicMock()] + create_db_vocab(vocab_list) + vocab_list[0].load.assert_called_once_with( + mock_create_engine.return_value.connect.return_value.__enter__.return_value + ) + mock_create_engine.assert_called_once_with( + mock_get_settings.return_value.dst_postgres_dsn + ) From fd44f982b97a8e9f000d053b091333bb19496512 Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Wed, 15 Feb 2023 19:59:13 +0000 Subject: [PATCH 05/13] Handle fewer generators than tables --- sqlsynthgen/create.py | 6 +-- sqlsynthgen/make.py | 31 ++++++++++++++++ tests/examples/basetable.csv | 4 ++ tests/examples/expected.csv | 2 + tests/examples/mytable.csv | 6 +-- tests/test_base.py | 8 ++-- tests/test_create.py | 17 ++++++++- tests/test_make.py | 71 +++++++++++++++++++++++++++++++++--- tests/test_providers.py | 9 ++++- 9 files changed, 135 insertions(+), 19 deletions(-) create mode 100644 tests/examples/basetable.csv create mode 100644 tests/examples/expected.csv diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py index 66d9ab48..93bbb680 100644 --- a/sqlsynthgen/create.py +++ b/sqlsynthgen/create.py @@ -49,10 +49,8 @@ def populate( ) -> None: """Populate a database schema with dummy data.""" - # ToDo Now that we have the vocab list, we can't assume that these are the same length - # We could reverse them, zip them and then reverse them again? - for table, generator in zip( - tables, generators + for table, generator in reversed( + list(zip(reversed(tables), reversed(generators))) ): # Run all the inserts for one table in a transaction with dst_conn.begin(): for _ in range(num_rows): diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index 15d00436..06215d5d 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -1,12 +1,16 @@ """Functions to make a module of generator classes.""" +import csv import inspect +from pathlib import Path from types import ModuleType from typing import Any, Final, Optional from mimesis.providers.base import BaseProvider +from sqlalchemy import create_engine, select from sqlalchemy.sql import sqltypes from sqlsynthgen import providers +from sqlsynthgen.settings import get_settings HEADER_TEXT: str = "\n".join( ( @@ -137,6 +141,27 @@ def _add_generator_for_table( return content, new_class_name +def _download_table(table: Any, engine: Any) -> None: + """Download a table and store it as a .csv file""" + stmt = select([table]) + with engine.connect() as conn: + # result = conn.execute("select * from concept") + result = list(conn.execute(stmt)) + with Path(table.fullname + ".csv").open( + "w", newline="", encoding="utf-8" + ) as csvfile: + writer = csv.writer( + csvfile, delimiter="," + ) # , quotechar='|', quoting=csv.quote_minimal) + writer.writerow([x.name for x in table.columns]) + # writer = csv.dictwriter(csvfile, fieldnames=[x.name for x in table.columns]) + for row in result: + writer.writerow(row) + + # writer.writeheader() + # writer.writerow({'first_name': 'baked', 'last_name': 'beans'}) + + def make_generators_from_tables( tables_module: ModuleType, generator_config: dict ) -> str: @@ -158,6 +183,9 @@ def make_generators_from_tables( sorted_generators = "[\n" sorted_vocab = "[\n" + settings = get_settings() + engine = create_engine(settings.src_postgres_dsn) + for table in tables_module.Base.metadata.sorted_tables: # ToDo Get list of vocab tables from config file if table.name in ("concept",): @@ -171,6 +199,9 @@ def make_generators_from_tables( f"= FileUploader({tables_module.__name__}.{class_name}.__table__)" ) sorted_vocab += f"{INDENTATION}{class_name.lower()}_vocab,\n" + + _download_table(table, engine) + continue table_config = generator_config.get("tables", {}).get(table.name, {}) diff --git a/tests/examples/basetable.csv b/tests/examples/basetable.csv new file mode 100644 index 00000000..7a8faf79 --- /dev/null +++ b/tests/examples/basetable.csv @@ -0,0 +1,4 @@ +id +1 +2 +3 diff --git a/tests/examples/expected.csv b/tests/examples/expected.csv new file mode 100644 index 00000000..3ff3deb8 --- /dev/null +++ b/tests/examples/expected.csv @@ -0,0 +1,2 @@ +id +1 diff --git a/tests/examples/mytable.csv b/tests/examples/mytable.csv index 7a8faf79..9c8175ee 100644 --- a/tests/examples/mytable.csv +++ b/tests/examples/mytable.csv @@ -1,4 +1,2 @@ -id -1 -2 -3 +id +1 diff --git a/tests/test_base.py b/tests/test_base.py index 7c9c982e..f35079ec 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -14,10 +14,10 @@ metadata = Base.metadata -class MyTable(Base): # type: ignore +class BaseTable(Base): # type: ignore """A SQLAlchemy table.""" - __tablename__ = "mytable" + __tablename__ = "basetable" id = Column( Integer, primary_key=True, @@ -43,10 +43,10 @@ def tearDown(self) -> None: def test_load(self) -> None: """Test the load method.""" - vocab_gen = FileUploader(MyTable.__table__) + vocab_gen = FileUploader(BaseTable.__table__) with self.engine.connect() as conn: vocab_gen.load(conn) - statement = select([MyTable]) + statement = select([BaseTable]) rows = list(conn.execute(statement)) self.assertEqual(3, len(rows)) diff --git a/tests/test_create.py b/tests/test_create.py index 1de41536..1e4dfe28 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,6 +1,6 @@ """Tests for the main module.""" from unittest import TestCase -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch from sqlsynthgen.create import ( create_db_data, @@ -60,6 +60,21 @@ def test_populate(self) -> None: mock_insert.return_value.values.return_value ) + def test_populate_diff_length(self) -> None: + """Test when generators and tables differ in length.""" + mock_dst_conn = MagicMock() + mock_gen_two = MagicMock() + mock_gen_three = MagicMock() + tables = [1, 2, 3] + generators = [mock_gen_two, mock_gen_three] + + with patch("sqlsynthgen.create.insert") as mock_insert: + populate(2, mock_dst_conn, tables, generators, 1) + self.assertListEqual([call(2), call(3)], mock_insert.call_args_list) + + mock_gen_two.assert_called_once() + mock_gen_three.assert_called_once() + def test_create_db_vocab(self) -> None: """Test the create_db_vocab function.""" with patch("sqlsynthgen.create.create_engine") as mock_create_engine, patch( diff --git a/tests/test_make.py b/tests/test_make.py index 94e59a07..f849fd57 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -1,25 +1,86 @@ """Tests for the main module.""" +import os +from pathlib import Path from unittest import TestCase +from unittest.mock import patch import yaml +from sqlalchemy import Column, Integer, create_engine, insert +from sqlalchemy.orm import declarative_base +from sqlalchemy.pool import NullPool from sqlsynthgen import make from tests.examples import example_orm +from tests.utils import run_psql + +# pylint: disable=protected-access +# pylint: disable=invalid-name +Base = declarative_base() +# pylint: enable=invalid-name +metadata = Base.metadata + + +class MakeTable(Base): # type: ignore + """A SQLAlchemy table.""" + + __tablename__ = "maketable" + id = Column( + Integer, + primary_key=True, + ) class MyTestCase(TestCase): """Module test case.""" + def setUp(self) -> None: + """Pre-test setup.""" + + run_psql("providers.dump") + + self.engine = create_engine( + "postgresql://postgres:password@localhost:5432/providers", + # pool=NullPool + poolclass=NullPool, + ) + metadata.create_all(self.engine) + os.chdir("tests/examples") + + def tearDown(self) -> None: + os.chdir("../..") + + self.engine.dispose() + def test_make_generators_from_tables(self) -> None: """Check that we can make a generators file from a tables module.""" self.maxDiff = None # pylint: disable=invalid-name - with open( - "tests/examples/expected_ssg.py", encoding="utf-8" - ) as expected_output: + with open("expected_ssg.py", encoding="utf-8") as expected_output: expected = expected_output.read() - conf_path = "tests/examples/generator_conf.yaml" + conf_path = "generator_conf.yaml" with open(conf_path, "r", encoding="utf8") as f: config = yaml.safe_load(f) - actual = make.make_generators_from_tables(example_orm, config) + with patch("sqlsynthgen.make._download_table",) as mock_download, patch( + "sqlsynthgen.make.create_engine" + ) as mock_create_engine, patch("sqlsynthgen.make.get_settings"): + # pass + actual = make.make_generators_from_tables(example_orm, config) + mock_download.assert_called_once() + mock_create_engine.assert_called_once() + + self.assertEqual(expected, actual) + + def test__download_table(self) -> None: + """Test the _download_table function.""" + with self.engine.connect() as conn: + conn.execute(insert(MakeTable).values({"id": 1})) + + make._download_table(MakeTable.__table__, self.engine) + + with Path("expected.csv").open(encoding="utf-8") as csvfile: + expected = csvfile.read() + + with Path("maketable.csv").open(encoding="utf-8") as csvfile: + actual = csvfile.read() + self.assertEqual(expected, actual) diff --git a/tests/test_providers.py b/tests/test_providers.py index 13459383..5a6ec07c 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -5,6 +5,7 @@ from sqlalchemy import Column, Integer, Text, create_engine, insert from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.pool import NullPool from sqlsynthgen import providers from tests.utils import run_psql @@ -47,10 +48,15 @@ def setUp(self) -> None: run_psql("providers.dump") self.engine = create_engine( - "postgresql://postgres:password@localhost:5432/providers" + "postgresql://postgres:password@localhost:5432/providers", + # pool=NullPool + poolclass=NullPool, ) metadata.create_all(self.engine) + def tearDown(self) -> None: + self.engine.dispose() + def test_column_value(self) -> None: """Test the key method.""" # pylint: disable=invalid-name @@ -63,6 +69,7 @@ def test_column_value(self) -> None: key = provider.column_value(conn, Person, "sex") self.assertEqual("M", key) + # pass class TimedeltaProvider(TestCase): From 2ebdca81969f72ec42bade32cefa8f0690d57ebe Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Wed, 15 Feb 2023 19:59:41 +0000 Subject: [PATCH 06/13] Force deletion of test fixture dbs To avoid the "error database is being accessed by other users" errors. --- tests/examples/dst.dump | 2 +- tests/examples/providers.dump | 2 +- tests/examples/src.dump | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/examples/dst.dump b/tests/examples/dst.dump index 909e2afe..a5dadd6c 100644 --- a/tests/examples/dst.dump +++ b/tests/examples/dst.dump @@ -16,7 +16,7 @@ SET xmloption = content; SET client_min_messages = warning; SET row_security = off; -DROP DATABASE IF EXISTS dst; +DROP DATABASE IF EXISTS dst WITH (FORCE); -- -- Name: dst; Type: DATABASE; Schema: -; Owner: postgres -- diff --git a/tests/examples/providers.dump b/tests/examples/providers.dump index 525cc832..de8d5b36 100644 --- a/tests/examples/providers.dump +++ b/tests/examples/providers.dump @@ -16,7 +16,7 @@ SET xmloption = content; SET client_min_messages = warning; SET row_security = off; -DROP DATABASE IF EXISTS providers; +DROP DATABASE IF EXISTS providers WITH (FORCE); -- -- Name: providers; Type: DATABASE; Schema: -; Owner: postgres -- diff --git a/tests/examples/src.dump b/tests/examples/src.dump index dffa05dd..571504f7 100644 --- a/tests/examples/src.dump +++ b/tests/examples/src.dump @@ -16,7 +16,7 @@ SET xmloption = content; SET client_min_messages = warning; SET row_security = off; -DROP DATABASE IF EXISTS src; +DROP DATABASE IF EXISTS src WITH (FORCE); -- -- Name: src; Type: DATABASE; Schema: -; Owner: postgres -- From 8d3e57b605566b03482a50d5e97bd4465b02a11a Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Wed, 15 Feb 2023 20:40:07 +0000 Subject: [PATCH 07/13] Get vocab table list from config file --- sqlsynthgen/make.py | 7 +++++-- tests/examples/generator_conf.yaml | 3 ++- tests/test_make.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index 06215d5d..2da07f56 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -187,8 +187,11 @@ def make_generators_from_tables( engine = create_engine(settings.src_postgres_dsn) for table in tables_module.Base.metadata.sorted_tables: - # ToDo Get list of vocab tables from config file - if table.name in ("concept",): + if table.name in [ + x + for x in generator_config.get("tables", {}).keys() + if generator_config["tables"][x].get("vocabulary_table") + ]: orm_class = _orm_class_from_table_name(tables_module, table.fullname) if not orm_class: diff --git a/tests/examples/generator_conf.yaml b/tests/examples/generator_conf.yaml index 6566666d..26ec28a2 100644 --- a/tests/examples/generator_conf.yaml +++ b/tests/examples/generator_conf.yaml @@ -2,7 +2,6 @@ custom_generators_module: custom_generators tables: person: num_rows_per_pass: 2 - vocabulary_table: false custom_generators: - name: generic.person.full_name args: null @@ -27,3 +26,5 @@ tables: - visit_start - visit_end - visit_duration_seconds + concept: + vocabulary_table: true diff --git a/tests/test_make.py b/tests/test_make.py index f849fd57..173ecbd3 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -63,9 +63,9 @@ def test_make_generators_from_tables(self) -> None: with patch("sqlsynthgen.make._download_table",) as mock_download, patch( "sqlsynthgen.make.create_engine" ) as mock_create_engine, patch("sqlsynthgen.make.get_settings"): - # pass actual = make.make_generators_from_tables(example_orm, config) mock_download.assert_called_once() + # self.assertEqual(, mock_download.call_args_list[0]) mock_create_engine.assert_called_once() self.assertEqual(expected, actual) From 4994b7ec4138a6bcba4e4b5325019ff03bdec8dd Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Thu, 16 Feb 2023 08:53:13 +0000 Subject: [PATCH 08/13] Remove unnecessary vocab parameter --- sqlsynthgen/main.py | 5 +---- tests/test_functional.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/sqlsynthgen/main.py b/sqlsynthgen/main.py index fa21cd0e..0aa330f9 100644 --- a/sqlsynthgen/main.py +++ b/sqlsynthgen/main.py @@ -46,11 +46,8 @@ def create_data( @app.command() -def create_vocab( - orm_file: str = typer.Argument(...), ssg_file: str = typer.Argument(...) -) -> None: +def create_vocab(ssg_file: str = typer.Argument(...)) -> None: """Create tables using the SQLAlchemy file.""" - orm_module = import_file(orm_file) ssg_module = import_file(ssg_file) create_db_vocab(ssg_module.sorted_vocab) diff --git a/tests/test_functional.py b/tests/test_functional.py index 4898e9b6..87f4e10e 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -53,7 +53,7 @@ def test_workflow(self) -> None: run(["sqlsynthgen", "create-tables", self.orm_file_path], env=env, check=True) run( - ["sqlsynthgen", "create-vocab", self.orm_file_path, self.ssg_file_path], + ["sqlsynthgen", "create-vocab", self.ssg_file_path], env=env, check=True, ) From 11b0a9211632155b3a42b636eeefe131cdf0dfcd Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Thu, 16 Feb 2023 15:06:46 +0000 Subject: [PATCH 09/13] Remove unnecessary engine pool options --- tests/test_make.py | 5 ----- tests/test_providers.py | 7 ------- 2 files changed, 12 deletions(-) diff --git a/tests/test_make.py b/tests/test_make.py index 173ecbd3..36d55ff9 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -7,7 +7,6 @@ import yaml from sqlalchemy import Column, Integer, create_engine, insert from sqlalchemy.orm import declarative_base -from sqlalchemy.pool import NullPool from sqlsynthgen import make from tests.examples import example_orm @@ -40,8 +39,6 @@ def setUp(self) -> None: self.engine = create_engine( "postgresql://postgres:password@localhost:5432/providers", - # pool=NullPool - poolclass=NullPool, ) metadata.create_all(self.engine) os.chdir("tests/examples") @@ -49,8 +46,6 @@ def setUp(self) -> None: def tearDown(self) -> None: os.chdir("../..") - self.engine.dispose() - def test_make_generators_from_tables(self) -> None: """Check that we can make a generators file from a tables module.""" self.maxDiff = None # pylint: disable=invalid-name diff --git a/tests/test_providers.py b/tests/test_providers.py index 5a6ec07c..8e86f48b 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -5,7 +5,6 @@ from sqlalchemy import Column, Integer, Text, create_engine, insert from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.pool import NullPool from sqlsynthgen import providers from tests.utils import run_psql @@ -49,14 +48,9 @@ def setUp(self) -> None: self.engine = create_engine( "postgresql://postgres:password@localhost:5432/providers", - # pool=NullPool - poolclass=NullPool, ) metadata.create_all(self.engine) - def tearDown(self) -> None: - self.engine.dispose() - def test_column_value(self) -> None: """Test the key method.""" # pylint: disable=invalid-name @@ -69,7 +63,6 @@ def test_column_value(self) -> None: key = provider.column_value(conn, Person, "sex") self.assertEqual("M", key) - # pass class TimedeltaProvider(TestCase): From a423a0d0e7b6d866d19901387a0b59d4c3d47c24 Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Thu, 16 Feb 2023 15:54:04 +0000 Subject: [PATCH 10/13] Address PR comments --- sqlsynthgen/make.py | 9 +-------- tests/test_base.py | 2 +- tests/test_make.py | 3 +-- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index 2da07f56..941cbf84 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -145,22 +145,15 @@ def _download_table(table: Any, engine: Any) -> None: """Download a table and store it as a .csv file""" stmt = select([table]) with engine.connect() as conn: - # result = conn.execute("select * from concept") result = list(conn.execute(stmt)) with Path(table.fullname + ".csv").open( "w", newline="", encoding="utf-8" ) as csvfile: - writer = csv.writer( - csvfile, delimiter="," - ) # , quotechar='|', quoting=csv.quote_minimal) + writer = csv.writer(csvfile, delimiter=",") writer.writerow([x.name for x in table.columns]) - # writer = csv.dictwriter(csvfile, fieldnames=[x.name for x in table.columns]) for row in result: writer.writerow(row) - # writer.writeheader() - # writer.writerow({'first_name': 'baked', 'last_name': 'beans'}) - def make_generators_from_tables( tables_module: ModuleType, generator_config: dict diff --git a/tests/test_base.py b/tests/test_base.py index f35079ec..5952d3f9 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,4 +1,4 @@ -"""Tests for the main module.""" +"""Tests for the base module.""" import os from unittest import TestCase diff --git a/tests/test_make.py b/tests/test_make.py index 36d55ff9..bad07da8 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -12,7 +12,6 @@ from tests.examples import example_orm from tests.utils import run_psql -# pylint: disable=protected-access # pylint: disable=invalid-name Base = declarative_base() # pylint: enable=invalid-name @@ -60,13 +59,13 @@ def test_make_generators_from_tables(self) -> None: ) as mock_create_engine, patch("sqlsynthgen.make.get_settings"): actual = make.make_generators_from_tables(example_orm, config) mock_download.assert_called_once() - # self.assertEqual(, mock_download.call_args_list[0]) mock_create_engine.assert_called_once() self.assertEqual(expected, actual) def test__download_table(self) -> None: """Test the _download_table function.""" + # pylint: disable=protected-access with self.engine.connect() as conn: conn.execute(insert(MakeTable).values({"id": 1})) From de22df0cf63edbfe1f9cd8a1f4df786e462ee150 Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Thu, 16 Feb 2023 16:03:39 +0000 Subject: [PATCH 11/13] Add skipUnless decorator to classes that need it --- .github/workflows/tests.yml | 2 +- tests/test_base.py | 5 ++--- tests/test_functional.py | 8 ++------ tests/test_make.py | 5 ++--- tests/utils.py | 12 ++++++++++++ 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 80f8734a..195feaeb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -54,4 +54,4 @@ jobs: - name: Run Unit Tests shell: bash run: | - FUNCTIONAL_TESTS=1 poetry run python -m unittest discover --verbose tests + REQUIRES_DB=1 poetry run python -m unittest discover --verbose tests diff --git a/tests/test_base.py b/tests/test_base.py index 5952d3f9..5b712ff2 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,12 +1,11 @@ """Tests for the base module.""" import os -from unittest import TestCase from sqlalchemy import Column, Integer, create_engine, select from sqlalchemy.orm import declarative_base from sqlsynthgen.base import FileUploader -from tests.utils import run_psql +from tests.utils import RequiresDBTestCase, run_psql # pylint: disable=invalid-name Base = declarative_base() @@ -24,7 +23,7 @@ class BaseTable(Base): # type: ignore ) -class VocabTests(TestCase): +class VocabTests(RequiresDBTestCase): """Module test case.""" def setUp(self) -> None: diff --git a/tests/test_functional.py b/tests/test_functional.py index 87f4e10e..9bbfd6d4 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2,15 +2,11 @@ import os from pathlib import Path from subprocess import run -from unittest import TestCase, skipUnless -from tests.utils import run_psql +from tests.utils import RequiresDBTestCase, run_psql -@skipUnless( - os.environ.get("FUNCTIONAL_TESTS") == "1", "Set 'FUNCTIONAL_TESTS=1' to enable." -) -class FunctionalTests(TestCase): +class FunctionalTestCase(RequiresDBTestCase): """End-to-end tests.""" orm_file_path = Path("tests/tmp/orm.py") diff --git a/tests/test_make.py b/tests/test_make.py index bad07da8..95407fbc 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -1,7 +1,6 @@ """Tests for the main module.""" import os from pathlib import Path -from unittest import TestCase from unittest.mock import patch import yaml @@ -10,7 +9,7 @@ from sqlsynthgen import make from tests.examples import example_orm -from tests.utils import run_psql +from tests.utils import RequiresDBTestCase, run_psql # pylint: disable=invalid-name Base = declarative_base() @@ -28,7 +27,7 @@ class MakeTable(Base): # type: ignore ) -class MyTestCase(TestCase): +class MyTestCase(RequiresDBTestCase): """Module test case.""" def setUp(self) -> None: diff --git a/tests/utils.py b/tests/utils.py index 49624879..e7e9cefb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ from functools import lru_cache from pathlib import Path from subprocess import run +from unittest import TestCase, skipUnless from sqlsynthgen import settings @@ -48,3 +49,14 @@ def run_psql(dump_file_name: str) -> None: ) # psql doesn't always return != 0 if it fails assert completed_process.stderr == b"", completed_process.stderr + + +@skipUnless(os.environ.get("REQUIRES_DB") == "1", "Set 'REQUIRES_DB=1' to enable.") +class RequiresDBTestCase(TestCase): + """A test case that only runs if REQUIRES_DB has been set to true.""" + + def setUp(self) -> None: + pass + + def tearDown(self) -> None: + pass From 842b4d6cac211003588db9f66302b5467bc4a27c Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Thu, 16 Feb 2023 16:04:33 +0000 Subject: [PATCH 12/13] Set jobs=1 in pylint config For some reason, it runs quicker that way. Tested on x2 dev machines. --- .pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index 89698404..09fc6f7c 100644 --- a/.pylintrc +++ b/.pylintrc @@ -37,7 +37,7 @@ ignore-patterns= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use. -jobs=0 +jobs=1 # Control the amount of potential inferred values when inferring a single # object. This can help the performance when dealing with large functions or From ced32a3a0a0a7979ab8c1cd1fa99d7dff92fdac7 Mon Sep 17 00:00:00 2001 From: Iain <25081046+Iain-S@users.noreply.github.com> Date: Thu, 16 Feb 2023 16:20:25 +0000 Subject: [PATCH 13/13] Disable duplicate-code pylint error --- .pylintrc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index 09fc6f7c..38176a3c 100644 --- a/.pylintrc +++ b/.pylintrc @@ -86,7 +86,8 @@ disable=raw-checker-failed, suppressed-message, deprecated-pragma, use-symbolic-message-instead, - too-few-public-methods + too-few-public-methods, + duplicate-code # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option