Skip to content

Commit

Permalink
refactor: removed unnecessary peeks in stack
Browse files Browse the repository at this point in the history
  • Loading branch information
fnieddu committed Feb 16, 2024
1 parent 1d3ff9e commit ecc3fd5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 48 deletions.
33 changes: 0 additions & 33 deletions mlir-assigner/include/mlir-assigner/memory/stack_frame.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,39 +107,6 @@ namespace nil {
return false;
}

template<typename MlirType>
bool peek_local(MlirType identifier) {
return peek_local(mlir::hash_value(identifier));
}

bool peek_local(llvm::hash_code hash_code) {
size_t hash = size_t(hash_code);
for (auto iter = frames.rbegin(); iter != frames.rend(); ++iter) {
if (iter->locals.find(hash) != iter->locals.end()) {
// yay we found it
return true;
}
}
return false;
}

template<typename MlirType>
bool peek_memref(MlirType identifier) {
return peek_memref(mlir::hash_value(identifier));
}

bool peek_memref(llvm::hash_code hash_code) {
size_t hash = size_t(hash_code);
for (auto iter = frames.rbegin(); iter != frames.rend(); ++iter) {
if (iter->memrefs.find(hash) != iter->memrefs.end()) {
// yay we found it
return true;
}
}
return false;
}


template<typename MlirType>
void push_local(MlirType identifier, const VarType &local, bool allow_overwrite = true) {
push_local(mlir::hash_value(identifier), local, allow_overwrite);
Expand Down
27 changes: 12 additions & 15 deletions mlir-assigner/include/mlir-assigner/parser/evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,16 +373,16 @@ namespace zk_ml_toolchain {
"Select must operate on same type");
// check if we work on indices
Type operandType = operation->getOperand(1).getType();
auto i1Hash = mlir::hash_value(operation->getOperand(0));
if (operandType.isa<IndexType>()) {
if (stack.get_constant(i1Hash)) {
if (operandType.isa<mlir::IndexType>()) {
if (stack.get_constant(operation->getOperand(0))) {
auto truthy = stack.get_constant(operation->getOperand(1));
stack.push_constant(operation->getResult(0), truthy);
} else {
auto falsy = stack.get_constant(operation->getOperand(2));
stack.push_constant(operation->getResult(0), falsy);
}
} else if (operandType.isa<FloatType>() || operandType.isa<IntegerType>()) {
auto i1Hash = mlir::hash_value(operation->getOperand(0));
if (stack.peek_constant(i1Hash)) {
// we can just pick as selector was produces by index
if (stack.get_constant(i1Hash)) {
Expand Down Expand Up @@ -420,17 +420,16 @@ namespace zk_ml_toolchain {
"Or must operate on same type");
// check if we work on indices
// TODO this seems like a hack, maybe we can do something better
auto lhsHash = mlir::hash_value(operation.getLhs());
if (stack.peek_constant(lhsHash)) {
auto lhs = stack.get_constant(lhsHash);
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for OrIOp");
if (LhsType.isa<IndexType>()) {
auto lhs = stack.get_constant(operation.getLhs());
auto rhs = stack.get_constant(operation.getRhs());
auto result = lhs | rhs;
stack.push_constant(operation.getResult(), result);
} else {
// check if logical and or bitwise and
mlir::Type LhsType = operation.getLhs().getType();
mlir::Type RhsType = operation.getRhs().getType();
assert(LhsType == RhsType && "must be same type for OrIOp");
unsigned bits = LhsType.getIntOrFloatBitWidth();
if (1 == bits) {
nil::blueprint::handle_logic_or(operation, stack, bp, assignmnt, compParams);
Expand All @@ -440,13 +439,11 @@ namespace zk_ml_toolchain {
}
} else if (arith::ExtFOp operation = llvm::dyn_cast<arith::ExtFOp>(op)) {
// nothing for us to do here. just copy in stack
auto opHash = mlir::hash_value(operation->getOperand(0));
if (stack.peek_local(opHash)) {
stack.push_local(operation.getResult(), stack.get_local(opHash));
} else if (stack.peek_memref(opHash)) {
stack.push_memref(operation.getResult(), stack.get_memref(opHash));
mlir::Type operandType = operation->getOperand(0).getType();
if (operandType.isa<mlir::MemRefType>()) {
stack.push_memref(operation.getResult(), stack.get_memref(operation->getOperand(0)));
} else {
UNREACHABLE("cannot find operand for extf");
stack.push_local(operation.getResult(), stack.get_local(operation->getOperand(0)));
}
} else if (arith::XOrIOp operation = llvm::dyn_cast<arith::XOrIOp>(op)) {
// check if logical and or bitwise and
Expand Down

0 comments on commit ecc3fd5

Please sign in to comment.