Skip to content

Commit

Permalink
Many fixups based on dep upgrades.
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Prescod committed May 10, 2023
2 parents 3f4274b + 1f98720 commit e08c6f6
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 19 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ jobs:
python -VV
python -m pip install --upgrade pip
pip install -r requirements_dev.txt
pip install git+https://github.com/SFDO-Tooling/CumulusCI.git@main
- name: Run Tests
run: python -m pytest
Expand Down
12 changes: 9 additions & 3 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# This file is autogenerated by pip-compile with Python 3.9
# by the following command:
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# pip-compile --allow-unsafe requirements/dev.in
#
Expand Down Expand Up @@ -75,6 +75,8 @@ importlib-metadata==6.6.0
# via
# markdown
# mkdocs
importlib-resources==5.12.0
# via jsonschema
iniconfig==2.0.0
# via pytest
jinja2==3.1.2
Expand Down Expand Up @@ -115,6 +117,8 @@ packaging==23.1
# tox
pathspec==0.11.1
# via black
pkgutil-resolve-name==1.3.10
# via jsonschema
platformdirs==3.5.0
# via
# black
Expand Down Expand Up @@ -218,7 +222,9 @@ wrapt==1.15.0
yarl==1.9.2
# via vcrpy
zipp==3.15.0
# via importlib-metadata
# via
# importlib-metadata
# importlib-resources

# The following packages are considered to be unsafe in a requirements file:
setuptools==67.7.2
Expand Down
5 changes: 4 additions & 1 deletion requirements/prod.in
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@ gvgen
pydantic
python-baseconv
requests
urllib3<2.0

# remove this line when VCR is fixed
# https://github.com/kevin1024/vcrpy/issues/688
urllib3<2.0.0
4 changes: 2 additions & 2 deletions requirements/prod.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# This file is autogenerated by pip-compile with Python 3.9
# by the following command:
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# pip-compile --allow-unsafe requirements/prod.in
#
Expand Down
16 changes: 14 additions & 2 deletions snowfakery/output_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,19 +319,26 @@ def from_url(cls, db_url: str, mappings: None = None):
if mappings: # pragma: no cover -- should not be triggered.
warn("Please do not pass mappings argument to from_url", DeprecationWarning)
try:
print("create_engine from_url", db_url)
engine = create_engine(db_url)
except ModuleNotFoundError as e:
raise DataGenError(f"Cannot find a driver for your database: {e}")
except Exception as e:
raise DataGenError(f"Cannot connect to database: {e}")
self = cls(engine)
setattr(self, "url", db_url)
return self

def write_single_row(self, tablename: str, row: Dict) -> None:
# cache the value for later insert
self.buffered_rows[tablename].append(row)

def flush(self):
with self.session.begin():
self._flush_rows()
self.session.flush()

def _flush_rows(self):
for tablename, (insert_statement, fallback_dict) in self.table_info.items():
# Make sure every row has the same records per SQLAlchemy's rules

Expand All @@ -350,16 +357,16 @@ def flush(self):
if values:
self.session.execute(insert_statement, values)
self.buffered_rows[tablename] = []
self.session.flush()

def commit(self):
if any(self.buffered_rows):
self.flush()
self.session.commit()

def close(self, **kwargs) -> Optional[Sequence[str]]:
print("Starting close, SqlDbOutputStream")
self.commit()
self.session.close()
self.engine.dispose()

def create_or_validate_tables(self, inferred_tables: Dict[str, TableInfo]) -> None:
try:
Expand Down Expand Up @@ -410,6 +417,7 @@ def __init__(self, stream_or_path=None, **kwargs):
def _init_db(self):
"Initialize a db through an owned output stream"
db_url = f"sqlite:///{self.tempdir.name}/tempdb.db"
print("create_engine _init_db", db_url)
engine = create_engine(db_url)
return SqlDbOutputStream(engine)

Expand Down Expand Up @@ -437,11 +445,15 @@ def _dump_db(self):
assert self.text_output.stream
self.text_output.stream.write("%s\n" % line)

con.close()

def close(self, *args, **kwargs):
print("starting close", self.tempdir)
self._dump_db()
self.sql_db.close(*args, **kwargs)
self.text_output.close(*args, **kwargs)
self.tempdir.cleanup()
print("Ended close", self.tempdir)


def create_tables_from_inferred_fields(
Expand Down
1 change: 1 addition & 0 deletions snowfakery/salesforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

def create_cci_record_type_tables(db_url: str):
"""Create record type tables that CCI expects"""
print("create_cci_record_type_tables", db_url)
engine = create_engine(db_url)
metadata = MetaData()
metadata.reflect(views=True, bind=engine)
Expand Down
1 change: 1 addition & 0 deletions snowfakery/standard_plugins/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
def _open_db(db_url):
"Internal function for opening the database up."
engine = create_engine(db_url)
print("datasets pluginn", db_url)
metadata = MetaData()
metadata.reflect(views=True, bind=engine)
return engine, metadata
Expand Down
6 changes: 2 additions & 4 deletions snowfakery/tools/mkdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from unittest.mock import patch

from mkdocs.plugins import BasePlugin
import mkdocs
from mkdocs.config import config_options


class Plugin(BasePlugin):
config_scheme = (
("build_locales", mkdocs.config.config_options.Type(bool, default=False)),
)
config_scheme = (("build_locales", config_options.Type(bool, default=False)),)

def on_config(self, config):
"""Look for and load main_mkdocs_plugin in tools/faker_docs_utils/mkdocs_plugins.py
Expand Down
4 changes: 2 additions & 2 deletions snowfakery/tools/snowbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import locale

import click
from sqlalchemy import create_engine, inspect
from sqlalchemy import create_engine, inspect, text

from snowfakery import generate_data

Expand Down Expand Up @@ -170,7 +170,7 @@ def count_database(filename, counts):

def count_table(engine, tablename):
with engine.connect() as c:
return c.execute(f"select count(Id) from '{tablename}'").first()[0]
return c.execute(text(f"select count(Id) from '{tablename}'")).first()[0]


def snowfakery(recipe, num_records, tablename, outputfile):
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def doit(recipe_data, *args, **kwargs):
**kwargs,
)
mapping = yaml.safe_load(mapping_file.read_text())
print("DOIT", dburl)
e = create_engine(dburl)
with e.connect() as connection:
yield mapping, connection
Expand Down
18 changes: 15 additions & 3 deletions tests/test_output_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,21 @@ def do_output(self, yaml, url=None):
results = generate(StringIO(yaml), {}, output_stream)
table_names = results.tables.keys()
output_stream.close()
print("do_output", url)
engine = create_engine(url)
with engine.connect() as connection:
tables = {
table_name: [
row._mapping
dict(row._mapping)
for row in connection.execute(
text(f"select * from {table_name}")
)
]
for table_name in table_names
}
return tables
engine.dispose()
del engine
return tables

def test_null(self):
yaml = """
Expand All @@ -206,10 +209,14 @@ def test_table_already_exists(self):
metadata.create_all(bind=engine)
with engine.begin() as c:
c.execute(t.insert().values([[5]]))
engine.dispose()

with pytest.raises(exc.DataGenError, match="Table already exists"):
output_stream = SqlDbOutputStream.from_url(url)
generate(StringIO(yaml), {}, output_stream)
try:
generate(StringIO(yaml), {}, output_stream)
finally:
output_stream.close()

def test_bad_database_connection(self):
yaml = """
Expand Down Expand Up @@ -288,6 +295,7 @@ def test_json_output_mocked(self):
def test_from_cli(self):
x = StringIO()
with redirect_stdout(x):
assert generate_cli.callback
generate_cli.callback(yaml_file=sample_yaml, output_format="json")
data = json.loads(x.getvalue())
print(data)
Expand Down Expand Up @@ -363,6 +371,7 @@ def test_csv_output(self):
output_stream = CSVOutputStream(Path(t) / "csvoutput")
generate(StringIO(yaml), {}, output_stream)
messages = output_stream.close()
assert messages
assert "foo.csv" in messages[0]
assert "bar.csv" in messages[1]
assert "csvw" in messages[2]
Expand Down Expand Up @@ -417,6 +426,7 @@ class TestExternalOutputStream:
def test_external_output_stream(self):
x = StringIO()
with redirect_stdout(x):
assert generate_cli.callback
generate_cli.callback(
yaml_file=sample_yaml, output_format="package1.TestOutputStream"
)
Expand All @@ -430,6 +440,7 @@ def test_external_output_stream(self):
def test_external_output_stream_yaml(self):
x = StringIO()
with redirect_stdout(x):
assert generate_cli.callback
generate_cli.callback(
yaml_file=sample_yaml, output_format="examples.YamlOutputStream"
)
Expand All @@ -451,6 +462,7 @@ def test_external_output_stream_yaml(self):

def test_external_output_stream__failure(self):
with pytest.raises(ClickException, match="no.such.output.Stream"):
assert generate_cli.callback
generate_cli.callback(
yaml_file=sample_yaml, output_format="no.such.output.Stream"
)
5 changes: 3 additions & 2 deletions tests/test_with_cci.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from snowfakery.standard_plugins import Salesforce

try:
import cumulusci
import cumulusci # type: ignore
except ImportError:
cumulusci = False

Expand All @@ -43,7 +43,7 @@ def test_mapping_file(self):
],
standalone_mode=False,
)

print("test_with_cci", url)
engine = create_engine(url)
with engine.connect() as connection:
result = [
Expand All @@ -52,6 +52,7 @@ def test_mapping_file(self):
]
assert result[0]["id"] == 1
assert result[0]["BillingCountry"] == "Canada"
engine.dispose()


class FakeSimpleSalesforce:
Expand Down

0 comments on commit e08c6f6

Please sign in to comment.