Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

public class AggregateAddCols implements FunctionDescriptor.SerializableFunction<Record, Record> {
final List<AggregateCall> aggregateCalls;
public AggregateAddCols(final List<AggregateCall> aggregateCalls){

public AggregateAddCols(final List<AggregateCall> aggregateCalls) {
this.aggregateCalls = aggregateCalls;
}

Expand All @@ -36,7 +36,7 @@ public Record apply(final Record record) {
final int l = record.size();
final int newRecordSize = l + aggregateCalls.size() + 1;
final Object[] resValues = new Object[newRecordSize];

for (int i = 0; i < l; i++) {
resValues[i] = record.getField(i);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.wayang.api.sql.calcite.converter.functions;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.BiFunction;
Expand All @@ -39,7 +40,7 @@ public AggregateFunction(final List<AggregateCall> aggregateCalls) {
public Record apply(final Record record1, final Record record2) {
final int l = record1.size();
final Object[] resValues = new Object[l];
final boolean countDone = false;
boolean countDone = false;

for (int i = 0; i < l - aggregateCalls.size() - 1; i++) {
resValues[i] = record1.getField(i);
Expand All @@ -61,25 +62,33 @@ public Record apply(final Record record1, final Record record2) {
case MAX:
resValues[counter] = this.castAndMap(field1, field2, SqlFunctions::greatest, SqlFunctions::greatest,
SqlFunctions::greatest, SqlFunctions::greatest);
break;
case COUNT:
// since aggregates inject an extra column for counting before,
// see AggregateAddCols. the column we operate on are integer counts,
// which means we can eagerly get the fields as integers and simply sum
assert (field1 instanceof Integer && field2 instanceof Integer)
: "Expected to find integers for count but found: " + field1 + " and " + field2;
Object obj = Integer.class.cast(field1) + Integer.class.cast(field2);
resValues[counter] = obj;
final Object count = Integer.class.cast(field1) + Integer.class.cast(field2);
resValues[counter] = count;
break;
case AVG:
throw new UnsupportedOperationException("Averages not currently supported");
// resValues[counter] = this.castAndMap(field1, field2, null, null, null, null);
// break;
assert (field1 instanceof Integer && field2 instanceof Integer)
: "Expected to find integers for count but found: " + field1 + " and " + field2;
final Object avg = Integer.class.cast(field1) + Integer.class.cast(field2);

resValues[counter] = avg;

if (!countDone) {
resValues[l - 1] = record1.getInt(l - 1) + record2.getInt(l - 1);
countDone = true;
}
break;
default:
throw new IllegalStateException("Unsupported operation: " + aggregateCall.getAggregation().kind);
}
counter++;
}

return new Record(resValues);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@

package org.apache.wayang.api.sql.calcite.converter.functions;

import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.sql.SqlKind;
import org.apache.wayang.basic.data.Record;
import org.apache.wayang.core.function.FunctionDescriptor;

Expand All @@ -36,31 +40,22 @@ public AggregateGetResult(final List<AggregateCall> aggregateCalls, final Set<In

@Override
public Record apply(final Record record) {
final int l = record.size();
final int outputRecordSize = aggregateCallList.size() + groupingfields.size();
final Object[] resValues = new Object[outputRecordSize];

int i = 0;
int j = 0;
for (i = 0; j < groupingfields.size(); i++) {
if (groupingfields.contains(i)) {
resValues[j] = record.getField(i);
j++;
}
}

i = l - aggregateCallList.size() - 1;
for (final AggregateCall aggregateCall : aggregateCallList) {
final String name = aggregateCall.getAggregation().getName();
if (name.equals("AVG")) {
resValues[j] = record.getDouble(i) / record.getDouble(l - 1);
} else {
resValues[j] = record.getField(i);
}
j++;
i++;
}

return new Record(resValues);
final int recordSize = record.size();
final int aggregateCallOffset = recordSize - aggregateCallList.size() - 1;

final Object[] fields = groupingfields.stream()
.map(record::getField)
.toArray();

final Object[] aggregateCallFields = IntStream.range(0, aggregateCallList.size())
.mapToObj(i -> aggregateCallList.get(i).getAggregation().getKind().equals(SqlKind.AVG)
? record.getDouble(i + aggregateCallOffset) / record.getDouble(recordSize - 1)
: record.getField(i + aggregateCallOffset))
.toArray();

final Object[] combinedFields = Stream.concat(Arrays.stream(fields), Arrays.stream(aggregateCallFields))
.toArray();

return new Record(combinedFields);
}
}
Loading
Loading