diff --git a/raft/src/main/java/net/kuujo/copycat/raft/StateMachine.java b/raft/src/main/java/net/kuujo/copycat/raft/StateMachine.java index 1625a76a93..16ebd40930 100644 --- a/raft/src/main/java/net/kuujo/copycat/raft/StateMachine.java +++ b/raft/src/main/java/net/kuujo/copycat/raft/StateMachine.java @@ -24,6 +24,8 @@ import java.lang.reflect.Method; import java.util.HashMap; import java.util.Map; +import java.util.function.BiPredicate; +import java.util.function.Function; /** * Raft state machine. @@ -32,10 +34,10 @@ */ public abstract class StateMachine { private final Logger LOGGER = LoggerFactory.getLogger(getClass()); - private final Map, Method>> filters = new HashMap<>(); - private Map allFilters = new HashMap<>(); - private final Map, Method> operations = new HashMap<>(); - private Method allOperation; + private final Map, BiPredicate, Compaction>>> filters = new HashMap<>(); + private Map, Compaction>> allFilters = new HashMap<>(); + private final Map, Function, ?>> operations = new HashMap<>(); + private Function, ?> allOperation; protected StateMachine() { init(); @@ -79,37 +81,64 @@ private void declareFilters(Method method) { for (Class command : filter.value()) { if (command == Filter.All.class) { if (!allFilters.containsKey(filter.compaction())) { - allFilters.put(filter.compaction(), method); + if (method.getParameterCount() == 1) { + allFilters.put(filter.compaction(), wrapFilter(method)); + } } } else { - Map, Method> filters = this.filters.get(filter.compaction()); + Map, BiPredicate, Compaction>> filters = this.filters.get(filter.compaction()); if (filters == null) { filters = new HashMap<>(); this.filters.put(filter.compaction(), filters); } if (!filters.containsKey(command)) { - filters.put(command, method); + filters.put(command, wrapFilter(method)); } } } } } + /** + * Wraps a filter method. + */ + private BiPredicate, Compaction> wrapFilter(Method method) { + if (method.getParameterCount() == 1) { + return (commit, compaction) -> { + try { + return (boolean) method.invoke(this, commit); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new ApplicationException("failed to filter command", e); + } + }; + } else if (method.getParameterCount() == 2) { + return (commit, compaction) -> { + try { + return (boolean) method.invoke(this, commit, compaction); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new ApplicationException("failed to filter command", e); + } + }; + } else { + throw new IllegalStateException("invalid filter method: too many parameters"); + } + } + /** * Finds the filter method for the given command. */ - private Method findFilter(Class type, Compaction.Type compaction) { - Map, Method> filters = this.filters.get(compaction); + private BiPredicate, Compaction> findFilter(Class type, Compaction.Type compaction) { + Map, BiPredicate, Compaction>> filters = this.filters.get(compaction); if (filters == null) { - Method method = allFilters.get(compaction); - if (method == null) { + BiPredicate, Compaction> filter = allFilters.get(compaction); + if (filter == null) { throw new IllegalArgumentException("unknown command type: " + type); } - return method; + return filter; } - Method method = filters.computeIfAbsent(type, t -> { - for (Map.Entry, Method> entry : filters.entrySet()) { + BiPredicate, Compaction> filter = filters.computeIfAbsent(type, t -> { + for (Map.Entry, BiPredicate, Compaction>> entry : filters.entrySet()) { if (entry.getKey().isAssignableFrom(type)) { return entry.getValue(); } @@ -117,10 +146,10 @@ private Method findFilter(Class type, Compaction.Type compact return allFilters.get(compaction); }); - if (method == null) { + if (filter == null) { throw new IllegalArgumentException("unknown command type: " + type); } - return method; + return filter; } /** @@ -132,20 +161,39 @@ private void declareOperations(Method method) { method.setAccessible(true); for (Class operation : apply.value()) { if (operation == Apply.All.class) { - allOperation = method; + allOperation = wrapOperation(method); } else if (!operations.containsKey(operation)) { - operations.put(operation, method); + operations.put(operation, wrapOperation(method)); } } } } + /** + * Wraps an operation method. + */ + private Function, ?> wrapOperation(Method method) { + if (method.getParameterCount() < 1) { + throw new IllegalStateException("invalid operation method: not enough arguments"); + } else if (method.getParameterCount() > 1) { + throw new IllegalStateException("invalid operation method: too many arguments"); + } else { + return commit -> { + try { + return method.invoke(this, commit); + } catch (IllegalAccessException | InvocationTargetException e) { + return new ApplicationException("failed to invoke operation", e); + } + }; + } + } + /** * Finds the operation method for the given operation. */ - private Method findOperation(Class type) { - Method method = operations.computeIfAbsent(type, t -> { - for (Map.Entry, Method> entry : operations.entrySet()) { + private Function, ?> findOperation(Class type) { + Function, ?> operation = operations.computeIfAbsent(type, t -> { + for (Map.Entry, Function, ?>> entry : operations.entrySet()) { if (entry.getKey().isAssignableFrom(type)) { return entry.getValue(); } @@ -153,10 +201,10 @@ private Method findOperation(Class type) { return allOperation; }); - if (method == null) { + if (operation == null) { throw new IllegalArgumentException("unknown operation type: " + type); } - return method; + return operation; } /** @@ -176,12 +224,8 @@ public void register(Session session) { * @return Whether to keep the commit. */ public boolean filter(Commit commit, Compaction compaction) { - LOGGER.debug("filter {}", commit); - try { - return (boolean) findFilter(commit.type(), compaction.type()).invoke(this, commit, compaction); - } catch (IllegalAccessException | InvocationTargetException e) { - throw new ApplicationException("failed to filter command", e); - } + LOGGER.debug("Filtering {}", commit); + return findFilter(commit.type(), compaction.type()).test(commit, compaction); } /** @@ -191,12 +235,8 @@ public boolean filter(Commit commit, Compaction compaction) { * @return The operation result. */ public Object apply(Commit commit) { - LOGGER.debug("apply {}", commit); - try { - return findOperation(commit.type()).invoke(this, commit); - } catch (IllegalAccessException | InvocationTargetException e) { - return new ApplicationException("failed to invoke operation", e); - } + LOGGER.debug("Applying {}", commit); + return findOperation(commit.type()).apply(commit); } /**