Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,978 changes: 1,110 additions & 868 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ license = "MIT"
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.9"
python = "^3.9,<3.11"
sqlacodegen = {git = "https://github.com/agronholm/sqlacodegen.git", rev = "3.0.0rc1"}
pydantic = {extras = ["dotenv"], version = "^1.10.2"}
psycopg2-binary = "^2.9.5"
sqlalchemy-utils = "^0.38.3"
mimesis = "^6.1.1"
typer = "^0.7.0"
pyyaml = "^6.0"
pyyaml = "^5.0"
sqlalchemy = "^1.4"
sphinx-rtd-theme = {version = "^1.2.0", optional = true}
sphinxcontrib-napoleon = {version = "^0.7", optional = true}
smartnoise-sql = "^0.2.9.1"

[tool.poetry.group.dev.dependencies]
black = "^22.10.0"
Expand All @@ -26,6 +27,7 @@ pylint = "^2.15.8"
mypy = "^0.991"
types-pyyaml = "^6.0.12.4"
pydocstyle = "^6.3.0"
pytest = "^7.2.0"

[tool.poetry.extras]
docs = ["sphinx-rtd-theme", "sphinxcontrib-napoleon"]
Expand Down
36 changes: 33 additions & 3 deletions sqlsynthgen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@
from typing import Final, Optional

import typer
import yaml

from sqlsynthgen.create import create_db_data, create_db_tables, create_db_vocab
from sqlsynthgen.make import make_generators_from_tables, make_tables_file
from sqlsynthgen.make import (
make_generators_from_tables,
make_src_stats,
make_tables_file,
)
from sqlsynthgen.settings import get_settings
from sqlsynthgen.utils import import_file, read_yaml_file

ORM_FILENAME: Final[str] = "orm.py"
SSG_FILENAME: Final[str] = "ssg.py"
STATS_FILENAME: Final[str] = "src-stats.yaml"

app = typer.Typer()

Expand Down Expand Up @@ -92,7 +98,8 @@ def create_tables(orm_file: str = typer.Option(ORM_FILENAME)) -> None:
def make_generators(
orm_file: str = typer.Option(ORM_FILENAME),
ssg_file: str = typer.Option(SSG_FILENAME),
config_file: Optional[str] = typer.Argument(None),
config_file: Optional[str] = typer.Option(None),
stats_file: Optional[str] = typer.Option(None),
) -> None:
"""Make a SQLSynthGen file of generator classes.

Expand All @@ -115,11 +122,34 @@ def make_generators(

orm_module = import_file(orm_file)
generator_config = read_yaml_file(config_file) if config_file is not None else {}
result = make_generators_from_tables(orm_module, generator_config)
result = make_generators_from_tables(orm_module, generator_config, stats_file)

ssg_file_path.write_text(result, encoding="utf-8")


@app.command()
def make_stats(
config_file: str = typer.Option(...),
stats_file: str = typer.Option(STATS_FILENAME),
) -> None:
"""Compute summary statistics from the source database, write them to a YAML file.

Example:
$ sqlsynthgen make_stats --config-file=example_config.yaml
"""
stats_file_path = Path(stats_file)
if stats_file_path.exists():
print(f"{stats_file} should not already exist. Exiting...", file=stderr)
sys.exit(1)
settings = get_settings()
generator_config = read_yaml_file(config_file) if config_file is not None else {}
src_dsn = settings.src_postgres_dsn
if src_dsn is None:
raise ValueError("Missing source database connection details.")
src_stats = make_src_stats(src_dsn, generator_config)
stats_file_path.write_text(yaml.dump(src_stats), encoding="utf-8")


@app.command()
def make_tables(
orm_file: str = typer.Option(ORM_FILENAME),
Expand Down
46 changes: 45 additions & 1 deletion sqlsynthgen/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from types import ModuleType
from typing import Any, Final, Optional

import snsql
from mimesis.providers.base import BaseProvider
from sqlalchemy import create_engine
from sqlalchemy.sql import sqltypes
Expand Down Expand Up @@ -135,13 +136,15 @@ def _add_generator_for_table(


def make_generators_from_tables(
tables_module: ModuleType, generator_config: dict
tables_module: ModuleType, generator_config: dict, src_stats_filename: Optional[str]
) -> str:
"""Create sqlsynthgen generator classes from a sqlacodegen-generated file.

Args:
tables_module: A sqlacodegen-generated module.
generator_config: Configuration to control the generator creation.
src_stats_filename: A filename for where to read src stats from. Optional, if
`None` this feature will be skipped

Returns:
A string that is a valid Python module, once written to file.
Expand All @@ -151,6 +154,14 @@ def make_generators_from_tables(
generator_module_name = generator_config.get("custom_generators_module", None)
if generator_module_name is not None:
new_content += f"\nimport {generator_module_name}"
if src_stats_filename:
new_content += "\nimport yaml"
new_content += (
f'\nwith open("{src_stats_filename}", "r", encoding="utf-8") as f:'
)
new_content += (
f"\n{INDENTATION}SRC_STATS = yaml.load(f, Loader=yaml.FullLoader)"
)

sorted_generators = "[\n"
sorted_vocab = "[\n"
Expand Down Expand Up @@ -219,3 +230,36 @@ def make_tables_file(db_dsn: str, schema_name: Optional[str]) -> str:
)

return completed_process.stdout


def make_src_stats(dsn: str, config: dict) -> dict:
"""Run the src-stats queries specified by the configuration.

Query the src database with the queries in the src-stats block of the `config`
dictionary, using the differential privacy parameters set in the `smartnoise-sql`
block of `config`. Record the results in a dictionary and returns it.
Args:
dsn: postgres connection string
config: a dictionary with the necessary configuration
stats_filename: path to the YAML file to write the output to

Returns:
The dictionary of src-stats.
"""
engine = create_engine(dsn, echo=False, future=True)
dp_config = config.get("smartnoise-sql", {})
snsql_metadata = {"": dp_config}
src_stats = {}
for stat_data in config.get("src-stats", []):
privacy = snsql.Privacy(epsilon=stat_data["epsilon"], delta=stat_data["delta"])
with engine.connect() as conn:
reader = snsql.from_connection(
conn.connection,
engine="postgres",
privacy=privacy,
metadata=snsql_metadata,
)
private_result = reader.execute(stat_data["query"])
# The first entry in the list names the columns, skip that.
src_stats[stat_data["name"]] = private_result[1:]
return src_stats
13 changes: 13 additions & 0 deletions sqlsynthgen/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,16 @@ def timespan(
start = Datetime().datetime(start=earliest_start_year, end=last_start_year)
end = start + delta
return start, end, delta


class WeightedBooleanProvider(BaseProvider):
"""A Mimesis provider for booleans with a given probability for True."""

class Meta:
"""Meta-class for WeightedBooleanProvider settings."""

name = "weighted_boolean_provider"

def bool(self, probability: float) -> bool:
"""Return True with given `probability`, otherwise False."""
return self.random.uniform(0, 1) < probability
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
smartnoise-sql:
public:
person:
# You may well want censor_dims to be on, but we turn it off for the
# tests to silence a smartnoise-sql nag warning.
censor_dims: False
person_id:
name: person_id
type: int
private_id: True
research_opt_out:
name: research_opt_out
type: boolean
private_id: False
src-stats:
- name: count_opt_outs
query: SELECT count(*) AS num, research_opt_out FROM person GROUP BY research_opt_out
epsilon: 0.1
delta: 0.0001
custom_generators_module: custom_generators
tables:
person:
Expand All @@ -11,6 +30,11 @@ tables:
start: 2022
end: 2022
columns_assigned: stored_from
- name: custom_generators.boolean_from_src_stats_generator
args:
generic: generic
src_stats: SRC_STATS["count_opt_outs"]
columns_assigned: research_opt_out

hospital_visit:
num_rows_per_pass: 3
Expand Down
7 changes: 6 additions & 1 deletion tests/examples/expected_ssg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@
generic.add_provider(TimedeltaProvider)
from sqlsynthgen.providers import TimespanProvider
generic.add_provider(TimespanProvider)
from sqlsynthgen.providers import WeightedBooleanProvider
generic.add_provider(WeightedBooleanProvider)

import tests.examples.example_orm
import custom_generators
import yaml
with open("example_stats.yaml", "r", encoding="utf-8") as f:
SRC_STATS = yaml.load(f, Loader=yaml.FullLoader)

concept_vocab = FileUploader(tests.examples.example_orm.Concept.__table__)

Expand All @@ -32,9 +37,9 @@ class personGenerator:
def __init__(self, src_db_conn, dst_db_conn):
self.name = generic.person.full_name()
self.stored_from = generic.datetime.datetime(start=2022, end=2022)
self.research_opt_out = custom_generators.boolean_from_src_stats_generator(generic=generic, src_stats=SRC_STATS["count_opt_outs"])
pass
self.nhs_number = generic.text.color()
self.research_opt_out = generic.development.boolean()
self.source_system = generic.text.color()


Expand Down
30 changes: 0 additions & 30 deletions tests/examples/generator_conf.yaml

This file was deleted.

Loading