diff --git a/dubbo-cluster/src/main/java/com/alibaba/dubbo/rpc/cluster/loadbalance/RoundRobinLoadBalance.java b/dubbo-cluster/src/main/java/com/alibaba/dubbo/rpc/cluster/loadbalance/RoundRobinLoadBalance.java index 9d4d0267d67..fa03e9d3b43 100644 --- a/dubbo-cluster/src/main/java/com/alibaba/dubbo/rpc/cluster/loadbalance/RoundRobinLoadBalance.java +++ b/dubbo-cluster/src/main/java/com/alibaba/dubbo/rpc/cluster/loadbalance/RoundRobinLoadBalance.java @@ -17,87 +17,141 @@ package com.alibaba.dubbo.rpc.cluster.loadbalance; import com.alibaba.dubbo.common.URL; -import com.alibaba.dubbo.common.utils.AtomicPositiveInteger; import com.alibaba.dubbo.rpc.Invocation; import com.alibaba.dubbo.rpc.Invoker; -import java.util.LinkedHashMap; +import java.util.Collection; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; /** * Round robin load balance. + * + * Smoothly round robin's implementation @since 2.6.5 + * @author jason * */ public class RoundRobinLoadBalance extends AbstractLoadBalance { - public static final String NAME = "roundrobin"; + + private static int RECYCLE_PERIOD = 60000; + + protected static class WeightedRoundRobin { + private int weight; + private AtomicLong current = new AtomicLong(0); + private long lastUpdate; + public int getWeight() { + return weight; + } + public void setWeight(int weight) { + this.weight = weight; + current.set(0); + } + public long increaseCurrent() { + return current.addAndGet(weight); + } + public void sel(int total) { + current.addAndGet(-1 * total); + } + public long getLastUpdate() { + return lastUpdate; + } + public void setLastUpdate(long lastUpdate) { + this.lastUpdate = lastUpdate; + } + } - private final ConcurrentMap sequences = new ConcurrentHashMap(); - + private ConcurrentMap> methodWeightMap = new ConcurrentHashMap>(); + private AtomicBoolean updateLock = new AtomicBoolean(); + + /** + * get invoker addr list cached for specified invocation + *

+ * for unit test only + * + * @param invokers + * @param invocation + * @return + */ + protected Collection getInvokerAddrList(List> invokers, Invocation invocation) { + String key = invokers.get(0).getUrl().getServiceKey() + "." + invocation.getMethodName(); + Map map = methodWeightMap.get(key); + if (map != null) { + return map.keySet(); + } + return null; + } + @Override protected Invoker doSelect(List> invokers, URL url, Invocation invocation) { String key = invokers.get(0).getUrl().getServiceKey() + "." + invocation.getMethodName(); - int length = invokers.size(); // Number of invokers - int maxWeight = 0; // The maximum weight - int minWeight = Integer.MAX_VALUE; // The minimum weight - final LinkedHashMap, IntegerWrapper> invokerToWeightMap = new LinkedHashMap, IntegerWrapper>(); - int weightSum = 0; - for (int i = 0; i < length; i++) { - int weight = getWeight(invokers.get(i), invocation); - maxWeight = Math.max(maxWeight, weight); // Choose the maximum weight - minWeight = Math.min(minWeight, weight); // Choose the minimum weight - if (weight > 0) { - invokerToWeightMap.put(invokers.get(i), new IntegerWrapper(weight)); - weightSum += weight; - } + ConcurrentMap map = methodWeightMap.get(key); + if (map == null) { + methodWeightMap.putIfAbsent(key, new ConcurrentHashMap()); + map = methodWeightMap.get(key); } - AtomicPositiveInteger sequence = sequences.get(key); - if (sequence == null) { - sequences.putIfAbsent(key, new AtomicPositiveInteger()); - sequence = sequences.get(key); + int totalWeight = 0; + long maxCurrent = Long.MIN_VALUE; + long now = System.currentTimeMillis(); + Invoker selectedInvoker = null; + WeightedRoundRobin selectedWRR = null; + for (Invoker invoker : invokers) { + String identifyString = invoker.getUrl().toIdentityString(); + WeightedRoundRobin weightedRoundRobin = map.get(identifyString); + int weight = getWeight(invoker, invocation); + if (weight < 0) { + weight = 0; + } + if (weightedRoundRobin == null) { + weightedRoundRobin = new WeightedRoundRobin(); + weightedRoundRobin.setWeight(weight); + map.putIfAbsent(identifyString, weightedRoundRobin); + weightedRoundRobin = map.get(identifyString); + } + if (weight != weightedRoundRobin.getWeight()) { + //weight changed + weightedRoundRobin.setWeight(weight); + } + long cur = weightedRoundRobin.increaseCurrent(); + weightedRoundRobin.setLastUpdate(now); + if (cur > maxCurrent) { + maxCurrent = cur; + selectedInvoker = invoker; + selectedWRR = weightedRoundRobin; + } + totalWeight += weight; } - int currentSequence = sequence.getAndIncrement(); - if (maxWeight > 0 && minWeight < maxWeight) { - int mod = currentSequence % weightSum; - for (int i = 0; i < maxWeight; i++) { - for (Map.Entry, IntegerWrapper> each : invokerToWeightMap.entrySet()) { - final Invoker k = each.getKey(); - final IntegerWrapper v = each.getValue(); - if (mod == 0 && v.getValue() > 0) { - return k; - } - if (v.getValue() > 0) { - v.decrement(); - mod--; + if (!updateLock.get() && invokers.size() != map.size()) { + if (updateLock.compareAndSet(false, true)) { + try { + // copy -> modify -> update reference + ConcurrentMap newMap = new ConcurrentHashMap(); + newMap.putAll(map); + Iterator> it = newMap.entrySet().iterator(); + while (it.hasNext()) { + Entry item = it.next(); + if (now - item.getValue().getLastUpdate() > RECYCLE_PERIOD) { + it.remove(); + } } + methodWeightMap.put(key, newMap); + } finally { + updateLock.set(false); } } } - // Round robin - return invokers.get(currentSequence % length); - } - - private static final class IntegerWrapper { - private int value; - - public IntegerWrapper(int value) { - this.value = value; - } - - public int getValue() { - return value; - } - - public void setValue(int value) { - this.value = value; - } - - public void decrement() { - this.value--; + if (selectedInvoker != null) { + selectedWRR.sel(totalWeight); + return selectedInvoker; } + // should not happen here + return invokers.get(0); } } diff --git a/dubbo-cluster/src/test/java/com/alibaba/dubbo/rpc/cluster/StickyTest.java b/dubbo-cluster/src/test/java/com/alibaba/dubbo/rpc/cluster/StickyTest.java index 080003824e2..b3a022570c5 100644 --- a/dubbo-cluster/src/test/java/com/alibaba/dubbo/rpc/cluster/StickyTest.java +++ b/dubbo-cluster/src/test/java/com/alibaba/dubbo/rpc/cluster/StickyTest.java @@ -114,12 +114,12 @@ public int testSticky(String methodName, boolean check) { given(invoker1.invoke(invocation)).willReturn(result); given(invoker1.isAvailable()).willReturn(true); - given(invoker1.getUrl()).willReturn(url); + given(invoker1.getUrl()).willReturn(url.setPort(1)); given(invoker1.getInterface()).willReturn(StickyTest.class); given(invoker2.invoke(invocation)).willReturn(result); given(invoker2.isAvailable()).willReturn(true); - given(invoker2.getUrl()).willReturn(url); + given(invoker2.getUrl()).willReturn(url.setPort(2)); given(invoker2.getInterface()).willReturn(StickyTest.class); invocation.setMethodName(methodName); @@ -158,4 +158,4 @@ public Invoker getSelectedInvoker() { return selectedInvoker; } } -} \ No newline at end of file +} diff --git a/dubbo-cluster/src/test/java/com/alibaba/dubbo/rpc/cluster/loadbalance/LoadBalanceTest.java b/dubbo-cluster/src/test/java/com/alibaba/dubbo/rpc/cluster/loadbalance/LoadBalanceTest.java index bdede14a4fa..d231c63c3b2 100644 --- a/dubbo-cluster/src/test/java/com/alibaba/dubbo/rpc/cluster/loadbalance/LoadBalanceTest.java +++ b/dubbo-cluster/src/test/java/com/alibaba/dubbo/rpc/cluster/loadbalance/LoadBalanceTest.java @@ -24,15 +24,21 @@ import com.alibaba.dubbo.rpc.RpcInvocation; import com.alibaba.dubbo.rpc.RpcStatus; import com.alibaba.dubbo.rpc.cluster.LoadBalance; -import junit.framework.Assert; +import com.alibaba.fastjson.JSON; + +import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; +import java.lang.reflect.Field; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import static org.mockito.BDDMockito.given; @@ -55,7 +61,7 @@ public class LoadBalanceTest { RpcStatus weightTestRpcStatus2; RpcStatus weightTestRpcStatus3; RpcInvocation weightTestInvocation; - + /** * @throws java.lang.Exception */ @@ -110,7 +116,11 @@ public void setUp() throws Exception { invokers.add(invoker4); invokers.add(invoker5); } - + + private AbstractLoadBalance getLoadBalance(String loadbalanceName) { + return (AbstractLoadBalance) ExtensionLoader.getExtensionLoader(LoadBalance.class).getExtension(loadbalanceName); + } + @Test public void testRoundRobinLoadBalance_select() { int runs = 10000; @@ -120,6 +130,120 @@ public void testRoundRobinLoadBalance_select() { Assert.assertTrue("abs diff shoud < 1", Math.abs(count - runs / (0f + invokers.size())) < 1f); } } + + private void assertStrictWRRResult(int runs, Map resultMap) { + for (InvokeResult invokeResult : resultMap.values()) { + // Because it's a strictly round robin, so the abs delta should be < 10 too + Assert.assertTrue("delta with expected count should < 10", + Math.abs(invokeResult.getExpected(runs) - invokeResult.getCount().get()) < 10); + } + } + + /** + * a multi-threaded test on weighted round robin + */ + @Test + public void testRoundRobinLoadBalanceWithWeight() { + final Map totalMap = new HashMap(); + final AtomicBoolean shouldBegin = new AtomicBoolean(false); + final int runs = 10000; + List threads = new ArrayList(); + int threadNum = 10; + for (int i = 0; i < threadNum; i ++) { + threads.add(new Thread() { + @Override + public void run() { + while (!shouldBegin.get()) { + try { + sleep(5); + } catch (InterruptedException e) { + } + } + Map resultMap = getWeightedInvokeResult(runs, RoundRobinLoadBalance.NAME); + synchronized (totalMap) { + for (Entry entry : resultMap.entrySet()) { + if (!totalMap.containsKey(entry.getKey())) { + totalMap.put(entry.getKey(), entry.getValue()); + } else { + totalMap.get(entry.getKey()).getCount().addAndGet(entry.getValue().getCount().get()); + } + } + } + } + }); + } + for (Thread thread : threads) { + thread.start(); + } + // let's rock it! + shouldBegin.set(true); + for (Thread thread : threads) { + try { + thread.join(); + } catch (InterruptedException e) { + } + } + assertStrictWRRResult(runs * threadNum, totalMap); + } + + @Test + public void testRoundRobinLoadBalanceWithWeightShouldNotRecycle() { + int runs = 10000; + //tmperately add a new invoker + weightInvokers.add(weightInvokerTmp); + try { + Map resultMap = getWeightedInvokeResult(runs, RoundRobinLoadBalance.NAME); + assertStrictWRRResult(runs, resultMap); + RoundRobinLoadBalance lb = (RoundRobinLoadBalance)getLoadBalance(RoundRobinLoadBalance.NAME); + Assert.assertEquals(weightInvokers.size(), lb.getInvokerAddrList(weightInvokers, weightTestInvocation).size()); + + //remove the last invoker and retry + weightInvokers.remove(weightInvokerTmp); + resultMap = getWeightedInvokeResult(runs, RoundRobinLoadBalance.NAME); + assertStrictWRRResult(runs, resultMap); + Assert.assertNotEquals(weightInvokers.size(), lb.getInvokerAddrList(weightInvokers, weightTestInvocation).size()); + } finally { + weightInvokers.remove(weightInvokerTmp); + } + } + + @Test + public void testRoundRobinLoadBalanceWithWeightShouldRecycle() { + { + Field recycleTimeField = null; + try { + //change recycle time to 1 ms + recycleTimeField = RoundRobinLoadBalance.class.getDeclaredField("RECYCLE_PERIOD"); + recycleTimeField.setAccessible(true); + recycleTimeField.setInt(RoundRobinLoadBalance.class, 10); + } catch (NoSuchFieldException e) { + Assert.assertTrue("getField failed", true); + } catch (SecurityException e) { + Assert.assertTrue("getField failed", true); + } catch (IllegalArgumentException e) { + Assert.assertTrue("getField failed", true); + } catch (IllegalAccessException e) { + Assert.assertTrue("getField failed", true); + } + } + int runs = 10000; + //temporarily add a new invoker + weightInvokers.add(weightInvokerTmp); + try { + Map resultMap = getWeightedInvokeResult(runs, RoundRobinLoadBalance.NAME); + assertStrictWRRResult(runs, resultMap); + RoundRobinLoadBalance lb = (RoundRobinLoadBalance)getLoadBalance(RoundRobinLoadBalance.NAME); + Assert.assertEquals(weightInvokers.size(), lb.getInvokerAddrList(weightInvokers, weightTestInvocation).size()); + + //remove the tmp invoker and retry, should recycle its cache + weightInvokers.remove(weightInvokerTmp); + resultMap = getWeightedInvokeResult(runs, RoundRobinLoadBalance.NAME); + assertStrictWRRResult(runs, resultMap); + Assert.assertEquals(weightInvokers.size(), lb.getInvokerAddrList(weightInvokers, weightTestInvocation).size()); + } finally { + weightInvokers.remove(weightInvokerTmp); + } + } @Test public void testSelectByWeightLeastActive() { @@ -170,32 +294,6 @@ public void testSelectByWeightRandom() { Assert.assertEquals("select failed!", sumInvoker1 + sumInvoker2 + sumInvoker3, loop); } - @Test - public void testSelectByWeight() { - int sumInvoker1 = 0; - int sumInvoker2 = 0; - int sumInvoker3 = 0; - int loop = 10000; - RoundRobinLoadBalance lb = new RoundRobinLoadBalance(); - for (int i = 0; i < loop; i++) { - Invoker selected = lb.select(weightInvokers, null, weightTestInvocation); - if (selected.getUrl().getProtocol().equals("test1")) { - sumInvoker1++; - } - if (selected.getUrl().getProtocol().equals("test2")) { - sumInvoker2++; - } - if (selected.getUrl().getProtocol().equals("test3")) { - sumInvoker3++; - } - } - // 1 : 9 : 6 - System.out.println(sumInvoker1); - System.out.println(sumInvoker2); - System.out.println(sumInvoker3); - Assert.assertEquals("select failed!", sumInvoker1 + sumInvoker2 + sumInvoker3, loop); - } - @Test public void testRandomLoadBalance_select() { int runs = 1000; @@ -234,15 +332,16 @@ public void testLeastActiveLoadBalance_select() { Math.abs(count - runs / (0f + invokers.size())) < runs / (0f + invokers.size())); } } - - public Map getInvokeCounter(int runs, String loadbalanceName) { + + private Map getInvokeCounter(int runs, String loadbalanceName) { Map counter = new ConcurrentHashMap(); - LoadBalance lb = ExtensionLoader.getExtensionLoader(LoadBalance.class).getExtension(loadbalanceName); + LoadBalance lb = getLoadBalance(loadbalanceName); for (Invoker invoker : invokers) { counter.put(invoker, new AtomicLong(0)); } + URL url = invokers.get(0).getUrl(); for (int i = 0; i < runs; i++) { - Invoker sinvoker = lb.select(invokers, invokers.get(0).getUrl(), invocation); + Invoker sinvoker = lb.select(invokers, url, invocation); counter.get(sinvoker).incrementAndGet(); } return counter; @@ -279,46 +378,115 @@ public void testLoadBalanceWarmup() { Assert.assertEquals(100, AbstractLoadBalance .calculateWarmupWeight(20 * 60 * 1000, Constants.DEFAULT_WARMUP, Constants.DEFAULT_WEIGHT)); } - + /*------------------------------------test invokers for weight---------------------------------------*/ protected List> weightInvokers = new ArrayList>(); protected Invoker weightInvoker1; protected Invoker weightInvoker2; protected Invoker weightInvoker3; + protected Invoker weightInvokerTmp; @Before - public void before() throws Exception { + public void setUpWeightInvokers() throws Exception { weightInvoker1 = mock(Invoker.class); weightInvoker2 = mock(Invoker.class); weightInvoker3 = mock(Invoker.class); + weightInvokerTmp = mock(Invoker.class); + weightTestInvocation = new RpcInvocation(); weightTestInvocation.setMethodName("test"); - URL url1 = URL.valueOf("test1://0:1/DemoService"); - url1 = url1.addParameter(Constants.WEIGHT_KEY, 1); - url1 = url1.addParameter(weightTestInvocation.getMethodName() + "." + Constants.WEIGHT_KEY, 1); - url1 = url1.addParameter("active", 0); - URL url2 = URL.valueOf("test2://0:9/DemoService"); - url2 = url2.addParameter(Constants.WEIGHT_KEY, 9); - url2 = url2.addParameter(weightTestInvocation.getMethodName() + "." + Constants.WEIGHT_KEY, 9); - url2 = url2.addParameter("active", 0); - URL url3 = URL.valueOf("test3://1:6/DemoService"); - url3 = url3.addParameter(Constants.WEIGHT_KEY, 6); - url3 = url3.addParameter(weightTestInvocation.getMethodName() + "." + Constants.WEIGHT_KEY, 6); - url3 = url3.addParameter("active", 1); + + URL url1 = URL.valueOf("test1://127.0.0.1:11/DemoService?weight=11&active=0"); + URL url2 = URL.valueOf("test2://127.0.0.1:12/DemoService?weight=97&active=0"); + URL url3 = URL.valueOf("test3://127.0.0.1:13/DemoService?weight=67&active=1"); + URL urlTmp = URL.valueOf("test4://127.0.0.1:9999/DemoService?weight=601&active=0"); + given(weightInvoker1.isAvailable()).willReturn(true); + given(weightInvoker1.getInterface()).willReturn(LoadBalanceTest.class); given(weightInvoker1.getUrl()).willReturn(url1); + given(weightInvoker2.isAvailable()).willReturn(true); + given(weightInvoker2.getInterface()).willReturn(LoadBalanceTest.class); given(weightInvoker2.getUrl()).willReturn(url2); + given(weightInvoker3.isAvailable()).willReturn(true); + given(weightInvoker3.getInterface()).willReturn(LoadBalanceTest.class); given(weightInvoker3.getUrl()).willReturn(url3); + + given(weightInvokerTmp.isAvailable()).willReturn(true); + given(weightInvokerTmp.getInterface()).willReturn(LoadBalanceTest.class); + given(weightInvokerTmp.getUrl()).willReturn(urlTmp); + weightInvokers.add(weightInvoker1); weightInvokers.add(weightInvoker2); weightInvokers.add(weightInvoker3); + weightTestRpcStatus1 = RpcStatus.getStatus(weightInvoker1.getUrl(), weightTestInvocation.getMethodName()); weightTestRpcStatus2 = RpcStatus.getStatus(weightInvoker2.getUrl(), weightTestInvocation.getMethodName()); weightTestRpcStatus3 = RpcStatus.getStatus(weightInvoker3.getUrl(), weightTestInvocation.getMethodName()); + // weightTestRpcStatus3 active is 1 RpcStatus.beginCount(weightInvoker3.getUrl(), weightTestInvocation.getMethodName()); } + + private static class InvokeResult { + private AtomicLong count = new AtomicLong(); + private int weight = 0; + private int totalWeight = 0; + + public InvokeResult(int weight) { + this.weight = weight; + } + + public AtomicLong getCount() { + return count; + } + + public int getWeight() { + return weight; + } + + public int getTotalWeight() { + return totalWeight; + } + + public void setTotalWeight(int totalWeight) { + this.totalWeight = totalWeight; + } + + public int getExpected(int runCount) { + return getWeight() * runCount / getTotalWeight(); + } + + public float getDeltaPercentage(int runCount) { + int expected = getExpected(runCount); + return Math.abs((expected - getCount().get()) * 100.0f / expected); + } + + @Override + public String toString() { + return JSON.toJSONString(this); + } + } + + private Map getWeightedInvokeResult(int runs, String loadbalanceName) { + Map counter = new ConcurrentHashMap(); + AbstractLoadBalance lb = getLoadBalance(loadbalanceName); + int totalWeight = 0; + for (int i = 0; i < weightInvokers.size(); i ++) { + InvokeResult invokeResult = new InvokeResult(lb.getWeight(weightInvokers.get(i), weightTestInvocation)); + counter.put(weightInvokers.get(i), invokeResult); + totalWeight += invokeResult.getWeight(); + } + for (InvokeResult invokeResult : counter.values()) { + invokeResult.setTotalWeight(totalWeight); + } + for (int i = 0; i < runs; i++) { + Invoker sinvoker = lb.select(weightInvokers, weightInvokers.get(0).getUrl(), weightTestInvocation); + counter.get(sinvoker).getCount().incrementAndGet(); + } + return counter; + } + } \ No newline at end of file diff --git a/dubbo-cluster/src/test/java/com/alibaba/dubbo/rpc/cluster/support/AbstractClusterInvokerTest.java b/dubbo-cluster/src/test/java/com/alibaba/dubbo/rpc/cluster/support/AbstractClusterInvokerTest.java index cac83e93ea5..8883ed68a23 100644 --- a/dubbo-cluster/src/test/java/com/alibaba/dubbo/rpc/cluster/support/AbstractClusterInvokerTest.java +++ b/dubbo-cluster/src/test/java/com/alibaba/dubbo/rpc/cluster/support/AbstractClusterInvokerTest.java @@ -91,23 +91,23 @@ public void setUp() throws Exception { given(invoker1.isAvailable()).willReturn(false); given(invoker1.getInterface()).willReturn(IHelloService.class); - given(invoker1.getUrl()).willReturn(turl.addParameter("name", "invoker1")); + given(invoker1.getUrl()).willReturn(turl.setPort(1).addParameter("name", "invoker1")); given(invoker2.isAvailable()).willReturn(true); given(invoker2.getInterface()).willReturn(IHelloService.class); - given(invoker2.getUrl()).willReturn(turl.addParameter("name", "invoker2")); + given(invoker2.getUrl()).willReturn(turl.setPort(2).addParameter("name", "invoker2")); given(invoker3.isAvailable()).willReturn(false); given(invoker3.getInterface()).willReturn(IHelloService.class); - given(invoker3.getUrl()).willReturn(turl.addParameter("name", "invoker3")); + given(invoker3.getUrl()).willReturn(turl.setPort(3).addParameter("name", "invoker3")); given(invoker4.isAvailable()).willReturn(true); given(invoker4.getInterface()).willReturn(IHelloService.class); - given(invoker4.getUrl()).willReturn(turl.addParameter("name", "invoker4")); + given(invoker4.getUrl()).willReturn(turl.setPort(4).addParameter("name", "invoker4")); given(invoker5.isAvailable()).willReturn(false); given(invoker5.getInterface()).willReturn(IHelloService.class); - given(invoker5.getUrl()).willReturn(turl.addParameter("name", "invoker5")); + given(invoker5.getUrl()).willReturn(turl.setPort(5).addParameter("name", "invoker5")); given(mockedInvoker1.isAvailable()).willReturn(false); given(mockedInvoker1.getInterface()).willReturn(IHelloService.class);