Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions model/pipeline/src/main/proto/metrics.proto
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,11 @@ message MonitoringInfo {
BIGTABLE_PROJECT_ID = 20 [(label_props) = { name: "BIGTABLE_PROJECT_ID"}];
INSTANCE_ID = 21 [(label_props) = { name: "INSTANCE_ID"}];
TABLE_ID = 22 [(label_props) = { name: "TABLE_ID"}];
SPANNER_PROJECT_ID = 23 [(label_props) = { name: "SPANNER_PROJECT_ID"}];
SPANNER_DATABASE_ID = 24 [(label_props) = { name: "SPANNER_DATABASE_ID"}];
SPANNER_INSTANCE_ID = 25 [(label_props) = { name: "SPANNER_INSTANCE_ID" }];
SPANNER_QUERY_NAME = 26 [(label_props) = { name: "SPANNER_QUERY_NAME" }];
SPANNER_PROJECT_ID = 23 [(label_props) = { name: "SPANNER_PROJECT_ID" }];
SPANNER_DATABASE_ID = 24 [(label_props) = { name: "SPANNER_DATABASE_ID" }];
SPANNER_TABLE_ID = 25 [(label_props) = { name: "SPANNER_TABLE_ID" }];
SPANNER_INSTANCE_ID = 26 [(label_props) = { name: "SPANNER_INSTANCE_ID" }];
SPANNER_QUERY_NAME = 27 [(label_props) = { name: "SPANNER_QUERY_NAME" }];
}

// A set of key and value labels which define the scope of the metric. For
Expand Down
193 changes: 172 additions & 21 deletions sdks/python/apache_beam/io/gcp/experimental/spannerio.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,10 @@
from apache_beam import Flatten
from apache_beam import ParDo
from apache_beam import Reshuffle
from apache_beam.internal.metrics.metric import ServiceCallMetric
from apache_beam.io.gcp import resource_identifiers
from apache_beam.metrics import Metrics
from apache_beam.metrics import monitoring_infos
from apache_beam.pvalue import AsSingleton
from apache_beam.pvalue import PBegin
from apache_beam.pvalue import TaggedOutput
Expand All @@ -189,12 +192,17 @@
from apache_beam.typehints import with_output_types
from apache_beam.utils.annotations import experimental

# Protect against environments where spanner library is not available.
# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
# pylint: disable=unused-import
try:
from google.cloud.spanner import Client
from google.cloud.spanner import KeySet
from google.cloud.spanner_v1 import batch
from google.cloud.spanner_v1.database import BatchSnapshot
from google.cloud.spanner_v1.proto.mutation_pb2 import Mutation
from google.api_core.exceptions import ClientError, GoogleAPICallError
from apitools.base.py.exceptions import HttpError
except ImportError:
Client = None
KeySet = None
Expand Down Expand Up @@ -284,6 +292,8 @@ class _BeamSpannerConfiguration(namedtuple("_BeamSpannerConfiguration",
["project",
"instance",
"database",
"table",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use table_id as the name here to be consistent with other "table_id" variable names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is consistent with the parameters in this namedtuple~ otherwise, there would be project_id and instance_id, database_id, etc. What do you think?

"query_name",
"credentials",
"pool",
"snapshot_read_timestamp",
Expand Down Expand Up @@ -320,6 +330,42 @@ def __init__(self, spanner_configuration):
self._spanner_configuration = spanner_configuration
self._snapshot = None
self._session = None
self.base_labels = {
monitoring_infos.SERVICE_LABEL: 'Spanner',
monitoring_infos.METHOD_LABEL: 'Read',
monitoring_infos.SPANNER_PROJECT_ID: (
self._spanner_configuration.project),
monitoring_infos.SPANNER_DATABASE_ID: (
self._spanner_configuration.database),
}

def _table_metric(self, table_id, status):
database_id = self._spanner_configuration.database
project_id = self._spanner_configuration.project
resource = resource_identifiers.SpannerTable(
project_id, database_id, table_id)
labels = {
**self.base_labels,
monitoring_infos.RESOURCE_LABEL: resource,
monitoring_infos.SPANNER_TABLE_ID: table_id
}
service_call_metric = ServiceCallMetric(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ServiceCallMetric by design must be instantiated before the IO(spanner) API call is made to the IO source/sink.
This is because we intened to use it to also time the IO(spanner) API calls.

Then the .call() should be made once the IO(spanner) API call returns

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
base_labels=labels)
service_call_metric.call(str(status))

def _query_metric(self, query_name, status):
project_id = self._spanner_configuration.project
resource = resource_identifiers.SpannerSqlQuery(project_id, query_name)
labels = {
**self.base_labels,
monitoring_infos.RESOURCE_LABEL: resource,
monitoring_infos.SPANNER_QUERY_NAME: query_name
}
service_call_metric = ServiceCallMetric(
request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
base_labels=labels)
service_call_metric.call(str(status))

def _get_session(self):
if self._session is None:
Expand Down Expand Up @@ -357,16 +403,32 @@ def process(self, element, spanner_transaction):
# getting the transaction from the snapshot's session to run read operation.
# with self._snapshot.session().transaction() as transaction:
with self._get_session().transaction() as transaction:
table_id = self._spanner_configuration.table
query_name = self._spanner_configuration.query_name or ''

if element.is_sql is True:
transaction_read = transaction.execute_sql
metric_action = self._query_metric
metric_id = query_name
elif element.is_table is True:
transaction_read = transaction.read
metric_action = self._table_metric
metric_id = table_id
else:
raise ValueError(
"ReadOperation is improperly configure: %s" % str(element))

for row in transaction_read(**element.kwargs):
yield row
try:
for row in transaction_read(**element.kwargs):
yield row

metric_action(metric_id, 'ok')
except (ClientError, GoogleAPICallError) as e:
metric_action(metric_id, e.code.value)
raise
except HttpError as e:
metric_action(metric_id, e)
raise


@with_input_types(ReadOperation)
Expand Down Expand Up @@ -523,6 +585,43 @@ class _ReadFromPartitionFn(DoFn):
"""
def __init__(self, spanner_configuration):
self._spanner_configuration = spanner_configuration
self.base_labels = {
monitoring_infos.SERVICE_LABEL: 'Spanner',
monitoring_infos.METHOD_LABEL: 'Read',
monitoring_infos.SPANNER_PROJECT_ID: (
self._spanner_configuration.project),
monitoring_infos.SPANNER_DATABASE_ID: (
self._spanner_configuration.database),
}
self.service_metric = None

def _table_metric(self, table_id):
database_id = self._spanner_configuration.database
project_id = self._spanner_configuration.project
resource = resource_identifiers.SpannerTable(
project_id, database_id, table_id)
labels = {
**self.base_labels,
monitoring_infos.RESOURCE_LABEL: resource,
monitoring_infos.SPANNER_TABLE_ID: table_id
}
service_call_metric = ServiceCallMetric(
request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
base_labels=labels)
return service_call_metric

def _query_metric(self, query_name):
project_id = self._spanner_configuration.project
resource = resource_identifiers.SpannerSqlQuery(project_id, query_name)
labels = {
**self.base_labels,
monitoring_infos.RESOURCE_LABEL: resource,
monitoring_infos.SPANNER_QUERY_NAME: query_name
}
service_call_metric = ServiceCallMetric(
request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
base_labels=labels)
return service_call_metric

def setup(self):
spanner_client = Client(self._spanner_configuration.project)
Expand All @@ -537,16 +636,30 @@ def process(self, element):
self._snapshot = BatchSnapshot.from_dict(
self._database, element['transaction_info'])

table_id = self._spanner_configuration.table
query_name = self._spanner_configuration.query_name or ''

if element['is_sql'] is True:
read_action = self._snapshot.process_query_batch
self.service_metric = self._query_metric(query_name)
elif element['is_table'] is True:
read_action = self._snapshot.process_read_batch
self.service_metric = self._table_metric(table_id)
else:
raise ValueError(
"ReadOperation is improperly configure: %s" % str(element))

for row in read_action(element['partitions']):
yield row
try:
for row in read_action(element['partitions']):
yield row

self.service_metric.call('ok')
except (ClientError, GoogleAPICallError) as e:
self.service_metric(str(e.code.value))
raise
except HttpError as e:
self.service_metric(str(e))
raise

def teardown(self):
if self._snapshot:
Expand All @@ -563,7 +676,8 @@ class ReadFromSpanner(PTransform):
def __init__(self, project_id, instance_id, database_id, pool=None,
read_timestamp=None, exact_staleness=None, credentials=None,
sql=None, params=None, param_types=None, # with_query
table=None, columns=None, index="", keyset=None, # with_table
table=None, query_name=None, columns=None, index="",
keyset=None, # with_table
read_operations=None, # for read all
transaction=None
):
Expand Down Expand Up @@ -611,6 +725,8 @@ def __init__(self, project_id, instance_id, database_id, pool=None,
project=project_id,
instance=instance_id,
database=database_id,
table=table,
query_name=query_name,
credentials=credentials,
pool=pool,
snapshot_read_timestamp=read_timestamp,
Expand Down Expand Up @@ -725,6 +841,8 @@ def __init__(
project=project_id,
instance=instance_id,
database=database_id,
table=None,
query_name=None,
credentials=credentials,
pool=pool,
snapshot_read_timestamp=None,
Expand Down Expand Up @@ -1068,6 +1186,28 @@ def __init__(self, spanner_configuration):
self._spanner_configuration = spanner_configuration
self._db_instance = None
self.batches = Metrics.counter(self.__class__, 'SpannerBatches')
self.base_labels = {
monitoring_infos.SERVICE_LABEL: 'Spanner',
monitoring_infos.METHOD_LABEL: 'Write',
monitoring_infos.SPANNER_PROJECT_ID: spanner_configuration.project,
monitoring_infos.SPANNER_DATABASE_ID: spanner_configuration.database,
}
self.service_metric = None

def _table_metric(self, table_id):
database_id = self._spanner_configuration.database
project_id = self._spanner_configuration.project
resource = resource_identifiers.SpannerTable(
project_id, database_id, table_id)
labels = {
**self.base_labels,
monitoring_infos.RESOURCE_LABEL: resource,
monitoring_infos.SPANNER_TABLE_ID: table_id
}
service_call_metric = ServiceCallMetric(
request_count_urn=monitoring_infos.API_REQUEST_COUNT_URN,
base_labels=labels)
return service_call_metric

def setup(self):
spanner_client = Client(self._spanner_configuration.project)
Expand All @@ -1078,22 +1218,33 @@ def setup(self):

def process(self, element):
self.batches.inc()
with self._db_instance.batch() as b:
for m in element:
if m.operation == WriteMutation._OPERATION_DELETE:
batch_func = b.delete
elif m.operation == WriteMutation._OPERATION_REPLACE:
batch_func = b.replace
elif m.operation == WriteMutation._OPERATION_INSERT_OR_UPDATE:
batch_func = b.insert_or_update
elif m.operation == WriteMutation._OPERATION_INSERT:
batch_func = b.insert
elif m.operation == WriteMutation._OPERATION_UPDATE:
batch_func = b.update
else:
raise ValueError("Unknown operation action: %s" % m.operation)

batch_func(**m.kwargs)
try:
with self._db_instance.batch() as b:
for m in element:
table_id = m.kwargs['table']
self.service_metric = self._table_metric(table_id)

if m.operation == WriteMutation._OPERATION_DELETE:
batch_func = b.delete
elif m.operation == WriteMutation._OPERATION_REPLACE:
batch_func = b.replace
elif m.operation == WriteMutation._OPERATION_INSERT_OR_UPDATE:
batch_func = b.insert_or_update
elif m.operation == WriteMutation._OPERATION_INSERT:
batch_func = b.insert
elif m.operation == WriteMutation._OPERATION_UPDATE:
batch_func = b.update
else:
raise ValueError("Unknown operation action: %s" % m.operation)
batch_func(**m.kwargs)

self.service_metric.call('ok')
except (ClientError, GoogleAPICallError) as e:
self.service_metric.call(str(e.code.value))
raise
except HttpError as e:
self.service_metric.call(str(e))
raise


@with_input_types(typing.Union[MutationGroup, _Mutator])
Expand Down
Loading