From 89349065847f4f9793075645e0a68abf14e97a0b Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 17 Apr 2024 18:30:16 +0800 Subject: [PATCH] feat(map): cast null or map() to explicit map type (#3847) Auto casting to target type when necessary, which usually happens with insert statement while a default schema defined by target table. This supports constructing null or empty values from offline engine. --- cases/plan/create.yaml | 2 +- cases/query/udf_query.yaml | 15 +++++++ hybridse/include/node/node_manager.h | 22 ++++------- hybridse/include/node/sql_node.h | 18 ++++++--- hybridse/src/codegen/cast_expr_ir_builder.cc | 27 ++++++++++++- hybridse/src/codegen/expr_ir_builder.cc | 12 ++---- hybridse/src/codegen/insert_row_builder.cc | 27 ++++++++++++- hybridse/src/codegen/insert_row_builder.h | 1 - .../src/codegen/insert_row_builder_test.cc | 14 ++++++- hybridse/src/codegen/ir_base_builder.cc | 4 +- hybridse/src/codegen/native_value.cc | 6 +-- hybridse/src/codegen/struct_ir_builder.cc | 38 +++++++++++++++--- hybridse/src/codegen/struct_ir_builder.h | 7 ++++ hybridse/src/codegen/type_ir_builder.cc | 39 +------------------ hybridse/src/codegen/type_ir_builder.h | 4 -- hybridse/src/node/expr_node.cc | 29 ++++++++++---- hybridse/src/node/node_manager.cc | 9 +---- hybridse/src/node/sql_node.cc | 6 +-- hybridse/src/passes/lambdafy_projects.cc | 5 +++ hybridse/src/planv2/ast_node_converter.cc | 3 +- .../openmldb/batch/SparkRowCodec.scala | 2 - .../batch/nodes/ConstProjectPlan.scala | 2 +- .../batch/nodes/SimpleProjectPlan.scala | 2 +- .../openmldb/batch/utils/ExpressionUtil.scala | 2 +- .../openmldb/batch/TestProjectPlan.scala | 25 ++++-------- .../openmldb/batch/end2end/TestProject.scala | 5 ++- .../batch/utils/TestGraphvizUtil.scala | 4 +- 27 files changed, 197 insertions(+), 133 deletions(-) diff --git a/cases/plan/create.yaml b/cases/plan/create.yaml index 00c7e583406..7c0eed558b1 100644 --- a/cases/plan/create.yaml +++ b/cases/plan/create.yaml @@ -798,7 +798,7 @@ cases: | +-0: | | +-node[kColumnDesc] | | +-column_name: column1 - | | +-column_type: string DEFAULT string(1) + | | +-column_type: string DEFAULT STRING(1) | +-1: | +-node[kColumnDesc] | +-column_name: column3 diff --git a/cases/query/udf_query.yaml b/cases/query/udf_query.yaml index 562d7672664..fefe1380dbb 100644 --- a/cases/query/udf_query.yaml +++ b/cases/query/udf_query.yaml @@ -557,6 +557,7 @@ cases: # ================================================================ # Map data type + # FIXME: request mode tests disabled, because TestRequestEngineForLastRow cause SEG FAULT # ================================================================ - id: 13 mode: request-unsupport @@ -637,3 +638,17 @@ cases: 1, abc 2, null + - id: 19 + mode: request-unsupport + desc: empty or null map + sql: | + select cast (null as map)[0] as o1, + cast (null as map) ["12"] as o2, + cast (map() as map) ["12"] as o3, + cast (map() as map) [7] as o4, + cast (map(7, "9") as map) [7] as o5, + cast (map() as map) [date("2012-12-12")] as o6, + expect: + columns: ["o1 string", "o2 int", "o3 int64", "o4 timestamp", "o5 string", "o6 timestamp"] + data: | + NULL, NULL, NULL, NULL, 9, NULL diff --git a/hybridse/include/node/node_manager.h b/hybridse/include/node/node_manager.h index d210304cdf3..9fc217d6f82 100644 --- a/hybridse/include/node/node_manager.h +++ b/hybridse/include/node/node_manager.h @@ -410,26 +410,20 @@ class NodeManager { return node_ptr; } - private: - void SetNodeUniqueId(ExprNode *node); - void SetNodeUniqueId(TypeNode *node); - void SetNodeUniqueId(PlanNode *node); - void SetNodeUniqueId(vm::PhysicalOpNode *node); + void SetIdCounter(size_t i) { + assert(i > id_counter_); + id_counter_ = i; + } + private: template void SetNodeUniqueId(T *node) { - node->SetNodeId(other_node_idx_counter_++); + node->SetNodeId(id_counter_++); } std::list node_list_; - - // unique id counter for various types of node - size_t expr_idx_counter_ = 1; - size_t type_idx_counter_ = 1; - size_t plan_idx_counter_ = 1; - size_t physical_plan_idx_counter_ = 1; - size_t other_node_idx_counter_ = 1; - size_t exprid_idx_counter_ = 0; + size_t id_counter_ = 0; + size_t expr_id_counter_ = 0; }; } // namespace node diff --git a/hybridse/include/node/sql_node.h b/hybridse/include/node/sql_node.h index a5c8579a9d9..b21b68bb49b 100644 --- a/hybridse/include/node/sql_node.h +++ b/hybridse/include/node/sql_node.h @@ -1520,22 +1520,28 @@ class FnDefNode : public SqlNode { class CastExprNode : public ExprNode { public: - explicit CastExprNode(const node::DataType cast_type, node::ExprNode *expr) + explicit CastExprNode(const node::TypeNode *cast_type, node::ExprNode *expr) : ExprNode(kExprCast), cast_type_(cast_type) { this->AddChild(expr); } - ~CastExprNode() {} - void Print(std::ostream &output, const std::string &org_tab) const; - const std::string GetExprString() const; - virtual bool Equals(const ExprNode *that) const; + void Print(std::ostream &output, const std::string &org_tab) const override; + const std::string GetExprString() const override; + bool Equals(const ExprNode *that) const override; CastExprNode *ShadowCopy(NodeManager *) const override; static CastExprNode *CastFrom(ExprNode *node); ExprNode *expr() const { return GetChild(0); } - const DataType cast_type_; + const TypeNode *cast_type() const { return cast_type_; } + + // legacy interface, required by offline batch + // pls use cast_type() as much as possible + node::DataType base_cast_type() const; Status InferAttr(ExprAnalysisContext *ctx) override; + + private: + const TypeNode *cast_type_; }; class WhenExprNode : public ExprNode { diff --git a/hybridse/src/codegen/cast_expr_ir_builder.cc b/hybridse/src/codegen/cast_expr_ir_builder.cc index 57e4103cba6..d971f08496f 100644 --- a/hybridse/src/codegen/cast_expr_ir_builder.cc +++ b/hybridse/src/codegen/cast_expr_ir_builder.cc @@ -66,8 +66,10 @@ bool CastExprIRBuilder::IsSafeCast(::llvm::Type* lhs, ::llvm::Type* rhs) { } Status CastExprIRBuilder::Cast(const NativeValue& value, ::llvm::Type* cast_type, NativeValue* output) { - CHECK_STATUS(TypeIRBuilder::BinaryOpTypeInfer(node::ExprNode::IsCastAccept, - value.GetType(), cast_type)); + if (value.GetType() == cast_type) { + *output = value; + return {}; + } if (IsSafeCast(value.GetType(), cast_type)) { CHECK_STATUS(SafeCast(value, cast_type, output)); } else { @@ -81,6 +83,7 @@ Status CastExprIRBuilder::SafeCast(const NativeValue& value, ::llvm::Type* dst_t CHECK_TRUE(IsSafeCast(value.GetType(), dst_type), kCodegenError, "Safe cast fail: unsafe cast"); Status status; if (value.IsConstNull()) { + // VOID type auto res = CreateSafeNull(block_, dst_type); CHECK_TRUE(res.ok(), kCodegenError, res.status().ToString()); *output = res.value(); @@ -114,6 +117,12 @@ Status CastExprIRBuilder::SafeCast(const NativeValue& value, ::llvm::Type* dst_t Status CastExprIRBuilder::UnSafeCast(const NativeValue& value, ::llvm::Type* dst_type, NativeValue* output) { ::llvm::IRBuilder<> builder(block_); + node::NodeManager nm; + const node::TypeNode* src_node = nullptr; + const node::TypeNode* dst_node = nullptr; + CHECK_TRUE(GetFullType(&nm, value.GetType(), &src_node), kCodegenError); + CHECK_TRUE(GetFullType(&nm, dst_type, &dst_node), kCodegenError); + if (value.IsConstNull() || (TypeIRBuilder::IsNumber(dst_type) && TypeIRBuilder::IsDatePtr(value.GetType()))) { // input is const null or (cast date to number) auto res = CreateSafeNull(block_, dst_type); @@ -135,6 +144,20 @@ Status CastExprIRBuilder::UnSafeCast(const NativeValue& value, ::llvm::Type* dst StringIRBuilder string_ir_builder(block_->getModule()); CHECK_STATUS(string_ir_builder.CastToNumber(block_, value, dst_type, output)); return Status::OK(); + } else if (src_node->IsMap() && dst_node->IsMap()) { + auto src_map_node = src_node->GetAsOrNull(); + assert(src_map_node != nullptr && "logic error: map type empty"); + if (src_map_node->GetGenericType(0)->IsNull() && src_map_node->GetGenericType(1)->IsNull()) { + auto s = StructTypeIRBuilder::CreateStructTypeIRBuilder(block_->getModule(), dst_type); + CHECK_TRUE(s.ok(), kCodegenError, s.status().ToString()); + llvm::Value* val = nullptr; + CHECK_TRUE(s.value()->CreateDefault(block_, &val), kCodegenError); + *output = NativeValue::Create(val); + return Status::OK(); + } else { + CHECK_TRUE(false, kCodegenError, "unimplemented: casting ", src_node->DebugString(), " to ", + dst_node->DebugString()); + } } else { Status status; ::llvm::Value* output_value = nullptr; diff --git a/hybridse/src/codegen/expr_ir_builder.cc b/hybridse/src/codegen/expr_ir_builder.cc index d6f75cd8b4a..b3ccffd5e0d 100644 --- a/hybridse/src/codegen/expr_ir_builder.cc +++ b/hybridse/src/codegen/expr_ir_builder.cc @@ -229,7 +229,7 @@ Status ExprIRBuilder::BuildConstExpr( ::llvm::IRBuilder<> builder(ctx_->GetCurrentBlock()); switch (const_node->GetDataType()) { case ::hybridse::node::kNull: { - *output = NativeValue(nullptr, nullptr, llvm::Type::getTokenTy(builder.getContext())); + *output = NativeValue(nullptr, nullptr, llvm::Type::getVoidTy(builder.getContext())); break; } case ::hybridse::node::kBool: { @@ -649,14 +649,10 @@ Status ExprIRBuilder::BuildCastExpr(const ::hybridse::node::CastExprNode* node, CastExprIRBuilder cast_builder(ctx_->GetCurrentBlock()); ::llvm::Type* cast_type = NULL; - CHECK_TRUE(GetLlvmType(ctx_->GetModule(), node->cast_type_, &cast_type), - kCodegenError, "Fail to cast expr: dist type invalid"); + CHECK_TRUE(GetLlvmType(ctx_->GetModule(), node->cast_type(), &cast_type), kCodegenError, + "Fail to cast expr: dist type invalid"); - if (cast_builder.IsSafeCast(left.GetType(), cast_type)) { - return cast_builder.SafeCast(left, cast_type, output); - } else { - return cast_builder.UnSafeCast(left, cast_type, output); - } + return cast_builder.Cast(left, cast_type, output); } Status ExprIRBuilder::BuildBinaryExpr(const ::hybridse::node::BinaryExpr* node, diff --git a/hybridse/src/codegen/insert_row_builder.cc b/hybridse/src/codegen/insert_row_builder.cc index bea754485c6..a4dd41eb4aa 100644 --- a/hybridse/src/codegen/insert_row_builder.cc +++ b/hybridse/src/codegen/insert_row_builder.cc @@ -16,6 +16,7 @@ #include "codegen/insert_row_builder.h" +#include #include #include #include @@ -28,7 +29,9 @@ #include "codegen/buf_ir_builder.h" #include "codegen/context.h" #include "codegen/expr_ir_builder.h" +#include "codegen/ir_base_builder.h" #include "node/node_manager.h" +#include "node/sql_node.h" #include "passes/resolve_fn_and_attrs.h" #include "udf/default_udf_library.h" #include "vm/jit_wrapper.h" @@ -36,6 +39,8 @@ namespace hybridse { namespace codegen { +static size_t MaxExprId(absl::Span); + InsertRowBuilder::InsertRowBuilder(vm::HybridSeJitWrapper* jit, const codec::Schema* schema) : schema_(schema), jit_(jit) {} @@ -63,6 +68,9 @@ absl::StatusOr InsertRowBuilder::ComputeRowUnsafe(absl::Span(absl::StrCat("insert_row_builder_", fn_counter_.load()), *llvm_ctx); vm::SchemasContext empty_sc; node::NodeManager nm; + // WORKAROUND. Set the id counter to the max of all input expr nodes, + // so there will no node id conflicts during codegen + nm.SetIdCounter(MaxExprId(values) + 1); codec::Schema empty_param_types; CodeGenContext dump_ctx(llvm_module.get(), &empty_sc, &empty_param_types, &nm); @@ -71,9 +79,16 @@ absl::StatusOr InsertRowBuilder::ComputeRowUnsafe(absl::Span transformed; - for (auto& expr : values) { + for (size_t i = 0; i < values.size(); i++) { + auto expr = values[i]; node::ExprNode* out = nullptr; CHECK_STATUS_TO_ABSL(resolver.VisitExpr(expr, &out)); + auto tgt_type = ColumnSchema2Type(schema_->Get(i).schema(), &nm); + CHECK_ABSL_STATUSOR(tgt_type); + if (!tgt_type.value()->Equals(out->GetOutputType())) { + auto cast = nm.MakeNode(tgt_type.value(), out); + CHECK_STATUS_TO_ABSL(resolver.VisitExpr(cast, &out)); + } transformed.push_back(out); } @@ -140,5 +155,15 @@ absl::StatusOr InsertRowBuilder::BuildFn(CodeGenContext* ctx, l return fn; } +size_t MaxExprId(absl::Span exprs) { + size_t ret = 0; + + for (auto& expr : exprs) { + ret = std::max(std::max(ret, expr->node_id()), MaxExprId(expr->children_)); + } + + return ret; +} + } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/insert_row_builder.h b/hybridse/src/codegen/insert_row_builder.h index 3391bbe9e57..4576590eaf7 100644 --- a/hybridse/src/codegen/insert_row_builder.h +++ b/hybridse/src/codegen/insert_row_builder.h @@ -60,7 +60,6 @@ class InsertRowBuilder { absl::StatusOr BuildFn(CodeGenContext* ctx, llvm::StringRef fn_name, absl::Span); - // CodeGenContextBase* ctx_; const codec::Schema* schema_; vm::HybridSeJitWrapper* jit_; std::atomic fn_counter_ = 0; diff --git a/hybridse/src/codegen/insert_row_builder_test.cc b/hybridse/src/codegen/insert_row_builder_test.cc index 900fc9b2e04..6c4f761fddb 100644 --- a/hybridse/src/codegen/insert_row_builder_test.cc +++ b/hybridse/src/codegen/insert_row_builder_test.cc @@ -30,7 +30,7 @@ namespace codegen { class InsertRowBuilderTest : public ::testing::Test {}; TEST_F(InsertRowBuilderTest, encode) { - std::string sql = "insert into t1 values (1, map (1, '12'))"; + std::string sql = "insert into t1 values (1, map (1, '12'), null, map())"; vm::SqlContext ctx; ctx.sql = sql; auto s = plan::PlanAPI::CreatePlanTreeFromScript(&ctx); @@ -51,6 +51,18 @@ TEST_F(InsertRowBuilderTest, encode) { map_ty->mutable_key_type()->set_base_type(type::kInt32); map_ty->mutable_value_type()->set_base_type(type::kVarchar); } + { + auto col = sc.Add(); + auto map_ty = col->mutable_schema()->mutable_map_type(); + map_ty->mutable_key_type()->set_base_type(type::kFloat); + map_ty->mutable_value_type()->set_base_type(type::kTimestamp); + } + { + auto col = sc.Add(); + auto map_ty = col->mutable_schema()->mutable_map_type(); + map_ty->mutable_key_type()->set_base_type(type::kDate); + map_ty->mutable_value_type()->set_base_type(type::kVarchar); + } auto jit = std::shared_ptr(vm::HybridSeJitWrapper::Create()); ASSERT_TRUE(jit->Init()); diff --git a/hybridse/src/codegen/ir_base_builder.cc b/hybridse/src/codegen/ir_base_builder.cc index f51b9600725..72dc7c93de7 100644 --- a/hybridse/src/codegen/ir_base_builder.cc +++ b/hybridse/src/codegen/ir_base_builder.cc @@ -610,8 +610,8 @@ bool GetBaseType(::llvm::Type* type, ::hybridse::node::DataType* output) { return false; } switch (type->getTypeID()) { - case ::llvm::Type::TokenTyID: { - *output = ::hybridse::node::kNull; + case ::llvm::Type::VoidTyID: { + *output = ::hybridse::node::kVoid; return true; } case ::llvm::Type::FloatTyID: { diff --git a/hybridse/src/codegen/native_value.cc b/hybridse/src/codegen/native_value.cc index 23cd400922c..88592799720 100644 --- a/hybridse/src/codegen/native_value.cc +++ b/hybridse/src/codegen/native_value.cc @@ -100,10 +100,8 @@ bool NativeValue::IsNullable() const { return IsConstNull() || HasFlag(); } // NativeValue is null if: // - raw_ is null -// - type_ is of token type. -// Currently there is no elsewhere using token type, so assert token type should be safe. -// token type represents SQL NULL may not appropriate, more work refer #926 -bool NativeValue::IsConstNull() const { return raw_ == nullptr || (type_ != nullptr && type_->isTokenTy()); } +// - type_ is of void type. +bool NativeValue::IsConstNull() const { return raw_ == nullptr || (type_ != nullptr && type_->isVoidTy()); } void NativeValue::SetName(const std::string& name) { if (raw_ == nullptr) { diff --git a/hybridse/src/codegen/struct_ir_builder.cc b/hybridse/src/codegen/struct_ir_builder.cc index 0d08e89aefb..4b00b052c29 100644 --- a/hybridse/src/codegen/struct_ir_builder.cc +++ b/hybridse/src/codegen/struct_ir_builder.cc @@ -21,8 +21,10 @@ #include "codegen/context.h" #include "codegen/date_ir_builder.h" #include "codegen/ir_base_builder.h" +#include "codegen/map_ir_builder.h" #include "codegen/string_ir_builder.h" #include "codegen/timestamp_ir_builder.h" +#include "node/node_manager.h" namespace hybridse { namespace codegen { @@ -40,19 +42,34 @@ bool StructTypeIRBuilder::StructCopyFrom(::llvm::BasicBlock* block, ::llvm::Valu absl::StatusOr> StructTypeIRBuilder::CreateStructTypeIRBuilder( ::llvm::Module* m, ::llvm::Type* type) { - node::DataType base_type; - if (!GetBaseType(type, &base_type)) { - return absl::UnimplementedError( - absl::StrCat("fail to create struct type ir builder for ", GetLlvmObjectString(type))); + node::NodeManager nm; + const node::TypeNode* ctype = nullptr; + if (!GetFullType(&nm, type, &ctype)) { + return absl::InvalidArgumentError(absl::StrCat("can't get full type for: ", GetLlvmObjectString(type))); } - switch (base_type) { + switch (ctype->base()) { case node::kTimestamp: return std::make_unique(m); case node::kDate: return std::make_unique(m); case node::kVarchar: return std::make_unique(m); + case node::DataType::kMap: { + assert(ctype->IsMap() && "logic error: not a map type"); + auto map_type = ctype->GetAsOrNull(); + assert(map_type != nullptr && "logic error: map type empty"); + ::llvm::Type* key_type = nullptr; + ::llvm::Type* value_type = nullptr; + if (codegen::GetLlvmType(m, map_type->key_type(), &key_type) && + codegen::GetLlvmType(m, map_type->value_type(), &value_type)) { + return std::make_unique(m, key_type, value_type); + } else { + return absl::InvalidArgumentError( + absl::Substitute("not able to casting map type: $0", GetLlvmObjectString(type))); + } + break; + } default: { break; } @@ -224,5 +241,16 @@ absl::StatusOr> StructTypeIRBuilder::Load(CodeGenConte return res; } + +absl::StatusOr CreateSafeNull(::llvm::BasicBlock* block, ::llvm::Type* type) { + if (TypeIRBuilder::IsStructPtr(type)) { + auto s = StructTypeIRBuilder::CreateStructTypeIRBuilder(block->getModule(), type); + CHECK_ABSL_STATUSOR(s); + return s.value()->CreateNull(block); + } + + return NativeValue(nullptr, nullptr, type); +} + } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/struct_ir_builder.h b/hybridse/src/codegen/struct_ir_builder.h index 4c09e488ce9..8529c2d7848 100644 --- a/hybridse/src/codegen/struct_ir_builder.h +++ b/hybridse/src/codegen/struct_ir_builder.h @@ -34,6 +34,8 @@ class StructTypeIRBuilder : public TypeIRBuilder { explicit StructTypeIRBuilder(::llvm::Module*); ~StructTypeIRBuilder(); + // construct corresponding struct ir builder if exists for input type, + // otherwise, error status returned static absl::StatusOr> CreateStructTypeIRBuilder(::llvm::Module*, ::llvm::Type*); static bool StructCopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist); @@ -93,6 +95,11 @@ class StructTypeIRBuilder : public TypeIRBuilder { ::llvm::Module* m_; ::llvm::StructType* struct_type_; }; + +// construct a safe null value for type +// returns NativeValue{raw, is_null=true} on success, raw is ensured to be not nullptr +absl::StatusOr CreateSafeNull(::llvm::BasicBlock* block, ::llvm::Type* type); + } // namespace codegen } // namespace hybridse #endif // HYBRIDSE_SRC_CODEGEN_STRUCT_IR_BUILDER_H_ diff --git a/hybridse/src/codegen/type_ir_builder.cc b/hybridse/src/codegen/type_ir_builder.cc index 0cba6015b9d..e49f1420787 100644 --- a/hybridse/src/codegen/type_ir_builder.cc +++ b/hybridse/src/codegen/type_ir_builder.cc @@ -16,11 +16,8 @@ #include "codegen/type_ir_builder.h" -#include "absl/status/status.h" -#include "codegen/date_ir_builder.h" +#include "base/fe_status.h" #include "codegen/ir_base_builder.h" -#include "codegen/string_ir_builder.h" -#include "codegen/timestamp_ir_builder.h" #include "node/node_manager.h" namespace hybridse { @@ -54,7 +51,7 @@ bool TypeIRBuilder::IsBool(::llvm::Type* type) { return data_type == node::kBool; } -bool TypeIRBuilder::IsNull(::llvm::Type* type) { return type->isTokenTy(); } +bool TypeIRBuilder::IsNull(::llvm::Type* type) { return type->isVoidTy(); } bool TypeIRBuilder::IsInterger(::llvm::Type* type) { return type->isIntegerTy(); @@ -132,37 +129,5 @@ base::Status TypeIRBuilder::BinaryOpTypeInfer( return base::Status::OK(); } -absl::StatusOr CreateSafeNull(::llvm::BasicBlock* block, ::llvm::Type* type) { - node::DataType data_type; - if (!GetBaseType(type, &data_type)) { - return absl::InvalidArgumentError(absl::StrCat("can't get base type for: ", GetLlvmObjectString(type))); - } - - if (TypeIRBuilder::IsStructPtr(type)) { - std::unique_ptr builder = nullptr; - - switch (data_type) { - case node::DataType::kTimestamp: { - builder.reset(new TimestampIRBuilder(block->getModule())); - break; - } - case node::DataType::kDate: { - builder.reset(new DateIRBuilder(block->getModule())); - break; - } - case node::DataType::kVarchar: { - builder.reset(new StringIRBuilder(block->getModule())); - break; - } - default: - return absl::InvalidArgumentError(absl::StrCat("invalid struct type: ", GetLlvmObjectString(type))); - } - - return builder->CreateNull(block); - } - - return NativeValue(nullptr, nullptr, type); -} - } // namespace codegen } // namespace hybridse diff --git a/hybridse/src/codegen/type_ir_builder.h b/hybridse/src/codegen/type_ir_builder.h index e68d7f0233b..1ecbed9382d 100644 --- a/hybridse/src/codegen/type_ir_builder.h +++ b/hybridse/src/codegen/type_ir_builder.h @@ -92,10 +92,6 @@ class BoolIRBuilder : public TypeIRBuilder { } }; -// construct a safe null value for type -// returns NativeValue{raw, is_null=true} on success, raw is ensured to be not nullptr -absl::StatusOr CreateSafeNull(::llvm::BasicBlock* block, ::llvm::Type* type); - } // namespace codegen } // namespace hybridse #endif // HYBRIDSE_SRC_CODEGEN_TYPE_IR_BUILDER_H_ diff --git a/hybridse/src/node/expr_node.cc b/hybridse/src/node/expr_node.cc index 94b7613d3cc..bfc46b2867e 100644 --- a/hybridse/src/node/expr_node.cc +++ b/hybridse/src/node/expr_node.cc @@ -119,8 +119,9 @@ Status ExprIdNode::InferAttr(ExprAnalysisContext* ctx) { return Status::OK(); } +node::DataType CastExprNode::base_cast_type() const { return cast_type_->base(); } Status CastExprNode::InferAttr(ExprAnalysisContext* ctx) { - SetOutputType(ctx->node_manager()->MakeTypeNode(cast_type_)); + SetOutputType(cast_type_); return Status::OK(); } @@ -281,17 +282,28 @@ absl::StatusOr ExprNode::CompatibleType(NodeManager* nm, const /** * support rules: -* case target_type -* bool -> from_type is bool -* int* -> from_type is bool or from_type is equal/smaller integral type -* float|double -> from_type is bool or equal/smaller float type -* timestamp -> from_type is timestamp or integral type +* 1. case target_type +* bool -> from_type is bool +* intXX -> from_type is bool or from_type is equal/smaller integral type +* float | double -> from_type is bool or equal/smaller float type +* timestamp -> from_type is timestamp or integral type +* string | date -> not convertible from other type +* MAP -> +* from_type: MAP (consturct by map()) -> OK +* from_type: MAP -> SafeCast(K -> KEY) && SafeCast(V -> VALUE) +* +* 2. from_type of NOT_NULL = false can not cast to target_type of NOT_NULL = True +* TODO(someone): TypeNode should contains NOT_NULL ATtribute. */ bool ExprNode::IsSafeCast(const TypeNode* from_type, const TypeNode* target_type) { if (from_type == nullptr || target_type == nullptr) { return false; } + if (from_type->IsNull()) { + // VOID -> T + return true; + } if (TypeEquals(from_type, target_type)) { return true; } @@ -314,8 +326,9 @@ bool ExprNode::IsSafeCast(const TypeNode* from_type, case kTimestamp: return from_base == kTimestamp || from_type->IsInteger(); default: - return false; + break; } + return false; } bool ExprNode::IsIntFloat2PointerCast(const TypeNode* lhs, @@ -895,7 +908,7 @@ ExprIdNode* ExprIdNode::ShadowCopy(NodeManager* nm) const { } CastExprNode* CastExprNode::ShadowCopy(NodeManager* nm) const { - return nm->MakeCastNode(cast_type_, GetChild(0)); + return nm->MakeNode(cast_type_, GetChild(0)); } WhenExprNode* WhenExprNode::ShadowCopy(NodeManager* nm) const { diff --git a/hybridse/src/node/node_manager.cc b/hybridse/src/node/node_manager.cc index 77689a61927..59d581b3039 100644 --- a/hybridse/src/node/node_manager.cc +++ b/hybridse/src/node/node_manager.cc @@ -322,7 +322,7 @@ ColumnRefNode *NodeManager::MakeColumnRefNode(const std::string &column_name, co return MakeColumnRefNode(column_name, relation_name, ""); } CastExprNode *NodeManager::MakeCastNode(const node::DataType cast_type, ExprNode *expr) { - CastExprNode *node_ptr = new CastExprNode(cast_type, expr); + CastExprNode *node_ptr = new CastExprNode(MakeNode(cast_type), expr); return RegisterNode(node_ptr); } WhenExprNode *NodeManager::MakeWhenNode(ExprNode *when_expr, ExprNode *then_expr) { @@ -416,7 +416,7 @@ ParameterExpr *NodeManager::MakeParameterExpr(int position) { return RegisterNode(node_ptr); } ExprIdNode *NodeManager::MakeExprIdNode(const std::string &name) { - return RegisterNode(new ::hybridse::node::ExprIdNode(name, exprid_idx_counter_++)); + return RegisterNode(new ::hybridse::node::ExprIdNode(name, expr_id_counter_++)); } ExprIdNode *NodeManager::MakeUnresolvedExprId(const std::string &name) { return RegisterNode(new ::hybridse::node::ExprIdNode(name, -1)); @@ -1065,10 +1065,5 @@ SqlNode *NodeManager::MakeInputParameterNode(bool is_constant, const std::string return RegisterNode(node_ptr); } -void NodeManager::SetNodeUniqueId(ExprNode *node) { node->SetNodeId(expr_idx_counter_++); } -void NodeManager::SetNodeUniqueId(TypeNode *node) { node->SetNodeId(type_idx_counter_++); } -void NodeManager::SetNodeUniqueId(PlanNode *node) { node->SetNodeId(plan_idx_counter_++); } -void NodeManager::SetNodeUniqueId(vm::PhysicalOpNode *node) { node->SetNodeId(physical_plan_idx_counter_++); } - } // namespace node } // namespace hybridse diff --git a/hybridse/src/node/sql_node.cc b/hybridse/src/node/sql_node.cc index 3fc8c067ca6..4e35793e1b4 100644 --- a/hybridse/src/node/sql_node.cc +++ b/hybridse/src/node/sql_node.cc @@ -851,12 +851,12 @@ void CastExprNode::Print(std::ostream &output, const std::string &org_tab) const ExprNode::Print(output, org_tab); output << "\n"; const std::string tab = org_tab + INDENT + SPACE_ED; - PrintValue(output, tab, DataTypeName(cast_type_), "cast_type", false); + PrintValue(output, tab, cast_type_->DebugString(), "cast_type", false); output << "\n"; PrintSqlNode(output, tab, expr(), "expr", true); } const std::string CastExprNode::GetExprString() const { - std::string str = DataTypeName(cast_type_); + std::string str = cast_type_->DebugString(); str.append("(").append(ExprString(expr())).append(")"); return str; } @@ -868,7 +868,7 @@ bool CastExprNode::Equals(const ExprNode *node) const { return false; } const CastExprNode *that = dynamic_cast(node); - return this->cast_type_ == that->cast_type_ && ExprEquals(expr(), that->expr()); + return TypeEquals(cast_type_, that->cast_type()) && ExprEquals(expr(), that->expr()); } CastExprNode *CastExprNode::CastFrom(ExprNode *node) { return dynamic_cast(node); } diff --git a/hybridse/src/passes/lambdafy_projects.cc b/hybridse/src/passes/lambdafy_projects.cc index dc8f6380b9a..aa30291ec51 100644 --- a/hybridse/src/passes/lambdafy_projects.cc +++ b/hybridse/src/passes/lambdafy_projects.cc @@ -208,6 +208,11 @@ Status LambdafyProjects::VisitLeafExpr(node::ExprNode* expr, *out = expr; break; } + case node::kExprCall: { + // fn with empty args + *out = expr; + break; + } default: FAIL_STATUS(common::kCodegenError, "Unknown leaf expr type: " + ExprTypeName(expr->GetExprType())) } diff --git a/hybridse/src/planv2/ast_node_converter.cc b/hybridse/src/planv2/ast_node_converter.cc index 0427b5c6ba6..46bb285f993 100644 --- a/hybridse/src/planv2/ast_node_converter.cc +++ b/hybridse/src/planv2/ast_node_converter.cc @@ -399,8 +399,7 @@ base::Status ConvertExprNode(const zetasql::ASTExpression* ast_expression, node: node::TypeNode* tp = nullptr; CHECK_STATUS(ConvertASTType(cast_expression->type(), node_manager, &tp)) - // TODO(ace): cast from base type is not enough for type like array - *output = node_manager->MakeCastNode(tp->base(), expr_node); + *output = node_manager->MakeNode(tp, expr_node); return base::Status::OK(); } case zetasql::AST_PARAMETER_EXPR: { diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkRowCodec.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkRowCodec.scala index f89a97d2d4d..478d570a669 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkRowCodec.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/SparkRowCodec.scala @@ -418,8 +418,6 @@ class SparkRowCodec(sliceSchemas: Array[StructType]) { args.AddChild(valToNativeExpr(kv._1, keyType, nm)) args.AddChild(valToNativeExpr(kv._2, valType, nm)) }) - // TODO(someone): support empty map, since 'map()' inferred as map - // we need construst a extra cast operation to hint the true type from schema nm.MakeFuncNode("map", args, null) } case _ => throw new IllegalArgumentException( diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/ConstProjectPlan.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/ConstProjectPlan.scala index 48927ef3fb4..f8eb326a103 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/ConstProjectPlan.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/ConstProjectPlan.scala @@ -72,7 +72,7 @@ object ConstProjectPlan { case ExprType.kExprCast => val cast = CastExprNode.CastFrom(expr) - val castType = cast.getCast_type_ + val castType = cast.base_cast_type val (childCol, childType) = createSparkColumn(spark, cast.GetChild(0)) val castColumn = castSparkOutputCol(spark, childCol, childType, castType) castColumn -> castType diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SimpleProjectPlan.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SimpleProjectPlan.scala index 2e1d1c429c4..950bb92356e 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SimpleProjectPlan.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SimpleProjectPlan.scala @@ -109,7 +109,7 @@ object SimpleProjectPlan { case ExprType.kExprCast => val cast = CastExprNode.CastFrom(expr) - val castType = cast.getCast_type_ + val castType = cast.base_cast_type val (childCol, childType) = createSparkColumn(spark, inputDf, node, cast.GetChild(0)) val castColumn = ConstProjectPlan.castSparkOutputCol( diff --git a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/ExpressionUtil.scala b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/ExpressionUtil.scala index c849104ddc5..62448414e47 100644 --- a/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/ExpressionUtil.scala +++ b/java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/utils/ExpressionUtil.scala @@ -175,7 +175,7 @@ object ExpressionUtil { ExpressionUtil.constExprToSparkColumn(const) case ExprType.kExprCast => val cast = CastExprNode.CastFrom(expr) - val castType = cast.getCast_type_ + val castType = cast.base_cast_type val childCol = recursiveGetSparkColumnFromExpr(cast.GetChild(0), node, leftDf, rightDf, hasIndexColumn) // Convert OpenMLDB node datatype to Spark datatype childCol.cast(DataTypeUtil.openmldbTypeToSparkType(castType)) diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/TestProjectPlan.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/TestProjectPlan.scala index a50bb6ca56c..c0170aa1ed2 100644 --- a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/TestProjectPlan.scala +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/TestProjectPlan.scala @@ -17,6 +17,7 @@ package com._4paradigm.openmldb.batch import com._4paradigm.openmldb.batch.utils.SparkUtil +import com._4paradigm.openmldb.batch.api.OpenmldbSession import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.{DateType, DoubleType, IntegerType, StringType, StructField, StructType} @@ -52,24 +53,12 @@ class TestProjectPlan extends SparkTestSuite { } test("Test const project") { - val sess = getSparkSession - - val schema = StructType(Seq( - StructField("1", IntegerType, nullable = false), - StructField("3.500000", DoubleType, nullable = false), - StructField("a", StringType, nullable = false), - StructField("date(2024-03-25)", DateType, nullable = true), - StructField("string(int32(int64(1)))", StringType, nullable = false) - )) - val expectDf = sess.createDataFrame(Seq( - (1, 3.5d, "a", Date.valueOf("2024-03-25"), "1") - ).map(Row.fromTuple(_)).asJava, schema) - - val planner = new SparkPlanner(sess) - val res = planner.plan("select 1, 3.5, \"a\", date('2024-03-25'), string(int(bigint(1)));", - mutable.HashMap[String, mutable.Map[String, DataFrame]]()) - val output = res.getDf() + val sess = new OpenmldbSession(getSparkSession) + val sql = "select 1, 3.5, 'a', date('2024-03-25'), string(int(bigint(1)));" + val res = sess.sql(sql) + res.show() + val sparkDf = sess.sparksql(sql) - assert(SparkUtil.approximateDfEqual(expectDf, output)) + assert(SparkUtil.approximateDfEqual(sparkDf, res.getSparkDf(), false)) } } diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestProject.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestProject.scala index 09990296c9b..e6260a2ab48 100644 --- a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestProject.scala +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/end2end/TestProject.scala @@ -49,8 +49,9 @@ class TestProject extends SparkTestSuite { val data = Seq( Row(1, Map.apply(1 -> "11", 12 -> "99")), - Row(2, Map.apply(13 -> "99"))) - // Row(2, Map.empty[Int, String])) + Row(2, Map.apply(13 -> "99")), + Row(3, Map.empty[Int, String]), + Row(4, null)) val schema = StructType(List( StructField("id", IntegerType), StructField("val", MapType(IntegerType, StringType)))) diff --git a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/utils/TestGraphvizUtil.scala b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/utils/TestGraphvizUtil.scala index e56423039f9..b9d5523ebbb 100644 --- a/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/utils/TestGraphvizUtil.scala +++ b/java/openmldb-batch/src/test/scala/com/_4paradigm/openmldb/batch/utils/TestGraphvizUtil.scala @@ -86,7 +86,7 @@ class TestGraphvizUtil extends SparkTestSuite { engine.close() } - assert(mutablenode.toString == "[65]GroupAgg{}->[22]GroupBy::") + assert(mutablenode.toString == "[87]GroupAgg{}->[22]GroupBy::") } test("Test visitPhysicalOp") { @@ -103,6 +103,6 @@ class TestGraphvizUtil extends SparkTestSuite { if (engine != null) { engine.close() } - assert(mutablenode.toString == "[65]GroupAgg{}->") + assert(mutablenode.toString == "[87]GroupAgg{}->") } }