Skip to content

Commit 29e7889

Browse files
myyongclaude
andcommitted
Fix remaining RANDOM()/LIMIT dialect incompatibilities (closes #107)
Six locations in generators and the interactive shell still emitted RANDOM() and LIMIT which MS-SQL does not support. Group A — interactive shell (base.py, table.py, generators.py): Replace sqlalchemy.text() f-strings with SQLAlchemy expression API. func.newid() vs func.random() chosen per dialect.name == "mssql"; .limit(n) compiles to TOP n / LIMIT n automatically. Group B — CovariateQuery (continuous.py): Add dialect_name parameter to __init__(); _inner_query() now emits SELECT TOP n … ORDER BY NEWID() on MS-SQL and SELECT * … ORDER BY RANDOM() LIMIT n elsewhere. dialect_name passed from get_generators(). Group C — MissingnessType (missingness.py): sampled_query() accepts dialect_name; returns the MS-SQL TOP/NEWID form when dialect_name == "mssql". do_sampled() passes self.sync_engine.dialect.name. Tests added in test_generators_dialect.py (CovariateQuery, Missingness) and new test_interactive_dialect.py (peek, _get_column_data, print_column_data) covering both MS-SQL and PostgreSQL. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 41b96f6 commit 29e7889

7 files changed

Lines changed: 306 additions & 35 deletions

File tree

datafaker/generators/continuous.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,13 +396,15 @@ def __init__(
396396
self,
397397
table: str,
398398
factory: MultivariateNormalGeneratorFactoryBase,
399+
dialect_name: str = "",
399400
) -> None:
400401
"""
401402
Initialize the query for the basics for multivariate normal/lognormal parameters.
402403
403404
:param table: The name of the table to be queried.
404405
:param factory: The generator factory, perhaps with overridden
405406
``query_var`` and ``query_predicate`` methods.
407+
:param dialect_name: The SQLAlchemy dialect name (e.g. ``"mssql"``).
406408
"""
407409
self.table = table
408410
self._columns: Sequence[Column] = []
@@ -412,6 +414,7 @@ def __init__(
412414
self.suppress_count = 1
413415
self._sample_count: int | None = None
414416
self._factory = factory
417+
self._dialect_name = dialect_name
415418
self._predicate_fn = lambda x: x + " IS NOT NULL"
416419

417420
def get_query_comment(self) -> str:
@@ -566,6 +569,11 @@ def _inner_query(self) -> str:
566569
where = " WHERE " + where
567570
if self._sample_count is None:
568571
return self.table + where
572+
if self._dialect_name == "mssql":
573+
return (
574+
f"(SELECT TOP {self._sample_count} * FROM {self.table}{where}"
575+
f" ORDER BY NEWID()) AS _sampled"
576+
)
569577
return (
570578
f"(SELECT * FROM {self.table}{where} ORDER BY RANDOM()"
571579
f" LIMIT {self._sample_count}) AS _sampled"
@@ -628,7 +636,7 @@ def get_generators(
628636
return []
629637
column_names = [c.name for c in columns]
630638
table = columns[0].table.name
631-
cq = CovariateQuery(table, self).columns(columns)
639+
cq = CovariateQuery(table, self, dialect_name=engine.dialect.name).columns(columns)
632640
query = cq.get()
633641
with engine.connect() as connection:
634642
try:

datafaker/interactive/base.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import sqlalchemy
1212
from prettytable import PrettyTable
13-
from sqlalchemy import Engine, ForeignKey, MetaData, Table
13+
from sqlalchemy import Engine, ForeignKey, MetaData, Table, func, literal_column, or_, select, table as sa_table
1414
from typing_extensions import Self
1515

1616
from datafaker.utils import (
@@ -412,19 +412,23 @@ def do_peek(self, arg: str) -> None:
412412
col_names = arg.split()
413413
if not col_names:
414414
col_names = self._get_column_names()
415-
nonnulls = [f'"{cn}" IS NOT NULL' for cn in col_names]
415+
random_fn = (
416+
func.newid() if self.sync_engine.dialect.name == "mssql" else func.random()
417+
)
418+
col_exprs = [literal_column(f'"{cn}"') for cn in col_names]
419+
nonnull_clauses = [literal_column(f'"{cn}"').isnot(None) for cn in col_names]
420+
stmt = (
421+
select(*col_exprs)
422+
.select_from(sa_table(table_name))
423+
.where(or_(*nonnull_clauses))
424+
.order_by(random_fn)
425+
.limit(max_peek_rows)
426+
)
416427
with self.sync_engine.connect() as connection:
417-
cols = ", ".join(f'"{cn}"' for cn in col_names)
418-
where = "WHERE" if nonnulls else ""
419-
nonnull = " OR ".join(nonnulls)
420-
query = sqlalchemy.text(
421-
f"SELECT {cols} FROM {table_name} {where} {nonnull}"
422-
f" ORDER BY RANDOM() LIMIT {max_peek_rows}"
423-
)
424428
try:
425-
result = connection.execute(query)
429+
result = connection.execute(stmt)
426430
except sqlalchemy.exc.SQLAlchemyError as exc:
427-
self.print(self.ERROR_FAILED_SQL, exc=exc, query=query)
431+
self.print(self.ERROR_FAILED_SQL, exc=exc, query=stmt)
428432
return
429433
self.print_table(list(result.keys()), result.fetchmany(max_peek_rows))
430434

datafaker/interactive/generators.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Callable, Optional, cast
77

88
import sqlalchemy
9-
from sqlalchemy import Column
9+
from sqlalchemy import Column, and_, func, literal_column, select, table as sa_table
1010

1111
from datafaker.generators import everything_factory
1212
from datafaker.generators.base import Generator, PredefinedGenerator
@@ -754,15 +754,20 @@ def _get_column_data(
754754
self, count: int, to_str: Callable[[Any], str] = repr
755755
) -> list[list[str]]:
756756
columns = self._get_column_names()
757-
columns_string = ", ".join(columns)
758-
pred = " AND ".join(f"{column} IS NOT NULL" for column in columns)
757+
random_fn = (
758+
func.newid() if self.sync_engine.dialect.name == "mssql" else func.random()
759+
)
760+
col_exprs = [literal_column(col) for col in columns]
761+
nonnull_clauses = [literal_column(col).isnot(None) for col in columns]
762+
stmt = (
763+
select(*col_exprs)
764+
.select_from(sa_table(self.table_name()))
765+
.where(and_(*nonnull_clauses))
766+
.order_by(random_fn)
767+
.limit(count)
768+
)
759769
with self.sync_engine.connect() as connection:
760-
result = connection.execute(
761-
sqlalchemy.text(
762-
f"SELECT {columns_string} FROM {self.table_name()}"
763-
f" WHERE {pred} ORDER BY RANDOM() LIMIT {count}"
764-
)
765-
)
770+
result = connection.execute(stmt)
766771
return [[to_str(x) for x in xs] for xs in result.all()]
767772

768773
def do_propose(self, _arg: str) -> None:

datafaker/interactive/missingness.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,32 @@ class MissingnessType:
2323
columns: list[str]
2424

2525
@classmethod
26-
def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> str:
26+
def sampled_query(
27+
cls,
28+
table: str,
29+
count: int,
30+
column_names: Iterable[str],
31+
dialect_name: str = "",
32+
) -> str:
2733
"""
2834
Construct a query to make a sampling of the named rows of the table.
2935
3036
:param table: The name of the table to sample.
3137
:param count: The number of samples to get.
3238
:param column_names: The columns to fetch.
39+
:param dialect_name: The SQLAlchemy dialect name (e.g. ``"mssql"``).
3340
:return: The SQL query to do the sampling.
3441
"""
3542
result_names = ", ".join([f"{c}__is_null" for c in column_names])
3643
column_is_nulls = ", ".join(
3744
[f"{c} IS NULL AS {c}__is_null" for c in column_names]
3845
)
46+
if dialect_name == "mssql":
47+
return (
48+
f"SELECT COUNT(*) AS row_count, {result_names} FROM "
49+
f"(SELECT TOP {count} {column_is_nulls} FROM {table} ORDER BY NEWID())"
50+
f" AS __t GROUP BY {result_names}"
51+
)
3952
return cls.SAMPLED_QUERY.format(
4053
result_names=result_names,
4154
column_is_nulls=column_is_nulls,
@@ -330,6 +343,7 @@ def do_sampled(self, arg: str) -> None:
330343
entry.name,
331344
count,
332345
self.get_nullable_columns(entry.name),
346+
dialect_name=self.sync_engine.dialect.name,
333347
),
334348
[
335349
"The missingness patterns and how often they appear in a"

datafaker/interactive/table.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, cast
55

66
import sqlalchemy
7+
from sqlalchemy import func, literal_column, select, table as sa_table, text
78

89
from datafaker.interactive.base import (
910
TYPE_LETTER,
@@ -477,16 +478,23 @@ def print_column_data(self, column: str, count: int, min_length: int) -> None:
477478
:param count: The number of rows to sample.
478479
:param min_length: The minimum length of text to choose from (0 for any text).
479480
"""
480-
where = f"WHERE {column} IS NOT NULL"
481+
random_fn = (
482+
func.newid() if self.sync_engine.dialect.name == "mssql" else func.random()
483+
)
484+
col_expr = literal_column(column)
481485
if 0 < min_length:
482-
where = f"WHERE LENGTH({column}) >= {min_length}"
486+
where_clause = func.length(col_expr) >= min_length
487+
else:
488+
where_clause = col_expr.isnot(None)
489+
stmt = (
490+
select(col_expr)
491+
.select_from(sa_table(self.table_name()))
492+
.where(where_clause)
493+
.order_by(random_fn)
494+
.limit(count)
495+
)
483496
with self.sync_engine.connect() as connection:
484-
result = connection.execute(
485-
sqlalchemy.text(
486-
f"SELECT {column} FROM {self.table_name()}"
487-
f" {where} ORDER BY RANDOM() LIMIT {count}"
488-
)
489-
)
497+
result = connection.execute(stmt)
490498
self.columnize([str(x[0]) for x in result.all()])
491499

492500
def print_row_data(self, count: int) -> None:
@@ -495,12 +503,17 @@ def print_row_data(self, count: int) -> None:
495503
496504
:param count: The number of rows to report.
497505
"""
506+
random_fn = (
507+
func.newid() if self.sync_engine.dialect.name == "mssql" else func.random()
508+
)
509+
stmt = (
510+
select(text("*"))
511+
.select_from(sa_table(self.table_name()))
512+
.order_by(random_fn)
513+
.limit(count)
514+
)
498515
with self.sync_engine.connect() as connection:
499-
result = connection.execute(
500-
sqlalchemy.text(
501-
f"SELECT * FROM {self.table_name()} ORDER BY RANDOM() LIMIT {count}"
502-
)
503-
)
516+
result = connection.execute(stmt)
504517
if result is None:
505518
self.print("No rows in this table!")
506519
return

tests/test_generators_dialect.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,86 @@ def test_no_schema_omits_qualifier(self) -> None:
309309
"""Without schema, the FROM clause has no dot-qualifier."""
310310
sql = self._get_make_buckets_sql("postgresql", schema=None)
311311
self.assertNotIn(".", sql)
312+
313+
314+
class TestCovariateQueryDialect(unittest.TestCase):
315+
"""CovariateQuery._inner_query() uses TOP/NEWID on MS-SQL and RANDOM/LIMIT elsewhere."""
316+
317+
def _make_factory(self) -> MagicMock:
318+
factory = MagicMock()
319+
factory.query_predicate.return_value = ""
320+
return factory
321+
322+
def _inner_query(self, dialect_name: str) -> str:
323+
from datafaker.generators.continuous import CovariateQuery
324+
325+
cq = (
326+
CovariateQuery("person", self._make_factory(), dialect_name=dialect_name)
327+
.sample_count(500)
328+
)
329+
return cq._inner_query().upper()
330+
331+
def test_mssql_uses_top_and_newid(self) -> None:
332+
"""MS-SQL inner query uses SELECT TOP n … ORDER BY NEWID()."""
333+
sql = self._inner_query("mssql")
334+
self.assertIn("TOP 500", sql)
335+
self.assertIn("NEWID()", sql)
336+
self.assertNotIn("RANDOM()", sql)
337+
self.assertNotIn("LIMIT", sql)
338+
339+
def test_postgresql_uses_random_and_limit(self) -> None:
340+
"""PostgreSQL inner query uses ORDER BY RANDOM() LIMIT n."""
341+
sql = self._inner_query("postgresql")
342+
self.assertIn("RANDOM()", sql)
343+
self.assertIn("LIMIT 500", sql)
344+
self.assertNotIn("NEWID()", sql)
345+
self.assertNotIn("TOP", sql)
346+
347+
def test_no_sample_count_has_no_random_or_limit(self) -> None:
348+
"""When sample_count is None no random ordering is emitted."""
349+
from datafaker.generators.continuous import CovariateQuery
350+
351+
for dialect in ("mssql", "postgresql", ""):
352+
with self.subTest(dialect=dialect):
353+
cq = CovariateQuery("person", self._make_factory(), dialect_name=dialect)
354+
sql = cq._inner_query().upper()
355+
self.assertNotIn("RANDOM()", sql)
356+
self.assertNotIn("NEWID()", sql)
357+
self.assertNotIn("LIMIT", sql)
358+
self.assertNotIn("TOP", sql)
359+
360+
361+
class TestMissingnessQueryDialect(unittest.TestCase):
362+
"""MissingnessType.sampled_query() produces dialect-correct SQL."""
363+
364+
def test_mssql_uses_top_and_newid(self) -> None:
365+
"""MS-SQL sampled query uses SELECT TOP n … ORDER BY NEWID()."""
366+
from datafaker.interactive.missingness import MissingnessType
367+
368+
sql = MissingnessType.sampled_query(
369+
"person", 1000, ["col_a", "col_b"], dialect_name="mssql"
370+
).upper()
371+
self.assertIn("TOP 1000", sql)
372+
self.assertIn("NEWID()", sql)
373+
self.assertNotIn("RANDOM()", sql)
374+
self.assertNotIn("LIMIT", sql)
375+
376+
def test_default_uses_random_and_limit(self) -> None:
377+
"""Default (no dialect) sampled query uses RANDOM() and LIMIT."""
378+
from datafaker.interactive.missingness import MissingnessType
379+
380+
sql = MissingnessType.sampled_query("person", 1000, ["col_a"]).upper()
381+
self.assertIn("RANDOM()", sql)
382+
self.assertIn("LIMIT 1000", sql)
383+
self.assertNotIn("NEWID()", sql)
384+
self.assertNotIn("TOP", sql)
385+
386+
def test_mssql_result_contains_column_null_checks(self) -> None:
387+
"""MS-SQL sampled query retains IS NULL expressions for the named columns."""
388+
from datafaker.interactive.missingness import MissingnessType
389+
390+
sql = MissingnessType.sampled_query(
391+
"person", 500, ["gender_concept_id"], dialect_name="mssql"
392+
)
393+
self.assertIn("gender_concept_id IS NULL", sql)
394+
self.assertIn("gender_concept_id__is_null", sql)

0 commit comments

Comments
 (0)