From 2eba86bf1169e1c499c9068893dbccd7768da71e Mon Sep 17 00:00:00 2001 From: Abhinav Shukla <67401627+maxprogrammer007@users.noreply.github.com> Date: Fri, 22 May 2026 09:15:53 +0530 Subject: [PATCH 1/2] Update model.rs --- src/model.rs | 51 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/src/model.rs b/src/model.rs index 8b6501e..607aa79 100644 --- a/src/model.rs +++ b/src/model.rs @@ -44,6 +44,22 @@ fn match_local_layout(config_base: &Path, model_base: &Path, config_file: &str) }) } +fn decode_token_mapping(dtype: Dtype, raw: &[u8]) -> Result> { + let mapping = match dtype { + Dtype::I64 => raw + .chunks_exact(8) + .map(|b| i64::from_le_bytes(b.try_into().unwrap()) as usize) + .collect(), + Dtype::I32 => raw + .chunks_exact(4) + .map(|b| i32::from_le_bytes(b.try_into().unwrap()) as usize) + .collect(), + other => return Err(anyhow!("unsupported mapping dtype: {:?}", other)), + }; + + Ok(mapping) +} + #[cfg(all(feature = "hf-hub", not(feature = "local-only")))] fn is_not_found(e: &hf_hub::api::sync::ApiError) -> bool { use hf_hub::api::sync::ApiError; @@ -190,14 +206,7 @@ impl StaticModel { }; let token_mapping = match safet.tensor("mapping") { - Ok(t) => { - let raw = t.data(); - let v: Vec = raw - .chunks_exact(4) - .map(|b| i32::from_le_bytes(b.try_into().unwrap()) as usize) - .collect(); - Some(v) - } + Ok(t) => Some(decode_token_mapping(t.dtype(), t.data())?), Err(_) => None, }; @@ -423,6 +432,32 @@ impl StaticModel { } } +#[cfg(test)] +mod tests { + use super::decode_token_mapping; + use safetensors::tensor::Dtype; + + #[test] + fn decode_token_mapping_supports_i32_and_i64() { + let i32_raw = [1i32, 2, 3] + .into_iter() + .flat_map(|value| value.to_le_bytes()) + .collect::>(); + let i64_raw = [4i64, 5, 6] + .into_iter() + .flat_map(|value| value.to_le_bytes()) + .collect::>(); + + assert_eq!(decode_token_mapping(Dtype::I32, &i32_raw).unwrap(), vec![1, 2, 3]); + assert_eq!(decode_token_mapping(Dtype::I64, &i64_raw).unwrap(), vec![4, 5, 6]); + } + + #[test] + fn decode_token_mapping_rejects_unsupported_dtype() { + let err = decode_token_mapping(Dtype::F32, &[0, 0, 0, 0]).unwrap_err(); + assert!(err.to_string().contains("unsupported mapping dtype")); + } +} fn resolve_model_files>( repo_or_path: P, token: Option<&str>, From eb6d985e061213c3dc52a2dfe1e3ab3bd91f2f67 Mon Sep 17 00:00:00 2001 From: Pringled Date: Sat, 23 May 2026 12:17:43 +0200 Subject: [PATCH 2/2] fix CI --- src/model.rs | 53 ++++++++++++++++++++++++++-------------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/src/model.rs b/src/model.rs index 607aa79..990de8b 100644 --- a/src/model.rs +++ b/src/model.rs @@ -432,32 +432,6 @@ impl StaticModel { } } -#[cfg(test)] -mod tests { - use super::decode_token_mapping; - use safetensors::tensor::Dtype; - - #[test] - fn decode_token_mapping_supports_i32_and_i64() { - let i32_raw = [1i32, 2, 3] - .into_iter() - .flat_map(|value| value.to_le_bytes()) - .collect::>(); - let i64_raw = [4i64, 5, 6] - .into_iter() - .flat_map(|value| value.to_le_bytes()) - .collect::>(); - - assert_eq!(decode_token_mapping(Dtype::I32, &i32_raw).unwrap(), vec![1, 2, 3]); - assert_eq!(decode_token_mapping(Dtype::I64, &i64_raw).unwrap(), vec![4, 5, 6]); - } - - #[test] - fn decode_token_mapping_rejects_unsupported_dtype() { - let err = decode_token_mapping(Dtype::F32, &[0, 0, 0, 0]).unwrap_err(); - assert!(err.to_string().contains("unsupported mapping dtype")); - } -} fn resolve_model_files>( repo_or_path: P, token: Option<&str>, @@ -521,3 +495,30 @@ fn download_model_files(repo_id: &str, token: Option<&str>, subfolder: Option<&s result } + +#[cfg(test)] +mod tests { + use super::decode_token_mapping; + use safetensors::tensor::Dtype; + + #[test] + fn decode_token_mapping_supports_i32_and_i64() { + let i32_raw = [1i32, 2, 3] + .into_iter() + .flat_map(|value| value.to_le_bytes()) + .collect::>(); + let i64_raw = [4i64, 5, 6] + .into_iter() + .flat_map(|value| value.to_le_bytes()) + .collect::>(); + + assert_eq!(decode_token_mapping(Dtype::I32, &i32_raw).unwrap(), vec![1, 2, 3]); + assert_eq!(decode_token_mapping(Dtype::I64, &i64_raw).unwrap(), vec![4, 5, 6]); + } + + #[test] + fn decode_token_mapping_rejects_unsupported_dtype() { + let err = decode_token_mapping(Dtype::F32, &[0, 0, 0, 0]).unwrap_err(); + assert!(err.to_string().contains("unsupported mapping dtype")); + } +}