diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 80f8734a..195feaeb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -54,4 +54,4 @@ jobs: - name: Run Unit Tests shell: bash run: | - FUNCTIONAL_TESTS=1 poetry run python -m unittest discover --verbose tests + REQUIRES_DB=1 poetry run python -m unittest discover --verbose tests diff --git a/.pylintrc b/.pylintrc index 89698404..38176a3c 100644 --- a/.pylintrc +++ b/.pylintrc @@ -37,7 +37,7 @@ ignore-patterns= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use. -jobs=0 +jobs=1 # Control the amount of potential inferred values when inferring a single # object. This can help the performance when dealing with large functions or @@ -86,7 +86,8 @@ disable=raw-checker-failed, suppressed-message, deprecated-pragma, use-symbolic-message-instead, - too-few-public-methods + too-few-public-methods, + duplicate-code # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/sqlsynthgen/base.py b/sqlsynthgen/base.py new file mode 100644 index 00000000..52a51e28 --- /dev/null +++ b/sqlsynthgen/base.py @@ -0,0 +1,23 @@ +"""Base generator classes.""" +import csv +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from sqlalchemy import insert + + +@dataclass +class FileUploader: + """For uploading data files.""" + + table: Any + + def load(self, connection: Any) -> None: + """Load the data from file.""" + with Path(self.table.fullname + ".csv").open( + "r", newline="", encoding="utf-8" + ) as csvfile: + reader = csv.DictReader(csvfile) + stmt = insert(self.table).values(list(reader)) + connection.execute(stmt) diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py index 02bd2243..93bbb680 100644 --- a/sqlsynthgen/create.py +++ b/sqlsynthgen/create.py @@ -1,5 +1,5 @@ """Functions and classes to create and populate the target database.""" -from typing import Any +from typing import Any, List from sqlalchemy import create_engine, insert from sqlalchemy.schema import CreateSchema @@ -23,6 +23,16 @@ def create_db_tables(metadata: Any) -> Any: metadata.create_all(engine) +def create_db_vocab(sorted_vocab: List[Any]) -> None: + """Load vocabulary tables from files.""" + settings = get_settings() + dst_engine = create_engine(settings.dst_postgres_dsn) + + with dst_engine.connect() as dst_conn: + for vocab_table in sorted_vocab: + vocab_table.load(dst_conn) + + def create_db_data(sorted_tables: list, sorted_generators: list, num_rows: int) -> None: """Connect to a database and populate it with data.""" settings = get_settings() @@ -39,8 +49,8 @@ def populate( ) -> None: """Populate a database schema with dummy data.""" - for table, generator in zip( - tables, generators + for table, generator in reversed( + list(zip(reversed(tables), reversed(generators))) ): # Run all the inserts for one table in a transaction with dst_conn.begin(): for _ in range(num_rows): diff --git a/sqlsynthgen/main.py b/sqlsynthgen/main.py index 829986d7..0aa330f9 100644 --- a/sqlsynthgen/main.py +++ b/sqlsynthgen/main.py @@ -10,7 +10,7 @@ import typer import yaml -from sqlsynthgen.create import create_db_data, create_db_tables +from sqlsynthgen.create import create_db_data, create_db_tables, create_db_vocab from sqlsynthgen.make import make_generators_from_tables from sqlsynthgen.settings import get_settings @@ -45,6 +45,13 @@ def create_data( ) +@app.command() +def create_vocab(ssg_file: str = typer.Argument(...)) -> None: + """Create tables using the SQLAlchemy file.""" + ssg_module = import_file(ssg_file) + create_db_vocab(ssg_module.sorted_vocab) + + @app.command() def create_tables(orm_file: str = typer.Argument(...)) -> None: """Create tables using the SQLAlchemy file.""" diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index 39501669..941cbf84 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -1,18 +1,23 @@ """Functions to make a module of generator classes.""" +import csv import inspect +from pathlib import Path from types import ModuleType from typing import Any, Final, Optional from mimesis.providers.base import BaseProvider +from sqlalchemy import create_engine, select from sqlalchemy.sql import sqltypes from sqlsynthgen import providers +from sqlsynthgen.settings import get_settings HEADER_TEXT: str = "\n".join( ( '"""This file was auto-generated by sqlsynthgen but can be edited manually."""', "from mimesis import Generic", "from mimesis.locales import Locale", + "from sqlsynthgen.base import FileUploader", "", "generic = Generic(locale=Locale.EN)", "", @@ -136,6 +141,20 @@ def _add_generator_for_table( return content, new_class_name +def _download_table(table: Any, engine: Any) -> None: + """Download a table and store it as a .csv file""" + stmt = select([table]) + with engine.connect() as conn: + result = list(conn.execute(stmt)) + with Path(table.fullname + ".csv").open( + "w", newline="", encoding="utf-8" + ) as csvfile: + writer = csv.writer(csvfile, delimiter=",") + writer.writerow([x.name for x in table.columns]) + for row in result: + writer.writerow(row) + + def make_generators_from_tables( tables_module: ModuleType, generator_config: dict ) -> str: @@ -143,6 +162,7 @@ def make_generators_from_tables( Args: tables_module: A sqlacodegen-generated module. + generator_config: Configuration to control the generator creation. Returns: A string that is a valid Python module, once written to file. @@ -154,12 +174,42 @@ def make_generators_from_tables( new_content += f"\nfrom . import {generator_module_name}" sorted_generators = "[\n" + sorted_vocab = "[\n" + + settings = get_settings() + engine = create_engine(settings.src_postgres_dsn) + for table in tables_module.Base.metadata.sorted_tables: + if table.name in [ + x + for x in generator_config.get("tables", {}).keys() + if generator_config["tables"][x].get("vocabulary_table") + ]: + + orm_class = _orm_class_from_table_name(tables_module, table.fullname) + if not orm_class: + raise RuntimeError(f"Couldn't find {table.fullname} in {tables_module}") + class_name = orm_class.__name__ + new_content += ( + f"\n\n{class_name.lower()}_vocab " + f"= FileUploader({tables_module.__name__}.{class_name}.__table__)" + ) + sorted_vocab += f"{INDENTATION}{class_name.lower()}_vocab,\n" + + _download_table(table, engine) + + continue + table_config = generator_config.get("tables", {}).get(table.name, {}) new_content, new_generator_name = _add_generator_for_table( new_content, tables_module, table_config, table ) sorted_generators += f"{INDENTATION}{new_generator_name},\n" + sorted_generators += "]" + sorted_vocab += "]" + new_content += "\n\n" + "sorted_generators = " + sorted_generators + "\n" + new_content += "\n\n" + "sorted_vocab = " + sorted_vocab + "\n" + return new_content diff --git a/tests/examples/basetable.csv b/tests/examples/basetable.csv new file mode 100644 index 00000000..7a8faf79 --- /dev/null +++ b/tests/examples/basetable.csv @@ -0,0 +1,4 @@ +id +1 +2 +3 diff --git a/tests/examples/dst.dump b/tests/examples/dst.dump index 909e2afe..a5dadd6c 100644 --- a/tests/examples/dst.dump +++ b/tests/examples/dst.dump @@ -16,7 +16,7 @@ SET xmloption = content; SET client_min_messages = warning; SET row_security = off; -DROP DATABASE IF EXISTS dst; +DROP DATABASE IF EXISTS dst WITH (FORCE); -- -- Name: dst; Type: DATABASE; Schema: -; Owner: postgres -- diff --git a/tests/examples/example_orm.py b/tests/examples/example_orm.py index e37ea654..b242a366 100644 --- a/tests/examples/example_orm.py +++ b/tests/examples/example_orm.py @@ -47,3 +47,14 @@ class Entity(Base): Integer, primary_key=True, ) + + +class Concept(Base): + __tablename__ = "concept" + __table_args__ = {"schema": "myschema"} + + concept_id = Column( + Integer, + primary_key=True, + ) + concept_name = Column(Text) diff --git a/tests/examples/expected.csv b/tests/examples/expected.csv new file mode 100644 index 00000000..3ff3deb8 --- /dev/null +++ b/tests/examples/expected.csv @@ -0,0 +1,2 @@ +id +1 diff --git a/tests/examples/expected_ssg.py b/tests/examples/expected_ssg.py index 5a42583c..dee1447c 100644 --- a/tests/examples/expected_ssg.py +++ b/tests/examples/expected_ssg.py @@ -1,6 +1,7 @@ """This file was auto-generated by sqlsynthgen but can be edited manually.""" from mimesis import Generic from mimesis.locales import Locale +from sqlsynthgen.base import FileUploader generic = Generic(locale=Locale.EN) @@ -16,6 +17,8 @@ import tests.examples.example_orm from . import custom_generators +concept_vocab = FileUploader(tests.examples.example_orm.Concept.__table__) + class entityGenerator: def __init__(self, src_db_conn, dst_db_conn): pass @@ -44,3 +47,8 @@ def __init__(self, src_db_conn, dst_db_conn): personGenerator, hospital_visitGenerator, ] + + +sorted_vocab = [ + concept_vocab, +] diff --git a/tests/examples/generator_conf.yaml b/tests/examples/generator_conf.yaml index 6566666d..26ec28a2 100644 --- a/tests/examples/generator_conf.yaml +++ b/tests/examples/generator_conf.yaml @@ -2,7 +2,6 @@ custom_generators_module: custom_generators tables: person: num_rows_per_pass: 2 - vocabulary_table: false custom_generators: - name: generic.person.full_name args: null @@ -27,3 +26,5 @@ tables: - visit_start - visit_end - visit_duration_seconds + concept: + vocabulary_table: true diff --git a/tests/examples/mytable.csv b/tests/examples/mytable.csv new file mode 100644 index 00000000..9c8175ee --- /dev/null +++ b/tests/examples/mytable.csv @@ -0,0 +1,2 @@ +id +1 diff --git a/tests/examples/providers.dump b/tests/examples/providers.dump index 525cc832..de8d5b36 100644 --- a/tests/examples/providers.dump +++ b/tests/examples/providers.dump @@ -16,7 +16,7 @@ SET xmloption = content; SET client_min_messages = warning; SET row_security = off; -DROP DATABASE IF EXISTS providers; +DROP DATABASE IF EXISTS providers WITH (FORCE); -- -- Name: providers; Type: DATABASE; Schema: -; Owner: postgres -- diff --git a/tests/examples/src.dump b/tests/examples/src.dump index f9a5258b..571504f7 100644 --- a/tests/examples/src.dump +++ b/tests/examples/src.dump @@ -16,7 +16,7 @@ SET xmloption = content; SET client_min_messages = warning; SET row_security = off; -DROP DATABASE IF EXISTS src; +DROP DATABASE IF EXISTS src WITH (FORCE); -- -- Name: src; Type: DATABASE; Schema: -; Owner: postgres -- @@ -44,26 +44,23 @@ SET default_tablespace = ''; SET default_table_access_method = heap; -- --- Name: hospital_visit; Type: TABLE; Schema: public; Owner: postgres +-- Name: concept; Type: TABLE; Schema: public; Owner: postgres -- -CREATE TABLE public.hospital_visit ( - hospital_visit_id bigint NOT NULL, - person_id integer NOT NULL, - visit_start date NOT NULL, - visit_duration_seconds real NOT NULL, - visit_image bytea NOT NULL +CREATE TABLE public.concept ( + concept_id integer NOT NULL PRIMARY KEY, + concept_name text NOT NULL ); -ALTER TABLE public.hospital_visit OWNER TO postgres; +ALTER TABLE public.concept OWNER TO postgres; -- -- Name: person; Type: TABLE; Schema: public; Owner: postgres -- CREATE TABLE public.person ( - person_id integer NOT NULL, + person_id integer NOT NULL PRIMARY KEY, name text NOT NULL, research_opt_out boolean NOT NULL, stored_from timestamp with time zone NOT NULL @@ -73,36 +70,34 @@ CREATE TABLE public.person ( ALTER TABLE public.person OWNER TO postgres; -- --- Data for Name: hospital_visit; Type: TABLE DATA; Schema: public; Owner: postgres +-- Name: hospital_visit; Type: TABLE; Schema: public; Owner: postgres -- -COPY public.hospital_visit (hospital_visit_id, person_id, visit_start, visit_duration_seconds, visit_image) FROM stdin; -\. +CREATE TABLE public.hospital_visit ( + hospital_visit_id bigint NOT NULL PRIMARY KEY, + person_id integer NOT NULL references public.person(person_id), + visit_start date NOT NULL, + visit_duration_seconds real NOT NULL, + visit_image bytea NOT NULL, + visit_type_concept_id integer NOT NULL references public.concept(concept_id) +); +ALTER TABLE public.hospital_visit OWNER TO postgres; -- --- Data for Name: person; Type: TABLE DATA; Schema: public; Owner: postgres +-- Data for Name: hospital_visit; Type: TABLE DATA; Schema: public; Owner: postgres -- -COPY public.person (person_id, name, research_opt_out, stored_from) FROM stdin; +COPY public.hospital_visit (hospital_visit_id, person_id, visit_start, visit_duration_seconds, visit_image) FROM stdin; \. -- --- Name: hospital_visit hospital_visit_pkey; Type: CONSTRAINT; Schema: public; Owner: postgres --- - -ALTER TABLE ONLY public.hospital_visit - ADD CONSTRAINT hospital_visit_pkey PRIMARY KEY (hospital_visit_id); - - --- --- Name: person person_pkey; Type: CONSTRAINT; Schema: public; Owner: postgres +-- Data for Name: person; Type: TABLE DATA; Schema: public; Owner: postgres -- -ALTER TABLE ONLY public.person - ADD CONSTRAINT person_pkey PRIMARY KEY (person_id); - +COPY public.person (person_id, name, research_opt_out, stored_from) FROM stdin; +\. -- -- PostgreSQL database dump complete diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 00000000..5b712ff2 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,51 @@ +"""Tests for the base module.""" +import os + +from sqlalchemy import Column, Integer, create_engine, select +from sqlalchemy.orm import declarative_base + +from sqlsynthgen.base import FileUploader +from tests.utils import RequiresDBTestCase, run_psql + +# pylint: disable=invalid-name +Base = declarative_base() +# pylint: enable=invalid-name +metadata = Base.metadata + + +class BaseTable(Base): # type: ignore + """A SQLAlchemy table.""" + + __tablename__ = "basetable" + id = Column( + Integer, + primary_key=True, + ) + + +class VocabTests(RequiresDBTestCase): + """Module test case.""" + + def setUp(self) -> None: + """Pre-test setup.""" + + run_psql("providers.dump") + + self.engine = create_engine( + "postgresql://postgres:password@localhost:5432/providers" + ) + metadata.create_all(self.engine) + os.chdir("tests/examples") + + def tearDown(self) -> None: + os.chdir("../..") + + def test_load(self) -> None: + """Test the load method.""" + vocab_gen = FileUploader(BaseTable.__table__) + + with self.engine.connect() as conn: + vocab_gen.load(conn) + statement = select([BaseTable]) + rows = list(conn.execute(statement)) + self.assertEqual(3, len(rows)) diff --git a/tests/test_create.py b/tests/test_create.py index 3041b8ba..1e4dfe28 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,8 +1,13 @@ """Tests for the main module.""" from unittest import TestCase -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch -from sqlsynthgen.create import create_db_data, create_db_tables, populate +from sqlsynthgen.create import ( + create_db_data, + create_db_tables, + create_db_vocab, + populate, +) from tests.utils import get_test_settings @@ -54,3 +59,32 @@ def test_populate(self) -> None: mock_dst_conn.execute.assert_called_once_with( mock_insert.return_value.values.return_value ) + + def test_populate_diff_length(self) -> None: + """Test when generators and tables differ in length.""" + mock_dst_conn = MagicMock() + mock_gen_two = MagicMock() + mock_gen_three = MagicMock() + tables = [1, 2, 3] + generators = [mock_gen_two, mock_gen_three] + + with patch("sqlsynthgen.create.insert") as mock_insert: + populate(2, mock_dst_conn, tables, generators, 1) + self.assertListEqual([call(2), call(3)], mock_insert.call_args_list) + + mock_gen_two.assert_called_once() + mock_gen_three.assert_called_once() + + def test_create_db_vocab(self) -> None: + """Test the create_db_vocab function.""" + with patch("sqlsynthgen.create.create_engine") as mock_create_engine, patch( + "sqlsynthgen.create.get_settings" + ) as mock_get_settings: + vocab_list = [MagicMock()] + create_db_vocab(vocab_list) + vocab_list[0].load.assert_called_once_with( + mock_create_engine.return_value.connect.return_value.__enter__.return_value + ) + mock_create_engine.assert_called_once_with( + mock_get_settings.return_value.dst_postgres_dsn + ) diff --git a/tests/test_functional.py b/tests/test_functional.py index 2f4a83c7..9bbfd6d4 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2,15 +2,11 @@ import os from pathlib import Path from subprocess import run -from unittest import TestCase, skipUnless -from tests.utils import run_psql +from tests.utils import RequiresDBTestCase, run_psql -@skipUnless( - os.environ.get("FUNCTIONAL_TESTS") == "1", "Set 'FUNCTIONAL_TESTS=1' to enable." -) -class FunctionalTests(TestCase): +class FunctionalTestCase(RequiresDBTestCase): """End-to-end tests.""" orm_file_path = Path("tests/tmp/orm.py") @@ -52,6 +48,11 @@ def test_workflow(self) -> None: ) run(["sqlsynthgen", "create-tables", self.orm_file_path], env=env, check=True) + run( + ["sqlsynthgen", "create-vocab", self.ssg_file_path], + env=env, + check=True, + ) run( ["sqlsynthgen", "create-data", self.orm_file_path, self.ssg_file_path, "1"], env=env, diff --git a/tests/test_make.py b/tests/test_make.py index 94e59a07..95407fbc 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -1,25 +1,79 @@ """Tests for the main module.""" -from unittest import TestCase +import os +from pathlib import Path +from unittest.mock import patch import yaml +from sqlalchemy import Column, Integer, create_engine, insert +from sqlalchemy.orm import declarative_base from sqlsynthgen import make from tests.examples import example_orm +from tests.utils import RequiresDBTestCase, run_psql +# pylint: disable=invalid-name +Base = declarative_base() +# pylint: enable=invalid-name +metadata = Base.metadata -class MyTestCase(TestCase): + +class MakeTable(Base): # type: ignore + """A SQLAlchemy table.""" + + __tablename__ = "maketable" + id = Column( + Integer, + primary_key=True, + ) + + +class MyTestCase(RequiresDBTestCase): """Module test case.""" + def setUp(self) -> None: + """Pre-test setup.""" + + run_psql("providers.dump") + + self.engine = create_engine( + "postgresql://postgres:password@localhost:5432/providers", + ) + metadata.create_all(self.engine) + os.chdir("tests/examples") + + def tearDown(self) -> None: + os.chdir("../..") + def test_make_generators_from_tables(self) -> None: """Check that we can make a generators file from a tables module.""" self.maxDiff = None # pylint: disable=invalid-name - with open( - "tests/examples/expected_ssg.py", encoding="utf-8" - ) as expected_output: + with open("expected_ssg.py", encoding="utf-8") as expected_output: expected = expected_output.read() - conf_path = "tests/examples/generator_conf.yaml" + conf_path = "generator_conf.yaml" with open(conf_path, "r", encoding="utf8") as f: config = yaml.safe_load(f) - actual = make.make_generators_from_tables(example_orm, config) + with patch("sqlsynthgen.make._download_table",) as mock_download, patch( + "sqlsynthgen.make.create_engine" + ) as mock_create_engine, patch("sqlsynthgen.make.get_settings"): + actual = make.make_generators_from_tables(example_orm, config) + mock_download.assert_called_once() + mock_create_engine.assert_called_once() + + self.assertEqual(expected, actual) + + def test__download_table(self) -> None: + """Test the _download_table function.""" + # pylint: disable=protected-access + with self.engine.connect() as conn: + conn.execute(insert(MakeTable).values({"id": 1})) + + make._download_table(MakeTable.__table__, self.engine) + + with Path("expected.csv").open(encoding="utf-8") as csvfile: + expected = csvfile.read() + + with Path("maketable.csv").open(encoding="utf-8") as csvfile: + actual = csvfile.read() + self.assertEqual(expected, actual) diff --git a/tests/test_providers.py b/tests/test_providers.py index 13459383..8e86f48b 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -47,7 +47,7 @@ def setUp(self) -> None: run_psql("providers.dump") self.engine = create_engine( - "postgresql://postgres:password@localhost:5432/providers" + "postgresql://postgres:password@localhost:5432/providers", ) metadata.create_all(self.engine) diff --git a/tests/utils.py b/tests/utils.py index 49624879..e7e9cefb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ from functools import lru_cache from pathlib import Path from subprocess import run +from unittest import TestCase, skipUnless from sqlsynthgen import settings @@ -48,3 +49,14 @@ def run_psql(dump_file_name: str) -> None: ) # psql doesn't always return != 0 if it fails assert completed_process.stderr == b"", completed_process.stderr + + +@skipUnless(os.environ.get("REQUIRES_DB") == "1", "Set 'REQUIRES_DB=1' to enable.") +class RequiresDBTestCase(TestCase): + """A test case that only runs if REQUIRES_DB has been set to true.""" + + def setUp(self) -> None: + pass + + def tearDown(self) -> None: + pass