Skip to content

Commit

Permalink
[feature] add AliasMethod, with benchmark.
Browse files Browse the repository at this point in the history
  • Loading branch information
PhantomThief committed Apr 7, 2020
1 parent 9225de7 commit aa6c9a4
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 1 deletion.
100 changes: 100 additions & 0 deletions src/main/java/com/github/phantomthief/failover/util/AliasMethod.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package com.github.phantomthief.failover.util;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ThreadLocalRandom;

import javax.annotation.Nonnull;

/**
* http://www.keithschwarz.com/darts-dice-coins/
*
* @author w.vela
* Created on 2020-04-07.
*/
public class AliasMethod<T> {

private final Object[] values;
private final int[] alias;
private final double[] probability;

public AliasMethod(@Nonnull Map<T, Integer> weightMap) {
checkNotNull(weightMap);
checkArgument(weightMap.size() > 0);
List<Double> probabilities = new ArrayList<>(weightMap.size());
List<T> valueList = new ArrayList<>(weightMap.size());
long sum = 0;
for (Entry<T, Integer> entry : weightMap.entrySet()) {
Integer weight = entry.getValue();
if (weight > 0) {
sum += weight;
valueList.add(entry.getKey());
}
}
for (Entry<T, Integer> entry : weightMap.entrySet()) {
Integer weight = entry.getValue();
if (weight > 0) {
probabilities.add((double) weight / sum);
}
}
checkArgument(sum > 0);
values = valueList.toArray(new Object[0]);

int size = probabilities.size();
probability = new double[size];
alias = new int[size];

double average = 1.0 / size;

probabilities = new ArrayList<>(probabilities);

Deque<Integer> small = new ArrayDeque<>();
Deque<Integer> large = new ArrayDeque<>();

for (int i = 0; i < size; ++i) {
if (probabilities.get(i) >= average) {
large.add(i);
} else {
small.add(i);
}
}

while (!small.isEmpty() && !large.isEmpty()) {
int less = small.removeLast();
int more = large.removeLast();

probability[less] = probabilities.get(less) * size;
alias[less] = more;

probabilities.set(more, probabilities.get(more) + probabilities.get(less) - average);

if (probabilities.get(more) >= 1.0 / size) {
large.add(more);
} else {
small.add(more);
}
}

while (!small.isEmpty()) {
probability[small.removeLast()] = 1.0;
}
while (!large.isEmpty()) {
probability[large.removeLast()] = 1.0;
}
}

@SuppressWarnings("unchecked")
public T get() {
int column = ThreadLocalRandom.current().nextInt(probability.length);
boolean coinToss = ThreadLocalRandom.current().nextDouble() < probability[column];
int index = coinToss ? column : alias[column];
return (T) values[index];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

/**
* 带权重的树
*
* 如果只使用 {@link #get()},可以考虑使用 {@link AliasMethod},性能更好
*
* @author w.vela
* @param <T>
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package com.github.phantomthief.failover.impl.benchmark;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Threads;
import org.openjdk.jmh.annotations.Warmup;

import com.github.phantomthief.failover.util.AliasMethod;
import com.github.phantomthief.failover.util.Weight;

/**
* Benchmark (totalSize) Mode Cnt Score Error Units
* WeightBenchmark.testAliasMethod 10 thrpt 3 531373408.041 ± 1310821420.530 ops/s
* WeightBenchmark.testAliasMethod 100 thrpt 3 524611078.970 ± 567444175.501 ops/s
* WeightBenchmark.testAliasMethod 1000 thrpt 3 442690581.938 ± 705466764.645 ops/s
* WeightBenchmark.testWeight 10 thrpt 3 101635713.806 ± 28607384.102 ops/s
* WeightBenchmark.testWeight 100 thrpt 3 59522839.677 ± 33694178.059 ops/s
* WeightBenchmark.testWeight 1000 thrpt 3 36993978.805 ± 4766898.860 ops/s
*
* @author w.vela
* Created on 2020-04-07.
*/
@BenchmarkMode(Mode.Throughput)
@Fork(1)
@Threads(10)
@Warmup(iterations = 1, time = 1)
@Measurement(iterations = 3, time = 1)
@State(Scope.Benchmark)
public class WeightBenchmark {

@Param({"10", "100", "1000"})
private int totalSize;

private Weight<String> weight;
private AliasMethod<String> aliasMethod;

@Setup
public void init() {
weight = new Weight<>();
Map<String, Integer> weightMap = new HashMap<>();
for (int i = 0; i < totalSize; i++) {
int weightValue = ThreadLocalRandom.current().nextInt(1, 100);
String node = "key" + i;
weight.add(node, weightValue);
weightMap.put(node, weightValue);
}
aliasMethod = new AliasMethod<>(weightMap);
}


@Benchmark
public void testWeight() {
weight.get();
}

@Benchmark
public void testAliasMethod() {
aliasMethod.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.junit.jupiter.api.Test;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Multiset;

/**
Expand Down Expand Up @@ -39,4 +40,26 @@ void test() {

assertEquals(3, weight.allNodes().size());
}

@Test
void testAliasMethod() {
AliasMethod<String> weight = new AliasMethod<String>(ImmutableMap.<String, Integer> builder()
.put("s1", 1)
.put("s2", 2)
.put("s3", 3)
.build()
);
Multiset<String> result = HashMultiset.create();
for (int i = 0; i < 10000; i++) {
result.add(weight.get());
}
assertTrue(checkRatio(result.count("s2"), result.count("s1"), 2));
assertTrue(checkRatio(result.count("s3"), result.count("s1"), 3));

result.clear();

for (int i = 0; i < 10000; i++) {
result.add(weight.get());
}
}
}

0 comments on commit aa6c9a4

Please sign in to comment.