Skip to content

Commit

Permalink
Support generated columns in classic module (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Aug 26, 2023
1 parent a2f5185 commit 538bb9f
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 9 deletions.
7 changes: 7 additions & 0 deletions docs/contents/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ ChangeLog
Version 5.2.5 (to be released)
------------------------------
- This version officially supports the new Python 3.11 and PostgreSQL 15.
- Two more improvements in the `inserttable()` method of the `pg` module
(thanks to Justin Pryzby for this contribution):
- error handling has been improved (#72)
- the method now returns the number of inserted rows (#73)
- Another improvement in the `pg` module (#83):
- generated columns can be requested with the `get_generated()` method
- generated columns are ignored by the insert, update and upsert method

Version 5.2.4 (2022-03-26)
--------------------------
Expand Down
14 changes: 14 additions & 0 deletions docs/contents/pg/db_wrapper.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,20 @@ By default, only a limited number of simple types will be returned.
You can get the registered types instead, if enabled by calling the
:meth:`DB.use_regtypes` method.

get_generated -- get the generated columns of a table
-----------------------------------------------------

.. method:: DB.get_generated(table)

Get the generated columns of a table

:param str table: name of table
:returns: an frozenset of column names

Given the name of a table, digs out the set of generated columns.

.. versionadded:: 5.2.5

has_table_privilege -- check table privilege
--------------------------------------------

Expand Down
55 changes: 48 additions & 7 deletions pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,7 @@ def __init__(self, *args, **kw):
self.dbname = db.db
self._regtypes = False
self._attnames = {}
self._generated = {}
self._pkeys = {}
self._privileges = {}
self.adapter = Adapter(self)
Expand Down Expand Up @@ -1657,6 +1658,17 @@ def __init__(self, *args, **kw):
" WHERE a.attrelid OPERATOR(pg_catalog.=)"
" %s::pg_catalog.regclass"
" AND %s AND NOT a.attisdropped ORDER BY a.attnum")
if db.server_version < 100000:
self._query_generated = None
elif db.server_version < 120000:
self._query_generated = (
"a.attidentity OPERATOR(pg_catalog.=) 'a'"
)
else:
self._query_generated = (
"(a.attidentity OPERATOR(pg_catalog.=) 'a' OR"
" a.attgenerated OPERATOR(pg_catalog.!=) '')"
)
db.set_cast_hook(self.dbtypes.typecast)
# For debugging scripts, self.debug can be set
# * to a string format specification (e.g. in CGI set to "%s<BR>"),
Expand Down Expand Up @@ -2130,7 +2142,7 @@ def get_relations(self, kinds=None, system=False):
"""Get list of relations in connected database of specified kinds.
If kinds is None or empty, all kinds of relations are returned.
Otherwise kinds can be a string or sequence of type letters
Otherwise, kinds can be a string or sequence of type letters
specifying which kind of relations you want to list.
Set the system flag if you want to get the system relations as well.
Expand Down Expand Up @@ -2190,6 +2202,32 @@ def get_attnames(self, table, with_oid=True, flush=False):
attnames[table] = names # cache it
return names

def get_generated(self, table, flush=False):
"""Given the name of a table, dig out the set of generated columns.
Returns a set of column names that are generated and unalterable.
If flush is set, then the internal cache for generated columns will
be flushed. This may be necessary after the database schema or
the search path has been changed.
"""
query_generated = self._query_generated
if not query_generated:
return frozenset()
generated = self._generated
if flush:
generated.clear()
self._do_debug('The generated cache has been flushed')
try: # cache lookup
names = generated[table]
except KeyError: # cache miss, check the database
q = "a.attnum OPERATOR(pg_catalog.>) 0 AND " + query_generated
q = self._query_attnames % (_quote_if_unqualified('$1', table), q)
names = self.db.query(q, (table,)).getresult()
names = frozenset(name[0] for name in names)
generated[table] = names # cache it
return names

def use_regtypes(self, regtypes=None):
"""Use registered type names instead of simplified type names."""
if regtypes is None:
Expand Down Expand Up @@ -2307,8 +2345,8 @@ def insert(self, table, row=None, **kw):
be passed as the first parameter. The other parameters are used for
providing the data of the row that shall be inserted into the table.
If a dictionary is supplied as the second parameter, it starts with
that. Otherwise it uses a blank dictionary. Either way the dictionary
is updated from the keywords.
that. Otherwise, it uses a blank dictionary.
Either way the dictionary is updated from the keywords.
The dictionary is then reloaded with the values actually inserted in
order to pick up values modified by rules, triggers, etc.
Expand All @@ -2321,13 +2359,14 @@ def insert(self, table, row=None, **kw):
if 'oid' in row:
del row['oid'] # do not insert oid
attnames = self.get_attnames(table)
generated = self.get_generated(table)
qoid = _oid_key(table) if 'oid' in attnames else None
params = self.adapter.parameter_list()
adapt = params.add
col = self.escape_identifier
names, values = [], []
for n in attnames:
if n in row:
if n in row and n not in generated:
names.append(col(n))
values.append(adapt(row[n], attnames[n]))
if not names:
Expand Down Expand Up @@ -2360,6 +2399,7 @@ def update(self, table, row=None, **kw):
if table.endswith('*'):
table = table[:-1].rstrip() # need parent table name
attnames = self.get_attnames(table)
generated = self.get_generated(table)
qoid = _oid_key(table) if 'oid' in attnames else None
if row is None:
row = {}
Expand Down Expand Up @@ -2390,7 +2430,7 @@ def update(self, table, row=None, **kw):
values = []
keyname = set(keyname)
for n in attnames:
if n in row and n not in keyname:
if n in row and n not in keyname and n not in generated:
values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
if not values:
return row
Expand Down Expand Up @@ -2461,13 +2501,14 @@ def upsert(self, table, row=None, **kw):
if 'oid' in kw:
del kw['oid'] # do not update oid
attnames = self.get_attnames(table)
generated = self.get_generated(table)
qoid = _oid_key(table) if 'oid' in attnames else None
params = self.adapter.parameter_list()
adapt = params.add
col = self.escape_identifier
names, values = [], []
for n in attnames:
if n in row:
if n in row and n not in generated:
names.append(col(n))
values.append(adapt(row[n], attnames[n]))
names, values = ', '.join(names), ', '.join(values)
Expand All @@ -2480,7 +2521,7 @@ def upsert(self, table, row=None, **kw):
keyname = set(keyname)
keyname.add('oid')
for n in attnames:
if n not in keyname:
if n not in keyname and n not in generated:
value = kw.get(n, n in row)
if value:
if not isinstance(value, basestring):
Expand Down
140 changes: 138 additions & 2 deletions tests/test_classic_dbwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def testAllDBAttributes(self):
'escape_literal', 'escape_string',
'fileno',
'get', 'get_as_dict', 'get_as_list',
'get_attnames', 'get_cast_hook',
'get_databases', 'get_notice_receiver',
'get_attnames', 'get_cast_hook', 'get_databases',
'get_generated', 'get_notice_receiver',
'get_parameter', 'get_relations', 'get_tables',
'getline', 'getlo', 'getnotify',
'has_table_privilege', 'host',
Expand Down Expand Up @@ -1473,6 +1473,53 @@ def testGetAttnamesIsAttrDict(self):
r = ' '.join(list(r.keys()))
self.assertEqual(r, 'n alpha v gamma tau beta')

def testGetGenerated(self):
get_generated = self.db.get_generated
server_version = self.db.server_version
if server_version >= 100000:
self.assertRaises(pg.ProgrammingError,
self.db.get_generated, 'does_not_exist')
self.assertRaises(pg.ProgrammingError,
self.db.get_generated, 'has.too.many.dots')
r = get_generated('test')
self.assertIsInstance(r, frozenset)
self.assertFalse(r)
if server_version >= 100000:
table = 'test_get_generated_1'
self.createTable(
table,
'i int generated always as identity primary key,'
' j int generated always as identity,'
' k int generated by default as identity,'
' n serial, m int')
r = get_generated(table)
self.assertIsInstance(r, frozenset)
self.assertEqual(r, {'i', 'j'})
if server_version >= 120000:
table = 'test_get_generated_2'
self.createTable(
table,
'n int, m int generated always as (n + 3) stored,'
' i int generated always as identity,'
' j int generated by default as identity')
r = get_generated(table)
self.assertIsInstance(r, frozenset)
self.assertEqual(r, {'m', 'i'})

def testGetGeneratedIsCached(self):
server_version = self.db.server_version
if server_version < 100000:
return
get_generated = self.db.get_generated
query = self.db.query
table = 'test_get_generated_2'
self.createTable(table, 'i int primary key')
self.assertFalse(get_generated(table))
query('alter table %s alter column i'
' add generated always as identity' % table)
self.assertFalse(get_generated(table))
self.assertEqual(get_generated(table, flush=True), {'i'})

def testHasTablePrivilege(self):
can = self.db.has_table_privilege
self.assertEqual(can('test'), True)
Expand Down Expand Up @@ -1918,6 +1965,32 @@ def testInsertIntoView(self):
r = query(q).getresult()
self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])

def testInsertWithGeneratedColumns(self):
insert = self.db.insert
get = self.db.get
server_version = self.db.server_version
table = 'insert_test_table_2'
table_def = 'i int not null'
if server_version >= 100000:
table_def += (
', a int generated always as identity'
', d int generated by default as identity primary key')
else:
table_def += ', a int not null default 1, d int primary key'
if server_version >= 120000:
table_def += ', j int generated always as (i + 7) stored'
else:
table_def += ', j int not null default 42'
self.createTable(table, table_def)
i, d = 35, 1001
j = i + 7
r = insert(table, {'i': i, 'd': d, 'a': 1, 'j': j})
self.assertIsInstance(r, dict)
self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j})
r = get(table, d)
self.assertIsInstance(r, dict)
self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j})

def testUpdate(self):
update = self.db.update
query = self.db.query
Expand Down Expand Up @@ -2089,6 +2162,38 @@ def testUpdateWithQuotedNames(self):
self.assertEqual(r['much space'], 7007)
self.assertEqual(r['Questions?'], 'When?')

def testUpdateWithGeneratedColumns(self):
update = self.db.update
get = self.db.get
query = self.db.query
server_version = self.db.server_version
table = 'update_test_table_2'
table_def = 'i int not null'
if server_version >= 100000:
table_def += (
', a int generated always as identity'
', d int generated by default as identity primary key')
else:
table_def += ', a int not null default 1, d int primary key'
if server_version >= 120000:
table_def += ', j int generated always as (i + 7) stored'
else:
table_def += ', j int not null default 42'
self.createTable(table, table_def)
i, d = 35, 1001
j = i + 7
r = query('insert into %s (i, d) values (%d, %d)' % (table, i, d))
self.assertEqual(r, '1')
r = get(table, d)
self.assertIsInstance(r, dict)
self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j})
r['i'] += 1
r = update(table, r)
i += 1
if server_version >= 120000:
j += 1
self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j})

def testUpsert(self):
upsert = self.db.upsert
query = self.db.query
Expand Down Expand Up @@ -2349,6 +2454,37 @@ def testUpsertWithQuotedNames(self):
r = query(q).getresult()
self.assertEqual(r, [(31, 9009, 'No.')])

def testUpsertWithGeneratedColumns(self):
upsert = self.db.upsert
get = self.db.get
server_version = self.db.server_version
table = 'upsert_test_table_2'
table_def = 'i int not null'
if server_version >= 100000:
table_def += (
', a int generated always as identity'
', d int generated by default as identity primary key')
else:
table_def += ', a int not null default 1, d int primary key'
if server_version >= 120000:
table_def += ', j int generated always as (i + 7) stored'
else:
table_def += ', j int not null default 42'
self.createTable(table, table_def)
i, d = 35, 1001
j = i + 7
r = upsert(table, {'i': i, 'd': d, 'a': 1, 'j': j})
self.assertIsInstance(r, dict)
self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j})
r['i'] += 1
r = upsert(table, r)
i += 1
if server_version >= 120000:
j += 1
self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j})
r = get(table, d)
self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j})

def testClear(self):
clear = self.db.clear
f = False if pg.get_bool() else 'f'
Expand Down

0 comments on commit 538bb9f

Please sign in to comment.