Skip to content

Commit

Permalink
Merge pull request #127 from jeff-zou/dev
Browse files Browse the repository at this point in the history
#125 #126 support parse UDTF which the number of input and argument is equal
  • Loading branch information
HamaWhiteGG committed Apr 21, 2024
2 parents 70c0a2c + d2cb8dc commit cd81704
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ public String getTransform() {
return transform;
}

public void setTransform(String transform) {
this.transform = transform;
}

@Override
public boolean equals(Object obj) {
if (!(obj instanceof RelColumnOrigin)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@

import java.util.*;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static com.hw.lineage.common.util.Constant.DELIMITER;
Expand Down Expand Up @@ -59,8 +57,6 @@ public class RelMdColumnOrigins implements MetadataHandler<BuiltInMetadata.Colum

private static final Logger LOG = LoggerFactory.getLogger(RelMdColumnOrigins.class);

private final Pattern pattern = Pattern.compile("\\$[\\w.]+");

public static final RelMetadataProvider SOURCE =
ReflectiveRelMetadataProvider.reflectiveSource(
BuiltInMethod.COLUMN_ORIGIN.method, new RelMdColumnOrigins());
Expand All @@ -85,10 +81,37 @@ public Set<RelColumnOrigin> getColumnOrigins(Aggregate rel, RelMetadataQuery mq,
// Aggregate columns are derived from input columns
AggregateCall call = rel.getAggCallList().get(iOutputColumn - rel.getGroupCount());
final Set<RelColumnOrigin> set = new LinkedHashSet<>();
String transform = call.toString();
for (Integer iInput : call.getArgList()) {
set.addAll(mq.getColumnOrigins(rel.getInput(), iInput));

RexNode rexNode = ((Project) rel.getInput()).getProjects().get(iInput);

if (rexNode instanceof RexLiteral) {
RexLiteral literal = (RexLiteral) rexNode;
transform = transform.replace("$" + iInput, literal.toString().replace("_UTF-16LE", ""));
continue;
}

Set<org.apache.calcite.rel.metadata.RelColumnOrigin> subSet =
mq.getColumnOrigins(rel.getInput(), iInput);

if (!(rexNode instanceof RexCall)) {
subSet = createDerivedColumnOrigins(subSet, rexNode);
}

for (org.apache.calcite.rel.metadata.RelColumnOrigin relColumnOrigin : subSet) {
if (relColumnOrigin.getTransform() != null) {
transform = transform.replace("$" + iInput, relColumnOrigin.getTransform());
}
break;
}
set.addAll(subSet);
}
return createDerivedColumnOrigins(set, call);

// 替换所有的transform
final String finalTransform = transform;
set.forEach(s -> s.setTransform(finalTransform));
return set;
}

public Set<RelColumnOrigin> getColumnOrigins(Join rel, RelMetadataQuery mq, int iOutputColumn) {
Expand Down Expand Up @@ -411,32 +434,12 @@ private Set<RelColumnOrigin> createDerivedColumnOrigins(Set<RelColumnOrigin> inp
private String computeTransform(Set<RelColumnOrigin> inputSet, Object transform) {
LOG.debug("origin transform: {}, class: {}", transform, transform.getClass());
String finalTransform = transform.toString();

Matcher matcher = pattern.matcher(finalTransform);

Set<String> operandSet = new LinkedHashSet<>();
while (matcher.find()) {
operandSet.add(matcher.group());
}

if (operandSet.isEmpty()) {
return finalTransform;
}
if (inputSet.size() != operandSet.size()) {
LOG.warn("The number [{}] of fields in the source tables are not equal to operands [{}]", inputSet.size(),
operandSet.size());
return null;
}

Map<String, String> sourceColumnMap = buildSourceColumnMap(inputSet, transform);

matcher = pattern.matcher(finalTransform);
String temp;
while (matcher.find()) {
temp = matcher.group();
finalTransform = finalTransform.replace(temp, sourceColumnMap.get(temp));
for (Map.Entry<String, String> entry : sourceColumnMap.entrySet()) {
finalTransform = finalTransform.replace(entry.getKey(), entry.getValue());
}
// temporary special treatment

finalTransform = finalTransform.replace("_UTF-16LE", "");
LOG.debug("final transform: {}", finalTransform);
return finalTransform;
Expand Down Expand Up @@ -481,7 +484,12 @@ public Void visitFieldAccess(RexFieldAccess fieldAccess) {
}
Map<String, String> sourceColumnMap = new HashMap<>(INITIAL_CAPACITY);
Iterator<String> iterator = optimizeSourceColumnSet(inputSet).iterator();
traversalSet.forEach(index -> sourceColumnMap.put("$" + index, iterator.next()));
traversalSet.forEach(
index -> {
if (iterator.hasNext()) {
sourceColumnMap.put("$" + index, iterator.next());
}
});
LOG.debug("sourceColumnMap: {}", sourceColumnMap);
return sourceColumnMap;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.lineage.flink.aggregatefunction;

import com.hw.lineage.flink.basic.AbstractBasicTest;

import org.junit.Before;
import org.junit.Test;

public class AggregateFunctionTest extends AbstractBasicTest {

@Before
public void createTable() {

createTableOfOdsMysqlUsers();

createTableOfDwdHudiUsers();

createTableOfOdsMysqlUsersDetail();

createPrintTable();

createFunction();
}

/**
* #125 issue
*/
@Test
public void testAggregateFunction() {
String sql = "INSERT INTO dwd_hudi_users " +
"SELECT " +
" id ," +
" name ," +
" test_aggregate(concat_ws('_', name, 'test'), name, 'test')," +
" birthday ," +
" ts ," +
" DATE_FORMAT(birthday, 'yyyyMMdd') " +
"FROM" +
" ods_mysql_users group by id, name, birthday, ts ";

String[][] expectedArray = {
{"ods_mysql_users", "id", "dwd_hudi_users", "id"},
{"ods_mysql_users", "name", "dwd_hudi_users", "name"},
{"ods_mysql_users", "name", "dwd_hudi_users", "company_name",
"test_aggregate(CONCAT_WS('_', name, 'test'), name, 'test')"},
{"ods_mysql_users", "birthday", "dwd_hudi_users", "birthday"},
{"ods_mysql_users", "ts", "dwd_hudi_users", "ts"},
{"ods_mysql_users", "birthday", "dwd_hudi_users", "partition", "DATE_FORMAT(birthday, 'yyyyMMdd')"}
};

analyzeLineage(sql, expectedArray);
}

/**
* #128 issue
*/
@Test
public void testMultiTierUdf() {
String sql = "INSERT INTO print_table " +
"SELECT " +
" round( COUNT(*) / COUNT( DISTINCT name ) , 2 )" +
"FROM" +
" ods_mysql_users group by ts ";

String[][] expectedArray = {
{"ods_mysql_users", "name", "print_table", "num", "ROUND(/(COUNT(DISTINCT name), $2), 2)"},

};

analyzeLineage(sql, expectedArray);
}

/**
* #126 issue
*/
@Test
public void testAggregateFunctionInputArgument() {
String sql = "INSERT INTO dwd_hudi_users " +
"SELECT " +
" id ," +
" name ," +
" test_aggregate(concat_ws('_', name, email), address, 'test')," +
" birthday ," +
" ts ," +
" DATE_FORMAT(birthday, 'yyyyMMdd') " +
"FROM" +
" ods_mysql_user_detail group by id, name, birthday, ts ";

String[][] expectedArray = {
{"ods_mysql_user_detail", "id", "dwd_hudi_users", "id"},
{"ods_mysql_user_detail", "name", "dwd_hudi_users", "name"},
{"ods_mysql_user_detail", "name", "dwd_hudi_users", "company_name",
"test_aggregate(CONCAT_WS('_', name, email), address, 'test')"},
{"ods_mysql_user_detail", "email", "dwd_hudi_users", "company_name",
"test_aggregate(CONCAT_WS('_', name, email), address, 'test')"},
{"ods_mysql_user_detail", "address", "dwd_hudi_users", "company_name",
"test_aggregate(CONCAT_WS('_', name, email), address, 'test')"},
{"ods_mysql_user_detail", "birthday", "dwd_hudi_users", "birthday"},
{"ods_mysql_user_detail", "ts", "dwd_hudi_users", "ts"},
{"ods_mysql_user_detail", "birthday", "dwd_hudi_users", "partition",
"DATE_FORMAT(birthday, 'yyyyMMdd')"}
};

analyzeLineage(sql, expectedArray);
}

private void createPrintTable() {
context.execute("drop table if exists print_table");
context.execute("create table print_table (num double) with ('connector'='print')");
}
protected void createTableOfOdsMysqlUsersDetail() {
context.execute("DROP TABLE IF EXISTS ods_mysql_user_detail ");

context.execute("CREATE TABLE IF NOT EXISTS ods_mysql_user_detail (" +
" id BIGINT PRIMARY KEY NOT ENFORCED ," +
" name STRING ," +
" birthday TIMESTAMP(3) ," +
" ts TIMESTAMP(3) ," +
" email STRING ," +
" address STRING ," +
" proc_time as proctime() " +
") WITH ( " +
" 'connector' = 'mysql-cdc' ," +
" 'hostname' = '127.0.0.1' ," +
" 'port' = '3306' ," +
" 'username' = 'root' ," +
" 'password' = 'xxx' ," +
" 'server-time-zone' = 'Asia/Shanghai' ," +
" 'database-name' = 'demo' ," +
" 'table-name' = 'users' " +
")");
}

private void createFunction() {
context.execute("drop function if exists test_aggregate");
context.execute(
"create function test_aggregate as 'com.hw.lineage.flink.aggregatefunction.TestAggregateFunction'");
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.lineage.flink.aggregatefunction;

import org.apache.flink.table.functions.AggregateFunction;

public class TestAggregateFunction extends AggregateFunction<String, TestAggregateFunction.TestAggregateAcc> {

public void accumulate(TestAggregateAcc acc, String param1, String param2, String param3) {
acc.test = param1 + param2 + param3;
}

@Override
public String getValue(TestAggregateAcc accumulator) {
return accumulator.test;
}

@Override
public TestAggregateAcc createAccumulator() {
return new TestAggregateAcc();
}

public static class TestAggregateAcc {

public String test;
}
}

0 comments on commit cd81704

Please sign in to comment.