66from abc import abstractmethod
77from typing import Any , Sequence , Union
88
9- from sqlalchemy import Column , CursorResult , Engine , text
9+ from sqlalchemy import Column , CursorResult , Engine , desc , func , literal_column , select , table
1010
1111from datafaker .generators .base import (
1212 Generator ,
@@ -44,6 +44,55 @@ def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None
4444 yield x
4545
4646
47+ def _choice_stmt (
48+ column_name : str ,
49+ table_name : str ,
50+ store_counts : bool ,
51+ sample_count : int | None ,
52+ suppress_count : int ,
53+ random_fn : Any ,
54+ ) -> Any :
55+ """Build a SQLAlchemy SELECT for gathering choice value distributions.
56+
57+ Compiles to dialect-correct SQL: LIMIT/random() on PostgreSQL/DuckDB,
58+ TOP/newid() on MS-SQL. MS-SQL also forbids ORDER BY inside a subquery
59+ without TOP; this function never emits such a clause.
60+ """
61+ col = literal_column (f'"{ column_name } "' )
62+ tbl = table (table_name )
63+ if sample_count is not None :
64+ sample_sub = (
65+ select (col .label ("value" ))
66+ .where (col .isnot (None ))
67+ .select_from (tbl )
68+ .order_by (random_fn )
69+ .limit (sample_count )
70+ .subquery ("_inner" )
71+ )
72+ counted_sub = (
73+ select (sample_sub .c .value , func .count (sample_sub .c .value ).label ("count" ))
74+ .group_by (sample_sub .c .value )
75+ .subquery ("_counted" )
76+ )
77+ else :
78+ counted_sub = (
79+ select (col .label ("value" ), func .count (col ).label ("count" ))
80+ .where (col .isnot (None ))
81+ .select_from (tbl )
82+ .group_by (col )
83+ .subquery ("_counted" )
84+ )
85+ out_cols = [counted_sub .c .value ]
86+ if store_counts :
87+ out_cols .append (counted_sub .c ["count" ])
88+ stmt = select (* out_cols ).select_from (counted_sub )
89+ if suppress_count > 0 :
90+ stmt = stmt .where (counted_sub .c ["count" ] > suppress_count )
91+ else :
92+ stmt = stmt .order_by (desc (counted_sub .c ["count" ]))
93+ return stmt
94+
95+
4796class ChoiceGenerator (Generator ):
4897 """Base generator for all generators producing choices of items."""
4998
@@ -58,6 +107,7 @@ def __init__(
58107 counts : list [int ],
59108 sample_count : int | None = None ,
60109 suppress_count : int = 0 ,
110+ dialect : Any = None ,
61111 ) -> None :
62112 """Initialise a ChoiceGenerator."""
63113 super ().__init__ ()
@@ -67,60 +117,41 @@ def __init__(
67117 estimated_counts = self .get_estimated_counts (counts )
68118 self ._fit = fit_from_buckets (counts , estimated_counts )
69119
70- extra_results = ""
71- extra_expo = ""
72- extra_comment = ""
73- if self .STORE_COUNTS :
74- extra_results = f", COUNT({ column_name } ) AS count"
75- extra_expo = ", count"
76- extra_comment = " and their counts"
120+ extra_comment = " and their counts" if self .STORE_COUNTS else ""
121+ random_fn = (
122+ func .newid ()
123+ if (dialect is not None and dialect .name == "mssql" )
124+ else func .random ()
125+ )
126+ stmt = _choice_stmt (
127+ column_name , table_name , self .STORE_COUNTS , sample_count , suppress_count , random_fn
128+ )
129+ compile_opts : dict [str , Any ] = {"compile_kwargs" : {"literal_binds" : True }}
130+ if dialect is not None :
131+ compile_opts ["dialect" ] = dialect
132+ self ._query = str (stmt .compile (** compile_opts ))
133+
77134 if suppress_count == 0 :
78135 if sample_count is None :
79- self ._query = (
80- f"SELECT { column_name } AS value{ extra_results } FROM { table_name } "
81- f" WHERE { column_name } IS NOT NULL GROUP BY value"
82- f" ORDER BY COUNT({ column_name } ) DESC"
83- )
84136 self ._comment = (
85137 f"All the values{ extra_comment } that appear in column { column_name } "
86138 f" of table { table_name } "
87139 )
88140 self ._annotation = None
89141 else :
90- self ._query = (
91- f"SELECT { column_name } AS value{ extra_results } FROM"
92- f" (SELECT { column_name } FROM { table_name } "
93- f" WHERE { column_name } IS NOT NULL"
94- f" ORDER BY RANDOM() LIMIT { sample_count } )"
95- f" AS _inner GROUP BY value ORDER BY COUNT({ column_name } ) DESC"
96- )
97142 self ._comment = (
98143 f"The values{ extra_comment } that appear in column { column_name } "
99144 f" of a random sample of { sample_count } rows of table { table_name } "
100145 )
101146 self ._annotation = "sampled"
102147 else :
103148 if sample_count is None :
104- self ._query = (
105- f"SELECT value{ extra_expo } FROM"
106- f" (SELECT { column_name } AS value, COUNT({ column_name } ) AS count"
107- f" FROM { table_name } WHERE { column_name } IS NOT NULL"
108- f" GROUP BY value ORDER BY count DESC) AS _inner"
109- f" WHERE { suppress_count } < count"
110- )
111149 self ._comment = (
112150 f"All the values{ extra_comment } that appear in column { column_name } "
113151 f" of table { table_name } more than { suppress_count } times"
114152 )
115153 self ._annotation = "suppressed"
116154 else :
117- self ._query = (
118- f"SELECT value{ extra_expo } FROM (SELECT value, COUNT(value) AS count FROM"
119- f" (SELECT { column_name } AS value FROM { table_name } "
120- f" WHERE { column_name } IS NOT NULL ORDER BY RANDOM() LIMIT { sample_count } )"
121- f" AS _inner GROUP BY value ORDER BY count DESC)"
122- f" AS _inner WHERE { suppress_count } < count"
123- )
124155 self ._comment = (
125156 f"The values{ extra_comment } that appear more than { suppress_count } times"
126157 f" in column { column_name } , out of a random sample of { sample_count } rows"
@@ -302,112 +333,107 @@ def get_generators(
302333 column = columns [0 ]
303334 column_name = column .name
304335 table_name = column .table .name
336+ dialect = engine .dialect
337+ random_fn = func .newid () if dialect .name == "mssql" else func .random ()
338+ col = literal_column (f'"{ column_name } "' )
339+ tbl = table (table_name )
305340 generators = []
306341 with engine .connect () as connection :
307- results = connection . execute (
308- text (
309- f'SELECT " { column_name } " AS v, COUNT(" { column_name } ")'
310- f' AS f FROM " { table_name } " GROUP BY v'
311- f" ORDER BY f DESC LIMIT { MAXIMUM_CHOICES + 1 } "
312- )
342+ stmt_count = (
343+ select ( col . label ( "v" ), func . count ( col ). label ( "f" ))
344+ . select_from ( tbl )
345+ . group_by ( col )
346+ . order_by ( desc ( func . count ( col )))
347+ . limit ( MAXIMUM_CHOICES + 1 )
313348 )
349+ results = connection .execute (stmt_count )
314350 if results is not None and results .rowcount <= MAXIMUM_CHOICES :
315351 vg = ValueGatherer (results , self .SUPPRESS_COUNT )
316352 if vg .counts :
317353 generators += [
318354 ZipfChoiceGenerator (
319- table_name , column_name , vg .values , vg .counts
355+ table_name , column_name , vg .values , vg .counts ,
356+ dialect = dialect ,
320357 ),
321358 UniformChoiceGenerator (
322- table_name , column_name , vg .values , vg .counts
359+ table_name , column_name , vg .values , vg .counts ,
360+ dialect = dialect ,
323361 ),
324362 WeightedChoiceGenerator (
325- table_name , column_name , vg .cvs , vg .counts
363+ table_name , column_name , vg .cvs , vg .counts ,
364+ dialect = dialect ,
326365 ),
327366 ]
328367 if vg .counts_not_suppressed :
329368 generators += [
330369 ZipfChoiceGenerator (
331- table_name ,
332- column_name ,
333- vg .values_not_suppressed ,
334- vg .counts_not_suppressed ,
335- suppress_count = self .SUPPRESS_COUNT ,
370+ table_name , column_name ,
371+ vg .values_not_suppressed , vg .counts_not_suppressed ,
372+ suppress_count = self .SUPPRESS_COUNT , dialect = dialect ,
336373 ),
337374 UniformChoiceGenerator (
338- table_name ,
339- column_name ,
340- vg .values_not_suppressed ,
341- vg .counts_not_suppressed ,
342- suppress_count = self .SUPPRESS_COUNT ,
375+ table_name , column_name ,
376+ vg .values_not_suppressed , vg .counts_not_suppressed ,
377+ suppress_count = self .SUPPRESS_COUNT , dialect = dialect ,
343378 ),
344379 WeightedChoiceGenerator (
345- table_name = table_name ,
346- column_name = column_name ,
380+ table_name = table_name , column_name = column_name ,
347381 values = vg .cvs_not_suppressed ,
348382 counts = vg .counts_not_suppressed ,
349- suppress_count = self .SUPPRESS_COUNT ,
383+ suppress_count = self .SUPPRESS_COUNT , dialect = dialect ,
350384 ),
351385 ]
352- sampled_results = connection .execute (
353- text (
354- f"SELECT v, COUNT(v) AS f FROM"
355- f' (SELECT "{ column_name } " as v FROM "{ table_name } "'
356- f" ORDER BY RANDOM() LIMIT { self .SAMPLE_COUNT } )"
357- f" AS _inner GROUP BY v ORDER BY f DESC"
358- )
386+ inner = (
387+ select (col .label ("v" ))
388+ .select_from (tbl )
389+ .order_by (random_fn )
390+ .limit (self .SAMPLE_COUNT )
391+ .subquery ("_inner" )
392+ )
393+ stmt_sample = (
394+ select (inner .c .v , func .count (inner .c .v ).label ("f" ))
395+ .select_from (inner )
396+ .group_by (inner .c .v )
397+ .order_by (desc (func .count (inner .c .v )))
359398 )
399+ sampled_results = connection .execute (stmt_sample )
360400 if sampled_results is not None :
361401 vg = ValueGatherer (sampled_results , self .SUPPRESS_COUNT )
362402 if vg .counts :
363403 generators += [
364404 ZipfChoiceGenerator (
365- table_name ,
366- column_name ,
367- vg .values ,
368- vg .counts ,
369- sample_count = self .SAMPLE_COUNT ,
405+ table_name , column_name , vg .values , vg .counts ,
406+ sample_count = self .SAMPLE_COUNT , dialect = dialect ,
370407 ),
371408 UniformChoiceGenerator (
372- table_name ,
373- column_name ,
374- vg .values ,
375- vg .counts ,
376- sample_count = self .SAMPLE_COUNT ,
409+ table_name , column_name , vg .values , vg .counts ,
410+ sample_count = self .SAMPLE_COUNT , dialect = dialect ,
377411 ),
378412 WeightedChoiceGenerator (
379- table_name ,
380- column_name ,
381- vg .cvs ,
382- vg .counts ,
383- sample_count = self .SAMPLE_COUNT ,
413+ table_name , column_name , vg .cvs , vg .counts ,
414+ sample_count = self .SAMPLE_COUNT , dialect = dialect ,
384415 ),
385416 ]
386417 if vg .counts_not_suppressed :
387418 generators += [
388419 ZipfChoiceGenerator (
389- table_name ,
390- column_name ,
391- vg .values_not_suppressed ,
392- vg .counts_not_suppressed ,
420+ table_name , column_name ,
421+ vg .values_not_suppressed , vg .counts_not_suppressed ,
393422 sample_count = self .SAMPLE_COUNT ,
394- suppress_count = self .SUPPRESS_COUNT ,
423+ suppress_count = self .SUPPRESS_COUNT , dialect = dialect ,
395424 ),
396425 UniformChoiceGenerator (
397- table_name ,
398- column_name ,
399- vg .values_not_suppressed ,
400- vg .counts_not_suppressed ,
426+ table_name , column_name ,
427+ vg .values_not_suppressed , vg .counts_not_suppressed ,
401428 sample_count = self .SAMPLE_COUNT ,
402- suppress_count = self .SUPPRESS_COUNT ,
429+ suppress_count = self .SUPPRESS_COUNT , dialect = dialect ,
403430 ),
404431 WeightedChoiceGenerator (
405- table_name = table_name ,
406- column_name = column_name ,
432+ table_name = table_name , column_name = column_name ,
407433 values = vg .cvs_not_suppressed ,
408434 counts = vg .counts_not_suppressed ,
409435 sample_count = self .SAMPLE_COUNT ,
410- suppress_count = self .SUPPRESS_COUNT ,
436+ suppress_count = self .SUPPRESS_COUNT , dialect = dialect ,
411437 ),
412438 ]
413439 return generators
0 commit comments