From 5df27deecff56824d5f93648be71b10585b2757f Mon Sep 17 00:00:00 2001 From: Vitaly Burovoy Date: Fri, 6 Jan 2017 07:45:58 +0000 Subject: [PATCH] Fix parsing server version string for non-final releases (devel, beta etc.) It seems broken from the beginning: d4a66cbd948c45d8ff43f5a383a7ce9b189f08b8 Detect numbers by both sides of a string in the last part (they are not split[1][2] by '.'). [1] 9.4beta2 https://git.postgresql.org/pg/commitdiff/c85374626f680902c207285b986ac38a134535eb [2] 10devel https://git.postgresql.org/pg/commitdiff/ca9112a424ff68ec4f2ef67b47122f7d61412964 --- asyncpg/serverversion.py | 17 +++++++++++++---- tests/test_connect.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/asyncpg/serverversion.py b/asyncpg/serverversion.py index 2e310311..88852e3e 100644 --- a/asyncpg/serverversion.py +++ b/asyncpg/serverversion.py @@ -16,14 +16,23 @@ def split_server_version_string(version_string): parts = version_string.strip().split('.') if not parts[-1].isdigit(): # release level specified - level = parts[-1].rstrip('0123456789').lower() - serial = parts[-1][level:] - versions = [int(p) for p in parts[:-1]][:3] + lastitem = parts[-1] + levelpart = lastitem.rstrip('0123456789').lower() + if levelpart != lastitem: + serial = int(lastitem[len(levelpart):]) + else: + serial = 0 + + level = levelpart.lstrip('0123456789') + if level != levelpart: + parts[-1] = levelpart[:-len(level)] + else: + parts[-1] = 0 else: level = 'final' serial = 0 - versions = [int(p) for p in parts][:3] + versions = [int(p) for p in parts][:3] if len(versions) < 3: versions += [0] * (3 - len(versions)) diff --git a/tests/test_connect.py b/tests/test_connect.py index 09c987e7..89132be4 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -14,6 +14,7 @@ import asyncpg from asyncpg import _testbase as tb from asyncpg.connection import _parse_connect_params +from asyncpg.serverversion import split_server_version_string _system = platform.uname().system @@ -25,6 +26,33 @@ async def test_get_settings_01(self): self.con.get_settings().client_encoding, 'UTF8') + async def test_server_version_01(self): + version = self.con.get_server_version() + version_num = await self.con.fetchval("SELECT current_setting($1)", + 'server_version_num', column=0) + ver_maj = int(version_num[:-4]) + ver_min = int(version_num[-4:-2]) + ver_fix = int(version_num[-2:]) + + self.assertEqual(version[:3], (ver_maj, ver_min, ver_fix)) + + def test_server_version_02(self): + versions = [ + ("9.2", (9, 2, 0, 'final', 0),), + ("9.2.1", (9, 2, 1, 'final', 0),), + ("9.4beta1", (9, 4, 0, 'beta', 1),), + ("10devel", (10, 0, 0, 'devel', 0),), + ("10beta2", (10, 0, 0, 'beta', 2),), + + # Despite the fact after version 10 Postgre's second number + # means "micro", it is parsed "as is" to be + # less confusing in comparisons. + ("10.1", (10, 1, 0, 'final', 0),), + ] + for version, expected in versions: + result = split_server_version_string(version) + self.assertEqual(expected, result) + class TestAuthentication(tb.ConnectedTestCase): def setUp(self):