From 15bb2609e9984b2c115f86d3e884f97a6e5aa682 Mon Sep 17 00:00:00 2001 From: fnieddu <118167989+fnieddu@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:15:21 +0100 Subject: [PATCH] feat: input ops (#21) * added logic or flag for math.ori * added input ops * readme update --- .../boolean/{and.hpp => logic_ops.hpp} | 35 ++-- .../components/handle_component.hpp | 1 + .../mlir-assigner/parser/evaluator.hpp | 159 +++++++++++++++--- .../Ops/Onnx/Constant/ConstantScalar.json | 1 + .../Ops/Onnx/Constant/ConstantScalar.onnx | Bin 0 -> 83 bytes .../Ops/Onnx/Constant/ConstantScalar.res | 3 + .../Ops/Onnx/Constant/ConstantSimple.json | 1 + .../Ops/Onnx/Constant/ConstantSimple.onnx | Bin 0 -> 133 bytes .../Ops/Onnx/Constant/ConstantSimple.res | 3 + .../ConstantOfShapeSimple.json | 1 + .../ConstantOfShapeSimple.onnx | 9 + .../ConstantOfShape/ConstantOfShapeSimple.res | 3 + .../tests/Ops/Onnx/OneHot/OneHotFloat.json | 1 + .../tests/Ops/Onnx/OneHot/OneHotFloat.onnx | Bin 0 -> 133 bytes .../tests/Ops/Onnx/OneHot/OneHotFloat.res | 3 + .../tests/Ops/Onnx/OneHot/OneHotSimple.json | 1 + .../tests/Ops/Onnx/OneHot/OneHotSimple.onnx | Bin 0 -> 128 bytes .../tests/Ops/Onnx/OneHot/OneHotSimple.res | 3 + .../Or/OrSimple.json | 0 .../Or/OrSimple.onnx | 0 mlir-assigner/tests/Ops/Onnx/Or/OrSimple.res | 3 + .../tests/Ops/Onnx/Range/RangeFloat.json | 1 + .../tests/Ops/Onnx/Range/RangeFloat.onnx | Bin 0 -> 123 bytes .../tests/Ops/Onnx/Range/RangeFloat.res | 3 + .../tests/Ops/Onnx/Range/RangeFloatLarge.json | 1 + .../tests/Ops/Onnx/Range/RangeFloatLarge.onnx | Bin 0 -> 128 bytes .../tests/Ops/Onnx/Range/RangeFloatLarge.res | 3 + .../tests/Ops/Onnx/Range/RangeIntLarge.json | 1 + .../tests/Ops/Onnx/Range/RangeIntLarge.onnx | 9 + .../tests/Ops/Onnx/Range/RangeIntLarge.res | 3 + .../tests/Ops/Onnx/Range/RangeSimple.json | 1 + .../tests/Ops/Onnx/Range/RangeSimple.onnx | 9 + .../tests/Ops/Onnx/Range/RangeSimple.res | 3 + .../Or/OrSimple.mlir | 16 -- .../Or/OrSimple.res | 3 - mlir-assigner/tests/README.md | 16 +- 36 files changed, 236 insertions(+), 60 deletions(-) rename mlir-assigner/include/mlir-assigner/components/boolean/{and.hpp => logic_ops.hpp} (74%) create mode 100644 mlir-assigner/tests/Ops/Onnx/Constant/ConstantScalar.json create mode 100644 mlir-assigner/tests/Ops/Onnx/Constant/ConstantScalar.onnx create mode 100644 mlir-assigner/tests/Ops/Onnx/Constant/ConstantScalar.res create mode 100644 mlir-assigner/tests/Ops/Onnx/Constant/ConstantSimple.json create mode 100644 mlir-assigner/tests/Ops/Onnx/Constant/ConstantSimple.onnx create mode 100644 mlir-assigner/tests/Ops/Onnx/Constant/ConstantSimple.res create mode 100644 mlir-assigner/tests/Ops/Onnx/ConstantOfShape/ConstantOfShapeSimple.json create mode 100644 mlir-assigner/tests/Ops/Onnx/ConstantOfShape/ConstantOfShapeSimple.onnx create mode 100644 mlir-assigner/tests/Ops/Onnx/ConstantOfShape/ConstantOfShapeSimple.res create mode 100644 mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.json create mode 100644 mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.onnx create mode 100644 mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.res create mode 100644 mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.json create mode 100644 mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.onnx create mode 100644 mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.res rename mlir-assigner/tests/Ops/{TheirBluePrintNotWorkingNOTHING_FOR_US => Onnx}/Or/OrSimple.json (100%) rename mlir-assigner/tests/Ops/{TheirBluePrintNotWorkingNOTHING_FOR_US => Onnx}/Or/OrSimple.onnx (100%) create mode 100644 mlir-assigner/tests/Ops/Onnx/Or/OrSimple.res create mode 100644 mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.json create mode 100644 mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.onnx create mode 100644 mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.res create mode 100644 mlir-assigner/tests/Ops/Onnx/Range/RangeFloatLarge.json create mode 100644 mlir-assigner/tests/Ops/Onnx/Range/RangeFloatLarge.onnx create mode 100644 mlir-assigner/tests/Ops/Onnx/Range/RangeFloatLarge.res create mode 100644 mlir-assigner/tests/Ops/Onnx/Range/RangeIntLarge.json create mode 100644 mlir-assigner/tests/Ops/Onnx/Range/RangeIntLarge.onnx create mode 100644 mlir-assigner/tests/Ops/Onnx/Range/RangeIntLarge.res create mode 100644 mlir-assigner/tests/Ops/Onnx/Range/RangeSimple.json create mode 100644 mlir-assigner/tests/Ops/Onnx/Range/RangeSimple.onnx create mode 100644 mlir-assigner/tests/Ops/Onnx/Range/RangeSimple.res delete mode 100644 mlir-assigner/tests/Ops/TheirBluePrintNotWorkingNOTHING_FOR_US/Or/OrSimple.mlir delete mode 100644 mlir-assigner/tests/Ops/TheirBluePrintNotWorkingNOTHING_FOR_US/Or/OrSimple.res diff --git a/mlir-assigner/include/mlir-assigner/components/boolean/and.hpp b/mlir-assigner/include/mlir-assigner/components/boolean/logic_ops.hpp similarity index 74% rename from mlir-assigner/include/mlir-assigner/components/boolean/and.hpp rename to mlir-assigner/include/mlir-assigner/components/boolean/logic_ops.hpp index d52cb8f..2e421f5 100644 --- a/mlir-assigner/include/mlir-assigner/components/boolean/and.hpp +++ b/mlir-assigner/include/mlir-assigner/components/boolean/logic_ops.hpp @@ -39,6 +39,7 @@ namespace nil { namespace blueprint { + // TODO There is also the logic_and_flag. Should we use this one or should we use the logic_ops????? template void handle_logic_and( mlir::arith::AndIOp &operation, @@ -54,7 +55,8 @@ namespace nil { const auto p = detail::PolicyManager::get_parameters( detail::ManifestReader::get_witness(0)); - component_type component(p.witness); + using manifest_reader = detail::ManifestReader; + component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs()); fill_trace(component, input, operation, frame, bp, assignment, start_row); } @@ -66,17 +68,23 @@ namespace nil { assignment_proxy> &assignment, std::uint32_t start_row) { - //FIXME logic_or is commented out. As soon as it is enabled, remove add the liens above and it SHOULD work - UNREACHABLE("LogicOR not enabled in blueprint"); - // using component_type = components::lookup_logic_or< - // crypto3::zk::snark::plonk_constraint_system>; - // - // auto input = PREPARE_BINARY_INPUT(mlir::arith::OrIOp); - // const auto p = detail::PolicyManager::get_parameters( - // detail::ManifestReader::get_witness(0)); - // - // component_type component(p.witness); - // fill_trace(component, input, operation, frame, bp, assignment, start_row); + using component_type = components::logic_or_flag< + crypto3::zk::snark::plonk_constraint_system>; + using input_type = typename component_type::input_type; + + auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs())); + auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs())); + ASSERT(lhs != frame.locals.end()); + ASSERT(rhs != frame.locals.end()); + + input_type input; + input.x = lhs->second; + input.y = rhs->second; + const auto p = detail::PolicyManager::get_parameters( + detail::ManifestReader::get_witness(0)); + using manifest_reader = detail::ManifestReader; + component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs()); + fill_trace(component, input, operation, frame, bp, assignment, start_row); } template @@ -94,7 +102,8 @@ namespace nil { const auto p = detail::PolicyManager::get_parameters( detail::ManifestReader::get_witness(0)); - component_type component(p.witness); + using manifest_reader = detail::ManifestReader; + component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs()); fill_trace(component, input, operation, frame, bp, assignment, start_row); } diff --git a/mlir-assigner/include/mlir-assigner/components/handle_component.hpp b/mlir-assigner/include/mlir-assigner/components/handle_component.hpp index b838458..d21cfb0 100644 --- a/mlir-assigner/include/mlir-assigner/components/handle_component.hpp +++ b/mlir-assigner/include/mlir-assigner/components/handle_component.hpp @@ -35,6 +35,7 @@ #include #include #include +#include #define PREPARE_UNARY_INPUT(OP) \ prepare_unary_operation_input #include #define TEST_WITHOUT_LOOKUP_TABLES @@ -53,7 +54,7 @@ #include #include #include -#include +#include #include #include @@ -190,6 +191,17 @@ namespace zk_ml_toolchain { bool PrintCircuitOutput; nil::blueprint::logger &logger; + template + NumberType resolve_number(VarType scalar) { + auto scalar_value = var_value(assignmnt, scalar); + static constexpr auto limit_value = + typename BlueprintFieldType::integral_type(std::numeric_limits::max()); + auto integral_value = static_cast(scalar_value.data); + ASSERT_MSG(integral_value < limit_value, "Too large to cast"); + NumberType number = static_cast(integral_value); + return number; + } + void doAffineFor(affine::AffineForOp &op, int64_t from, int64_t to, int64_t step) { assert(from < to); assert(step); @@ -253,7 +265,45 @@ namespace zk_ml_toolchain { } else if (arith::CmpFOp operation = llvm::dyn_cast(op)) { handle_fixedpoint_comparison_component(operation, frames.back(), bp, assignmnt, start_row); } else if (arith::SelectOp operation = llvm::dyn_cast(op)) { - handle_select_component(operation, frames.back(), bp, assignmnt, start_row); + ASSERT(operation.getNumOperands() == 3 && "Select must have three operands"); + ASSERT(operation->getOperand(1).getType() == operation->getOperand(2).getType() && + "Select must operate on same type"); + // check if we work on indices + Type operandType = operation->getOperand(1).getType(); + auto i1Hash = mlir::hash_value(operation->getOperand(0)); + if (operandType.isa()) { + // for now we expect that if we select on indices, that we also have the cmp result in + // constant values. Let's see if this holds true in the future + auto cmpResult = frames.back().constant_values.find(i1Hash); + ASSERT(cmpResult != frames.back().constant_values.end()); + if (cmpResult->second) { + auto truthy = frames.back().constant_values.find(mlir::hash_value(operation->getOperand(1))); + ASSERT(truthy != frames.back().constant_values.end()); + frames.back().constant_values[mlir::hash_value(operation->getResult(0))] = truthy->second; + } else { + auto falsy = frames.back().constant_values.find(mlir::hash_value(operation->getOperand(2))); + ASSERT(falsy != frames.back().constant_values.end()); + frames.back().constant_values[mlir::hash_value(operation->getResult(0))] = falsy->second; + } + } else if (frames.back().constant_values.find(i1Hash) != frames.back().constant_values.end()) { + // we come from index comparision but we do not work on indices, ergo we need to get from locals + if (frames.back().constant_values[i1Hash]) { + auto truthy = frames.back().locals.find(mlir::hash_value(operation->getOperand(1))); + ASSERT(truthy != frames.back().locals.end()); + frames.back().locals[mlir::hash_value(operation->getResult(0))] = truthy->second; + } else { + auto falsy = frames.back().locals.find(mlir::hash_value(operation->getOperand(2))); + ASSERT(falsy != frames.back().locals.end()); + frames.back().locals[mlir::hash_value(operation->getResult(0))] = falsy->second; + } + } else if (operandType.isa()) { + handle_select_component(operation, frames.back(), bp, assignmnt, start_row); + } else { + std::string typeStr; + llvm::raw_string_ostream ss(typeStr); + ss << operandType; + UNREACHABLE(std::string("unhandled select operand: ") + typeStr); + } } else if (arith::NegFOp operation = llvm::dyn_cast(op)) { handle_fixedpoint_neg_component(operation, frames.back(), bp, assignmnt, start_row); } else if (arith::AndIOp operation = llvm::dyn_cast(op)) { @@ -267,14 +317,28 @@ namespace zk_ml_toolchain { UNREACHABLE("TODO add Bitwise And Gadget"); } } else if (arith::OrIOp operation = llvm::dyn_cast(op)) { - // check if logical and or bitwise and - mlir::Type LhsType = operation.getLhs().getType(); - mlir::Type RhsType = operation.getRhs().getType(); - assert(LhsType == RhsType && "must be same type for OrIOp"); - if (LhsType.getIntOrFloatBitWidth() == 1) { - handle_logic_or(operation, frames.back(), bp, assignmnt, start_row); + ASSERT(operation.getNumOperands() == 2 && "Or must have two operands"); + ASSERT(operation->getOperand(0).getType() == operation->getOperand(1).getType() && + "Or must operate on same type"); + // check if we work on indices + // TODO this seems like a hack, maybe we can do something better + auto lhsHash = mlir::hash_value(operation.getLhs()); + if (frames.back().constant_values.find(lhsHash) != frames.back().constant_values.end()) { + auto lhs = frames.back().constant_values[lhsHash]; + auto rhs = frames.back().constant_values.find(mlir::hash_value(operation.getRhs())); + assert(rhs != frames.back().constant_values.end()); + auto result = lhs | rhs->second; + frames.back().constant_values[mlir::hash_value(operation.getResult())] = result; } else { - UNREACHABLE("TODO add Bitwise Or Gadget"); + // check if logical and or bitwise and + mlir::Type LhsType = operation.getLhs().getType(); + mlir::Type RhsType = operation.getRhs().getType(); + assert(LhsType == RhsType && "must be same type for OrIOp"); + if (LhsType.getIntOrFloatBitWidth() == 1) { + handle_logic_or(operation, frames.back(), bp, assignmnt, start_row); + } else { + UNREACHABLE("TODO add Bitwise Or Gadget"); + } } } else if (arith::XOrIOp operation = llvm::dyn_cast(op)) { // check if logical and or bitwise and @@ -287,7 +351,6 @@ namespace zk_ml_toolchain { UNREACHABLE("TODO add Bitwise XOr Gadget"); } } else if (arith::AddIOp operation = llvm::dyn_cast(op)) { - // TODO: ATM, handle only the case where we work on indices that are // constant values auto lhs = frames.back().constant_values.find(mlir::hash_value(operation.getLhs())); @@ -323,8 +386,53 @@ namespace zk_ml_toolchain { frames.back().constant_values[mlir::hash_value(operation.getResult())] = result; } else if (arith::CmpIOp operation = llvm::dyn_cast(op)) { - llvm::outs() << "icmp\n"; - exit(0); + assert(operation.getLhs().getType().isa()); + assert(operation.getRhs().getType().isa()); + + // TODO: ATM, handle only the case where we work on indices that are + // constant values + auto lhs = frames.back().constant_values.find(mlir::hash_value(operation.getLhs())); + auto rhs = frames.back().constant_values.find(mlir::hash_value(operation.getRhs())); + assert(lhs != frames.back().constant_values.end()); + assert(rhs != frames.back().constant_values.end()); + int64_t cmpResult; + switch (operation.getPredicate()) { + case arith::CmpIPredicate::eq: + cmpResult = static_cast(lhs->second == rhs->second); + break; + case arith::CmpIPredicate::ne: + cmpResult = static_cast(lhs->second != rhs->second); + break; + case arith::CmpIPredicate::slt: + cmpResult = static_cast(lhs->second < rhs->second); + break; + case arith::CmpIPredicate::sle: + cmpResult = static_cast(lhs->second <= rhs->second); + break; + case arith::CmpIPredicate::sgt: + cmpResult = static_cast(lhs->second > rhs->second); + break; + case arith::CmpIPredicate::sge: + cmpResult = static_cast(lhs->second >= rhs->second); + break; + case arith::CmpIPredicate::ult: + cmpResult = static_cast(static_cast(lhs->second) < + static_cast(rhs->second)); + break; + case arith::CmpIPredicate::ule: + cmpResult = static_cast(static_cast(lhs->second) <= + static_cast(rhs->second)); + break; + case arith::CmpIPredicate::ugt: + cmpResult = static_cast(static_cast(lhs->second) > + static_cast(rhs->second)); + break; + case arith::CmpIPredicate::uge: + cmpResult = static_cast(static_cast(lhs->second) >= + static_cast(rhs->second)); + break; + } + frames.back().constant_values[mlir::hash_value(operation.getResult())] = cmpResult; } else if (arith::ConstantOp operation = llvm::dyn_cast(op)) { TypedAttr constantValue = operation.getValueAttr(); if (constantValue.isa()) { @@ -358,10 +466,21 @@ namespace zk_ml_toolchain { } } else if (arith::IndexCastOp operation = llvm::dyn_cast(op)) { assert(operation->getNumOperands() == 1 && "IndexCast must have exactly one operand"); - auto index = frames.back().constant_values[mlir::hash_value(operation->getOperand(0))]; - typename BlueprintFieldType::value_type field_constant = index; - auto val = put_into_assignment(field_constant); - frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), val)); + auto opHash = mlir::hash_value(operation->getOperand(0)); + // from int to index + if (operation->getOperand(0).getType().isa()) { + auto i = frames.back().locals.find(opHash); + assert(i != frames.back().locals.end()); + frames.back().constant_values[mlir::hash_value(operation.getResult())] = resolve_number(i->second); + } else if (operation->getOperand(0).getType().isa()) { + auto index = frames.back().constant_values.find(opHash); + assert(index != frames.back().constant_values.end()); + typename BlueprintFieldType::value_type field_constant = index->second; + auto val = put_into_assignment(field_constant); + frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), val)); + } else { + UNREACHABLE("unsupported Index Cast"); + } } else if (arith::SIToFPOp operation = llvm::dyn_cast(op)) { // TODO this does not respect negative and no different ranges for ints... handle_to_fixedpoint(operation, frames.back(), bp, assignmnt, start_row); @@ -543,7 +662,7 @@ namespace zk_ml_toolchain { // Create the global at the entry of the module. assert(operation.getValue().has_value() && "Krnl Global must always have a value"); auto value = operation.getValue().value(); - //TODO check other bit sizes. Also no range constraint is this necessary???? + // TODO check other bit sizes. Also no range constraint is this necessary???? if (DenseElementsAttr attr = llvm::dyn_cast(value)) { mlir::Type attrType = attr.getElementType(); if (attrType.isa()) { @@ -551,8 +670,8 @@ namespace zk_ml_toolchain { assert(!mlir::failed(ints) && "must work as we checked above"); size_t idx = 0; for (auto a : ints.value()) { - auto var = put_into_assignment(a.getSExtValue()); - m.put_flat(idx++, var); + auto var = put_into_assignment(a.getSExtValue()); + m.put_flat(idx++, var); } } else if (attrType.isa()) { auto floats = attr.tryGetValues(); @@ -572,7 +691,7 @@ namespace zk_ml_toolchain { m.put_flat(idx++, var); } } else { - UNREACHABLE("Unsupported attribute type"); + UNREACHABLE("Unsupported attribute type"); } } else { UNREACHABLE("Expected a DenseElementsAttr"); diff --git a/mlir-assigner/tests/Ops/Onnx/Constant/ConstantScalar.json b/mlir-assigner/tests/Ops/Onnx/Constant/ConstantScalar.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Constant/ConstantScalar.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/Onnx/Constant/ConstantScalar.onnx b/mlir-assigner/tests/Ops/Onnx/Constant/ConstantScalar.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4f58cabb11c68dd43c7ad1ce47e90ed046147b9a GIT binary patch literal 83 zcmd[1] +3 diff --git a/mlir-assigner/tests/Ops/Onnx/Constant/ConstantSimple.json b/mlir-assigner/tests/Ops/Onnx/Constant/ConstantSimple.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Constant/ConstantSimple.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/Onnx/Constant/ConstantSimple.onnx b/mlir-assigner/tests/Ops/Onnx/Constant/ConstantSimple.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6c3682144a87f6f09d5f5fc0aae9a5ece3be526b GIT binary patch literal 133 zcmd~4`Mk0 z2?r3-03sHEhyx(v0f=C51QG&{3mDmi_+Z8bXXX~[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +12 diff --git a/mlir-assigner/tests/Ops/Onnx/ConstantOfShape/ConstantOfShapeSimple.json b/mlir-assigner/tests/Ops/Onnx/ConstantOfShape/ConstantOfShapeSimple.json new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/ConstantOfShape/ConstantOfShapeSimple.json @@ -0,0 +1 @@ +[] diff --git a/mlir-assigner/tests/Ops/Onnx/ConstantOfShape/ConstantOfShapeSimple.onnx b/mlir-assigner/tests/Ops/Onnx/ConstantOfShape/ConstantOfShapeSimple.onnx new file mode 100644 index 0000000..a2cd304 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/ConstantOfShape/ConstantOfShapeSimple.onnx @@ -0,0 +1,9 @@ + :` + +in_aout_a"ConstantOfShapeConstantOfShapeSimple*: +Bin_ab +out_a +  + + +B \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/Onnx/ConstantOfShape/ConstantOfShapeSimple.res b/mlir-assigner/tests/Ops/Onnx/ConstantOfShape/ConstantOfShapeSimple.res new file mode 100644 index 0000000..9b632d3 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/ConstantOfShape/ConstantOfShapeSimple.res @@ -0,0 +1,3 @@ +Result: +memref<1x10xf32>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +12 diff --git a/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.json b/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.json new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.json @@ -0,0 +1 @@ +[] diff --git a/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.onnx b/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6dfa88fe9962ef6aa6caac7b98cb8d7145734802 GIT binary patch literal 133 zcmdIBfmt58$!F~uxT+f zuyZ(pHEMAKxr{tuPLh@g2a^D!5(fhqG}wbhlas``ASMd&aq)043UP2TaWDfhj}w!C F2mt$K8g>8x literal 0 HcmV?d00001 diff --git a/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.res b/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.res new file mode 100644 index 0000000..ded5694 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.res @@ -0,0 +1,3 @@ +Result: +memref<3x12xf32>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] +11 diff --git a/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.json b/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.json new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.json @@ -0,0 +1 @@ +[] diff --git a/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.onnx b/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.onnx new file mode 100644 index 0000000000000000000000000000000000000000..536d89cc62829b9ce46feaa05af61d2e55587d19 GIT binary patch literal 128 zcmdIBfmt52SNvD<`(3nYVmV03$R-; zGq7_wfpuze3$ST1@_;!>T6`Q#KrRy#E0~|0B+dmfOo)$*hl5RsgNuoS8HjnDm;^)s DoB|p} literal 0 HcmV?d00001 diff --git a/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.res b/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.res new file mode 100644 index 0000000..85e653a --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.res @@ -0,0 +1,3 @@ +Result: +memref<3x12xi32>[5, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5, 2, 2, 2] +11 diff --git a/mlir-assigner/tests/Ops/TheirBluePrintNotWorkingNOTHING_FOR_US/Or/OrSimple.json b/mlir-assigner/tests/Ops/Onnx/Or/OrSimple.json similarity index 100% rename from mlir-assigner/tests/Ops/TheirBluePrintNotWorkingNOTHING_FOR_US/Or/OrSimple.json rename to mlir-assigner/tests/Ops/Onnx/Or/OrSimple.json diff --git a/mlir-assigner/tests/Ops/TheirBluePrintNotWorkingNOTHING_FOR_US/Or/OrSimple.onnx b/mlir-assigner/tests/Ops/Onnx/Or/OrSimple.onnx similarity index 100% rename from mlir-assigner/tests/Ops/TheirBluePrintNotWorkingNOTHING_FOR_US/Or/OrSimple.onnx rename to mlir-assigner/tests/Ops/Onnx/Or/OrSimple.onnx diff --git a/mlir-assigner/tests/Ops/Onnx/Or/OrSimple.res b/mlir-assigner/tests/Ops/Onnx/Or/OrSimple.res new file mode 100644 index 0000000..b642caf --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Or/OrSimple.res @@ -0,0 +1,3 @@ +Result: +memref<1x10xi1>[1, 1, 0, 0, 1, 1, 0, 1, 1, 1] +23 diff --git a/mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.json b/mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.onnx b/mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b36bdb05a618b041b9ffb196446c0a6ed3169f7f GIT binary patch literal 123 zcmd[3, 6] +4 diff --git a/mlir-assigner/tests/Ops/Onnx/Range/RangeFloatLarge.json b/mlir-assigner/tests/Ops/Onnx/Range/RangeFloatLarge.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Range/RangeFloatLarge.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/Onnx/Range/RangeFloatLarge.onnx b/mlir-assigner/tests/Ops/Onnx/Range/RangeFloatLarge.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f516b7bc03a1713164707de7276d3cce223039a5 GIT binary patch literal 128 zcmdf|!uSG?V15#e&tMPcCnpJWK@1e);^N?76k_3G;^1*&5)c6Z&$}CX literal 0 HcmV?d00001 diff --git a/mlir-assigner/tests/Ops/Onnx/Range/RangeFloatLarge.res b/mlir-assigner/tests/Ops/Onnx/Range/RangeFloatLarge.res new file mode 100644 index 0000000..270bafe --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Range/RangeFloatLarge.res @@ -0,0 +1,3 @@ +Result: +memref<12xf32>[3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7, 7.5, 8, 8.5] +14 diff --git a/mlir-assigner/tests/Ops/Onnx/Range/RangeIntLarge.json b/mlir-assigner/tests/Ops/Onnx/Range/RangeIntLarge.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Range/RangeIntLarge.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/Onnx/Range/RangeIntLarge.onnx b/mlir-assigner/tests/Ops/Onnx/Range/RangeIntLarge.onnx new file mode 100644 index 0000000..76daac6 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Range/RangeIntLarge.onnx @@ -0,0 +1,9 @@ + :m + +in_a +in_b +in_cout_a"Range RangeIntLarge* :Bin_a* : Bin_b* :Bin_cb +out_a + + +B \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/Onnx/Range/RangeIntLarge.res b/mlir-assigner/tests/Ops/Onnx/Range/RangeIntLarge.res new file mode 100644 index 0000000..3a9b604 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Range/RangeIntLarge.res @@ -0,0 +1,3 @@ +Result: +memref<6xi64>[3, 4, 5, 6, 7, 8] +8 diff --git a/mlir-assigner/tests/Ops/Onnx/Range/RangeSimple.json b/mlir-assigner/tests/Ops/Onnx/Range/RangeSimple.json new file mode 100644 index 0000000..0637a08 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Range/RangeSimple.json @@ -0,0 +1 @@ +[] \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/Onnx/Range/RangeSimple.onnx b/mlir-assigner/tests/Ops/Onnx/Range/RangeSimple.onnx new file mode 100644 index 0000000..e105f8e --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Range/RangeSimple.onnx @@ -0,0 +1,9 @@ + :k + +in_a +in_b +in_cout_a"Range RangeSimple* :Bin_a* : Bin_b* :Bin_cb +out_a + + +B \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/Onnx/Range/RangeSimple.res b/mlir-assigner/tests/Ops/Onnx/Range/RangeSimple.res new file mode 100644 index 0000000..a853199 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Range/RangeSimple.res @@ -0,0 +1,3 @@ +Result: +memref<2xi64>[3, 6] +4 diff --git a/mlir-assigner/tests/Ops/TheirBluePrintNotWorkingNOTHING_FOR_US/Or/OrSimple.mlir b/mlir-assigner/tests/Ops/TheirBluePrintNotWorkingNOTHING_FOR_US/Or/OrSimple.mlir deleted file mode 100644 index 5a0758d..0000000 --- a/mlir-assigner/tests/Ops/TheirBluePrintNotWorkingNOTHING_FOR_US/Or/OrSimple.mlir +++ /dev/null @@ -1,16 +0,0 @@ -module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "orsimple.mlir"} { - func.func @main_graph(%arg0: memref<1x10xi1>, %arg1: memref<1x10xi1>) -> memref<1x10xi1> attributes {input_names = ["in_a", "in_b"], llvm.emit_c_interface, output_names = ["out_a"]} { - %c0 = arith.constant 0 : index - %alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xi1> - affine.for %arg2 = 0 to 1 { - affine.for %arg3 = 0 to 10 { - %0 = affine.load %arg0[%c0, %arg3] : memref<1x10xi1> - %1 = affine.load %arg1[%c0, %arg3] : memref<1x10xi1> - %2 = arith.ori %0, %1 : i1 - affine.store %2, %alloc[%arg2, %arg3] : memref<1x10xi1> - } - } - return %alloc : memref<1x10xi1> - } - "krnl.entry_point"() {func = @main_graph, numInputs = 2 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22i1\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_a\22 }\0A , { \22type\22 : \22i1\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_b\22 }\0A\0A]\00@[ { \22type\22 : \22i1\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> () -} diff --git a/mlir-assigner/tests/Ops/TheirBluePrintNotWorkingNOTHING_FOR_US/Or/OrSimple.res b/mlir-assigner/tests/Ops/TheirBluePrintNotWorkingNOTHING_FOR_US/Or/OrSimple.res deleted file mode 100644 index 990b388..0000000 --- a/mlir-assigner/tests/Ops/TheirBluePrintNotWorkingNOTHING_FOR_US/Or/OrSimple.res +++ /dev/null @@ -1,3 +0,0 @@ -Result: -memref<1x10xbool>[1, 1, 0, 0, 1, 1, 0, 1, 1, 1] -ADD THE ROWS HERE \ No newline at end of file diff --git a/mlir-assigner/tests/README.md b/mlir-assigner/tests/README.md index 010f170..1794895 100644 --- a/mlir-assigner/tests/README.md +++ b/mlir-assigner/tests/README.md @@ -113,8 +113,8 @@ long as it is applicable for ZK). | **Compress** | :x: | :white_check_mark: | | | **Concat** | :x: | :white_check_mark: | | | **ConcatFromSequence** | :x: | :x: | | -| **Constant** | :x: | :white_check_mark: | | -| **ConstantOfShape** | :x: | :white_check_mark: | | +| **Constant** | :white_check_mark: | :white_check_mark: | | +| **ConstantOfShape** | :white_check_mark: | :white_check_mark: | | | **Conv** | :white_check_mark: | :white_check_mark: | | | **ConvInteger** | :x: | :x: | | | **ConvTranspose** | :x: | :white_check_mark: | | @@ -173,7 +173,7 @@ long as it is applicable for ZK). | **LessOrEqual** | :white_check_mark: | :white_check_mark: | | | **LinearClassifier** | :x: | :x: | | | **LinearRegressor** | :x: | :x: | | -| **Log** | :x: | :white_check_mark: | | +| **Log** | :white_check_mark: | :white_check_mark: | | | **LogSoftmax** | :white_check_mark: | :white_check_mark: | | | **Loop** | :x: | :white_check_mark: | | | **LpNormalization** | :x: | :x: | | @@ -199,12 +199,12 @@ long as it is applicable for ZK). | **NonZero** | :x: | :white_check_mark: | | | **Normalizer** | :x: | :x: | | | **Not** | :white_check_mark: | :white_check_mark: | | -| **OneHot** | :x: | :white_check_mark: | | +| **OneHot** | :white_check_mark: | :white_check_mark: | | | **OneHotEncoder** | :x: | :x: | | | **Optional** | :x: | :x: | | | **OptionalGetElement** | :x: | :x: | | | **OptionalHasElement** | :x: | :x: | | -| **Or** | :x: | :white_check_mark: | | +| **Or** | :white_check_mark: | :white_check_mark: | | | **PRelu** | :white_check_mark: | :white_check_mark: | | | **Pad** | :x: | :white_check_mark: | | | **Pow** | :x: | :white_check_mark: | | @@ -216,7 +216,7 @@ long as it is applicable for ZK). | **RandomNormalLike** | :x: | :x: | | | **RandomUniform** | :x: | :x: | | | **RandomUniformLike** | :x: | :x: | | -| **Range** | :x: | :white_check_mark: | | +| **Range** | :white_check_mark: | :white_check_mark: | | | **Reciprocal** | :white_check_mark: | :white_check_mark: | No support for integers at the moment. | | **ReduceL1** | :white_check_mark: | :white_check_mark: | | | **ReduceL2** | :white_check_mark: | :white_check_mark: | | @@ -250,13 +250,13 @@ long as it is applicable for ZK). | **SequenceInsert** | :x: | :white_check_mark: | | | **SequenceLength** | :x: | :x: | | | **SequenceMap** | :x: | :x: | | -| **Shape** | :x: | :white_check_mark: | | +| **Shape** | :white_check_mark: | :white_check_mark: | | | **Shrink** | :x: | :x: | | | **Sigmoid** | :white_check_mark: | :white_check_mark: | | | **Sign** | :white_check_mark: | :white_check_mark: | No support for integers at the moment. | | **Sin** | :white_check_mark: | :white_check_mark: | | | **Sinh** | :white_check_mark: | :white_check_mark: | | -| **Size** | :x: | :white_check_mark: | | +| **Size** | :white_check_mark: | :white_check_mark: | | | **Slice** | :x: | :white_check_mark: | | | **Softmax** | :white_check_mark: | :white_check_mark: | No support for integers at the moment. | | **SoftmaxCrossEntropyLoss** | :x: | :x: | |