|
1 |
| -import asyncio |
2 |
| -import getpass |
3 |
| -import os |
4 |
| -import urllib.parse |
5 |
| - |
| 1 | +from .connection import connect # NOQA |
6 | 2 | from .exceptions import * # NOQA
|
7 |
| -from . import connection |
8 |
| -from . import protocol |
9 | 3 | from .types import * # NOQA
|
10 | 4 |
|
11 | 5 |
|
12 | 6 | __all__ = ('connect',) + exceptions.__all__ # NOQA
|
13 |
| - |
14 |
| - |
15 |
| -async def connect(dsn=None, *, |
16 |
| - host=None, port=None, |
17 |
| - user=None, password=None, |
18 |
| - database=None, |
19 |
| - loop=None, |
20 |
| - timeout=60, |
21 |
| - statement_cache_size=100, |
22 |
| - command_timeout=None, |
23 |
| - **opts): |
24 |
| - |
25 |
| - if loop is None: |
26 |
| - loop = asyncio.get_event_loop() |
27 |
| - |
28 |
| - host, port, opts = _parse_connect_params( |
29 |
| - dsn=dsn, host=host, port=port, user=user, password=password, |
30 |
| - database=database, opts=opts) |
31 |
| - |
32 |
| - last_ex = None |
33 |
| - addr = None |
34 |
| - for h in host: |
35 |
| - connected = _create_future(loop) |
36 |
| - unix = h.startswith('/') |
37 |
| - |
38 |
| - if unix: |
39 |
| - # UNIX socket name |
40 |
| - addr = os.path.join(h, '.s.PGSQL.{}'.format(port)) |
41 |
| - conn = loop.create_unix_connection( |
42 |
| - lambda: protocol.Protocol(addr, connected, opts, loop), |
43 |
| - addr) |
44 |
| - else: |
45 |
| - addr = (h, port) |
46 |
| - conn = loop.create_connection( |
47 |
| - lambda: protocol.Protocol(addr, connected, opts, loop), |
48 |
| - h, port) |
49 |
| - |
50 |
| - try: |
51 |
| - tr, pr = await asyncio.wait_for(conn, timeout=timeout, loop=loop) |
52 |
| - except (OSError, asyncio.TimeoutError) as ex: |
53 |
| - last_ex = ex |
54 |
| - else: |
55 |
| - break |
56 |
| - else: |
57 |
| - raise last_ex |
58 |
| - |
59 |
| - try: |
60 |
| - await connected |
61 |
| - except: |
62 |
| - tr.close() |
63 |
| - raise |
64 |
| - |
65 |
| - con = connection.Connection( |
66 |
| - pr, tr, loop, addr, opts, |
67 |
| - statement_cache_size=statement_cache_size, |
68 |
| - command_timeout=command_timeout) |
69 |
| - pr.set_connection(con) |
70 |
| - return con |
71 |
| - |
72 |
| - |
73 |
| -def _parse_connect_params(*, dsn, host, port, user, |
74 |
| - password, database, opts): |
75 |
| - |
76 |
| - if dsn: |
77 |
| - parsed = urllib.parse.urlparse(dsn) |
78 |
| - |
79 |
| - if parsed.scheme not in {'postgresql', 'postgres'}: |
80 |
| - raise ValueError( |
81 |
| - 'invalid DSN: scheme is expected to be either of ' |
82 |
| - '"postgresql" or "postgres", got {!r}'.format(parsed.scheme)) |
83 |
| - |
84 |
| - if parsed.port and port is None: |
85 |
| - port = int(parsed.port) |
86 |
| - |
87 |
| - if parsed.hostname and host is None: |
88 |
| - host = parsed.hostname |
89 |
| - |
90 |
| - if parsed.path and database is None: |
91 |
| - database = parsed.path |
92 |
| - if database.startswith('/'): |
93 |
| - database = database[1:] |
94 |
| - |
95 |
| - if parsed.username and user is None: |
96 |
| - user = parsed.username |
97 |
| - |
98 |
| - if parsed.password and password is None: |
99 |
| - password = parsed.password |
100 |
| - |
101 |
| - if parsed.query: |
102 |
| - query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) |
103 |
| - for key, val in query.items(): |
104 |
| - if isinstance(val, list): |
105 |
| - query[key] = val[-1] |
106 |
| - |
107 |
| - if 'host' in query: |
108 |
| - val = query.pop('host') |
109 |
| - if host is None: |
110 |
| - host = val |
111 |
| - |
112 |
| - if 'port' in query: |
113 |
| - val = int(query.pop('port')) |
114 |
| - if port is None: |
115 |
| - port = val |
116 |
| - |
117 |
| - if 'dbname' in query: |
118 |
| - val = query.pop('dbname') |
119 |
| - if database is None: |
120 |
| - database = val |
121 |
| - |
122 |
| - if 'database' in query: |
123 |
| - val = query.pop('database') |
124 |
| - if database is None: |
125 |
| - database = val |
126 |
| - |
127 |
| - if 'user' in query: |
128 |
| - val = query.pop('user') |
129 |
| - if user is None: |
130 |
| - user = val |
131 |
| - |
132 |
| - if 'password' in query: |
133 |
| - val = query.pop('password') |
134 |
| - if password is None: |
135 |
| - password = val |
136 |
| - |
137 |
| - if query: |
138 |
| - opts = {**query, **opts} |
139 |
| - |
140 |
| - # On env-var -> connection parameter conversion read here: |
141 |
| - # https://www.postgresql.org/docs/current/static/libpq-envars.html |
142 |
| - # Note that env values may be an empty string in cases when |
143 |
| - # the variable is "unset" by setting it to an empty value |
144 |
| - # |
145 |
| - if host is None: |
146 |
| - host = os.getenv('PGHOST') |
147 |
| - if not host: |
148 |
| - host = ['/tmp', '/private/tmp', |
149 |
| - '/var/pgsql_socket', '/run/postgresql', |
150 |
| - 'localhost'] |
151 |
| - if not isinstance(host, list): |
152 |
| - host = [host] |
153 |
| - |
154 |
| - if port is None: |
155 |
| - port = os.getenv('PGPORT') |
156 |
| - if port: |
157 |
| - port = int(port) |
158 |
| - else: |
159 |
| - port = 5432 |
160 |
| - else: |
161 |
| - port = int(port) |
162 |
| - |
163 |
| - if user is None: |
164 |
| - user = os.getenv('PGUSER') |
165 |
| - if not user: |
166 |
| - user = getpass.getuser() |
167 |
| - |
168 |
| - if password is None: |
169 |
| - password = os.getenv('PGPASSWORD') |
170 |
| - |
171 |
| - if database is None: |
172 |
| - database = os.getenv('PGDATABASE') |
173 |
| - |
174 |
| - if user is not None: |
175 |
| - opts['user'] = user |
176 |
| - if password is not None: |
177 |
| - opts['password'] = password |
178 |
| - if database is not None: |
179 |
| - opts['database'] = database |
180 |
| - |
181 |
| - for param in opts: |
182 |
| - if not isinstance(param, str): |
183 |
| - raise ValueError( |
184 |
| - 'invalid connection parameter {!r} (str expected)' |
185 |
| - .format(param)) |
186 |
| - if not isinstance(opts[param], str): |
187 |
| - raise ValueError( |
188 |
| - 'invalid connection parameter {!r}: {!r} (str expected)' |
189 |
| - .format(param, opts[param])) |
190 |
| - |
191 |
| - return host, port, opts |
192 |
| - |
193 |
| - |
194 |
| -def _create_future(loop): |
195 |
| - try: |
196 |
| - create_future = loop.create_future |
197 |
| - except AttributeError: |
198 |
| - return asyncio.Future(loop=loop) |
199 |
| - else: |
200 |
| - return create_future() |
0 commit comments