From 98ad0fe07e642e19a293f9db282541ced2201d8a Mon Sep 17 00:00:00 2001 From: Richard Schwab Date: Sun, 28 May 2023 04:36:34 +0200 Subject: [PATCH] Modernize code with `pyupgrade --py37-plus` --- CHANGES.txt | 2 ++ aiomysql/connection.py | 24 ++++++++++++------------ aiomysql/cursors.py | 4 ++-- aiomysql/sa/transaction.py | 4 ++-- docs/conf.py | 1 - tests/conftest.py | 6 +++--- tests/test_basic.py | 6 +++--- tests/test_issues.py | 8 ++++---- tests/test_sha_connection.py | 2 +- tests/test_sscursor.py | 8 ++++---- 10 files changed, 33 insertions(+), 32 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index ac78620b..cc9ebe1a 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -14,6 +14,8 @@ next (unreleased) * Fix debug log level with sha256_password authentication #863 +* Modernized code with `pyupgrade `_ to Python 3.7+ syntax #930 + 0.1.1 (2022-05-08) ^^^^^^^^^^^^^^^^^^ diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 2c559f92..3520dfcc 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -542,7 +542,7 @@ async def _connect(self): self.connected_time = self._loop.time() if self.sql_mode is not None: - await self.query("SET sql_mode=%s" % (self.sql_mode,)) + await self.query(f"SET sql_mode={self.sql_mode}") if self.init_command is not None: await self.query(self.init_command) @@ -659,8 +659,8 @@ async def _read_bytes(self, num_bytes): msg = "Lost connection to MySQL server during query" self.close() raise OperationalError(CR.CR_SERVER_LOST, msg) from e - except (IOError, OSError) as e: - msg = "Lost connection to MySQL server during query (%s)" % (e,) + except OSError as e: + msg = f"Lost connection to MySQL server during query ({e})" self.close() raise OperationalError(CR.CR_SERVER_LOST, msg) from e return data @@ -899,7 +899,7 @@ async def _process_auth(self, plugin_name, auth_packet): data = self._password.encode('latin1') + b'\0' else: raise OperationalError( - 2059, "Authentication plugin '{0}'" + 2059, "Authentication plugin '{}'" " not configured".format(plugin_name) ) @@ -936,7 +936,7 @@ async def caching_sha2_password_auth(self, pkt): if not pkt.is_extra_auth_data(): raise OperationalError( "caching sha2: Unknown packet " - "for fast auth: {0}".format(pkt._data[:1]) + "for fast auth: {}".format(pkt._data[:1]) ) # magic numbers: @@ -955,7 +955,7 @@ async def caching_sha2_password_auth(self, pkt): if n != 4: raise OperationalError("caching sha2: Unknown " - "result for fast auth: {0}".format(n)) + "result for fast auth: {}".format(n)) logger.debug("caching sha2: Trying full auth...") @@ -975,7 +975,7 @@ async def caching_sha2_password_auth(self, pkt): if not pkt.is_extra_auth_data(): raise OperationalError( "caching sha2: Unknown packet " - "for public key: {0}".format(pkt._data[:1]) + "for public key: {}".format(pkt._data[:1]) ) self.server_public_key = pkt._data[1:] @@ -1126,7 +1126,7 @@ def _ensure_alive(self): def __del__(self): if self._writer: - warnings.warn("Unclosed connection {!r}".format(self), + warnings.warn(f"Unclosed connection {self!r}", ResourceWarning) self.close() @@ -1351,7 +1351,7 @@ async def _get_descriptions(self): self.description = tuple(description) -class LoadLocalFile(object): +class LoadLocalFile: def __init__(self, filename, connection): self.filename = filename self.connection = connection @@ -1364,8 +1364,8 @@ def _open_file(self): def opener(filename): try: self._file_object = open(filename, 'rb') - except IOError as e: - msg = "Can't find file '{0}'".format(filename) + except OSError as e: + msg = f"Can't find file '{filename}'" raise OperationalError(1017, msg) from e fut = self._loop.run_in_executor(self._executor, opener, self.filename) @@ -1384,7 +1384,7 @@ def freader(chunk_size): except Exception as e: self._file_object.close() self._file_object = None - msg = "Error reading file {}".format(self.filename) + msg = f"Error reading file {self.filename}" raise OperationalError(1024, msg) from e return chunk diff --git a/aiomysql/cursors.py b/aiomysql/cursors.py index 3401bdbf..35d29d72 100644 --- a/aiomysql/cursors.py +++ b/aiomysql/cursors.py @@ -196,7 +196,7 @@ def _escape_args(self, args, conn): if isinstance(args, (tuple, list)): return tuple(conn.escape(arg) for arg in args) elif isinstance(args, dict): - return dict((key, conn.escape(val)) for (key, val) in args.items()) + return {key: conn.escape(val) for (key, val) in args.items()} else: # If it's not a dictionary let's try escaping it anyways. # Worst case it will throw a Value error @@ -357,7 +357,7 @@ async def callproc(self, procname, args=()): await self.nextset() _args = ','.join('@_%s_%d' % (procname, i) for i in range(len(args))) - q = "CALL %s(%s)" % (procname, _args) + q = f"CALL {procname}({_args})" await self._query(q) self._executed = q return args diff --git a/aiomysql/sa/transaction.py b/aiomysql/sa/transaction.py index ff15ac08..bb17e04a 100644 --- a/aiomysql/sa/transaction.py +++ b/aiomysql/sa/transaction.py @@ -3,7 +3,7 @@ from . import exc -class Transaction(object): +class Transaction: """Represent a database transaction in progress. The Transaction object is procured by @@ -114,7 +114,7 @@ class NestedTransaction(Transaction): _savepoint = None def __init__(self, connection, parent): - super(NestedTransaction, self).__init__(connection, parent) + super().__init__(connection, parent) async def _do_rollback(self): assert self._savepoint is not None, "Broken transaction logic" diff --git a/docs/conf.py b/docs/conf.py index 027c81c5..159f3ac9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # # aiomysql documentation build configuration file, created by # sphinx-quickstart on Sun Jan 18 22:02:31 2015. diff --git a/tests/conftest.py b/tests/conftest.py index 6aaae6a9..047dce87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -40,7 +40,7 @@ def pytest_generate_tests(metafunc): ids.append(label) else: mysql_addresses.append(opt_mysql_unix_socket[i]) - ids.append("unix{}".format(i)) + ids.append(f"unix{i}") opt_mysql_address = list(metafunc.config.getoption("mysql_address")) for i in range(len(opt_mysql_address)): @@ -49,7 +49,7 @@ def pytest_generate_tests(metafunc): ids.append(label) else: addr = opt_mysql_address[i] - ids.append("tcp{}".format(i)) + ids.append(f"tcp{i}") if ":" in addr: addr = addr.split(":", 1) @@ -232,7 +232,7 @@ def _register_table(table_name): yield _register_table for t in table_list: # TODO: probably this is not safe code - sql = "DROP TABLE IF EXISTS {};".format(t) + sql = f"DROP TABLE IF EXISTS {t};" loop.run_until_complete(cursor.execute(sql)) diff --git a/tests/test_basic.py b/tests/test_basic.py index b0a939bf..78f0daa7 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -28,7 +28,7 @@ async def test_datatypes(connection, cursor, datatype_table): # insert values v = ( True, -3, 123456789012, 5.7, "hello'\" world", - u"Espa\xc3\xb1ol", + "Espa\xc3\xb1ol", "binary\x00data".encode(encoding), datetime.date(1988, 2, 2), datetime.datetime.now().replace(microsecond=0), @@ -148,10 +148,10 @@ async def test_binary_data(cursor, table_cleanup): async def test_untyped_convertion_to_null_and_empty_string(cursor): await cursor.execute("select null,''") r = await cursor.fetchone() - assert (None, u'') == r + assert (None, '') == r await cursor.execute("select '',null") r = await cursor.fetchone() - assert (u'', None) == r + assert ('', None) == r @pytest.mark.run_loop diff --git a/tests/test_issues.py b/tests/test_issues.py index e60a5103..77ed4267 100644 --- a/tests/test_issues.py +++ b/tests/test_issues.py @@ -122,10 +122,10 @@ async def test_issue_15(connection): await c.execute("create table issue15 (t varchar(32))") try: await c.execute("insert into issue15 (t) values (%s)", - (u'\xe4\xf6\xfc',)) + ('\xe4\xf6\xfc',)) await c.execute("select t from issue15") r = await c.fetchone() - assert u'\xe4\xf6\xfc' == r[0] + assert '\xe4\xf6\xfc' == r[0] finally: await c.execute("drop table issue15") @@ -412,8 +412,8 @@ async def test_issue_175(connection): conn = connection cur = await conn.cursor() for length in (200, 300): - cols = ', '.join('c{0} integer'.format(i) for i in range(length)) - sql = 'create table test_field_count ({0})'.format(cols) + cols = ', '.join(f'c{i} integer' for i in range(length)) + sql = f'create table test_field_count ({cols})' try: await cur.execute(sql) await cur.execute('select * from test_field_count') diff --git a/tests/test_sha_connection.py b/tests/test_sha_connection.py index 47baa0a6..c716a57a 100644 --- a/tests/test_sha_connection.py +++ b/tests/test_sha_connection.py @@ -24,7 +24,7 @@ def ensure_mysql_version(mysql_server): if mysql_server["db_type"] != "mysql" \ or mysql_server["server_version_tuple_short"] != (8, 0): - pytest.skip("Not applicable for {0} version: {1}" + pytest.skip("Not applicable for {} version: {}" .format(mysql_server["db_type"], mysql_server["server_version_tuple_short"])) diff --git a/tests/test_sscursor.py b/tests/test_sscursor.py index eff2ee33..de9da609 100644 --- a/tests/test_sscursor.py +++ b/tests/test_sscursor.py @@ -42,10 +42,10 @@ async def test_ssursor(connection): cursor = await conn.cursor(SSCursor) # Create table await cursor.execute('DROP TABLE IF EXISTS tz_data;') - await cursor.execute(('CREATE TABLE tz_data (' - 'region VARCHAR(64),' - 'zone VARCHAR(64),' - 'name VARCHAR(64))')) + await cursor.execute('CREATE TABLE tz_data (' + 'region VARCHAR(64),' + 'zone VARCHAR(64),' + 'name VARCHAR(64))') # Test INSERT for i in DATA: