diff --git a/hbase-common/src/main/java/org/apache/hadoop/hbase/util/ReservoirSample.java b/hbase-common/src/main/java/org/apache/hadoop/hbase/util/ReservoirSample.java new file mode 100644 index 000000000000..2cc502eb537b --- /dev/null +++ b/hbase-common/src/main/java/org/apache/hadoop/hbase/util/ReservoirSample.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.util; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.apache.yetus.audience.InterfaceAudience; + +import org.apache.hbase.thirdparty.com.google.common.base.Preconditions; + +/** + * The simple version of reservoir sampling implementation. It is enough for the usage in HBase. + *

+ * See https://en.wikipedia.org/wiki/Reservoir_sampling. + */ +@InterfaceAudience.Private +public class ReservoirSample { + + private final List r; + + private final int k; + + private int n; + + public ReservoirSample(int k) { + Preconditions.checkArgument(k > 0, "negative sampling number(%d) is not allowed"); + r = new ArrayList<>(k); + this.k = k; + } + + public void add(T t) { + if (n < k) { + r.add(t); + } else { + int j = ThreadLocalRandom.current().nextInt(n + 1); + if (j < k) { + r.set(j, t); + } + } + n++; + } + + public void add(Iterator iter) { + iter.forEachRemaining(this::add); + } + + public void add(Stream s) { + s.forEachOrdered(this::add); + } + + public List getSamplingResult() { + return r; + } +} diff --git a/hbase-common/src/test/java/org/apache/hadoop/hbase/util/TestReservoirSample.java b/hbase-common/src/test/java/org/apache/hadoop/hbase/util/TestReservoirSample.java new file mode 100644 index 000000000000..a4d23d47c47b --- /dev/null +++ b/hbase-common/src/test/java/org/apache/hadoop/hbase/util/TestReservoirSample.java @@ -0,0 +1,92 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.stream.IntStream; +import org.apache.hadoop.hbase.HBaseClassTestRule; +import org.apache.hadoop.hbase.testclassification.MiscTests; +import org.apache.hadoop.hbase.testclassification.SmallTests; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +@Category({ MiscTests.class, SmallTests.class }) +public class TestReservoirSample { + + @ClassRule + public static final HBaseClassTestRule CLASS_RULE = + HBaseClassTestRule.forClass(TestReservoirSample.class); + + @Test + public void test() { + int round = 100000; + int containsOne = 0; + for (int i = 0; i < round; i++) { + ReservoirSample rs = new ReservoirSample<>(10); + for (int j = 0; j < 100; j++) { + rs.add(j); + if (j < 10) { + assertEquals(j + 1, rs.getSamplingResult().size()); + } else { + assertEquals(10, rs.getSamplingResult().size()); + } + } + if (rs.getSamplingResult().contains(1)) { + containsOne++; + } + } + // we assume a 5% error rate + assertTrue(containsOne > round / 10 * 0.95); + assertTrue(containsOne < round / 10 * 1.05); + } + + @Test + public void testIterator() { + int round = 100000; + int containsOne = 0; + for (int i = 0; i < round; i++) { + ReservoirSample rs = new ReservoirSample<>(10); + rs.add(IntStream.range(0, 100).mapToObj(Integer::valueOf).iterator()); + if (rs.getSamplingResult().contains(1)) { + containsOne++; + } + } + // we assume a 5% error rate + assertTrue(containsOne > round / 10 * 0.95); + assertTrue(containsOne < round / 10 * 1.05); + } + + @Test + public void testStream() { + int round = 100000; + int containsOne = 0; + for (int i = 0; i < round; i++) { + ReservoirSample rs = new ReservoirSample<>(10); + rs.add(IntStream.range(0, 100).mapToObj(Integer::valueOf)); + if (rs.getSamplingResult().contains(1)) { + containsOne++; + } + } + // we assume a 5% error rate + assertTrue(containsOne > round / 10 * 0.95); + assertTrue(containsOne < round / 10 * 1.05); + } +} diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/HBaseRpcServicesBase.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/HBaseRpcServicesBase.java index 4bd0b3304d40..21b2d757463e 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/HBaseRpcServicesBase.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/HBaseRpcServicesBase.java @@ -22,11 +22,9 @@ import java.lang.reflect.Method; import java.net.BindException; import java.net.InetSocketAddress; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; -import java.util.concurrent.ThreadLocalRandom; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.client.ConnectionUtils; import org.apache.hadoop.hbase.conf.ConfigurationObserver; @@ -53,6 +51,7 @@ import org.apache.hadoop.hbase.security.access.ZKPermissionWatcher; import org.apache.hadoop.hbase.util.DNS; import org.apache.hadoop.hbase.util.OOMEChecker; +import org.apache.hadoop.hbase.util.ReservoirSample; import org.apache.hadoop.hbase.zookeeper.ZKWatcher; import org.apache.yetus.audience.InterfaceAudience; import org.apache.zookeeper.KeeperException; @@ -299,12 +298,13 @@ public GetMetaRegionLocationsResponse getMetaRegionLocations(RpcController contr @Override public final GetBootstrapNodesResponse getBootstrapNodes(RpcController controller, GetBootstrapNodesRequest request) throws ServiceException { - List bootstrapNodes = new ArrayList<>(server.getRegionServers()); - Collections.shuffle(bootstrapNodes, ThreadLocalRandom.current()); int maxNodeCount = server.getConfiguration().getInt(CLIENT_BOOTSTRAP_NODE_LIMIT, DEFAULT_CLIENT_BOOTSTRAP_NODE_LIMIT); + ReservoirSample sample = new ReservoirSample<>(maxNodeCount); + sample.add(server.getRegionServers()); + GetBootstrapNodesResponse.Builder builder = GetBootstrapNodesResponse.newBuilder(); - bootstrapNodes.stream().limit(maxNodeCount).map(ProtobufUtil::toServerName) + sample.getSamplingResult().stream().map(ProtobufUtil::toServerName) .forEach(builder::addServerName); return builder.build(); } diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/client/ConnectionRegistryEndpoint.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/client/ConnectionRegistryEndpoint.java index 420c6d6b98e1..0a1557139b32 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/client/ConnectionRegistryEndpoint.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/client/ConnectionRegistryEndpoint.java @@ -17,7 +17,7 @@ */ package org.apache.hadoop.hbase.client; -import java.util.Collection; +import java.util.Iterator; import java.util.List; import java.util.Optional; import org.apache.hadoop.hbase.HRegionLocation; @@ -46,9 +46,9 @@ public interface ConnectionRegistryEndpoint { List getBackupMasters(); /** - * Get all the region servers address. + * Get a iterator of the region servers which could be used as bootstrap nodes. */ - Collection getRegionServers(); + Iterator getRegionServers(); /** * Get the location of meta regions. diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/master/HMaster.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/master/HMaster.java index 9706149e82d5..45b980d86333 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/master/HMaster.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/master/HMaster.java @@ -3979,8 +3979,8 @@ public List getBackupMasters() { } @Override - public Collection getRegionServers() { - return regionServerTracker.getRegionServers(); + public Iterator getRegionServers() { + return regionServerTracker.getRegionServers().iterator(); } @Override diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/regionserver/HRegionServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/regionserver/HRegionServer.java index cc22124dbeee..b45f33abeb4f 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/regionserver/HRegionServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/regionserver/HRegionServer.java @@ -34,6 +34,7 @@ import java.util.Collections; import java.util.Comparator; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -3429,8 +3430,8 @@ public List getBackupMasters() { } @Override - public Collection getRegionServers() { - return regionServerAddressTracker.getRegionServers(); + public Iterator getRegionServers() { + return regionServerAddressTracker.getRegionServers().iterator(); } @Override