Skip to content

Commit

Permalink
Use new syntax (set literals and dict comprehensions)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Apr 19, 2020
1 parent 517a40b commit b14a09b
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 42 deletions.
37 changes: 20 additions & 17 deletions pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def dst(self, dt):

# time zones used in Postgres timestamptz output
_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
UCT='+0000', UTC='+0000', WET='+0000')
GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
UCT='+0000', UTC='+0000', WET='+0000')


def _timezone_as_offset(tz):
Expand Down Expand Up @@ -693,7 +693,7 @@ def format_query(self, command, values=None, types=None, inline=False):
len(types) != len(values)):
raise TypeError('The values and types do not match')
literals = [add(value, typ)
for value, typ in zip(values, types)]
for value, typ in zip(values, types)]
else:
literals = [add(value) for value in values]
command %= tuple(literals)
Expand All @@ -712,18 +712,18 @@ def format_query(self, command, values=None, types=None, inline=False):
values = used_values
if inline:
adapt = self.adapt_inline
literals = dict((key, adapt(value))
for key, value in values.items())
literals = {key: adapt(value)
for key, value in values.items()}
else:
add = params.add
if types:
if not isinstance(types, dict):
raise TypeError('The values and types do not match')
literals = dict((key, add(values[key], types.get(key)))
for key in sorted(values))
literals = {key: add(values[key], types.get(key))
for key in sorted(values)}
else:
literals = dict((key, add(values[key]))
for key in sorted(values))
literals = {key: add(values[key])
for key in sorted(values)}
command %= literals
else:
raise TypeError('The values must be passed as tuple, list or dict')
Expand Down Expand Up @@ -827,7 +827,7 @@ def cast_timestamp(value, connection):
if len(value[3]) > 4:
return datetime.max
fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
'%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
'%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
else:
if len(value[0]) > 10:
return datetime.max
Expand Down Expand Up @@ -1159,8 +1159,8 @@ class DbTypes(dict):
information on the associated database type.
"""

_num_types = frozenset('int float num money'
' int2 int4 int8 float4 float8 numeric money'.split())
_num_types = frozenset('int float num money int2 int4 int8'
' float4 float8 numeric money'.split())

def __init__(self, db):
"""Initialize type cache for connection."""
Expand Down Expand Up @@ -1768,7 +1768,7 @@ def get_parameter(self, parameter):
if param == 'all':
q = 'SHOW ALL'
values = self.db.query(q).getresult()
values = dict(value[:2] for value in values)
values = {value[0]: value[1] for value in values}
break
if isinstance(values, dict):
params[param] = key
Expand Down Expand Up @@ -1823,12 +1823,14 @@ def set_parameter(self, parameter, value=None, local=False):
if len(value) == 1:
value = value.pop()
if not(value is None or isinstance(value, basestring)):
raise ValueError('A single value must be specified'
raise ValueError(
'A single value must be specified'
' when parameter is a set')
parameter = dict.fromkeys(parameter, value)
elif isinstance(parameter, dict):
if value is not None:
raise ValueError('A value must not be specified'
raise ValueError(
'A value must not be specified'
' when parameter is a dictionary')
else:
raise TypeError(
Expand All @@ -1843,7 +1845,8 @@ def set_parameter(self, parameter, value=None, local=False):
raise TypeError('Invalid parameter')
if param == 'all':
if value is not None:
raise ValueError('A value must ot be specified'
raise ValueError(
'A value must ot be specified'
" when parameter is 'all'")
params = {'all': None}
break
Expand Down Expand Up @@ -1886,7 +1889,7 @@ def query(self, command, *args):
return self.db.query(command)

def query_formatted(self, command,
parameters=None, types=None, inline=False):
parameters=None, types=None, inline=False):
"""Execute a formatted SQL command string.
Similar to query, but using Python format placeholders of the form
Expand Down
6 changes: 3 additions & 3 deletions pgdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ def dst(self, dt):

# time zones used in Postgres timestamptz output
_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
UCT='+0000', UTC='+0000', WET='+0000')
GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
UCT='+0000', UTC='+0000', WET='+0000')


def _timezone_as_offset(tz):
Expand Down Expand Up @@ -521,7 +521,7 @@ def cast_interval(value):
raise ValueError('Cannot parse interval: %s' % value)
days += 365 * years + 30 * mons
return timedelta(days=days, hours=hours, minutes=mins,
seconds=secs, microseconds=usecs)
seconds=secs, microseconds=usecs)


class Typecasts(dict):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_classic_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,7 +1685,7 @@ def testInserttableFromTupleOfLists(self):
self.assertEqual(self.get_back(), self.data)

def testInserttableFromSetofTuples(self):
data = set(row for row in self.data)
data = {row for row in self.data}
try:
self.c.inserttable('test', data)
except TypeError as e:
Expand Down
33 changes: 18 additions & 15 deletions tests/test_classic_dbwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import pg # the module under test

from collections import OrderedDict
from decimal import Decimal
from datetime import date, time, datetime, timedelta
from uuid import UUID
Expand Down Expand Up @@ -52,8 +53,6 @@
except NameError: # Python >= 3.0
unicode = str

from collections import OrderedDict

if str is bytes: # noinspection PyUnresolvedReferences
from StringIO import StringIO
else: # Python >= 3.0
Expand Down Expand Up @@ -660,10 +659,10 @@ def testGetParameter(self):
self.assertEqual(r, ['hex', 'C'])
r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
r = f(set(['bytea_output', 'lc_monetary']))
r = f({'bytea_output', 'lc_monetary'})
self.assertIsInstance(r, dict)
self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
r = f(set(['Bytea_Output', ' LC_Monetary ']))
r = f({'Bytea_Output', ' LC_Monetary '})
self.assertIsInstance(r, dict)
self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
s = dict.fromkeys(('bytea_output', 'lc_monetary'))
Expand Down Expand Up @@ -720,13 +719,15 @@ def testSetParameter(self):
f(('escape_string_warning', 'standard_conforming_strings'), 'off')
self.assertEqual(g('escape_string_warning'), 'off')
self.assertEqual(g('standard_conforming_strings'), 'off')
f(set(['escape_string_warning', 'standard_conforming_strings']), 'on')
f({'escape_string_warning', 'standard_conforming_strings'}, 'on')
self.assertEqual(g('escape_string_warning'), 'on')
self.assertEqual(g('standard_conforming_strings'), 'on')
self.assertRaises(ValueError, f, set(['escape_string_warning',
'standard_conforming_strings']), ['off', 'on'])
f(set(['escape_string_warning', 'standard_conforming_strings']),
['off', 'off'])
self.assertRaises(
ValueError, f,
{'escape_string_warning', 'standard_conforming_strings'},
['off', 'on'])
f({'escape_string_warning', 'standard_conforming_strings'},
['off', 'off'])
self.assertEqual(g('escape_string_warning'), 'off')
self.assertEqual(g('standard_conforming_strings'), 'off')
f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
Expand Down Expand Up @@ -769,7 +770,7 @@ def testResetParameter(self):
f('standard_conforming_strings', not_scs)
self.assertEqual(g('escape_string_warning'), not_esw)
self.assertEqual(g('standard_conforming_strings'), not_scs)
f(set(['escape_string_warning', 'standard_conforming_strings']))
f({'escape_string_warning', 'standard_conforming_strings'})
self.assertEqual(g('escape_string_warning'), esw)
self.assertEqual(g('standard_conforming_strings'), scs)
db.close()
Expand Down Expand Up @@ -2880,8 +2881,8 @@ def testGetAsList(self):
from_table = '(select lower(name) as n2 from "%s") as t2' % table
r = get_as_list(from_table)
self.assertIsInstance(r, list)
r = set(row[0] for row in r)
expected = set(row[1].lower() for row in names)
r = {row[0] for row in r}
expected = {row[1].lower() for row in names}
self.assertEqual(r, expected)
r = get_as_list(from_table, order='n2', scalar=True)
self.assertIsInstance(r, list)
Expand Down Expand Up @@ -3030,7 +3031,7 @@ def testGetAsDict(self):
self.assertIsInstance(r, OrderedDict)
self.assertEqual(len(r), 0)
# test with unordered query
expected = dict((row[0], row[1:]) for row in colors)
expected = {row[0]: row[1:] for row in colors}
r = get_as_dict(table, order=False)
self.assertIsInstance(r, dict)
self.assertEqual(r, expected)
Expand Down Expand Up @@ -3382,7 +3383,8 @@ def testInsertGetJsonb(self):

def testArray(self):
returns_arrays = pg.get_array()
self.createTable('arraytest',
self.createTable(
'arraytest',
'id smallint, i2 smallint[], i4 integer[], i8 bigint[],'
' d numeric[], f4 real[], f8 double precision[], m money[],'
' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
Expand All @@ -3406,7 +3408,8 @@ def testArray(self):
long_decimal = decimal('12345671234.5')
odd_money = decimal('1234567123.25')
t, f = (True, False) if pg.get_bool() else ('t', 'f')
data = dict(id=42, i2=[42, 1234, None, 0, -1],
data = dict(
id=42, i2=[42, 1234, None, 0, -1],
i4=[42, 123456789, None, 0, 1, -1],
i8=[long(42), long(123456789123456789), None,
long(0), long(1), long(-1)],
Expand Down
10 changes: 4 additions & 6 deletions tests/test_dbapi20.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
except NameError: # Python >= 3.0
long = int

from collections import OrderedDict


class PgBitString:
"""Test object with a PostgreSQL representation as Bit String."""
Expand Down Expand Up @@ -161,8 +159,8 @@ def test_row_factory(self):
class TestCursor(pgdb.Cursor):

def row_factory(self, row):
return dict(('column %s' % desc[0], value)
for desc, value in zip(self.description, row))
return {'column %s' % desc[0]: value
for desc, value in zip(self.description, row)}

con = self._connect()
con.cursor_type = TestCursor
Expand All @@ -188,8 +186,8 @@ class TestCursor(pgdb.Cursor):

def build_row_factory(self):
keys = [desc[0] for desc in self.description]
return lambda row: dict((key, value)
for key, value in zip(keys, row))
return lambda row: {
key: value for key, value in zip(keys, row)}

con = self._connect()
con.cursor_type = TestCursor
Expand Down

0 comments on commit b14a09b

Please sign in to comment.