Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@
import org.apache.paimon.flink.LogicalTypeConversion;
import org.apache.paimon.flink.dataevolution.DataEvolutionPartialWriteOperator;
import org.apache.paimon.flink.dataevolution.FirstRowIdAssigner;
import org.apache.paimon.flink.dataevolution.MergeIntoUpdateChecker;
import org.apache.paimon.flink.sink.Committable;
import org.apache.paimon.flink.sink.CommittableTypeInfo;
import org.apache.paimon.flink.sink.CommitterOperatorFactory;
import org.apache.paimon.flink.sink.NoopCommittableStateManager;
import org.apache.paimon.flink.sink.StoreCommitter;
import org.apache.paimon.flink.sorter.SortOperator;
import org.apache.paimon.flink.utils.FlinkCalciteClasses;
import org.apache.paimon.flink.utils.InternalTypeInfo;
import org.apache.paimon.manifest.ManifestCommittable;
import org.apache.paimon.table.FileStoreTable;
import org.apache.paimon.table.SpecialFields;
import org.apache.paimon.types.DataField;
Expand All @@ -41,16 +44,14 @@
import org.apache.paimon.types.DataTypeRoot;
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.Preconditions;
import org.apache.paimon.utils.StringUtils;

import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.functions.sink.v2.DiscardingSink;
import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamMap;
import org.apache.flink.streaming.api.operators.StreamFlatMap;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableResult;
Expand All @@ -69,10 +70,9 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static org.apache.paimon.format.blob.BlobFileFormat.isBlobFile;
Expand All @@ -95,7 +95,6 @@
public class DataEvolutionMergeIntoAction extends TableActionBase {

private static final Logger LOG = LoggerFactory.getLogger(DataEvolutionMergeIntoAction.class);
public static final String IDENTIFIER_QUOTE = "`";

private final CoreOptions coreOptions;

Expand All @@ -120,6 +119,7 @@ public class DataEvolutionMergeIntoAction extends TableActionBase {

// merge condition
private String mergeCondition;
private MergeConditionParser mergeConditionParser;

// set statement
private String matchedUpdateSet;
Expand All @@ -137,6 +137,17 @@ public DataEvolutionMergeIntoAction(
table.getClass().getName()));
}

Long latestSnapshotId = ((FileStoreTable) table).snapshotManager().latestSnapshotId();
if (latestSnapshotId == null) {
throw new UnsupportedOperationException(
"merge-into action doesn't support updating an empty table.");
}
table =
table.copy(
Collections.singletonMap(
CoreOptions.COMMIT_STRICT_MODE_LAST_SAFE_SNAPSHOT.key(),
latestSnapshotId.toString()));

this.coreOptions = ((FileStoreTable) table).coreOptions();

if (!coreOptions.dataEvolutionEnabled()) {
Expand Down Expand Up @@ -168,6 +179,12 @@ public DataEvolutionMergeIntoAction withTargetAlias(String targetAlias) {

public DataEvolutionMergeIntoAction withMergeCondition(String mergeCondition) {
this.mergeCondition = mergeCondition;
try {
this.mergeConditionParser = new MergeConditionParser(mergeCondition);
} catch (Exception e) {
LOG.error("Failed to parse merge condition: {}", mergeCondition, e);
throw new RuntimeException("Failed to parse merge condition " + mergeCondition, e);
}
return this;
}

Expand Down Expand Up @@ -196,7 +213,12 @@ public TableResult runInternal() {
DataStream<Committable> written =
writePartialColumns(shuffled, sourceWithType.f1, sinkParallelism);
// 4. commit
DataStream<?> committed = commit(written);
Set<String> updatedColumns =
sourceWithType.f1.getFields().stream()
.map(DataField::name)
.filter(name -> !SpecialFields.ROW_ID.name().equals(name))
.collect(Collectors.toSet());
DataStream<?> committed = commit(written, updatedColumns);

// execute internal
Transformation<?> transformations =
Expand All @@ -219,8 +241,7 @@ public Tuple2<DataStream<RowData>, RowType> buildSource() {
List<String> project;
if (matchedUpdateSet.equals("*")) {
// if sourceName is qualified like 'default.S', we should build a project like S.*
String[] splits = sourceTable.split("\\.");
project = Collections.singletonList(splits[splits.length - 1] + ".*");
project = Collections.singletonList(sourceTableName() + ".*");
} else {
// validate upsert changes
Map<String, String> changes = parseCommaSeparatedKeyValues(matchedUpdateSet);
Expand All @@ -245,16 +266,38 @@ public Tuple2<DataStream<RowData>, RowType> buildSource() {
.collect(Collectors.toList());
}

// use join to find matched rows and assign row id for each source row.
// _ROW_ID is the first field of joined table.
String query =
String.format(
"SELECT %s, %s FROM %s INNER JOIN %s AS RT ON %s",
"`RT`.`_ROW_ID` as `_ROW_ID`",
String.join(",", project),
escapedSourceName(),
escapedRowTrackingTargetName(),
rewriteMergeCondition(mergeCondition));
String query;
Optional<String> sourceRowIdField;
try {
sourceRowIdField = mergeConditionParser.extractRowIdFieldFromSource(targetTableName());
} catch (Exception e) {
LOG.error("Error happened when extract row id field from source table.", e);
throw new RuntimeException(
"Error happened when extract row id field from source table.", e);
}

// if source table already contains _ROW_ID field, we could avoid join
if (sourceRowIdField.isPresent()) {
query =
String.format(
// cast _ROW_ID to BIGINT
"SELECT CAST(`%s`.`%s` AS BIGINT) AS `_ROW_ID`, %s FROM %s",
sourceTableName(),
sourceRowIdField.get(),
String.join(",", project),
escapedSourceName());
} else {
// use join to find matched rows and assign row id for each source row.
// _ROW_ID is the first field of joined table.
query =
String.format(
"SELECT %s, %s FROM %s INNER JOIN %s AS RT ON %s",
"`RT`.`_ROW_ID` as `_ROW_ID`",
String.join(",", project),
escapedSourceName(),
escapedRowTrackingTargetName(),
rewriteMergeCondition(mergeCondition));
}

LOG.info("Source query: {}", query);

Expand Down Expand Up @@ -286,11 +329,15 @@ public DataStream<Tuple2<Long, RowData>> shuffleByFirstRowId(
Preconditions.checkState(
!firstRowIds.isEmpty(), "Should not MERGE INTO an empty target table.");

// if firstRowIds is not empty, there must be a valid nextRowId
long maxRowId = table.latestSnapshot().get().nextRowId() - 1;

OneInputTransformation<RowData, Tuple2<Long, RowData>> assignedFirstRowId =
new OneInputTransformation<>(
sourceTransformation,
"ASSIGN FIRST_ROW_ID",
new StreamMap<>(new FirstRowIdAssigner(firstRowIds, sourceType)),
new StreamFlatMap<>(
new FirstRowIdAssigner(firstRowIds, maxRowId, sourceType)),
new TupleTypeInfo<>(
BasicTypeInfo.LONG_TYPE_INFO, sourceTransformation.getOutputType()),
sourceTransformation.getParallelism(),
Expand Down Expand Up @@ -334,9 +381,20 @@ public DataStream<Committable> writePartialColumns(
.setParallelism(sinkParallelism);
}

public DataStream<Committable> commit(DataStream<Committable> written) {
public DataStream<Committable> commit(
DataStream<Committable> written, Set<String> updatedColumns) {
FileStoreTable storeTable = (FileStoreTable) table;
OneInputStreamOperatorFactory<Committable, Committable> committerOperator =

// Check if some global-indexed columns are updated
DataStream<Committable> checked =
written.transform(
"Updated Column Check",
new CommittableTypeInfo(),
new MergeIntoUpdateChecker(storeTable, updatedColumns))
.setParallelism(1)
.setMaxParallelism(1);

CommitterOperatorFactory<Committable, ManifestCommittable> committerOperator =
new CommitterOperatorFactory<>(
false,
true,
Expand All @@ -348,7 +406,7 @@ public DataStream<Committable> commit(DataStream<Committable> written) {
context),
new NoopCommittableStateManager());

return written.transform("COMMIT OPERATOR", new CommittableTypeInfo(), committerOperator)
return checked.transform("COMMIT OPERATOR", new CommittableTypeInfo(), committerOperator)
.setParallelism(1)
.setMaxParallelism(1);
}
Expand Down Expand Up @@ -382,28 +440,13 @@ private DataStream<RowData> toDataStream(Table source) {
*/
@VisibleForTesting
public String rewriteMergeCondition(String mergeCondition) {
// skip single and double-quoted chunks
String skipQuoted = "'(?:''|[^'])*'" + "|\"(?:\"\"|[^\"])*\"";
String targetTableRegex =
"(?i)(?:\\b"
+ Pattern.quote(targetTableName())
+ "\\b|`"
+ Pattern.quote(targetTableName())
+ "`)\\s*\\.";

Pattern pattern = Pattern.compile(skipQuoted + "|(" + targetTableRegex + ")");
Matcher matcher = pattern.matcher(mergeCondition);

StringBuffer sb = new StringBuffer();
while (matcher.find()) {
if (matcher.group(1) != null) {
matcher.appendReplacement(sb, Matcher.quoteReplacement("`RT`."));
} else {
matcher.appendReplacement(sb, Matcher.quoteReplacement(matcher.group(0)));
}
try {
Object rewrittenNode = mergeConditionParser.rewriteSqlNode(targetTableName(), "RT");
return rewrittenNode.toString();
} catch (Exception e) {
LOG.error("Failed to rewrite merge condition: {}", mergeCondition, e);
throw new RuntimeException("Failed to rewrite merge condition " + mergeCondition, e);
}
matcher.appendTail(sb);
return sb.toString();
}

/**
Expand Down Expand Up @@ -432,7 +475,8 @@ private void checkSchema(Table source) {
foundRowIdColumn = true;
Preconditions.checkState(
flinkColumn.getDataType().getLogicalType().getTypeRoot()
== LogicalTypeRoot.BIGINT);
== LogicalTypeRoot.BIGINT,
"_ROW_ID field should be BIGINT type.");
} else {
DataField targetField = targetFields.get(flinkColumn.getName());
if (targetField == null) {
Expand Down Expand Up @@ -497,6 +541,11 @@ private String targetTableName() {
return targetAlias == null ? identifier.getObjectName() : targetAlias;
}

private String sourceTableName() {
String[] splits = sourceTable.split("\\.");
return splits[splits.length - 1];
}

private String escapedSourceName() {
return Arrays.stream(sourceTable.split("\\."))
.map(s -> String.format("`%s`", s))
Expand All @@ -514,28 +563,108 @@ private String escapedRowTrackingTargetName() {
catalogName, identifier.getDatabaseName(), identifier.getObjectName());
}

private List<String> normalizeFieldName(List<String> fieldNames) {
return fieldNames.stream().map(this::normalizeFieldName).collect(Collectors.toList());
}
/** The parser to parse merge condition through calcite sql parser. */
static class MergeConditionParser {

private final FlinkCalciteClasses calciteClasses;
private final Object sqlNode;

MergeConditionParser(String mergeCondition) throws Exception {
this.calciteClasses = new FlinkCalciteClasses();
this.sqlNode = initializeSqlNode(mergeCondition);
}

private String normalizeFieldName(String fieldName) {
if (StringUtils.isNullOrWhitespaceOnly(fieldName) || fieldName.endsWith(IDENTIFIER_QUOTE)) {
return fieldName;
private Object initializeSqlNode(String mergeCondition) throws Exception {
Object config =
calciteClasses
.configDelegate()
.withLex(
calciteClasses.sqlParserDelegate().config(),
calciteClasses.lexDelegate().java());
Object sqlParser = calciteClasses.sqlParserDelegate().create(mergeCondition, config);
return calciteClasses.sqlParserDelegate().parseExpression(sqlParser);
}

String[] splitFieldNames = fieldName.split("\\.");
if (!targetFieldNames.contains(splitFieldNames[splitFieldNames.length - 1])) {
return fieldName;
/**
* Rewrite the SQL node, replacing all references from the 'from' table to the 'to' table.
*/
public Object rewriteSqlNode(String from, String to) throws Exception {
return rewriteNode(sqlNode, from, to);
}

return String.join(
".",
Arrays.stream(splitFieldNames)
.map(
part ->
part.endsWith(IDENTIFIER_QUOTE)
? part
: IDENTIFIER_QUOTE + part + IDENTIFIER_QUOTE)
.toArray(String[]::new));
private Object rewriteNode(Object node, String from, String to) throws Exception {
// It's a SqlBasicCall, recursively rewrite children operands
if (calciteClasses.sqlBasicCallDelegate().instanceOfSqlBasicCall(node)) {
List<?> operandList = calciteClasses.sqlBasicCallDelegate().getOperandList(node);
List<Object> newNodes = new java.util.ArrayList<>();
for (Object operand : operandList) {
newNodes.add(rewriteNode(operand, from, to));
}

Object operator = calciteClasses.sqlBasicCallDelegate().getOperator(node);
Object parserPos = calciteClasses.sqlBasicCallDelegate().getParserPosition(node);
Object functionQuantifier =
calciteClasses.sqlBasicCallDelegate().getFunctionQuantifier(node);
return calciteClasses
.sqlBasicCallDelegate()
.create(operator, newNodes, parserPos, functionQuantifier);
} else if (calciteClasses.sqlIndentifierDelegate().instanceOfSqlIdentifier(node)) {
// It's a sql identifier, try to replace the table name
List<String> names = calciteClasses.sqlIndentifierDelegate().getNames(node);
Preconditions.checkState(
names.size() >= 2, "Please specify the table name for the column: " + node);
int nameLen = names.size();
if (names.get(nameLen - 2).equals(from)) {
return calciteClasses.sqlIndentifierDelegate().setName(node, nameLen - 2, to);
}
return node;
} else {
return node;
}
}

/**
* Find the row id field in source table. This method looks for an equality condition like
* `target_table._ROW_ID = source_table.some_field` or `source_table.some_field =
* target_table._ROW_ID`, and returns the field name that is paired with _ROW_ID.
*/
public Optional<String> extractRowIdFieldFromSource(String targetTable) throws Exception {
Object operator = calciteClasses.sqlBasicCallDelegate().getOperator(sqlNode);
Object kind = calciteClasses.sqlOperatorDelegate().getKind(operator);

if (kind == calciteClasses.sqlKindDelegate().equals()) {
List<?> operandList = calciteClasses.sqlBasicCallDelegate().getOperandList(sqlNode);

Object left = operandList.get(0);
Object right = operandList.get(1);

if (calciteClasses.sqlIndentifierDelegate().instanceOfSqlIdentifier(left)
&& calciteClasses.sqlIndentifierDelegate().instanceOfSqlIdentifier(right)) {

List<String> leftNames = calciteClasses.sqlIndentifierDelegate().getNames(left);
List<String> rightNames =
calciteClasses.sqlIndentifierDelegate().getNames(right);
Preconditions.checkState(
leftNames.size() >= 2,
"Please specify the table name for the column: " + left);
Preconditions.checkState(
rightNames.size() >= 2,
"Please specify the table name for the column: " + right);

if (leftNames.get(leftNames.size() - 1).equals(SpecialFields.ROW_ID.name())
&& leftNames.get(leftNames.size() - 2).equals(targetTable)) {
return Optional.of(rightNames.get(rightNames.size() - 1));
} else if (rightNames
.get(rightNames.size() - 1)
.equals(SpecialFields.ROW_ID.name())
&& rightNames.get(rightNames.size() - 2).equals(targetTable)) {
return Optional.of(leftNames.get(leftNames.size() - 1));
}
return Optional.empty();
}
}

return Optional.empty();
}
}
}
Loading
Loading