-
Notifications
You must be signed in to change notification settings - Fork 390
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GLUTEN-4306][CH]Fix CI Failure of Cast Function (#4337)
* improve multi if * improve cast * remove unused code * improve cast * improve cast * Update SparkFunctionCastFloatToInt.cpp re-run ci
- Loading branch information
1 parent
40f72a5
commit a3c690a
Showing
6 changed files
with
277 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
67 changes: 67 additions & 0 deletions
67
cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <base/types.h> | ||
#include <Functions/SparkFunctionCastFloatToInt.h> | ||
|
||
using namespace DB; | ||
|
||
namespace local_engine | ||
{ | ||
|
||
struct NameToUInt8 { static constexpr auto name = "sparkCastFloatToUInt8"; }; | ||
struct NameToUInt16 { static constexpr auto name = "sparkCastFloatToUInt16"; }; | ||
struct NameToUInt32 { static constexpr auto name = "sparkCastFloatToUInt32"; }; | ||
struct NameToUInt64 { static constexpr auto name = "sparkCastFloatToUInt64"; }; | ||
struct NameToUInt128 { static constexpr auto name = "sparkCastFloatToUInt128"; }; | ||
struct NameToUInt256 { static constexpr auto name = "sparkCastFloatToUInt256"; }; | ||
struct NameToInt8 { static constexpr auto name = "sparkCastFloatToInt8"; }; | ||
struct NameToInt16 { static constexpr auto name = "sparkCastFloatToInt16"; }; | ||
struct NameToInt32 { static constexpr auto name = "sparkCastFloatToInt32"; }; | ||
struct NameToInt64 { static constexpr auto name = "sparkCastFloatToInt64"; }; | ||
struct NameToInt128 { static constexpr auto name = "sparkCastFloatToInt128"; }; | ||
struct NameToInt256 { static constexpr auto name = "sparkCastFloatToInt256"; }; | ||
|
||
using SparkFunctionCastFloatToInt8 = local_engine::SparkFunctionCastFloatToInt<Int8, NameToInt8>; | ||
using SparkFunctionCastFloatToInt16 = local_engine::SparkFunctionCastFloatToInt<Int16, NameToInt16>; | ||
using SparkFunctionCastFloatToInt32 = local_engine::SparkFunctionCastFloatToInt<Int32, NameToInt32>; | ||
using SparkFunctionCastFloatToInt64 = local_engine::SparkFunctionCastFloatToInt<Int64, NameToInt64>; | ||
using SparkFunctionCastFloatToInt128 = local_engine::SparkFunctionCastFloatToInt<Int128, NameToInt128>; | ||
using SparkFunctionCastFloatToInt256 = local_engine::SparkFunctionCastFloatToInt<Int256, NameToInt256>; | ||
using SparkFunctionCastFloatToUInt8 = local_engine::SparkFunctionCastFloatToInt<UInt8, NameToUInt8>; | ||
using SparkFunctionCastFloatToUInt16 = local_engine::SparkFunctionCastFloatToInt<UInt16, NameToUInt16>; | ||
using SparkFunctionCastFloatToUInt32 = local_engine::SparkFunctionCastFloatToInt<UInt32, NameToUInt32>; | ||
using SparkFunctionCastFloatToUInt64 = local_engine::SparkFunctionCastFloatToInt<UInt64, NameToUInt64>; | ||
using SparkFunctionCastFloatToUInt128 = local_engine::SparkFunctionCastFloatToInt<UInt128, NameToUInt128>; | ||
using SparkFunctionCastFloatToUInt256 = local_engine::SparkFunctionCastFloatToInt<UInt256, NameToUInt256>; | ||
|
||
REGISTER_FUNCTION(SparkFunctionCastToInt) | ||
{ | ||
factory.registerFunction<SparkFunctionCastFloatToInt8>(); | ||
factory.registerFunction<SparkFunctionCastFloatToInt16>(); | ||
factory.registerFunction<SparkFunctionCastFloatToInt32>(); | ||
factory.registerFunction<SparkFunctionCastFloatToInt64>(); | ||
factory.registerFunction<SparkFunctionCastFloatToInt128>(); | ||
factory.registerFunction<SparkFunctionCastFloatToInt256>(); | ||
factory.registerFunction<SparkFunctionCastFloatToUInt8>(); | ||
factory.registerFunction<SparkFunctionCastFloatToUInt16>(); | ||
factory.registerFunction<SparkFunctionCastFloatToUInt32>(); | ||
factory.registerFunction<SparkFunctionCastFloatToUInt64>(); | ||
factory.registerFunction<SparkFunctionCastFloatToUInt128>(); | ||
factory.registerFunction<SparkFunctionCastFloatToUInt256>(); | ||
} | ||
} |
115 changes: 115 additions & 0 deletions
115
cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <Common/NaNUtils.h> | ||
#include <DataTypes/IDataType.h> | ||
#include <DataTypes/DataTypeNullable.h> | ||
#include <DataTypes/DataTypesNumber.h> | ||
#include <Functions/IFunction.h> | ||
#include <Functions/FunctionFactory.h> | ||
#include <Columns/ColumnsNumber.h> | ||
#include <Columns/ColumnVector.h> | ||
#include <Columns/ColumnNullable.h> | ||
|
||
using namespace DB; | ||
|
||
namespace DB | ||
{ | ||
namespace ErrorCodes | ||
{ | ||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; | ||
extern const int ILLEGAL_TYPE_OF_ARGUMENT; | ||
extern const int TYPE_MISMATCH; | ||
} | ||
} | ||
|
||
namespace local_engine | ||
{ | ||
|
||
template <typename T, typename Name> | ||
class SparkFunctionCastFloatToInt : public DB::IFunction | ||
{ | ||
public: | ||
size_t getNumberOfArguments() const override { return 1; } | ||
static constexpr auto name = Name::name; | ||
static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<SparkFunctionCastFloatToInt>(); } | ||
SparkFunctionCastFloatToInt() = default; | ||
~SparkFunctionCastFloatToInt() override = default; | ||
DB::String getName() const override { return name; } | ||
bool useDefaultImplementationForConstants() const override { return true; } | ||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } | ||
|
||
DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes &) const override | ||
{ | ||
if constexpr (std::is_integral_v<T>) | ||
{ | ||
return DB::makeNullable(std::make_shared<const DB::DataTypeNumber<T>>()); | ||
} | ||
else | ||
throw DB::Exception(DB::ErrorCodes::TYPE_MISMATCH, "Function {}'s return type should be Int", name); | ||
} | ||
|
||
DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t) const override | ||
{ | ||
if (arguments.size() != 1) | ||
throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s arguments number must be 1", name); | ||
|
||
if (!isFloat(removeNullable(arguments[0].type))) | ||
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {}'s 1st argument must be float type", name); | ||
|
||
DB::ColumnPtr src_col = arguments[0].column; | ||
size_t size = src_col->size(); | ||
|
||
auto res_col = DB::ColumnVector<T>::create(size); | ||
auto null_map_col = DB::ColumnUInt8::create(size, 0); | ||
|
||
switch(removeNullable(arguments[0].type)->getTypeId()) | ||
{ | ||
case DB::TypeIndex::Float32: | ||
{ | ||
executeInternal<DB::Float32>(src_col, res_col->getData(), null_map_col->getData()); | ||
break; | ||
} | ||
case DB::TypeIndex::Float64: | ||
{ | ||
executeInternal<DB::Float64>(src_col, res_col->getData(), null_map_col->getData()); | ||
break; | ||
} | ||
} | ||
return DB::ColumnNullable::create(std::move(res_col), std::move(null_map_col)); | ||
} | ||
|
||
template <typename F> | ||
void executeInternal(const DB::ColumnPtr & src, DB::PaddedPODArray<T> & data, DB::PaddedPODArray<UInt8> & null_map_data) const | ||
{ | ||
const DB::ColumnVector<F> * src_vec = assert_cast<const DB::ColumnVector<F> *>(src.get()); | ||
for (size_t i = 0; i < src_vec->size(); ++i) | ||
{ | ||
F element = src_vec->getElement(i); | ||
if (isNaN(element) || !isFinite(element)) | ||
{ | ||
data[i] = 0; | ||
null_map_data[i] = 1; | ||
} | ||
else | ||
data[i] = static_cast<T>(element); | ||
} | ||
} | ||
|
||
}; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
73 changes: 73 additions & 0 deletions
73
cpp-ch/local-engine/tests/benchmark_cast_float_function.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <Core/Block.h> | ||
#include <Columns/IColumn.h> | ||
#include <DataTypes/IDataType.h> | ||
#include <DataTypes/DataTypeFactory.h> | ||
#include <Functions/FunctionFactory.h> | ||
#include <Functions/FunctionsConversion.h> | ||
#include <Functions/SparkFunctionCastFloatToInt.h> | ||
#include <Parser/SerializedPlanParser.h> | ||
#include <benchmark/benchmark.h> | ||
|
||
using namespace DB; | ||
|
||
static Block createDataBlock(size_t rows) | ||
{ | ||
auto type = DataTypeFactory::instance().get("Float64"); | ||
auto column = type->createColumn(); | ||
for (size_t i = 0; i < rows; ++i) | ||
{ | ||
column->insert(i * 1.0f); | ||
} | ||
Block block; | ||
block.insert(ColumnWithTypeAndName(std::move(column), type, "d")); | ||
return std::move(block); | ||
} | ||
|
||
static void BM_CHCastFloatToInt(benchmark::State & state) | ||
{ | ||
using namespace DB; | ||
auto & factory = FunctionFactory::instance(); | ||
auto function = factory.get("CAST", local_engine::SerializedPlanParser::global_context); | ||
Block block = createDataBlock(30000000); | ||
DB::ColumnsWithTypeAndName args; | ||
args.emplace_back(block.getColumnsWithTypeAndName()[0]); | ||
DB::ColumnWithTypeAndName type_name_col; | ||
type_name_col.name = "Int64"; | ||
type_name_col.column = DB::DataTypeString().createColumnConst(0, type_name_col.name); | ||
type_name_col.type = std::make_shared<DB::DataTypeString>(); | ||
args.emplace_back(type_name_col); | ||
auto executable = function->build(args); | ||
for (auto _ : state)[[maybe_unused]] | ||
auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); | ||
} | ||
|
||
static void BM_SparkCastFloatToInt(benchmark::State & state) | ||
{ | ||
using namespace DB; | ||
auto & factory = FunctionFactory::instance(); | ||
auto function = factory.get("sparkCastFloatToInt64", local_engine::SerializedPlanParser::global_context); | ||
Block block = createDataBlock(30000000); | ||
auto executable = function->build(block.getColumnsWithTypeAndName()); | ||
for (auto _ : state)[[maybe_unused]] | ||
auto result = executable->execute(block.getColumnsWithTypeAndName(), executable->getResultType(), block.rows()); | ||
} | ||
|
||
BENCHMARK(BM_CHCastFloatToInt)->Unit(benchmark::kMillisecond)->Iterations(10); | ||
BENCHMARK(BM_SparkCastFloatToInt)->Unit(benchmark::kMillisecond)->Iterations(10); |