Skip to content

Commit

Permalink
Fixes #2182: added apoc.agg.rollup procedure (#4064)
Browse files Browse the repository at this point in the history
* Fixes #2182: added apoc.agg.rollup procedure

* updated extended.txt
  • Loading branch information
vga91 committed May 15, 2024
1 parent 081514b commit 4dcd9ca
Show file tree
Hide file tree
Showing 13 changed files with 1,260 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ apoc.agg.multiStats(value :: NODE | RELATIONSHIP, keys :: LIST OF STRING) :: (MA
|===


[[usage-apoc.data.email]]
[[usage-apoc.agg.multiStats]]
== Usage Examples

Given this dataset:
Expand Down

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions docs/asciidoc/modules/ROOT/pages/overview/apoc.agg/index.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,13 @@ Returns index of the `element` that match the given `predicate`

apoc.agg.multiStats(nodeOrRel, keys) - Return a multi-dimensional aggregation
|label:function[]


|xref::overview/apoc.agg/apoc.agg.rollup.adoc[apoc.agg.rollup icon:book[]]

apoc.agg.rollup(<ANY>, [groupKeys], [aggKeys])

Emulate an Oracle/Mysql rollup command: `ROLLUP groupKeys, SUM(aggKey1), AVG(aggKey1), COUNT(aggKey1), SUM(aggKey2), AVG(aggKey2), ... `
|label:function[]
|===

Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ Returns index of the `element` that match the given `predicate`

apoc.agg.multiStats(nodeOrRel, keys) - Return a multi-dimensional aggregation
|label:procedure[]

|xref::overview/apoc.agg/apoc.agg.rollup.adoc[apoc.agg.rollup icon:book[]]

apoc.agg.rollup(<ANY>, [groupKeys], [aggKeys])

Emulate an Oracle/Mysql rollup command: `ROLLUP groupKeys, SUM(aggKey1), AVG(aggKey1), COUNT(aggKey1), SUM(aggKey2), AVG(aggKey2), ... `
|label:procedure[]
|===


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This file is generated by DocsTest, so don't change it!
*** xref::overview/apoc.agg/apoc.agg.row.adoc[]
*** xref::overview/apoc.agg/apoc.agg.position.adoc[]
*** xref::overview/apoc.agg/apoc.agg.multiStats.adoc[]
*** xref::overview/apoc.agg/apoc.agg.rollup.adoc[]
** xref::overview/apoc.bolt/index.adoc[]
*** xref::overview/apoc.bolt/apoc.bolt.execute.adoc[]
*** xref::overview/apoc.bolt/apoc.bolt.load.adoc[]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
The procedure support the following properties in the APOC configuration file (`apoc.conf`):

.Config parameters
[opts=header, cols="1,1,1,3"]
|===
| name | type | default | description
| cube | boolean | false| to emulate the https://docs.oracle.com/cd/F49540_01/DOC/server.815/a68003/rollup_c.htm#32311[CUBE] clause,
instead of the https://docs.oracle.com/cd/F49540_01/DOC/server.815/a68003/rollup_c.htm#32084[ROLLUP] one.
|===
51 changes: 51 additions & 0 deletions extended/src/main/java/apoc/agg/AggregationUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package apoc.agg;

import java.util.Map;

public class AggregationUtil {

public static void updateAggregationValues(Map<String, Number> partialResult, Object property, String countKey, String sumKey, String avgKey) {
Number count = updateCountValue(partialResult, countKey);

updateSumAndAvgValues(partialResult, property, count.doubleValue(), sumKey, avgKey);
}

private static Number updateCountValue(Map<String, Number> partialResult, String countKey) {
Number count = partialResult.compute(countKey,
((subKey, subVal) -> {
return subVal == null ? 1 : subVal.longValue() + 1;
}));
return count;
}

private static void updateSumAndAvgValues(Map<String, Number> partialResult, Object property, double count, String sumKey, String avgKey) {
if (!(property instanceof Number)) {
return;
}

Number numberProp = (Number) property;

Number sum = partialResult.compute(sumKey,
((subKey, subVal) -> {
if (subVal == null) {
if (numberProp instanceof Long longProp) {
return longProp;
}
return numberProp.doubleValue();
}
if (subVal instanceof Long long1
&& numberProp instanceof Long long2) {
return long1 + long2;
}
return subVal.doubleValue() + numberProp.doubleValue();
}));

partialResult.compute(avgKey, ((subKey, subVal) -> {
if (subVal == null) {
return numberProp.doubleValue();
}
return sum.doubleValue() / count;
})
);
}
}
21 changes: 6 additions & 15 deletions extended/src/main/java/apoc/agg/MultiStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import java.util.Map;
import java.util.Objects;

import static apoc.agg.AggregationUtil.updateAggregationValues;

@Extended
public class MultiStats {

Expand Down Expand Up @@ -44,22 +46,11 @@ public void aggregate(

Map<String, Number> propMap = Objects.requireNonNullElseGet(propVal, HashMap::new);

Number count = propMap.compute("count",
((subKey, subVal) -> subVal == null ? 1 : subVal.longValue() + 1) );
String countKey = "count";
String sumKey = "sum";
String avgKey = "avg";

if (property instanceof Number numberProp) {
Number sum = propMap.compute("sum",
((subKey, subVal) -> {
if (subVal == null) return numberProp;
if (subVal instanceof Long long1 && numberProp instanceof Long long2) {
return long1 + long2;
}
return subVal.doubleValue() + numberProp.doubleValue();
}));

propMap.compute("avg",
((subKey, subVal) -> subVal == null ? numberProp.doubleValue() : sum.doubleValue() / count.doubleValue() ));
}
updateAggregationValues(propMap, property, countKey, sumKey, avgKey);

return propMap;
});
Expand Down
194 changes: 194 additions & 0 deletions extended/src/main/java/apoc/agg/Rollup.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package apoc.agg;

import apoc.Extended;
import apoc.util.Util;
import org.apache.commons.collections4.ListUtils;
import org.neo4j.graphdb.Entity;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.UserAggregationFunction;
import org.neo4j.procedure.UserAggregationResult;
import org.neo4j.procedure.UserAggregationUpdate;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static apoc.agg.AggregationUtil.updateAggregationValues;


@Extended
public class Rollup {
public static final String NULL_ROLLUP = "[NULL]";

@UserAggregationFunction("apoc.agg.rollup")
@Description("apoc.agg.rollup(<ANY>, [groupKeys], [aggKeys])" +
"\n Emulate an Oracle/Mysql rollup command: `ROLLUP groupKeys, SUM(aggKey1), AVG(aggKey1), COUNT(aggKey1), SUM(aggKey2), AVG(aggKey2), ... `")
public RollupFunction rollup() {
return new RollupFunction();
}

public static class RollupFunction {
// Function to generate all combinations of a list with "TEST" as a placeholder
public static <T> List<List<T>> generateCombinationsWithPlaceholder(List<T> elements) {
List<List<T>> result = new ArrayList<>();
generateCombinationsWithPlaceholder(elements, 0, new ArrayList<>(), result);
return result;
}

// Helper function for generating combinations recursively
private static <T> void generateCombinationsWithPlaceholder(List<T> elements, int index, List<T> current, List<List<T>> result) {
if (index == elements.size()) {
result.add(new ArrayList<>(current));
return;
}

current.add(elements.get(index));
generateCombinationsWithPlaceholder(elements, index + 1, current, result);
current.remove(current.size() - 1);

// Add "NULL" as a combination placeholder
current.add((T) NULL_ROLLUP);
generateCombinationsWithPlaceholder(elements, index + 1, current, result);
current.remove(current.size() - 1);
}

private final Map<String, Object> result = new HashMap<>();

private final Map<List<Object>, Map<String, Number>> rolledUpData = new HashMap<>();
private List<String> groupKeysRes = null;

@UserAggregationUpdate
public void aggregate(
@Name("value") Object value,
@Name(value = "groupKeys") List<String> groupKeys,
@Name(value = "aggKeys") List<String> aggKeys,
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {

boolean cube = Util.toBoolean(config.get("cube"));

Entity entity = (Entity) value;

if (groupKeys.isEmpty()) {
return;
}
groupKeysRes = groupKeys;

/*
if true:
emulate the CUBE command: https://docs.oracle.com/cd/F49540_01/DOC/server.815/a68003/rollup_c.htm#32311
else:
emulate the ROLLUP command: https://docs.oracle.com/cd/F49540_01/DOC/server.815/a68003/rollup_c.htm#32084
*/
if (cube) {
List<List<String>> groupingSets = generateCombinationsWithPlaceholder(groupKeys);

for (List<String> groupKey : groupingSets) {
List<Object> partialKey = new ArrayList<>();
for (String column : groupKey) {
partialKey.add(((Entity) value).getProperty(column, NULL_ROLLUP));
}
if (!rolledUpData.containsKey(partialKey)) {
rolledUpData.put(partialKey, new HashMap<>());
}
rollupAggregationProperties(aggKeys, entity, partialKey);
}

return;
}

List<Object> groupKey = groupKeys.stream()
.map(i -> entity.getProperty(i, null))
.toList();

for (int i = 0; i <= groupKey.size(); i++) {
// add NULL_ROLLUP to remaining elements,
// e.g. `[<firstGroupKey>, `NULL_ROLLUP`, `NULL_ROLLUP`]`
List<Object> partialKey = ListUtils.union(groupKey.subList(0, i), Collections.nCopies(groupKey.size() - i, NULL_ROLLUP));
if (!rolledUpData.containsKey(partialKey)) {
rolledUpData.put(partialKey, new HashMap<>());
}
rollupAggregationProperties(aggKeys, entity, partialKey);
}
}

private void rollupAggregationProperties(List<String> aggKeys, Entity entity, List<Object> partialKey) {
Map<String, Number> partialResult = rolledUpData.get(partialKey);
for(var aggKey: aggKeys) {
if (!entity.hasProperty(aggKey)) {
continue;
}

Object property = entity.getProperty(aggKey);

String countKey = "COUNT(%s)".formatted(aggKey);
String sumKey = "SUM(%s)".formatted(aggKey);
String avgKey = "AVG(%s)".formatted(aggKey);

updateAggregationValues(partialResult, property, countKey, sumKey, avgKey);
}
}

/**
* Transform a Map.of(ListGroupKeys, MapOfAggResults) in a List of Map.of(AggResult + ListGroupKeyToMap)
*/
@UserAggregationResult
public Object result() {
List<HashMap<String, Object>> list = rolledUpData.entrySet().stream()
.map(e -> {
HashMap<String, Object> map = new HashMap<>();
for (int i = 0; i < groupKeysRes.size(); i++) {
map.put(groupKeysRes.get(i), e.getKey().get(i));
}
map.putAll(e.getValue());
return map;
})
.sorted((m1, m2) -> {
for (String key : groupKeysRes) {
Object value1 = m1.get(key);
Object value2 = m2.get(key);
int cmp = compareValues(value1, value2);
if (cmp != 0) {
return cmp;
}
}
return 0;
})
.toList();

return list;
}

/**
* We use this instead of e.g. apoc.coll.sortMulti
* since we have to handle the NULL_ROLLUP values as well
*/
private static int compareValues(Object value1, Object value2) {
if (value1 == null && value2 == null) {
return 0;
} else if (value1 == null) {
return 1;
} else if (value2 == null) {
return -1;
} else if (NULL_ROLLUP.equals(value1) && NULL_ROLLUP.equals(value2)) {
return 0;
} else if (NULL_ROLLUP.equals(value1)) {
return 1;
} else if (NULL_ROLLUP.equals(value2)) {
return -1;
} else if (value1 instanceof Comparable && value2 instanceof Comparable) {
try {
return ((Comparable<Object>) value1).compareTo(value2);
} catch (Exception e) {
// e.g. different data types, like int and strings
return 0;
}

} else {
return 0;
}
}
}
}
1 change: 1 addition & 0 deletions extended/src/main/resources/extended.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
apoc.agg.position
apoc.agg.row
apoc.agg.multiStats
apoc.agg.rollup
apoc.algo.aStarWithPoint
apoc.bolt.execute
apoc.bolt.load
Expand Down

0 comments on commit 4dcd9ca

Please sign in to comment.