diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/BatchSpannerRead.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/BatchSpannerRead.java index efc983e39f5e..810f7ce8aaae 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/BatchSpannerRead.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/BatchSpannerRead.java @@ -218,7 +218,6 @@ public void processElement(ProcessContext c) throws Exception { BatchReadOnlyTransaction batchTx = spannerAccessor.getBatchClient().batchReadOnlyTransaction(tx.transactionId()); - serviceCallMetric.call("ok"); Partition p = c.element(); try (ResultSet resultSet = batchTx.execute(p)) { while (resultSet.next()) { @@ -227,7 +226,9 @@ public void processElement(ProcessContext c) throws Exception { } } catch (SpannerException e) { serviceCallMetric.call(e.getErrorCode().getGrpcStatusCode().toString()); + throw (e); } + serviceCallMetric.call("ok"); } private ServiceCallMetric createServiceCallMetric( diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java index c80490ea472f..48f7fed7feee 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java @@ -34,6 +34,7 @@ import com.google.cloud.spanner.Partition; import com.google.cloud.spanner.PartitionOptions; import com.google.cloud.spanner.ResultSets; +import com.google.cloud.spanner.SpannerException; import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.Struct; @@ -41,6 +42,7 @@ import com.google.cloud.spanner.Type; import com.google.cloud.spanner.Value; import com.google.protobuf.ByteString; +import io.grpc.Status.Code; import java.io.Serializable; import java.util.Arrays; import java.util.HashMap; @@ -49,6 +51,7 @@ import org.apache.beam.runners.core.metrics.MetricsContainerImpl; import org.apache.beam.runners.core.metrics.MonitoringInfoConstants; import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName; +import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.Read; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.testing.PAssert; @@ -293,7 +296,7 @@ public void runReadWithPriority() throws Exception { } @Test - public void testQueryMetrics() throws Exception { + public void testQueryMetricsFail() throws Exception { Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345); TimestampBound timestampBound = TimestampBound.ofReadTimestamp(timestamp); @@ -322,25 +325,74 @@ public void testQueryMetrics() throws Exception { any(PartitionOptions.class), eq(Statement.of("SELECT * FROM users")), any(ReadQueryUpdateTransactionOption.class))) - .thenReturn(Arrays.asList(fakePartition, fakePartition)); + .thenReturn(Arrays.asList(fakePartition)); when(mockBatchTx.execute(any(Partition.class))) .thenThrow( SpannerExceptionFactory.newSpannerException( - ErrorCode.DEADLINE_EXCEEDED, "Simulated Timeout 1")) - .thenThrow( - SpannerExceptionFactory.newSpannerException( - ErrorCode.DEADLINE_EXCEEDED, "Simulated Timeout 2")) + ErrorCode.DEADLINE_EXCEEDED, "Simulated Timeout 1")); + try { + pipeline.run(); + } catch (PipelineExecutionException e) { + if (e.getCause() instanceof SpannerException + && ((SpannerException) e.getCause()).getErrorCode().getGrpcStatusCode() + == Code.DEADLINE_EXCEEDED) { + // expected + } else { + throw e; + } + } + verifyMetricWasSet("test", "aaa", "123", "deadline_exceeded", null, 1); + verifyMetricWasSet("test", "aaa", "123", "ok", null, 0); + } + + @Test + public void testQueryMetricsSucceed() throws Exception { + Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345); + TimestampBound timestampBound = TimestampBound.ofReadTimestamp(timestamp); + + SpannerConfig spannerConfig = getSpannerConfig(); + + pipeline.apply( + "read q", + SpannerIO.read() + .withSpannerConfig(spannerConfig) + .withQuery("SELECT * FROM users") + .withQueryName("queryName") + .withTimestampBound(timestampBound)); + + FakeBatchTransactionId id = new FakeBatchTransactionId("runQueryTest"); + when(mockBatchTx.getBatchTransactionId()).thenReturn(id); + + when(serviceFactory.mockBatchClient().batchReadOnlyTransaction(timestampBound)) + .thenReturn(mockBatchTx); + when(serviceFactory.mockBatchClient().batchReadOnlyTransaction(any(BatchTransactionId.class))) + .thenReturn(mockBatchTx); + + Partition fakePartition = + FakePartitionFactory.createFakeQueryPartition(ByteString.copyFromUtf8("one")); + + when(mockBatchTx.partitionQuery( + any(PartitionOptions.class), + eq(Statement.of("SELECT * FROM users")), + any(ReadQueryUpdateTransactionOption.class))) + .thenReturn(Arrays.asList(fakePartition, fakePartition)); + when(mockBatchTx.execute(any(Partition.class))) .thenReturn( ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(0, 2)), - ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(2, 6))); + ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(2, 4)), + ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(4, 6))) + .thenReturn( + ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(0, 2)), + ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(2, 4)), + ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(4, 6))); pipeline.run(); - verifyMetricWasSet("test", "aaa", "123", "deadline_exceeded", null, 2); + verifyMetricWasSet("test", "aaa", "123", "deadline_exceeded", null, 0); verifyMetricWasSet("test", "aaa", "123", "ok", null, 2); } @Test - public void testReadMetrics() throws Exception { + public void testReadMetricsFail() throws Exception { Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345); TimestampBound timestampBound = TimestampBound.ofReadTimestamp(timestamp); @@ -371,21 +423,66 @@ public void testReadMetrics() throws Exception { eq(KeySet.all()), eq(Arrays.asList("id", "name")), any(ReadQueryUpdateTransactionOption.class))) - .thenReturn(Arrays.asList(fakePartition, fakePartition, fakePartition)); + .thenReturn(Arrays.asList(fakePartition)); when(mockBatchTx.execute(any(Partition.class))) .thenThrow( SpannerExceptionFactory.newSpannerException( - ErrorCode.DEADLINE_EXCEEDED, "Simulated Timeout 1")) - .thenThrow( - SpannerExceptionFactory.newSpannerException( - ErrorCode.DEADLINE_EXCEEDED, "Simulated Timeout 2")) + ErrorCode.DEADLINE_EXCEEDED, "Simulated Timeout 1")); + try { + pipeline.run(); + } catch (PipelineExecutionException e) { + if (e.getCause() instanceof SpannerException + && ((SpannerException) e.getCause()).getErrorCode().getGrpcStatusCode() + == Code.DEADLINE_EXCEEDED) { + // expected + } else { + throw e; + } + } + verifyMetricWasSet("test", "aaa", "123", "deadline_exceeded", null, 1); + verifyMetricWasSet("test", "aaa", "123", "ok", null, 0); + } + + @Test + public void testReadMetricsSucceed() throws Exception { + Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345); + TimestampBound timestampBound = TimestampBound.ofReadTimestamp(timestamp); + + SpannerConfig spannerConfig = getSpannerConfig(); + + pipeline.apply( + "read q", + SpannerIO.read() + .withSpannerConfig(spannerConfig) + .withTable("users") + .withColumns("id", "name") + .withTimestampBound(timestampBound)); + + FakeBatchTransactionId id = new FakeBatchTransactionId("runReadTest"); + when(mockBatchTx.getBatchTransactionId()).thenReturn(id); + + when(serviceFactory.mockBatchClient().batchReadOnlyTransaction(timestampBound)) + .thenReturn(mockBatchTx); + when(serviceFactory.mockBatchClient().batchReadOnlyTransaction(any(BatchTransactionId.class))) + .thenReturn(mockBatchTx); + + Partition fakePartition = + FakePartitionFactory.createFakeReadPartition(ByteString.copyFromUtf8("one")); + + when(mockBatchTx.partitionRead( + any(PartitionOptions.class), + eq("users"), + eq(KeySet.all()), + eq(Arrays.asList("id", "name")), + any(ReadQueryUpdateTransactionOption.class))) + .thenReturn(Arrays.asList(fakePartition, fakePartition, fakePartition)); + when(mockBatchTx.execute(any(Partition.class))) .thenReturn( ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(0, 2)), ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(2, 4)), ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(4, 6))); pipeline.run(); - verifyMetricWasSet("test", "aaa", "123", "deadline_exceeded", null, 2); verifyMetricWasSet("test", "aaa", "123", "ok", null, 3); }