Skip to content

Commit b6c576e

Browse files
committed
Move 'connect()' function into asyncpg.connection module.
1 parent 2a427d3 commit b6c576e

File tree

3 files changed

+194
-196
lines changed

3 files changed

+194
-196
lines changed

asyncpg/__init__.py

+1-195
Original file line numberDiff line numberDiff line change
@@ -1,200 +1,6 @@
1-
import asyncio
2-
import getpass
3-
import os
4-
import urllib.parse
5-
1+
from .connection import connect # NOQA
62
from .exceptions import * # NOQA
7-
from . import connection
8-
from . import protocol
93
from .types import * # NOQA
104

115

126
__all__ = ('connect',) + exceptions.__all__ # NOQA
13-
14-
15-
async def connect(dsn=None, *,
16-
host=None, port=None,
17-
user=None, password=None,
18-
database=None,
19-
loop=None,
20-
timeout=60,
21-
statement_cache_size=100,
22-
command_timeout=None,
23-
**opts):
24-
25-
if loop is None:
26-
loop = asyncio.get_event_loop()
27-
28-
host, port, opts = _parse_connect_params(
29-
dsn=dsn, host=host, port=port, user=user, password=password,
30-
database=database, opts=opts)
31-
32-
last_ex = None
33-
addr = None
34-
for h in host:
35-
connected = _create_future(loop)
36-
unix = h.startswith('/')
37-
38-
if unix:
39-
# UNIX socket name
40-
addr = os.path.join(h, '.s.PGSQL.{}'.format(port))
41-
conn = loop.create_unix_connection(
42-
lambda: protocol.Protocol(addr, connected, opts, loop),
43-
addr)
44-
else:
45-
addr = (h, port)
46-
conn = loop.create_connection(
47-
lambda: protocol.Protocol(addr, connected, opts, loop),
48-
h, port)
49-
50-
try:
51-
tr, pr = await asyncio.wait_for(conn, timeout=timeout, loop=loop)
52-
except (OSError, asyncio.TimeoutError) as ex:
53-
last_ex = ex
54-
else:
55-
break
56-
else:
57-
raise last_ex
58-
59-
try:
60-
await connected
61-
except:
62-
tr.close()
63-
raise
64-
65-
con = connection.Connection(
66-
pr, tr, loop, addr, opts,
67-
statement_cache_size=statement_cache_size,
68-
command_timeout=command_timeout)
69-
pr.set_connection(con)
70-
return con
71-
72-
73-
def _parse_connect_params(*, dsn, host, port, user,
74-
password, database, opts):
75-
76-
if dsn:
77-
parsed = urllib.parse.urlparse(dsn)
78-
79-
if parsed.scheme not in {'postgresql', 'postgres'}:
80-
raise ValueError(
81-
'invalid DSN: scheme is expected to be either of '
82-
'"postgresql" or "postgres", got {!r}'.format(parsed.scheme))
83-
84-
if parsed.port and port is None:
85-
port = int(parsed.port)
86-
87-
if parsed.hostname and host is None:
88-
host = parsed.hostname
89-
90-
if parsed.path and database is None:
91-
database = parsed.path
92-
if database.startswith('/'):
93-
database = database[1:]
94-
95-
if parsed.username and user is None:
96-
user = parsed.username
97-
98-
if parsed.password and password is None:
99-
password = parsed.password
100-
101-
if parsed.query:
102-
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
103-
for key, val in query.items():
104-
if isinstance(val, list):
105-
query[key] = val[-1]
106-
107-
if 'host' in query:
108-
val = query.pop('host')
109-
if host is None:
110-
host = val
111-
112-
if 'port' in query:
113-
val = int(query.pop('port'))
114-
if port is None:
115-
port = val
116-
117-
if 'dbname' in query:
118-
val = query.pop('dbname')
119-
if database is None:
120-
database = val
121-
122-
if 'database' in query:
123-
val = query.pop('database')
124-
if database is None:
125-
database = val
126-
127-
if 'user' in query:
128-
val = query.pop('user')
129-
if user is None:
130-
user = val
131-
132-
if 'password' in query:
133-
val = query.pop('password')
134-
if password is None:
135-
password = val
136-
137-
if query:
138-
opts = {**query, **opts}
139-
140-
# On env-var -> connection parameter conversion read here:
141-
# https://www.postgresql.org/docs/current/static/libpq-envars.html
142-
# Note that env values may be an empty string in cases when
143-
# the variable is "unset" by setting it to an empty value
144-
#
145-
if host is None:
146-
host = os.getenv('PGHOST')
147-
if not host:
148-
host = ['/tmp', '/private/tmp',
149-
'/var/pgsql_socket', '/run/postgresql',
150-
'localhost']
151-
if not isinstance(host, list):
152-
host = [host]
153-
154-
if port is None:
155-
port = os.getenv('PGPORT')
156-
if port:
157-
port = int(port)
158-
else:
159-
port = 5432
160-
else:
161-
port = int(port)
162-
163-
if user is None:
164-
user = os.getenv('PGUSER')
165-
if not user:
166-
user = getpass.getuser()
167-
168-
if password is None:
169-
password = os.getenv('PGPASSWORD')
170-
171-
if database is None:
172-
database = os.getenv('PGDATABASE')
173-
174-
if user is not None:
175-
opts['user'] = user
176-
if password is not None:
177-
opts['password'] = password
178-
if database is not None:
179-
opts['database'] = database
180-
181-
for param in opts:
182-
if not isinstance(param, str):
183-
raise ValueError(
184-
'invalid connection parameter {!r} (str expected)'
185-
.format(param))
186-
if not isinstance(opts[param], str):
187-
raise ValueError(
188-
'invalid connection parameter {!r}: {!r} (str expected)'
189-
.format(param, opts[param]))
190-
191-
return host, port, opts
192-
193-
194-
def _create_future(loop):
195-
try:
196-
create_future = loop.create_future
197-
except AttributeError:
198-
return asyncio.Future(loop=loop)
199-
else:
200-
return create_future()

0 commit comments

Comments
 (0)