Skip to content

Commit

Permalink
feat: input ops (#21)
Browse files Browse the repository at this point in the history
* added logic or flag for math.ori
* added input ops
* readme update
  • Loading branch information
fnieddu committed Jan 5, 2024
1 parent 5399202 commit 15bb260
Show file tree
Hide file tree
Showing 36 changed files with 236 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

namespace nil {
namespace blueprint {
// TODO There is also the logic_and_flag. Should we use this one or should we use the logic_ops?????
template<typename BlueprintFieldType, typename ArithmetizationParams>
void handle_logic_and(
mlir::arith::AndIOp &operation,
Expand All @@ -54,7 +55,8 @@ namespace nil {
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

component_type component(p.witness);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

Expand All @@ -66,17 +68,23 @@ namespace nil {
assignment_proxy<crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>
&assignment,
std::uint32_t start_row) {
//FIXME logic_or is commented out. As soon as it is enabled, remove add the liens above and it SHOULD work
UNREACHABLE("LogicOR not enabled in blueprint");
// using component_type = components::lookup_logic_or<
// crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;
//
// auto input = PREPARE_BINARY_INPUT(mlir::arith::OrIOp);
// const auto p = detail::PolicyManager::get_parameters(
// detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
//
// component_type component(p.witness);
// fill_trace(component, input, operation, frame, bp, assignment, start_row);
using component_type = components::logic_or_flag<
crypto3::zk::snark::plonk_constraint_system<BlueprintFieldType, ArithmetizationParams>>;
using input_type = typename component_type::input_type;

auto lhs = frame.locals.find(mlir::hash_value(operation.getLhs()));
auto rhs = frame.locals.find(mlir::hash_value(operation.getRhs()));
ASSERT(lhs != frame.locals.end());
ASSERT(rhs != frame.locals.end());

input_type input;
input.x = lhs->second;
input.y = rhs->second;
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

template<typename BlueprintFieldType, typename ArithmetizationParams>
Expand All @@ -94,7 +102,8 @@ namespace nil {
const auto p = detail::PolicyManager::get_parameters(
detail::ManifestReader<component_type, ArithmetizationParams>::get_witness(0));

component_type component(p.witness);
using manifest_reader = detail::ManifestReader<component_type, ArithmetizationParams>;
component_type component(p.witness, manifest_reader::get_constants(), manifest_reader::get_public_inputs());
fill_trace(component, input, operation, frame, bp, assignment, start_row);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <nil/blueprint/components/algebra/fixedpoint/plonk/argmax.hpp>
#include <nil/blueprint/components/algebra/fixedpoint/plonk/sqrt.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/non_native/lookup_logic_ops.hpp>
#include <nil/blueprint/components/algebra/fields/plonk/logic_or_flag.hpp>

#define PREPARE_UNARY_INPUT(OP) \
prepare_unary_operation_input<BlueprintFieldType, ArithmetizationParams, OP, \
Expand Down
159 changes: 139 additions & 20 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef CRYPTO3_BLUEPRINT_COMPONENT_INSTRUCTION_MLIR_EVALUATOR_HPP
#define CRYPTO3_BLUEPRINT_COMPONENT_INSTRUCTION_MLIR_EVALUATOR_HPP

#include "nil/blueprint/blueprint/plonk/assignment.hpp"
#include <cassert>
#include <cstdint>
#define TEST_WITHOUT_LOOKUP_TABLES
Expand Down Expand Up @@ -53,7 +54,7 @@
#include <mlir-assigner/components/fixedpoint/subtraction.hpp>
#include <mlir-assigner/components/fixedpoint/dot_product.hpp>
#include <mlir-assigner/components/fixedpoint/trigonometric.hpp>
#include <mlir-assigner/components/boolean/and.hpp>
#include <mlir-assigner/components/boolean/logic_ops.hpp>
#include <mlir-assigner/components/fixedpoint/to_fixpoint.hpp>

#include <mlir-assigner/memory/memref.hpp>
Expand Down Expand Up @@ -190,6 +191,17 @@ namespace zk_ml_toolchain {
bool PrintCircuitOutput;
nil::blueprint::logger &logger;

template<typename NumberType>
NumberType resolve_number(VarType scalar) {
auto scalar_value = var_value(assignmnt, scalar);
static constexpr auto limit_value =
typename BlueprintFieldType::integral_type(std::numeric_limits<NumberType>::max());
auto integral_value = static_cast<typename BlueprintFieldType::integral_type>(scalar_value.data);
ASSERT_MSG(integral_value < limit_value, "Too large to cast");
NumberType number = static_cast<NumberType>(integral_value);
return number;
}

void doAffineFor(affine::AffineForOp &op, int64_t from, int64_t to, int64_t step) {
assert(from < to);
assert(step);
Expand Down Expand Up @@ -253,7 +265,45 @@ namespace zk_ml_toolchain {
} else if (arith::CmpFOp operation = llvm::dyn_cast<arith::CmpFOp>(op)) {
handle_fixedpoint_comparison_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (arith::SelectOp operation = llvm::dyn_cast<arith::SelectOp>(op)) {
handle_select_component(operation, frames.back(), bp, assignmnt, start_row);
ASSERT(operation.getNumOperands() == 3 && "Select must have three operands");
ASSERT(operation->getOperand(1).getType() == operation->getOperand(2).getType() &&
"Select must operate on same type");
// check if we work on indices
Type operandType = operation->getOperand(1).getType();
auto i1Hash = mlir::hash_value(operation->getOperand(0));
if (operandType.isa<IndexType>()) {
// for now we expect that if we select on indices, that we also have the cmp result in
// constant values. Let's see if this holds true in the future
auto cmpResult = frames.back().constant_values.find(i1Hash);
ASSERT(cmpResult != frames.back().constant_values.end());
if (cmpResult->second) {
auto truthy = frames.back().constant_values.find(mlir::hash_value(operation->getOperand(1)));
ASSERT(truthy != frames.back().constant_values.end());
frames.back().constant_values[mlir::hash_value(operation->getResult(0))] = truthy->second;
} else {
auto falsy = frames.back().constant_values.find(mlir::hash_value(operation->getOperand(2)));
ASSERT(falsy != frames.back().constant_values.end());
frames.back().constant_values[mlir::hash_value(operation->getResult(0))] = falsy->second;
}
} else if (frames.back().constant_values.find(i1Hash) != frames.back().constant_values.end()) {
// we come from index comparision but we do not work on indices, ergo we need to get from locals
if (frames.back().constant_values[i1Hash]) {
auto truthy = frames.back().locals.find(mlir::hash_value(operation->getOperand(1)));
ASSERT(truthy != frames.back().locals.end());
frames.back().locals[mlir::hash_value(operation->getResult(0))] = truthy->second;
} else {
auto falsy = frames.back().locals.find(mlir::hash_value(operation->getOperand(2)));
ASSERT(falsy != frames.back().locals.end());
frames.back().locals[mlir::hash_value(operation->getResult(0))] = falsy->second;
}
} else if (operandType.isa<FloatType>()) {
handle_select_component(operation, frames.back(), bp, assignmnt, start_row);
} else {
std::string typeStr;
llvm::raw_string_ostream ss(typeStr);
ss << operandType;
UNREACHABLE(std::string("unhandled select operand: ") + typeStr);
}
} else if (arith::NegFOp operation = llvm::dyn_cast<arith::NegFOp>(op)) {
handle_fixedpoint_neg_component(operation, frames.back(), bp, assignmnt, start_row);
} else if (arith::AndIOp operation = llvm::dyn_cast<arith::AndIOp>(op)) {
Expand All @@ -267,14 +317,28 @@ namespace zk_ml_toolchain {
UNREACHABLE("TODO add Bitwise And Gadget");
}
} else if (arith::OrIOp operation = llvm::dyn_cast<arith::OrIOp>(op)) {
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for OrIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
handle_logic_or(operation, frames.back(), bp, assignmnt, start_row);
ASSERT(operation.getNumOperands() == 2 && "Or must have two operands");
ASSERT(operation->getOperand(0).getType() == operation->getOperand(1).getType() &&
"Or must operate on same type");
// check if we work on indices
// TODO this seems like a hack, maybe we can do something better
auto lhsHash = mlir::hash_value(operation.getLhs());
if (frames.back().constant_values.find(lhsHash) != frames.back().constant_values.end()) {
auto lhs = frames.back().constant_values[lhsHash];
auto rhs = frames.back().constant_values.find(mlir::hash_value(operation.getRhs()));
assert(rhs != frames.back().constant_values.end());
auto result = lhs | rhs->second;
frames.back().constant_values[mlir::hash_value(operation.getResult())] = result;
} else {
UNREACHABLE("TODO add Bitwise Or Gadget");
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for OrIOp");
if (LhsType.getIntOrFloatBitWidth() == 1) {
handle_logic_or(operation, frames.back(), bp, assignmnt, start_row);
} else {
UNREACHABLE("TODO add Bitwise Or Gadget");
}
}
} else if (arith::XOrIOp operation = llvm::dyn_cast<arith::XOrIOp>(op)) {
// check if logical and or bitwise and
Expand All @@ -287,7 +351,6 @@ namespace zk_ml_toolchain {
UNREACHABLE("TODO add Bitwise XOr Gadget");
}
} else if (arith::AddIOp operation = llvm::dyn_cast<arith::AddIOp>(op)) {

// TODO: ATM, handle only the case where we work on indices that are
// constant values
auto lhs = frames.back().constant_values.find(mlir::hash_value(operation.getLhs()));
Expand Down Expand Up @@ -323,8 +386,53 @@ namespace zk_ml_toolchain {
frames.back().constant_values[mlir::hash_value(operation.getResult())] = result;

} else if (arith::CmpIOp operation = llvm::dyn_cast<arith::CmpIOp>(op)) {
llvm::outs() << "icmp\n";
exit(0);
assert(operation.getLhs().getType().isa<IndexType>());
assert(operation.getRhs().getType().isa<IndexType>());

// TODO: ATM, handle only the case where we work on indices that are
// constant values
auto lhs = frames.back().constant_values.find(mlir::hash_value(operation.getLhs()));
auto rhs = frames.back().constant_values.find(mlir::hash_value(operation.getRhs()));
assert(lhs != frames.back().constant_values.end());
assert(rhs != frames.back().constant_values.end());
int64_t cmpResult;
switch (operation.getPredicate()) {
case arith::CmpIPredicate::eq:
cmpResult = static_cast<int64_t>(lhs->second == rhs->second);
break;
case arith::CmpIPredicate::ne:
cmpResult = static_cast<int64_t>(lhs->second != rhs->second);
break;
case arith::CmpIPredicate::slt:
cmpResult = static_cast<int64_t>(lhs->second < rhs->second);
break;
case arith::CmpIPredicate::sle:
cmpResult = static_cast<int64_t>(lhs->second <= rhs->second);
break;
case arith::CmpIPredicate::sgt:
cmpResult = static_cast<int64_t>(lhs->second > rhs->second);
break;
case arith::CmpIPredicate::sge:
cmpResult = static_cast<int64_t>(lhs->second >= rhs->second);
break;
case arith::CmpIPredicate::ult:
cmpResult = static_cast<int64_t>(static_cast<uint64_t>(lhs->second) <
static_cast<uint64_t>(rhs->second));
break;
case arith::CmpIPredicate::ule:
cmpResult = static_cast<int64_t>(static_cast<uint64_t>(lhs->second) <=
static_cast<uint64_t>(rhs->second));
break;
case arith::CmpIPredicate::ugt:
cmpResult = static_cast<int64_t>(static_cast<uint64_t>(lhs->second) >
static_cast<uint64_t>(rhs->second));
break;
case arith::CmpIPredicate::uge:
cmpResult = static_cast<int64_t>(static_cast<uint64_t>(lhs->second) >=
static_cast<uint64_t>(rhs->second));
break;
}
frames.back().constant_values[mlir::hash_value(operation.getResult())] = cmpResult;
} else if (arith::ConstantOp operation = llvm::dyn_cast<arith::ConstantOp>(op)) {
TypedAttr constantValue = operation.getValueAttr();
if (constantValue.isa<IntegerAttr>()) {
Expand Down Expand Up @@ -358,10 +466,21 @@ namespace zk_ml_toolchain {
}
} else if (arith::IndexCastOp operation = llvm::dyn_cast<arith::IndexCastOp>(op)) {
assert(operation->getNumOperands() == 1 && "IndexCast must have exactly one operand");
auto index = frames.back().constant_values[mlir::hash_value(operation->getOperand(0))];
typename BlueprintFieldType::value_type field_constant = index;
auto val = put_into_assignment(field_constant);
frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), val));
auto opHash = mlir::hash_value(operation->getOperand(0));
// from int to index
if (operation->getOperand(0).getType().isa<IntegerType>()) {
auto i = frames.back().locals.find(opHash);
assert(i != frames.back().locals.end());
frames.back().constant_values[mlir::hash_value(operation.getResult())] = resolve_number<int64_t>(i->second);
} else if (operation->getOperand(0).getType().isa<IndexType>()) {
auto index = frames.back().constant_values.find(opHash);
assert(index != frames.back().constant_values.end());
typename BlueprintFieldType::value_type field_constant = index->second;
auto val = put_into_assignment(field_constant);
frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), val));
} else {
UNREACHABLE("unsupported Index Cast");
}
} else if (arith::SIToFPOp operation = llvm::dyn_cast<arith::SIToFPOp>(op)) {
// TODO this does not respect negative and no different ranges for ints...
handle_to_fixedpoint(operation, frames.back(), bp, assignmnt, start_row);
Expand Down Expand Up @@ -543,16 +662,16 @@ namespace zk_ml_toolchain {
// Create the global at the entry of the module.
assert(operation.getValue().has_value() && "Krnl Global must always have a value");
auto value = operation.getValue().value();
//TODO check other bit sizes. Also no range constraint is this necessary????
// TODO check other bit sizes. Also no range constraint is this necessary????
if (DenseElementsAttr attr = llvm::dyn_cast<DenseElementsAttr>(value)) {
mlir::Type attrType = attr.getElementType();
if (attrType.isa<mlir::IntegerType>()) {
auto ints = attr.tryGetValues<APInt>();
assert(!mlir::failed(ints) && "must work as we checked above");
size_t idx = 0;
for (auto a : ints.value()) {
auto var = put_into_assignment(a.getSExtValue());
m.put_flat(idx++, var);
auto var = put_into_assignment(a.getSExtValue());
m.put_flat(idx++, var);
}
} else if (attrType.isa<mlir::FloatType>()) {
auto floats = attr.tryGetValues<APFloat>();
Expand All @@ -572,7 +691,7 @@ namespace zk_ml_toolchain {
m.put_flat(idx++, var);
}
} else {
UNREACHABLE("Unsupported attribute type");
UNREACHABLE("Unsupported attribute type");
}
} else {
UNREACHABLE("Expected a DenseElementsAttr");
Expand Down
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Constant/ConstantScalar.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Constant/ConstantScalar.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<f32>[1]
3
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Constant/ConstantSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Constant/ConstantSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<10xf32>[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
12
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
 :`

in_aout_a"ConstantOfShapeConstantOfShapeSimple*:
Bin_ab
out_a



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xf32>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
12
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Binary file added mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.onnx
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<3x12xf32>[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]
11
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<3x12xi32>[5, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5, 2, 2, 2]
11
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Or/OrSimple.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<1x10xi1>[1, 1, 0, 0, 1, 1, 0, 1, 1, 1]
23
1 change: 1 addition & 0 deletions mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[]
Binary file added mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.onnx
Binary file not shown.
3 changes: 3 additions & 0 deletions mlir-assigner/tests/Ops/Onnx/Range/RangeFloat.res
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Result:
memref<2xf32>[3, 6]
4
Loading

0 comments on commit 15bb260

Please sign in to comment.