Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the base module."""
import os
from pathlib import Path

from sqlalchemy import Column, Integer, create_engine, select
from sqlalchemy.orm import declarative_base
Expand All @@ -26,19 +27,22 @@ class BaseTable(Base): # type: ignore
class VocabTests(RequiresDBTestCase):
"""Module test case."""

test_dir = Path("tests/examples")
start_dir = os.getcwd()

def setUp(self) -> None:
"""Pre-test setup."""

run_psql("providers.dump")
run_psql(Path("tests/examples/providers.dump"))

self.engine = create_engine(
"postgresql://postgres:password@localhost:5432/providers"
)
metadata.create_all(self.engine)
os.chdir("tests/examples")
os.chdir(self.test_dir)

def tearDown(self) -> None:
os.chdir("../..")
os.chdir(self.start_dir)

def test_load(self) -> None:
"""Test the load method."""
Expand Down
9 changes: 6 additions & 3 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class FunctionalTestCase(RequiresDBTestCase):

concept_file_path = Path("concept.csv")

test_dir = Path("tests/workspace")
start_dir = os.getcwd()

env = os.environ.copy()
env = {
**env,
Expand All @@ -37,9 +40,9 @@ def setUp(self) -> None:
"""Pre-test setup."""

# Create a blank destination database
run_psql("dst.dump")
run_psql(Path("tests/examples/dst.dump"))

os.chdir("tests/workspace")
os.chdir(self.test_dir)

for file_path in (
self.orm_file_path,
Expand All @@ -51,7 +54,7 @@ def setUp(self) -> None:
file_path.unlink(missing_ok=True)

def tearDown(self) -> None:
os.chdir("../../")
os.chdir(self.start_dir)

def test_workflow_minimal_args(self) -> None:
"""Test the recommended CLI workflow runs without errors."""
Expand Down
9 changes: 6 additions & 3 deletions tests/test_make.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for the main module."""
import os
from io import StringIO
from pathlib import Path
from subprocess import CalledProcessError
from unittest import TestCase
from unittest.mock import MagicMock, call, patch
Expand All @@ -16,14 +17,16 @@
class TestMake(TestCase):
"""Tests that don't require a database."""

test_dir = Path("tests/examples")
start_dir = os.getcwd()

def setUp(self) -> None:
"""Pre-test setup."""

os.chdir("tests/examples")
os.chdir(self.test_dir)

def tearDown(self) -> None:
"""Post-test cleanup."""
os.chdir("../..")
os.chdir(self.start_dir)

@patch("sqlsynthgen.make.get_settings")
@patch("sqlsynthgen.make.create_engine")
Expand Down
3 changes: 2 additions & 1 deletion tests/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the providers module."""
import datetime as dt
from pathlib import Path
from unittest import TestCase

from sqlalchemy import Column, Integer, Text, create_engine, insert
Expand Down Expand Up @@ -40,7 +41,7 @@ class ColumnValueProviderTestCase(RequiresDBTestCase):
def setUp(self) -> None:
"""Pre-test setup."""

run_psql("providers.dump")
run_psql(Path("tests/examples/providers.dump"))

self.engine = create_engine(
"postgresql://postgres:password@localhost:5432/providers",
Expand Down
18 changes: 12 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ class MyTable(Base): # type: ignore
class TestImport(TestCase):
"""Tests for the import_file function."""

test_dir = Path("tests/examples")
start_dir = os.getcwd()

def setUp(self) -> None:
"""Pre-test setup."""

os.chdir("tests/examples")
os.chdir(self.test_dir)

def tearDown(self) -> None:
os.chdir("../../")
"""Post-test cleanup."""
os.chdir(self.start_dir)

def test_import_file(self) -> None:
"""Test that we can import an example module."""
Expand All @@ -53,23 +56,26 @@ class TestDownload(RequiresDBTestCase):

mytable_file_path = Path("mytable.csv")

test_dir = Path("tests/workspace")
start_dir = os.getcwd()

def setUp(self) -> None:
"""Pre-test setup."""

run_psql("providers.dump")
run_psql(Path("tests/examples/providers.dump"))

self.engine = create_engine(
"postgresql://postgres:password@localhost:5432/providers",
connect_args={"connect_timeout": 10},
)
metadata.create_all(self.engine)

os.chdir("tests/workspace")
os.chdir(self.test_dir)
self.mytable_file_path.unlink(missing_ok=True)

def tearDown(self) -> None:
"""Post-test cleanup."""
os.chdir("../..")
os.chdir(self.start_dir)

def test_download_table(self) -> None:
"""Test the download_table function."""
Expand Down
11 changes: 3 additions & 8 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_test_settings() -> settings.Settings:
)


def run_psql(dump_file_name: str) -> None:
def run_psql(dump_file: Path) -> None:
"""Run psql and pass dump_file_name as the --file option."""

# If you need to update a .dump file, use
Expand All @@ -41,12 +41,7 @@ def run_psql(dump_file_name: str) -> None:

# Clear and re-create the test database
completed_process = run(
[
"psql",
"--host=localhost",
"--username=postgres",
"--file=" + str(Path(f"tests/examples/{dump_file_name}")),
],
["psql", "--host=localhost", "--username=postgres", f"--file={dump_file}"],
capture_output=True,
env=env,
check=True,
Expand All @@ -57,7 +52,7 @@ def run_psql(dump_file_name: str) -> None:

@skipUnless(os.environ.get("REQUIRES_DB") == "1", "Set 'REQUIRES_DB=1' to enable.")
class RequiresDBTestCase(TestCase):
"""A test case that only runs if REQUIRES_DB has been set to true."""
"""A test case that only runs if REQUIRES_DB has been set to 1."""

def setUp(self) -> None:
pass
Expand Down