Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion be/src/format/reader/column_mapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ static VExprSPtr rewrite_literal_to_file_type(const VExprSPtr& literal_expr,
original_literal->data_type(), original_field);
}

// TODO: rewrite InPredicate
static bool rewrite_binary_slot_literal_predicate(
const VExprSPtr& expr,
const std::map<int32_t, FileSlotRewriteInfo>& table_column_to_file_slot) {
Expand Down Expand Up @@ -227,6 +226,51 @@ static bool rewrite_binary_slot_literal_predicate(
return true;
}

static bool rewrite_in_slot_literal_predicate(
const VExprSPtr& expr,
const std::map<int32_t, FileSlotRewriteInfo>& table_column_to_file_slot) {
if (expr->node_type() != TExprNodeType::IN_PRED || expr->get_num_children() < 2) {
return false;
}
auto children = expr->children();
const VSlotRef* slot_ref = nullptr;
const FileSlotRewriteInfo* rewrite_info =
find_slot_rewrite_info(children[0], table_column_to_file_slot, &slot_ref);
if (rewrite_info == nullptr || slot_ref == nullptr) {
return false;
}

VExprSPtrs rewritten_literals;
rewritten_literals.reserve(children.size() - 1);
for (size_t child_idx = 1; child_idx < children.size(); ++child_idx) {
auto literal_expr =
unwrap_literal_for_file_cast(children[child_idx], rewrite_info->table_type);
if (literal_expr == nullptr) {
return false;
}
auto rewritten_literal = rewrite_literal_to_file_type(literal_expr, *rewrite_info);
if (rewritten_literal == nullptr) {
for (size_t restore_idx = 1; restore_idx < children.size(); ++restore_idx) {
auto restore_literal = unwrap_literal_for_file_cast(children[restore_idx],
rewrite_info->table_type);
if (restore_literal != nullptr) {
children[restore_idx] = original_table_literal(restore_literal);
}
}
expr->set_children(std::move(children));
return false;
}
rewritten_literals.push_back(std::move(rewritten_literal));
}

children[0] = create_file_slot_ref(*slot_ref, *rewrite_info);
for (size_t literal_idx = 0; literal_idx < rewritten_literals.size(); ++literal_idx) {
children[literal_idx + 1] = std::move(rewritten_literals[literal_idx]);
}
expr->set_children(std::move(children));
return true;
}

static VExprSPtr rewrite_table_expr_to_file_expr(
const VExprSPtr& expr,
const std::map<int32_t, FileSlotRewriteInfo>& table_column_to_file_slot) {
Expand All @@ -236,6 +280,9 @@ static VExprSPtr rewrite_table_expr_to_file_expr(
if (rewrite_binary_slot_literal_predicate(expr, table_column_to_file_slot)) {
return expr;
}
if (rewrite_in_slot_literal_predicate(expr, table_column_to_file_slot)) {
return expr;
}
if (expr->is_slot_ref()) {
const auto* slot_ref = assert_cast<const VSlotRef*>(expr.get());
const auto rewrite_it = table_column_to_file_slot.find(slot_ref->slot_id());
Expand Down
149 changes: 149 additions & 0 deletions be/test/format/reader/expr/cast_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "format/reader/table_reader.h"
#include "runtime/descriptors.h"
#include "testutil/column_helper.h"
#include "testutil/mock/mock_in_expr.h"
#include "testutil/mock/mock_runtime_state.h"

namespace doris {
Expand Down Expand Up @@ -493,6 +494,154 @@ TEST_F(CastTest, ColumnMapperCastsLiteralForLiteralSlotPredicateTypeMismatch) {
file_request.conjuncts[0]->close();
}

TEST_F(CastTest, ColumnMapperCastsInPredicateLiteralsForTypeMismatch) {
reader::TableColumnMapper mapper;
reader::TableColumn table_column;
table_column.id = 7;
table_column.name = "value";
table_column.type = std::make_shared<DataTypeInt64>();
std::vector<reader::TableColumn> projected_columns {table_column};

reader::SchemaField file_field;
file_field.id = 0;
file_field.name = "value";
file_field.type = std::make_shared<DataTypeInt32>();
std::vector<reader::SchemaField> file_schema {file_field};

auto status = mapper.create_mapping(projected_columns, {}, file_schema);
ASSERT_TRUE(status.ok()) << status;

auto predicate = MockInExpr::create();
predicate->add_child(TableSlotRef::create_shared(7, 7, -1, table_column.type, "value"));
predicate->add_child(
TableLiteral::create_shared(table_column.type, Field::create_field<TYPE_BIGINT>(15)));
predicate->add_child(
TableLiteral::create_shared(table_column.type, Field::create_field<TYPE_BIGINT>(22)));
reader::TableFilter table_filter;
table_filter.conjunct = VExprContext::create_shared(predicate);
table_filter.slot_ids = {7};

reader::FileScanRequest file_request;
ASSERT_TRUE(
mapper.create_scan_request({table_filter}, {}, projected_columns, &file_request).ok());
ASSERT_EQ(file_request.conjuncts.size(), 1);
ASSERT_EQ(projection_ids(file_request.predicate_columns), std::vector<reader::ColumnId>({0}));
const auto& localized_expr = file_request.conjuncts[0]->root();
ASSERT_EQ(localized_expr->get_num_children(), 3);
const auto* localized_slot =
assert_cast<const TableSlotRef*>(localized_expr->children()[0].get());
EXPECT_EQ(localized_slot->column_id(), 0);
EXPECT_TRUE(localized_slot->data_type()->equals(*file_field.type));
EXPECT_TRUE(localized_expr->children()[1]->is_literal());
EXPECT_TRUE(localized_expr->children()[1]->data_type()->equals(*file_field.type));
EXPECT_TRUE(localized_expr->children()[2]->is_literal());
EXPECT_TRUE(localized_expr->children()[2]->data_type()->equals(*file_field.type));
}

TEST_F(CastTest, ColumnMapperFallsBackToSlotCastWhenInPredicateLiteralRewriteFails) {
reader::TableColumnMapper mapper;
reader::TableColumn table_column;
table_column.id = 7;
table_column.name = "value";
table_column.type = std::make_shared<DataTypeString>();
std::vector<reader::TableColumn> projected_columns {table_column};

reader::SchemaField file_field;
file_field.id = 0;
file_field.name = "value";
file_field.type = std::make_shared<DataTypeInt32>();
std::vector<reader::SchemaField> file_schema {file_field};

auto status = mapper.create_mapping(projected_columns, {}, file_schema);
ASSERT_TRUE(status.ok()) << status;

auto predicate = MockInExpr::create();
predicate->add_child(TableSlotRef::create_shared(7, 7, -1, table_column.type, "value"));
predicate->add_child(
TableLiteral::create_shared(table_column.type, Field::create_field<TYPE_STRING>("10")));
predicate->add_child(TableLiteral::create_shared(table_column.type,
Field::create_field<TYPE_STRING>("bad")));
reader::TableFilter table_filter;
table_filter.conjunct = VExprContext::create_shared(predicate);
table_filter.slot_ids = {7};

reader::FileScanRequest file_request;
ASSERT_TRUE(
mapper.create_scan_request({table_filter}, {}, projected_columns, &file_request).ok());
ASSERT_EQ(file_request.conjuncts.size(), 1);
const auto& localized_expr = file_request.conjuncts[0]->root();
ASSERT_EQ(localized_expr->get_num_children(), 3);
const auto& localized_child = localized_expr->children()[0];
ASSERT_NE(dynamic_cast<const Cast*>(localized_child.get()), nullptr);
ASSERT_EQ(localized_child->get_num_children(), 1);
const auto* localized_slot =
assert_cast<const TableSlotRef*>(localized_child->children()[0].get());
EXPECT_EQ(localized_slot->column_id(), 0);
EXPECT_TRUE(localized_slot->data_type()->equals(*file_field.type));
EXPECT_TRUE(localized_child->data_type()->equals(*table_column.type));
EXPECT_TRUE(localized_expr->children()[1]->is_literal());
EXPECT_TRUE(localized_expr->children()[1]->data_type()->equals(*table_column.type));
EXPECT_TRUE(localized_expr->children()[2]->is_literal());
EXPECT_TRUE(localized_expr->children()[2]->data_type()->equals(*table_column.type));
}

TEST_F(CastTest, ColumnMapperDoesNotLeakRewrittenInPredicateLiteralAcrossSplits) {
reader::TableColumn table_column;
table_column.id = 7;
table_column.name = "value";
table_column.type = std::make_shared<DataTypeInt64>();
std::vector<reader::TableColumn> projected_columns {table_column};

auto predicate = MockInExpr::create();
predicate->add_child(TableSlotRef::create_shared(7, 7, -1, table_column.type, "value"));
predicate->add_child(
TableLiteral::create_shared(table_column.type, Field::create_field<TYPE_BIGINT>(15)));
predicate->add_child(
TableLiteral::create_shared(table_column.type, Field::create_field<TYPE_BIGINT>(22)));
reader::TableFilter table_filter;
table_filter.conjunct = VExprContext::create_shared(predicate);
table_filter.slot_ids = {7};

reader::SchemaField int_file_field;
int_file_field.id = 0;
int_file_field.name = "value";
int_file_field.type = std::make_shared<DataTypeInt32>();
reader::TableColumnMapper int_mapper;
ASSERT_TRUE(int_mapper.create_mapping(projected_columns, {}, {int_file_field}).ok());
reader::FileScanRequest int_request;
ASSERT_TRUE(int_mapper.create_scan_request({table_filter}, {}, projected_columns, &int_request)
.ok());
ASSERT_EQ(int_request.conjuncts.size(), 1);
const auto& int_localized_expr = int_request.conjuncts[0]->root();
ASSERT_EQ(int_localized_expr->get_num_children(), 3);
EXPECT_TRUE(int_localized_expr->children()[1]->is_literal());
EXPECT_TRUE(int_localized_expr->children()[1]->data_type()->equals(*int_file_field.type));
EXPECT_TRUE(int_localized_expr->children()[2]->is_literal());
EXPECT_TRUE(int_localized_expr->children()[2]->data_type()->equals(*int_file_field.type));

reader::SchemaField bigint_file_field;
bigint_file_field.id = 0;
bigint_file_field.name = "value";
bigint_file_field.type = std::make_shared<DataTypeInt64>();
reader::TableColumnMapper bigint_mapper;
ASSERT_TRUE(bigint_mapper.create_mapping(projected_columns, {}, {bigint_file_field}).ok());
reader::FileScanRequest bigint_request;
ASSERT_TRUE(bigint_mapper
.create_scan_request({table_filter}, {}, projected_columns, &bigint_request)
.ok());
ASSERT_EQ(bigint_request.conjuncts.size(), 1);
const auto& bigint_localized_expr = bigint_request.conjuncts[0]->root();
ASSERT_EQ(bigint_localized_expr->get_num_children(), 3);
const auto* localized_slot =
assert_cast<const TableSlotRef*>(bigint_localized_expr->children()[0].get());
EXPECT_EQ(localized_slot->column_id(), 0);
EXPECT_TRUE(localized_slot->data_type()->equals(*bigint_file_field.type));
EXPECT_TRUE(bigint_localized_expr->children()[1]->is_literal());
EXPECT_TRUE(bigint_localized_expr->children()[1]->data_type()->equals(*bigint_file_field.type));
EXPECT_TRUE(bigint_localized_expr->children()[2]->is_literal());
EXPECT_TRUE(bigint_localized_expr->children()[2]->data_type()->equals(*bigint_file_field.type));
}

TEST_F(CastTest, ColumnMapperFallsBackToSlotCastWhenLiteralRewriteFails) {
reader::TableColumnMapper mapper;
reader::TableColumn table_column;
Expand Down
46 changes: 23 additions & 23 deletions be/test/format/reader/table_reader_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2483,17 +2483,17 @@ TEST(TableColumnMapperByIndexTest, MapsTopLevelColumnsByPositionIgnoringFileName
const auto& mappings = mapper.mappings();
ASSERT_EQ(mappings.size(), 3);

ASSERT_TRUE(mappings[0].file_column_id.has_value());
EXPECT_EQ(*mappings[0].file_column_id, 0);
ASSERT_TRUE(mappings[0].field_id.has_value());
EXPECT_EQ(*mappings[0].field_id, 0);
EXPECT_EQ(mappings[0].file_column_name, "_col0");
EXPECT_FALSE(mappings[0].is_constant);

ASSERT_TRUE(mappings[1].file_column_id.has_value());
EXPECT_EQ(*mappings[1].file_column_id, 1);
ASSERT_TRUE(mappings[1].field_id.has_value());
EXPECT_EQ(*mappings[1].field_id, 1);
EXPECT_EQ(mappings[1].file_column_name, "_col1");

ASSERT_TRUE(mappings[2].file_column_id.has_value());
EXPECT_EQ(*mappings[2].file_column_id, 2);
ASSERT_TRUE(mappings[2].field_id.has_value());
EXPECT_EQ(*mappings[2].field_id, 2);
EXPECT_EQ(mappings[2].file_column_name, "_col2");
}

Expand All @@ -2519,12 +2519,12 @@ TEST(TableColumnMapperByIndexTest, SparseProjectionMapsByExplicitFileIndex) {

const auto& mappings = mapper.mappings();
ASSERT_EQ(mappings.size(), 2);
ASSERT_TRUE(mappings[0].file_column_id.has_value());
EXPECT_EQ(*mappings[0].file_column_id, 2);
ASSERT_TRUE(mappings[0].field_id.has_value());
EXPECT_EQ(*mappings[0].field_id, 2);
EXPECT_EQ(mappings[0].file_column_name, "_col2");

ASSERT_TRUE(mappings[1].file_column_id.has_value());
EXPECT_EQ(*mappings[1].file_column_id, 4);
ASSERT_TRUE(mappings[1].field_id.has_value());
EXPECT_EQ(*mappings[1].field_id, 4);
EXPECT_EQ(mappings[1].file_column_name, "_col4");
}

Expand Down Expand Up @@ -2559,15 +2559,15 @@ TEST(TableColumnMapperByIndexTest, PartitionColumnsTakeConstantAndDoNotConsumeFi
ASSERT_EQ(mappings.size(), 3);

EXPECT_TRUE(mappings[0].is_constant);
EXPECT_FALSE(mappings[0].file_column_id.has_value());
EXPECT_FALSE(mappings[0].field_id.has_value());
EXPECT_NE(mappings[0].default_expr, nullptr);

ASSERT_TRUE(mappings[1].file_column_id.has_value());
EXPECT_EQ(*mappings[1].file_column_id, 0);
ASSERT_TRUE(mappings[1].field_id.has_value());
EXPECT_EQ(*mappings[1].field_id, 0);
EXPECT_EQ(mappings[1].file_column_name, "_col0");

ASSERT_TRUE(mappings[2].file_column_id.has_value());
EXPECT_EQ(*mappings[2].file_column_id, 1);
ASSERT_TRUE(mappings[2].field_id.has_value());
EXPECT_EQ(*mappings[2].field_id, 1);
EXPECT_EQ(mappings[2].file_column_name, "_col1");
}

Expand Down Expand Up @@ -2601,14 +2601,14 @@ TEST(TableColumnMapperByIndexTest, FileIndexOutOfRangeFallsBackToDefaultOrMissin
const auto& mappings = mapper.mappings();
ASSERT_EQ(mappings.size(), 3);

ASSERT_TRUE(mappings[0].file_column_id.has_value());
EXPECT_EQ(*mappings[0].file_column_id, 0);
ASSERT_TRUE(mappings[0].field_id.has_value());
EXPECT_EQ(*mappings[0].field_id, 0);

EXPECT_FALSE(mappings[1].file_column_id.has_value());
EXPECT_FALSE(mappings[1].field_id.has_value());
EXPECT_TRUE(mappings[1].is_constant);
EXPECT_EQ(mappings[1].default_expr, literal_expr);

EXPECT_FALSE(mappings[2].file_column_id.has_value());
EXPECT_FALSE(mappings[2].field_id.has_value());
EXPECT_FALSE(mappings[2].is_constant);
EXPECT_EQ(mappings[2].default_expr, nullptr);
}
Expand Down Expand Up @@ -2652,8 +2652,8 @@ TEST(TableColumnMapperByIndexTest, ExtraFileColumnsAreSimplyIgnored) {

const auto& mappings = mapper.mappings();
ASSERT_EQ(mappings.size(), 1);
ASSERT_TRUE(mappings[0].file_column_id.has_value());
EXPECT_EQ(*mappings[0].file_column_id, 0);
ASSERT_TRUE(mappings[0].field_id.has_value());
EXPECT_EQ(*mappings[0].field_id, 0);
}

TEST(TableColumnMapperByIndexTest, IgnoresFileColumnNames) {
Expand All @@ -2678,8 +2678,8 @@ TEST(TableColumnMapperByIndexTest, IgnoresFileColumnNames) {

const auto& mappings = mapper.mappings();
ASSERT_EQ(mappings.size(), 1);
ASSERT_TRUE(mappings[0].file_column_id.has_value());
EXPECT_EQ(*mappings[0].file_column_id, 20);
ASSERT_TRUE(mappings[0].field_id.has_value());
EXPECT_EQ(*mappings[0].field_id, 20);
EXPECT_EQ(mappings[0].file_column_name, "b");
}

Expand Down
Loading