From be9953f85e15188eb5dc70a574be560d3853ad7d Mon Sep 17 00:00:00 2001 From: danielsun1106 Date: Sat, 24 Nov 2018 21:01:18 +0800 Subject: [PATCH] GROOVY-8901: Add DGSM `countDistinct[By]`, `sum[By]`, `avg[By]`, etc. --- .../runtime/DefaultGroovyStaticMethods.java | 575 ++++++++++++++++++ .../DefaultGroovyStaticMethodsTest.groovy | 98 ++- 2 files changed, 668 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/codehaus/groovy/runtime/DefaultGroovyStaticMethods.java b/src/main/java/org/codehaus/groovy/runtime/DefaultGroovyStaticMethods.java index 0e801577d7e..54e4c62bf10 100644 --- a/src/main/java/org/codehaus/groovy/runtime/DefaultGroovyStaticMethods.java +++ b/src/main/java/org/codehaus/groovy/runtime/DefaultGroovyStaticMethods.java @@ -24,17 +24,27 @@ import java.io.File; import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.RoundingMode; import java.text.ParseException; import java.text.SimpleDateFormat; +import java.util.Comparator; import java.util.Date; +import java.util.HashSet; import java.util.Locale; import java.util.Optional; import java.util.ResourceBundle; import java.util.TimeZone; +import java.util.function.Function; +import java.util.function.Predicate; import java.util.regex.Matcher; import java.util.stream.Collector; import java.util.stream.Collectors; +import static java.util.Comparator.naturalOrder; +import static java.util.stream.Collectors.collectingAndThen; + /** * This class defines all the new static groovy methods which appear on normal * JDK classes inside the Groovy environment. Static methods are used with the @@ -286,4 +296,569 @@ public static long currentTimeSeconds(System self){ return Collectors.reducing((v1, v2) -> v2); } + /** + * Returns a {@link Collector} that calculates the count + * + * @return a {@link Collector} which implements the count operation + * @since 3.0.0 + */ + public static Collector count(Collectors self) { + return Collectors.counting(); + } + + /** + * Returns a {@link Collector} that calculates the distinct count + * + * @return a {@link Collector} which implements the distinct count operation + * @since 3.0.0 + */ + public static Collector countDistinct(Collectors self) { + return countDistinctBy(self, t -> t); + } + + /** + * Returns a {@link Collector} that calculates the distinct count by function + * + * @return a {@link Collector} which implements the distinct count operation by function + * @since 3.0.0 + */ + public static Collector countDistinctBy(Collectors self, Function function) { + return Collector.of( + () -> new HashSet(), + (s, v) -> s.add(function.apply(v)), + (s1, s2) -> { + s1.addAll(s2); + return s1; + }, + s -> (long) s.size(), + Collector.Characteristics.UNORDERED + ); + } + + /** + * Returns a {@link Collector} that calculates sum for any type of {@link Number}. + * + * @return a {@link Collector} which implements the sum operation + * @since 3.0.0 + */ + public static Collector> sum(Collectors self) { + return sumBy(self, t -> t); + } + + /** + * Returns a {@link Collector} that calculates sum for any type of {@link Number} by function + * + * @return a {@link Collector} which implements the sum operation by function + * @since 3.0.0 + */ + @SuppressWarnings({"rawtypes", "unchecked"}) + public static Collector> sumBy(Collectors self, Function function) { + return Collector.of(() -> (Sum[]) new Sum[1], + (s, v) -> { + if (s[0] == null) + s[0] = Sum.create(function.apply(v)); + else + s[0].add(function.apply(v)); + }, + (s1, s2) -> { + s1[0].add(s2[0]); + return s1; + }, + s -> s[0] == null ? Optional.empty() : Optional.of(s[0].result()), + Collector.Characteristics.UNORDERED + ); + } + + /** + * Returns a {@link Collector} that calculates avg for any type of {@link Number}. + * + * @return a {@link Collector} which implements the avg operation + * @since 3.0.0 + */ + public static Collector> avg(Collectors self) { + return avgBy(self, t -> t); + } + + /** + * Returns a {@link Collector} that calculates avg for any type of {@link Number} by function + * + * @return a {@link Collector} which implements the avg operation by function + * @since 3.0.0 + */ + @SuppressWarnings({"rawtypes", "unchecked"}) + public static Collector> avgBy(Collectors self, Function function) { + return Collector.of( + () -> (Sum[]) new Sum[1], + (s, v) -> { + if (s[0] == null) + s[0] = Sum.create(function.apply(v)); + else + s[0].add(function.apply(v)); + }, + (s1, s2) -> { + s1[0].add(s2[0]); + return s1; + }, + s -> s[0] == null ? Optional.empty() : Optional.of(s[0].avg()), + Collector.Characteristics.UNORDERED + ); + } + + /** + * Returns a {@link Collector} that calculates min + * + * @return a {@link Collector} which implements the min operation + * @since 3.0.0 + */ + public static > Collector> min(Collectors self) { + return minBy(self, naturalOrder(), t -> t); + } + + /** + * Returns a {@link Collector} that calculates min by the order specified by the comparator + * + * @return a {@link Collector} which implements the min operation + * @since 3.0.0 + */ + public static Collector> min(Collectors self, Comparator comparator) { + return minBy(self, comparator, t -> t); + } + + /** + * Returns a {@link Collector} that calculates min by the function and the order specified by the comparator + * + * @return a {@link Collector} which implements the min operation + * @since 3.0.0 + */ + public static Collector> minBy(Collectors self, Comparator comparator, Function function) { + return maxBy(self, comparator.reversed(), function); + } + + /** + * Returns a {@link Collector} that calculates max + * + * @return a {@link Collector} which implements the max operation + * @since 3.0.0 + */ + public static > Collector> max(Collectors self) { + return maxBy(self, naturalOrder(), t -> t); + } + + /** + * Returns a {@link Collector} that calculates max by the order specified by the comparator + * + * @return a {@link Collector} which implements the max operation + * @since 3.0.0 + */ + public static Collector> max(Collectors self, Comparator comparator) { + return maxBy(self, comparator, t -> t); + } + + /** + * Returns a {@link Collector} that calculates max by the function and the order specified by the comparator + * + * @return a {@link Collector} which implements the max operation + * @since 3.0.0 + */ + public static Collector> maxBy(Collectors self, Comparator comparator, Function function) { + class Accumulator { + T t; + U u; + } + + return Collector.of( + () -> new Accumulator(), + (a, t) -> { + U u = function.apply(t); + + if (a.u == null || comparator.compare(a.u, u) < 0) { + a.t = t; + a.u = u; + } + }, + (a1, a2) -> + a1.u == null + ? a2 + : a2.u == null + ? a1 + : comparator.compare(a1.u, a2.u) < 0 + ? a2 + : a1, + a -> Optional.ofNullable(a.t) + ); + } + + /** + * Returns a {@link Collector} that calculates whether all match + * + * @return a {@link Collector} which implements the all match operation + * @since 3.0.0 + */ + public static Collector allMatch(Collectors self) { + return allMatchBy(self, t -> t); + } + + /** + * Returns a {@link Collector} that calculates whether all match + * + * @return a {@link Collector} which implements the all match operation + * @since 3.0.0 + */ + public static Collector allMatchBy(Collectors self, Predicate predicate) { + return Collector.of( + () -> new Boolean[1], + (a, t) -> { + if (a[0] == null) + a[0] = predicate.test(t); + else + a[0] = a[0] && predicate.test(t); + }, + (a1, a2) -> { + a1[0] = a1[0] && a2[0]; + return a1; + }, + a -> a[0] == null || a[0], + Collector.Characteristics.UNORDERED + ); + } + + /** + * Returns a {@link Collector} that calculates whether any match + * + * @return a {@link Collector} which implements the any match operation + * @since 3.0.0 + */ + public static Collector anyMatch(Collectors self) { + return anyMatchBy(self, t -> t); + } + + /** + * Returns a {@link Collector} that calculates whether any match + * + * @return a {@link Collector} which implements the any match operation + * @since 3.0.0 + */ + public static Collector anyMatchBy(Collectors self, Predicate predicate) { + return collectingAndThen(noneMatchBy(self, predicate), t -> !t); + } + + /** + * Returns a {@link Collector} that calculates whether none match + * + * @return a {@link Collector} which implements the none match operation + * @since 3.0.0 + */ + public static Collector noneMatch(Collectors self) { + return noneMatchBy(self, t -> t); + } + + /** + * Returns a {@link Collector} that calculates whether none match + * + * @return a {@link Collector} which implements the none match operation + * @since 3.0.0 + */ + public static Collector noneMatchBy(Collectors self, Predicate predicate) { + return allMatchBy(self, predicate.negate()); + } + + + private static abstract class Sum { + long count; + + void add(Sum sum) { + add0(sum.result()); + count += sum.count; + } + + void add(N value) { + add0(value); + count += 1; + } + + void and(Sum sum) { + and0(sum.result()); + } + + void and(N value) { + and0(value); + } + + void or(Sum sum) { + or0(sum.result()); + } + + void or(N value) { + or0(value); + } + + abstract void add0(N value); + + abstract void and0(N value); + + abstract void or0(N value); + + abstract N result(); + + abstract N avg(); + + @SuppressWarnings({"unchecked", "rawtypes"}) + static Sum create(N value) { + Sum result; + + if (value instanceof Byte) + result = (Sum) new OfByte(); + else if (value instanceof Short) + result = (Sum) new OfShort(); + else if (value instanceof Integer) + result = (Sum) new OfInt(); + else if (value instanceof Long) + result = (Sum) new OfLong(); + else if (value instanceof Float) + result = (Sum) new OfFloat(); + else if (value instanceof Double) + result = (Sum) new OfDouble(); + else if (value instanceof BigInteger) + result = (Sum) new OfBigInteger(); + else if (value instanceof BigDecimal) + result = (Sum) new OfBigDecimal(); + else + throw new IllegalArgumentException("Cannot calculate sums for value : " + value); + + result.add(value); + return result; + } + + static class OfByte extends Sum { + byte result; + + @Override + void add0(Byte value) { + result += value; + } + + @Override + Byte result() { + return result; + } + + @Override + void and0(Byte value) { + result &= value; + } + + @Override + void or0(Byte value) { + result |= value; + } + + @Override + Byte avg() { + return (byte) (result / count); + } + } + + static class OfShort extends Sum { + short sum; + + @Override + void add0(Short value) { + sum += value; + } + + @Override + void and0(Short value) { + sum &= value; + } + + @Override + void or0(Short value) { + sum |= value; + } + + @Override + Short result() { + return sum; + } + + @Override + Short avg() { + return (short) (sum / count); + } + } + + static class OfInt extends Sum { + int sum; + + @Override + void add0(Integer value) { + sum += value; + } + + @Override + void and0(Integer value) { + sum &= value; + } + + @Override + void or0(Integer value) { + sum |= value; + } + + @Override + Integer result() { + return sum; + } + + @Override + Integer avg() { + return (int) (sum / count); + } + } + + static class OfLong extends Sum { + long sum; + + @Override + void add0(Long value) { + sum += value; + } + + @Override + void and0(Long value) { + sum &= value; + } + + @Override + void or0(Long value) { + sum |= value; + } + + @Override + Long result() { + return sum; + } + + @Override + Long avg() { + return sum / count; + } + } + + static class OfFloat extends Sum { + float sum; + + @Override + void add0(Float value) { + sum += value; + } + + @Override + void and0(Float value) { + throw new UnsupportedOperationException(); + } + + @Override + void or0(Float value) { + throw new UnsupportedOperationException(); + } + + @Override + Float result() { + return sum; + } + + @Override + Float avg() { + return sum / (float) count; + } + } + + static class OfDouble extends Sum { + double sum; + + @Override + void add0(Double value) { + sum += value; + } + + @Override + void and0(Double value) { + throw new UnsupportedOperationException(); + } + + @Override + void or0(Double value) { + throw new UnsupportedOperationException(); + } + + @Override + Double result() { + return sum; + } + + @Override + Double avg() { + return sum / (double) count; + } + } + + static class OfBigInteger extends Sum { + BigInteger sum = BigInteger.ZERO; + + @Override + void add0(BigInteger value) { + sum = sum.add(value); + } + + @Override + void and0(BigInteger value) { + throw new UnsupportedOperationException(); + } + + @Override + void or0(BigInteger value) { + throw new UnsupportedOperationException(); + } + + @Override + BigInteger result() { + return sum; + } + + @Override + BigInteger avg() { + return sum.divide(BigInteger.valueOf(count)); + } + } + + static class OfBigDecimal extends Sum { + BigDecimal sum = BigDecimal.ZERO; + + @Override + void add0(BigDecimal value) { + sum = sum.add(value); + } + + @Override + void and0(BigDecimal value) { + throw new UnsupportedOperationException(); + } + + @Override + void or0(BigDecimal value) { + throw new UnsupportedOperationException(); + } + + @Override + BigDecimal result() { + return sum; + } + + @Override + BigDecimal avg() { + return sum.divide(BigDecimal.valueOf(count), RoundingMode.HALF_EVEN); + } + } + } } diff --git a/src/test/org/codehaus/groovy/runtime/DefaultGroovyStaticMethodsTest.groovy b/src/test/org/codehaus/groovy/runtime/DefaultGroovyStaticMethodsTest.groovy index daedb2406db..9de2fc9e758 100644 --- a/src/test/org/codehaus/groovy/runtime/DefaultGroovyStaticMethodsTest.groovy +++ b/src/test/org/codehaus/groovy/runtime/DefaultGroovyStaticMethodsTest.groovy @@ -21,6 +21,9 @@ package org.codehaus.groovy.runtime import java.util.stream.Collectors import java.util.stream.Stream +import static groovy.lang.Tuple.collectors +import static groovy.lang.Tuple.tuple + /** * Tests for DefaultGroovyStaticMethods */ @@ -30,8 +33,8 @@ class DefaultGroovyStaticMethodsTest extends GroovyTestCase { long timeMillis = System.currentTimeMillis() long timeSeconds = System.currentTimeSeconds() long timeMillis2 = System.currentTimeMillis() - assert timeMillis/1000 as int <= timeSeconds - assert timeMillis2/1000 as int >= timeSeconds + assert timeMillis/1000 as long <= timeSeconds + assert timeMillis2/1000 as long >= timeSeconds } void testFirst() { @@ -43,10 +46,95 @@ class DefaultGroovyStaticMethodsTest extends GroovyTestCase { } void testFirstAndLast() { - Tuple2 firstAndLastTuple = + Tuple2 t = + Stream.of(2, 3, 6, 5) + .collect(collectors(Collectors.first(), Collectors.last())) + .map1(Optional::get).map2(Optional::get) + assert tuple(2, 5) == t + } + + void testCountDistinct() { + Tuple2 t = Stream.of(2 , 3, 4, 5, 6, 2, 3, 4, 5, 6) + .collect(collectors(Collectors.count(), Collectors.countDistinct())) + + assert tuple(10L, 5L) == t + } + + void testCountDistinctBy() { + Tuple2 t = Stream.of('a', 'ab', 'abc', 'a', 'ab', 'abc') + .collect(collectors(Collectors.count(), Collectors.countDistinctBy(String::length))) + + assert tuple(6L, 3L) == t + } + + void testSum() { + Tuple1 t = Stream.of(1, 2, 3) + .collect(collectors(Collectors.sum())) + .map1(Optional::get) + + assert tuple(6) == t + } + + void testSumBy() { + Tuple1 t = Stream.of('a', 'ab', 'abc', 'abcd') + .collect(collectors(Collectors.sumBy(String::length))) + .map1(Optional::get) + + assert tuple(10) == t + } + + void testAvg() { + Tuple1 t = Stream.of(1, 2, 3) + .collect(collectors(Collectors.avg())) + .map1(Optional::get) + + assert tuple(2) == t + } + + void testAvgBy() { + Tuple1 t = Stream.of('ab', 'abcd') + .collect(collectors(Collectors.avgBy(String::length))) + .map1(Optional::get) + + assert tuple(3) == t + } + + void testMinAndMax() { + Tuple2 t = Stream.of(2, 3, 6, 5) - .collect(Tuple.collectors(Collectors.first(), Collectors.last())) + .collect(collectors(Collectors.min(), Collectors.max())) + .map1(Optional::get).map2(Optional::get) + + assert tuple(2, 6) == t + + Tuple2 t2 = + Stream.of('ab', 'c', 'abc', 'efgh', 'de', 'fgh') + .collect(collectors(Collectors.min((o1, o2) -> o1.length() <=> o2.length()), Collectors.max((o1, o2) -> o1.length() <=> o2.length()))) + .map1(Optional::get).map2(Optional::get) + + assert tuple('c', 'efgh') == t2 + } + + void testMinByAndMaxBy() { + Tuple2 t = + Stream.of('ab', 'c', 'abc', 'efgh', 'de', 'fgh') + .collect(collectors(Collectors.minBy((o1, o2) -> o1 <=> o2, e -> e.length()), Collectors.maxBy((o1, o2) -> o1 <=> o2, e -> e.length()))) .map1(Optional::get).map2(Optional::get) - assert Tuple.tuple(2, 5) == firstAndLastTuple + + assert tuple('c', 'efgh') == t + } + + void testAllMatchAndNoneMatchAndAnyMatch() { + Tuple3 t = + Stream.of(true, false, true) + .collect(collectors(Collectors.allMatch(), Collectors.noneMatch(), Collectors.anyMatch())) + + assert tuple(false, false, true) == t + + Tuple3 t2 = + Stream.of(2, 4, 6, 8, 10) + .collect(collectors(Collectors.allMatchBy(e -> 0 == e % 2), Collectors.noneMatchBy(e -> 1 == e % 2), Collectors.anyMatchBy(e -> e > 10))) + + assert tuple(true, true, false) == t2 } }