Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions frontend/wasm/src/miden_abi/stdlib/crypto/hashes/rpo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::miden_abi::{FunctionTypeMap, ModuleFunctionTypeMap};
pub const MODULE_ID: &str = "std::crypto::hashes::rpo";

pub const HASH_MEMORY: &str = "hash_memory";
pub const HASH_MEMORY_WORDS: &str = "hash_memory_words";

pub(crate) fn signatures() -> ModuleFunctionTypeMap {
let mut m: ModuleFunctionTypeMap = Default::default();
Expand All @@ -18,6 +19,11 @@ pub(crate) fn signatures() -> ModuleFunctionTypeMap {
Symbol::from(HASH_MEMORY),
FunctionType::new(CallConv::Wasm, [I32, I32], [Felt, Felt, Felt, Felt]),
);
// hash_memory_words takes (start_addr: u32, end_addr: u32) and returns 4 Felt values on the stack
rpo.insert(
Symbol::from(HASH_MEMORY_WORDS),
FunctionType::new(CallConv::Wasm, [I32, I32], [Felt, Felt, Felt, Felt]),
);

let module_path = SymbolPath::from_iter([
SymbolNameComponent::Root,
Expand Down
3 changes: 2 additions & 1 deletion frontend/wasm/src/miden_abi/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ fn get_transform_strategy(path: &SymbolPath) -> Option<TransformStrategy> {
}
name if name == Symbol::intern("rpo") => {
match components.next_if(|c| c.is_leaf())?.as_symbol_name().as_str() {
stdlib::crypto::hashes::rpo::HASH_MEMORY => {
stdlib::crypto::hashes::rpo::HASH_MEMORY
| stdlib::crypto::hashes::rpo::HASH_MEMORY_WORDS => {
Some(TransformStrategy::ReturnViaPointer)
}
_ => None,
Expand Down
34 changes: 28 additions & 6 deletions sdk/stdlib-sys/src/stdlib/crypto/hashes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@ extern "C" {
/// The output is passed back to the caller via a pointer.
#[link_name = "std::crypto::hashes::rpo::hash_memory"]
pub fn extern_hash_memory(ptr: u32, num_elements: u32, result_ptr: *mut Felt);

/// Computes the hash of a sequence of words using the Rescue Prime Optimized (RPO) hash
/// function.
///
/// This maps to the `std::crypto::hashes::rpo::hash_memory_words` procedure in the Miden
/// stdlib.
///
/// Input: The start and end addresses (in field elements) of the words to hash.
/// Output: One digest (4 field elements)
/// The output is passed back to the caller via a pointer.
#[link_name = "std::crypto::hashes::rpo::hash_memory_words"]
pub fn extern_hash_memory_words(start_addr: u32, end_addr: u32, result_ptr: *mut Felt);
}

/// Hashes a 32-byte input to a 32-byte output using the given hash function.
Expand Down Expand Up @@ -196,13 +208,17 @@ pub fn sha256_hash_2to1(input: [u8; 64]) -> [u8; 32] {
/// Computes the hash of a sequence of field elements using the Rescue Prime Optimized (RPO)
/// hash function.
///
/// This maps to the `std::crypto::rpo::hash_memory` procedure in the Miden stdlib.
/// This maps to the `std::crypto::rpo::hash_memory` procedure in the Miden stdlib and to the
/// `std::crypto::hashes::rpo::hash_memory_words` word-optimized variant when the input length is a
/// multiple of 4.
///
/// # Arguments
/// * `elements` - A Vec of field elements to be hashed
#[inline]
pub fn hash_elements(elements: Vec<Felt>) -> Digest {
let rust_ptr = elements.as_ptr().addr() as u32;
let element_count = elements.len();
let num_elements = element_count as u32;

unsafe {
let mut ret_area = core::mem::MaybeUninit::<Word>::uninit();
Expand All @@ -211,7 +227,13 @@ pub fn hash_elements(elements: Vec<Felt>) -> Digest {
// Since our BumpAlloc produces word-aligned allocations the pointer should be word-aligned
assert_eq(Felt::from_u32(miden_ptr % 4), felt!(0));

extern_hash_memory(miden_ptr, elements.len() as u32, result_ptr);
if element_count.is_multiple_of(4) {
let start_addr = miden_ptr;
let end_addr = start_addr + num_elements;
extern_hash_memory_words(start_addr, end_addr, result_ptr);
} else {
extern_hash_memory(miden_ptr, num_elements, result_ptr);
}

Digest::from_word(ret_area.assume_init().reverse())
}
Expand All @@ -220,8 +242,7 @@ pub fn hash_elements(elements: Vec<Felt>) -> Digest {
/// Computes the hash of a sequence of words using the Rescue Prime Optimized (RPO)
/// hash function.
///
/// This maps to the `std::crypto::rpo::hash_memory` procedure in the Miden stdlib treating the
/// `words` as an array of fielt elements.
/// This maps to the `std::crypto::hashes::rpo::hash_memory_words` procedure in the Miden stdlib.
///
/// # Arguments
/// * `words` - A slice of words to be hashed
Expand All @@ -236,8 +257,9 @@ pub fn hash_words(words: &[Word]) -> Digest {
// It's safe to assume the `words` ptr is word-aligned.
assert_eq(Felt::from_u32(miden_ptr % 4), felt!(0));

let num_elements = (words.len() * 4) as u32;
extern_hash_memory(miden_ptr, num_elements, result_ptr);
let start_addr = miden_ptr;
let end_addr = start_addr + (words.len() as u32 * 4);
extern_hash_memory_words(start_addr, end_addr, result_ptr);

Digest::from_word(ret_area.assume_init().reverse())
}
Expand Down
10 changes: 10 additions & 0 deletions sdk/stdlib-sys/stubs/crypto/hashes_rpo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,13 @@ pub extern "C" fn rpo_hash_memory_stub(ptr: u32, num_elements: u32, result_ptr:
unsafe { core::hint::unreachable_unchecked() }
}

/// Unreachable stub for std::crypto::hashes::rpo::hash_memory_words
#[export_name = "std::crypto::hashes::rpo::hash_memory_words"]
pub extern "C" fn rpo_hash_memory_words_stub(
start_addr: u32,
end_addr: u32,
result_ptr: *mut c_void,
) {
let _ = (start_addr, end_addr, result_ptr);
unsafe { core::hint::unreachable_unchecked() }
}
Loading