diff --git a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/BigtableDataClientFactory.java b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/BigtableDataClientFactory.java index b08613cc69..d4561ab4df 100644 --- a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/BigtableDataClientFactory.java +++ b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/BigtableDataClientFactory.java @@ -23,7 +23,7 @@ import com.google.api.gax.rpc.FixedHeaderProvider; import com.google.api.gax.rpc.FixedTransportChannelProvider; import com.google.api.gax.rpc.FixedWatchdogProvider; -import com.google.api.gax.rpc.StubSettings; +import com.google.cloud.bigtable.data.v2.stub.EnhancedBigtableStubSettings; import java.io.IOException; import javax.annotation.Nonnull; @@ -189,8 +189,11 @@ public BigtableDataClient createForInstance( } // Update stub settings to use shared resources in this factory - private void patchStubSettings(StubSettings.Builder stubSettings) { + private void patchStubSettings(EnhancedBigtableStubSettings.Builder stubSettings) { stubSettings + // Channel refreshing will be configured in the shared ClientContext. Derivative clients + // won't be able to reconfigure the refreshing logic + .setRefreshingChannel(false) .setTransportChannelProvider( FixedTransportChannelProvider.create(sharedClientContext.getTransportChannel())) .setCredentialsProvider( diff --git a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettings.java b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettings.java index e1b998d46f..bf2d88810c 100644 --- a/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettings.java +++ b/google-cloud-bigtable/src/main/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettings.java @@ -20,6 +20,7 @@ import com.google.api.gax.batching.BatchingSettings; import com.google.api.gax.batching.FlowControlSettings; import com.google.api.gax.batching.FlowController.LimitExceededBehavior; +import com.google.api.gax.core.FixedCredentialsProvider; import com.google.api.gax.core.GoogleCredentialsProvider; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider; import com.google.api.gax.retrying.RetrySettings; @@ -29,6 +30,7 @@ import com.google.api.gax.rpc.StubSettings; import com.google.api.gax.rpc.TransportChannelProvider; import com.google.api.gax.rpc.UnaryCallSettings; +import com.google.auth.Credentials; import com.google.cloud.bigtable.Version; import com.google.cloud.bigtable.data.v2.models.ConditionalRowMutation; import com.google.cloud.bigtable.data.v2.models.KeyOffset; @@ -42,6 +44,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Set; @@ -787,6 +790,22 @@ public EnhancedBigtableStubSettings build() { Preconditions.checkArgument( getTransportChannelProvider() instanceof InstantiatingGrpcChannelProvider, "refreshingChannel only works with InstantiatingGrpcChannelProviders"); + InstantiatingGrpcChannelProvider.Builder channelProviderBuilder = + ((InstantiatingGrpcChannelProvider) getTransportChannelProvider()).toBuilder(); + Credentials credentials = null; + if (getCredentialsProvider() != null) { + try { + credentials = getCredentialsProvider().getCredentials(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + // Use shared credentials + this.setCredentialsProvider(FixedCredentialsProvider.create(credentials)); + channelProviderBuilder.setChannelPrimer( + BigtableChannelPrimer.create( + credentials, projectId, instanceId, appProfileId, primedTableIds)); + this.setTransportChannelProvider(channelProviderBuilder.build()); } return new EnhancedBigtableStubSettings(this); } diff --git a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/BigtableDataClientFactoryTest.java b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/BigtableDataClientFactoryTest.java index d322654b81..25c341d650 100644 --- a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/BigtableDataClientFactoryTest.java +++ b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/BigtableDataClientFactoryTest.java @@ -18,19 +18,32 @@ import static com.google.common.truth.Truth.assertThat; import com.google.api.core.ApiClock; +import com.google.api.core.ApiFunction; import com.google.api.gax.core.CredentialsProvider; import com.google.api.gax.core.ExecutorProvider; +import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider; import com.google.api.gax.rpc.TransportChannelProvider; import com.google.api.gax.rpc.WatchdogProvider; import com.google.bigtable.v2.BigtableGrpc; import com.google.bigtable.v2.MutateRowRequest; import com.google.bigtable.v2.MutateRowResponse; +import com.google.bigtable.v2.ReadRowsRequest; +import com.google.bigtable.v2.ReadRowsResponse; +import com.google.bigtable.v2.RowFilter; +import com.google.bigtable.v2.RowSet; import com.google.cloud.bigtable.data.v2.internal.NameUtil; import com.google.cloud.bigtable.data.v2.models.RowMutation; import com.google.common.base.Preconditions; +import com.google.protobuf.ByteString; +import io.grpc.Attributes; +import io.grpc.ServerTransportFilter; import io.grpc.stub.StreamObserver; import java.io.IOException; import java.lang.reflect.Method; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -60,16 +73,33 @@ public class BigtableDataClientFactoryTest { private WatchdogProvider watchdogProvider; private ApiClock apiClock; private BigtableDataSettings defaultSettings; + private int port; + + private final BlockingQueue setUpAttributes = new LinkedBlockingDeque<>(); + private final BlockingQueue terminateAttributes = new LinkedBlockingDeque<>(); @Before public void setUp() throws IOException { service = new FakeBigtableService(); - - serviceHelper = new FakeServiceHelper(service); + ServerTransportFilter transportFilter = + new ServerTransportFilter() { + @Override + public Attributes transportReady(Attributes transportAttrs) { + setUpAttributes.add(transportAttrs); + return super.transportReady(transportAttrs); + } + + @Override + public void transportTerminated(Attributes transportAttrs) { + terminateAttributes.add(transportAttrs); + } + }; + serviceHelper = new FakeServiceHelper(null, transportFilter, service); + port = serviceHelper.getPort(); serviceHelper.start(); BigtableDataSettings.Builder builder = - BigtableDataSettings.newBuilderForEmulator(serviceHelper.getPort()) + BigtableDataSettings.newBuilderForEmulator(port) .setProjectId(DEFAULT_PROJECT_ID) .setInstanceId(DEFAULT_INSTANCE_ID) .setAppProfileId(DEFAULT_APP_PROFILE_ID); @@ -191,8 +221,94 @@ public void testCreateForInstanceWithAppProfileHasCorrectSettings() throws Excep assertThat(service.lastRequest.getAppProfileId()).isEqualTo("other-app-profile"); } + @Test + public void testCreateWithRefreshingChannel() throws Exception { + String[] tableIds = {"fake-table1", "fake-table2"}; + int poolSize = 3; + BigtableDataSettings.Builder builder = + BigtableDataSettings.newBuilderForEmulator(port) + .setProjectId(DEFAULT_PROJECT_ID) + .setInstanceId(DEFAULT_INSTANCE_ID) + .setAppProfileId(DEFAULT_APP_PROFILE_ID) + .setPrimingTableIds(tableIds) + .setRefreshingChannel(true); + builder + .stubSettings() + .setCredentialsProvider(credentialsProvider) + .setStreamWatchdogProvider(watchdogProvider) + .setExecutorProvider(executorProvider); + InstantiatingGrpcChannelProvider channelProvider = + (InstantiatingGrpcChannelProvider) builder.stubSettings().getTransportChannelProvider(); + InstantiatingGrpcChannelProvider.Builder channelProviderBuilder = channelProvider.toBuilder(); + channelProviderBuilder.setPoolSize(poolSize); + builder.stubSettings().setTransportChannelProvider(channelProviderBuilder.build()); + + BigtableDataClientFactory factory = BigtableDataClientFactory.create(builder.build()); + factory.createDefault(); + factory.createForAppProfile("other-appprofile"); + factory.createForInstance("other-project", "other-instance"); + + // Make sure that only 1 instance is created for all clients + Mockito.verify(credentialsProvider, Mockito.times(1)).getCredentials(); + Mockito.verify(executorProvider, Mockito.times(1)).getExecutor(); + Mockito.verify(watchdogProvider, Mockito.times(1)).getWatchdog(); + + // Make sure that the clients are sharing the same ChannelPool + assertThat(setUpAttributes).hasSize(poolSize); + + // Make sure that prime requests were sent only once per table per connection + assertThat(service.readRowsRequests).hasSize(poolSize * tableIds.length); + List expectedRequests = new LinkedList<>(); + for (String tableId : tableIds) { + for (int i = 0; i < poolSize; i++) { + expectedRequests.add( + ReadRowsRequest.newBuilder() + .setTableName( + String.format( + "projects/%s/instances/%s/tables/%s", + DEFAULT_PROJECT_ID, DEFAULT_INSTANCE_ID, tableId)) + .setAppProfileId(DEFAULT_APP_PROFILE_ID) + .setRows( + RowSet.newBuilder() + .addRowKeys(ByteString.copyFromUtf8("nonexistent-priming-row"))) + .setFilter(RowFilter.newBuilder().setBlockAllFilter(true).build()) + .setRowsLimit(1) + .build()); + } + } + assertThat(service.readRowsRequests).containsExactly(expectedRequests.toArray()); + + // Wait for all the connections to close asynchronously + factory.close(); + long sleepTimeMs = 1000; + Thread.sleep(sleepTimeMs); + // Verify that all the channels are closed + assertThat(terminateAttributes).hasSize(poolSize); + } + private static class FakeBigtableService extends BigtableGrpc.BigtableImplBase { + volatile MutateRowRequest lastRequest; + BlockingQueue readRowsRequests = new LinkedBlockingDeque<>(); + private ApiFunction readRowsCallback = + new ApiFunction() { + @Override + public ReadRowsResponse apply(ReadRowsRequest readRowsRequest) { + return ReadRowsResponse.getDefaultInstance(); + } + }; + + @Override + public void readRows( + ReadRowsRequest request, StreamObserver responseObserver) { + try { + readRowsRequests.add(request); + responseObserver.onNext(readRowsCallback.apply(request)); + responseObserver.onCompleted(); + } catch (RuntimeException e) { + responseObserver.onError(e); + } + } @Override public void mutateRow( @@ -204,6 +320,7 @@ public void mutateRow( } private static class BuilderAnswer implements Answer { + private final Class targetClass; private T targetInstance; diff --git a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/FakeServiceHelper.java b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/FakeServiceHelper.java index abd5569702..9ec5e59cb7 100644 --- a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/FakeServiceHelper.java +++ b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/FakeServiceHelper.java @@ -19,6 +19,7 @@ import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.ServerInterceptor; +import io.grpc.ServerTransportFilter; import java.io.IOException; import java.net.ServerSocket; @@ -33,6 +34,14 @@ public FakeServiceHelper(BindableService... services) throws IOException { public FakeServiceHelper(ServerInterceptor interceptor, BindableService... services) throws IOException { + this(interceptor, null, services); + } + + public FakeServiceHelper( + ServerInterceptor interceptor, + ServerTransportFilter transportFilter, + BindableService... services) + throws IOException { try (ServerSocket ss = new ServerSocket(0)) { port = ss.getLocalPort(); } @@ -40,6 +49,9 @@ public FakeServiceHelper(ServerInterceptor interceptor, BindableService... servi if (interceptor != null) { builder = builder.intercept(interceptor); } + if (transportFilter != null) { + builder = builder.addTransportFilter(transportFilter); + } for (BindableService service : services) { builder = builder.addService(service); } diff --git a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettingsTest.java b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettingsTest.java index 9a3eb874d1..d9273b5fd7 100644 --- a/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettingsTest.java +++ b/google-cloud-bigtable/src/test/java/com/google/cloud/bigtable/data/v2/stub/EnhancedBigtableStubSettingsTest.java @@ -19,20 +19,27 @@ import com.google.api.gax.batching.BatchingSettings; import com.google.api.gax.core.CredentialsProvider; +import com.google.api.gax.core.FixedCredentialsProvider; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider; import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.rpc.ServerStreamingCallSettings; import com.google.api.gax.rpc.StatusCode.Code; import com.google.api.gax.rpc.UnaryCallSettings; import com.google.api.gax.rpc.WatchdogProvider; +import com.google.auth.Credentials; import com.google.cloud.bigtable.data.v2.models.ConditionalRowMutation; import com.google.cloud.bigtable.data.v2.models.KeyOffset; import com.google.cloud.bigtable.data.v2.models.Query; import com.google.cloud.bigtable.data.v2.models.Row; import com.google.cloud.bigtable.data.v2.models.RowMutation; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Range; +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Set; import org.junit.Test; import org.junit.runner.RunWith; @@ -61,7 +68,7 @@ public void settingsAreNotLostTest() { String projectId = "my-project"; String instanceId = "my-instance"; String appProfileId = "my-app-profile-id"; - boolean isRefreshingChannel = true; + boolean isRefreshingChannel = false; String endpoint = "some.other.host:123"; CredentialsProvider credentialsProvider = Mockito.mock(CredentialsProvider.class); WatchdogProvider watchdogProvider = Mockito.mock(WatchdogProvider.class); @@ -612,4 +619,56 @@ public void isRefreshingChannelFalseValueTest() { assertThat(builder.build().isRefreshingChannel()).isFalse(); assertThat(builder.build().toBuilder().isRefreshingChannel()).isFalse(); } + + @Test + public void refreshingChannelSetFixedCredentialProvider() throws Exception { + String dummyProjectId = "my-project"; + String dummyInstanceId = "my-instance"; + CredentialsProvider credentialsProvider = Mockito.mock(CredentialsProvider.class); + FakeCredentials expectedCredentials = new FakeCredentials(); + Mockito.when(credentialsProvider.getCredentials()) + .thenReturn(expectedCredentials, new FakeCredentials(), new FakeCredentials()); + EnhancedBigtableStubSettings.Builder builder = + EnhancedBigtableStubSettings.newBuilder() + .setProjectId(dummyProjectId) + .setInstanceId(dummyInstanceId) + .setRefreshingChannel(true) + .setCredentialsProvider(credentialsProvider); + assertThat(builder.isRefreshingChannel()).isTrue(); + // Verify that isRefreshing setting is not lost and stubSettings will always return the same + // credential + EnhancedBigtableStubSettings stubSettings = builder.build(); + assertThat(stubSettings.isRefreshingChannel()).isTrue(); + assertThat(stubSettings.getCredentialsProvider()).isInstanceOf(FixedCredentialsProvider.class); + assertThat(stubSettings.getCredentialsProvider().getCredentials()) + .isEqualTo(expectedCredentials); + assertThat(stubSettings.toBuilder().isRefreshingChannel()).isTrue(); + assertThat(stubSettings.toBuilder().getCredentialsProvider().getCredentials()) + .isEqualTo(expectedCredentials); + } + + private static class FakeCredentials extends Credentials { + @Override + public String getAuthenticationType() { + return "fake"; + } + + @Override + public Map> getRequestMetadata(URI uri) throws IOException { + return ImmutableMap.of("my-header", Arrays.asList("fake-credential")); + } + + @Override + public boolean hasRequestMetadata() { + return true; + } + + @Override + public boolean hasRequestMetadataOnly() { + return true; + } + + @Override + public void refresh() throws IOException {} + } }