Skip to content

Commit

Permalink
pgutil: split postgresql utils out of common.py into a new pgutil.py
Browse files Browse the repository at this point in the history
The upstream for pgutil is now https://github.com/ohmu/ohmu_common_py but
the files from there are copied verbatim into this project.  The
ohmu_common_py repository may become a proper Python module at some point,
or we may end up phasing it out, but splitting utilities out of common.py
probably makes sense anyway.
  • Loading branch information
saaros committed May 27, 2016
1 parent bfa976d commit 99603a5
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 154 deletions.
5 changes: 2 additions & 3 deletions pglookout/cluster_monitor.py
Expand Up @@ -8,9 +8,8 @@
See the file `LICENSE` for details.
"""

from .common import (
mask_connection_info, get_iso_timestamp, parse_iso_datetime,
set_syslog_handler)
from .common import get_iso_timestamp, parse_iso_datetime, set_syslog_handler
from .pgutil import mask_connection_info
from concurrent.futures import as_completed, ThreadPoolExecutor
from email.utils import parsedate
from psycopg2.extras import RealDictCursor
Expand Down
94 changes: 0 additions & 94 deletions pglookout/common.py
Expand Up @@ -4,108 +4,14 @@
Copyright (c) 2015 Ohmu Ltd
See LICENSE for details
"""

import datetime
import logging
import re
try:
from urllib.parse import urlparse, parse_qs # pylint: disable=no-name-in-module, import-error
except ImportError:
from urlparse import urlparse, parse_qs # pylint: disable=no-name-in-module, import-error


LOG_FORMAT = "%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s"
LOG_FORMAT_SYSLOG = '%(name)s %(levelname)s: %(message)s'


def create_connection_string(connection_info):
return " ".join("{0}='{1}'".format(k, str(v).replace("'", "\\'"))
for k, v in sorted(connection_info.items()))


def mask_connection_info(info):
masked_info = get_connection_info(info)
password = masked_info.pop("password", None)
return "{0}; {1} password".format(
create_connection_string(masked_info),
"no" if password is None else "hidden")


def get_connection_info_from_config_line(line):
_, value = line.split("=", 1)
value = value.strip()[1:-1].replace("''", "'")
return get_connection_info(value)


def get_connection_info(info):
"""turn a connection info object into a dict or return it if it was a
dict already. supports both the traditional libpq format and the new
url format"""
if isinstance(info, dict):
return info.copy()
elif info.startswith("postgres://") or info.startswith("postgresql://"):
return parse_connection_string_url(info)
else:
return parse_connection_string_libpq(info)


def parse_connection_string_url(url):
# drop scheme from the url as some versions of urlparse don't handle
# query and path properly for urls with a non-http scheme
schemeless_url = url.split(":", 1)[1]
p = urlparse(schemeless_url)
fields = {}
if p.hostname:
fields["host"] = p.hostname
if p.port:
fields["port"] = str(p.port)
if p.username:
fields["user"] = p.username
if p.password is not None:
fields["password"] = p.password
if p.path and p.path != "/":
fields["dbname"] = p.path[1:]
for k, v in parse_qs(p.query).items():
fields[k] = v[-1]
return fields


def parse_connection_string_libpq(connection_string):
"""parse a postgresql connection string as defined in
http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING"""
fields = {}
while True:
connection_string = connection_string.strip()
if not connection_string:
break
if "=" not in connection_string:
raise ValueError("expecting key=value format in connection_string fragment {!r}".format(connection_string))
key, rem = connection_string.split("=", 1)
if rem.startswith("'"):
asis, value = False, ""
for i in range(1, len(rem)):
if asis:
value += rem[i]
asis = False
elif rem[i] == "'":
break # end of entry
elif rem[i] == "\\":
asis = True
else:
value += rem[i]
else:
raise ValueError("invalid connection_string fragment {!r}".format(rem))
connection_string = rem[i + 1:] # pylint: disable=undefined-loop-variable
else:
res = rem.split(None, 1)
if len(res) > 1:
value, connection_string = res
else:
value, connection_string = rem, ""
fields[key] = value
return fields


def convert_xlog_location_to_offset(xlog_location):
log_id, offset = xlog_location.split("/")
return int(log_id, 16) << 32 | int(offset, 16)
Expand Down
3 changes: 2 additions & 1 deletion pglookout/pglookout.py
Expand Up @@ -12,9 +12,10 @@
from . import statsd, version
from .cluster_monitor import ClusterMonitor
from .common import (
create_connection_string, get_connection_info, get_connection_info_from_config_line,
convert_xlog_location_to_offset, parse_iso_datetime, get_iso_timestamp,
set_syslog_handler, LOG_FORMAT, LOG_FORMAT_SYSLOG)
from .pgutil import (
create_connection_string, get_connection_info, get_connection_info_from_config_line)
from .webserver import WebServer
from psycopg2.extensions import adapt
import argparse
Expand Down
100 changes: 100 additions & 0 deletions pglookout/pgutil.py
@@ -0,0 +1,100 @@
# Copied from https://github.com/ohmu/ohmu_common_py ohmu_common_py/pgutil.py version 0.0.1-0-unknown-fa54b44
"""
pglookout - postgresql utility functions
Copyright (c) 2015 Ohmu Ltd
See LICENSE for details
"""

try:
from urllib.parse import urlparse, parse_qs # pylint: disable=no-name-in-module, import-error
except ImportError:
from urlparse import urlparse, parse_qs # pylint: disable=no-name-in-module, import-error


def create_connection_string(connection_info):
return " ".join("{}='{}'".format(k, str(v).replace("'", "\\'"))
for k, v in sorted(connection_info.items()))


def mask_connection_info(info):
masked_info = get_connection_info(info)
password = masked_info.pop("password", None)
return "{0}; {1} password".format(
create_connection_string(masked_info),
"no" if password is None else "hidden")


def get_connection_info_from_config_line(line):
_, value = line.split("=", 1)
value = value.strip()[1:-1].replace("''", "'")
return get_connection_info(value)


def get_connection_info(info):
"""turn a connection info object into a dict or return it if it was a
dict already. supports both the traditional libpq format and the new
url format"""
if isinstance(info, dict):
return info.copy()
elif info.startswith("postgres://") or info.startswith("postgresql://"):
return parse_connection_string_url(info)
else:
return parse_connection_string_libpq(info)


def parse_connection_string_url(url):
# drop scheme from the url as some versions of urlparse don't handle
# query and path properly for urls with a non-http scheme
schemeless_url = url.split(":", 1)[1]
p = urlparse(schemeless_url)
fields = {}
if p.hostname:
fields["host"] = p.hostname
if p.port:
fields["port"] = str(p.port)
if p.username:
fields["user"] = p.username
if p.password is not None:
fields["password"] = p.password
if p.path and p.path != "/":
fields["dbname"] = p.path[1:]
for k, v in parse_qs(p.query).items():
fields[k] = v[-1]
return fields


def parse_connection_string_libpq(connection_string):
"""parse a postgresql connection string as defined in
http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING"""
fields = {}
while True:
connection_string = connection_string.strip()
if not connection_string:
break
if "=" not in connection_string:
raise ValueError("expecting key=value format in connection_string fragment {!r}".format(connection_string))
key, rem = connection_string.split("=", 1)
if rem.startswith("'"):
asis, value = False, ""
for i in range(1, len(rem)):
if asis:
value += rem[i]
asis = False
elif rem[i] == "'":
break # end of entry
elif rem[i] == "\\":
asis = True
else:
value += rem[i]
else:
raise ValueError("invalid connection_string fragment {!r}".format(rem))
connection_string = rem[i + 1:] # pylint: disable=undefined-loop-variable
else:
res = rem.split(None, 1)
if len(res) > 1:
value, connection_string = res
else:
value, connection_string = rem, ""
fields[key] = value
return fields
52 changes: 0 additions & 52 deletions test/test_common.py
Expand Up @@ -6,65 +6,13 @@
"""

from pglookout.common import (
create_connection_string, get_connection_info, mask_connection_info,
convert_xlog_location_to_offset,
parse_iso_datetime, get_iso_timestamp, ISO_EXT_RE,
)
from pytest import raises
import datetime


def test_connection_info():
url = "postgres://hannu:secret@dbhost.local:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='hannu' dbname='abc'\n" \
"replication=true password=secret sslmode=require port=5555"
ci = {
"host": "dbhost.local",
"port": "5555",
"user": "hannu",
"password": "secret",
"dbname": "abc",
"replication": "true",
"sslmode": "require",
}
assert get_connection_info(ci) == get_connection_info(cs)
assert get_connection_info(ci) == get_connection_info(url)

basic_cstr = "host=localhost user=os"
assert create_connection_string(get_connection_info(basic_cstr)) == "host='localhost' user='os'"

assert get_connection_info("foo=bar bar='\\'x'") == {"foo": "bar", "bar": "'x"}

with raises(ValueError):
get_connection_info("foo=bar x")
with raises(ValueError):
get_connection_info("foo=bar bar='x")


def test_mask_connection_info():
url = "postgres://michael:secret@dbhost.local:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='michael' dbname='abc'\n" \
"replication=true password=secret sslmode=require port=5555"
ci = get_connection_info(cs)
masked_url = mask_connection_info(url)
masked_cs = mask_connection_info(url)
masked_ci = mask_connection_info(url)
assert masked_url == masked_cs
assert masked_url == masked_ci
assert "password" in ci # make sure we didn't modify the original dict

# the return format is a connection string without password, followed by
# a semicolon and comment about password presence
masked_str, password_info = masked_url.split("; ", 1)
assert "password" not in masked_str
assert password_info == "hidden password"

# remasking the masked string should yield a no password comment
masked_masked = mask_connection_info(masked_str)
_, masked_password_info = masked_masked.split("; ", 1)
assert masked_password_info == "no password"


def test_convert_xlog_location_to_offset():
assert convert_xlog_location_to_offset("1/00000000") == 1 << 32
assert convert_xlog_location_to_offset("F/AAAAAAAA") == (0xF << 32) | 0xAAAAAAAA
Expand Down
5 changes: 2 additions & 3 deletions test/test_lookout.py
Expand Up @@ -8,10 +8,9 @@
See the file `LICENSE` for details.
"""

from pglookout.common import (
get_connection_info, get_connection_info_from_config_line,
get_iso_timestamp)
from pglookout.common import get_iso_timestamp
from pglookout.pglookout import PgLookout
from pglookout.pgutil import get_connection_info, get_connection_info_from_config_line
try:
from mock import Mock # pylint: disable=import-error
except ImportError: # py3k import location
Expand Down
63 changes: 63 additions & 0 deletions test/test_pgutil.py
@@ -0,0 +1,63 @@
# Copied from https://github.com/ohmu/ohmu_common_py test/test_pgutil.py version 0.0.1-0-unknown-fa54b44
"""
pglookout - postgresql utility function tests
Copyright (c) 2015 Ohmu Ltd
See LICENSE for details
"""

from pglookout.pgutil import (
create_connection_string, get_connection_info, mask_connection_info,
)
from pytest import raises


def test_connection_info():
url = "postgres://hannu:secret@dbhost.local:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='hannu' dbname='abc'\n" \
"replication=true password=secret sslmode=require port=5555"
ci = {
"host": "dbhost.local",
"port": "5555",
"user": "hannu",
"password": "secret",
"dbname": "abc",
"replication": "true",
"sslmode": "require",
}
assert get_connection_info(ci) == get_connection_info(cs)
assert get_connection_info(ci) == get_connection_info(url)

basic_cstr = "host=localhost user=os"
assert create_connection_string(get_connection_info(basic_cstr)) == "host='localhost' user='os'"

assert get_connection_info("foo=bar bar='\\'x'") == {"foo": "bar", "bar": "'x"}

with raises(ValueError):
get_connection_info("foo=bar x")
with raises(ValueError):
get_connection_info("foo=bar bar='x")


def test_mask_connection_info():
url = "postgres://michael:secret@dbhost.local:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='michael' dbname='abc'\n" \
"replication=true password=secret sslmode=require port=5555"
ci = get_connection_info(cs)
masked_url = mask_connection_info(url)
masked_cs = mask_connection_info(url)
masked_ci = mask_connection_info(url)
assert masked_url == masked_cs
assert masked_url == masked_ci
assert "password" in ci # make sure we didn't modify the original dict

# the return format is a connection string without password, followed by
# a semicolon and comment about password presence
masked_str, password_info = masked_url.split("; ", 1)
assert "password" not in masked_str
assert password_info == "hidden password"

# remasking the masked string should yield a no password comment
masked_masked = mask_connection_info(masked_str)
_, masked_password_info = masked_masked.split("; ", 1)
assert masked_password_info == "no password"
7 changes: 6 additions & 1 deletion version.py
@@ -1,6 +1,9 @@
# Copied from https://github.com/ohmu/ohmu_common_py version.py version 0.0.1-0-unknown-fa54b44
"""
automatically maintains the latest git tag + revision info in a python file
pglookout - version detection and version.py __version__ generation
Copyright (c) 2015 Ohmu Ltd
See LICENSE for details
"""

import imp
Expand Down Expand Up @@ -34,6 +37,8 @@ def get_project_version(version_file):
pass
else:
git_ver = git_out.splitlines()[0].strip().decode("utf-8")
if "." not in git_ver:
git_ver = "0.0.1-0-unknown-{}".format(git_ver)
if save_version(git_ver, file_ver, version_file):
return git_ver

Expand Down

0 comments on commit 99603a5

Please sign in to comment.