Skip to content

Commit

Permalink
Basic support for UINT4 in loader
Browse files Browse the repository at this point in the history
  • Loading branch information
chunseoklee committed Mar 18, 2024
1 parent a1f7e1e commit b4d5d88
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 33 deletions.
43 changes: 14 additions & 29 deletions runtime/libs/circle-schema/include/circle_schema_generated.h
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
/*
* Copyright (c) 2019-2024 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2018 The TensorFlow Authors. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// automatically generated by the FlatBuffers compiler, do not modify

#ifndef FLATBUFFERS_GENERATED_CIRCLESCHEMA_CIRCLE_H_
Expand Down Expand Up @@ -530,6 +514,7 @@ struct ModelBuilder;

enum TensorType : int8_t
{
TensorType_UINT4 = -1,
TensorType_FLOAT32 = 0,
TensorType_FLOAT16 = 1,
TensorType_INT32 = 2,
Expand All @@ -548,35 +533,35 @@ enum TensorType : int8_t
TensorType_UINT32 = 15,
TensorType_UINT16 = 16,
TensorType_INT4 = 17,
TensorType_MIN = TensorType_FLOAT32,
TensorType_MIN = TensorType_UINT4,
TensorType_MAX = TensorType_INT4
};

inline const TensorType (&EnumValuesTensorType())[18]
inline const TensorType (&EnumValuesTensorType())[19]
{
static const TensorType values[] = {
TensorType_FLOAT32, TensorType_FLOAT16, TensorType_INT32, TensorType_UINT8,
TensorType_INT64, TensorType_STRING, TensorType_BOOL, TensorType_INT16,
TensorType_COMPLEX64, TensorType_INT8, TensorType_FLOAT64, TensorType_COMPLEX128,
TensorType_UINT64, TensorType_RESOURCE, TensorType_VARIANT, TensorType_UINT32,
TensorType_UINT16, TensorType_INT4};
TensorType_UINT4, TensorType_FLOAT32, TensorType_FLOAT16, TensorType_INT32,
TensorType_UINT8, TensorType_INT64, TensorType_STRING, TensorType_BOOL,
TensorType_INT16, TensorType_COMPLEX64, TensorType_INT8, TensorType_FLOAT64,
TensorType_COMPLEX128, TensorType_UINT64, TensorType_RESOURCE, TensorType_VARIANT,
TensorType_UINT32, TensorType_UINT16, TensorType_INT4};
return values;
}

inline const char *const *EnumNamesTensorType()
{
static const char *const names[19] = {"FLOAT32", "FLOAT16", "INT32", "UINT8", "INT64",
"STRING", "BOOL", "INT16", "COMPLEX64", "INT8",
"FLOAT64", "COMPLEX128", "UINT64", "RESOURCE", "VARIANT",
"UINT32", "UINT16", "INT4", nullptr};
static const char *const names[20] = {"UINT4", "FLOAT32", "FLOAT16", "INT32", "UINT8",
"INT64", "STRING", "BOOL", "INT16", "COMPLEX64",
"INT8", "FLOAT64", "COMPLEX128", "UINT64", "RESOURCE",
"VARIANT", "UINT32", "UINT16", "INT4", nullptr};
return names;
}

inline const char *EnumNameTensorType(TensorType e)
{
if (::flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_INT4))
if (::flatbuffers::IsOutRange(e, TensorType_UINT4, TensorType_INT4))
return "";
const size_t index = static_cast<size_t>(e);
const size_t index = static_cast<size_t>(e) - static_cast<size_t>(TensorType_UINT4);
return EnumNamesTensorType()[index];
}

Expand Down
3 changes: 2 additions & 1 deletion runtime/onert/core/include/ir/DataType.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ enum class DataType
QUANT_INT16_ASYMM = 10,
QUANT_INT8_SYMM_PER_CHANNEL = 11,
QUANT_INT16_SYMM = 12,
QUANT_INT4_SYMM = 13
QUANT_INT4_SYMM = 13,
QUANT_UINT4_ASYMM = 14
};

size_t sizeOfDataType(DataType data_type);
Expand Down
2 changes: 2 additions & 0 deletions runtime/onert/core/src/ir/DataType.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ size_t sizeOfDataType(DataType data_type)
return sizeof(int16_t);
case DataType::QUANT_INT4_SYMM:
return sizeof(uint8_t); // Q: what is type size for int4?
case DataType::QUANT_UINT4_ASYMM:
return sizeof(uint8_t); // Q: what is type size for uint4?
default:
throw std::runtime_error{"Unsupported type size"};
}
Expand Down
4 changes: 1 addition & 3 deletions runtime/onert/core/src/loader/BaseLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ template <typename LoaderDomain> class BaseLoader

// Helper functions
ir::Activation convertActivation(ActivationFunctionType type);
ir::DataType tensorTypeToDataType(TensorType type);
virtual ir::DataType tensorTypeToDataType(TensorType type);
ir::OperandIndex tensorIdxToOperandIdx(int32_t tensorIdx);
flexbuffers::Map getCustomOpAttrMap(const Operator *op);

Expand Down Expand Up @@ -321,8 +321,6 @@ ir::DataType BaseLoader<LoaderDomain>::BaseLoader::tensorTypeToDataType(const Te
// case TensorType::TensorType_FLOAT64
case TensorType::TensorType_UINT32:
return ir::DataType::UINT32;
case TensorType::TensorType_INT4:
return ir::DataType::QUANT_INT4_SYMM;
default:
throw std::runtime_error(
std::string("Unsupported tensor type: ").append(EnumNameTensorType(type)));
Expand Down
16 changes: 16 additions & 0 deletions runtime/onert/core/src/loader/CircleLoader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class CircleLoader final : public loader::BaseLoader<LoaderDomain>
void loadInstanceNorm(const Operator *op, ir::Graph &subg);
void loadBCQFullyConnected(const Operator *op, ir::Graph &subg);
void loadBCQGather(const Operator *op, ir::Graph &subg);
virtual ir::DataType tensorTypeToDataType(TensorType type) override;

public:
using BaseLoader::BaseLoader;
Expand Down Expand Up @@ -149,6 +150,21 @@ class CircleLoader final : public loader::BaseLoader<LoaderDomain>
}
};


ir::DataType CircleLoader::tensorTypeToDataType(const TensorType type)
{
switch (type)
{
case TensorType::TensorType_INT4:
return ir::DataType::QUANT_INT4_SYMM;
case TensorType::TensorType_UINT4:
return ir::DataType::QUANT_UINT4_ASYMM;
default:
return BaseLoader::tensorTypeToDataType(type);
}
}


void CircleLoader::loadBatchMatMul(const Operator *op, ir::Graph &subg)
{
ir::OperandIndexSequence inputs;
Expand Down

0 comments on commit b4d5d88

Please sign in to comment.