diff --git a/mlir-assigner/include/mlir-assigner/memory/stack_frame.hpp b/mlir-assigner/include/mlir-assigner/memory/stack_frame.hpp index f22af22..1c3621f 100644 --- a/mlir-assigner/include/mlir-assigner/memory/stack_frame.hpp +++ b/mlir-assigner/include/mlir-assigner/memory/stack_frame.hpp @@ -107,39 +107,6 @@ namespace nil { return false; } - template - bool peek_local(MlirType identifier) { - return peek_local(mlir::hash_value(identifier)); - } - - bool peek_local(llvm::hash_code hash_code) { - size_t hash = size_t(hash_code); - for (auto iter = frames.rbegin(); iter != frames.rend(); ++iter) { - if (iter->locals.find(hash) != iter->locals.end()) { - // yay we found it - return true; - } - } - return false; - } - - template - bool peek_memref(MlirType identifier) { - return peek_memref(mlir::hash_value(identifier)); - } - - bool peek_memref(llvm::hash_code hash_code) { - size_t hash = size_t(hash_code); - for (auto iter = frames.rbegin(); iter != frames.rend(); ++iter) { - if (iter->memrefs.find(hash) != iter->memrefs.end()) { - // yay we found it - return true; - } - } - return false; - } - - template void push_local(MlirType identifier, const VarType &local, bool allow_overwrite = true) { push_local(mlir::hash_value(identifier), local, allow_overwrite); diff --git a/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp b/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp index b4075c6..8d79b9a 100644 --- a/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp +++ b/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp @@ -373,9 +373,8 @@ namespace zk_ml_toolchain { "Select must operate on same type"); // check if we work on indices Type operandType = operation->getOperand(1).getType(); - auto i1Hash = mlir::hash_value(operation->getOperand(0)); - if (operandType.isa()) { - if (stack.get_constant(i1Hash)) { + if (operandType.isa()) { + if (stack.get_constant(operation->getOperand(0))) { auto truthy = stack.get_constant(operation->getOperand(1)); stack.push_constant(operation->getResult(0), truthy); } else { @@ -383,6 +382,7 @@ namespace zk_ml_toolchain { stack.push_constant(operation->getResult(0), falsy); } } else if (operandType.isa() || operandType.isa()) { + auto i1Hash = mlir::hash_value(operation->getOperand(0)); if (stack.peek_constant(i1Hash)) { // we can just pick as selector was produces by index if (stack.get_constant(i1Hash)) { @@ -420,17 +420,16 @@ namespace zk_ml_toolchain { "Or must operate on same type"); // check if we work on indices // TODO this seems like a hack, maybe we can do something better - auto lhsHash = mlir::hash_value(operation.getLhs()); - if (stack.peek_constant(lhsHash)) { - auto lhs = stack.get_constant(lhsHash); + mlir::Type LhsType = operation.getLhs().getType(); + mlir::Type RhsType = operation.getRhs().getType(); + assert(LhsType == RhsType && "must be same type for OrIOp"); + if (LhsType.isa()) { + auto lhs = stack.get_constant(operation.getLhs()); auto rhs = stack.get_constant(operation.getRhs()); auto result = lhs | rhs; stack.push_constant(operation.getResult(), result); } else { // check if logical and or bitwise and - mlir::Type LhsType = operation.getLhs().getType(); - mlir::Type RhsType = operation.getRhs().getType(); - assert(LhsType == RhsType && "must be same type for OrIOp"); unsigned bits = LhsType.getIntOrFloatBitWidth(); if (1 == bits) { nil::blueprint::handle_logic_or(operation, stack, bp, assignmnt, compParams); @@ -440,13 +439,11 @@ namespace zk_ml_toolchain { } } else if (arith::ExtFOp operation = llvm::dyn_cast(op)) { // nothing for us to do here. just copy in stack - auto opHash = mlir::hash_value(operation->getOperand(0)); - if (stack.peek_local(opHash)) { - stack.push_local(operation.getResult(), stack.get_local(opHash)); - } else if (stack.peek_memref(opHash)) { - stack.push_memref(operation.getResult(), stack.get_memref(opHash)); + mlir::Type operandType = operation->getOperand(0).getType(); + if (operandType.isa()) { + stack.push_memref(operation.getResult(), stack.get_memref(operation->getOperand(0))); } else { - UNREACHABLE("cannot find operand for extf"); + stack.push_local(operation.getResult(), stack.get_local(operation->getOperand(0))); } } else if (arith::XOrIOp operation = llvm::dyn_cast(op)) { // check if logical and or bitwise and