Skip to content

Commit

Permalink
feat(map): cast null or map() to explicit map type (#3847)
Browse files Browse the repository at this point in the history
Auto casting to target type when necessary, which usually happens with
insert statement while a default schema defined by target table. This
supports constructing null or empty values from offline engine.
  • Loading branch information
aceforeverd committed Apr 17, 2024
1 parent d31526a commit 8934906
Show file tree
Hide file tree
Showing 27 changed files with 197 additions and 133 deletions.
2 changes: 1 addition & 1 deletion cases/plan/create.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ cases:
| +-0:
| | +-node[kColumnDesc]
| | +-column_name: column1
| | +-column_type: string DEFAULT string(1)
| | +-column_type: string DEFAULT STRING(1)
| +-1:
| +-node[kColumnDesc]
| +-column_name: column3
Expand Down
15 changes: 15 additions & 0 deletions cases/query/udf_query.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ cases:
# ================================================================
# Map data type
# FIXME: request mode tests disabled, because TestRequestEngineForLastRow cause SEG FAULT
# ================================================================
- id: 13
mode: request-unsupport
Expand Down Expand Up @@ -637,3 +638,17 @@ cases:
1, abc
2, null
- id: 19
mode: request-unsupport
desc: empty or null map
sql: |
select cast (null as map<int, string>)[0] as o1,
cast (null as map<string, int>) ["12"] as o2,
cast (map() as map<string, int64>) ["12"] as o3,
cast (map() as map<int, timestamp>) [7] as o4,
cast (map(7, "9") as map<int, string>) [7] as o5,
cast (map() as map<date, timestamp>) [date("2012-12-12")] as o6,
expect:
columns: ["o1 string", "o2 int", "o3 int64", "o4 timestamp", "o5 string", "o6 timestamp"]
data: |
NULL, NULL, NULL, NULL, 9, NULL
22 changes: 8 additions & 14 deletions hybridse/include/node/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,26 +410,20 @@ class NodeManager {
return node_ptr;
}

private:
void SetNodeUniqueId(ExprNode *node);
void SetNodeUniqueId(TypeNode *node);
void SetNodeUniqueId(PlanNode *node);
void SetNodeUniqueId(vm::PhysicalOpNode *node);
void SetIdCounter(size_t i) {
assert(i > id_counter_);
id_counter_ = i;
}

private:
template <typename T>
void SetNodeUniqueId(T *node) {
node->SetNodeId(other_node_idx_counter_++);
node->SetNodeId(id_counter_++);
}

std::list<base::FeBaseObject *> node_list_;

// unique id counter for various types of node
size_t expr_idx_counter_ = 1;
size_t type_idx_counter_ = 1;
size_t plan_idx_counter_ = 1;
size_t physical_plan_idx_counter_ = 1;
size_t other_node_idx_counter_ = 1;
size_t exprid_idx_counter_ = 0;
size_t id_counter_ = 0;
size_t expr_id_counter_ = 0;
};

} // namespace node
Expand Down
18 changes: 12 additions & 6 deletions hybridse/include/node/sql_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -1520,22 +1520,28 @@ class FnDefNode : public SqlNode {

class CastExprNode : public ExprNode {
public:
explicit CastExprNode(const node::DataType cast_type, node::ExprNode *expr)
explicit CastExprNode(const node::TypeNode *cast_type, node::ExprNode *expr)
: ExprNode(kExprCast), cast_type_(cast_type) {
this->AddChild(expr);
}

~CastExprNode() {}
void Print(std::ostream &output, const std::string &org_tab) const;
const std::string GetExprString() const;
virtual bool Equals(const ExprNode *that) const;
void Print(std::ostream &output, const std::string &org_tab) const override;
const std::string GetExprString() const override;
bool Equals(const ExprNode *that) const override;
CastExprNode *ShadowCopy(NodeManager *) const override;
static CastExprNode *CastFrom(ExprNode *node);

ExprNode *expr() const { return GetChild(0); }
const DataType cast_type_;
const TypeNode *cast_type() const { return cast_type_; }

// legacy interface, required by offline batch
// pls use cast_type() as much as possible
node::DataType base_cast_type() const;

Status InferAttr(ExprAnalysisContext *ctx) override;

private:
const TypeNode *cast_type_;
};

class WhenExprNode : public ExprNode {
Expand Down
27 changes: 25 additions & 2 deletions hybridse/src/codegen/cast_expr_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ bool CastExprIRBuilder::IsSafeCast(::llvm::Type* lhs, ::llvm::Type* rhs) {
}
Status CastExprIRBuilder::Cast(const NativeValue& value,
::llvm::Type* cast_type, NativeValue* output) {
CHECK_STATUS(TypeIRBuilder::BinaryOpTypeInfer(node::ExprNode::IsCastAccept,
value.GetType(), cast_type));
if (value.GetType() == cast_type) {
*output = value;
return {};
}
if (IsSafeCast(value.GetType(), cast_type)) {
CHECK_STATUS(SafeCast(value, cast_type, output));
} else {
Expand All @@ -81,6 +83,7 @@ Status CastExprIRBuilder::SafeCast(const NativeValue& value, ::llvm::Type* dst_t
CHECK_TRUE(IsSafeCast(value.GetType(), dst_type), kCodegenError, "Safe cast fail: unsafe cast");
Status status;
if (value.IsConstNull()) {
// VOID type
auto res = CreateSafeNull(block_, dst_type);
CHECK_TRUE(res.ok(), kCodegenError, res.status().ToString());
*output = res.value();
Expand Down Expand Up @@ -114,6 +117,12 @@ Status CastExprIRBuilder::SafeCast(const NativeValue& value, ::llvm::Type* dst_t

Status CastExprIRBuilder::UnSafeCast(const NativeValue& value, ::llvm::Type* dst_type, NativeValue* output) {
::llvm::IRBuilder<> builder(block_);
node::NodeManager nm;
const node::TypeNode* src_node = nullptr;
const node::TypeNode* dst_node = nullptr;
CHECK_TRUE(GetFullType(&nm, value.GetType(), &src_node), kCodegenError);
CHECK_TRUE(GetFullType(&nm, dst_type, &dst_node), kCodegenError);

if (value.IsConstNull() || (TypeIRBuilder::IsNumber(dst_type) && TypeIRBuilder::IsDatePtr(value.GetType()))) {
// input is const null or (cast date to number)
auto res = CreateSafeNull(block_, dst_type);
Expand All @@ -135,6 +144,20 @@ Status CastExprIRBuilder::UnSafeCast(const NativeValue& value, ::llvm::Type* dst
StringIRBuilder string_ir_builder(block_->getModule());
CHECK_STATUS(string_ir_builder.CastToNumber(block_, value, dst_type, output));
return Status::OK();
} else if (src_node->IsMap() && dst_node->IsMap()) {
auto src_map_node = src_node->GetAsOrNull<node::MapType>();
assert(src_map_node != nullptr && "logic error: map type empty");
if (src_map_node->GetGenericType(0)->IsNull() && src_map_node->GetGenericType(1)->IsNull()) {
auto s = StructTypeIRBuilder::CreateStructTypeIRBuilder(block_->getModule(), dst_type);
CHECK_TRUE(s.ok(), kCodegenError, s.status().ToString());
llvm::Value* val = nullptr;
CHECK_TRUE(s.value()->CreateDefault(block_, &val), kCodegenError);
*output = NativeValue::Create(val);
return Status::OK();
} else {
CHECK_TRUE(false, kCodegenError, "unimplemented: casting ", src_node->DebugString(), " to ",
dst_node->DebugString());
}
} else {
Status status;
::llvm::Value* output_value = nullptr;
Expand Down
12 changes: 4 additions & 8 deletions hybridse/src/codegen/expr_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ Status ExprIRBuilder::BuildConstExpr(
::llvm::IRBuilder<> builder(ctx_->GetCurrentBlock());
switch (const_node->GetDataType()) {
case ::hybridse::node::kNull: {
*output = NativeValue(nullptr, nullptr, llvm::Type::getTokenTy(builder.getContext()));
*output = NativeValue(nullptr, nullptr, llvm::Type::getVoidTy(builder.getContext()));
break;
}
case ::hybridse::node::kBool: {
Expand Down Expand Up @@ -649,14 +649,10 @@ Status ExprIRBuilder::BuildCastExpr(const ::hybridse::node::CastExprNode* node,

CastExprIRBuilder cast_builder(ctx_->GetCurrentBlock());
::llvm::Type* cast_type = NULL;
CHECK_TRUE(GetLlvmType(ctx_->GetModule(), node->cast_type_, &cast_type),
kCodegenError, "Fail to cast expr: dist type invalid");
CHECK_TRUE(GetLlvmType(ctx_->GetModule(), node->cast_type(), &cast_type), kCodegenError,
"Fail to cast expr: dist type invalid");

if (cast_builder.IsSafeCast(left.GetType(), cast_type)) {
return cast_builder.SafeCast(left, cast_type, output);
} else {
return cast_builder.UnSafeCast(left, cast_type, output);
}
return cast_builder.Cast(left, cast_type, output);
}

Status ExprIRBuilder::BuildBinaryExpr(const ::hybridse::node::BinaryExpr* node,
Expand Down
27 changes: 26 additions & 1 deletion hybridse/src/codegen/insert_row_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "codegen/insert_row_builder.h"

#include <algorithm>
#include <map>
#include <string>
#include <utility>
Expand All @@ -28,14 +29,18 @@
#include "codegen/buf_ir_builder.h"
#include "codegen/context.h"
#include "codegen/expr_ir_builder.h"
#include "codegen/ir_base_builder.h"
#include "node/node_manager.h"
#include "node/sql_node.h"
#include "passes/resolve_fn_and_attrs.h"
#include "udf/default_udf_library.h"
#include "vm/jit_wrapper.h"

namespace hybridse {
namespace codegen {

static size_t MaxExprId(absl::Span<node::ExprNode* const>);

InsertRowBuilder::InsertRowBuilder(vm::HybridSeJitWrapper* jit, const codec::Schema* schema)
: schema_(schema), jit_(jit) {}

Expand Down Expand Up @@ -63,6 +68,9 @@ absl::StatusOr<int8_t*> InsertRowBuilder::ComputeRowUnsafe(absl::Span<node::Expr
llvm::make_unique<llvm::Module>(absl::StrCat("insert_row_builder_", fn_counter_.load()), *llvm_ctx);
vm::SchemasContext empty_sc;
node::NodeManager nm;
// WORKAROUND. Set the id counter to the max of all input expr nodes,
// so there will no node id conflicts during codegen
nm.SetIdCounter(MaxExprId(values) + 1);
codec::Schema empty_param_types;
CodeGenContext dump_ctx(llvm_module.get(), &empty_sc, &empty_param_types, &nm);

Expand All @@ -71,9 +79,16 @@ absl::StatusOr<int8_t*> InsertRowBuilder::ComputeRowUnsafe(absl::Span<node::Expr
passes::ResolveFnAndAttrs resolver(&expr_ctx);

std::vector<node::ExprNode*> transformed;
for (auto& expr : values) {
for (size_t i = 0; i < values.size(); i++) {
auto expr = values[i];
node::ExprNode* out = nullptr;
CHECK_STATUS_TO_ABSL(resolver.VisitExpr(expr, &out));
auto tgt_type = ColumnSchema2Type(schema_->Get(i).schema(), &nm);
CHECK_ABSL_STATUSOR(tgt_type);
if (!tgt_type.value()->Equals(out->GetOutputType())) {
auto cast = nm.MakeNode<node::CastExprNode>(tgt_type.value(), out);
CHECK_STATUS_TO_ABSL(resolver.VisitExpr(cast, &out));
}
transformed.push_back(out);
}

Expand Down Expand Up @@ -140,5 +155,15 @@ absl::StatusOr<llvm::Function*> InsertRowBuilder::BuildFn(CodeGenContext* ctx, l
return fn;
}

size_t MaxExprId(absl::Span<node::ExprNode* const> exprs) {
size_t ret = 0;

for (auto& expr : exprs) {
ret = std::max(std::max(ret, expr->node_id()), MaxExprId(expr->children_));
}

return ret;
}

} // namespace codegen
} // namespace hybridse
1 change: 0 additions & 1 deletion hybridse/src/codegen/insert_row_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class InsertRowBuilder {
absl::StatusOr<llvm::Function*> BuildFn(CodeGenContext* ctx, llvm::StringRef fn_name,
absl::Span<node::ExprNode* const>);

// CodeGenContextBase* ctx_;
const codec::Schema* schema_;
vm::HybridSeJitWrapper* jit_;
std::atomic<uint32_t> fn_counter_ = 0;
Expand Down
14 changes: 13 additions & 1 deletion hybridse/src/codegen/insert_row_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace codegen {
class InsertRowBuilderTest : public ::testing::Test {};

TEST_F(InsertRowBuilderTest, encode) {
std::string sql = "insert into t1 values (1, map (1, '12'))";
std::string sql = "insert into t1 values (1, map (1, '12'), null, map())";
vm::SqlContext ctx;
ctx.sql = sql;
auto s = plan::PlanAPI::CreatePlanTreeFromScript(&ctx);
Expand All @@ -51,6 +51,18 @@ TEST_F(InsertRowBuilderTest, encode) {
map_ty->mutable_key_type()->set_base_type(type::kInt32);
map_ty->mutable_value_type()->set_base_type(type::kVarchar);
}
{
auto col = sc.Add();
auto map_ty = col->mutable_schema()->mutable_map_type();
map_ty->mutable_key_type()->set_base_type(type::kFloat);
map_ty->mutable_value_type()->set_base_type(type::kTimestamp);
}
{
auto col = sc.Add();
auto map_ty = col->mutable_schema()->mutable_map_type();
map_ty->mutable_key_type()->set_base_type(type::kDate);
map_ty->mutable_value_type()->set_base_type(type::kVarchar);
}

auto jit = std::shared_ptr<vm::HybridSeJitWrapper>(vm::HybridSeJitWrapper::Create());
ASSERT_TRUE(jit->Init());
Expand Down
4 changes: 2 additions & 2 deletions hybridse/src/codegen/ir_base_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,8 +610,8 @@ bool GetBaseType(::llvm::Type* type, ::hybridse::node::DataType* output) {
return false;
}
switch (type->getTypeID()) {
case ::llvm::Type::TokenTyID: {
*output = ::hybridse::node::kNull;
case ::llvm::Type::VoidTyID: {
*output = ::hybridse::node::kVoid;
return true;
}
case ::llvm::Type::FloatTyID: {
Expand Down
6 changes: 2 additions & 4 deletions hybridse/src/codegen/native_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,8 @@ bool NativeValue::IsNullable() const { return IsConstNull() || HasFlag(); }

// NativeValue is null if:
// - raw_ is null
// - type_ is of token type.
// Currently there is no elsewhere using token type, so assert token type should be safe.
// token type represents SQL NULL may not appropriate, more work refer #926
bool NativeValue::IsConstNull() const { return raw_ == nullptr || (type_ != nullptr && type_->isTokenTy()); }
// - type_ is of void type.
bool NativeValue::IsConstNull() const { return raw_ == nullptr || (type_ != nullptr && type_->isVoidTy()); }

void NativeValue::SetName(const std::string& name) {
if (raw_ == nullptr) {
Expand Down
38 changes: 33 additions & 5 deletions hybridse/src/codegen/struct_ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
#include "codegen/context.h"
#include "codegen/date_ir_builder.h"
#include "codegen/ir_base_builder.h"
#include "codegen/map_ir_builder.h"
#include "codegen/string_ir_builder.h"
#include "codegen/timestamp_ir_builder.h"
#include "node/node_manager.h"

namespace hybridse {
namespace codegen {
Expand All @@ -40,19 +42,34 @@ bool StructTypeIRBuilder::StructCopyFrom(::llvm::BasicBlock* block, ::llvm::Valu

absl::StatusOr<std::unique_ptr<StructTypeIRBuilder>> StructTypeIRBuilder::CreateStructTypeIRBuilder(
::llvm::Module* m, ::llvm::Type* type) {
node::DataType base_type;
if (!GetBaseType(type, &base_type)) {
return absl::UnimplementedError(
absl::StrCat("fail to create struct type ir builder for ", GetLlvmObjectString(type)));
node::NodeManager nm;
const node::TypeNode* ctype = nullptr;
if (!GetFullType(&nm, type, &ctype)) {
return absl::InvalidArgumentError(absl::StrCat("can't get full type for: ", GetLlvmObjectString(type)));
}

switch (base_type) {
switch (ctype->base()) {
case node::kTimestamp:
return std::make_unique<TimestampIRBuilder>(m);
case node::kDate:
return std::make_unique<DateIRBuilder>(m);
case node::kVarchar:
return std::make_unique<StringIRBuilder>(m);
case node::DataType::kMap: {
assert(ctype->IsMap() && "logic error: not a map type");
auto map_type = ctype->GetAsOrNull<node::MapType>();
assert(map_type != nullptr && "logic error: map type empty");
::llvm::Type* key_type = nullptr;
::llvm::Type* value_type = nullptr;
if (codegen::GetLlvmType(m, map_type->key_type(), &key_type) &&
codegen::GetLlvmType(m, map_type->value_type(), &value_type)) {
return std::make_unique<MapIRBuilder>(m, key_type, value_type);
} else {
return absl::InvalidArgumentError(
absl::Substitute("not able to casting map type: $0", GetLlvmObjectString(type)));
}
break;
}
default: {
break;
}
Expand Down Expand Up @@ -224,5 +241,16 @@ absl::StatusOr<std::vector<llvm::Value*>> StructTypeIRBuilder::Load(CodeGenConte

return res;
}

absl::StatusOr<NativeValue> CreateSafeNull(::llvm::BasicBlock* block, ::llvm::Type* type) {
if (TypeIRBuilder::IsStructPtr(type)) {
auto s = StructTypeIRBuilder::CreateStructTypeIRBuilder(block->getModule(), type);
CHECK_ABSL_STATUSOR(s);
return s.value()->CreateNull(block);
}

return NativeValue(nullptr, nullptr, type);
}

} // namespace codegen
} // namespace hybridse
Loading

0 comments on commit 8934906

Please sign in to comment.