Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 141 additions & 20 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

"""DB-API Connection for the Google Cloud Spanner."""

import warnings

from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import spanner_v1 as spanner

Expand All @@ -15,19 +17,115 @@
from google.cloud.spanner_dbapi.version import PY_VERSION


class Connection(object):
"""DB-API Connection to a Google Cloud Spanner database.
AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"


class Connection:
"""Representation of a DB-API connection to a Cloud Spanner database.

You most likely don't need to instantiate `Connection` objects
directly, use the `connect` module function instead.

:type instance: :class:`~google.cloud.spanner_v1.instance.Instance`
:param instance: Cloud Spanner instance to connect to.

:type database: :class:`~google.cloud.spanner_v1.database.Database`
:param database: The database to which the connection is linked.
"""

def __init__(self, instance, database):
self.instance = instance
self.database = database
self._instance = instance
self._database = database
self._ddl_statements = []

self._transaction = None
self._session = None

self.is_closed = False
self._autocommit = False

@property
def autocommit(self):
"""Autocommit mode flag for this connection.

:rtype: bool
:returns: Autocommit mode flag value.
"""
return self._autocommit

@autocommit.setter
def autocommit(self, value):
"""Change this connection autocommit mode.

:type value: bool
:param value: New autocommit mode state.
"""
if value and not self._autocommit:
self.commit()

self._autocommit = value

@property
def database(self):
"""Database to which this connection relates.

:rtype: :class:`~google.cloud.spanner_v1.database.Database`
:returns: The related database object.
"""
return self._database

@property
def instance(self):
"""Instance to which this connection relates.

:rtype: :class:`~google.cloud.spanner_v1.instance.Instance`
:returns: The related instance object.
"""
return self._instance

def _session_checkout(self):
"""Get a Cloud Spanner session from the pool.

If there is already a session associated with
this connection, it'll be used instead.

:rtype: :class:`google.cloud.spanner_v1.session.Session`
:returns: Cloud Spanner session object ready to use.
"""
if not self._session:
self._session = self.database._pool.get()

return self._session

def _release_session(self):
"""Release the currently used Spanner session.

The session will be returned into the sessions pool.
"""
self.database._pool.put(self._session)
self._session = None

def transaction_checkout(self):
"""Get a Cloud Spanner transaction.

self.ddl_statements = []
Begin a new transaction, if there is no transaction in
this connection yet. Return the begun one otherwise.

The method is non operational in autocommit mode.

:rtype: :class:`google.cloud.spanner_v1.transaction.Transaction`
:returns: A Cloud Spanner transaction object, ready to use.
"""
if not self.autocommit:
if (
not self._transaction
or self._transaction.committed
or self._transaction.rolled_back
):
self._transaction = self._session_checkout().transaction()
self._transaction.begin()

return self._transaction

def _raise_if_closed(self):
"""Helper to check the connection state before running a query.
Expand All @@ -41,20 +139,33 @@ def _raise_if_closed(self):
def close(self):
"""Closes this connection.

The connection will be unusable from this point forward.
The connection will be unusable from this point forward. If the
connection has an active transaction, it will be rolled back.
"""
self.database = None
if (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
):
self._transaction.rollback()

self.is_closed = True

def commit(self):
"""Commits any pending transaction to the database."""
self._raise_if_closed()

self.run_prior_DDL_statements()
if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
elif self._transaction:
self._transaction.commit()
self._release_session()

def rollback(self):
"""A no-op, raising an error if the connection is closed."""
self._raise_if_closed()
"""Rollback all the pending transactions."""
if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
elif self._transaction:
self._transaction.rollback()
self._release_session()

def cursor(self):
"""Factory to create a DB-API Cursor."""
Expand All @@ -65,9 +176,9 @@ def cursor(self):
def run_prior_DDL_statements(self):
self._raise_if_closed()

if self.ddl_statements:
ddl_statements = self.ddl_statements
self.ddl_statements = []
if self._ddl_statements:
ddl_statements = self._ddl_statements
self._ddl_statements = []

return self.database.update_ddl(ddl_statements).result()

Expand All @@ -80,15 +191,20 @@ def __exit__(self, etype, value, traceback):


def connect(
instance_id, database_id, project=None, credentials=None, user_agent=None
instance_id,
database_id,
project=None,
credentials=None,
pool=None,
user_agent=None,
):
"""Creates a connection to a Google Cloud Spanner database.

:type instance_id: str
:param instance_id: ID of the instance to connect to.
:param instance_id: The ID of the instance to connect to.

:type database_id: str
:param database_id: The name of the database to connect to.
:param database_id: The ID of the database to connect to.

:type project: str
:param project: (Optional) The ID of the project which owns the
Expand All @@ -102,8 +218,13 @@ def connect(
attempt to ascertain the credentials from the
environment.

:type pool: Concrete subclass of
:class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`.
:param pool: (Optional). Session pool to be used by database.

:type user_agent: str
:param user_agent: (Optional) Prefix to the user agent header.
:param user_agent: (Optional) User agent to be used with this connection's
requests.

:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Google Cloud Spanner
Expand All @@ -125,7 +246,7 @@ def connect(
if not instance.exists():
raise ValueError("instance '%s' does not exist." % instance_id)

database = instance.database(database_id, pool=spanner.pool.BurstyPool())
database = instance.database(database_id, pool=pool)
if not database.exists():
raise ValueError("database '%s' does not exist." % database_id)

Expand Down
18 changes: 15 additions & 3 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@

from google.cloud.spanner_dbapi import parse_utils
from google.cloud.spanner_dbapi.parse_utils import get_param_types

from .utils import PeekIterator
from google.cloud.spanner_dbapi.utils import PeekIterator

_UNSET_COUNT = -1

Expand Down Expand Up @@ -150,14 +149,27 @@ def execute(self, sql, args=None):
try:
classification = parse_utils.classify_stmt(sql)
if classification == parse_utils.STMT_DDL:
self.connection.ddl_statements.append(sql)
self.connection._ddl_statements.append(sql)
return

# For every other operation, we've got to ensure that
# any prior DDL statements were run.
# self._run_prior_DDL_statements()
self.connection.run_prior_DDL_statements()

if not self.connection.autocommit:
transaction = self.connection.transaction_checkout()

sql, params = parse_utils.sql_pyformat_args_to_spanner(
sql, args
)

self._result_set = transaction.execute_sql(
sql, params, param_types=get_param_types(params)
)
self._itr = PeekIterator(self._result_set)
return

if classification == parse_utils.STMT_NON_UPDATING:
self._handle_DQL(sql, args or None)
elif classification == parse_utils.STMT_INSERT:
Expand Down
Loading