Skip to content

Commit 7c0add6

Browse files
myyongclaude
andcommitted
Fix RANDOM()/LIMIT dialect incompatibility in ChoiceGeneratorFactory (closes #106)
Replace raw RANDOM()/LIMIT SQL in choice.py with SQLAlchemy expression API: - New _choice_stmt() helper builds SELECT expressions using func.random()/ func.newid() and .limit()/.subquery(); SQLAlchemy compiles these to RANDOM()/LIMIT on PostgreSQL/DuckDB and NEWID()/TOP on MS-SQL automatically. - ChoiceGenerator.__init__() gains an optional dialect= parameter; the stored _query string (written to src-stats.yaml for later make-stats execution) is now compiled against the source engine's dialect at configure-generators time. - ChoiceGeneratorFactory.get_generators() builds both live queries as SQLAlchemy select() expressions and passes engine.dialect to all ChoiceGenerator constructors. - The suppress-only subquery no longer contains ORDER BY without TOP, which MS-SQL rejects in derived table expressions. - tests/test_generators_dialect.py: 7 new tests covering stored query SQL for all four sample/suppress combinations and both dialect-correct live queries. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent ad98acf commit 7c0add6

2 files changed

Lines changed: 252 additions & 92 deletions

File tree

datafaker/generators/choice.py

Lines changed: 117 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from abc import abstractmethod
77
from typing import Any, Sequence, Union
88

9-
from sqlalchemy import Column, CursorResult, Engine, text
9+
from sqlalchemy import Column, CursorResult, Engine, desc, func, literal_column, select, table
1010

1111
from datafaker.generators.base import (
1212
Generator,
@@ -44,6 +44,55 @@ def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None
4444
yield x
4545

4646

47+
def _choice_stmt(
48+
column_name: str,
49+
table_name: str,
50+
store_counts: bool,
51+
sample_count: int | None,
52+
suppress_count: int,
53+
random_fn: Any,
54+
) -> Any:
55+
"""Build a SQLAlchemy SELECT for gathering choice value distributions.
56+
57+
Compiles to dialect-correct SQL: LIMIT/random() on PostgreSQL/DuckDB,
58+
TOP/newid() on MS-SQL. MS-SQL also forbids ORDER BY inside a subquery
59+
without TOP; this function never emits such a clause.
60+
"""
61+
col = literal_column(f'"{column_name}"')
62+
tbl = table(table_name)
63+
if sample_count is not None:
64+
sample_sub = (
65+
select(col.label("value"))
66+
.where(col.isnot(None))
67+
.select_from(tbl)
68+
.order_by(random_fn)
69+
.limit(sample_count)
70+
.subquery("_inner")
71+
)
72+
counted_sub = (
73+
select(sample_sub.c.value, func.count(sample_sub.c.value).label("count"))
74+
.group_by(sample_sub.c.value)
75+
.subquery("_counted")
76+
)
77+
else:
78+
counted_sub = (
79+
select(col.label("value"), func.count(col).label("count"))
80+
.where(col.isnot(None))
81+
.select_from(tbl)
82+
.group_by(col)
83+
.subquery("_counted")
84+
)
85+
out_cols = [counted_sub.c.value]
86+
if store_counts:
87+
out_cols.append(counted_sub.c["count"])
88+
stmt = select(*out_cols).select_from(counted_sub)
89+
if suppress_count > 0:
90+
stmt = stmt.where(counted_sub.c["count"] > suppress_count)
91+
else:
92+
stmt = stmt.order_by(desc(counted_sub.c["count"]))
93+
return stmt
94+
95+
4796
class ChoiceGenerator(Generator):
4897
"""Base generator for all generators producing choices of items."""
4998

@@ -58,6 +107,7 @@ def __init__(
58107
counts: list[int],
59108
sample_count: int | None = None,
60109
suppress_count: int = 0,
110+
dialect: Any = None,
61111
) -> None:
62112
"""Initialise a ChoiceGenerator."""
63113
super().__init__()
@@ -67,60 +117,41 @@ def __init__(
67117
estimated_counts = self.get_estimated_counts(counts)
68118
self._fit = fit_from_buckets(counts, estimated_counts)
69119

70-
extra_results = ""
71-
extra_expo = ""
72-
extra_comment = ""
73-
if self.STORE_COUNTS:
74-
extra_results = f", COUNT({column_name}) AS count"
75-
extra_expo = ", count"
76-
extra_comment = " and their counts"
120+
extra_comment = " and their counts" if self.STORE_COUNTS else ""
121+
random_fn = (
122+
func.newid()
123+
if (dialect is not None and dialect.name == "mssql")
124+
else func.random()
125+
)
126+
stmt = _choice_stmt(
127+
column_name, table_name, self.STORE_COUNTS, sample_count, suppress_count, random_fn
128+
)
129+
compile_opts: dict[str, Any] = {"compile_kwargs": {"literal_binds": True}}
130+
if dialect is not None:
131+
compile_opts["dialect"] = dialect
132+
self._query = str(stmt.compile(**compile_opts))
133+
77134
if suppress_count == 0:
78135
if sample_count is None:
79-
self._query = (
80-
f"SELECT {column_name} AS value{extra_results} FROM {table_name}"
81-
f" WHERE {column_name} IS NOT NULL GROUP BY value"
82-
f" ORDER BY COUNT({column_name}) DESC"
83-
)
84136
self._comment = (
85137
f"All the values{extra_comment} that appear in column {column_name}"
86138
f" of table {table_name}"
87139
)
88140
self._annotation = None
89141
else:
90-
self._query = (
91-
f"SELECT {column_name} AS value{extra_results} FROM"
92-
f" (SELECT {column_name} FROM {table_name}"
93-
f" WHERE {column_name} IS NOT NULL"
94-
f" ORDER BY RANDOM() LIMIT {sample_count})"
95-
f" AS _inner GROUP BY value ORDER BY COUNT({column_name}) DESC"
96-
)
97142
self._comment = (
98143
f"The values{extra_comment} that appear in column {column_name}"
99144
f" of a random sample of {sample_count} rows of table {table_name}"
100145
)
101146
self._annotation = "sampled"
102147
else:
103148
if sample_count is None:
104-
self._query = (
105-
f"SELECT value{extra_expo} FROM"
106-
f" (SELECT {column_name} AS value, COUNT({column_name}) AS count"
107-
f" FROM {table_name} WHERE {column_name} IS NOT NULL"
108-
f" GROUP BY value ORDER BY count DESC) AS _inner"
109-
f" WHERE {suppress_count} < count"
110-
)
111149
self._comment = (
112150
f"All the values{extra_comment} that appear in column {column_name}"
113151
f" of table {table_name} more than {suppress_count} times"
114152
)
115153
self._annotation = "suppressed"
116154
else:
117-
self._query = (
118-
f"SELECT value{extra_expo} FROM (SELECT value, COUNT(value) AS count FROM"
119-
f" (SELECT {column_name} AS value FROM {table_name}"
120-
f" WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count})"
121-
f" AS _inner GROUP BY value ORDER BY count DESC)"
122-
f" AS _inner WHERE {suppress_count} < count"
123-
)
124155
self._comment = (
125156
f"The values{extra_comment} that appear more than {suppress_count} times"
126157
f" in column {column_name}, out of a random sample of {sample_count} rows"
@@ -302,112 +333,107 @@ def get_generators(
302333
column = columns[0]
303334
column_name = column.name
304335
table_name = column.table.name
336+
dialect = engine.dialect
337+
random_fn = func.newid() if dialect.name == "mssql" else func.random()
338+
col = literal_column(f'"{column_name}"')
339+
tbl = table(table_name)
305340
generators = []
306341
with engine.connect() as connection:
307-
results = connection.execute(
308-
text(
309-
f'SELECT "{column_name}" AS v, COUNT("{column_name}")'
310-
f' AS f FROM "{table_name}" GROUP BY v'
311-
f" ORDER BY f DESC LIMIT {MAXIMUM_CHOICES + 1}"
312-
)
342+
stmt_count = (
343+
select(col.label("v"), func.count(col).label("f"))
344+
.select_from(tbl)
345+
.group_by(col)
346+
.order_by(desc(func.count(col)))
347+
.limit(MAXIMUM_CHOICES + 1)
313348
)
349+
results = connection.execute(stmt_count)
314350
if results is not None and results.rowcount <= MAXIMUM_CHOICES:
315351
vg = ValueGatherer(results, self.SUPPRESS_COUNT)
316352
if vg.counts:
317353
generators += [
318354
ZipfChoiceGenerator(
319-
table_name, column_name, vg.values, vg.counts
355+
table_name, column_name, vg.values, vg.counts,
356+
dialect=dialect,
320357
),
321358
UniformChoiceGenerator(
322-
table_name, column_name, vg.values, vg.counts
359+
table_name, column_name, vg.values, vg.counts,
360+
dialect=dialect,
323361
),
324362
WeightedChoiceGenerator(
325-
table_name, column_name, vg.cvs, vg.counts
363+
table_name, column_name, vg.cvs, vg.counts,
364+
dialect=dialect,
326365
),
327366
]
328367
if vg.counts_not_suppressed:
329368
generators += [
330369
ZipfChoiceGenerator(
331-
table_name,
332-
column_name,
333-
vg.values_not_suppressed,
334-
vg.counts_not_suppressed,
335-
suppress_count=self.SUPPRESS_COUNT,
370+
table_name, column_name,
371+
vg.values_not_suppressed, vg.counts_not_suppressed,
372+
suppress_count=self.SUPPRESS_COUNT, dialect=dialect,
336373
),
337374
UniformChoiceGenerator(
338-
table_name,
339-
column_name,
340-
vg.values_not_suppressed,
341-
vg.counts_not_suppressed,
342-
suppress_count=self.SUPPRESS_COUNT,
375+
table_name, column_name,
376+
vg.values_not_suppressed, vg.counts_not_suppressed,
377+
suppress_count=self.SUPPRESS_COUNT, dialect=dialect,
343378
),
344379
WeightedChoiceGenerator(
345-
table_name=table_name,
346-
column_name=column_name,
380+
table_name=table_name, column_name=column_name,
347381
values=vg.cvs_not_suppressed,
348382
counts=vg.counts_not_suppressed,
349-
suppress_count=self.SUPPRESS_COUNT,
383+
suppress_count=self.SUPPRESS_COUNT, dialect=dialect,
350384
),
351385
]
352-
sampled_results = connection.execute(
353-
text(
354-
f"SELECT v, COUNT(v) AS f FROM"
355-
f' (SELECT "{column_name}" as v FROM "{table_name}"'
356-
f" ORDER BY RANDOM() LIMIT {self.SAMPLE_COUNT})"
357-
f" AS _inner GROUP BY v ORDER BY f DESC"
358-
)
386+
inner = (
387+
select(col.label("v"))
388+
.select_from(tbl)
389+
.order_by(random_fn)
390+
.limit(self.SAMPLE_COUNT)
391+
.subquery("_inner")
392+
)
393+
stmt_sample = (
394+
select(inner.c.v, func.count(inner.c.v).label("f"))
395+
.select_from(inner)
396+
.group_by(inner.c.v)
397+
.order_by(desc(func.count(inner.c.v)))
359398
)
399+
sampled_results = connection.execute(stmt_sample)
360400
if sampled_results is not None:
361401
vg = ValueGatherer(sampled_results, self.SUPPRESS_COUNT)
362402
if vg.counts:
363403
generators += [
364404
ZipfChoiceGenerator(
365-
table_name,
366-
column_name,
367-
vg.values,
368-
vg.counts,
369-
sample_count=self.SAMPLE_COUNT,
405+
table_name, column_name, vg.values, vg.counts,
406+
sample_count=self.SAMPLE_COUNT, dialect=dialect,
370407
),
371408
UniformChoiceGenerator(
372-
table_name,
373-
column_name,
374-
vg.values,
375-
vg.counts,
376-
sample_count=self.SAMPLE_COUNT,
409+
table_name, column_name, vg.values, vg.counts,
410+
sample_count=self.SAMPLE_COUNT, dialect=dialect,
377411
),
378412
WeightedChoiceGenerator(
379-
table_name,
380-
column_name,
381-
vg.cvs,
382-
vg.counts,
383-
sample_count=self.SAMPLE_COUNT,
413+
table_name, column_name, vg.cvs, vg.counts,
414+
sample_count=self.SAMPLE_COUNT, dialect=dialect,
384415
),
385416
]
386417
if vg.counts_not_suppressed:
387418
generators += [
388419
ZipfChoiceGenerator(
389-
table_name,
390-
column_name,
391-
vg.values_not_suppressed,
392-
vg.counts_not_suppressed,
420+
table_name, column_name,
421+
vg.values_not_suppressed, vg.counts_not_suppressed,
393422
sample_count=self.SAMPLE_COUNT,
394-
suppress_count=self.SUPPRESS_COUNT,
423+
suppress_count=self.SUPPRESS_COUNT, dialect=dialect,
395424
),
396425
UniformChoiceGenerator(
397-
table_name,
398-
column_name,
399-
vg.values_not_suppressed,
400-
vg.counts_not_suppressed,
426+
table_name, column_name,
427+
vg.values_not_suppressed, vg.counts_not_suppressed,
401428
sample_count=self.SAMPLE_COUNT,
402-
suppress_count=self.SUPPRESS_COUNT,
429+
suppress_count=self.SUPPRESS_COUNT, dialect=dialect,
403430
),
404431
WeightedChoiceGenerator(
405-
table_name=table_name,
406-
column_name=column_name,
432+
table_name=table_name, column_name=column_name,
407433
values=vg.cvs_not_suppressed,
408434
counts=vg.counts_not_suppressed,
409435
sample_count=self.SAMPLE_COUNT,
410-
suppress_count=self.SUPPRESS_COUNT,
436+
suppress_count=self.SUPPRESS_COUNT, dialect=dialect,
411437
),
412438
]
413439
return generators

0 commit comments

Comments
 (0)