Skip to content

Commit

Permalink
Merge pull request #76 from jsangmeister/connection-pool
Browse files Browse the repository at this point in the history
Use a connection pool to support multiple parallel requests
  • Loading branch information
FinnStutzenstein committed Jul 6, 2020
2 parents 6b69ca7 + 41ac2b6 commit b4de8cf
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 150 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.prod
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ ENV PORT=$PORT
EXPOSE $PORT

ENTRYPOINT ["./entrypoint.sh"]
CMD gunicorn -w 1 -b 0.0.0.0:$PORT $MODULE.app:application
CMD gunicorn -w 8 -b 0.0.0.0:$PORT $MODULE.app:application
117 changes: 61 additions & 56 deletions shared/shared/postgresql_backend/pg_connection_handler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Optional
import threading
from threading import Semaphore

import psycopg2
from psycopg2 import sql
from psycopg2.extras import DictCursor, Json
from psycopg2.pool import ThreadedConnectionPool

from shared.di import service_as_singleton
from shared.services import EnvironmentService, ShutdownService
Expand All @@ -17,7 +19,7 @@
# be retried. Also it should create a connection, if it wasn't established before.


class ENVIRONMENT_VARIABLES:
class DATABASE_ENVIRONMENT_VARIABLES:
HOST = "DATASTORE_DATABASE_HOST"
PORT = "DATASTORE_DATABASE_PORT"
NAME = "DATASTORE_DATABASE_NAME"
Expand All @@ -30,83 +32,100 @@ def __init__(self, connection_handler):
self.connection_handler = connection_handler

def __enter__(self):
self.connection_handler.connection.__enter__()
self.connection_handler.set_transaction_running(True)
self.connection = self.connection_handler.get_connection()
self.connection.__enter__()

def __exit__(self, type, value, traceback):
self.connection_handler.set_transaction_running(False)
if self.connection_handler.connection:
self.connection_handler.connection.__exit__(type, value, traceback)
self.connection.__exit__(type, value, traceback)
self.connection_handler.put_connection(self.connection)


@service_as_singleton
class PgConnectionHandlerService:

connection: Optional[Any] = None
context = None
_storage: threading.local
connection_pool: ThreadedConnectionPool

environment: EnvironmentService
shutdown_service: ShutdownService

def __init__(self, shutdown_service: ShutdownService):
self.set_transaction_running(False)
shutdown_service.register(self)
self._storage = threading.local()

min_conn = int(self.environment.try_get("DATASTORE_MIN_CONNECTIONS") or 1)
max_conn = int(self.environment.try_get("DATASTORE_MAX_CONNECTIONS") or 1)
self._semaphore = Semaphore(max_conn)
try:
self.connection_pool = ThreadedConnectionPool(
min_conn, max_conn, **self.get_connection_params()
)
except psycopg2.Error as e:
self.raise_error(
f"Database connection error ({type(e).__name__}) {e.pgcode}: {e.pgerror}" # noqa
)

def set_transaction_running(self, value):
self.is_transaction_running = value
def get_current_connection(self):
try:
return self._storage.connection
except AttributeError:
return None

def set_current_connection(self, connection):
self._storage.connection = connection

def get_connection_params(self):
return {
"host": self.environment.get(ENVIRONMENT_VARIABLES.HOST),
"port": int(self.environment.try_get(ENVIRONMENT_VARIABLES.PORT) or 5432),
"database": self.environment.get(ENVIRONMENT_VARIABLES.NAME),
"user": self.environment.get(ENVIRONMENT_VARIABLES.USER),
"password": self.environment.get(ENVIRONMENT_VARIABLES.PASSWORD),
"host": self.environment.get(DATABASE_ENVIRONMENT_VARIABLES.HOST),
"port": int(
self.environment.try_get(DATABASE_ENVIRONMENT_VARIABLES.PORT) or 5432
),
"database": self.environment.get(DATABASE_ENVIRONMENT_VARIABLES.NAME),
"user": self.environment.get(DATABASE_ENVIRONMENT_VARIABLES.USER),
"password": self.environment.get(DATABASE_ENVIRONMENT_VARIABLES.PASSWORD),
"cursor_factory": DictCursor,
}

def ensure_connection(self):
if not self.connection:
try:
self.connection = psycopg2.connect(**self.get_connection_params())
self.connection.autocommit = False
except psycopg2.Error as e:
self.raise_error(
f"Database connection error ({type(e).__name__}) {e.pgcode}: {e.pgerror}" # noqa
)
else:
# TODO: check if alive
pass
def get_connection(self):
if self.get_current_connection():
raise BadCodingError(
"You cannot start multiple transactions in one thread!"
)
self._semaphore.acquire()
connection = self.connection_pool.getconn()
connection.autocommit = False
self.set_current_connection(connection)
return connection

def get_connection_context(self):
if self.is_transaction_running:
raise BadCodingError("You cannot start multiple transactions at once!")
def put_connection(self, connection):
if connection != self.get_current_connection():
raise BadCodingError("Invalid connection")

self.ensure_connection()
self.context = ConnectionContext(self)
return self.context
self.connection_pool.putconn(connection)
self.set_current_connection(None)
self._semaphore.release()

def get_connection_context(self):
return ConnectionContext(self)

def to_json(self, data):
return Json(data)

def execute(self, query, arguments, sql_parameters=[]):
connection = self.get_connection_with_open_transaction()
prepared_query = self.prepare_query(query, sql_parameters)
with connection.cursor() as cursor:
with self.get_current_connection().cursor() as cursor:
cursor.execute(prepared_query, arguments)

def query(self, query, arguments, sql_parameters=[]):
connection = self.get_connection_with_open_transaction()
prepared_query = self.prepare_query(query, sql_parameters)
with connection.cursor() as cursor:
with self.get_current_connection().cursor() as cursor:
cursor.execute(prepared_query, arguments)
result = cursor.fetchall()
return result

def query_single_value(self, query, arguments, sql_parameters=[]):
connection = self.get_connection_with_open_transaction()
prepared_query = self.prepare_query(query, sql_parameters)
with connection.cursor() as cursor:
with self.get_current_connection().cursor() as cursor:
cursor.execute(prepared_query, arguments)
result = cursor.fetchone()

Expand All @@ -118,17 +137,6 @@ def query_list_of_single_values(self, query, arguments, sql_parameters=[]):
result = self.query(query, arguments, sql_parameters)
return list(map(lambda row: row[0], result))

def get_connection_with_open_transaction(self) -> Any:
if not self.connection:
raise BadCodingError(
"You should open a db connection first with `get_connection_context()`!"
)
if not self.is_transaction_running:
raise BadCodingError(
"You should start a transaction with `get_connection_context()`!"
)
return self.connection

def prepare_query(self, query, sql_parameters):
prepared_query = sql.SQL(query).format(
*[sql.Identifier(param) for param in sql_parameters]
Expand All @@ -140,7 +148,4 @@ def raise_error(self, msg):
raise DatabaseError(msg)

def shutdown(self):
if self.connection:
self.connection.close()
self.context = None
self.connection = None
self.connection_pool.closeall()
2 changes: 1 addition & 1 deletion shared/shared/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from shared.di import injector
from shared.postgresql_backend.pg_connection_handler import (
ENVIRONMENT_VARIABLES as POSTGRESQL_ENVIRONMENT_VARIABLES,
DATABASE_ENVIRONMENT_VARIABLES as POSTGRESQL_ENVIRONMENT_VARIABLES,
)

from ..util import ALL_TABLES
Expand Down
Loading

0 comments on commit b4de8cf

Please sign in to comment.