-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
BeamAggregationRel.java
394 lines (359 loc) · 16.7 KB
/
BeamAggregationRel.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.rel;
import static java.util.stream.Collectors.toList;
import static org.apache.beam.sdk.values.PCollection.IsBounded.BOUNDED;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import java.io.Serializable;
import java.util.List;
import org.apache.beam.sdk.extensions.sql.impl.BeamSqlPipelineOptions;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.AggregationCombineFnAdapter;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.transforms.Group;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.WithTimestamps;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.transforms.windowing.Sessions;
import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptCluster;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptPlanner;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelTraitSet;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelWriter;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.ImmutableBitSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
/** {@link BeamRelNode} to replace a {@link Aggregate} node. */
@SuppressWarnings({
"rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public class BeamAggregationRel extends Aggregate implements BeamRelNode {
private @Nullable WindowFn<Row, IntervalWindow> windowFn;
private final int windowFieldIndex;
public BeamAggregationRel(
RelOptCluster cluster,
RelTraitSet traits,
RelNode child,
ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls,
@Nullable WindowFn<Row, IntervalWindow> windowFn,
int windowFieldIndex) {
super(cluster, traits, child, groupSet, groupSets, aggCalls);
assert getGroupType() == Group.SIMPLE;
this.windowFn = windowFn;
this.windowFieldIndex = windowFieldIndex;
}
@Override
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, BeamRelMetadataQuery mq) {
NodeStats inputStat = BeamSqlRelUtils.getNodeStats(this.input, mq);
inputStat = computeWindowingCostEffect(inputStat);
// Aggregates with more aggregate functions cost a bit more
float multiplier = 1f + (float) aggCalls.size() * 0.125f;
for (AggregateCall aggCall : aggCalls) {
if (aggCall.getAggregation().getName().equals("SUM")) {
// Pretend that SUM costs a little bit more than $SUM0,
// to make things deterministic.
multiplier += 0.0125f;
}
}
return BeamCostModel.FACTORY.makeCost(
inputStat.getRowCount() * multiplier, inputStat.getRate() * multiplier);
}
@Override
public NodeStats estimateNodeStats(BeamRelMetadataQuery mq) {
NodeStats inputEstimate = BeamSqlRelUtils.getNodeStats(this.input, mq);
inputEstimate = computeWindowingCostEffect(inputEstimate);
// groupCount shows how many columns do we have in group by. One of them might be the windowing.
int groupCount = groupSet.cardinality() - (windowFn == null ? 0 : 1);
// This is similar to what Calcite does.If groupCount is zero then then we have only one value
// per window for unbounded and we have only one value for bounded. e.g select count(*) from A
// If group count is none zero then more column we include in the group by, more rows will be
// preserved.
return (groupCount == 0)
? NodeStats.create(
Math.min(inputEstimate.getRowCount(), 1d),
inputEstimate.getRate() / inputEstimate.getWindow(),
1d)
: inputEstimate.multiply(1.0 - Math.pow(.5, groupCount));
}
private NodeStats computeWindowingCostEffect(NodeStats inputStat) {
if (windowFn == null) {
return inputStat;
}
WindowFn w = windowFn;
double multiplicationFactor = 1;
// If the window is SlidingWindow, the number of tuples will increase. (Because, some of the
// tuples repeat in multiple windows).
if (w instanceof SlidingWindows) {
multiplicationFactor =
((double) ((SlidingWindows) w).getSize().getStandardSeconds())
/ ((SlidingWindows) w).getPeriod().getStandardSeconds();
}
return NodeStats.create(
inputStat.getRowCount() * multiplicationFactor,
inputStat.getRate() * multiplicationFactor,
BeamIOSourceRel.CONSTANT_WINDOW_SIZE);
}
@Override
public RelWriter explainTerms(RelWriter pw) {
super.explainTerms(pw);
if (this.windowFn != null) {
WindowFn windowFn = this.windowFn;
String window = windowFn.getClass().getSimpleName() + "($" + String.valueOf(windowFieldIndex);
if (windowFn instanceof FixedWindows) {
FixedWindows fn = (FixedWindows) windowFn;
window = window + ", " + fn.getSize().toString() + ", " + fn.getOffset().toString();
} else if (windowFn instanceof SlidingWindows) {
SlidingWindows fn = (SlidingWindows) windowFn;
window =
window
+ ", "
+ fn.getPeriod().toString()
+ ", "
+ fn.getSize().toString()
+ ", "
+ fn.getOffset().toString();
} else if (windowFn instanceof Sessions) {
Sessions fn = (Sessions) windowFn;
window = window + ", " + fn.getGapDuration().toString();
} else {
throw new UnsupportedOperationException(
"Unknown window function " + windowFn.getClass().getSimpleName());
}
window = window + ")";
pw.item("window", window);
}
return pw;
}
@Override
public PTransform<PCollectionList<Row>, PCollection<Row>> buildPTransform() {
Schema outputSchema = CalciteUtils.toSchema(getRowType());
List<FieldAggregation> aggregationAdapters =
getNamedAggCalls().stream()
.map(aggCall -> new FieldAggregation(aggCall.getKey(), aggCall.getValue()))
.collect(toList());
return new Transform(
windowFn, windowFieldIndex, getGroupSet(), aggregationAdapters, outputSchema);
}
private static class FieldAggregation implements Serializable {
final List<Integer> inputs;
final CombineFn combineFn;
final Field outputField;
FieldAggregation(AggregateCall call, String alias) {
inputs = call.getArgList();
outputField = CalciteUtils.toField(alias, call.getType());
combineFn =
AggregationCombineFnAdapter.createCombineFn(
call, outputField, call.getAggregation().getName());
}
}
private static class Transform extends PTransform<PCollectionList<Row>, PCollection<Row>> {
private final List<Integer> keyFieldsIds;
private Schema outputSchema;
private WindowFn<Row, IntervalWindow> windowFn;
private int windowFieldIndex;
private List<FieldAggregation> fieldAggregations;
private final int groupSetCount;
private boolean ignoreValues;
private Transform(
WindowFn<Row, IntervalWindow> windowFn,
int windowFieldIndex,
ImmutableBitSet groupSet,
List<FieldAggregation> fieldAggregations,
Schema outputSchema) {
this.windowFn = windowFn;
this.windowFieldIndex = windowFieldIndex;
this.fieldAggregations = fieldAggregations;
this.outputSchema = outputSchema;
this.groupSetCount = groupSet.asList().size();
this.ignoreValues = false;
this.keyFieldsIds =
groupSet.asList().stream().filter(i -> i != windowFieldIndex).collect(toList());
}
@Override
public PCollection<Row> expand(PCollectionList<Row> pinput) {
checkArgument(
pinput.size() == 1,
"Wrong number of inputs for %s: %s",
BeamAggregationRel.class.getSimpleName(),
pinput);
PCollection<Row> upstream = pinput.get(0);
PCollection<Row> windowedStream = upstream;
if (windowFn != null) {
windowedStream = assignTimestampsAndWindow(upstream);
}
validateWindowIsSupported(windowedStream);
// Check if have fields to be grouped
if (groupSetCount > 0) {
org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner<Row> byFields =
org.apache.beam.sdk.schemas.transforms.Group.byFieldIds(keyFieldsIds);
PTransform<PCollection<Row>, PCollection<Row>> combiner = createCombiner(byFields);
boolean verifyRowValues =
pinput.getPipeline().getOptions().as(BeamSqlPipelineOptions.class).getVerifyRowValues();
return windowedStream
.apply(combiner)
.apply(
"mergeRecord",
ParDo.of(
mergeRecord(outputSchema, windowFieldIndex, ignoreValues, verifyRowValues)))
.setRowSchema(outputSchema);
}
org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner<Row> globally =
(org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner<Row>)
org.apache.beam.sdk.schemas.transforms.Group.CombineFieldsGlobally.create();
PTransform<PCollection<Row>, PCollection<Row>> combiner = createCombiner(globally);
return windowedStream.apply(combiner).setRowSchema(outputSchema);
}
private PTransform<PCollection<Row>, PCollection<Row>> createCombiner(
org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner<Row> initialCombiner) {
org.apache.beam.sdk.schemas.transforms.Group.AggregateCombiner combined = null;
for (FieldAggregation fieldAggregation : fieldAggregations) {
List<Integer> inputs = fieldAggregation.inputs;
CombineFn combineFn = fieldAggregation.combineFn;
if (inputs.size() == 1) {
// Combining over a single field, so extract just that field.
combined =
(combined == null)
? initialCombiner.aggregateField(
inputs.get(0), combineFn, fieldAggregation.outputField)
: combined.aggregateField(inputs.get(0), combineFn, fieldAggregation.outputField);
} else {
// In this path we extract a Row (an empty row if inputs.isEmpty).
combined =
(combined == null)
? initialCombiner.aggregateFieldsById(
inputs, combineFn, fieldAggregation.outputField)
: combined.aggregateFieldsById(inputs, combineFn, fieldAggregation.outputField);
}
}
PTransform<PCollection<Row>, PCollection<Row>> combiner = combined;
if (combiner == null) {
// If no field aggregations were specified, we run a constant combiner that always returns
// a single empty row for each key. This is used by the SELECT DISTINCT query plan - in this
// case a group by is generated to determine unique keys, and a constant null combiner is
// used.
combiner =
initialCombiner.aggregateField(
"*",
AggregationCombineFnAdapter.createConstantCombineFn(),
Field.of(
"e",
FieldType.row(AggregationCombineFnAdapter.EMPTY_SCHEMA).withNullable(true)));
ignoreValues = true;
}
return combiner;
}
/** Extract timestamps from the windowFieldIndex, then window into windowFns. */
private PCollection<Row> assignTimestampsAndWindow(PCollection<Row> upstream) {
PCollection<Row> windowedStream;
windowedStream =
upstream
.apply(
"assignEventTimestamp",
WithTimestamps.<Row>of(row -> row.getDateTime(windowFieldIndex).toInstant())
.withAllowedTimestampSkew(Duration.millis(Long.MAX_VALUE)))
.setCoder(upstream.getCoder())
.apply(Window.into(windowFn));
return windowedStream;
}
/**
* Performs the same check as {@link GroupByKey}, provides more context in exception.
*
* <p>Verifies that the input PCollection is bounded, or that there is windowing/triggering
* being used. Without this, the watermark (at end of global window) will never be reached.
*
* <p>Throws {@link UnsupportedOperationException} if validation fails.
*/
private void validateWindowIsSupported(PCollection<Row> upstream) {
WindowingStrategy<?, ?> windowingStrategy = upstream.getWindowingStrategy();
if (windowingStrategy.getWindowFn() instanceof GlobalWindows
&& windowingStrategy.getTrigger() instanceof DefaultTrigger
&& upstream.isBounded() != BOUNDED) {
throw new UnsupportedOperationException(
"Please explicitly specify windowing in SQL query using HOP/TUMBLE/SESSION functions "
+ "(default trigger will be used in this case). "
+ "Unbounded input with global windowing and default trigger is not supported "
+ "in Beam SQL aggregations. "
+ "See GroupByKey section in Beam Programming Guide");
}
}
static DoFn<Row, Row> mergeRecord(
Schema outputSchema,
int windowStartFieldIndex,
boolean ignoreValues,
boolean verifyRowValues) {
return new DoFn<Row, Row>() {
@ProcessElement
public void processElement(
@Element Row kvRow, BoundedWindow window, OutputReceiver<Row> o) {
int capacity =
kvRow.getRow(0).getFieldCount()
+ (!ignoreValues ? kvRow.getRow(1).getFieldCount() : 0);
List<Object> fieldValues = Lists.newArrayListWithCapacity(capacity);
fieldValues.addAll(kvRow.getRow(0).getValues());
if (!ignoreValues) {
fieldValues.addAll(kvRow.getRow(1).getValues());
}
if (windowStartFieldIndex != -1) {
fieldValues.add(windowStartFieldIndex, ((IntervalWindow) window).start());
}
Row row =
verifyRowValues
? Row.withSchema(outputSchema).addValues(fieldValues).build()
: Row.withSchema(outputSchema).attachValues(fieldValues);
o.output(row);
}
};
}
}
@Override
public Aggregate copy(
RelTraitSet traitSet,
RelNode input,
ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls) {
return new BeamAggregationRel(
getCluster(), traitSet, input, groupSet, groupSets, aggCalls, windowFn, windowFieldIndex);
}
}