Skip to content
Permalink
Browse files

pgutil: split postgresql utils out of common.py into a new pgutil.py

The upstream for pgutil is now https://github.com/ohmu/ohmu_common_py but
the files from there are copied verbatim into this project.  The
ohmu_common_py repository may become a proper Python module at some point,
or we may end up phasing it out, but splitting utilities out of common.py
probably makes sense anyway.
  • Loading branch information...
saaros committed May 8, 2016
1 parent 70810ad commit 887addea4ab4d816eb1b3f3f699d992f77298713
Showing with 185 additions and 103 deletions.
  1. +2 −2 pghoard/basebackup.py
  2. +5 −76 pghoard/common.py
  3. +100 −0 pghoard/pgutil.py
  4. +2 −1 pghoard/receivexlog.py
  5. +1 −1 test/test_basebackup.py
  6. +0 −17 test/test_common.py
  7. +2 −1 test/test_pghoard.py
  8. +63 −0 test/test_pgutil.py
  9. +1 −1 test/test_webserver.py
  10. +9 −4 version.py
@@ -4,8 +4,8 @@
Copyright (c) 2016 Ohmu Ltd
See LICENSE for details
"""
from .common import (get_connection_info, set_stream_nonblocking, set_subprocess_stdout_and_stderr_nonblocking,
terminate_subprocess)
from .common import set_stream_nonblocking, set_subprocess_stdout_and_stderr_nonblocking, terminate_subprocess
from .pgutil import get_connection_info
from pghoard.rohmu.compat import suppress
from pghoard.rohmu.compressor import Compressor
from tempfile import NamedTemporaryFile
@@ -4,9 +4,9 @@
Copyright (c) 2016 Ohmu Ltd
See LICENSE for details
"""
from pghoard import pgutil
from pghoard.rohmu.compat import suppress
from pghoard.rohmu.errors import Error, InvalidConfigurationError
from urllib.parse import urlparse, parse_qs
import datetime
import fcntl
import json
@@ -25,84 +25,13 @@
syslog_format_str = "%(name)s %(threadName)s %(levelname)s: %(message)s"


def create_connection_string(connection_info):
return " ".join("{}='{}'".format(k, str(v).replace("'", "\\'"))
for k, v in sorted(connection_info.items()))


def get_connection_info(info):
"""turn a connection info object into a dict or return it if it was a
dict already. supports both the traditional libpq format and the new
url format"""
if isinstance(info, dict):
return info.copy()
elif info.startswith("postgres://") or info.startswith("postgresql://"):
return parse_connection_string_url(info)
else:
return parse_connection_string_libpq(info)


def parse_connection_string_url(url):
p = urlparse(url)
fields = {}
if p.hostname:
fields["host"] = p.hostname
if p.port:
fields["port"] = str(p.port)
if p.username:
fields["user"] = p.username
if p.password is not None:
fields["password"] = p.password
if p.path and p.path != "/":
fields["dbname"] = p.path[1:]
for k, v in parse_qs(p.query).items():
fields[k] = v[-1]
return fields


def parse_connection_string_libpq(connection_string):
"""parse a postgresql connection string as defined in
http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING"""
fields = {}
while True:
connection_string = connection_string.strip()
if not connection_string:
break
if "=" not in connection_string:
raise ValueError("expecting key=value format in connection_string fragment {!r}".format(connection_string))
key, rem = connection_string.split("=", 1)
if rem.startswith("'"):
asis, value = False, ""
for i in range(1, len(rem)):
if asis:
value += rem[i]
asis = False
elif rem[i] == "'":
break # end of entry
elif rem[i] == "\\":
asis = True
else:
value += rem[i]
else:
raise ValueError("invalid connection_string fragment {!r}".format(rem))
connection_string = rem[i + 1:] # pylint: disable=undefined-loop-variable
else:
res = rem.split(None, 1)
if len(res) > 1:
value, connection_string = res
else:
value, connection_string = rem, ""
fields[key] = value
return fields


def create_pgpass_file(connection_string_or_info):
"""Look up password from the given object which can be a dict or a
string and write a possible password in a pgpass file;
returns a connection_string without a password in it"""
info = get_connection_info(connection_string_or_info)
info = pgutil.get_connection_info(connection_string_or_info)
if "password" not in info:
return create_connection_string(info)
return pgutil.create_connection_string(info)
linekey = "{host}:{port}:{dbname}:{user}:".format(
host=info.get("host", "localhost"),
port=info.get("port", 5432),
@@ -125,7 +54,7 @@ def create_pgpass_file(connection_string_or_info):
os.fchmod(fp.fileno(), 0o600)
fp.write(content)
LOG.debug("Wrote %r to %r", pwline, pgpass_path)
return create_connection_string(info)
return pgutil.create_connection_string(info)


def replication_connection_string_using_pgpass(target_node_info):
@@ -144,7 +73,7 @@ def replication_connection_string_using_pgpass(target_node_info):
target_node_info = target_node_info["connection_string"]
# make sure it's a replication connection to the host
# pointed by the key using the "replication" pseudo-db
connection_info = get_connection_info(target_node_info)
connection_info = pgutil.get_connection_info(target_node_info)
connection_info["dbname"] = "replication"
connection_info["replication"] = "true"
connection_string = create_pgpass_file(connection_info)
@@ -0,0 +1,100 @@
# Copied from https://github.com/ohmu/ohmu_common_py ohmu_common_py/pgutil.py version 0.0.1-0-unknown-fa54b44
"""
pghoard - postgresql utility functions
Copyright (c) 2015 Ohmu Ltd
See LICENSE for details
"""

try:
from urllib.parse import urlparse, parse_qs # pylint: disable=no-name-in-module, import-error
except ImportError:
from urlparse import urlparse, parse_qs # pylint: disable=no-name-in-module, import-error


def create_connection_string(connection_info):
return " ".join("{}='{}'".format(k, str(v).replace("'", "\\'"))
for k, v in sorted(connection_info.items()))


def mask_connection_info(info):
masked_info = get_connection_info(info)
password = masked_info.pop("password", None)
return "{0}; {1} password".format(
create_connection_string(masked_info),
"no" if password is None else "hidden")


def get_connection_info_from_config_line(line):
_, value = line.split("=", 1)
value = value.strip()[1:-1].replace("''", "'")
return get_connection_info(value)


def get_connection_info(info):
"""turn a connection info object into a dict or return it if it was a
dict already. supports both the traditional libpq format and the new
url format"""
if isinstance(info, dict):
return info.copy()
elif info.startswith("postgres://") or info.startswith("postgresql://"):
return parse_connection_string_url(info)
else:
return parse_connection_string_libpq(info)


def parse_connection_string_url(url):
# drop scheme from the url as some versions of urlparse don't handle
# query and path properly for urls with a non-http scheme
schemeless_url = url.split(":", 1)[1]
p = urlparse(schemeless_url)
fields = {}
if p.hostname:
fields["host"] = p.hostname
if p.port:
fields["port"] = str(p.port)
if p.username:
fields["user"] = p.username
if p.password is not None:
fields["password"] = p.password
if p.path and p.path != "/":
fields["dbname"] = p.path[1:]
for k, v in parse_qs(p.query).items():
fields[k] = v[-1]
return fields


def parse_connection_string_libpq(connection_string):
"""parse a postgresql connection string as defined in
http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING"""
fields = {}
while True:
connection_string = connection_string.strip()
if not connection_string:
break
if "=" not in connection_string:
raise ValueError("expecting key=value format in connection_string fragment {!r}".format(connection_string))
key, rem = connection_string.split("=", 1)
if rem.startswith("'"):
asis, value = False, ""
for i in range(1, len(rem)):
if asis:
value += rem[i]
asis = False
elif rem[i] == "'":
break # end of entry
elif rem[i] == "\\":
asis = True
else:
value += rem[i]
else:
raise ValueError("invalid connection_string fragment {!r}".format(rem))
connection_string = rem[i + 1:] # pylint: disable=undefined-loop-variable
else:
res = rem.split(None, 1)
if len(res) > 1:
value, connection_string = res
else:
value, connection_string = rem, ""
fields[key] = value
return fields
@@ -11,7 +11,8 @@
import subprocess
import time

from . common import get_connection_info, set_subprocess_stdout_and_stderr_nonblocking, terminate_subprocess
from .common import set_subprocess_stdout_and_stderr_nonblocking, terminate_subprocess
from .pgutil import get_connection_info
from threading import Thread


@@ -6,7 +6,7 @@
"""
from copy import deepcopy
from pghoard.basebackup import PGBaseBackup
from pghoard.common import create_connection_string
from pghoard.pgutil import create_connection_string
from pghoard.restore import Restore, RestoreError
from queue import Queue
import os
@@ -9,7 +9,6 @@
create_pgpass_file,
convert_pg_command_version_to_number,
default_json_serialization,
get_connection_info,
json_encode,
write_json_file,
)
@@ -50,22 +49,6 @@ def get_pgpass_contents():
assert get_pgpass_contents() == b'localhost:5432:replication:another:bar\nlocalhost:5432:replication:foo:xyz\n'
os.environ['HOME'] = original_home

def test_connection_info(self):
url = "postgres://hannu:secret@dbhost.local:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='hannu' dbname='abc'\n" \
"replication=true password=secret sslmode=require port=5555"
ci = {
"host": "dbhost.local",
"port": "5555",
"user": "hannu",
"password": "secret",
"dbname": "abc",
"replication": "true",
"sslmode": "require",
}
assert get_connection_info(ci) == get_connection_info(cs)
assert get_connection_info(ci) == get_connection_info(url)

def test_pg_versions(self):
assert convert_pg_command_version_to_number("foobar (PostgreSQL) 9.4.1") == 90401
assert convert_pg_command_version_to_number("asdf (PostgreSQL) 9.5alpha1") == 90500
@@ -6,8 +6,9 @@
"""
# pylint: disable=attribute-defined-outside-init
from .base import PGHoardTestCase
from pghoard.common import create_alert_file, create_connection_string, delete_alert_file, write_json_file
from pghoard.common import create_alert_file, delete_alert_file, write_json_file
from pghoard.pghoard import PGHoard
from pghoard.pgutil import create_connection_string
from unittest.mock import Mock, patch
import datetime
import json
@@ -0,0 +1,63 @@
# Copied from https://github.com/ohmu/ohmu_common_py test/test_pgutil.py version 0.0.1-0-unknown-fa54b44
"""
pghoard - postgresql utility function tests
Copyright (c) 2015 Ohmu Ltd
See LICENSE for details
"""

from pghoard.pgutil import (
create_connection_string, get_connection_info, mask_connection_info,
)
from pytest import raises


def test_connection_info():
url = "postgres://hannu:secret@dbhost.local:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='hannu' dbname='abc'\n" \
"replication=true password=secret sslmode=require port=5555"
ci = {
"host": "dbhost.local",
"port": "5555",
"user": "hannu",
"password": "secret",
"dbname": "abc",
"replication": "true",
"sslmode": "require",
}
assert get_connection_info(ci) == get_connection_info(cs)
assert get_connection_info(ci) == get_connection_info(url)

basic_cstr = "host=localhost user=os"
assert create_connection_string(get_connection_info(basic_cstr)) == "host='localhost' user='os'"

assert get_connection_info("foo=bar bar='\\'x'") == {"foo": "bar", "bar": "'x"}

with raises(ValueError):
get_connection_info("foo=bar x")
with raises(ValueError):
get_connection_info("foo=bar bar='x")


def test_mask_connection_info():
url = "postgres://michael:secret@dbhost.local:5555/abc?replication=true&sslmode=foobar&sslmode=require"
cs = "host=dbhost.local user='michael' dbname='abc'\n" \
"replication=true password=secret sslmode=require port=5555"
ci = get_connection_info(cs)
masked_url = mask_connection_info(url)
masked_cs = mask_connection_info(url)
masked_ci = mask_connection_info(url)
assert masked_url == masked_cs
assert masked_url == masked_ci
assert "password" in ci # make sure we didn't modify the original dict

# the return format is a connection string without password, followed by
# a semicolon and comment about password presence
masked_str, password_info = masked_url.split("; ", 1)
assert "password" not in masked_str
assert password_info == "hidden password"

# remasking the masked string should yield a no password comment
masked_masked = mask_connection_info(masked_str)
_, masked_password_info = masked_masked.split("; ", 1)
assert masked_password_info == "no password"
@@ -11,7 +11,7 @@
from http.client import HTTPConnection
from pghoard import postgres_command, wal
from pghoard.archive_sync import ArchiveSync
from pghoard.common import create_connection_string
from pghoard.pgutil import create_connection_string
from pghoard.postgres_command import archive_command, restore_command
from pghoard.restore import HTTPRestore, Restore
from pghoard.rohmu.encryptor import Encryptor
@@ -1,6 +1,9 @@
# Copied from https://github.com/ohmu/ohmu_common_py version.py version 0.0.1-0-unknown-fa54b44
"""
automatically maintains the latest git tag + revision info in a python file
pghoard - version detection and version.py __version__ generation
Copyright (c) 2015 Ohmu Ltd
See LICENSE for details
"""

import imp
@@ -23,17 +26,19 @@ def get_project_version(version_file):
try:
module = imp.load_source("verfile", version_file)
file_ver = module.__version__
except:
except IOError:
file_ver = None

os.chdir(os.path.dirname(__file__) or ".")
try:
git_out = subprocess.check_output(["git", "describe", "--always"],
stderr=subprocess.DEVNULL)
except (FileNotFoundError, subprocess.CalledProcessError):
stderr=getattr(subprocess, "DEVNULL", None))
except (OSError, subprocess.CalledProcessError):
pass
else:
git_ver = git_out.splitlines()[0].strip().decode("utf-8")
if "." not in git_ver:
git_ver = "0.0.1-0-unknown-{}".format(git_ver)
if save_version(git_ver, file_ver, version_file):
return git_ver

0 comments on commit 887adde

Please sign in to comment.
You can’t perform that action at this time.