From 56da9646ccb2a295a50b9dca0c3c8abf829f3056 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Fri, 24 Apr 2026 15:25:51 -0700 Subject: [PATCH 01/21] Numel and Nbytes Validation (#19121) (#19121) Summary: Attempt 3 to check numel and nbytes overflow. This time we defer checking dynamic sized inputs until their size is realized. Reviewed By: lucylq Differential Revision: D98148157 --- runtime/executor/method.cpp | 28 ++++ runtime/executor/program_validation.cpp | 90 ++++++++----- runtime/executor/program_validation.h | 9 +- runtime/executor/test/method_test.cpp | 25 ++++ .../executor/test/program_validation_test.cpp | 126 ++++++++++++++---- 5 files changed, 214 insertions(+), 64 deletions(-) diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 577691dc44b..1610804586d 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include // @donotremove #include @@ -1194,6 +1195,33 @@ Method::set_input(const EValue& input_evalue, size_t input_idx) { input_idx, executorch::runtime::toString(t_dst.scalar_type()), executorch::runtime::toString(t_src.scalar_type())); + + ssize_t numel = 1; + for (ssize_t i = 0; i < t_src.dim(); i++) { + bool overflow = c10::mul_overflows( + numel, static_cast(t_src.size(i)), &numel); + ET_CHECK_OR_RETURN_ERROR( + !overflow, + InvalidArgument, + "Input %" ET_PRIsize_t + ": numel overflowed at dimension %zd with size %zd", + input_idx, + (size_t)i, + (size_t)t_src.size(i)); + } + size_t nbytes; + bool nbytes_overflow = c10::mul_overflows( + static_cast(numel), + executorch::runtime::elementSize(t_src.scalar_type()), + &nbytes); + ET_CHECK_OR_RETURN_ERROR( + !nbytes_overflow, + InvalidArgument, + "Input %" ET_PRIsize_t + ": nbytes overflowed: numel %zd with element size %zu", + input_idx, + numel, + executorch::runtime::elementSize(t_src.scalar_type())); // Reset the shape for the Method's input as the size of forwarded input // tensor for shape dynamism. Also is a safety check if need memcpy. ET_CHECK_OK_OR_RETURN_ERROR( diff --git a/runtime/executor/program_validation.cpp b/runtime/executor/program_validation.cpp index 92243fea289..448edc61ce5 100644 --- a/runtime/executor/program_validation.cpp +++ b/runtime/executor/program_validation.cpp @@ -14,7 +14,7 @@ #include #include -// #include +#include namespace executorch { namespace runtime { @@ -32,7 +32,8 @@ validate_tensor(const executorch_flatbuffer::Tensor* tensor) { return Error::InvalidProgram; } - // ssize_t numel = 1; + ssize_t numel = 1; + bool numel_overflowed = false; for (flatbuffers::uoffset_t i = 0; i < sizes->size(); i++) { int32_t size = sizes->Get(i); @@ -45,16 +46,10 @@ validate_tensor(const executorch_flatbuffer::Tensor* tensor) { return Error::InvalidProgram; } - // bool overflow = - // c10::mul_overflows(numel, static_cast(size), &numel); - // if (overflow) { - // ET_LOG( - // Error, - // "numel overflowed at dimension %u with size %d", - // static_cast(i), - // size); - // return Error::InvalidProgram; - // } + if (!numel_overflowed) { + numel_overflowed = + c10::mul_overflows(numel, static_cast(size), &numel); + } } auto scalar_type = @@ -64,19 +59,18 @@ validate_tensor(const executorch_flatbuffer::Tensor* tensor) { return Error::InvalidProgram; } - // size_t nbytes; - // bool nbytes_overflow = c10::mul_overflows( - // static_cast(numel), - // executorch::runtime::elementSize(scalar_type), - // &nbytes); - // if (nbytes_overflow) { - // ET_LOG( - // Error, - // "nbytes overflowed: numel %zd with element size %zu", - // numel, - // executorch::runtime::elementSize(scalar_type)); - // return Error::InvalidProgram; - // } + if (numel_overflowed) { + return Error::InvalidProgram; + } + + size_t nbytes; + bool nbytes_overflow = c10::mul_overflows( + static_cast(numel), + executorch::runtime::elementSize(scalar_type), + &nbytes); + if (nbytes_overflow) { + return Error::InvalidProgram; + } return Error::Ok; } @@ -114,6 +108,27 @@ validate_program(const executorch_flatbuffer::Program* program) { return Error::InvalidProgram; } + const auto* inputs = plan->inputs(); + auto is_dynamic_input = [&](flatbuffers::uoffset_t idx) -> bool { + if (inputs == nullptr) { + return false; + } + for (flatbuffers::uoffset_t i = 0; i < inputs->size(); i++) { + if (inputs->Get(i) == static_cast(idx)) { + const auto* value = values->Get(idx); + if (value == nullptr) { + return false; + } + const auto* tensor = + static_cast(value->val()); + return tensor != nullptr && + tensor->shape_dynamism() != + executorch_flatbuffer::TensorShapeDynamism::STATIC; + } + } + return false; + }; + for (flatbuffers::uoffset_t value_idx = 0; value_idx < values->size(); value_idx++) { const auto* value = values->Get(value_idx); @@ -128,12 +143,25 @@ validate_program(const executorch_flatbuffer::Program* program) { Error err = validate_tensor(tensor); if (err != Error::Ok) { - ET_LOG( - Error, - "Tensor validation failed for value %u in execution plan %u", - static_cast(value_idx), - static_cast(plan_idx)); - return err; + // Dynamic input tensors may have upper-bound sizes serialized for + // 64-bit machines that would overflow on 32-bit. Since their actual + // sizes are provided at set_input time, we defer overflow checks + // for those to Method::set_input. + if (is_dynamic_input(value_idx)) { + ET_LOG( + Info, + "Skipping validation failure for dynamic input tensor " + "at value %u in execution plan %u", + static_cast(value_idx), + static_cast(plan_idx)); + } else { + ET_LOG( + Error, + "Tensor validation failed for value %u in execution plan %u", + static_cast(value_idx), + static_cast(plan_idx)); + return err; + } } } diff --git a/runtime/executor/program_validation.h b/runtime/executor/program_validation.h index bb42a29423c..68e4ff7eb81 100644 --- a/runtime/executor/program_validation.h +++ b/runtime/executor/program_validation.h @@ -22,13 +22,12 @@ namespace executorch { namespace runtime { /** - * Validates that computing numel (number of elements) from the tensor's sizes - * will not overflow. This check should be performed before creating TensorImpl - * objects to prevent undefined behavior from integer overflow. + * Validates that a tensor's metadata is semantically valid: sizes are + * non-negative, scalar type is valid, and computing numel/nbytes will not + * overflow. * * @param[in] tensor The flatbuffer Tensor to validate. - * @return Error::Ok if the numel calculation is safe, Error::InvalidProgram - * if computing numel would overflow. + * @return Error::Ok if validation passes, Error::InvalidProgram otherwise. */ ET_NODISCARD Error validate_tensor(const executorch_flatbuffer::Tensor* tensor); diff --git a/runtime/executor/test/method_test.cpp b/runtime/executor/test/method_test.cpp index dc926184049..36a0c6f169b 100644 --- a/runtime/executor/test/method_test.cpp +++ b/runtime/executor/test/method_test.cpp @@ -310,6 +310,31 @@ TEST_F(MethodTest, AliasedIOTest) { } } +TEST_F(MethodTest, SetInputRejectsOverflowingSizes) { + // The "cat" model (ModuleDynamicCatUnallocatedIO) has a 2D input. + // set_input validates numel/nbytes overflow before resize_tensor. + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = programs_["cat"]->load_method("forward", &mmm.get()); + ASSERT_EQ(method.error(), Error::Ok); + + // Create a 2D tensor with enormous sizes. On 32-bit platforms the numel + // multiplication overflows; on 64-bit the resize bounds check rejects it. + int32_t sizes[2] = {2000000000, 2000000000}; + uint8_t dim_order[2] = {0, 1}; + int32_t strides[2] = {1, 1}; + executorch::aten::TensorImpl impl( + executorch::aten::ScalarType::Float, + 2, + sizes, + nullptr, + dim_order, + strides); + + auto input_err = + method->set_input(EValue(executorch::aten::Tensor(&impl)), 0); + EXPECT_NE(input_err, Error::Ok); +} + TEST_F(MethodTest, ConstantSegmentTest) { // Execute model with constants stored in segment. ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); diff --git a/runtime/executor/test/program_validation_test.cpp b/runtime/executor/test/program_validation_test.cpp index 73d3e5b9cb6..a133ff074ee 100644 --- a/runtime/executor/test/program_validation_test.cpp +++ b/runtime/executor/test/program_validation_test.cpp @@ -64,12 +64,15 @@ struct EValueConfig { EValueType type; std::vector tensor_sizes; // For Tensor type. std::vector tensor_list_items; // For TensorList type (indices). + bool is_dynamic = false; // For Tensor type: if true, uses DYNAMIC_BOUND. }; // Unified helper to create a minimal valid PTE flatbuffer with configurable -// evalues. Returns a buffer containing the flatbuffer data. +// evalues. Returns a buffer containing the flatbuffer data. input_indices +// specifies which value indices appear in the execution plan's inputs list. std::vector CreateTestProgram( - const std::vector& configs) { + const std::vector& configs, + const std::vector& input_indices = {}) { flatbuffers::FlatBufferBuilder builder(1024); std::vector> evalues; @@ -93,7 +96,9 @@ std::vector CreateTestProgram( /*data_buffer_idx=*/0, /*allocation_info=*/0, /*layout=*/0, - executorch_flatbuffer::TensorShapeDynamism::STATIC, + config.is_dynamic + ? executorch_flatbuffer::TensorShapeDynamism::DYNAMIC_BOUND + : executorch_flatbuffer::TensorShapeDynamism::STATIC, /*extra_tensor_info=*/0); evalues.push_back(executorch_flatbuffer::CreateEValue( builder, @@ -124,6 +129,7 @@ std::vector CreateTestProgram( auto values_vec = builder.CreateVector(evalues); auto plan_name = builder.CreateString("forward"); + auto inputs_vec = builder.CreateVector(input_indices); auto empty_int_vec = builder.CreateVector(std::vector{}); auto empty_int64_vec = builder.CreateVector(std::vector{0}); auto empty_chain_vec = builder.CreateVector( @@ -139,8 +145,8 @@ std::vector CreateTestProgram( plan_name, /*container_meta_type=*/0, values_vec, - empty_int_vec, - empty_int_vec, + /*inputs=*/inputs_vec, + /*outputs=*/empty_int_vec, empty_chain_vec, empty_operators_vec, empty_delegates_vec, @@ -206,29 +212,93 @@ TEST_F(ProgramValidationTest, InternalConsistencyDetectsTruncatedData) { ASSERT_EQ(program.error(), Error::InvalidProgram); } -// TEST_F(ProgramValidationTest, TensorNumelOverflowDetected) { -// std::vector configs = { -// {EValueType::Tensor, {2000000000, 2000000000, 2000000000}, {}}}; -// -// AlignedBuffer buf(CreateTestProgram(configs)); -// auto loader = buf.loader(); -// -// Result program = -// Program::load(&loader, Program::Verification::InternalConsistency); -// EXPECT_EQ(program.error(), Error::InvalidProgram); -// } - -// TEST_F(ProgramValidationTest, TensorNumelOverflowNotDetectedWithMinimal) { -// std::vector configs = { -// {EValueType::Tensor, {2000000000, 2000000000, 2000000000}, {}}}; -// -// AlignedBuffer buf(CreateTestProgram(configs)); -// auto loader = buf.loader(); -// -// // Minimal verification doesn't run program validation. -// Result program = -// Program::load(&loader, Program::Verification::Minimal); -// } +TEST_F(ProgramValidationTest, TensorNumelOverflowDetectedForStaticTensor) { + // Static tensors always have their overflow checked at validation time. + std::vector configs = { + {EValueType::Tensor, + {2000000000, 2000000000, 2000000000}, + {}, + /*is_dynamic=*/false}}; + + AlignedBuffer buf(CreateTestProgram(configs)); + auto loader = buf.loader(); + + Result program = + Program::load(&loader, Program::Verification::InternalConsistency); + EXPECT_EQ(program.error(), Error::InvalidProgram); +} + +TEST_F(ProgramValidationTest, TensorNumelOverflowNotDetectedWithMinimal) { + std::vector configs = { + {EValueType::Tensor, + {2000000000, 2000000000, 2000000000}, + {}, + /*is_dynamic=*/false}}; + + AlignedBuffer buf(CreateTestProgram(configs)); + auto loader = buf.loader(); + + // Minimal verification doesn't run program validation. + Result program = + Program::load(&loader, Program::Verification::Minimal); + EXPECT_EQ(program.error(), Error::Ok); +} + +TEST_F(ProgramValidationTest, TensorNumelOverflowSkippedForDynamicInput) { + // Dynamic input tensors skip overflow checks at validation time; the check + // is deferred to set_input where actual sizes are known. + std::vector configs = { + {EValueType::Tensor, + {2000000000, 2000000000, 2000000000}, + {}, + /*is_dynamic=*/true}}; + + // Mark value index 0 as a plan input. + AlignedBuffer buf(CreateTestProgram(configs, /*input_indices=*/{0})); + auto loader = buf.loader(); + + Result program = + Program::load(&loader, Program::Verification::InternalConsistency); + EXPECT_EQ(program.error(), Error::Ok); +} + +TEST_F( + ProgramValidationTest, + TensorNumelOverflowDetectedForDynamicNonInputTensor) { + // A dynamic tensor that is NOT in the inputs list should still have its + // overflow checked at validation time. + std::vector configs = { + {EValueType::Tensor, + {2000000000, 2000000000, 2000000000}, + {}, + /*is_dynamic=*/true}}; + + // No input indices — the tensor is not a plan input. + AlignedBuffer buf(CreateTestProgram(configs)); + auto loader = buf.loader(); + + Result program = + Program::load(&loader, Program::Verification::InternalConsistency); + EXPECT_EQ(program.error(), Error::InvalidProgram); +} + +TEST_F(ProgramValidationTest, TensorNumelOverflowDetectedForStaticInputTensor) { + // A static input tensor should still have its overflow checked at + // validation time since its sizes cannot change. + std::vector configs = { + {EValueType::Tensor, + {2000000000, 2000000000, 2000000000}, + {}, + /*is_dynamic=*/false}}; + + // Mark value index 0 as a plan input. + AlignedBuffer buf(CreateTestProgram(configs, /*input_indices=*/{0})); + auto loader = buf.loader(); + + Result program = + Program::load(&loader, Program::Verification::InternalConsistency); + EXPECT_EQ(program.error(), Error::InvalidProgram); +} TEST_F(ProgramValidationTest, NegativeSizeDetected) { std::vector configs = {{EValueType::Tensor, {10, -5, 10}, {}}}; From 60ffe194cdc359dea2524382f2fa8529b94dbf9a Mon Sep 17 00:00:00 2001 From: John Gibson <5562125+jgibson2@users.noreply.github.com> Date: Fri, 24 Apr 2026 18:36:52 -0400 Subject: [PATCH 02/21] portable: accumulate in fp32 for Half/BFloat16 in grid_sampler_2d bilinear (#19117) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary The bilinear grid_sampler_2d portable kernel computes interpolation weights via subtractions like `(ix_se - ix)` where both operands are close integer-valued coordinates in pixel space. In fp16 (10 bits of mantissa) that's classic catastrophic cancellation — the result has only a handful of significant bits. The downstream weighted-sum accumulation then loses further precision. Measured on a unit test exercising interior grid points with fp16 inputs, the kernel drifts by ~0.1 absolute from an fp32 reference. That's visible as incorrect depth / flow output near non-integer sample points, which is most of them. ## Fix An `AccType` trait mapping `Half` and `BFloat16` to `float`, leaving every other dtype unchanged. Used for intermediate coordinate, weight computation, and `out_val` accumulation. Loads cast `CTYPE -> ACC`; the final store casts `ACC -> CTYPE` once. Only internal math is promoted; memory layout / public API / tensor dtypes are unchanged. ```cpp template using AccType = std::conditional_t< std::is_same_v || std::is_same_v, float, CTYPE>; ``` ## Effects - **fp32 / Int / any non-half dtype**: `AccType` is `T`, so the generated code is byte-identical. No behavior change. - **Half / BFloat16**: `max_abs` vs an fp32 reference drops from **~0.1 to 0** on the shapes I tested (N=1..2, C=7..64, H/W up to 96, both `align_corners` values). - **Perf**: a handful of fp16↔fp32 conversions per output element. Not measurable at op level; well within the portable kernel's scalar cost envelope. ## Scope Only touches the bilinear interpolation path. The nearest-mode path doesn't do weighted-sum accumulation and doesn't have the cancellation issue — left alone in this change. ## Test plan - [x] Builds clean for Android arm64 and host (Apple Clang 21). - [x] Verified numerically via a standalone harness that runs the kernel with matched fp32 / fp16 inputs and compares against an fp32-then-downcast reference. All shapes pass within a single fp16 ULP (or are bit-exact). fp32 tests remain bit-identical to the pre-change kernel. - [x] Existing `kernels/test/op_grid_sampler_2d_test.cpp` unit tests continue to pass (both fp32 shapes that were previously tested, and the fp16 path I'm specifically fixing). Happy to add an fp16-specific test case to `op_grid_sampler_2d_test.cpp` if useful for CI coverage here — just let me know the preferred approach. cc @larryliu0820 @manuelcandales --- kernels/portable/cpu/op_grid_sampler_2d.cpp | 99 ++++++++++++++------- 1 file changed, 65 insertions(+), 34 deletions(-) diff --git a/kernels/portable/cpu/op_grid_sampler_2d.cpp b/kernels/portable/cpu/op_grid_sampler_2d.cpp index 57155b3c01b..dd483c4ddaa 100644 --- a/kernels/portable/cpu/op_grid_sampler_2d.cpp +++ b/kernels/portable/cpu/op_grid_sampler_2d.cpp @@ -10,6 +10,8 @@ #include #include +#include + namespace torch { namespace executor { namespace native { @@ -19,6 +21,22 @@ using executorch::aten::SizesType; using std::optional; namespace { + +// For half-precision inputs, all internal math (source-index computation, +// interpolation weight subtractions like `ix_se - ix` which are prone to +// catastrophic cancellation, and weighted-sum accumulation) is done in fp32. +// Loads and stores stay in the tensor's dtype. The speed cost is negligible +// (a handful of fp16↔fp32 conversions per output element) and the precision +// win is material: fp16 has only ~10 bits of mantissa, so subtracting nearby +// pixel coordinates can round to values that are meaningfully off, producing +// visibly wrong interpolation weights. +template +using AccType = std::conditional_t< + std::is_same_v || + std::is_same_v, + float, + CTYPE>; + template void grid_sample_2d_bilinear_kernel_impl_nchw( const Tensor& in, @@ -26,6 +44,7 @@ void grid_sample_2d_bilinear_kernel_impl_nchw( GridSamplerPadding padding_mode, bool align_corners, Tensor& out) { + using ACC = AccType; const auto in_data = in.const_data_ptr(); auto out_data = out.mutable_data_ptr(); @@ -59,13 +78,14 @@ void grid_sample_2d_bilinear_kernel_impl_nchw( // grid[n, h, w] contains (x, y) const int64_t grid_idx = grid_offset + h * grid.strides()[1] + w * grid.strides()[2]; - const CTYPE x = grid_data[grid_idx]; - const CTYPE y = grid_data[grid_idx + grid.strides()[3]]; + const ACC x = static_cast(grid_data[grid_idx]); + const ACC y = + static_cast(grid_data[grid_idx + grid.strides()[3]]); - // Compute source coordinates in pixel space - const CTYPE ix = grid_sampler_compute_source_index( + // Compute source coordinates in pixel space (in ACC precision). + const ACC ix = grid_sampler_compute_source_index( x, inp_W, padding_mode, align_corners); - const CTYPE iy = grid_sampler_compute_source_index( + const ACC iy = grid_sampler_compute_source_index( y, inp_H, padding_mode, align_corners); // Get corner pixel coordinates @@ -78,40 +98,46 @@ void grid_sample_2d_bilinear_kernel_impl_nchw( const int64_t ix_se = ix_nw + 1; const int64_t iy_se = iy_nw + 1; - // Get interpolation weights - const CTYPE nw_weight = (ix_se - ix) * (iy_se - iy); - const CTYPE ne_weight = (ix - ix_sw) * (iy_sw - iy); - const CTYPE sw_weight = (ix_ne - ix) * (iy - iy_ne); - const CTYPE se_weight = (ix - ix_nw) * (iy - iy_nw); + // Interpolation weights. For half inputs these are computed in + // fp32 — the subtractions `ix_se - ix` otherwise suffer + // catastrophic cancellation in fp16 for interior pixels. + const ACC nw_weight = (ix_se - ix) * (iy_se - iy); + const ACC ne_weight = (ix - ix_sw) * (iy_sw - iy); + const ACC sw_weight = (ix_ne - ix) * (iy - iy_ne); + const ACC se_weight = (ix - ix_nw) * (iy - iy_nw); - // Compute output value for this channel - CTYPE out_val = 0; + // Accumulate the weighted sum in ACC precision. + ACC out_val = 0; // Add contribution from each corner if within bounds if (padding_mode == GridSamplerPadding::Zeros) { // For zeros padding, only sample if within bounds if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { - out_val += in_data - [in_channel_offset + iy_nw * in.strides()[2] + - ix_nw * in.strides()[3]] * + out_val += static_cast( + in_data + [in_channel_offset + iy_nw * in.strides()[2] + + ix_nw * in.strides()[3]]) * nw_weight; } if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { - out_val += in_data - [in_channel_offset + iy_ne * in.strides()[2] + - ix_ne * in.strides()[3]] * + out_val += static_cast( + in_data + [in_channel_offset + iy_ne * in.strides()[2] + + ix_ne * in.strides()[3]]) * ne_weight; } if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { - out_val += in_data - [in_channel_offset + iy_sw * in.strides()[2] + - ix_sw * in.strides()[3]] * + out_val += static_cast( + in_data + [in_channel_offset + iy_sw * in.strides()[2] + + ix_sw * in.strides()[3]]) * sw_weight; } if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { - out_val += in_data - [in_channel_offset + iy_se * in.strides()[2] + - ix_se * in.strides()[3]] * + out_val += static_cast( + in_data + [in_channel_offset + iy_se * in.strides()[2] + + ix_se * in.strides()[3]]) * se_weight; } } else { @@ -126,28 +152,33 @@ void grid_sample_2d_bilinear_kernel_impl_nchw( const int64_t iy_sw_safe = clip_coordinates(iy_sw, inp_H); const int64_t ix_se_safe = clip_coordinates(ix_se, inp_W); const int64_t iy_se_safe = clip_coordinates(iy_se, inp_H); - out_val = in_data - [in_channel_offset + iy_nw_safe * in.strides()[2] + - ix_nw_safe * in.strides()[3]] * + out_val = + static_cast( + in_data + [in_channel_offset + iy_nw_safe * in.strides()[2] + + ix_nw_safe * in.strides()[3]]) * nw_weight + - in_data + static_cast( + in_data [in_channel_offset + iy_ne_safe * in.strides()[2] + - ix_ne_safe * in.strides()[3]] * + ix_ne_safe * in.strides()[3]]) * ne_weight + - in_data + static_cast( + in_data [in_channel_offset + iy_sw_safe * in.strides()[2] + - ix_sw_safe * in.strides()[3]] * + ix_sw_safe * in.strides()[3]]) * sw_weight + - in_data + static_cast( + in_data [in_channel_offset + iy_se_safe * in.strides()[2] + - ix_se_safe * in.strides()[3]] * + ix_se_safe * in.strides()[3]]) * se_weight; } // Write output in NCHW order const int64_t out_idx = out_channel_offset + h * out.strides()[2] + w * out.strides()[3]; - out_data[out_idx] = out_val; + out_data[out_idx] = static_cast(out_val); } } } From 2330652dae1f9213322112955c2380e7a020c86d Mon Sep 17 00:00:00 2001 From: Naveen Suda <99509021+navsud@users.noreply.github.com> Date: Fri, 24 Apr 2026 17:05:00 -0700 Subject: [PATCH 03/21] Allow chunked prefill when num_prompt_tokens > max_seq_len Differential Revision: D101728720 Pull Request resolved: https://github.com/pytorch/executorch/pull/19052 --- extension/llm/runner/text_llm_runner.cpp | 29 ++++++++++++++++-------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index fecedbd2f92..160b254460a 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -138,16 +138,16 @@ Error TextLLMRunner::generate( num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); - ET_CHECK_OR_RETURN_ERROR( - num_prompt_tokens <= max_seq_len, - InvalidArgument, - "num_prompt_tokens %d > max_seq_len %" PRId64 - ", Single prefill chunk too large - please reduce prompt size or increase max_seq_len", - num_prompt_tokens, - max_seq_len); - // For non-sliding-window models, also check that we won't exceed - // KV cache capacity. Sliding window models (where max_seq_len < - // max_context_len) handle position wrapping internally. + // Note: We intentionally do NOT enforce num_prompt_tokens <= max_seq_len + // here. TextPrefiller::prefill() supports chunked prefill: when + // num_prompt_tokens > max_seq_len it splits the prompt into max_seq_len + // chunks and prefills them sequentially. Models that were exported with + // max_seq_len < max_context_len (e.g. a 1024 prefill chunk over a 4096 KV + // cache) rely on this behavior. + // Ensure the prompt fits within total KV cache capacity. For + // sliding-window models (where max_seq_len < max_context_len) the model + // handles position wrapping internally, so pos_ doesn't represent + // consumed capacity and we only need a per-call bound. if (max_seq_len >= max_context_len) { ET_CHECK_OR_RETURN_ERROR( pos_ + num_prompt_tokens < max_context_len, @@ -158,6 +158,15 @@ Error TextLLMRunner::generate( pos_, num_prompt_tokens, max_context_len); + } else { + ET_CHECK_OR_RETURN_ERROR( + num_prompt_tokens < max_context_len, + InvalidArgument, + "num_prompt_tokens %d >= max_context_len %" PRId64 + ", Prompt exceeds KV cache capacity - please reduce prompt size or " + "increase max_context_len in your export script", + num_prompt_tokens, + max_context_len); } // print prompts From 0a43e2f865c4ecef84ca35495e55cb9b2d99d965 Mon Sep 17 00:00:00 2001 From: Alessandro Vacca <99895808+AlessandroVacca@users.noreply.github.com> Date: Sat, 25 Apr 2026 02:15:47 +0200 Subject: [PATCH 04/21] MLX delegate: add integer support for aten.bitwise_not (#19053) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Summary Fixes #18924 Extends `aten.bitwise_not` support in the MLX delegate to handle integer tensors, not just boolean tensors. Previously the handler only dispatched to `LogicalNotNode` for `bool` and raised `NotImplementedError` for all other dtypes. This adds a dedicated `BitwiseInvertNode` backed by `mlx::core::bitwise_invert`, and updates the handler to dispatch based on dtype: - `bool` → `LogicalNotNode` (unchanged) - `int32`, `int64` → `BitwiseInvertNode` #### Changes: - `serialization/schema.fbs`: add `BitwiseInvertNode` table and append to `OpNode` union - `runtime/MLXInterpreter.h`: add `exec_bitwise_invert()` and dispatch case - `ops.py`: update `_bitwise_not_handler` to dispatch to `BitwiseInvertNode` for integers - `test/test_ops.py`: add `bitwise_not_int` test for `int32` and `int64` ### Test plan All tests were ran on a machine with an Apple M1 Pro CPU, macOS 26.4.1. - `python3 -m py_compile backends/mlx/ops.py backends/mlx/test/test_ops.py` - `python3 backends/mlx/serialization/generate.py` - `python3 -m executorch.backends.mlx.test.run_all_tests bitwise_not_int` ### Test output ``` ============================================================ TEST SUMMARY ============================================================ Passed: 6 Failed: 0 ============================================================ ``` cc @metascroy --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: Scott Roy <161522778+metascroy@users.noreply.github.com> --- backends/mlx/ops.py | 28 ++++++++++++++++++++------- backends/mlx/runtime/MLXInterpreter.h | 11 +++++++++++ backends/mlx/serialization/schema.fbs | 8 +++++++- backends/mlx/test/test_ops.py | 1 + 4 files changed, 40 insertions(+), 8 deletions(-) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 3f7da88a793..27d214e0ae9 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -50,6 +50,7 @@ AsStridedNode, AsTypeNode, Atan2Node, + BitwiseInvertNode, BroadcastToNode, CeilNode, ClipNode, @@ -3066,27 +3067,40 @@ def _where_handler(P: MLXProgramBuilder, n: Node) -> Slot: @REGISTRY.register(target=[torch.ops.aten.bitwise_not.default]) def _bitwise_not_handler(P: MLXProgramBuilder, n: Node) -> Slot: - """Handle aten.bitwise_not - for boolean tensors, dispatch to logical_not.""" + """Handle aten.bitwise_not - logical_not for bool, bitwise_invert for integers.""" args = P.args(n) require_args(args, 1, 1, "aten.bitwise_not") require_kwargs(P.kwargs(n), set(), "aten.bitwise_not") x_meta = n.args[0].meta.get("val") + out = P.make_or_get_slot(n) - if x_meta is not None and x_meta.dtype == torch.bool: - # For boolean tensors, bitwise_not is equivalent to logical_not - out = P.make_or_get_slot(n) + if x_meta is None or not hasattr(x_meta, "dtype"): + raise NotImplementedError( + "aten.bitwise_not requires known input dtype metadata for MLX lowering" + ) + + if x_meta.dtype == torch.bool: P.emit( LogicalNotNode( x=P.slot_to_tid(args[0]), out=P.slot_to_tid(out), ) ) - return out + elif x_meta.dtype in { + torch.int32, + torch.int64, + }: + P.emit( + BitwiseInvertNode( + x=P.slot_to_tid(args[0]), + out=P.slot_to_tid(out), + ) + ) else: raise NotImplementedError( - f"aten.bitwise_not is only supported for boolean tensors. " - f"Got dtype={x_meta.dtype if x_meta else 'unknown'}" + f"aten.bitwise_not on dtype {x_meta.dtype} is not supported for MLX lowering" ) + return out @REGISTRY.register( diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 9fa08ab722d..304fdfe9805 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -1380,6 +1380,13 @@ inline void exec_logical_not( st.set_tensor(n.out, logical_not(st.const_tensor_ref(n.x), s)); } +inline void exec_bitwise_invert( + const BitwiseInvertNode& n, + ExecutionState& st, + StreamOrDevice s) { + st.set_tensor(n.out, bitwise_invert(st.const_tensor_ref(n.x), s)); +} + inline void exec_logical_and( const LogicalAndNode& n, ExecutionState& st, @@ -2028,6 +2035,10 @@ class Interpreter { case OpCode::LOGICAL_NOT: ops::exec_logical_not(std::get(instr.node), st, s); break; + case OpCode::BITWISE_INVERT: + ops::exec_bitwise_invert( + std::get(instr.node), st, s); + break; case OpCode::LOGICAL_AND: ops::exec_logical_and(std::get(instr.node), st, s); break; diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 6e8d6f47db8..67b4636f0be 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -562,6 +562,11 @@ table LogicalNotNode { out: Tid (required); } +table BitwiseInvertNode { + x: Tid (required); + out: Tid (required); +} + table LogicalAndNode { a: Tid (required); b: Tid (required); @@ -1113,7 +1118,8 @@ union OpNode { GatherMmNode, GatherQmmNode, ScanNode, - MetalKernelNode + MetalKernelNode, + BitwiseInvertNode // BC: Add new op nodes here (append only) } diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 7ba3902e436..459d5aa1e73 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -4111,6 +4111,7 @@ def create_model(self) -> nn.Module: {"op_name": "abs", "op_fn": torch.abs}, {"op_name": "neg", "op_fn": torch.neg}, {"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()}, + {"op_name": "bitwise_not_int", "op_fn": torch.bitwise_not, "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], "input_fn": _int_input_fn()}, {"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()}, # activations {"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)}, From 222711e4723ff8e8e49bcdf8bf256a1820062054 Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 24 Apr 2026 17:57:39 -0700 Subject: [PATCH 05/21] Add safe_numel() Differential Revision: D102070375 Pull Request resolved: https://github.com/pytorch/executorch/pull/19074 --- runtime/core/exec_aten/exec_aten.h | 31 ++++++++++++++++++++++ runtime/core/exec_aten/targets.bzl | 5 +++- runtime/core/portable_type/tensor_impl.cpp | 27 +++++++++++++++++++ runtime/core/portable_type/tensor_impl.h | 13 +++++++++ 4 files changed, 75 insertions(+), 1 deletion(-) diff --git a/runtime/core/exec_aten/exec_aten.h b/runtime/core/exec_aten/exec_aten.h index 8c06045927e..f539414aec9 100644 --- a/runtime/core/exec_aten/exec_aten.h +++ b/runtime/core/exec_aten/exec_aten.h @@ -8,7 +8,10 @@ #pragma once +#include // @manual +#include // @manual #include // @manual +#include // @manual #include #ifdef USE_ATEN_LIB #include // @manual @@ -28,6 +31,7 @@ #include // @manual #include // @manual #include // @manual +#include // @manual #include // @manual #include #else // use executor @@ -110,6 +114,32 @@ inline ssize_t compute_numel(const SizesType* sizes, ssize_t dim) { c10::multiply_integers(c10::ArrayRef(sizes, dim))); } +inline ::executorch::runtime::Result safe_numel( + const SizesType* sizes, + ssize_t dim) { + ET_CHECK_OR_RETURN_ERROR( + dim == 0 || sizes != nullptr, + InvalidArgument, + "Sizes must be provided for non-scalar tensors"); + ssize_t numel = 1; + for (ssize_t i = 0; i < dim; i++) { + ET_CHECK_OR_RETURN_ERROR( + sizes[i] >= 0, + InvalidArgument, + "Size must be non-negative, got %zd at dimension %zd", + static_cast(sizes[i]), + i); + ssize_t next_numel; + ET_CHECK_OR_RETURN_ERROR( + !c10::mul_overflows(numel, static_cast(sizes[i]), &next_numel), + InvalidArgument, + "Overflow computing numel at dimension %zd", + i); + numel = next_numel; + } + return numel; +} + #undef ET_PRI_TENSOR_SIZE #define ET_PRI_TENSOR_SIZE PRId64 @@ -158,6 +188,7 @@ using OptionalArrayRef = using OptionalIntArrayRef = OptionalArrayRef; using torch::executor::compute_numel; +using torch::executor::safe_numel; #endif // Use ExecuTorch types diff --git a/runtime/core/exec_aten/targets.bzl b/runtime/core/exec_aten/targets.bzl index df4a87ef033..7499d3b0bea 100644 --- a/runtime/core/exec_aten/targets.bzl +++ b/runtime/core/exec_aten/targets.bzl @@ -16,6 +16,9 @@ def define_common_targets(): exported_headers = ["exec_aten.h"], exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [], visibility = ["PUBLIC"], - exported_deps = ["//executorch/runtime/core:tensor_shape_dynamism"] + ([] if aten_mode else ["//executorch/runtime/core/portable_type:portable_type"]), + exported_deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core:tensor_shape_dynamism", + ] + ([] if aten_mode else ["//executorch/runtime/core/portable_type:portable_type"]), exported_external_deps = ["libtorch"] if aten_mode else [], ) diff --git a/runtime/core/portable_type/tensor_impl.cpp b/runtime/core/portable_type/tensor_impl.cpp index 17243fca0fd..affc5821fed 100644 --- a/runtime/core/portable_type/tensor_impl.cpp +++ b/runtime/core/portable_type/tensor_impl.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -43,6 +44,32 @@ ssize_t compute_numel(const TensorImpl::SizesType* sizes, ssize_t dim) { return numel; } +::executorch::runtime::Result safe_numel( + const TensorImpl::SizesType* sizes, + ssize_t dim) { + ET_CHECK_OR_RETURN_ERROR( + dim == 0 || sizes != nullptr, + InvalidArgument, + "Sizes must be provided for non-scalar tensors"); + ssize_t numel = 1; + for (const auto i : c10::irange(dim)) { + ET_CHECK_OR_RETURN_ERROR( + sizes[i] >= 0, + InvalidArgument, + "Size must be non-negative, got %zd at dimension %zd", + static_cast(sizes[i]), + i); + ssize_t next_numel; + ET_CHECK_OR_RETURN_ERROR( + !c10::mul_overflows(numel, static_cast(sizes[i]), &next_numel), + InvalidArgument, + "Overflow computing numel at dimension %zd", + i); + numel = next_numel; + } + return numel; +} + TensorImpl::TensorImpl( ScalarType type, ssize_t dim, diff --git a/runtime/core/portable_type/tensor_impl.h b/runtime/core/portable_type/tensor_impl.h index ea2cde5aeb0..b01d8fa6c52 100644 --- a/runtime/core/portable_type/tensor_impl.h +++ b/runtime/core/portable_type/tensor_impl.h @@ -12,7 +12,9 @@ #include #include #include +#include #include +#include // Forward declaration of a helper that provides access to internal resizing // methods of TensorImpl. Real definition is in @@ -293,6 +295,16 @@ ssize_t compute_numel( const ::executorch::runtime::etensor::TensorImpl::SizesType* sizes, ssize_t dim); +/** + * Compute the number of elements based on the sizes of a tensor. + * Returns Error::InvalidArgument if any intermediate multiplication would + * overflow ssize_t, or if a size is negative. Prefer this over compute_numel() + * for paths that can propagate an Error upward. + */ +::executorch::runtime::Result safe_numel( + const ::executorch::runtime::etensor::TensorImpl::SizesType* sizes, + ssize_t dim); + /// Appropriate format specifier for the result of calling /// size(). Must be used instead of using zd directly to support ATen /// mode. @@ -322,6 +334,7 @@ namespace executor { // TODO(T197294990): Remove these deprecated aliases once all users have moved // to the new `::executorch` namespaces. using ::executorch::runtime::etensor::compute_numel; +using ::executorch::runtime::etensor::safe_numel; using ::executorch::runtime::etensor::TensorImpl; } // namespace executor } // namespace torch From e1cd352d91e436599860e409c164ab66cfc9c557 Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 24 Apr 2026 18:12:53 -0700 Subject: [PATCH 06/21] Use safe numel() in ET (retake) (#19130) https://github.com/pytorch/executorch/pull/19075 failed to cp to main --- extension/tensor/tensor_ptr.cpp | 33 ++++++++++++++++++---- extension/tensor/tensor_ptr.h | 17 ++++++++--- extension/tensor/tensor_ptr_maker.cpp | 21 ++++++++------ extension/wasm/wasm_bindings.cpp | 10 ++++--- runtime/core/portable_type/tensor_impl.cpp | 6 +++- 5 files changed, 65 insertions(+), 22 deletions(-) diff --git a/extension/tensor/tensor_ptr.cpp b/extension/tensor/tensor_ptr.cpp index bb76311bd67..a6ba6018333 100644 --- a/extension/tensor/tensor_ptr.cpp +++ b/extension/tensor/tensor_ptr.cpp @@ -10,6 +10,8 @@ #include +#include + #include namespace executorch { @@ -147,11 +149,26 @@ TensorPtr make_tensor_ptr( std::vector strides, executorch::aten::ScalarType type, executorch::aten::TensorShapeDynamism dynamism) { + auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size()); + ET_CHECK_MSG( + numel_result.ok(), + "safe_numel failed: %d", + static_cast(numel_result.error())); + const ssize_t numel = numel_result.get(); + size_t nbytes; ET_CHECK_MSG( - data.size() == - executorch::aten::compute_numel(sizes.data(), sizes.size()) * - executorch::aten::elementSize(type), - "Data size does not match tensor size."); + !c10::mul_overflows( + static_cast(numel), + executorch::aten::elementSize(type), + &nbytes), + "Overflow computing nbytes: numel=%zd element_size=%zu", + numel, + executorch::aten::elementSize(type)); + ET_CHECK_MSG( + data.size() == nbytes, + "Data size (%zu) does not match tensor size (%zu).", + data.size(), + nbytes); auto data_ptr = data.data(); return make_tensor_ptr( std::move(sizes), @@ -205,7 +222,13 @@ TensorPtr clone_tensor_ptr( runtime::canCast(tensor_type, type), "Cannot cast tensor type to desired type."); const auto tensor_numel = static_cast(tensor.numel()); - std::vector data(tensor_numel * aten::elementSize(type)); + size_t clone_nbytes; + ET_CHECK_MSG( + !c10::mul_overflows(tensor_numel, aten::elementSize(type), &clone_nbytes), + "Overflow computing clone nbytes: numel=%zu element_size=%zu", + tensor_numel, + aten::elementSize(type)); + std::vector data(clone_nbytes); // Create a minimal context for error handling in ET_SWITCH struct { diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index 08d6cc1254c..0ed06cbe021 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -110,9 +110,13 @@ inline TensorPtr make_tensor_ptr( executorch::aten::ScalarType type = deduced_type, executorch::aten::TensorShapeDynamism dynamism = executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size()); ET_CHECK_MSG( - data.size() == - executorch::aten::compute_numel(sizes.data(), sizes.size()), + numel_result.ok(), + "safe_numel failed: %d", + static_cast(numel_result.error())); + ET_CHECK_MSG( + data.size() == static_cast(numel_result.get()), "Data size does not match tensor size."); if (type != deduced_type) { ET_CHECK_MSG( @@ -368,8 +372,13 @@ inline TensorPtr make_tensor_ptr( const auto same_rank = sizes.size() == static_cast(tensor.dim()); const auto same_shape = same_rank && std::equal(sizes.begin(), sizes.end(), tensor.sizes().begin()); - const auto element_count = - executorch::aten::compute_numel(sizes.data(), sizes.size()); + auto element_count_result = + executorch::aten::safe_numel(sizes.data(), sizes.size()); + ET_CHECK_MSG( + element_count_result.ok(), + "safe_numel failed: %d", + static_cast(element_count_result.error())); + const auto element_count = element_count_result.get(); const auto parent_element_count = tensor.numel(); ET_CHECK_MSG( element_count <= parent_element_count, diff --git a/extension/tensor/tensor_ptr_maker.cpp b/extension/tensor/tensor_ptr_maker.cpp index 52a3e8f281c..583f27a232d 100644 --- a/extension/tensor/tensor_ptr_maker.cpp +++ b/extension/tensor/tensor_ptr_maker.cpp @@ -113,16 +113,21 @@ TensorPtr empty_strided( std::vector strides, executorch::aten::ScalarType type, executorch::aten::TensorShapeDynamism dynamism) { - const auto numel = static_cast( - executorch::aten::compute_numel(sizes.data(), sizes.size())); - const auto elem_size = - static_cast(executorch::aten::elementSize(type)); - size_t nbytes = 0; + auto numel_result = executorch::aten::safe_numel(sizes.data(), sizes.size()); ET_CHECK_MSG( - !c10::mul_overflows(numel, elem_size, &nbytes), - "empty_strided size overflow: numel %zu * element size %zu", + numel_result.ok(), + "safe_numel failed: %d", + static_cast(numel_result.error())); + const ssize_t numel = numel_result.get(); + size_t nbytes; + ET_CHECK_MSG( + !c10::mul_overflows( + static_cast(numel), + executorch::aten::elementSize(type), + &nbytes), + "Overflow computing nbytes: numel=%zd element_size=%zu", numel, - elem_size); + executorch::aten::elementSize(type)); std::vector data(nbytes); return make_tensor_ptr( std::move(sizes), diff --git a/extension/wasm/wasm_bindings.cpp b/extension/wasm/wasm_bindings.cpp index 38a227f9067..2066be4d7e7 100644 --- a/extension/wasm/wasm_bindings.cpp +++ b/extension/wasm/wasm_bindings.cpp @@ -84,9 +84,9 @@ inline void js_array_push(val_array& array, const T& value) { _(float, Float) \ _(int64_t, Long) -inline ssize_t compute_expected_numel( +inline ::executorch::runtime::Result compute_expected_numel( const std::vector& sizes) { - return executorch::aten::compute_numel(sizes.data(), sizes.size()); + return executorch::aten::safe_numel(sizes.data(), sizes.size()); } template @@ -94,10 +94,12 @@ inline void assert_valid_numel( const std::vector& data, const std::vector& sizes) { auto computed_numel = compute_expected_numel(sizes); + THROW_IF_ERROR( + computed_numel.error(), "Invalid tensor sizes: numel computation failed"); THROW_IF_FALSE( - data.size() >= computed_numel, + data.size() >= static_cast(computed_numel.get()), "Required %ld elements, given %ld", - computed_numel, + computed_numel.get(), data.size()); } diff --git a/runtime/core/portable_type/tensor_impl.cpp b/runtime/core/portable_type/tensor_impl.cpp index affc5821fed..113a1f06c83 100644 --- a/runtime/core/portable_type/tensor_impl.cpp +++ b/runtime/core/portable_type/tensor_impl.cpp @@ -147,7 +147,11 @@ Error TensorImpl::internal_resize_contiguous(ArrayRef new_sizes) { // TODO(T175194371): Unbounded dynamic tensor resizing is not yet // supported: treat them as upper-bounded. case TensorShapeDynamism::DYNAMIC_UNBOUND: { - const auto new_numel = compute_numel(new_sizes.data(), dim_); + auto new_numel_result = safe_numel(new_sizes.data(), dim_); + if (!new_numel_result.ok()) { + return new_numel_result.error(); + } + const auto new_numel = new_numel_result.get(); ET_CHECK_OR_RETURN_ERROR( static_cast(new_numel) <= numel_bound_, From b8f04aac0720339c8a7cc088bfdbafbbce40b170 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Fri, 24 Apr 2026 19:28:54 -0700 Subject: [PATCH 07/21] Android: Module implements Closeable (#19124) Add Closeable interface so Module can be used with try-with-resources. close() delegates to destroy(). Also make destroy() idempotent by checking mHybridData.isValid() before calling resetNative(), satisfying the Closeable contract. This commit was authored with the help of Claude. --- .../src/main/java/org/pytorch/executorch/Module.java | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index 05e1e5b88cf..6cf99966e6a 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -12,6 +12,7 @@ import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; +import java.io.Closeable; import java.util.HashMap; import java.util.Map; import java.util.concurrent.locks.Lock; @@ -24,7 +25,7 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public class Module { +public class Module implements Closeable { static { if (!NativeLoader.isInitialized()) { @@ -274,7 +275,9 @@ public boolean etdump() { public void destroy() { if (mLock.tryLock()) { try { - mHybridData.resetNative(); + if (mHybridData.isValid()) { + mHybridData.resetNative(); + } } finally { mLock.unlock(); } @@ -282,4 +285,9 @@ public void destroy() { throw new IllegalStateException("Cannot destroy module while method is executing"); } } + + @Override + public void close() { + destroy(); + } } From 7e2ff8ae55147f3949539d09b1f85f9d8ae5b0d1 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Fri, 24 Apr 2026 20:43:52 -0700 Subject: [PATCH 08/21] Android: consistent error types across all modules (#19099) TrainingModule: implement Closeable, replace Log.e + silent empty returns with IllegalStateException throws. Add checkNotDestroyed() guard on all public methods. SGD: throw IllegalStateException instead of bare RuntimeException when optimizer is destroyed. AsrModule: throw ExecutorchRuntimeException instead of bare RuntimeException on transcription failure. ExecuTorchRuntime.validateFilePath: throw IllegalArgumentException instead of bare RuntimeException, with descriptive message. JNI constructors: wrap ExecuTorchJni and ExecuTorchLlmJni constructor bodies in try-catch so C++ exceptions become ExecutorchRuntimeException instead of generic RuntimeException. This commit was authored with the help of Claude. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../pytorch/executorch/ExecuTorchRuntime.java | 16 +- .../executorch/extension/asr/AsrModule.kt | 10 +- .../org/pytorch/executorch/training/SGD.java | 2 +- .../executorch/training/TrainingModule.java | 32 +-- extension/android/jni/jni_layer.cpp | 14 +- extension/android/jni/jni_layer_llama.cpp | 190 ++++++++++-------- 6 files changed, 152 insertions(+), 112 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java index 30ebf1a2c1d..53ee4d3f33a 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java @@ -36,12 +36,22 @@ public static ExecuTorchRuntime getRuntime() { /** * Validates that the given path points to a readable file. * - * @throws RuntimeException if the file does not exist or is not readable. + * @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not + * readable. */ public static void validateFilePath(String path, String description) { + if (path == null) { + throw new IllegalArgumentException("Cannot load " + description + ": path is null"); + } File file = new File(path); - if (!file.canRead() || !file.isFile()) { - throw new RuntimeException("Cannot load " + description + " " + path); + if (!file.exists()) { + throw new IllegalArgumentException("Cannot load " + description + "!! " + path); + } + if (!file.isFile()) { + throw new IllegalArgumentException("Cannot load " + description + "!! " + path); + } + if (!file.canRead()) { + throw new IllegalArgumentException("Cannot load " + description + "!! " + path); } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt index 987cb3ec3be..ab9099ba405 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt @@ -11,6 +11,7 @@ package org.pytorch.executorch.extension.asr import java.io.Closeable import java.io.File import java.util.concurrent.atomic.AtomicLong +import org.pytorch.executorch.ExecutorchRuntimeException import org.pytorch.executorch.annotations.Experimental /** @@ -53,7 +54,10 @@ class AsrModule( val handle = nativeCreate(modelPath, tokenizerPath, dataPath, preprocessorPath) if (handle == 0L) { - throw RuntimeException("Failed to create native AsrModule") + throw ExecutorchRuntimeException( + ExecutorchRuntimeException.INTERNAL, + "Failed to create native AsrModule", + ) } nativeHandle.set(handle) } @@ -129,7 +133,7 @@ class AsrModule( * @param callback Optional callback to receive tokens as they are generated (can be null) * @return The complete transcribed text * @throws IllegalStateException if the module has been destroyed - * @throws RuntimeException if transcription fails (non-zero result code) + * @throws ExecutorchRuntimeException if transcription fails (error code carried in exception) */ @JvmOverloads fun transcribe( @@ -160,7 +164,7 @@ class AsrModule( ) if (status != 0) { - throw RuntimeException("Transcription failed with error code: $status") + throw ExecutorchRuntimeException(status, "Transcription failed") } return result.toString() diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java index 8f4292c1bc8..58c7704b83e 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java @@ -93,7 +93,7 @@ public static SGD create(Map namedParameters, double learningRat */ public void step(Map namedGradients) { if (!mHybridData.isValid()) { - throw new RuntimeException("Attempt to use a destroyed SGD optimizer"); + throw new IllegalStateException("SGD optimizer has been destroyed"); } stepNative(namedGradients); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java index 4a6653cb7a1..ca4bac9aa54 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java @@ -8,12 +8,11 @@ package org.pytorch.executorch.training; -import android.util.Log; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; -import java.util.HashMap; +import java.io.Closeable; import java.util.Map; import org.pytorch.executorch.EValue; import org.pytorch.executorch.ExecuTorchRuntime; @@ -26,7 +25,7 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public class TrainingModule { +public class TrainingModule implements Closeable { static { if (!NativeLoader.isInitialized()) { @@ -37,6 +36,7 @@ public class TrainingModule { } private final HybridData mHybridData; + private boolean mDestroyed = false; @DoNotStrip private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath); @@ -45,6 +45,10 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) { mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath); } + private void checkNotDestroyed() { + if (mDestroyed) throw new IllegalStateException("TrainingModule has been destroyed"); + } + /** * Loads a serialized ExecuTorch Training Module from the specified path on the disk. * @@ -78,10 +82,7 @@ public static TrainingModule load(final String modelPath) { * @return return value(s) from the method. */ public EValue[] executeForwardBackward(String methodName, EValue... inputs) { - if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return new EValue[0]; - } + checkNotDestroyed(); return executeForwardBackwardNative(methodName, inputs); } @@ -89,10 +90,7 @@ public EValue[] executeForwardBackward(String methodName, EValue... inputs) { private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs); public Map namedParameters(String methodName) { - if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return new HashMap(); - } + checkNotDestroyed(); return namedParametersNative(methodName); } @@ -100,13 +98,17 @@ public Map namedParameters(String methodName) { private native Map namedParametersNative(String methodName); public Map namedGradients(String methodName) { - if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return new HashMap(); - } + checkNotDestroyed(); return namedGradientsNative(methodName); } @DoNotStrip private native Map namedGradientsNative(String methodName); + + @Override + public void close() { + if (mDestroyed) return; + mDestroyed = true; + mHybridData.resetNative(); + } } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 88e9f9e2a12..0cf08e41983 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -284,8 +284,18 @@ class ExecuTorchJni : public facebook::jni::HybridClass { #else auto etdump_gen = nullptr; #endif - module_ = std::make_unique( - modelPath->toStdString(), load_mode, std::move(etdump_gen)); + try { + module_ = std::make_unique( + modelPath->toStdString(), load_mode, std::move(etdump_gen)); + } catch (const std::exception& e) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::Internal), + std::string("Failed to create Module: ") + e.what()); + } catch (...) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::Internal), + "Failed to create Module: unknown native error"); + } #ifdef ET_USE_THREADPOOL // Default to using cores/2 threadpool threads. The long-term plan is to diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 94c0efff335..0c1ff5c67b9 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -149,103 +149,117 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint num_bos = 0, jint num_eos = 0, jint load_mode = 1) { - temperature_ = temperature; - num_bos_ = num_bos; - num_eos_ = num_eos; + try { + temperature_ = temperature; + num_bos_ = num_bos; + num_eos_ = num_eos; #if defined(ET_USE_THREADPOOL) - // Reserve 1 thread for the main thread. - int32_t num_performant_cores = - ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; - if (num_performant_cores > 0) { - ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores); - ::executorch::extension::threadpool::get_threadpool() - ->_unsafe_reset_threadpool(num_performant_cores); - } + // Reserve 1 thread for the main thread. + int32_t num_performant_cores = + ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; + if (num_performant_cores > 0) { + ET_LOG( + Info, "Resetting threadpool to %d threads", num_performant_cores); + ::executorch::extension::threadpool::get_threadpool() + ->_unsafe_reset_threadpool(num_performant_cores); + } #endif - model_type_category_ = model_type_category; - auto cpp_load_mode = load_mode_from_int(load_mode); - std::vector data_files_vector; - if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { - runner_ = llm::create_multimodal_runner( - model_path->toStdString().c_str(), - llm::load_tokenizer(tokenizer_path->toStdString()), - std::nullopt, - cpp_load_mode); - } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { - if (data_files != nullptr) { - // Convert Java List to C++ std::vector - auto list_class = facebook::jni::findClassStatic("java/util/List"); - auto size_method = list_class->getMethod("size"); - auto get_method = - list_class->getMethod(jint)>( - "get"); - - jint size = size_method(data_files); - for (jint i = 0; i < size; ++i) { - auto str_obj = get_method(data_files, i); - auto jstr = facebook::jni::static_ref_cast(str_obj); - data_files_vector.push_back(jstr->toStdString()); + model_type_category_ = model_type_category; + auto cpp_load_mode = load_mode_from_int(load_mode); + std::vector data_files_vector; + if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { + runner_ = llm::create_multimodal_runner( + model_path->toStdString().c_str(), + llm::load_tokenizer(tokenizer_path->toStdString()), + std::nullopt, + cpp_load_mode); + } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { + if (data_files != nullptr) { + // Convert Java List to C++ std::vector + auto list_class = facebook::jni::findClassStatic("java/util/List"); + auto size_method = list_class->getMethod("size"); + auto get_method = + list_class->getMethod(jint)>( + "get"); + + jint size = size_method(data_files); + for (jint i = 0; i < size; ++i) { + auto str_obj = get_method(data_files, i); + auto jstr = facebook::jni::static_ref_cast(str_obj); + data_files_vector.push_back(jstr->toStdString()); + } } - } - runner_ = executorch::extension::llm::create_text_llm_runner( - model_path->toStdString(), - llm::load_tokenizer(tokenizer_path->toStdString()), - data_files_vector, - /*temperature=*/-1.0f, - /*event_tracer=*/nullptr, - /*method_name=*/"forward", - cpp_load_mode); + runner_ = executorch::extension::llm::create_text_llm_runner( + model_path->toStdString(), + llm::load_tokenizer(tokenizer_path->toStdString()), + data_files_vector, + /*temperature=*/-1.0f, + /*event_tracer=*/nullptr, + /*method_name=*/"forward", + cpp_load_mode); #if defined(EXECUTORCH_BUILD_QNN) - } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { - std::unique_ptr module = - std::make_unique( - model_path->toStdString().c_str(), - data_files_vector, - cpp_load_mode); - std::string decoder_model = "llama3"; // use llama3 for now - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (module->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - module->get("get_kv_io_bit_width").get().toScalar().to()); - } + } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { + std::unique_ptr module = + std::make_unique( + model_path->toStdString().c_str(), + data_files_vector, + cpp_load_mode); + std::string decoder_model = "llama3"; // use llama3 for now + // Using 8bit as default since this meta is introduced with 16bit kv io + // support and older models only have 8bit kv io. + example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; + if (module->method_names()->count("get_kv_io_bit_width") > 0) { + kv_bitwidth = static_cast( + module->get("get_kv_io_bit_width") + .get() + .toScalar() + .to()); + } - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } - model_type_category_ = MODEL_TYPE_CATEGORY_LLM; + if (kv_bitwidth == example::KvBitWidth::kWidth8) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else { + ET_CHECK_MSG( + false, + "Unsupported kv bitwidth: %ld", + static_cast(kv_bitwidth)); + } + model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif #if defined(EXECUTORCH_BUILD_MEDIATEK) - } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { - runner_ = std::make_unique( - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str()); - // Interpret the model type as LLM - model_type_category_ = MODEL_TYPE_CATEGORY_LLM; + } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { + runner_ = std::make_unique( + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str()); + // Interpret the model type as LLM + model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif + } + } catch (const std::exception& e) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::Internal), + std::string("Failed to create LlmModule: ") + e.what()); + } catch (...) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::Internal), + "Failed to create LlmModule: unknown native error"); } } From c1d482e2362f60a2c2e467f32c8da70aa594a1bc Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan <33402477+DrJessop@users.noreply.github.com> Date: Fri, 24 Apr 2026 21:40:50 -0700 Subject: [PATCH 09/21] Merge back to back slices on the same dim Differential Revision: D102425537 Pull Request resolved: https://github.com/pytorch/executorch/pull/19128 --- backends/cadence/aot/fuse_ops.py | 71 +++++ .../aot/tests/test_fusion_ops_passes.py | 255 +++++++++++++++++- 2 files changed, 325 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index d6ee88e94c6..97e396fba51 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -31,6 +31,7 @@ HierarchicalInplacePassInterface, register_cadence_pass, RemoveOrReplacePassInterface, + set_arg, ) from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import ( @@ -1003,6 +1004,75 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class FuseSliceSameDimPass(RemoveOrReplacePassInterface): + """Fuse chained slices on the same dim into a single slice. + + When a slice_copy's input is another slice_copy on the same dimension + with step=1, the child slice can read directly from the grandparent + with merged indices, eliminating the intermediate slice. + + Handles negative start/end indices by canonicalizing them against the + relevant dimension size before merging. + """ + + @staticmethod + def _canonicalize(val: int, dim_size: int) -> int: + return val + dim_size if val < 0 else val + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.slice_copy.Tensor] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + parent = get_arg(node, "input", torch.fx.Node) + if parent.target != exir_ops.edge.aten.slice_copy.Tensor: + return False + + grandparent = get_arg(parent, "input", torch.fx.Node) + ndim = len(grandparent.meta["val"].shape) + child_dim = get_arg(node, "dim", int) % ndim + parent_dim = get_arg(parent, "dim", int) % ndim + if child_dim != parent_dim: + return False + + child_start = get_arg(node, "start", Optional[int]) + child_end = get_arg(node, "end", Optional[int]) + child_step = get_arg(node, "step", int) + parent_start = get_arg(parent, "start", Optional[int]) + parent_end = get_arg(parent, "end", Optional[int]) + parent_step = get_arg(parent, "step", int) + + if child_step != 1 or parent_step != 1: + return False + if ( + child_start is None + or child_end is None + or parent_start is None + or parent_end is None + ): + return False + + grandparent_dim_size = grandparent.meta["val"].shape[parent_dim] + parent_dim_size = parent.meta["val"].shape[parent_dim] + + p_start = self._canonicalize(parent_start, grandparent_dim_size) + p_end = self._canonicalize(parent_end, grandparent_dim_size) + c_start = self._canonicalize(child_start, parent_dim_size) + c_end = self._canonicalize(child_end, parent_dim_size) + + new_start = p_start + c_start + new_end = min(p_start + c_end, p_end) + + if new_end > grandparent_dim_size: + return False + + node.replace_input_with(parent, grandparent) + set_arg(node, "start", new_start) + set_arg(node, "end", new_end) + return True + + class HierarchicalCSEPass(HierarchicalInplacePassInterface): """ A hierarchical Common Subexpression Elimination (CSE) pass that recursively @@ -1035,4 +1105,5 @@ class CadenceFuseOpsInGraph: FuseMulScalarIntoDequantPass, FuseFullThenReshapePass, FuseTransposeOrPermuteOpPairsPass, + FuseSliceSameDimPass, ] diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index f5afbe243f8..57145404726 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -25,10 +25,15 @@ FuseMulTensorIntoQuantPass, FuseQuantDequantToRequantizePass, FuseQuantizedBatchNormWithConv, + FuseSliceSameDimPass, FuseTransposeOrPermuteOpPairsPass, HierarchicalCSEPass, ) -from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match +from executorch.backends.cadence.aot.pass_utils import ( + count_node, + get_arg, + op_counts_match, +) from executorch.backends.cadence.aot.typing_stubs import expand from executorch.backends.test.graph_builder import GraphBuilder from executorch.exir.dialects._ops import ops as exir_ops @@ -1696,3 +1701,251 @@ def __init__(self) -> None: # Verify fusion occurred: bn should be removed, conv remains self.assertEqual(count_node(gm, conv_op), 1) self.assertEqual(count_node(gm, bn_op), 0) + + +class TestFuseSliceSameDimPass(TestFusionPassesBase): + def _get_single_slice(self, gm: torch.fx.GraphModule) -> torch.fx.Node: + slices = gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + self.assertEqual(len(slices), 1) + return slices[0] + + def test_basic_chain_bypass(self) -> None: + """slice(dim=3, 0:78) → slice(dim=3, 0:60) → direct slice(dim=3, 0:60).""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 80)) + parent = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, 3, 0, 78, 1), + ) + child = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(parent, 3, 0, 60, 1), + ) + builder.output([child]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseSliceSameDimPass()(original)) + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 + ) + merged = self._get_single_slice(result.graph_module) + self.assertEqual(get_arg(merged, "start"), 0) + self.assertEqual(get_arg(merged, "end"), 60) + validate_numerics( + gm_before, + result.graph_module, + (torch.randn(2, 3, 4, 80),), + "FuseSliceSameDimPass", + ) + + def test_chain_with_offset(self) -> None: + """slice(dim=1, 10:50) → slice(dim=1, 5:20) → direct slice(dim=1, 15:30).""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 64)) + parent = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, 1, 10, 50, 1), + ) + child = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(parent, 1, 5, 20, 1), + ) + builder.output([child]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseSliceSameDimPass()(original)) + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 + ) + merged = self._get_single_slice(result.graph_module) + self.assertEqual(get_arg(merged, "start"), 15) + self.assertEqual(get_arg(merged, "end"), 30) + validate_numerics( + gm_before, + result.graph_module, + (torch.randn(4, 64),), + "FuseSliceSameDimPass", + ) + + def test_parent_kept_with_other_users(self) -> None: + """Parent slice has another user besides the child → parent stays.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 80)) + parent = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, 3, 0, 78, 1), + ) + child = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(parent, 3, 0, 60, 1), + ) + neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(parent,)) + builder.output([child, neg]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseSliceSameDimPass()(original)) + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 2 + ) + slices = result.graph_module.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + ends = sorted(get_arg(s, "end") for s in slices) + self.assertEqual(ends, [60, 78]) + validate_numerics( + gm_before, + result.graph_module, + (torch.randn(2, 3, 4, 80),), + "FuseSliceSameDimPass", + ) + + def test_different_dims_no_change(self) -> None: + """Chained slices on different dims → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(8, 16, 32)) + parent = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, 1, 0, 10, 1), + ) + child = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(parent, 2, 0, 5, 1), + ) + builder.output([child]) + original = builder.get_graph_module() + + result = cast(PassResult, FuseSliceSameDimPass()(original)) + self.assertFalse(result.modified) + + def test_step_not_one_no_change(self) -> None: + """Parent has step != 1 → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 64)) + parent = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, 1, 0, 60, 2), + ) + child = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(parent, 1, 0, 10, 1), + ) + builder.output([child]) + original = builder.get_graph_module() + + result = cast(PassResult, FuseSliceSameDimPass()(original)) + self.assertFalse(result.modified) + + def test_no_chain_no_change(self) -> None: + """Single slice with no slice user → no change.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4, 64)) + sliced = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, 1, 0, 32, 1), + ) + builder.output([sliced]) + original = builder.get_graph_module() + + result = cast(PassResult, FuseSliceSameDimPass()(original)) + self.assertFalse(result.modified) + + def test_child_end_clamped_to_parent_range(self) -> None: + """Child end exceeds parent output size → clamped to parent_end.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 100)) + parent = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, 1, 10, 50, 1), + ) + child = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(parent, 1, 5, 45, 1), + ) + builder.output([child]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseSliceSameDimPass()(original)) + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 + ) + merged = self._get_single_slice(result.graph_module) + self.assertEqual(get_arg(merged, "start"), 15) + self.assertEqual(get_arg(merged, "end"), 50) + validate_numerics( + gm_before, + result.graph_module, + (torch.randn(1, 100),), + "FuseSliceSameDimPass", + ) + + def test_negative_indices(self) -> None: + """Negative start/end are canonicalized before merging.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 100)) + parent = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, 1, 10, -10, 1), + ) + child = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(parent, 1, 5, -5, 1), + ) + builder.output([child]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseSliceSameDimPass()(original)) + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 + ) + merged = self._get_single_slice(result.graph_module) + self.assertEqual(get_arg(merged, "start"), 15) + self.assertEqual(get_arg(merged, "end"), 85) + validate_numerics( + gm_before, + result.graph_module, + (torch.randn(1, 100),), + "FuseSliceSameDimPass", + ) + + def test_negative_dim(self) -> None: + """Negative dim is canonicalized so matching works across conventions.""" + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + parent = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, -1, 0, 4, 1), + ) + child = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(parent, 3, 0, 2, 1), + ) + builder.output([child]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + result = cast(PassResult, FuseSliceSameDimPass()(original)) + self.assertTrue(result.modified) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 + ) + merged = self._get_single_slice(result.graph_module) + self.assertEqual(get_arg(merged, "start"), 0) + self.assertEqual(get_arg(merged, "end"), 2) + validate_numerics( + gm_before, + result.graph_module, + (torch.randn(2, 3, 4, 5),), + "FuseSliceSameDimPass", + ) From 52527040f5833fbc9077d5de45ab05135d443461 Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Sat, 25 Apr 2026 12:50:24 -0700 Subject: [PATCH 10/21] Add C++ unit tests for cadence::quantized_conv2d_nhwc + add depthwise_nhwc operator + tests Differential Revision: D96507563 Pull Request resolved: https://github.com/pytorch/executorch/pull/18479 --- backends/cadence/aot/BUCK | 2 + backends/cadence/aot/functions.yaml | 5 ++ backends/cadence/aot/ops_registrations.py | 49 +++++++++++++++++++ backends/cadence/aot/ref_implementations.py | 35 +++++++++++++ .../generic/operators/op_quantized_conv2d.cpp | 34 +++++++++++++ .../generic/operators/op_quantized_conv2d.h | 18 +++++++ 6 files changed, 143 insertions(+) diff --git a/backends/cadence/aot/BUCK b/backends/cadence/aot/BUCK index 5b5316245f8..ae884e29deb 100644 --- a/backends/cadence/aot/BUCK +++ b/backends/cadence/aot/BUCK @@ -156,6 +156,8 @@ fbcode_target(_kind = executorch_generated_lib, "//executorch/backends/cadence/generic/operators:op_quantized_conv2d", "//executorch/backends/cadence/generic/operators:op_quantized_conv1d_ncl", "//executorch/backends/cadence/generic/operators:op_quantized_conv1d_nlc", + "//executorch/backends/cadence/generic/operators:op_quantized_depthwise_conv1d_ncl", + "//executorch/backends/cadence/generic/operators:op_quantized_depthwise_conv1d_nlc", "//executorch/backends/cadence/generic/operators:op_quantized_fully_connected", "//executorch/backends/cadence/generic/operators:op_quantized_layer_norm", "//executorch/backends/cadence/generic/operators:op_quantized_linear", diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 2eed2f4c486..60fda2853a3 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -389,6 +389,11 @@ - arg_meta: null kernel_name: impl::generic::quantized_conv2d_nhwc_per_tensor_out +- func: cadence::quantized_conv2d_depthwise_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_conv2d_depthwise_nhwc_out + - func: cadence::quantized_conv1d_ncl.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 0effaf3e029..131c85c9ab1 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -238,6 +238,12 @@ def register_fake( lib.define( "quantized_conv2d_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, Tensor? offset=None, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_conv2d_depthwise_nhwc(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv2d_depthwise_nhwc.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_conv1d_ncl(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)" ) @@ -2105,6 +2111,49 @@ def quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) +@register_fake("cadence::quantized_conv2d_depthwise_nhwc") +def quantized_conv2d_depthwise_nhwc_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + in_size = input.shape + assert len(in_size) > 2 + assert len(in_size) < 6 + # Depthwise weight is always [*kernel_size, OC]: + # 2D: [KH, KW, OC], 1D: [K, OC] + *kernel_size, out_channels = weight.shape + + output_size = ( + get_conv1d_output_size( + in_size, + out_channels, + stride[-1], + padding[-1], + dilation[-1], + kernel_size[0], + True, + ) + if len(in_size) == 3 + else get_conv2d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size, True + ) + ) + + return input.new_empty(output_size, dtype=input.dtype) + + @register_fake("cadence::quantized_layer_norm") def quantized_layer_norm_meta( input: torch.Tensor, diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 32558166fbf..6c780782070 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1556,6 +1556,41 @@ def quantized_conv2d_nhwc( ) +@impl_tracked(m, "quantized_conv2d_depthwise_nhwc") +def quantized_conv2d_depthwise_nhwc( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int, int], + padding: tuple[int, int], + dilation: tuple[int, int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + return quantized_conv2d_nhwc_per_tensor( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ) + + def quantized_conv_variant( layout: str, input_dtype: torch.dtype, diff --git a/backends/cadence/generic/operators/op_quantized_conv2d.cpp b/backends/cadence/generic/operators/op_quantized_conv2d.cpp index 8cf24015893..0811267a3b8 100644 --- a/backends/cadence/generic/operators/op_quantized_conv2d.cpp +++ b/backends/cadence/generic/operators/op_quantized_conv2d.cpp @@ -955,6 +955,40 @@ Tensor& quantized_conv2d_nhwc_per_tensor_out( return out; } +Tensor& quantized_conv2d_depthwise_nhwc_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + static_cast(groups), + static_cast(in_zero_point), + static_cast(weight_zero_point), + static_cast(bias_scale), + static_cast(output_scale), + static_cast(output_zero_point), + out); + return out; +} + Tensor& quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out( ET_UNUSED KernelRuntimeContext& ctx, const Tensor& input, diff --git a/backends/cadence/generic/operators/op_quantized_conv2d.h b/backends/cadence/generic/operators/op_quantized_conv2d.h index 00cf62eba70..bb9476e2644 100644 --- a/backends/cadence/generic/operators/op_quantized_conv2d.h +++ b/backends/cadence/generic/operators/op_quantized_conv2d.h @@ -208,6 +208,24 @@ ::executorch::aten::Tensor& quantized_conv2d_nhwc_per_tensor_out( const ::executorch::aten::optional& offset, Tensor& out); +::executorch::aten::Tensor& quantized_conv2d_depthwise_nhwc_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + ::executorch::aten::Tensor& quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out( KernelRuntimeContext& ctx, From 6b175ff5940f58a07fed726e7014ac29522a999c Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Sat, 25 Apr 2026 17:29:05 -0700 Subject: [PATCH 11/21] Revert Android PRs #19099, #19124, #19092, #19028 (#19133) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Reverts the following Android PRs: - #19099 — Android: consistent error types across all modules - #19124 — Android: Module implements Closeable - #19092 — Android: improve error diagnostics for LlmModule and exceptions - #19028 — Ignored Module tests: provide required input tensor Authored with Claude. --- .../executorch/ModuleInstrumentationTest.kt | 70 +++--- .../pytorch/executorch/ExecuTorchRuntime.java | 16 +- .../ExecutorchRuntimeException.java | 5 - .../java/org/pytorch/executorch/Module.java | 12 +- .../executorch/extension/asr/AsrModule.kt | 10 +- .../org/pytorch/executorch/training/SGD.java | 2 +- .../executorch/training/TrainingModule.java | 32 ++- extension/android/jni/jni_layer.cpp | 14 +- extension/android/jni/jni_layer_llama.cpp | 211 ++++++++---------- 9 files changed, 158 insertions(+), 214 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index eb2b6f096a1..ba91f444287 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -17,6 +17,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Before +import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TestFileUtils.getTestFilePath @@ -39,49 +40,48 @@ class ModuleInstrumentationTest { inputStream.close() } + @Ignore( + "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " + ) @Test @Throws(IOException::class, URISyntaxException::class) fun testModuleLoadAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - try { - val results = module.forward(EValue.from(dummyInput())) - Assert.assertTrue(results[0].isTensor) - } finally { - module.destroy() - } + + val results = module.forward() + Assert.assertTrue(results[0].isTensor) } @Test @Throws(IOException::class, URISyntaxException::class) fun testMethodMetadata() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - module.destroy() } + @Ignore( + "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " + ) @Test @Throws(IOException::class) fun testModuleLoadMethodAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - try { - module.loadMethod(FORWARD_METHOD) - val results = module.forward(EValue.from(dummyInput())) - Assert.assertTrue(results[0].isTensor) - } finally { - module.destroy() - } + module.loadMethod(FORWARD_METHOD) + + val results = module.forward() + Assert.assertTrue(results[0].isTensor) } + @Ignore( + "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " + ) @Test @Throws(IOException::class) fun testModuleLoadForwardExplicit() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - try { - val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput())) - Assert.assertTrue(results[0].isTensor) - } finally { - module.destroy() - } + + val results = module.execute(FORWARD_METHOD) + Assert.assertTrue(results[0].isTensor) } @Test(expected = RuntimeException::class) @@ -94,18 +94,15 @@ class ModuleInstrumentationTest { @Throws(IOException::class) fun testModuleLoadMethodNonExistantMethod() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - try { - val exception = - Assert.assertThrows(ExecutorchRuntimeException::class.java) { - module.loadMethod(NONE_METHOD) - } - Assert.assertEquals( - ExecutorchRuntimeException.INVALID_ARGUMENT, - exception.getErrorCode(), - ) - } finally { - module.destroy() - } + + val exception = + Assert.assertThrows(ExecutorchRuntimeException::class.java) { + module.loadMethod(NONE_METHOD) + } + Assert.assertEquals( + ExecutorchRuntimeException.INVALID_ARGUMENT, + exception.getErrorCode(), + ) } @Test(expected = RuntimeException::class) @@ -138,6 +135,9 @@ class ModuleInstrumentationTest { Assert.assertThrows(IllegalStateException::class.java) { module.forward() } } + @Ignore( + "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " + ) @Test @Throws(InterruptedException::class, IOException::class) fun testForwardFromMultipleThreads() { @@ -151,7 +151,7 @@ class ModuleInstrumentationTest { try { latch.countDown() latch.await(5000, TimeUnit.MILLISECONDS) - val results = module.forward(EValue.from(dummyInput())) + val results = module.forward() Assert.assertTrue(results[0].isTensor) completed.incrementAndGet() } catch (_: InterruptedException) {} @@ -168,7 +168,6 @@ class ModuleInstrumentationTest { } Assert.assertEquals(numThreads.toLong(), completed.get().toLong()) - module.destroy() } companion object { @@ -177,8 +176,5 @@ class ModuleInstrumentationTest { private const val NON_PTE_FILE_NAME = "/test.txt" private const val FORWARD_METHOD = "forward" private const val NONE_METHOD = "none" - private val inputShape = longArrayOf(1, 3, 224, 224) - - private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT) } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java index 53ee4d3f33a..30ebf1a2c1d 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java @@ -36,22 +36,12 @@ public static ExecuTorchRuntime getRuntime() { /** * Validates that the given path points to a readable file. * - * @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not - * readable. + * @throws RuntimeException if the file does not exist or is not readable. */ public static void validateFilePath(String path, String description) { - if (path == null) { - throw new IllegalArgumentException("Cannot load " + description + ": path is null"); - } File file = new File(path); - if (!file.exists()) { - throw new IllegalArgumentException("Cannot load " + description + "!! " + path); - } - if (!file.isFile()) { - throw new IllegalArgumentException("Cannot load " + description + "!! " + path); - } - if (!file.canRead()) { - throw new IllegalArgumentException("Cannot load " + description + "!! " + path); + if (!file.canRead() || !file.isFile()) { + throw new RuntimeException("Cannot load " + description + " " + path); } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java index e72ed9e3d28..e0fda73cc06 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java @@ -161,11 +161,6 @@ public ExecutorchRuntimeException(int errorCode, String details) { this.errorCode = errorCode; } - public ExecutorchRuntimeException(int errorCode, String details, Throwable cause) { - super(ErrorHelper.formatMessage(errorCode, details), cause); - this.errorCode = errorCode; - } - /** Returns the numeric error code from {@code runtime/core/error.h}. */ public int getErrorCode() { return errorCode; diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index 6cf99966e6a..05e1e5b88cf 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -12,7 +12,6 @@ import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; -import java.io.Closeable; import java.util.HashMap; import java.util.Map; import java.util.concurrent.locks.Lock; @@ -25,7 +24,7 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public class Module implements Closeable { +public class Module { static { if (!NativeLoader.isInitialized()) { @@ -275,9 +274,7 @@ public boolean etdump() { public void destroy() { if (mLock.tryLock()) { try { - if (mHybridData.isValid()) { - mHybridData.resetNative(); - } + mHybridData.resetNative(); } finally { mLock.unlock(); } @@ -285,9 +282,4 @@ public void destroy() { throw new IllegalStateException("Cannot destroy module while method is executing"); } } - - @Override - public void close() { - destroy(); - } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt index ab9099ba405..987cb3ec3be 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt @@ -11,7 +11,6 @@ package org.pytorch.executorch.extension.asr import java.io.Closeable import java.io.File import java.util.concurrent.atomic.AtomicLong -import org.pytorch.executorch.ExecutorchRuntimeException import org.pytorch.executorch.annotations.Experimental /** @@ -54,10 +53,7 @@ class AsrModule( val handle = nativeCreate(modelPath, tokenizerPath, dataPath, preprocessorPath) if (handle == 0L) { - throw ExecutorchRuntimeException( - ExecutorchRuntimeException.INTERNAL, - "Failed to create native AsrModule", - ) + throw RuntimeException("Failed to create native AsrModule") } nativeHandle.set(handle) } @@ -133,7 +129,7 @@ class AsrModule( * @param callback Optional callback to receive tokens as they are generated (can be null) * @return The complete transcribed text * @throws IllegalStateException if the module has been destroyed - * @throws ExecutorchRuntimeException if transcription fails (error code carried in exception) + * @throws RuntimeException if transcription fails (non-zero result code) */ @JvmOverloads fun transcribe( @@ -164,7 +160,7 @@ class AsrModule( ) if (status != 0) { - throw ExecutorchRuntimeException(status, "Transcription failed") + throw RuntimeException("Transcription failed with error code: $status") } return result.toString() diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java index 58c7704b83e..8f4292c1bc8 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java @@ -93,7 +93,7 @@ public static SGD create(Map namedParameters, double learningRat */ public void step(Map namedGradients) { if (!mHybridData.isValid()) { - throw new IllegalStateException("SGD optimizer has been destroyed"); + throw new RuntimeException("Attempt to use a destroyed SGD optimizer"); } stepNative(namedGradients); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java index ca4bac9aa54..4a6653cb7a1 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java @@ -8,11 +8,12 @@ package org.pytorch.executorch.training; +import android.util.Log; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; -import java.io.Closeable; +import java.util.HashMap; import java.util.Map; import org.pytorch.executorch.EValue; import org.pytorch.executorch.ExecuTorchRuntime; @@ -25,7 +26,7 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public class TrainingModule implements Closeable { +public class TrainingModule { static { if (!NativeLoader.isInitialized()) { @@ -36,7 +37,6 @@ public class TrainingModule implements Closeable { } private final HybridData mHybridData; - private boolean mDestroyed = false; @DoNotStrip private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath); @@ -45,10 +45,6 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) { mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath); } - private void checkNotDestroyed() { - if (mDestroyed) throw new IllegalStateException("TrainingModule has been destroyed"); - } - /** * Loads a serialized ExecuTorch Training Module from the specified path on the disk. * @@ -82,7 +78,10 @@ public static TrainingModule load(final String modelPath) { * @return return value(s) from the method. */ public EValue[] executeForwardBackward(String methodName, EValue... inputs) { - checkNotDestroyed(); + if (!mHybridData.isValid()) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new EValue[0]; + } return executeForwardBackwardNative(methodName, inputs); } @@ -90,7 +89,10 @@ public EValue[] executeForwardBackward(String methodName, EValue... inputs) { private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs); public Map namedParameters(String methodName) { - checkNotDestroyed(); + if (!mHybridData.isValid()) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new HashMap(); + } return namedParametersNative(methodName); } @@ -98,17 +100,13 @@ public Map namedParameters(String methodName) { private native Map namedParametersNative(String methodName); public Map namedGradients(String methodName) { - checkNotDestroyed(); + if (!mHybridData.isValid()) { + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new HashMap(); + } return namedGradientsNative(methodName); } @DoNotStrip private native Map namedGradientsNative(String methodName); - - @Override - public void close() { - if (mDestroyed) return; - mDestroyed = true; - mHybridData.resetNative(); - } } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 0cf08e41983..88e9f9e2a12 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -284,18 +284,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass { #else auto etdump_gen = nullptr; #endif - try { - module_ = std::make_unique( - modelPath->toStdString(), load_mode, std::move(etdump_gen)); - } catch (const std::exception& e) { - executorch::jni_helper::throwExecutorchException( - static_cast(Error::Internal), - std::string("Failed to create Module: ") + e.what()); - } catch (...) { - executorch::jni_helper::throwExecutorchException( - static_cast(Error::Internal), - "Failed to create Module: unknown native error"); - } + module_ = std::make_unique( + modelPath->toStdString(), load_mode, std::move(etdump_gen)); #ifdef ET_USE_THREADPOOL // Default to using cores/2 threadpool threads. The long-term plan is to diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 0c1ff5c67b9..2c0117dc576 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -149,117 +148,103 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint num_bos = 0, jint num_eos = 0, jint load_mode = 1) { - try { - temperature_ = temperature; - num_bos_ = num_bos; - num_eos_ = num_eos; + temperature_ = temperature; + num_bos_ = num_bos; + num_eos_ = num_eos; #if defined(ET_USE_THREADPOOL) - // Reserve 1 thread for the main thread. - int32_t num_performant_cores = - ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; - if (num_performant_cores > 0) { - ET_LOG( - Info, "Resetting threadpool to %d threads", num_performant_cores); - ::executorch::extension::threadpool::get_threadpool() - ->_unsafe_reset_threadpool(num_performant_cores); - } + // Reserve 1 thread for the main thread. + int32_t num_performant_cores = + ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; + if (num_performant_cores > 0) { + ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores); + ::executorch::extension::threadpool::get_threadpool() + ->_unsafe_reset_threadpool(num_performant_cores); + } #endif - model_type_category_ = model_type_category; - auto cpp_load_mode = load_mode_from_int(load_mode); - std::vector data_files_vector; - if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { - runner_ = llm::create_multimodal_runner( - model_path->toStdString().c_str(), - llm::load_tokenizer(tokenizer_path->toStdString()), - std::nullopt, - cpp_load_mode); - } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { - if (data_files != nullptr) { - // Convert Java List to C++ std::vector - auto list_class = facebook::jni::findClassStatic("java/util/List"); - auto size_method = list_class->getMethod("size"); - auto get_method = - list_class->getMethod(jint)>( - "get"); - - jint size = size_method(data_files); - for (jint i = 0; i < size; ++i) { - auto str_obj = get_method(data_files, i); - auto jstr = facebook::jni::static_ref_cast(str_obj); - data_files_vector.push_back(jstr->toStdString()); - } + model_type_category_ = model_type_category; + auto cpp_load_mode = load_mode_from_int(load_mode); + std::vector data_files_vector; + if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { + runner_ = llm::create_multimodal_runner( + model_path->toStdString().c_str(), + llm::load_tokenizer(tokenizer_path->toStdString()), + std::nullopt, + cpp_load_mode); + } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { + if (data_files != nullptr) { + // Convert Java List to C++ std::vector + auto list_class = facebook::jni::findClassStatic("java/util/List"); + auto size_method = list_class->getMethod("size"); + auto get_method = + list_class->getMethod(jint)>( + "get"); + + jint size = size_method(data_files); + for (jint i = 0; i < size; ++i) { + auto str_obj = get_method(data_files, i); + auto jstr = facebook::jni::static_ref_cast(str_obj); + data_files_vector.push_back(jstr->toStdString()); } - runner_ = executorch::extension::llm::create_text_llm_runner( - model_path->toStdString(), - llm::load_tokenizer(tokenizer_path->toStdString()), - data_files_vector, - /*temperature=*/-1.0f, - /*event_tracer=*/nullptr, - /*method_name=*/"forward", - cpp_load_mode); + } + runner_ = executorch::extension::llm::create_text_llm_runner( + model_path->toStdString(), + llm::load_tokenizer(tokenizer_path->toStdString()), + data_files_vector, + /*temperature=*/-1.0f, + /*event_tracer=*/nullptr, + /*method_name=*/"forward", + cpp_load_mode); #if defined(EXECUTORCH_BUILD_QNN) - } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { - std::unique_ptr module = - std::make_unique( - model_path->toStdString().c_str(), - data_files_vector, - cpp_load_mode); - std::string decoder_model = "llama3"; // use llama3 for now - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (module->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - module->get("get_kv_io_bit_width") - .get() - .toScalar() - .to()); - } - - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), + } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { + std::unique_ptr module = + std::make_unique( model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } - model_type_category_ = MODEL_TYPE_CATEGORY_LLM; + data_files_vector, + cpp_load_mode); + std::string decoder_model = "llama3"; // use llama3 for now + // Using 8bit as default since this meta is introduced with 16bit kv io + // support and older models only have 8bit kv io. + example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; + if (module->method_names()->count("get_kv_io_bit_width") > 0) { + kv_bitwidth = static_cast( + module->get("get_kv_io_bit_width").get().toScalar().to()); + } + + if (kv_bitwidth == example::KvBitWidth::kWidth8) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else { + ET_CHECK_MSG( + false, + "Unsupported kv bitwidth: %ld", + static_cast(kv_bitwidth)); + } + model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif #if defined(EXECUTORCH_BUILD_MEDIATEK) - } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { - runner_ = std::make_unique( - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str()); - // Interpret the model type as LLM - model_type_category_ = MODEL_TYPE_CATEGORY_LLM; + } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { + runner_ = std::make_unique( + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str()); + // Interpret the model type as LLM + model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif - } - } catch (const std::exception& e) { - executorch::jni_helper::throwExecutorchException( - static_cast(Error::Internal), - std::string("Failed to create LlmModule: ") + e.what()); - } catch (...) { - executorch::jni_helper::throwExecutorchException( - static_cast(Error::Internal), - "Failed to create LlmModule: unknown native error"); } } @@ -609,19 +594,21 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint load() { if (!runner_) { - std::stringstream ss; - ss << "Model runner was not created. model_type_category=" - << model_type_category_ - << ". Valid values: " << MODEL_TYPE_CATEGORY_LLM << " (LLM), " - << MODEL_TYPE_CATEGORY_MULTIMODAL << " (Multimodal)"; - executorch::jni_helper::throwExecutorchException( - static_cast(Error::InvalidState), ss.str().c_str()); + ET_LOG( + Error, + "ExecuTorchLlmJni::load() called but runner_ is null. " + "The model runner was not created or failed to initialize due to a " + "previous configuration or initialization error. " + "Model type category: %d.", + model_type_category_); return static_cast(Error::InvalidState); } const auto load_result = static_cast(runner_->load()); if (load_result != static_cast(Error::Ok)) { - executorch::jni_helper::throwExecutorchException( - static_cast(load_result), "Failed to load model runner"); + ET_LOG( + Error, + "ExecuTorchLlmJni::load() failed in runner_->load() with error code %d.", + static_cast(load_result)); } return load_result; } From bdf1bf4e9dcb6b814bb41482222ca1ac03f77b01 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Sat, 25 Apr 2026 18:37:41 -0700 Subject: [PATCH 12/21] Back out D102011505 and D101260086 (#19134) Differential Revision: D102488314 Pull Request resolved: https://github.com/pytorch/executorch/pull/19134 --- .../LlmModuleInstrumentationTest.kt | 21 +- .../executorch/ModuleInstrumentationTest.kt | 29 +- .../ExecutorchRuntimeException.java | 54 +-- .../java/org/pytorch/executorch/Module.java | 79 +--- .../executorch/extension/llm/LlmCallback.java | 8 +- .../executorch/extension/llm/LlmModule.java | 399 ++++++------------ extension/android/jni/jni_layer.cpp | 10 +- extension/android/jni/jni_layer_llama.cpp | 39 +- .../org/pytorch/minibench/LlmModelRunner.java | 15 +- .../org/pytorch/minibench/ModelRunner.java | 102 ++--- 10 files changed, 241 insertions(+), 515 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt index 0974a04af44..4b6c3caed94 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt @@ -54,7 +54,9 @@ class LlmModuleInstrumentationTest : LlmCallback { @Test @Throws(IOException::class, URISyntaxException::class) fun testGenerate() { - llmModule.load() + val loadResult = llmModule.load() + // Check that the model can be load successfully + assertEquals(OK.toLong(), loadResult.toLong()) llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) assertEquals(results.size.toLong(), SEQ_LEN.toLong()) @@ -271,26 +273,11 @@ class LlmModuleInstrumentationTest : LlmCallback { } } - // --- Lifecycle tests --- - - @Test - fun testUseAfterCloseThrows() { - llmModule.close() - assertThrows(IllegalStateException::class.java) { - llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) - } - } - - @Test - fun testCloseIsIdempotent() { - llmModule.close() - llmModule.close() - } - companion object { private const val TEST_FILE_NAME = "/stories.pte" private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" private const val TEST_PROMPT = "Hello" + private const val OK = 0x00 private const val SEQ_LEN = 32 } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index ba91f444287..99d53b6dba3 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -66,7 +66,8 @@ class ModuleInstrumentationTest { fun testModuleLoadMethodAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - module.loadMethod(FORWARD_METHOD) + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), OK.toLong()) val results = module.forward() Assert.assertTrue(results[0].isTensor) @@ -95,14 +96,8 @@ class ModuleInstrumentationTest { fun testModuleLoadMethodNonExistantMethod() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - val exception = - Assert.assertThrows(ExecutorchRuntimeException::class.java) { - module.loadMethod(NONE_METHOD) - } - Assert.assertEquals( - ExecutorchRuntimeException.INVALID_ARGUMENT, - exception.getErrorCode(), - ) + val loadMethod = module.loadMethod(NONE_METHOD) + Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) } @Test(expected = RuntimeException::class) @@ -110,7 +105,8 @@ class ModuleInstrumentationTest { fun testNonPteFile() { val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)) - module.loadMethod(FORWARD_METHOD) + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) } @Test @@ -120,7 +116,8 @@ class ModuleInstrumentationTest { module.destroy() - Assert.assertThrows(IllegalStateException::class.java) { module.loadMethod(FORWARD_METHOD) } + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong()) } @Test @@ -128,11 +125,13 @@ class ModuleInstrumentationTest { fun testForwardOnDestroyedModule() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - module.loadMethod(FORWARD_METHOD) + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), OK.toLong()) module.destroy() - Assert.assertThrows(IllegalStateException::class.java) { module.forward() } + val results = module.forward() + Assert.assertEquals(0, results.size.toLong()) } @Ignore( @@ -176,5 +175,9 @@ class ModuleInstrumentationTest { private const val NON_PTE_FILE_NAME = "/test.txt" private const val FORWARD_METHOD = "forward" private const val NONE_METHOD = "none" + private const val OK = 0x00 + private const val INVALID_STATE = 0x2 + private const val INVALID_ARGUMENT = 0x12 + private const val ACCESS_FAILED = 0x22 } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java index e0fda73cc06..102b96ab686 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java @@ -12,83 +12,34 @@ import java.util.HashMap; import java.util.Map; -/** - * Base exception for all ExecuTorch runtime errors. Each instance carries an integer error code - * corresponding to the native {@code runtime/core/error.h} values, accessible via {@link - * #getErrorCode()}. - */ public class ExecutorchRuntimeException extends RuntimeException { // Error code constants - keep in sync with runtime/core/error.h - // System errors - - /** Operation completed successfully. */ public static final int OK = 0x00; - - /** An unexpected internal error occurred in the runtime. */ public static final int INTERNAL = 0x01; - - /** The runtime or method is in an invalid state for the requested operation. */ public static final int INVALID_STATE = 0x02; - - /** The method has finished execution and has no more work to do. */ public static final int END_OF_METHOD = 0x03; - /** A required resource has already been loaded. */ - public static final int ALREADY_LOADED = 0x04; - // Logical errors - - /** The requested operation is not supported by this build or backend. */ public static final int NOT_SUPPORTED = 0x10; - - /** The requested operation has not been implemented. */ public static final int NOT_IMPLEMENTED = 0x11; - - /** One or more arguments passed to the operation are invalid. */ public static final int INVALID_ARGUMENT = 0x12; - - /** A value or tensor has an unexpected type. */ public static final int INVALID_TYPE = 0x13; - - /** A required operator kernel is not registered. */ public static final int OPERATOR_MISSING = 0x14; - - /** The maximum number of registered kernels has been exceeded. */ public static final int REGISTRATION_EXCEEDING_MAX_KERNELS = 0x15; - - /** A kernel with the same name is already registered. */ public static final int REGISTRATION_ALREADY_REGISTERED = 0x16; // Resource errors - - /** A required resource (file, tensor, program) was not found. */ public static final int NOT_FOUND = 0x20; - - /** A memory allocation failed. */ public static final int MEMORY_ALLOCATION_FAILED = 0x21; - - /** Access to a resource was denied or failed. */ public static final int ACCESS_FAILED = 0x22; - - /** The loaded program is malformed or incompatible. */ public static final int INVALID_PROGRAM = 0x23; - - /** External data referenced by the program is invalid or missing. */ public static final int INVALID_EXTERNAL_DATA = 0x24; - - /** The system has run out of a required resource. */ public static final int OUT_OF_RESOURCES = 0x25; // Delegate errors - - /** A delegate reported an incompatible model or configuration. */ public static final int DELEGATE_INVALID_COMPATIBILITY = 0x30; - - /** A delegate failed to allocate required memory. */ public static final int DELEGATE_MEMORY_ALLOCATION_FAILED = 0x31; - - /** A delegate received an invalid or stale handle. */ public static final int DELEGATE_INVALID_HANDLE = 0x32; private static final Map ERROR_CODE_MESSAGES; @@ -101,7 +52,6 @@ public class ExecutorchRuntimeException extends RuntimeException { map.put(INTERNAL, "Internal error"); map.put(INVALID_STATE, "Invalid state"); map.put(END_OF_METHOD, "End of method reached"); - map.put(ALREADY_LOADED, "Already loaded"); // Logical errors map.put(NOT_SUPPORTED, "Operation not supported"); map.put(NOT_IMPLEMENTED, "Operation not implemented"); @@ -133,7 +83,7 @@ static String formatMessage(int errorCode, String details) { String safeDetails = details != null ? details : "No details provided"; return String.format( - "[ExecuTorch Error 0x%s] %s: %s", + "[Executorch Error 0x%s] %s: %s", Integer.toHexString(errorCode), baseMessage, safeDetails); } @@ -161,12 +111,10 @@ public ExecutorchRuntimeException(int errorCode, String details) { this.errorCode = errorCode; } - /** Returns the numeric error code from {@code runtime/core/error.h}. */ public int getErrorCode() { return errorCode; } - /** Returns detailed log output captured from the native runtime, if available. */ public String getDetailedError() { return ErrorHelper.getDetailedErrorLogs(); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index 05e1e5b88cf..f7e2e37dcec 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -8,6 +8,7 @@ package org.pytorch.executorch; +import android.util.Log; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; @@ -129,10 +130,11 @@ public EValue[] forward(EValue... inputs) { * @return return value from the method. */ public EValue[] execute(String methodName, EValue... inputs) { - mLock.lock(); try { + mLock.lock(); if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return new EValue[0]; } return executeNative(methodName, inputs); } finally { @@ -149,17 +151,17 @@ public EValue[] execute(String methodName, EValue... inputs) { * synchronous, and will block until the method is loaded. Therefore, it is recommended to call * this on a background thread. However, users need to make sure that they don't execute before * this function returns. + * + * @return the Error code if there was an error loading the method */ - public void loadMethod(String methodName) { - mLock.lock(); + public int loadMethod(String methodName) { try { + mLock.lock(); if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); - } - int errorCode = loadMethodNative(methodName); - if (errorCode != 0) { - throw new ExecutorchRuntimeException(errorCode, "Failed to load method: " + methodName); + Log.e("ExecuTorch", "Attempt to use a destroyed module"); + return 0x2; // InvalidState } + return loadMethodNative(methodName); } finally { mLock.unlock(); } @@ -182,20 +184,8 @@ public void loadMethod(String methodName) { * * @return name of methods in this Module */ - public String[] getMethods() { - mLock.lock(); - try { - if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); - } - return getMethodsNative(); - } finally { - mLock.unlock(); - } - } - @DoNotStrip - private native String[] getMethodsNative(); + public native String[] getMethods(); /** * Get the corresponding @MethodMetadata for a method @@ -204,19 +194,11 @@ public String[] getMethods() { * @return @MethodMetadata for this method */ public MethodMetadata getMethodMetadata(String name) { - mLock.lock(); - try { - if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); - } - MethodMetadata methodMetadata = mMethodMetadata.get(name); - if (methodMetadata == null) { - throw new IllegalArgumentException("method " + name + " does not exist for this module"); - } - return methodMetadata; - } finally { - mLock.unlock(); + MethodMetadata methodMetadata = mMethodMetadata.get(name); + if (methodMetadata == null) { + throw new IllegalArgumentException("method " + name + " does not exist for this module"); } + return methodMetadata; } @DoNotStrip @@ -228,15 +210,7 @@ public static String[] readLogBufferStatic() { /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ public String[] readLogBuffer() { - mLock.lock(); - try { - if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); - } - return readLogBufferNative(); - } finally { - mLock.unlock(); - } + return readLogBufferNative(); } @DoNotStrip @@ -250,20 +224,8 @@ public String[] readLogBuffer() { * @return true if the etdump was successfully written, false otherwise. */ @Experimental - public boolean etdump() { - mLock.lock(); - try { - if (!mHybridData.isValid()) { - throw new IllegalStateException("Module has been destroyed"); - } - return etdumpNative(); - } finally { - mLock.unlock(); - } - } - @DoNotStrip - private native boolean etdumpNative(); + public native boolean etdump(); /** * Explicitly destroys the native Module object. Calling this method is not required, as the @@ -279,7 +241,10 @@ public void destroy() { mLock.unlock(); } } else { - throw new IllegalStateException("Cannot destroy module while method is executing"); + Log.w( + "ExecuTorch", + "Destroy was called while the module was in use. Resources will not be immediately" + + " released."); } } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java index ec0413caf2e..4e834d06721 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java @@ -46,11 +46,5 @@ default void onStats(String stats) {} * @param message Human-readable error description */ @DoNotStrip - default void onError(int errorCode, String message) { - try { - android.util.Log.e("ExecuTorch", "LLM error " + errorCode + ": " + message); - } catch (Throwable t) { - System.err.println("ExecuTorch LLM error " + errorCode + ": " + message); - } - } + default void onError(int errorCode, String message) {} } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index ecce94827d4..a563dc6bcc7 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -10,12 +10,9 @@ import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; -import java.io.Closeable; import java.nio.ByteBuffer; import java.util.List; -import java.util.concurrent.locks.ReentrantLock; import org.pytorch.executorch.ExecuTorchRuntime; -import org.pytorch.executorch.ExecutorchRuntimeException; import org.pytorch.executorch.annotations.Experimental; /** @@ -25,15 +22,13 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public class LlmModule implements Closeable { +public class LlmModule { public static final int MODEL_TYPE_TEXT = 1; public static final int MODEL_TYPE_TEXT_VISION = 2; public static final int MODEL_TYPE_MULTIMODAL = 2; private final HybridData mHybridData; - private final ReentrantLock mLock = new ReentrantLock(); - private boolean mDestroyed = false; private static final int DEFAULT_SEQ_LEN = 128; private static final boolean DEFAULT_ECHO = true; private static final float DEFAULT_TEMPERATURE = -1.0f; @@ -190,41 +185,8 @@ public LlmModule(LlmModuleConfig config) { config.getLoadMode()); } - private void checkNotDestroyed() { - if (mDestroyed) throw new IllegalStateException("LlmModule has been destroyed"); - } - - /** - * Releases native resources. Callers must ensure no other methods are in-flight. Call {@link - * #stop()} and wait for {@link #generate(String, LlmCallback)} to return before calling this - * method. - */ - @Override - public void close() { - if (mLock.tryLock()) { - try { - if (mLock.getHoldCount() > 1) { - throw new IllegalStateException( - "Cannot close module from within a callback during execution"); - } - if (!mDestroyed) { - mDestroyed = true; - mHybridData.resetNative(); - } - } finally { - mLock.unlock(); - } - } else { - throw new IllegalStateException("Cannot close module while method is executing"); - } - } - - /** - * @deprecated Use {@link #close()} instead. - */ - @Deprecated public void resetNative() { - close(); + mHybridData.resetNative(); } /** @@ -233,8 +195,8 @@ public void resetNative() { * @param prompt Input prompt * @param llmCallback callback object to receive results. */ - public void generate(String prompt, LlmCallback llmCallback) { - generate( + public int generate(String prompt, LlmCallback llmCallback) { + return generate( prompt, DEFAULT_SEQ_LEN, llmCallback, @@ -251,8 +213,8 @@ public void generate(String prompt, LlmCallback llmCallback) { * @param seqLen sequence length * @param llmCallback callback object to receive results. */ - public void generate(String prompt, int seqLen, LlmCallback llmCallback) { - generate( + public int generate(String prompt, int seqLen, LlmCallback llmCallback) { + return generate( null, 0, 0, @@ -273,8 +235,8 @@ public void generate(String prompt, int seqLen, LlmCallback llmCallback) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public void generate(String prompt, LlmCallback llmCallback, boolean echo) { - generate( + public int generate(String prompt, LlmCallback llmCallback, boolean echo) { + return generate( null, 0, 0, @@ -296,8 +258,9 @@ public void generate(String prompt, LlmCallback llmCallback, boolean echo) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public void generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { - generate(prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS); + public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { + return generate( + prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS); } /** @@ -311,28 +274,7 @@ public void generate(String prompt, int seqLen, LlmCallback llmCallback, boolean * @param numBos number of BOS tokens to prepend * @param numEos number of EOS tokens to append */ - public void generate( - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo, - float temperature, - int numBos, - int numEos) { - mLock.lock(); - try { - checkNotDestroyed(); - int err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); - if (err != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate"); - } - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native int generateNative( + public native int generate( String prompt, int seqLen, LlmCallback llmCallback, @@ -348,13 +290,13 @@ private native int generateNative( * @param config the config for generation * @param llmCallback callback object to receive results */ - public void generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) { + public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) { int seqLen = config.getSeqLen(); boolean echo = config.isEcho(); float temperature = config.getTemperature(); int numBos = config.getNumBos(); int numEos = config.getNumEos(); - generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); + return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); } /** @@ -369,7 +311,7 @@ public void generate(String prompt, LlmGenerationConfig config, LlmCallback llmC * @param llmCallback callback object to receive results. * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public void generate( + public int generate( int[] image, int width, int height, @@ -378,7 +320,7 @@ public void generate( int seqLen, LlmCallback llmCallback, boolean echo) { - generate( + return generate( image, width, height, @@ -405,7 +347,7 @@ public void generate( * @param echo indicate whether to echo the input prompt or not (text completion vs chat) * @param temperature temperature for sampling (use negative value to use module default) */ - public void generate( + public int generate( int[] image, int width, int height, @@ -415,7 +357,7 @@ public void generate( LlmCallback llmCallback, boolean echo, float temperature) { - generate( + return generate( image, width, height, @@ -444,7 +386,7 @@ public void generate( * @param numBos number of BOS tokens to prepend * @param numEos number of EOS tokens to append */ - public void generate( + public int generate( int[] image, int width, int height, @@ -456,22 +398,10 @@ public void generate( float temperature, int numBos, int numEos) { - mLock.lock(); - try { - checkNotDestroyed(); - if (image != null) { - int nativeResult = prefillImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } - int err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); - if (err != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate"); - } - } finally { - mLock.unlock(); + if (image != null) { + prefillImages(image, width, height, channels); } + return generate(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); } /** @@ -481,20 +411,16 @@ public void generate( * @param width Input image width * @param height Input image height * @param channels Input image number of channels - * @throws ExecutorchRuntimeException if the prefill failed + * @return 0 on success + * @throws RuntimeException if the prefill failed */ @Experimental - public void prefillImages(int[] image, int width, int height, int channels) { - mLock.lock(); - try { - checkNotDestroyed(); - int nativeResult = prefillImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); + public long prefillImages(int[] image, int width, int height, int channels) { + int nativeResult = prefillImagesInput(image, width, height, channels); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); } + return 0; } /** @@ -514,40 +440,34 @@ public void prefillImages(int[] image, int width, int height, int channels) { */ @Experimental public void prefillImages(ByteBuffer image, int width, int height, int channels) { - mLock.lock(); + if (!image.isDirect()) { + throw new IllegalArgumentException("Input ByteBuffer must be direct."); + } + long expectedBytes; try { - checkNotDestroyed(); - if (!image.isDirect()) { - throw new IllegalArgumentException("Input ByteBuffer must be direct."); - } - long expectedBytes; - try { - long pixels = Math.multiplyExact((long) width, (long) height); - expectedBytes = Math.multiplyExact(pixels, (long) channels); - } catch (ArithmeticException ex) { - throw new IllegalArgumentException( - "width*height*channels is too large and overflows the allowed range.", ex); - } - if (width <= 0 - || height <= 0 - || channels <= 0 - || expectedBytes > Integer.MAX_VALUE - || image.remaining() < expectedBytes) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be at least width*height*channels (" - + expectedBytes - + ")."); - } - // slice() so that getDirectBufferAddress on the native side returns a pointer - // starting at the current position, not the base address. - int nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); + long pixels = Math.multiplyExact((long) width, (long) height); + expectedBytes = Math.multiplyExact(pixels, (long) channels); + } catch (ArithmeticException ex) { + throw new IllegalArgumentException( + "width*height*channels is too large and overflows the allowed range.", ex); + } + if (width <= 0 + || height <= 0 + || channels <= 0 + || expectedBytes > Integer.MAX_VALUE + || image.remaining() < expectedBytes) { + throw new IllegalArgumentException( + "ByteBuffer remaining (" + + image.remaining() + + ") must be at least width*height*channels (" + + expectedBytes + + ")."); + } + // slice() so that getDirectBufferAddress on the native side returns a pointer + // starting at the current position, not the base address. + int nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); } } @@ -571,57 +491,49 @@ public void prefillImages(ByteBuffer image, int width, int height, int channels) */ @Experimental public void prefillNormalizedImage(ByteBuffer image, int width, int height, int channels) { - mLock.lock(); + if (!image.isDirect()) { + throw new IllegalArgumentException("Input ByteBuffer must be direct."); + } + if (image.order() != java.nio.ByteOrder.nativeOrder()) { + throw new IllegalArgumentException( + "Input ByteBuffer must use native byte order (ByteOrder.nativeOrder())."); + } + if (image.position() % Float.BYTES != 0) { + throw new IllegalArgumentException( + "Input ByteBuffer position (" + image.position() + ") must be 4-byte aligned."); + } + final long expectedBytes; try { - checkNotDestroyed(); - if (!image.isDirect()) { - throw new IllegalArgumentException("Input ByteBuffer must be direct."); - } - if (image.order() != java.nio.ByteOrder.nativeOrder()) { - throw new IllegalArgumentException( - "Input ByteBuffer must use native byte order (ByteOrder.nativeOrder())."); - } - if (image.position() % Float.BYTES != 0) { - throw new IllegalArgumentException( - "Input ByteBuffer position (" + image.position() + ") must be 4-byte aligned."); - } - final long expectedBytes; - try { - int wh = Math.multiplyExact(width, height); - long whc = Math.multiplyExact((long) wh, (long) channels); - long totalBytes = Math.multiplyExact(whc, (long) Float.BYTES); - if (totalBytes > Integer.MAX_VALUE) { - throw new IllegalArgumentException( - "ByteBuffer size (width*height*channels*4) exceeds Integer.MAX_VALUE bytes: " - + totalBytes); - } - expectedBytes = totalBytes; - } catch (ArithmeticException e) { - throw new IllegalArgumentException( - "Overflow while computing width*height*channels*4 for ByteBuffer size.", e); - } - if (width <= 0 || height <= 0 || channels <= 0 || image.remaining() < expectedBytes) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be at least width*height*channels*4 (" - + expectedBytes - + ")."); - } - if (image.remaining() % Float.BYTES != 0) { + int wh = Math.multiplyExact(width, height); + long whc = Math.multiplyExact((long) wh, (long) channels); + long totalBytes = Math.multiplyExact(whc, (long) Float.BYTES); + if (totalBytes > Integer.MAX_VALUE) { throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be a multiple of 4 (float size)."); + "ByteBuffer size (width*height*channels*4) exceeds Integer.MAX_VALUE bytes: " + + totalBytes); } - // slice() so that getDirectBufferAddress on the native side returns a pointer - // starting at the current position, not the base address. - int nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); + expectedBytes = totalBytes; + } catch (ArithmeticException e) { + throw new IllegalArgumentException( + "Overflow while computing width*height*channels*4 for ByteBuffer size.", e); + } + if (width <= 0 || height <= 0 || channels <= 0 || image.remaining() < expectedBytes) { + throw new IllegalArgumentException( + "ByteBuffer remaining (" + + image.remaining() + + ") must be at least width*height*channels*4 (" + + expectedBytes + + ")."); + } + if (image.remaining() % Float.BYTES != 0) { + throw new IllegalArgumentException( + "ByteBuffer remaining (" + image.remaining() + ") must be a multiple of 4 (float size)."); + } + // slice() so that getDirectBufferAddress on the native side returns a pointer + // starting at the current position, not the base address. + int nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); } } @@ -640,20 +552,16 @@ private native int prefillNormalizedImagesInputBuffer( * @param width Input image width * @param height Input image height * @param channels Input image number of channels - * @throws ExecutorchRuntimeException if the prefill failed + * @return 0 on success + * @throws RuntimeException if the prefill failed */ @Experimental - public void prefillImages(float[] image, int width, int height, int channels) { - mLock.lock(); - try { - checkNotDestroyed(); - int nativeResult = prefillNormalizedImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); + public long prefillImages(float[] image, int width, int height, int channels) { + int nativeResult = prefillNormalizedImagesInput(image, width, height, channels); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); } + return 0; } private native int prefillNormalizedImagesInput( @@ -666,20 +574,16 @@ private native int prefillNormalizedImagesInput( * @param batch_size Input batch size * @param n_bins Input number of bins * @param n_frames Input number of frames - * @throws ExecutorchRuntimeException if the prefill failed + * @return 0 on success + * @throws RuntimeException if the prefill failed */ @Experimental - public void prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { - mLock.lock(); - try { - checkNotDestroyed(); - int nativeResult = prefillAudioInput(audio, batch_size, n_bins, n_frames); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); + public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { + int nativeResult = prefillAudioInput(audio, batch_size, n_bins, n_frames); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); } + return 0; } private native int prefillAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames); @@ -691,20 +595,16 @@ public void prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) * @param batch_size Input batch size * @param n_bins Input number of bins * @param n_frames Input number of frames - * @throws ExecutorchRuntimeException if the prefill failed + * @return 0 on success + * @throws RuntimeException if the prefill failed */ @Experimental - public void prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { - mLock.lock(); - try { - checkNotDestroyed(); - int nativeResult = prefillAudioInputFloat(audio, batch_size, n_bins, n_frames); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); + public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { + int nativeResult = prefillAudioInputFloat(audio, batch_size, n_bins, n_frames); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); } + return 0; } private native int prefillAudioInputFloat( @@ -717,20 +617,16 @@ private native int prefillAudioInputFloat( * @param batch_size Input batch size * @param n_channels Input number of channels * @param n_samples Input number of samples - * @throws ExecutorchRuntimeException if the prefill failed + * @return 0 on success + * @throws RuntimeException if the prefill failed */ @Experimental - public void prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { - mLock.lock(); - try { - checkNotDestroyed(); - int nativeResult = prefillRawAudioInput(audio, batch_size, n_channels, n_samples); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); + public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { + int nativeResult = prefillRawAudioInput(audio, batch_size, n_channels, n_samples); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); } + return 0; } private native int prefillRawAudioInput( @@ -740,20 +636,16 @@ private native int prefillRawAudioInput( * Prefill the KV cache with the given text prompt. * * @param prompt The text prompt to prefill. - * @throws ExecutorchRuntimeException if the prefill failed + * @return 0 on success + * @throws RuntimeException if the prefill failed */ @Experimental - public void prefillPrompt(String prompt) { - mLock.lock(); - try { - checkNotDestroyed(); - int nativeResult = prefillTextInput(prompt); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); + public long prefillPrompt(String prompt) { + int nativeResult = prefillTextInput(prompt); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); } + return 0; } // returns status @@ -764,18 +656,7 @@ public void prefillPrompt(String prompt) { * *

The startPos will be reset to 0. */ - public void resetContext() { - mLock.lock(); - try { - checkNotDestroyed(); - resetContextNative(); - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native void resetContextNative(); + public native void resetContext(); /** Stop current generate() before it finishes. */ @DoNotStrip @@ -783,19 +664,5 @@ public void resetContext() { /** Force loading the module. Otherwise the model is loaded during first generate(). */ @DoNotStrip - public void load() { - mLock.lock(); - try { - checkNotDestroyed(); - int err = loadNative(); - if (err != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to load model"); - } - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native int loadNative(); + public native int load(); } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 88e9f9e2a12..beff72119b8 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -385,12 +385,6 @@ class ExecuTorchJni : public facebook::jni::HybridClass { static const auto toBoolMethod = JEValue::javaClassStatic()->getMethod("toBool"); evalues.emplace_back(static_cast(toBoolMethod(jevalue))); - } else { - std::stringstream ss; - ss << "Unsupported input EValue type code: " << typeCode; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str()); - return {}; } } @@ -570,8 +564,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass { makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer), makeNativeMethod( "readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic), - makeNativeMethod("etdumpNative", ExecuTorchJni::etdump), - makeNativeMethod("getMethodsNative", ExecuTorchJni::getMethods), + makeNativeMethod("etdump", ExecuTorchJni::etdump), + makeNativeMethod("getMethods", ExecuTorchJni::getMethods), makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends), }); } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 2c0117dc576..ed144acb14b 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -594,31 +594,30 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint load() { if (!runner_) { - ET_LOG( - Error, - "ExecuTorchLlmJni::load() called but runner_ is null. " - "The model runner was not created or failed to initialize due to a " - "previous configuration or initialization error. " - "Model type category: %d.", - model_type_category_); - return static_cast(Error::InvalidState); - } - const auto load_result = static_cast(runner_->load()); - if (load_result != static_cast(Error::Ok)) { - ET_LOG( - Error, - "ExecuTorchLlmJni::load() failed in runner_->load() with error code %d.", - static_cast(load_result)); - } - return load_result; + std::stringstream ss; + ss << "Invalid model type category: " << model_type_category_ + << ". Valid values are: " << MODEL_TYPE_CATEGORY_LLM << " or " + << MODEL_TYPE_CATEGORY_MULTIMODAL; + executorch::jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str().c_str()); + return -1; + } + int result = static_cast(runner_->load()); + if (result != 0) { + std::stringstream ss; + ss << "Failed to load runner: [" << result << "]"; + executorch::jni_helper::throwExecutorchException( + static_cast(result), ss.str().c_str()); + } + return result; } static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchLlmJni::initHybrid), - makeNativeMethod("generateNative", ExecuTorchLlmJni::generate), + makeNativeMethod("generate", ExecuTorchLlmJni::generate), makeNativeMethod("stop", ExecuTorchLlmJni::stop), - makeNativeMethod("loadNative", ExecuTorchLlmJni::load), + makeNativeMethod("load", ExecuTorchLlmJni::load), makeNativeMethod( "prefillImagesInput", ExecuTorchLlmJni::prefill_images_input), makeNativeMethod( @@ -639,7 +638,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "prefillRawAudioInput", ExecuTorchLlmJni::prefill_raw_audio_input), makeNativeMethod( "prefillTextInput", ExecuTorchLlmJni::prefill_text_input), - makeNativeMethod("resetContextNative", ExecuTorchLlmJni::reset_context), + makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), }); } }; diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java index 2c8770ca33e..a1b434a37bf 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java @@ -87,21 +87,10 @@ public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) { @Override public void handleMessage(android.os.Message msg) { if (msg.what == MESSAGE_LOAD_MODEL) { - int status = 0; - try { - mLlmModelRunner.mModule.load(); - } catch (org.pytorch.executorch.ExecutorchRuntimeException e) { - status = e.getErrorCode(); - } catch (Exception e) { - status = -1; - } + int status = mLlmModelRunner.mModule.load(); mLlmModelRunner.mCallback.onModelLoaded(status); } else if (msg.what == MESSAGE_GENERATE) { - try { - mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); - } catch (Exception e) { - android.util.Log.e("LlmModelRunner", "generate() failed", e); - } + mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); mLlmModelRunner.mCallback.onGenerationStopped(); } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java index 915496a25af..28f4e3728f0 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -27,73 +27,53 @@ public void runBenchmark( long loadStart = System.nanoTime(); Module module = Module.load(model.getPath()); - int errorCode = 0; - try { - module.loadMethod("forward"); - } catch (Exception e) { - errorCode = - (e instanceof org.pytorch.executorch.ExecutorchRuntimeException) - ? ((org.pytorch.executorch.ExecutorchRuntimeException) e).getErrorCode() - : -1; - } + int errorCode = module.loadMethod("forward"); long loadEnd = System.nanoTime(); - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); - - if (errorCode != 0) { - results.add( - new BenchmarkMetric( - benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); - results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); - module.destroy(); - return; + for (int i = 0; i < numWarmupIter; i++) { + module.forward(); } - try { - for (int i = 0; i < numWarmupIter; i++) { - module.forward(); - } - - for (int i = 0; i < numIter; i++) { - long start = System.nanoTime(); - module.forward(); - double forwardMs = (System.nanoTime() - start) * 1e-6; - latency.add(forwardMs); - } + for (int i = 0; i < numIter; i++) { + long start = System.nanoTime(); + module.forward(); + double forwardMs = (System.nanoTime() - start) * 1e-6; + latency.add(forwardMs); + } - module.etdump(); + module.etdump(); - // Currently the result has large variance from outliers, so only use - // 80% samples in the middle (trimmean 0.2) - Collections.sort(latency); - int resultSize = latency.size(); - List usedLatencyResults = latency.subList(resultSize / 10, resultSize * 9 / 10); + final BenchmarkMetric.BenchmarkModel benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); + // The list of metrics we have atm includes: + // Avg inference latency after N iterations + // Currently the result has large variance from outliers, so only use + // 80% samples in the middle (trimmean 0.2) + Collections.sort(latency); + int resultSize = latency.size(); + List usedLatencyResults = latency.subList(resultSize / 10, resultSize * 9 / 10); - results.add( - new BenchmarkMetric( - benchmarkModel, - "avg_inference_latency(ms)", - latency.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - results.add( - new BenchmarkMetric( - benchmarkModel, - "trimmean_inference_latency(ms)", - usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); - // RAM PSS usage - results.add( - new BenchmarkMetric( - benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); - } finally { - module.destroy(); - } + results.add( + new BenchmarkMetric( + benchmarkModel, + "avg_inference_latency(ms)", + latency.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + results.add( + new BenchmarkMetric( + benchmarkModel, + "trimmean_inference_latency(ms)", + usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + // Model load time + results.add( + new BenchmarkMetric( + benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); + // Load status + results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); + // RAM PSS usage + results.add( + new BenchmarkMetric( + benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); } } From 2d9bbc1ea3eb2a883a69c72997d675dc13d68a10 Mon Sep 17 00:00:00 2001 From: Hansong Zhang <107070759+kirklandsign@users.noreply.github.com> Date: Sat, 25 Apr 2026 20:02:12 -0700 Subject: [PATCH 13/21] Add top-k sampling support to llm Sampler (#19122) Differential Revision: D102385104 Pull Request resolved: https://github.com/pytorch/executorch/pull/19122 --- extension/llm/sampler/sampler.cpp | 55 +++++++++- extension/llm/sampler/sampler.h | 11 ++ extension/llm/sampler/test/test_sampler.cpp | 113 ++++++++++++++++++++ 3 files changed, 178 insertions(+), 1 deletion(-) diff --git a/extension/llm/sampler/sampler.cpp b/extension/llm/sampler/sampler.cpp index 3beda885d6f..d41da96f07e 100644 --- a/extension/llm/sampler/sampler.cpp +++ b/extension/llm/sampler/sampler.cpp @@ -69,6 +69,56 @@ int32_t Sampler::sample_mult(T* probabilities, float coin) { return vocab_size_ - 1; // in case of rounding errors } +template +int32_t Sampler::sample_topk(T* probabilities, float coin) { + // top-k sampling samples from the k highest-probability tokens. + // coin is a random number in [0, 1), usually from random_f32(). + // + // TODO: probindex is allocated on every call; lifting it to a member + // would avoid per-token heap allocation in autoregressive loops. + const int n = vocab_size_; + const int k = std::min(topk_, n); + // Defensive: callers gate on topk_ > 0, but a private helper should not + // rely on external invariants. Fall back to a deterministic index. + if (k <= 0) { + return 0; + } + + std::unique_ptr[]> probindex = + std::make_unique[]>(n); + for (int i = 0; i < n; i++) { + probindex[i].index = i; + probindex[i].prob = probabilities[i]; + } + + auto compare = [](const ProbIndex& a, const ProbIndex& b) { + return a.prob > b.prob; + }; + // Partial sort: only the top-k entries need to be sorted in descending order. + std::partial_sort( + probindex.get(), probindex.get() + k, probindex.get() + n, compare); + + // Sum of the top-k probabilities. Used to scale `coin` instead of + // explicitly renormalizing the k probs — mathematically equivalent and + // saves k divisions. Accumulate in float so FP16/BF16 inputs don't lose + // precision over k summands. + float topk_sum = 0.0f; + for (int i = 0; i < k; i++) { + topk_sum += static_cast(probindex[i].prob); + } + + // Sample from the (implicitly renormalized) top-k distribution. + const float r = coin * topk_sum; + float cdf = 0.0f; + for (int i = 0; i < k; i++) { + cdf += static_cast(probindex[i].prob); + if (r < cdf) { + return probindex[i].index; + } + } + return probindex[k - 1].index; // in case of rounding errors +} + template int32_t Sampler::sample_topp(T* probabilities, float coin) { // top-p sampling (or "nucleus sampling") samples from the smallest set of @@ -186,7 +236,10 @@ int32_t Sampler::sample(T* logits) { // flip a (float) coin (this is our source of entropy for sampling) float coin = random_f32(&rng_state_); // we sample from this distribution to get the next token - if (topp_ <= 0 || topp_ >= 1) { + if (topk_ > 0 && topk_ < vocab_size_) { + // top-k sampling, restrict to the k most likely tokens + next = sample_topk(logits, coin); + } else if (topp_ <= 0 || topp_ >= 1) { // simply sample from the predicted probability distribution next = sample_mult(logits, coin); } else { diff --git a/extension/llm/sampler/sampler.h b/extension/llm/sampler/sampler.h index 1525f38692a..4a480edc1ef 100644 --- a/extension/llm/sampler/sampler.h +++ b/extension/llm/sampler/sampler.h @@ -44,6 +44,13 @@ class ET_EXPERIMENTAL Sampler { Sampler(int32_t vocab_size, float temperature); + // Enable top-k filtering. k <= 0 or k >= vocab_size disables top-k. + // When top-k is enabled, top-p is ignored — the two modes are mutually + // exclusive in this implementation. + void set_topk(int32_t topk) { + topk_ = topk; + } + template int32_t sample(T* logits); @@ -51,6 +58,8 @@ class ET_EXPERIMENTAL Sampler { template int32_t sample_topp(T* probabilities, float coin); template + int32_t sample_topk(T* probabilities, float coin); + template int32_t sample_mult(T* probabilities, float coin); template int32_t sample_argmax(T* probabilities); @@ -60,6 +69,8 @@ class ET_EXPERIMENTAL Sampler { // reciprocal of temperature, or 0 if temperature == 0. float inv_temperature_; float topp_; + // 0 (or >= vocab_size_) means top-k is disabled. + int32_t topk_ = 0; unsigned long long rng_state_; }; diff --git a/extension/llm/sampler/test/test_sampler.cpp b/extension/llm/sampler/test/test_sampler.cpp index 044a39458ea..8463c2e9678 100644 --- a/extension/llm/sampler/test/test_sampler.cpp +++ b/extension/llm/sampler/test/test_sampler.cpp @@ -8,6 +8,8 @@ #include +#include + #include #include @@ -39,3 +41,114 @@ TEST(SamplerTest, TestArgMaxWithFP16) { input[0][0][396] = 1.0f; EXPECT_EQ(sampler.sample(input.data_ptr()), 396); } + +TEST(SamplerTest, TestTopKRestrictsToCandidates) { + // With topk=3, sampling must always return one of the top-3 indices, + // regardless of the random draw. + Sampler sampler{ + /*vocab_size*/ 100, + /*temperature*/ 1.0f, + /*topp*/ 0.0f, // disable top-p so we exercise top-k alone + /*rng_seed*/ 42}; + sampler.set_topk(3); + + // Construct logits where indices {7, 13, 42} dominate. + torch::Tensor input = torch::full({100}, -10.0f, at::kFloat); + input[7] = 5.0f; + input[13] = 4.5f; + input[42] = 4.0f; + + std::set allowed = {7, 13, 42}; + for (int trial = 0; trial < 50; ++trial) { + // Re-fill logits each trial because sample() mutates them in place. + torch::Tensor logits = input.clone(); + int32_t out = sampler.sample(logits.data_ptr()); + EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out; + } +} + +TEST(SamplerTest, TestTopKDisabledByZero) { + // topk=0 means disabled. With topp disabled, sampling collapses to + // multinomial over the full vocab, but the dominant token should still + // win the vast majority of the time. + Sampler sampler{ + /*vocab_size*/ 50, + /*temperature*/ 1.0f, + /*topp*/ 0.0f, + /*rng_seed*/ 7}; + sampler.set_topk(0); // disabled + + torch::Tensor input = torch::full({50}, -10.0f, at::kFloat); + input[11] = 20.0f; // dominant + + int hits = 0; + for (int trial = 0; trial < 20; ++trial) { + torch::Tensor logits = input.clone(); + if (sampler.sample(logits.data_ptr()) == 11) { + hits++; + } + } + EXPECT_GE(hits, 18); // dominant token should win nearly every time +} + +TEST(SamplerTest, TestTopKWithFP16) { + // Smoke test the FP16 template instantiation of the top-k path. + Sampler sampler{ + /*vocab_size*/ 50, + /*temperature*/ 1.0f, + /*topp*/ 0.0f, + /*rng_seed*/ 99}; + sampler.set_topk(2); + + torch::Tensor input = torch::full({50}, -10.0f, at::kHalf); + input[3] = 5.0f; + input[8] = 4.5f; + + std::set allowed = {3, 8}; + for (int trial = 0; trial < 30; ++trial) { + torch::Tensor logits = input.clone(); + int32_t out = sampler.sample(logits.data_ptr()); + EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out; + } +} + +TEST(SamplerTest, TestTopKEqualsOneIsArgmax) { + // topk=1 should behave like greedy argmax even with temperature > 0. + Sampler sampler{ + /*vocab_size*/ 100, + /*temperature*/ 1.0f, + /*topp*/ 0.0f, + /*rng_seed*/ 123}; + sampler.set_topk(1); + + torch::Tensor input = torch::rand({100}, at::kFloat); + input[57] = 100.0f; // make 57 the unambiguous max + + for (int trial = 0; trial < 10; ++trial) { + torch::Tensor logits = input.clone(); + EXPECT_EQ(sampler.sample(logits.data_ptr()), 57); + } +} + +TEST(SamplerTest, TestTopKTakesPrecedenceOverTopP) { + // When both top-k and top-p are set, top-k should restrict the candidate + // set; top-p alone would admit a third token that top-k=2 must exclude. + Sampler sampler{ + /*vocab_size*/ 100, + /*temperature*/ 1.0f, + /*topp*/ 0.99f, // would keep nearly the whole vocab on its own + /*rng_seed*/ 99}; + sampler.set_topk(2); + + torch::Tensor input = torch::full({100}, -10.0f, at::kFloat); + input[3] = 5.0f; + input[8] = 4.5f; + input[19] = 4.0f; // would be in the top-p set but is excluded by top-k=2 + + std::set allowed = {3, 8}; + for (int trial = 0; trial < 50; ++trial) { + torch::Tensor logits = input.clone(); + int32_t out = sampler.sample(logits.data_ptr()); + EXPECT_TRUE(allowed.count(out)) << "trial " << trial << " got " << out; + } +} From 563be2f899cf6e3ee7e9aa726cac3ce10b76f66c Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Sat, 25 Apr 2026 22:39:01 -0700 Subject: [PATCH 14/21] Re-apply D101260086: Android unified error reporting Differential Revision: D102493794 Pull Request resolved: https://github.com/pytorch/executorch/pull/19136 --- .../LlmModuleInstrumentationTest.kt | 5 +- .../executorch/ModuleInstrumentationTest.kt | 29 ++-- .../ExecutorchRuntimeException.java | 54 ++++++- .../java/org/pytorch/executorch/Module.java | 79 +++++++--- .../executorch/extension/llm/LlmCallback.java | 8 +- .../executorch/extension/llm/LlmModule.java | 147 +++++++++++------- extension/android/jni/jni_layer.cpp | 10 +- extension/android/jni/jni_layer_llama.cpp | 39 ++--- .../org/pytorch/minibench/LlmModelRunner.java | 15 +- .../org/pytorch/minibench/ModelRunner.java | 10 +- 10 files changed, 276 insertions(+), 120 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt index 4b6c3caed94..d5738773577 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt @@ -54,9 +54,7 @@ class LlmModuleInstrumentationTest : LlmCallback { @Test @Throws(IOException::class, URISyntaxException::class) fun testGenerate() { - val loadResult = llmModule.load() - // Check that the model can be load successfully - assertEquals(OK.toLong(), loadResult.toLong()) + llmModule.load() llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) assertEquals(results.size.toLong(), SEQ_LEN.toLong()) @@ -277,7 +275,6 @@ class LlmModuleInstrumentationTest : LlmCallback { private const val TEST_FILE_NAME = "/stories.pte" private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" private const val TEST_PROMPT = "Hello" - private const val OK = 0x00 private const val SEQ_LEN = 32 } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index 99d53b6dba3..ba91f444287 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -66,8 +66,7 @@ class ModuleInstrumentationTest { fun testModuleLoadMethodAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), OK.toLong()) + module.loadMethod(FORWARD_METHOD) val results = module.forward() Assert.assertTrue(results[0].isTensor) @@ -96,8 +95,14 @@ class ModuleInstrumentationTest { fun testModuleLoadMethodNonExistantMethod() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - val loadMethod = module.loadMethod(NONE_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + val exception = + Assert.assertThrows(ExecutorchRuntimeException::class.java) { + module.loadMethod(NONE_METHOD) + } + Assert.assertEquals( + ExecutorchRuntimeException.INVALID_ARGUMENT, + exception.getErrorCode(), + ) } @Test(expected = RuntimeException::class) @@ -105,8 +110,7 @@ class ModuleInstrumentationTest { fun testNonPteFile() { val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)) - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + module.loadMethod(FORWARD_METHOD) } @Test @@ -116,8 +120,7 @@ class ModuleInstrumentationTest { module.destroy() - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong()) + Assert.assertThrows(IllegalStateException::class.java) { module.loadMethod(FORWARD_METHOD) } } @Test @@ -125,13 +128,11 @@ class ModuleInstrumentationTest { fun testForwardOnDestroyedModule() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), OK.toLong()) + module.loadMethod(FORWARD_METHOD) module.destroy() - val results = module.forward() - Assert.assertEquals(0, results.size.toLong()) + Assert.assertThrows(IllegalStateException::class.java) { module.forward() } } @Ignore( @@ -175,9 +176,5 @@ class ModuleInstrumentationTest { private const val NON_PTE_FILE_NAME = "/test.txt" private const val FORWARD_METHOD = "forward" private const val NONE_METHOD = "none" - private const val OK = 0x00 - private const val INVALID_STATE = 0x2 - private const val INVALID_ARGUMENT = 0x12 - private const val ACCESS_FAILED = 0x22 } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java index 102b96ab686..e0fda73cc06 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java @@ -12,34 +12,83 @@ import java.util.HashMap; import java.util.Map; +/** + * Base exception for all ExecuTorch runtime errors. Each instance carries an integer error code + * corresponding to the native {@code runtime/core/error.h} values, accessible via {@link + * #getErrorCode()}. + */ public class ExecutorchRuntimeException extends RuntimeException { // Error code constants - keep in sync with runtime/core/error.h + // System errors + + /** Operation completed successfully. */ public static final int OK = 0x00; + + /** An unexpected internal error occurred in the runtime. */ public static final int INTERNAL = 0x01; + + /** The runtime or method is in an invalid state for the requested operation. */ public static final int INVALID_STATE = 0x02; + + /** The method has finished execution and has no more work to do. */ public static final int END_OF_METHOD = 0x03; + /** A required resource has already been loaded. */ + public static final int ALREADY_LOADED = 0x04; + // Logical errors + + /** The requested operation is not supported by this build or backend. */ public static final int NOT_SUPPORTED = 0x10; + + /** The requested operation has not been implemented. */ public static final int NOT_IMPLEMENTED = 0x11; + + /** One or more arguments passed to the operation are invalid. */ public static final int INVALID_ARGUMENT = 0x12; + + /** A value or tensor has an unexpected type. */ public static final int INVALID_TYPE = 0x13; + + /** A required operator kernel is not registered. */ public static final int OPERATOR_MISSING = 0x14; + + /** The maximum number of registered kernels has been exceeded. */ public static final int REGISTRATION_EXCEEDING_MAX_KERNELS = 0x15; + + /** A kernel with the same name is already registered. */ public static final int REGISTRATION_ALREADY_REGISTERED = 0x16; // Resource errors + + /** A required resource (file, tensor, program) was not found. */ public static final int NOT_FOUND = 0x20; + + /** A memory allocation failed. */ public static final int MEMORY_ALLOCATION_FAILED = 0x21; + + /** Access to a resource was denied or failed. */ public static final int ACCESS_FAILED = 0x22; + + /** The loaded program is malformed or incompatible. */ public static final int INVALID_PROGRAM = 0x23; + + /** External data referenced by the program is invalid or missing. */ public static final int INVALID_EXTERNAL_DATA = 0x24; + + /** The system has run out of a required resource. */ public static final int OUT_OF_RESOURCES = 0x25; // Delegate errors + + /** A delegate reported an incompatible model or configuration. */ public static final int DELEGATE_INVALID_COMPATIBILITY = 0x30; + + /** A delegate failed to allocate required memory. */ public static final int DELEGATE_MEMORY_ALLOCATION_FAILED = 0x31; + + /** A delegate received an invalid or stale handle. */ public static final int DELEGATE_INVALID_HANDLE = 0x32; private static final Map ERROR_CODE_MESSAGES; @@ -52,6 +101,7 @@ public class ExecutorchRuntimeException extends RuntimeException { map.put(INTERNAL, "Internal error"); map.put(INVALID_STATE, "Invalid state"); map.put(END_OF_METHOD, "End of method reached"); + map.put(ALREADY_LOADED, "Already loaded"); // Logical errors map.put(NOT_SUPPORTED, "Operation not supported"); map.put(NOT_IMPLEMENTED, "Operation not implemented"); @@ -83,7 +133,7 @@ static String formatMessage(int errorCode, String details) { String safeDetails = details != null ? details : "No details provided"; return String.format( - "[Executorch Error 0x%s] %s: %s", + "[ExecuTorch Error 0x%s] %s: %s", Integer.toHexString(errorCode), baseMessage, safeDetails); } @@ -111,10 +161,12 @@ public ExecutorchRuntimeException(int errorCode, String details) { this.errorCode = errorCode; } + /** Returns the numeric error code from {@code runtime/core/error.h}. */ public int getErrorCode() { return errorCode; } + /** Returns detailed log output captured from the native runtime, if available. */ public String getDetailedError() { return ErrorHelper.getDetailedErrorLogs(); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index f7e2e37dcec..05e1e5b88cf 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -8,7 +8,6 @@ package org.pytorch.executorch; -import android.util.Log; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; @@ -130,11 +129,10 @@ public EValue[] forward(EValue... inputs) { * @return return value from the method. */ public EValue[] execute(String methodName, EValue... inputs) { + mLock.lock(); try { - mLock.lock(); if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return new EValue[0]; + throw new IllegalStateException("Module has been destroyed"); } return executeNative(methodName, inputs); } finally { @@ -151,17 +149,17 @@ public EValue[] execute(String methodName, EValue... inputs) { * synchronous, and will block until the method is loaded. Therefore, it is recommended to call * this on a background thread. However, users need to make sure that they don't execute before * this function returns. - * - * @return the Error code if there was an error loading the method */ - public int loadMethod(String methodName) { + public void loadMethod(String methodName) { + mLock.lock(); try { - mLock.lock(); if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return 0x2; // InvalidState + throw new IllegalStateException("Module has been destroyed"); + } + int errorCode = loadMethodNative(methodName); + if (errorCode != 0) { + throw new ExecutorchRuntimeException(errorCode, "Failed to load method: " + methodName); } - return loadMethodNative(methodName); } finally { mLock.unlock(); } @@ -184,8 +182,20 @@ public int loadMethod(String methodName) { * * @return name of methods in this Module */ + public String[] getMethods() { + mLock.lock(); + try { + if (!mHybridData.isValid()) { + throw new IllegalStateException("Module has been destroyed"); + } + return getMethodsNative(); + } finally { + mLock.unlock(); + } + } + @DoNotStrip - public native String[] getMethods(); + private native String[] getMethodsNative(); /** * Get the corresponding @MethodMetadata for a method @@ -194,11 +204,19 @@ public int loadMethod(String methodName) { * @return @MethodMetadata for this method */ public MethodMetadata getMethodMetadata(String name) { - MethodMetadata methodMetadata = mMethodMetadata.get(name); - if (methodMetadata == null) { - throw new IllegalArgumentException("method " + name + " does not exist for this module"); + mLock.lock(); + try { + if (!mHybridData.isValid()) { + throw new IllegalStateException("Module has been destroyed"); + } + MethodMetadata methodMetadata = mMethodMetadata.get(name); + if (methodMetadata == null) { + throw new IllegalArgumentException("method " + name + " does not exist for this module"); + } + return methodMetadata; + } finally { + mLock.unlock(); } - return methodMetadata; } @DoNotStrip @@ -210,7 +228,15 @@ public static String[] readLogBufferStatic() { /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ public String[] readLogBuffer() { - return readLogBufferNative(); + mLock.lock(); + try { + if (!mHybridData.isValid()) { + throw new IllegalStateException("Module has been destroyed"); + } + return readLogBufferNative(); + } finally { + mLock.unlock(); + } } @DoNotStrip @@ -224,8 +250,20 @@ public String[] readLogBuffer() { * @return true if the etdump was successfully written, false otherwise. */ @Experimental + public boolean etdump() { + mLock.lock(); + try { + if (!mHybridData.isValid()) { + throw new IllegalStateException("Module has been destroyed"); + } + return etdumpNative(); + } finally { + mLock.unlock(); + } + } + @DoNotStrip - public native boolean etdump(); + private native boolean etdumpNative(); /** * Explicitly destroys the native Module object. Calling this method is not required, as the @@ -241,10 +279,7 @@ public void destroy() { mLock.unlock(); } } else { - Log.w( - "ExecuTorch", - "Destroy was called while the module was in use. Resources will not be immediately" - + " released."); + throw new IllegalStateException("Cannot destroy module while method is executing"); } } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java index 4e834d06721..ec0413caf2e 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java @@ -46,5 +46,11 @@ default void onStats(String stats) {} * @param message Human-readable error description */ @DoNotStrip - default void onError(int errorCode, String message) {} + default void onError(int errorCode, String message) { + try { + android.util.Log.e("ExecuTorch", "LLM error " + errorCode + ": " + message); + } catch (Throwable t) { + System.err.println("ExecuTorch LLM error " + errorCode + ": " + message); + } + } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index a563dc6bcc7..8ea43b4f38d 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -13,6 +13,7 @@ import java.nio.ByteBuffer; import java.util.List; import org.pytorch.executorch.ExecuTorchRuntime; +import org.pytorch.executorch.ExecutorchRuntimeException; import org.pytorch.executorch.annotations.Experimental; /** @@ -29,6 +30,7 @@ public class LlmModule { public static final int MODEL_TYPE_MULTIMODAL = 2; private final HybridData mHybridData; + private volatile boolean mDestroyed = false; private static final int DEFAULT_SEQ_LEN = 128; private static final boolean DEFAULT_ECHO = true; private static final float DEFAULT_TEMPERATURE = -1.0f; @@ -185,7 +187,14 @@ public LlmModule(LlmModuleConfig config) { config.getLoadMode()); } + private void checkNotDestroyed() { + if (mDestroyed) throw new IllegalStateException("LlmModule has been destroyed"); + } + + @Deprecated public void resetNative() { + if (mDestroyed) return; + mDestroyed = true; mHybridData.resetNative(); } @@ -195,8 +204,9 @@ public void resetNative() { * @param prompt Input prompt * @param llmCallback callback object to receive results. */ - public int generate(String prompt, LlmCallback llmCallback) { - return generate( + public void generate(String prompt, LlmCallback llmCallback) { + checkNotDestroyed(); + generate( prompt, DEFAULT_SEQ_LEN, llmCallback, @@ -213,8 +223,9 @@ public int generate(String prompt, LlmCallback llmCallback) { * @param seqLen sequence length * @param llmCallback callback object to receive results. */ - public int generate(String prompt, int seqLen, LlmCallback llmCallback) { - return generate( + public void generate(String prompt, int seqLen, LlmCallback llmCallback) { + checkNotDestroyed(); + generate( null, 0, 0, @@ -235,8 +246,9 @@ public int generate(String prompt, int seqLen, LlmCallback llmCallback) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public int generate(String prompt, LlmCallback llmCallback, boolean echo) { - return generate( + public void generate(String prompt, LlmCallback llmCallback, boolean echo) { + checkNotDestroyed(); + generate( null, 0, 0, @@ -258,9 +270,9 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { - return generate( - prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS); + public void generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { + checkNotDestroyed(); + generate(prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS); } /** @@ -274,7 +286,23 @@ public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean * @param numBos number of BOS tokens to prepend * @param numEos number of EOS tokens to append */ - public native int generate( + public void generate( + String prompt, + int seqLen, + LlmCallback llmCallback, + boolean echo, + float temperature, + int numBos, + int numEos) { + checkNotDestroyed(); + int err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate"); + } + } + + @DoNotStrip + private native int generateNative( String prompt, int seqLen, LlmCallback llmCallback, @@ -290,13 +318,14 @@ public native int generate( * @param config the config for generation * @param llmCallback callback object to receive results */ - public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) { + public void generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) { + checkNotDestroyed(); int seqLen = config.getSeqLen(); boolean echo = config.isEcho(); float temperature = config.getTemperature(); int numBos = config.getNumBos(); int numEos = config.getNumEos(); - return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); + generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); } /** @@ -311,7 +340,7 @@ public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCa * @param llmCallback callback object to receive results. * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public int generate( + public void generate( int[] image, int width, int height, @@ -320,7 +349,8 @@ public int generate( int seqLen, LlmCallback llmCallback, boolean echo) { - return generate( + checkNotDestroyed(); + generate( image, width, height, @@ -347,7 +377,7 @@ public int generate( * @param echo indicate whether to echo the input prompt or not (text completion vs chat) * @param temperature temperature for sampling (use negative value to use module default) */ - public int generate( + public void generate( int[] image, int width, int height, @@ -357,7 +387,8 @@ public int generate( LlmCallback llmCallback, boolean echo, float temperature) { - return generate( + checkNotDestroyed(); + generate( image, width, height, @@ -386,7 +417,7 @@ public int generate( * @param numBos number of BOS tokens to prepend * @param numEos number of EOS tokens to append */ - public int generate( + public void generate( int[] image, int width, int height, @@ -398,10 +429,11 @@ public int generate( float temperature, int numBos, int numEos) { + checkNotDestroyed(); if (image != null) { prefillImages(image, width, height, channels); } - return generate(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); + generate(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); } /** @@ -411,16 +443,15 @@ public int generate( * @param width Input image width * @param height Input image height * @param channels Input image number of channels - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillImages(int[] image, int width, int height, int channels) { + public void prefillImages(int[] image, int width, int height, int channels) { + checkNotDestroyed(); int nativeResult = prefillImagesInput(image, width, height, channels); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } - return 0; } /** @@ -440,6 +471,7 @@ public long prefillImages(int[] image, int width, int height, int channels) { */ @Experimental public void prefillImages(ByteBuffer image, int width, int height, int channels) { + checkNotDestroyed(); if (!image.isDirect()) { throw new IllegalArgumentException("Input ByteBuffer must be direct."); } @@ -467,7 +499,7 @@ public void prefillImages(ByteBuffer image, int width, int height, int channels) // starting at the current position, not the base address. int nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } } @@ -491,6 +523,7 @@ public void prefillImages(ByteBuffer image, int width, int height, int channels) */ @Experimental public void prefillNormalizedImage(ByteBuffer image, int width, int height, int channels) { + checkNotDestroyed(); if (!image.isDirect()) { throw new IllegalArgumentException("Input ByteBuffer must be direct."); } @@ -533,7 +566,7 @@ public void prefillNormalizedImage(ByteBuffer image, int width, int height, int // starting at the current position, not the base address. int nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } } @@ -552,16 +585,15 @@ private native int prefillNormalizedImagesInputBuffer( * @param width Input image width * @param height Input image height * @param channels Input image number of channels - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillImages(float[] image, int width, int height, int channels) { + public void prefillImages(float[] image, int width, int height, int channels) { + checkNotDestroyed(); int nativeResult = prefillNormalizedImagesInput(image, width, height, channels); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } - return 0; } private native int prefillNormalizedImagesInput( @@ -574,16 +606,15 @@ private native int prefillNormalizedImagesInput( * @param batch_size Input batch size * @param n_bins Input number of bins * @param n_frames Input number of frames - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { + public void prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { + checkNotDestroyed(); int nativeResult = prefillAudioInput(audio, batch_size, n_bins, n_frames); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } - return 0; } private native int prefillAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames); @@ -595,16 +626,15 @@ public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) * @param batch_size Input batch size * @param n_bins Input number of bins * @param n_frames Input number of frames - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { + public void prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { + checkNotDestroyed(); int nativeResult = prefillAudioInputFloat(audio, batch_size, n_bins, n_frames); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } - return 0; } private native int prefillAudioInputFloat( @@ -617,16 +647,15 @@ private native int prefillAudioInputFloat( * @param batch_size Input batch size * @param n_channels Input number of channels * @param n_samples Input number of samples - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { + public void prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { + checkNotDestroyed(); int nativeResult = prefillRawAudioInput(audio, batch_size, n_channels, n_samples); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } - return 0; } private native int prefillRawAudioInput( @@ -636,16 +665,15 @@ private native int prefillRawAudioInput( * Prefill the KV cache with the given text prompt. * * @param prompt The text prompt to prefill. - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillPrompt(String prompt) { + public void prefillPrompt(String prompt) { + checkNotDestroyed(); int nativeResult = prefillTextInput(prompt); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } - return 0; } // returns status @@ -656,7 +684,13 @@ public long prefillPrompt(String prompt) { * *

The startPos will be reset to 0. */ - public native void resetContext(); + public void resetContext() { + checkNotDestroyed(); + resetContextNative(); + } + + @DoNotStrip + private native void resetContextNative(); /** Stop current generate() before it finishes. */ @DoNotStrip @@ -664,5 +698,14 @@ public long prefillPrompt(String prompt) { /** Force loading the module. Otherwise the model is loaded during first generate(). */ @DoNotStrip - public native int load(); + public void load() { + checkNotDestroyed(); + int err = loadNative(); + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to load model"); + } + } + + @DoNotStrip + private native int loadNative(); } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index beff72119b8..88e9f9e2a12 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -385,6 +385,12 @@ class ExecuTorchJni : public facebook::jni::HybridClass { static const auto toBoolMethod = JEValue::javaClassStatic()->getMethod("toBool"); evalues.emplace_back(static_cast(toBoolMethod(jevalue))); + } else { + std::stringstream ss; + ss << "Unsupported input EValue type code: " << typeCode; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str()); + return {}; } } @@ -564,8 +570,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass { makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer), makeNativeMethod( "readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic), - makeNativeMethod("etdump", ExecuTorchJni::etdump), - makeNativeMethod("getMethods", ExecuTorchJni::getMethods), + makeNativeMethod("etdumpNative", ExecuTorchJni::etdump), + makeNativeMethod("getMethodsNative", ExecuTorchJni::getMethods), makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends), }); } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index ed144acb14b..2c0117dc576 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -594,30 +594,31 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint load() { if (!runner_) { - std::stringstream ss; - ss << "Invalid model type category: " << model_type_category_ - << ". Valid values are: " << MODEL_TYPE_CATEGORY_LLM << " or " - << MODEL_TYPE_CATEGORY_MULTIMODAL; - executorch::jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - return -1; - } - int result = static_cast(runner_->load()); - if (result != 0) { - std::stringstream ss; - ss << "Failed to load runner: [" << result << "]"; - executorch::jni_helper::throwExecutorchException( - static_cast(result), ss.str().c_str()); - } - return result; + ET_LOG( + Error, + "ExecuTorchLlmJni::load() called but runner_ is null. " + "The model runner was not created or failed to initialize due to a " + "previous configuration or initialization error. " + "Model type category: %d.", + model_type_category_); + return static_cast(Error::InvalidState); + } + const auto load_result = static_cast(runner_->load()); + if (load_result != static_cast(Error::Ok)) { + ET_LOG( + Error, + "ExecuTorchLlmJni::load() failed in runner_->load() with error code %d.", + static_cast(load_result)); + } + return load_result; } static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchLlmJni::initHybrid), - makeNativeMethod("generate", ExecuTorchLlmJni::generate), + makeNativeMethod("generateNative", ExecuTorchLlmJni::generate), makeNativeMethod("stop", ExecuTorchLlmJni::stop), - makeNativeMethod("load", ExecuTorchLlmJni::load), + makeNativeMethod("loadNative", ExecuTorchLlmJni::load), makeNativeMethod( "prefillImagesInput", ExecuTorchLlmJni::prefill_images_input), makeNativeMethod( @@ -638,7 +639,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "prefillRawAudioInput", ExecuTorchLlmJni::prefill_raw_audio_input), makeNativeMethod( "prefillTextInput", ExecuTorchLlmJni::prefill_text_input), - makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), + makeNativeMethod("resetContextNative", ExecuTorchLlmJni::reset_context), }); } }; diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java index a1b434a37bf..2c8770ca33e 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java @@ -87,10 +87,21 @@ public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) { @Override public void handleMessage(android.os.Message msg) { if (msg.what == MESSAGE_LOAD_MODEL) { - int status = mLlmModelRunner.mModule.load(); + int status = 0; + try { + mLlmModelRunner.mModule.load(); + } catch (org.pytorch.executorch.ExecutorchRuntimeException e) { + status = e.getErrorCode(); + } catch (Exception e) { + status = -1; + } mLlmModelRunner.mCallback.onModelLoaded(status); } else if (msg.what == MESSAGE_GENERATE) { - mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); + try { + mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); + } catch (Exception e) { + android.util.Log.e("LlmModelRunner", "generate() failed", e); + } mLlmModelRunner.mCallback.onGenerationStopped(); } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java index 28f4e3728f0..b2fdeed9bab 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -27,7 +27,15 @@ public void runBenchmark( long loadStart = System.nanoTime(); Module module = Module.load(model.getPath()); - int errorCode = module.loadMethod("forward"); + int errorCode = 0; + try { + module.loadMethod("forward"); + } catch (Exception e) { + errorCode = + (e instanceof org.pytorch.executorch.ExecutorchRuntimeException) + ? ((org.pytorch.executorch.ExecutorchRuntimeException) e).getErrorCode() + : -1; + } long loadEnd = System.nanoTime(); for (int i = 0; i < numWarmupIter; i++) { From bf64fa159fcca4bcd4cf24e012e578dafdf1d048 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Sat, 25 Apr 2026 23:32:20 -0700 Subject: [PATCH 15/21] Back out "Re-apply D101260086: Android unified error reporting" (#19137) Differential Revision: D102493838 Pull Request resolved: https://github.com/pytorch/executorch/pull/19137 From dd5f6b15a0fa74deac1fec63c5097318b8e5575a Mon Sep 17 00:00:00 2001 From: Winston Kuo Date: Tue, 21 Apr 2026 17:30:58 +0800 Subject: [PATCH 16/21] delete aihub readme under executorch/examples/qualcomm/README.md --- examples/qualcomm/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/qualcomm/README.md b/examples/qualcomm/README.md index 34a0b4e8fe8..eca0442c8fb 100644 --- a/examples/qualcomm/README.md +++ b/examples/qualcomm/README.md @@ -9,8 +9,7 @@ We have separated the example scripts into the following subfolders, please refe 2. oss_scripts: OSS stands for Open Source Software. This folder contains python scripts for open source models. Some models under this folder might also have their own customized runner. For example, [llama](oss_scripts/llama/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model. -3. qaihub_scripts: QAIHub stands for [Qualcomm AI Hub](https://aihub.qualcomm.com/). On QAIHub, users can find pre-compiled context binaries, a format used by QNN to save its models. This provides users with a new option for model deployment. Different from oss_scripts & scripts, which the example scripts are converting a model from nn.Module to ExecuTorch .pte files, qaihub_scripts provides example scripts for converting pre-compiled context binaries to ExecuTorch .pte files. Additionally, users can find customized example runners specific to the QAIHub models for execution. For example [qaihub_llama2_7b](qaihub_scripts/llama/llama2/qaihub_llama2_7b.py) is a script converting context binaries to ExecuTorch .pte files, and [qaihub_llama2_7b_runner](qaihub_scripts/llama/llama2/qaihub_llama2_7b_runner.cpp) is a customized example runner to execute llama2 .pte files. Please be aware that context-binaries downloaded from QAIHub are tied to a specific QNN SDK version. -Before executing the scripts and runner, please ensure that you are using the QNN SDK version that is matching the context binary. Please refer to [Check context binary version](#check-context-binary-version) for tutorial on how to check the QNN Version for a context binary. + 4. scripts: This folder contains scripts to build models provided by ExecuTorch. From bafc40c12df15caf6b333eb89eeabe70280afea3 Mon Sep 17 00:00:00 2001 From: Winston Kuo Date: Tue, 21 Apr 2026 17:31:22 +0800 Subject: [PATCH 17/21] Delete more aihub readme under executorch/examples/qualcomm/README.md --- examples/qualcomm/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/qualcomm/README.md b/examples/qualcomm/README.md index eca0442c8fb..a40a5bdfe85 100644 --- a/examples/qualcomm/README.md +++ b/examples/qualcomm/README.md @@ -9,8 +9,6 @@ We have separated the example scripts into the following subfolders, please refe 2. oss_scripts: OSS stands for Open Source Software. This folder contains python scripts for open source models. Some models under this folder might also have their own customized runner. For example, [llama](oss_scripts/llama/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model. - - 4. scripts: This folder contains scripts to build models provided by ExecuTorch. 5. util_scripts: This folder includes tutorial example scripts designed to showcase the utilities we've developed. For example, we provide a debugging tool [qnn_intermediate_debugger](./util_scripts/qnn_intermediate_debugger_demo.py) that allow users to compare the intermediate outputs of QNNs V.S. CPUs. By reviewing these scripts, we aim to help users smoothly integrate these utilities into their own projects. From 69d56bcba9598f25609745b8863541de0f35b923 Mon Sep 17 00:00:00 2001 From: Winston Kuo Date: Wed, 22 Apr 2026 10:09:21 +0800 Subject: [PATCH 18/21] Update README to fix minor error --- examples/qualcomm/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/qualcomm/README.md b/examples/qualcomm/README.md index a40a5bdfe85..95c33dc75b8 100644 --- a/examples/qualcomm/README.md +++ b/examples/qualcomm/README.md @@ -9,9 +9,9 @@ We have separated the example scripts into the following subfolders, please refe 2. oss_scripts: OSS stands for Open Source Software. This folder contains python scripts for open source models. Some models under this folder might also have their own customized runner. For example, [llama](oss_scripts/llama/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model. -4. scripts: This folder contains scripts to build models provided by ExecuTorch. +3. scripts: This folder contains scripts to build models provided by ExecuTorch. -5. util_scripts: This folder includes tutorial example scripts designed to showcase the utilities we've developed. For example, we provide a debugging tool [qnn_intermediate_debugger](./util_scripts/qnn_intermediate_debugger_demo.py) that allow users to compare the intermediate outputs of QNNs V.S. CPUs. By reviewing these scripts, we aim to help users smoothly integrate these utilities into their own projects. +4. util_scripts: This folder includes tutorial example scripts designed to showcase the utilities we've developed. For example, we provide a debugging tool [qnn_intermediate_debugger](./util_scripts/qnn_intermediate_debugger_demo.py) that allow users to compare the intermediate outputs of QNNs V.S. CPUs. By reviewing these scripts, we aim to help users smoothly integrate these utilities into their own projects. From dea073183584b68e26e6e7d78f2880e21af7a33c Mon Sep 17 00:00:00 2001 From: Winston Kuo Date: Wed, 22 Apr 2026 10:12:08 +0800 Subject: [PATCH 19/21] More fix on readme removal --- examples/qualcomm/README.md | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/examples/qualcomm/README.md b/examples/qualcomm/README.md index 95c33dc75b8..c55f38dab95 100644 --- a/examples/qualcomm/README.md +++ b/examples/qualcomm/README.md @@ -71,17 +71,6 @@ python mobilenet_v2.py -s -m "SM8550" -b path/to/build-android/ python deeplab_v3.py -s -m "SM8550" -b path/to/build-android/ --download ``` -#### Check context binary version -This is typically useful when users want to run any models under `qaihub_scripts`. When users retrieve context binaries from Qualcomm AI Hub, we need to ensure the QNN SDK used to run the `qaihub_scripts` is the same version as the QNN SDK that Qualcomm AI Hub used to compile the context binaries. To do so, please run the following script to retrieve the JSON file that contains the metadata about the context binary: -```bash -cd ${QNN_SDK_ROOT}/bin/x86_64-linux-clang -./qnn-context-binary-utility --context_binary ${PATH_TO_CONTEXT_BINARY} --json_file ${OUTPUT_JSON_NAME} -``` -After retrieving the json file, search in the json file for the field "buildId" and ensure it matches the `${QNN_SDK_ROOT}` you are using for the environment variable. -If you run into the following error, that means the ${QNN_SDK_ROOT} that you are using is older than the context binary's QNN SDK version. In this case, please download a newer QNN SDK version. -``` -Error: Failed to get context binary info. -``` ## Model Structure This section outlines the essential APIs and utilities provided to streamline the process of model conversion, deployment, and evaluation on Qualcomm hardware using ExecuTorch. The official APIs can be found under [export_utils.py](../../backends/qualcomm/export_utils.py) From 3d4c00d3d031dce4d17707d21d1145c319800ae0 Mon Sep 17 00:00:00 2001 From: Winston Kuo Date: Wed, 22 Apr 2026 10:35:46 +0800 Subject: [PATCH 20/21] Remove SKILL.md info related to aihub and aihub info under backends/qualcomm/README.md --- .claude/skills/qualcomm/SKILL.md | 1 - backends/qualcomm/README.md | 1 - 2 files changed, 2 deletions(-) diff --git a/.claude/skills/qualcomm/SKILL.md b/.claude/skills/qualcomm/SKILL.md index bb0c5f017e7..bcbd581e293 100644 --- a/.claude/skills/qualcomm/SKILL.md +++ b/.claude/skills/qualcomm/SKILL.md @@ -93,6 +93,5 @@ Required flags: `-m` (SoC model), `-b` (Android build dir). Optional: `-s` (devi | `TestExampleLLMScript` | LLM script tests | | `TestExampleMultimodalityScript` | Multimodality script tests | | `TestExampleOssScript` | OSS model script tests | -| `TestExampleQaihubScript` | QAI Hub script tests | | `TestExampleScript` | General example script tests | | `TestUtilsScript` | Utility script tests | diff --git a/backends/qualcomm/README.md b/backends/qualcomm/README.md index 68375d1287b..6b51169af2f 100644 --- a/backends/qualcomm/README.md +++ b/backends/qualcomm/README.md @@ -61,7 +61,6 @@ backends/qualcomm examples/qualcomm ├── executor_runner # A general runner that is capable of running most of the basic models. ├── oss_scripts # Scripts for OSS(Open Source Software) models and customized runner for some specific models. -├── qaihub_scripts # Scripts for Qaihub models and corresponding customized runner for these models. └── scripts # Scripts for models provided by executorch. ``` From 98bc51856ad898f1bf7caa28c69f1da33c0bb080 Mon Sep 17 00:00:00 2001 From: Winston Kuo Date: Tue, 21 Apr 2026 17:43:14 +0800 Subject: [PATCH 21/21] Remove cpp, python, and cmake files --- backends/qualcomm/tests/test_qnn_delegate.py | 202 ------ examples/qualcomm/CMakeLists.txt | 6 - .../qaihub_scripts/llama/CMakeLists.txt | 85 --- .../qualcomm/qaihub_scripts/llama/README.md | 57 -- .../llama/llama2/qaihub_llama2_7b.py | 241 ------- .../llama/llama2/qaihub_llama2_7b_runner.cpp | 81 --- .../llama/llama3/qaihub_llama3_8b.py | 248 ------- .../llama/llama3/qaihub_llama3_8b_runner.cpp | 88 --- .../qaihub_scripts/llama/runner/io_memory.cpp | 538 --------------- .../qaihub_scripts/llama/runner/io_memory.h | 168 ----- .../qaihub_scripts/llama/runner/runner.cpp | 418 ------------ .../qaihub_scripts/llama/runner/runner.h | 113 ---- .../stable_diffusion/CMakeLists.txt | 37 -- .../qaihub_scripts/stable_diffusion/README.md | 38 -- .../stable_diffusion/install_requirements.sh | 3 - .../qaihub_stable_diffusion.py | 403 ------------ .../qaihub_stable_diffusion_runner.cpp | 141 ---- .../stable_diffusion/runner/runner.cpp | 617 ------------------ .../stable_diffusion/runner/runner.h | 141 ---- .../stable_diffusion/stable_diffusion_lib.py | 22 - .../qualcomm/qaihub_scripts/utils/README.md | 102 --- .../qualcomm/qaihub_scripts/utils/export.py | 507 -------------- .../qualcomm/qaihub_scripts/utils/utils.py | 82 --- examples/qualcomm/util_scripts/cli.py | 8 +- 24 files changed, 7 insertions(+), 4339 deletions(-) delete mode 100644 examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt delete mode 100644 examples/qualcomm/qaihub_scripts/llama/README.md delete mode 100644 examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py delete mode 100644 examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b_runner.cpp delete mode 100644 examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py delete mode 100644 examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b_runner.cpp delete mode 100644 examples/qualcomm/qaihub_scripts/llama/runner/io_memory.cpp delete mode 100644 examples/qualcomm/qaihub_scripts/llama/runner/io_memory.h delete mode 100644 examples/qualcomm/qaihub_scripts/llama/runner/runner.cpp delete mode 100644 examples/qualcomm/qaihub_scripts/llama/runner/runner.h delete mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt delete mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/README.md delete mode 100755 examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh delete mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py delete mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp delete mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp delete mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h delete mode 100644 examples/qualcomm/qaihub_scripts/stable_diffusion/stable_diffusion_lib.py delete mode 100644 examples/qualcomm/qaihub_scripts/utils/README.md delete mode 100644 examples/qualcomm/qaihub_scripts/utils/export.py delete mode 100644 examples/qualcomm/qaihub_scripts/utils/utils.py diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 3e236952933..9b89e1661f6 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import csv -import io import itertools import json import logging @@ -8517,207 +8516,6 @@ def test_whisper(self): self.assertLessEqual(msg["wer"], 0.25) -class TestExampleQaihubScript(TestQNN): - def test_utils_export(self): - with tempfile.TemporaryDirectory() as tmp_dir: - module = ContextBinaryExample() # noqa: F405 - generate_context_binary( - module=module, - inputs=module.example_inputs(), - quantized=True, - artifact_dir=tmp_dir, - ) - ctx_path = f"{tmp_dir}/model_ctx.bin" - fpath = f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/utils/export.py" - - # do compilation - compile_cmds = [ - "python", - fpath, - "compile", - "-a", - ctx_path, - "-m", - self.soc_model, - "-l", - "False", - "-b", - self.build_folder, - "-o", - f"{tmp_dir}/output_pte", - ] - compile_process = subprocess.Popen( - compile_cmds, stdout=subprocess.DEVNULL, cwd=self.executorch_root - ) - output_pte_dir = f"{tmp_dir}/output_pte/model_ctx" - compile_process.communicate() - - # check artifacts are correctly generated - self.assertTrue( - all( - [ - Path(output_pte_dir).exists(), - Path(f"{output_pte_dir}/model_ctx.json").exists(), - Path(f"{output_pte_dir}/model_ctx.svg").exists(), - ] - ) - ) - - # prepare input files - input_list, inputs = [], module.example_inputs() - for name, tensor in inputs.items(): - tensor_path = f"{output_pte_dir}/{name}.pt" - torch.save(tensor, tensor_path) - input_list.append(tensor_path) - - # do execution - output_data_dir = f"{tmp_dir}/output_data" - execute_cmds = [ - "python", - fpath, - "execute", - "-p", - output_pte_dir, - "-i", - *input_list, - "-s", - self.device, - "-z", - "-b", - self.build_folder, - "-o", - output_data_dir, - ] - if self.host is not None: - execute_cmds.append(f"-H {self.host}") - execute_process = subprocess.Popen(execute_cmds, cwd=self.executorch_root) - execute_process.communicate() - - # read outputs - with open(f"{output_pte_dir}/model_ctx.json", "r") as f: - graph_info = json.load(f) - - device_output = [] - for output in graph_info["outputs"]: - with open(f"{output_data_dir}/{output['name']}.pt", "rb") as f: - buffer = io.BytesIO(f.read()) - device_output.append(torch.load(buffer, weights_only=False)) - - # validate outputs - golden_output = module.forward(inputs["x"], inputs["y"]) - self.atol, self.rtol = 1e-1, 1 - self._assert_outputs_equal(golden_output, device_output) - - def test_llama2_7b(self): - if not self.required_envs(): - self.skipTest("missing required envs") - - prompt = "Explain the rules of baseball" - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--tokenizer_bin", - f"{self.artifact_dir}/tokenizer.bin", - "--context_binaries", - f"{self.artifact_dir}", - "--prompt", - f"{prompt}", - ] - self.add_default_cmds(cmds) - - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - model_out = msg["result"] - self.assertTrue(model_out.startswith(prompt)) - - def test_llama3_8b(self): - if not self.required_envs(): - self.skipTest("missing required envs") - - prompt = "Explain the rules of baseball" - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--tokenizer_model", - f"{self.artifact_dir}/tokenizer.model", - "--context_binaries", - f"{self.artifact_dir}", - "--prompt", - f"{prompt}", - ] - self.add_default_cmds(cmds) - - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - model_out = msg["result"] - expected_result = ( - "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" - + prompt - + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>" - ) - self.assertTrue(model_out.startswith(expected_result)) - - def test_stable_diffusion(self): - if not self.required_envs(): - self.skipTest("missing required envs") - - prompt = "a photo of an astronaut riding a horse on mars" - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--text_encoder_bin", - f"{self.artifact_dir}/text_encoder.serialized.bin", - "--unet_bin", - f"{self.artifact_dir}/unet.serialized.bin", - "--vae_bin", - f"{self.artifact_dir}/vae.serialized.bin", - "--vocab_json", - f"{self.artifact_dir}/vocab.json", - "--num_time_steps", - "20", - "--prompt", - f"{prompt}", - "--fix_latents", - ] - self.add_default_cmds(cmds) - - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - if "Error" in msg: - self.fail(msg["Error"]) - else: - # For the default settings and prompt, the expected results will be {PSNR: 23.258, SSIM: 0.852} - self.assertGreaterEqual(msg["PSNR"], 20) - self.assertGreaterEqual(msg["SSIM"], 0.8) - - class TestExampleScript(TestQNN): def test_mobilenet_v2(self): if not self.required_envs([self.image_dataset]): diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index d7403030ca6..31c9f4f0be4 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -80,12 +80,6 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/t5) # build qnn_whisper_runner for whisper add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/whisper) -# build qaihub_llama2_7b_runner and qaihub_llama3_8b_runner -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama) - -# build qaihub_stable_diffusion_runner -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/stable_diffusion) - # direct-mode qnn_executor_direct_runner is build with ndk toolchain, paired # with libraries built with hexagon toolchain, communicate through self-defined # fastrpc protocol. diff --git a/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt b/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt deleted file mode 100644 index b42ceef6eae..00000000000 --- a/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# preprocess qaihub runner src files for llama2,3 -set(_qaihub_llama_runner__srcs ${_llama_runner__srcs}) -set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) - -list(TRANSFORM _qaihub_llama_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/") -list(FILTER _qaihub_llama_runner__srcs EXCLUDE REGEX ".*(/runner/).*") -list( - PREPEND - _qaihub_llama_runner__srcs - ${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp - ${CMAKE_CURRENT_LIST_DIR}/runner/runner.h - ${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.cpp - ${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.h -) - -# preprocess qaihub llama2 7b runner src files -set(_qaihub_llama2_7b_runner__srcs ${_qaihub_llama_runner__srcs}) - -list(PREPEND _qaihub_llama2_7b_runner__srcs - ${CMAKE_CURRENT_LIST_DIR}/llama2/qaihub_llama2_7b_runner.cpp -) - -# build qaihub llama2 7b runner -add_executable(qaihub_llama2_7b_runner ${_qaihub_llama2_7b_runner__srcs}) - -target_include_directories( - qaihub_llama2_7b_runner PUBLIC ${_common_include_directories} -) -target_link_libraries( - qaihub_llama2_7b_runner - qnn_executorch_backend - executorch_core - extension_data_loader - extension_flat_tensor - extension_llm_runner - extension_module - extension_tensor - gflags -) -target_compile_options( - qaihub_llama2_7b_runner PUBLIC ${_common_compile_options} -) -set_target_properties( - qaihub_llama2_7b_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" -) - -# preprocess qaihub llama3 8b runner src files -set(_qaihub_llama3_8b_runner__srcs ${_qaihub_llama_runner__srcs}) - -list(PREPEND _qaihub_llama3_8b_runner__srcs - ${CMAKE_CURRENT_LIST_DIR}/llama3/qaihub_llama3_8b_runner.cpp -) - -# Adding a compile option to differentiate llama2 with llama3 logic -list(APPEND _common_compile_options -DQAIHUB_LLAMA3_RUNNER) - -# build qaihub llama3 8b runner -add_executable(qaihub_llama3_8b_runner ${_qaihub_llama3_8b_runner__srcs}) -target_include_directories( - qaihub_llama3_8b_runner PUBLIC ${_common_include_directories} -) - -target_link_libraries( - qaihub_llama3_8b_runner - qnn_executorch_backend - executorch_core - extension_data_loader - extension_flat_tensor - extension_llm_runner - extension_module - extension_tensor - gflags -) -target_compile_options( - qaihub_llama3_8b_runner PUBLIC ${_common_compile_options} -) -set_target_properties( - qaihub_llama3_8b_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" -) diff --git a/examples/qualcomm/qaihub_scripts/llama/README.md b/examples/qualcomm/qaihub_scripts/llama/README.md deleted file mode 100644 index 887aeb0394f..00000000000 --- a/examples/qualcomm/qaihub_scripts/llama/README.md +++ /dev/null @@ -1,57 +0,0 @@ -# Summary - -## Overview -This file provides you the instructions to run LLAMA2 and LLAMA3 with different parameters via Qualcomm HTP backend. Following settings support for Llama-2-7b-chat-hf and Llama-3-8b-chat-hf - -Please check corresponding section for more information. - -## Llama-2-7b-chat-hf -This example demonstrates how to run Llama-2-7b-chat-hf on mobile via Qualcomm HTP backend. Model was precompiled into context binaries by [Qualcomm AI HUB](https://aihub.qualcomm.com/). -Note that the pre-compiled context binaries could not be further fine-tuned for other downstream tasks. - -### Instructions -#### Step 1: Setup -1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. -2. Follow the [tutorial](https://pytorch.org/executorch/main/backends-qualcomm) to build Qualcomm AI Engine Direct Backend. - -#### Step2: Prepare Model -1. Create account for https://aihub.qualcomm.com/ -2. Follow instructions in https://huggingface.co/qualcomm/Llama-v2-7B-Chat to export context binaries (will take some time to finish) - -```bash -# tokenizer.model: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main -# tokenizer.bin: -python -m examples.models.llama.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin -``` - -#### Step3: Verify context binary's version -Please refer to [Check context binary version](../../README.md#check-context-binary-version) for more info on why and how to verify the context binary's version - -#### Step4: Run default examples -```bash -# AIHUB_CONTEXT_BINARIES: ${PATH_TO_AIHUB_WORKSPACE}/build/llama_v2_7b_chat_quantized -python examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --context_binaries ${AIHUB_CONTEXT_BINARIES} --tokenizer_bin tokenizer.bin --prompt "What is Python?" -``` - -## Llama-3-8b-chat-hf -This example demonstrates how to run Llama-3-8b-chat-hf on mobile via Qualcomm HTP backend. Model was precompiled into context binaries by [Qualcomm AI HUB](https://aihub.qualcomm.com/). -Note that the pre-compiled context binaries could not be further fine-tuned for other downstream tasks. This example script has been tested on a 16GB RAM device and verified to work. - -### Instructions -#### Step 1: Setup -1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. -2. Follow the [tutorial](https://pytorch.org/executorch/main/backends-qualcomm) to build Qualcomm AI Engine Direct Backend. - -#### Step2: Prepare Model -1. Create account for https://aihub.qualcomm.com/ -2. Follow instructions in https://huggingface.co/qualcomm/Llama-v3-8B-Chat to export context binaries (will take some time to finish) -3. For Llama 3 tokenizer, please refer to https://github.com/meta-llama/llama-models/blob/main/README.md for further instructions on how to download tokenizer.model. - -#### Step3: Verify context binary's version -Please refer to [Check context binary version](../../README.md#check-context-binary-version) for more info on why and how to verify the context binary's version - -#### Step4: Run default examples -```bash -# AIHUB_CONTEXT_BINARIES: ${PATH_TO_AIHUB_WORKSPACE}/build/llama_v3_8b_chat_quantized -python examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --context_binaries ${AIHUB_CONTEXT_BINARIES} --tokenizer_model tokenizer.model --prompt "What is baseball?" -``` diff --git a/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py b/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py deleted file mode 100644 index bc999a67de9..00000000000 --- a/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os -from multiprocessing.connection import Client - -import torch - -from executorch.backends.qualcomm.export_utils import ( - QnnConfig, - setup_common_args_and_variables, - SimpleADB, -) -from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset -from executorch.backends.qualcomm.utils.utils import ( - from_context_binary, - generate_htp_compiler_spec, - generate_qnn_executorch_compiler_spec, - get_soc_to_chipset_map, -) -from executorch.examples.qualcomm.qaihub_scripts.utils.utils import ( - gen_pte_from_ctx_bin, - get_encoding, -) -from executorch.exir.capture._config import ExecutorchBackendConfig -from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass - - -def main(args): - qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) - - os.makedirs(args.artifact, exist_ok=True) - - target_names = ( - [ - f"llama_v2_7b_chat_quantized_PromptProcessor_{i}_Quantized.bin" - for i in range(1, 5) - ] - if args.use_prompt_processor - else [ - f"llama_v2_7b_chat_quantized_TokenGenerator_{i}_Quantized.bin" - for i in range(1, 5) - ] - ) - - # common part for compile & inference - backend_options = generate_htp_compiler_spec( - use_fp16=False, - use_multi_contexts=True, - ) - compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.soc_model), - backend_options=backend_options, - is_from_context_binary=True, - ) - - if args.use_prompt_processor: - pte_name = "qaihub_llama2_7b_prompt" - last_shard_num_inputs = 4 - last_shard_num_outputs = 513 - else: - pte_name = "qaihub_llama2_7b_token" - last_shard_num_inputs = 516 - last_shard_num_outputs = 513 - - if args.pre_gen_pte is None: - # create custom operators as context loader - soc_model = get_soc_to_chipset_map()[args.soc_model] - bundle_programs = [ - from_context_binary( - ctx_path=f"{args.context_binaries}/{target}", - op_name=f"ctx_loader_{i}", - soc_model=soc_model, - ) - for i, target in enumerate(target_names) - ] - pte_names = [f"{pte_name}_{i}" for i in range(len(target_names))] - memory_planning_pass = MemoryPlanningPass( - alloc_graph_input=False, - alloc_graph_output=False, - ) - pte_files = gen_pte_from_ctx_bin( - artifact=args.artifact, - pte_names=pte_names, - bundle_programs=bundle_programs, - backend_config=ExecutorchBackendConfig( - memory_planning_pass=memory_planning_pass - ), - ) - else: - pte_files = [f"{args.pre_gen_pte}/{pte_name}_{i}.pte" for i in range(4)] - - if args.compile_only: - return - - adb = SimpleADB( - qnn_config=qnn_config, - pte_path=pte_files, - workspace=f"/data/local/tmp/executorch/{pte_name}", - runner="examples/qualcomm/qaihub_scripts/llama/qaihub_llama2_7b_runner", - ) - output_file = "result.txt" - pos_embs_file = ["freq_cos", "freq_sin"] - encoding = get_encoding( - path_to_shard=f"{args.context_binaries}/{target_names[-1]}", - compiler_specs=compiler_specs, - get_input=False, - get_output=True, - num_input=last_shard_num_inputs, - num_output=last_shard_num_outputs, - )[0] - scale = encoding["scale"][-1] - offset = encoding["offset"][-1] - outputs = [] - runner_args = [ - *[ - f"--sharded_{i+1}_path {os.path.basename(pte_file)}" - for i, pte_file in enumerate(pte_files) - ], - *[f"--{fname}_path {fname}.raw" for fname in pos_embs_file], - f"--output_path {adb.output_folder}/{output_file}", - f"--tokenizer_path {os.path.basename(args.tokenizer_bin)}", - f"--prompt '{args.prompt}'", - f"--temperature {args.temperature}", - f"--seq_len {args.seq_len}", - f"--eval_mode {0 if args.use_prompt_processor else 1}", - f"--logits_scale {scale}", - f"--logits_offset {-offset}", - ] - runner_cmds = " ".join( - [ - f"cd {adb.workspace} &&", - f"./qaihub_llama2_7b_runner {' '.join(runner_args)}", - ] - ) - - def compute_pos_embedding(): - head_dim, max_seq_len, theta = 128, 1024, 10000.0 - base = torch.arange(0, head_dim, 2) - freqs = 1.0 / (theta ** (base[: (head_dim // 2)].float() / head_dim)) - t = torch.arange(max_seq_len * 2) - freqs = torch.outer(t, freqs).float() - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - freqs_cis = freqs_cis[0:max_seq_len] - freqs_real = torch.view_as_real(freqs_cis) - return freqs_real[:, :, 0], freqs_real[:, :, 1] - - def post_process(): - with open(f"{args.artifact}/outputs/{output_file}", "r") as f: - outputs.append(f.read()) - - custom_files = [args.tokenizer_bin] - for var_name, freq in zip(pos_embs_file, compute_pos_embedding()): - custom_files.append(f"{adb.working_dir}/{var_name}.raw") - scale, offset = (freq.max() - freq.min()) / 65535, 32768 - freq = (freq / scale + offset).clip(min=0, max=65535).detach() - freq.to(dtype=torch.uint16).numpy().tofile(custom_files[-1]) - - if not args.skip_push: - adb.push(files=custom_files) - adb.execute(custom_runner_cmd=runner_cmds) - adb.pull(args.artifact, callback=post_process) - if args.ip and args.port != -1: - with Client((args.ip, args.port)) as conn: - conn.send( - json.dumps( - { - "result": outputs[0], - } - ) - ) - else: - print(outputs[0]) - - -if __name__ == "__main__": - parser = setup_common_args_and_variables() - - parser.add_argument( - "-a", - "--artifact", - help="path for storing generated artifacts by this example. Default ./llama2_qai_hub", - default="./llama2_qai_hub", - type=str, - ) - - parser.add_argument( - "--context_binaries", - help="path to context binaries generated from qai_hub", - required=True, - ) - - parser.add_argument( - "--use_prompt_processor", - help="tokens will be evaluated all at once", - default=False, - action="store_true", - ) - - parser.add_argument( - "--tokenizer_bin", - help="llama2 tokenizer binary", - required=True, - type=str, - ) - - parser.add_argument( - "--seq_len", - help="ouput sequence length for llama2", - default=128, - type=int, - ) - - parser.add_argument( - "--temperature", - help="sampling temperature for llama2", - default=0.0, - type=float, - ) - - parser.add_argument( - "--prompt", - help="user prompts for llama2", - required=True, - type=str, - ) - - args = parser.parse_args() - - try: - main(args) - except Exception as e: - if args.ip and args.port != -1: - with Client((args.ip, args.port)) as conn: - conn.send(json.dumps({"Error": str(e)})) - else: - raise Exception(e) diff --git a/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b_runner.cpp b/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b_runner.cpp deleted file mode 100644 index 3de97cde7e8..00000000000 --- a/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b_runner.cpp +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -/** - * @file - * - * This tool can run Llama2 7b with Qualcomm AI Engine Direct. - * - * User could specify arguments like desired prompt, eval_mode, etc. - */ - -#include -#include -#include - -#include - -#include - -DEFINE_string(sharded_1_path, "", "Path to 1st sharded pte file"); -DEFINE_string(sharded_2_path, "", "Path to 2nd sharded pte file"); -DEFINE_string(sharded_3_path, "", "Path to 3rd sharded pte file"); -DEFINE_string(sharded_4_path, "", "Path to 4th sharded pte file"); - -DEFINE_string(freq_cos_path, "", "Path to precomputed position embeddings"); -DEFINE_string(freq_sin_path, "", "Path to precomputed position embeddings"); - -DEFINE_string(output_path, "outputs", "Executorch inference data output path."); -DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff."); -DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt."); -DEFINE_double( - temperature, - 0.0f, - "Temperature; Default is 0.0f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic"); -DEFINE_int32( - eval_mode, - 0, - "0: PromptProcessor / 1: TokenGenerator / 2: MixedMode (TBD)"); -DEFINE_int32( - seq_len, - 128, - "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens."); -DEFINE_double(logits_scale, 0.0, "Path to logits scale file"); -DEFINE_int32(logits_offset, 0, "Path to logits offset file"); - -int main(int argc, char** argv) { - gflags::ParseCommandLineFlags(&argc, &argv, true); - - std::vector models_path = { - FLAGS_sharded_1_path, - FLAGS_sharded_2_path, - FLAGS_sharded_3_path, - FLAGS_sharded_4_path}; - std::vector pos_embs_path = { - FLAGS_freq_cos_path, FLAGS_freq_sin_path}; - - // create llama runner - example::Runner runner( - models_path, - pos_embs_path, - {8, 8, 8, 8}, - FLAGS_tokenizer_path.c_str(), - FLAGS_eval_mode, - FLAGS_temperature, - FLAGS_logits_scale, - FLAGS_logits_offset); - - // generate tokens & store inference output - std::ofstream fout(FLAGS_output_path.c_str()); - runner.generate( - FLAGS_prompt, "", FLAGS_seq_len, [&](const std::string& piece) { - fout << piece; - }); - fout.close(); - return 0; -} diff --git a/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py b/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py deleted file mode 100644 index 9da728767af..00000000000 --- a/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os -from multiprocessing.connection import Client - -import torch -from executorch.backends.qualcomm.export_utils import ( - QnnConfig, - setup_common_args_and_variables, - SimpleADB, -) -from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset -from executorch.backends.qualcomm.utils.utils import ( - from_context_binary, - generate_htp_compiler_spec, - generate_qnn_executorch_compiler_spec, - get_soc_to_chipset_map, -) -from executorch.examples.qualcomm.qaihub_scripts.utils.utils import ( - gen_pte_from_ctx_bin, - get_encoding, -) -from executorch.exir.capture._config import ExecutorchBackendConfig -from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass - - -def main(args): - qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) - - os.makedirs(args.artifact, exist_ok=True) - - target_names = ( - [ - f"llama_v3_8b_chat_quantized_PromptProcessor_{i}_Quantized.bin" - for i in range(1, 6) - ] - if args.use_prompt_processor - else [ - f"llama_v3_8b_chat_quantized_TokenGenerator_{i}_Quantized.bin" - for i in range(1, 6) - ] - ) - - # common part for compile & inference - backend_options = generate_htp_compiler_spec( - use_fp16=False, - use_multi_contexts=True, - ) - compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.soc_model), - backend_options=backend_options, - is_from_context_binary=True, - ) - - if args.use_prompt_processor: - pte_name = "qaihub_llama3_8b_prompt" - last_shard_num_inputs = 4 - last_shard_num_outputs = 65 - else: - pte_name = "qaihub_llama3_8b_token" - last_shard_num_inputs = 68 - last_shard_num_outputs = 65 - - if args.pre_gen_pte is None: - # create custom operators as context loader - soc_model = get_soc_to_chipset_map()[args.soc_model] - bundle_programs = [ - from_context_binary( - ctx_path=f"{args.context_binaries}/{target}", - op_name=f"ctx_loader_{i}", - soc_model=soc_model, - ) - for i, target in enumerate(target_names) - ] - pte_names = [f"{pte_name}_{i}" for i in range(len(target_names))] - memory_planning_pass = MemoryPlanningPass( - alloc_graph_input=False, - alloc_graph_output=False, - ) - pte_files = gen_pte_from_ctx_bin( - artifact=args.artifact, - pte_names=pte_names, - bundle_programs=bundle_programs, - backend_config=ExecutorchBackendConfig( - memory_planning_pass=memory_planning_pass - ), - ) - else: - pte_files = [f"{args.pre_gen_pte}/{pte_name}_{i}.pte" for i in range(5)] - - if args.compile_only: - return - - adb = SimpleADB( - qnn_config=qnn_config, - pte_path=pte_files, - workspace=f"/data/local/tmp/executorch/{pte_name}", - runner="examples/qualcomm/qaihub_scripts/llama/qaihub_llama3_8b_runner", - ) - output_file = "result.txt" - pos_embs_file = ["freq_cos", "freq_sin"] - - encoding = get_encoding( - path_to_shard=f"{args.context_binaries}/{target_names[-1]}", - compiler_specs=compiler_specs, - get_input=False, - get_output=True, - num_input=last_shard_num_inputs, - num_output=last_shard_num_outputs, - )[0] - scale = encoding["scale"][-1] - offset = encoding["offset"][-1] - outputs = [] - runner_args = [ - *[ - f"--sharded_{i+1}_path {os.path.basename(pte_file)}" - for i, pte_file in enumerate(pte_files) - ], - *[f"--{fname}_path {fname}.raw" for fname in pos_embs_file], - f"--output_path {adb.output_folder}/{output_file}", - f"--tokenizer_path {os.path.basename(args.tokenizer_model)}", - f"--prompt '{args.prompt}'", - f"--temperature {args.temperature}", - f"--seq_len {args.seq_len}", - f"--eval_mode {0 if args.use_prompt_processor else 1}", - f"--logits_scale {scale}", - f"--logits_offset {-offset}", - f"--system_prompt '{args.system_prompt}'", - ] - runner_cmds = " ".join( - [ - f"cd {adb.workspace} &&", - f"./qaihub_llama3_8b_runner {' '.join(runner_args)}", - ] - ) - - def compute_pos_embedding(): - head_dim, max_seq_len, theta = 128, 1024, 10000.0 - base = torch.arange(0, head_dim, 2) - freqs = 1.0 / (theta ** (base[: (head_dim // 2)].float() / head_dim)) - t = torch.arange(max_seq_len * 2) - freqs = torch.outer(t, freqs).float() - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - freqs_cis = freqs_cis[0:max_seq_len] - freqs_real = torch.view_as_real(freqs_cis) - return freqs_real[:, :, 0], freqs_real[:, :, 1] - - def post_process(): - with open(f"{args.artifact}/outputs/{output_file}", "r") as f: - outputs.append(f.read()) - - custom_files = [args.tokenizer_model] - for var_name, freq in zip(pos_embs_file, compute_pos_embedding()): - custom_files.append(f"{adb.working_dir}/{var_name}.raw") - scale, offset = (freq.max() - freq.min()) / 65535, 32768 - freq = (freq / scale + offset).clip(min=0, max=65535).detach() - freq.to(dtype=torch.uint16).numpy().tofile(custom_files[-1]) - - adb.push(files=custom_files) - adb.execute(custom_runner_cmd=runner_cmds) - adb.pull(args.artifact, callback=post_process) - if args.ip and args.port != -1: - with Client((args.ip, args.port)) as conn: - conn.send( - json.dumps( - { - "result": outputs[0], - } - ) - ) - else: - print(outputs[0]) - - -if __name__ == "__main__": - parser = setup_common_args_and_variables() - - parser.add_argument( - "-a", - "--artifact", - help="path for storing generated artifacts by this example. Default ./llama3_qai_hub", - default="./llama3_qai_hub", - type=str, - ) - - parser.add_argument( - "--context_binaries", - help="path to context binaries generated from qai_hub", - required=True, - ) - - parser.add_argument( - "--use_prompt_processor", - help="tokens will be evaluated all at once", - default=False, - action="store_true", - ) - - parser.add_argument( - "--tokenizer_model", - help="llama3 tokenizer model", - required=True, - type=str, - ) - - parser.add_argument( - "--seq_len", - help="ouput sequence length for llama3", - default=128, - type=int, - ) - - parser.add_argument( - "--temperature", - help="sampling temperature for llama3", - default=0.0, - type=float, - ) - - parser.add_argument( - "--prompt", - help="user prompts for llama3", - required=True, - type=str, - ) - - parser.add_argument( - "--system_prompt", - help="Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None", - default="", - type=str, - ) - - args = parser.parse_args() - - try: - main(args) - except Exception as e: - if args.ip and args.port != -1: - with Client((args.ip, args.port)) as conn: - conn.send(json.dumps({"Error": str(e)})) - else: - raise Exception(e) diff --git a/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b_runner.cpp b/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b_runner.cpp deleted file mode 100644 index 7591b7ae1e9..00000000000 --- a/examples/qualcomm/qaihub_scripts/llama/llama3/qaihub_llama3_8b_runner.cpp +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -/** - * @file - * - * This tool can run Llama3 8b with Qualcomm AI Engine Direct. - * - * User could specify arguments like desired prompt, eval_mode, etc. - */ - -#include -#include -#include - -#include - -#include - -DEFINE_string(sharded_1_path, "", "Path to 1st sharded pte file"); -DEFINE_string(sharded_2_path, "", "Path to 2nd sharded pte file"); -DEFINE_string(sharded_3_path, "", "Path to 3rd sharded pte file"); -DEFINE_string(sharded_4_path, "", "Path to 4th sharded pte file"); -DEFINE_string(sharded_5_path, "", "Path to 5th sharded pte file"); - -DEFINE_string(freq_cos_path, "", "Path to precomputed position embeddings"); -DEFINE_string(freq_sin_path, "", "Path to precomputed position embeddings"); - -DEFINE_string(output_path, "outputs", "Executorch inference data output path."); -DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff."); -DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt."); -DEFINE_string( - system_prompt, - "", - "Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None"); -DEFINE_double( - temperature, - 0.0f, - "Temperature; Default is 0.0f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic"); -DEFINE_int32( - eval_mode, - 0, - "0: PromptProcessor / 1: TokenGenerator / 2: MixedMode (TBD)"); -DEFINE_int32( - seq_len, - 128, - "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens."); -DEFINE_double(logits_scale, 0.0, "Path to logits scale file"); -DEFINE_int32(logits_offset, 0, "Path to logits offset file"); - -int main(int argc, char** argv) { - gflags::ParseCommandLineFlags(&argc, &argv, true); - - std::vector models_path = { - FLAGS_sharded_1_path, - FLAGS_sharded_2_path, - FLAGS_sharded_3_path, - FLAGS_sharded_4_path, - FLAGS_sharded_5_path}; - std::vector pos_embs_path = { - FLAGS_freq_cos_path, FLAGS_freq_sin_path}; - - // create llama runner - example::Runner runner( - models_path, - pos_embs_path, - {4, 8, 8, 8, 4}, - FLAGS_tokenizer_path.c_str(), - FLAGS_eval_mode, - FLAGS_temperature, - FLAGS_logits_scale, - FLAGS_logits_offset); - - // generate tokens & store inference output - std::ofstream fout(FLAGS_output_path.c_str()); - runner.generate( - FLAGS_prompt, - FLAGS_system_prompt, - FLAGS_seq_len, - [&](const std::string& piece) { fout << piece; }); - fout.close(); - return 0; -} diff --git a/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.cpp b/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.cpp deleted file mode 100644 index 9ee7551650a..00000000000 --- a/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.cpp +++ /dev/null @@ -1,538 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -#include -#include - -using executorch::aten::Tensor; -using executorch::aten::TensorImpl; -using executorch::extension::Module; -using executorch::runtime::Error; -using executorch::runtime::MethodMeta; -using executorch::runtime::Result; -using executorch::runtime::TensorInfo; - -namespace example { - -Memory::Memory( - const std::vector& pos_embs_path, - std::vector>& modules) - : data_ptr_(nullptr, [](void*) {}), - input_tensors_(modules.size()), - output_tensors_(modules.size()), - pos_embs_path_(pos_embs_path), - modules_(modules) { - for (std::shared_ptr& module : modules_) { - method_names_.emplace_back(*module->method_names()->begin()); - } -} - -Memory::~Memory() {} - -void* Memory::get_mutable_ptr() { - return data_ptr_.get(); -} - -std::vector Memory::get_input_tensors(int shard_index) { - std::vector ret; - ret.reserve(input_tensors_.size()); - for (TensorImpl* impl : input_tensors_[shard_index]) { - ret.emplace_back(Tensor(impl)); - } - return ret; -} - -std::vector Memory::get_output_tensors(int shard_index) { - std::vector ret; - ret.reserve(output_tensors_.size()); - for (TensorImpl* impl : output_tensors_[shard_index]) { - ret.emplace_back(Tensor(impl)); - } - return ret; -} - -BertMemory::BertMemory( - const std::vector& pos_embs_path, - std::vector>& modules, - std::vector shard_layers) - : Memory(pos_embs_path, modules), - shard_layers_(shard_layers), - num_heads_(QAIHUB_LLAMA_NUM_HEADS) { - data_ptr_ = std::unique_ptr( - new IO, [](void* ptr) { delete static_cast(ptr); }); -} - -void BertMemory::prepare_io( - const std::vector>& methods_meta) { - IO* ptr = static_cast(data_ptr_.get()); - std::memset(ptr, 0, sizeof(IO)); - - for (int i = 0; i < modules_.size(); ++i) { - ET_CHECK_MSG( - methods_meta[i].ok(), - "Failed to get method_meta 0x%x", - static_cast(methods_meta[i].error())); - } - // [I] position embedding initialization - for (size_t i = 0; i < pos_embs_path_.size(); ++i) { - std::ifstream fin(pos_embs_path_[i], std::ios::binary); - fin.read( - reinterpret_cast( - i == 0 ? ptr->position_ids_cos : ptr->position_ids_sin), - 1024 * 64 * 2); - fin.close(); - } - // [I]: all shards (4 shards for llama2, 5 shards for llama) - { - // [I]: input_ids - Result input_ids = methods_meta[0]->input_tensor_meta(0); - input_ids_ = std::make_unique( - input_ids->scalar_type(), - input_ids->sizes().size(), - const_cast(input_ids->sizes().data()), - ptr->input_ids, - const_cast(input_ids->dim_order().data())); - input_tensors_[0].push_back(input_ids_.get()); - // [I]: atten_mask - Result atten_mask = methods_meta[0]->input_tensor_meta(1); - attention_mask_ = std::make_unique( - atten_mask->scalar_type(), - atten_mask->sizes().size(), - const_cast(atten_mask->sizes().data()), - ptr->attention_mask, - const_cast(atten_mask->dim_order().data())); - input_tensors_[0].push_back(attention_mask_.get()); - // [I]: pos_ids_cos - Result pos_ids_cos = methods_meta[0]->input_tensor_meta(2); - position_ids_cos_ = std::make_unique( - pos_ids_cos->scalar_type(), - pos_ids_cos->sizes().size(), - const_cast(pos_ids_cos->sizes().data()), - ptr->position_ids_cos, - const_cast(pos_ids_cos->dim_order().data())); - input_tensors_[0].push_back(position_ids_cos_.get()); - // [I]: pos_ids_sin - Result pos_ids_sin = methods_meta[0]->input_tensor_meta(3); - position_ids_sin_ = std::make_unique( - pos_ids_sin->scalar_type(), - pos_ids_sin->sizes().size(), - const_cast(pos_ids_sin->sizes().data()), - ptr->position_ids_sin, - const_cast(pos_ids_sin->dim_order().data())); - input_tensors_[0].push_back(position_ids_sin_.get()); - // [IO]: hidden_state => [I] shard2,3,4 - int output_index = - shard_layers_[0] * 2 * num_heads_; // layers*(k + v caches)*heads - Result hidden_state = - methods_meta[0]->output_tensor_meta(output_index); - hidden_state_ = std::make_unique( - hidden_state->scalar_type(), - hidden_state->sizes().size(), - const_cast(hidden_state->sizes().data()), - ptr->hidden_state, - const_cast( - hidden_state->dim_order().data())); - // reuse inputs for following tensors - for (int shard_index = 1; shard_index < modules_.size(); ++shard_index) { - // inputs of shards 1 to n: hidden_state, atten_mask, pos_ids_cos, - // pos_ids_sin - input_tensors_[shard_index].push_back(hidden_state_.get()); - input_tensors_[shard_index].push_back(attention_mask_.get()); - input_tensors_[shard_index].push_back(position_ids_cos_.get()); - input_tensors_[shard_index].push_back(position_ids_sin_.get()); - } - } - // [O] kv_cache for all shards (4 shards for llama2 and 5 shards for llama3) - for (int offset = 0, shard_index = 0; shard_index < modules_.size(); - offset += shard_layers_[shard_index], shard_index++) { - for (int layer = 0; layer < shard_layers_[shard_index]; ++layer) { - for (int cache_group = 0; cache_group < 2; ++cache_group) { - for (int head = 0; head < num_heads_; ++head) { - int index = num_heads_ * 2 * layer + cache_group * num_heads_ + head; - Result kv_cache = - methods_meta[shard_index]->output_tensor_meta(index); - std::vector>& cache = - (cache_group == 0 ? v_cache_ : k_cache_); - cache.emplace_back(std::make_unique( - kv_cache->scalar_type(), - kv_cache->sizes().size(), - const_cast(kv_cache->sizes().data()), - cache_group == 0 ? ptr->v_cache[layer + offset][head] - : ptr->k_cache[layer + offset][head], - const_cast( - kv_cache->dim_order().data()))); - output_tensors_[shard_index].push_back(cache.back().get()); - } - } - } - } - // [O]: hidden_state for shard 0 to n-1 - for (int shard_index = 0; shard_index < modules_.size() - 1; ++shard_index) { - output_tensors_[shard_index].push_back(hidden_state_.get()); - } - // [O]: logits - { - int output_index = shard_layers_[modules_.size() - 1] * 2 * - num_heads_; // layers*(k + v caches)*heads - Result logits = - methods_meta[modules_.size() - 1]->output_tensor_meta(output_index); - logits_ = std::make_unique( - logits->scalar_type(), - logits->sizes().size(), - const_cast(logits->sizes().data()), - ptr->logits, - const_cast(logits->dim_order().data())); - output_tensors_[modules_.size() - 1].push_back(logits_.get()); - } -} - -void BertMemory::update_io( - int64_t cur_token, - int64_t pos, - std::vector>& output_tensors) { - (void)output_tensors; - IO* ptr = static_cast(data_ptr_.get()); - static int num_tokens_generated = 0; - int seq_len = 1024, last_index = seq_len - 1; - // refill past token ids, which is equivalent to following snippet: - // ---> - // for (int i = 0; i < last_index; ++i) { - // ptr->input_ids[i] = ptr->input_ids[i + 1]; - // } - // ptr->input_ids[last_index] = static_cast(cur_token); - // <--- - int32_t* new_addr = ++num_tokens_generated + ptr->input_ids; - new_addr[last_index] = static_cast(cur_token); - input_ids_->set_data(new_addr); - // update causal mask for next token - int tokens = pos + 1, start = last_index - tokens; - for (int i = last_index; tokens >= 0; --i, --tokens) { - ptr->attention_mask[i * seq_len + start] = 65535; - } -} - -KVCachedMemory::KVCachedMemory( - const std::vector& pos_embs_path, - std::vector>& modules, - std::vector shard_layers) - : Memory(pos_embs_path, modules), - shard_layers_(shard_layers), - num_heads_(QAIHUB_LLAMA_NUM_HEADS) { - data_ptr_ = std::unique_ptr( - new IO, [](void* ptr) { delete static_cast(ptr); }); - if (num_heads_ == 32) { - futures_ = std::vector>(thread_pool_.num_workers()); - } -} - -void KVCachedMemory::prepare_io( - const std::vector>& methods_meta) { - IO* ptr = static_cast(data_ptr_.get()); - std::memset(ptr, 0, sizeof(IO)); - for (int i = 0; i < modules_.size(); ++i) { - ET_CHECK_MSG( - methods_meta[i].ok(), - "Failed to get method_meta 0x%x", - static_cast(methods_meta[i].error())); - } - // [I] position embedding initialization - for (size_t i = 0; i < pos_embs_path_.size(); ++i) { - std::ifstream fin(pos_embs_path_[i], std::ios::binary); - fin.read( - reinterpret_cast( - i == 0 ? ptr->position_ids_cos : ptr->position_ids_sin), - 1024 * 64 * 2); - fin.close(); - } - // [I]: all shards (4 shards for llama2, 5 shards for llama) - { - // [I]: input_ids - Result input_ids = methods_meta[0]->input_tensor_meta(0); - input_ids_ = std::make_unique( - input_ids->scalar_type(), - input_ids->sizes().size(), - const_cast(input_ids->sizes().data()), - &ptr->input_ids, - const_cast(input_ids->dim_order().data())); - input_tensors_[0].push_back(input_ids_.get()); - // [I]: atten_mask - Result atten_mask = methods_meta[0]->input_tensor_meta(1); - attention_mask_ = std::make_unique( - atten_mask->scalar_type(), - atten_mask->sizes().size(), - const_cast(atten_mask->sizes().data()), - ptr->attention_mask, - const_cast(atten_mask->dim_order().data())); - input_tensors_[0].push_back(attention_mask_.get()); - // [I]: pos_ids_cos - Result pos_ids_cos = methods_meta[0]->input_tensor_meta(2); - position_ids_cos_ = std::make_unique( - pos_ids_cos->scalar_type(), - pos_ids_cos->sizes().size(), - const_cast(pos_ids_cos->sizes().data()), - ptr->position_ids_cos, - const_cast(pos_ids_cos->dim_order().data())); - input_tensors_[0].push_back(position_ids_cos_.get()); - // [I]: pos_ids_sin - Result pos_ids_sin = methods_meta[0]->input_tensor_meta(3); - position_ids_sin_ = std::make_unique( - pos_ids_sin->scalar_type(), - pos_ids_sin->sizes().size(), - const_cast(pos_ids_sin->sizes().data()), - ptr->position_ids_sin, - const_cast(pos_ids_sin->dim_order().data())); - input_tensors_[0].push_back(position_ids_sin_.get()); - // [IO]: hidden_state => [I] shard2,3,4 - int output_index = - shard_layers_[0] * 2 * num_heads_; // layers*(k + v caches)*heads - Result hidden_state = - methods_meta[0]->output_tensor_meta(output_index); - hidden_state_ = std::make_unique( - hidden_state->scalar_type(), - hidden_state->sizes().size(), - const_cast(hidden_state->sizes().data()), - ptr->hidden_state, - const_cast( - hidden_state->dim_order().data())); - // reuse inputs for following tensors - for (int shard_index = 1; shard_index < modules_.size(); ++shard_index) { - // inputs of shards 1 to n: hidden_state, atten_mask, pos_ids_cos, - // pos_ids_sin - input_tensors_[shard_index].push_back(hidden_state_.get()); - input_tensors_[shard_index].push_back(attention_mask_.get()); - input_tensors_[shard_index].push_back(position_ids_cos_.get()); - input_tensors_[shard_index].push_back(position_ids_sin_.get()); - } - } - // [I] kv_cache for all shards (4 shards for llama2 and 5 shards for llama3) - for (int offset = 0, shard_index = 0, v_stride = 1023 * 128; - shard_index < modules_.size(); - offset += shard_layers_[shard_index], shard_index++) { - for (int layer = 0; layer < shard_layers_[shard_index]; ++layer) { - for (int cache_group = 0; cache_group < 2; ++cache_group) { - for (int head = 0; head < num_heads_; ++head) { - // bypass hidden_state(input_ids), atten_mask, pos_cos, pos_sin - int index = - num_heads_ * 2 * layer + cache_group * num_heads_ + head + 4; - Result kv_cache = - methods_meta[shard_index]->input_tensor_meta(index); - std::vector>& cache = - (cache_group == 0 ? k_cache_in_ : v_cache_in_); - - void* cache_ptr = (cache_group == 0) - ? static_cast(ptr->k_cache[layer + offset][head]) - : static_cast( - ptr->v_cache[layer + offset] + head * v_stride); - - cache.emplace_back(std::make_unique( - kv_cache->scalar_type(), - kv_cache->sizes().size(), - const_cast(kv_cache->sizes().data()), - cache_ptr, - const_cast( - kv_cache->dim_order().data()))); - input_tensors_[shard_index].push_back(cache.back().get()); - } - } - } - } - // [O] kv_cache for all shards (4 shards for llama2 and 5 shards for llama3) - for (int offset = 0, shard_index = 0, v_stride = 1023 * 128; - shard_index < modules_.size(); - offset += shard_layers_[shard_index], shard_index++) { - for (int layer = 0; layer < shard_layers_[shard_index]; ++layer) { - for (int cache_group = 0; cache_group < 2; ++cache_group) { - for (int head = 0; head < num_heads_; ++head) { - int index = num_heads_ * 2 * layer + cache_group * num_heads_ + head; - Result kv_cache = - methods_meta[shard_index]->output_tensor_meta(index); - std::vector>& cache = - (cache_group == 0 ? v_cache_out_ : k_cache_out_); - - void* cache_ptr = (cache_group == 0) - ? static_cast( - ptr->v_cache[layer + offset] + (head + 1) * v_stride) - : static_cast(ptr->k_cache_out[layer + offset][head]); - - cache.emplace_back(std::make_unique( - kv_cache->scalar_type(), - kv_cache->sizes().size(), - const_cast(kv_cache->sizes().data()), - cache_ptr, - const_cast( - kv_cache->dim_order().data()))); - output_tensors_[shard_index].push_back(cache.back().get()); - } - } - } - } - // [O]: hidden_state for shard 0 to n-1 - for (int shard_index = 0; shard_index < modules_.size() - 1; ++shard_index) { - output_tensors_[shard_index].push_back(hidden_state_.get()); - } - // [O]: logits - { - int output_index = shard_layers_[modules_.size() - 1] * 2 * - num_heads_; // layers*(k + v caches)*heads - Result logits = - methods_meta[modules_.size() - 1]->output_tensor_meta(output_index); - logits_ = std::make_unique( - logits->scalar_type(), - logits->sizes().size(), - const_cast(logits->sizes().data()), - ptr->logits, - const_cast(logits->dim_order().data())); - output_tensors_[modules_.size() - 1].push_back(logits_.get()); - } - - // QAIHub Llama2 have 4* io compared to QAIHub Llama3, - // so we use multi-threading for Llama2 when updating io - if (num_heads_ == 32) { - // thread pool jobs - for (int i = 0, range = 1024 / thread_pool_.num_workers(); - i < thread_pool_.num_workers(); - ++i) { - lr_update_kv_.push_back( - {.start = i * range, .end = (i + 1) * range, .step = 1}); - } - } -} - -void KVCachedMemory::update_io( - int64_t cur_token, - int64_t pos, - std::vector>& output_tensors) { - IO* ptr = static_cast(data_ptr_.get()); - int seq_len = 1023; - // update input_ids - ptr->input_ids = static_cast(cur_token); - // update causal mask for next token - ptr->attention_mask[seq_len - pos] = 65535; - // update position_ids - position_ids_cos_->set_data(position_ids_cos_->mutable_data() + 64); - position_ids_sin_->set_data(position_ids_sin_->mutable_data() + 64); - - // use multithreading when we have a lot of ios, Llama2 in this case - if (num_heads_ == 32) { - auto update_kv = [&](void* arg) { - LoopRange* lr = static_cast(arg); - // update v_cache - for (int i = lr->start; i < lr->end; i += lr->step) { - v_cache_in_[i]->set_data(v_cache_in_[i]->mutable_data() + 128); - v_cache_out_[i]->set_data( - v_cache_out_[i]->mutable_data() + 128); - } - // update output tensors of v_cache, 256 is the number of kvs per shard - int shard = lr->start >> 8, offset = shard << 8; - int start = lr->start - offset, end = lr->end - offset; - for (int cache_stride = start; cache_stride < end; cache_stride += 32) { - for (int cache_group = 0; cache_group < 2; ++cache_group) { - for (int head = 0; head < 32; ++head) { - // k, v are placed interleaved - int index = (cache_stride << 1) + (cache_group << 5) + head; - ET_CHECK_MSG( - modules_[shard]->set_output( - method_names_[shard], - output_tensors[shard][index], - index) == Error::Ok, - "failed to set output tensor for module %d's %d'th output " - "while updating kv_cache output tensors", - shard, - index); - } - } - } - }; - - for (int i = 0; i < lr_update_kv_.size(); ++i) { - futures_[i] = std::move(thread_pool_.issue(update_kv, &lr_update_kv_[i])); - } - } else { - // update v_cache - for (int i = 0; i < v_cache_in_.size(); i++) { - v_cache_in_[i]->set_data(v_cache_in_[i]->mutable_data() + 128); - v_cache_out_[i]->set_data(v_cache_out_[i]->mutable_data() + 128); - } - for (int shard = 0; shard < output_tensors.size(); shard++) { - for (int index = 0; index < output_tensors[shard].size(); index++) { - ET_CHECK_MSG( - modules_[shard]->set_output( - method_names_[shard], output_tensors[shard][index], index) == - Error::Ok, - "failed to set output tensor for module %d's %d'th output " - "while updating kv_cache output tensors", - shard, - index); - } - } - } - // update k_cache by single thread, this part is cpu cache sensitive - for (int i = 0; i < k_cache_in_.size(); ++i) { - uint8_t* ptr_in = k_cache_in_[i]->mutable_data(); - const uint8_t* ptr_out = k_cache_out_[i]->data(); - for (size_t j = 0, offset = seq_len; j < 128; ++j, offset += seq_len) { - ptr_in[offset] = ptr_out[j]; - } - k_cache_in_[i]->set_data(ptr_in + 1); - } - for (auto& future : futures_) { - future.wait(); - } -} - -ThreadPool::ThreadPool() : stop_(false) { - size_t hc = (std::thread::hardware_concurrency() + 3) / 4; - // maximum number should be divisible by head dimension which equals to 32 - num_workers_ = std::min(32, hc * 4); - for (size_t i = 0; i < num_workers_; ++i) { - threads_.emplace_back([this]() { - while (1) { - std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { return !jobs_.empty() || stop_; }); - - if (stop_ && jobs_.empty()) - return; - - JobInfo job_info(std::move(jobs_.front())); - jobs_.pop(); - lock.unlock(); - job_info.func(job_info.arg); - } - }); - } -} - -ThreadPool::~ThreadPool() { - std::unique_lock lock(mutex_); - stop_ = true; - lock.unlock(); - cv_.notify_all(); - for (auto& thread : threads_) { - thread.join(); - } -} - -std::future ThreadPool::issue( - std::function func, - void* arg) { - std::unique_lock lock(mutex_); - jobs_.push(JobInfo(std::packaged_task(func), arg)); - std::future f = std::move(jobs_.back().func.get_future()); - lock.unlock(); - cv_.notify_one(); - return f; -} - -size_t ThreadPool::num_workers() { - return num_workers_; -} - -} // namespace example diff --git a/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.h b/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.h deleted file mode 100644 index 445be2ed21a..00000000000 --- a/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.h +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include -#include - -#if defined(QAIHUB_LLAMA3_RUNNER) -#define QAIHUB_LLAMA_NUM_HEADS 8 -#define QAIHUB_LLAMA_LOGITS 128256 -#else -#define QAIHUB_LLAMA_NUM_HEADS 32 -#define QAIHUB_LLAMA_LOGITS 32000 -#endif - -namespace example { - -class Memory { - public: - Memory( - const std::vector& pos_embs_path, - std::vector>& modules); - virtual ~Memory(); - virtual void prepare_io( - const std::vector< - executorch::runtime::Result>& - methods_meta) = 0; - virtual void update_io( - int64_t cur_token, - int64_t pos, - std::vector>& output_tensors) = 0; - void* get_mutable_ptr(); - std::vector get_input_tensors(int shard_index); - std::vector get_output_tensors(int shard_index); - - protected: - std::unique_ptr data_ptr_; - std::vector> input_tensors_; - std::vector> output_tensors_; - std::vector pos_embs_path_; - std::vector> modules_; - std::vector method_names_; -}; - -class BertMemory : public Memory { - public: - BertMemory( - const std::vector& pos_embs_path, - std::vector>& modules, - std::vector shard_layers); - void prepare_io(const std::vector>& methods_meta) override; - void update_io( - int64_t cur_token, - int64_t pos, - std::vector>& output_tensors) - override; - struct IO { - int32_t input_ids[1024 * 2]; - uint16_t hidden_state[1024 * 4096]; - uint16_t attention_mask[1024 * 1024]; - uint16_t position_ids_cos[1024 * 64]; - uint16_t position_ids_sin[1024 * 64]; - uint8_t k_cache[32][QAIHUB_LLAMA_NUM_HEADS][128 * 1024]; - uint8_t v_cache[32][QAIHUB_LLAMA_NUM_HEADS][1024 * 128]; - uint16_t logits[QAIHUB_LLAMA_LOGITS]; - }; - - private: - std::unique_ptr input_ids_; - std::unique_ptr hidden_state_; - std::unique_ptr attention_mask_; - std::unique_ptr position_ids_cos_; - std::unique_ptr position_ids_sin_; - std::vector> k_cache_; - std::vector> v_cache_; - std::unique_ptr logits_; - std::vector shard_layers_; - int num_heads_; -}; - -class ThreadPool { - public: - ThreadPool(); - ~ThreadPool(); - - std::future issue(std::function func, void* arg); - size_t num_workers(); - - private: - struct JobInfo { - explicit JobInfo(std::packaged_task&& func, void* arg) - : func(std::move(func)), arg(arg) {} - explicit JobInfo(JobInfo&& job_info) - : func(std::move(job_info.func)), arg(job_info.arg) {} - std::packaged_task func; - void* arg; - }; - size_t num_workers_; - std::vector threads_; - std::queue jobs_; - std::mutex mutex_; - std::condition_variable cv_; - bool stop_; -}; - -class KVCachedMemory : public Memory { - public: - KVCachedMemory( - const std::vector& pos_embs_path, - std::vector>& modules, - std::vector shard_layers); - void prepare_io(const std::vector>& methods_meta) override; - void update_io( - int64_t cur_token, - int64_t pos, - std::vector>& output_tensors) - override; - struct IO { - int32_t input_ids; - uint16_t hidden_state[4096]; - uint16_t attention_mask[1024]; - uint16_t position_ids_cos[1024 * 64]; - uint16_t position_ids_sin[1024 * 64]; - uint8_t k_cache[32][QAIHUB_LLAMA_NUM_HEADS][129 * 1023]; - uint8_t v_cache[32][(QAIHUB_LLAMA_NUM_HEADS + 1) * 1023 * 128]; - uint8_t k_cache_out[32][QAIHUB_LLAMA_NUM_HEADS][128]; - uint16_t logits[QAIHUB_LLAMA_LOGITS]; - }; - struct LoopRange { - int32_t start; - int32_t end; - int32_t step; - }; - - private: - std::unique_ptr input_ids_; - std::unique_ptr hidden_state_; - std::unique_ptr attention_mask_; - std::unique_ptr position_ids_cos_; - std::unique_ptr position_ids_sin_; - std::vector> k_cache_in_; - std::vector> v_cache_in_; - std::vector> k_cache_out_; - std::vector> v_cache_out_; - std::unique_ptr logits_; - std::vector lr_update_kv_; - std::vector> futures_; - ThreadPool thread_pool_; - std::vector shard_layers_; - int num_heads_; -}; - -} // namespace example diff --git a/examples/qualcomm/qaihub_scripts/llama/runner/runner.cpp b/examples/qualcomm/qaihub_scripts/llama/runner/runner.cpp deleted file mode 100644 index 06ea324ef6f..00000000000 --- a/examples/qualcomm/qaihub_scripts/llama/runner/runner.cpp +++ /dev/null @@ -1,418 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -// A simple llama2/3 runner that includes preprocessing and post processing -// logic. The module takes in a string as input and emits a string as output. - -#if defined(QAIHUB_LLAMA3_RUNNER) -#include -#else -#include -#endif -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#if defined(__aarch64__) -#include "arm_neon.h" -#endif - -using executorch::aten::Tensor; -using executorch::extension::Module; -using executorch::extension::llm::Sampler; -using executorch::extension::llm::time_in_ms; -using executorch::runtime::Error; -using executorch::runtime::EValue; -using executorch::runtime::MethodMeta; -using executorch::runtime::Result; - -namespace example { - -namespace { -static constexpr auto kTopp = 0.9f; -void printReport(const Runner::Stats& stats); -std::string statsToJsonString(const Runner::Stats& stats); -} // namespace - -Runner::Runner( - const std::vector& models_path, - const std::vector& pos_embs_path, - const std::vector& shard_layers, - const std::string& tokenizer_path, - const int eval_mode, - const float temperature, - const float logits_scale, - const int logits_offset) - : tokenizer_path_(tokenizer_path), - temperature_(temperature), - n_bos_(1), - n_eos_(1), - vocab_size_(QAIHUB_LLAMA_LOGITS), - max_seq_len_(1024), - eval_mode_(eval_mode), - stats_({}), - logits_scale_(logits_scale), - logits_offset_(logits_offset) { - for (size_t i = 0; i < models_path.size(); ++i) { - modules_.push_back(std::make_shared( - models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors)); - ET_LOG(Info, "creating module: model_path=%s", models_path[i].c_str()); - } - ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str()); - -// load tokenizer -#if defined(QAIHUB_LLAMA3_RUNNER) - tokenizer_ = example::get_tiktoken_for_llama(); - tokenizer_->load(tokenizer_path_); - eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); - version_ = LlamaVersion::kLlama3; -#else - tokenizer_ = std::make_unique(); - tokenizer_->load(tokenizer_path_); - version_ = LlamaVersion::kLlama2; -#endif - - bos_id_ = tokenizer_->bos_tok(); - eos_id_.insert(tokenizer_->eos_tok()); - - switch (eval_mode_) { - case EvalMode::kBert: - io_mem_ = - std::make_unique(pos_embs_path, modules_, shard_layers); - break; - case EvalMode::kKVCached: - io_mem_ = std::make_unique( - pos_embs_path, modules_, shard_layers); - break; - default: - ET_CHECK_MSG(false, "unsupported evaluation mode"); - } - ET_LOG(Info, "creating io_memory"); -} - -bool Runner::is_loaded() const { - bool loaded = true; - for (const std::shared_ptr& module : modules_) { - loaded &= module->is_loaded(); - } - return loaded && tokenizer_ && sampler_; -} - -Error Runner::load() { - if (is_loaded()) { - return Error::Ok; - } - for (std::shared_ptr& module : modules_) { - method_names_.emplace_back(*module->method_names()->begin()); - ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(method_names_.back())); - } - - // create sampler - sampler_ = std::make_unique( - vocab_size_, - temperature_, - kTopp, - static_cast(std::time(nullptr))); - - // prepare io - auto methods_meta = get_methods_meta(); - io_mem_->prepare_io(methods_meta); - return Error::Ok; -} - -int32_t Runner::logitsToToken(const Tensor& logits_tensor) { - static std::vector logits_f(vocab_size_); - const uint16_t* logits = logits_tensor.data_ptr(); - -#if defined(__aarch64__) - static int32x4_t offset = vmovq_n_s32(logits_offset_); - static float32x4_t scale = vmovq_n_f32(logits_scale_); - // dequantize - for (int i = 0; i < vocab_size_; i += 4) { - const uint16_t* in = logits + i; - float* out = logits_f.data() + i; - int32_t data[4] = {in[0], in[1], in[2], in[3]}; - int32x4_t quantized = vld1q_s32(data); - int32x4_t shifted = vsubq_s32(quantized, offset); - float32x4_t shifted_f = vcvtq_f32_s32(shifted); - vst1q_f32(out, vmulq_f32(shifted_f, scale)); - } -#else - // dequantize - for (int i = 0; i < vocab_size_; i++) { - logits_f[i] = (logits[i] - logits_offset_) * logits_scale_; - } -#endif - - return sampler_->sample(logits_f.data()); -} - -void Runner::run_model_step(std::vector>& inputs) { - for (size_t i = 0, num_modules = modules_.size(); i < num_modules; ++i) { - Result> outputs_res = - modules_[i]->execute(method_names_[i], inputs[i]); - ET_CHECK_MSG( - outputs_res.error() == Error::Ok, "shard %zu inference failed", i); - } -} - -// TODO: add overloaded method for on-device tokenize -Error Runner::generate( - const std::string& prompt, - const std::string& system_prompt, - int32_t seq_len, - std::function token_callback, - std::function stats_callback) { - ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null"); - - std::vector> input_tensors, output_tensors; - std::vector> inputs; - if (!is_loaded()) { - stats_.model_load_start_ms = time_in_ms(); - ET_CHECK_OK_OR_RETURN_ERROR(load()); - for (int i = 0; i < modules_.size(); ++i) { - input_tensors.emplace_back(io_mem_->get_input_tensors(i)); - output_tensors.emplace_back(io_mem_->get_output_tensors(i)); - for (size_t j = 0; j < output_tensors[i].size(); ++j) { - ET_CHECK_MSG( - modules_[i]->set_output( - method_names_[i], output_tensors[i][j], j) == Error::Ok, - "failed to set output tensor for module %d's %zu'th output", - i, - j); - } - inputs.emplace_back( - std::vector(begin(input_tensors[i]), end(input_tensors[i]))); - } - stats_.model_load_end_ms = time_in_ms(); - } - - stats_.inference_start_ms = time_in_ms(); - seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_; - - std::string post_process_prompt; - switch (version_) { - case LlamaVersion::kLlama2: - post_process_prompt.append(prompt); - break; - case LlamaVersion::kLlama3: - if (!system_prompt.empty()) { - post_process_prompt.append( - "<|start_header_id|>system<|end_header_id|>\n\n"); - post_process_prompt.append(system_prompt); - post_process_prompt.append("<|eot_id|>\n"); - } - post_process_prompt.append( - "<|start_header_id|>user<|end_header_id|>\n\n"); - post_process_prompt.append(prompt); - post_process_prompt.append( - "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"); - // tokenizer_->encode will add <|begin_of_text|> token for us. - // For now, do token call back so the output format looks the same as - // llama3 model card. - if (token_callback && eval_mode_ == EvalMode::kKVCached) { - token_callback("<|begin_of_text|>"); - } - break; - default: - ET_CHECK_MSG(false, "unsupported llama version"); - break; - } - - tokenizers::Result> encode_res = - tokenizer_->encode(post_process_prompt, n_bos_, 0); - ET_CHECK_TK_OK_OR_RETURN_ERROR( - encode_res.error(), - "failed to encode prompt %s", - post_process_prompt.c_str()); - - std::vector prompt_tokens = encode_res.get(); - int num_prompt_tokens = prompt_tokens.size(); - ET_CHECK_MSG(num_prompt_tokens < max_seq_len_, "max seq length exceeded"); - ET_CHECK_MSG( - num_prompt_tokens < seq_len, - "sequence length exceeded - please increase the seq_len value"); - - int64_t pos = 0, prev_token, cur_token = prompt_tokens[0]; - if (eval_mode_ == EvalMode::kBert) { - BertMemory::IO* ptr = - static_cast(io_mem_->get_mutable_ptr()); - - int start_index = max_seq_len_ - num_prompt_tokens; - // indices are filled from behind, take 3 tokens as an example: - // > tokens : [...tok_pad, tok_bos, tok1, tok2] - // > indices: [0.....1020, 1021, 1022, 1023] - for (int i = 0; i < num_prompt_tokens; i++) { - ptr->input_ids[start_index + i] = static_cast(prompt_tokens[i]); - } - // causal attention mask is filled as following: - // 0, 65535 maps to -100.0, 0.0 after dequantizing - // 0 : [0,...................0, 0, 0, 0] - // 1-1019 : ... - // 1020 : [0,...............65535, 0, 0, 0] - // 1021 : [0,...............65535, 65535, 0, 0] - // 1022 : [0,...............65535, 65535, 65535, 0] - // 1023 : [0,...............65535, 65535, 65535, 65535] - for (int i = max_seq_len_ - 1, len = num_prompt_tokens; len >= 0; - --i, --len) { - for (int j = 0; j <= len; ++j) { - ptr->attention_mask[i * max_seq_len_ + start_index - 1 + j] = 65535; - } - } - pos = num_prompt_tokens - 1; - cur_token = prompt_tokens[pos]; - } else if (eval_mode_ == EvalMode::kKVCached) { - KVCachedMemory::IO* ptr = - static_cast(io_mem_->get_mutable_ptr()); - ptr->input_ids = static_cast(cur_token); - ptr->attention_mask[max_seq_len_ - 1] = 65535; - } - - while (pos < seq_len - 1) { - // inference - run_model_step(inputs); - Tensor& logits_tensor = output_tensors.back().back(); - - if (pos == num_prompt_tokens) { - stats_.first_token_ms = time_in_ms(); - } else if (pos == num_prompt_tokens - 1) { - stats_.prompt_eval_end_ms = time_in_ms(); - } - - long sample_start_time_ms = time_in_ms(); - prev_token = cur_token; - cur_token = logitsToToken(logits_tensor); - stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; - - if (pos < num_prompt_tokens - 1) { - cur_token = prompt_tokens[pos + 1]; - } - io_mem_->update_io(cur_token, ++pos, output_tensors); - - auto piece_res = tokenizer_->decode(prev_token, cur_token); - ET_CHECK(piece_res.ok()); - - if (token_callback) { - token_callback(piece_res.get().c_str()); - } - - if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) { - ET_LOG(Info, "\nReached to the end of generation"); - break; - } - } - stats_.inference_end_ms = time_in_ms(); - - if (pos == seq_len) { - ET_LOG(Info, "\nSequence length (%i tokens) reached!", seq_len); - } - - stats_.num_prompt_tokens = num_prompt_tokens; - stats_.num_generated_tokens = pos - num_prompt_tokens; - printReport(stats_); - if (stats_callback) { - stats_callback(stats_); - } - - return Error::Ok; -} - -namespace { -void printReport(const Runner::Stats& stats) { - printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str()); - - ET_LOG( - Info, - "\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64, - stats.num_prompt_tokens, - stats.num_generated_tokens); - - ET_LOG( - Info, - "\tModel Load Time:\t\t%f (seconds)", - ((double)(stats.model_load_end_ms - stats.model_load_start_ms) / - stats.SCALING_FACTOR_UNITS_PER_SECOND)); - double inference_time_ms = - (double)(stats.inference_end_ms - stats.inference_start_ms); - ET_LOG( - Info, - "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)", - inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND, - - (stats.num_generated_tokens) / - (double)(stats.inference_end_ms - stats.inference_start_ms) * - stats.SCALING_FACTOR_UNITS_PER_SECOND); - double prompt_eval_time = - (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); - ET_LOG( - Info, - "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", - prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, - (stats.num_prompt_tokens) / prompt_eval_time * - stats.SCALING_FACTOR_UNITS_PER_SECOND); - - double eval_time = - (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); - ET_LOG( - Info, - "\t\tGenerated %" PRIu64 - " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", - stats.num_generated_tokens, - eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, - stats.num_generated_tokens / eval_time * - stats.SCALING_FACTOR_UNITS_PER_SECOND); - - // Time to first token is measured from the start of inference, excluding - // model load time. - ET_LOG( - Info, - "\tTime to first generated token:\t%f (seconds)", - ((double)(stats.first_token_ms - stats.inference_start_ms) / - stats.SCALING_FACTOR_UNITS_PER_SECOND)); - - ET_LOG( - Info, - "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)", - stats.num_prompt_tokens + stats.num_generated_tokens, - (double)stats.aggregate_sampling_time_ms / - stats.SCALING_FACTOR_UNITS_PER_SECOND); -} - -std::string statsToJsonString(const Runner::Stats& stats) { - std::stringstream ss; - ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << "," - << "\"generated_tokens\":" << stats.num_generated_tokens << "," - << "\"model_load_start_ms\":" << stats.model_load_start_ms << "," - << "\"model_load_end_ms\":" << stats.model_load_end_ms << "," - << "\"inference_start_ms\":" << stats.inference_start_ms << "," - << "\"inference_end_ms\":" << stats.inference_end_ms << "," - << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << "," - << "\"first_token_ms\":" << stats.first_token_ms << "," - << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms - << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":" - << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}"; - return ss.str(); -} -} // namespace - -std::vector> Runner::get_methods_meta() { - std::vector> methods_meta; - methods_meta.reserve(modules_.size()); - for (size_t i = 0; i < modules_.size(); ++i) { - methods_meta.emplace_back(modules_[i]->method_meta(method_names_[i])); - } - return methods_meta; -} -} // namespace example diff --git a/examples/qualcomm/qaihub_scripts/llama/runner/runner.h b/examples/qualcomm/qaihub_scripts/llama/runner/runner.h deleted file mode 100644 index 215930392ba..00000000000 --- a/examples/qualcomm/qaihub_scripts/llama/runner/runner.h +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -// A simple llama2/3 runner that includes preprocessing and post processing -// logic. The module takes in a string as input and emits a string as output. - -#pragma once - -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace example { - -class Runner { - public: - explicit Runner( - const std::vector& models_path, - const std::vector& pos_embs_path, - const std::vector& shard_layers, - const std::string& tokenizer_path, - const int eval_mode, - const float temperature, - const float logits_scale, - const int logits_offset); - - struct Stats { - // Scaling factor for timestamps - in this case, we use ms. - const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; - // Time stamps for the different stages of the execution - // model_load_start_ms: Start of model loading. - long model_load_start_ms; - // model_load_end_ms: End of model loading. - long model_load_end_ms; - // inference_start_ms: Immediately after the model is loaded (or we check - // for model load), measure the inference time. - long inference_start_ms; - // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right - // before the inference loop starts - long prompt_eval_end_ms; - // first_token: Timestamp when the first generated token is emitted - long first_token_ms; - // inference_end_ms: End of inference/generation. - long inference_end_ms; - // Keep a running total of the time spent in sampling. - long aggregate_sampling_time_ms = 0; - // Token count from prompt - int64_t num_prompt_tokens; - // Token count from generated (total - prompt) - int64_t num_generated_tokens; - }; - - bool is_loaded() const; - executorch::runtime::Error load(); - executorch::runtime::Error generate( - const std::string& prompt, - const std::string& system_prompt, - int32_t seq_len, - std::function token_callback = {}, - std::function stats_callback = {}); - void stop(); - std::vector> - get_methods_meta(); - - private: - enum EvalMode { - kBert = 0, - kKVCached, - kUnsupported, - }; - - enum LlamaVersion { - kLlama2 = 0, - kLlama3, - }; - - int32_t logitsToToken(const executorch::aten::Tensor& logits_tensor); - void run_model_step( - std::vector>& inputs); - // metadata - int32_t bos_id_; - std::unordered_set eos_id_; - const int32_t n_bos_; - const int32_t n_eos_; - const int32_t vocab_size_; - const int32_t max_seq_len_; - int32_t eval_mode_; - std::vector> modules_; - std::vector method_names_; - std::string tokenizer_path_; - float temperature_; - std::unique_ptr tokenizer_; - std::unique_ptr sampler_; - Stats stats_; - std::unique_ptr io_mem_; - const float logits_scale_; - const int32_t logits_offset_; - LlamaVersion version_; -}; - -} // namespace example diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt b/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt deleted file mode 100644 index 5b63a6678fc..00000000000 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# preprocess qaihub_stable_diffusion_runner_src files -set(_qaihub_stable_diffusion_runner__srcs - ${CMAKE_CURRENT_LIST_DIR}/qaihub_stable_diffusion_runner.cpp - ${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp - ${CMAKE_CURRENT_LIST_DIR}/runner/runner.h -) - -# build qaihub_stable_diffusion_runner -add_executable( - qaihub_stable_diffusion_runner ${_qaihub_stable_diffusion_runner__srcs} -) -target_include_directories( - qaihub_stable_diffusion_runner PUBLIC ${_common_include_directories} -) -target_link_libraries( - qaihub_stable_diffusion_runner - qnn_executorch_backend - executorch_core - extension_data_loader - extension_flat_tensor - extension_module - extension_tensor - gflags - re2::re2 -) -target_compile_options( - qaihub_stable_diffusion_runner PUBLIC ${_common_compile_options} -) -set_target_properties( - qaihub_stable_diffusion_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" -) diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/README.md b/examples/qualcomm/qaihub_scripts/stable_diffusion/README.md deleted file mode 100644 index 40a911ce280..00000000000 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# Summary - -## Overview -This file provides you the instructions to run Stable-Diffusion-v2.1 with different parameters via Qualcomm HTP backend. We will demonstrate how to run Stable Diffusion v2.1 on mobile devices using context binaries from Qualcomm AI Hub’s Stable Diffusion v2.1 - -Please check corresponding section for more information. - -## Stable-Diffusion-v2.1 -The model architecture, scheduler, and time embedding are from the [stabilityai/stable-diffusion-2-1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base). - -### Instructions -#### Step 1: Setup -1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. -2. Follow the [tutorial](https://pytorch.org/executorch/main/backends-qualcomm) to build Qualcomm AI Engine Direct Backend. - -#### Step2: Prepare Model -1. Download the context binaries for TextEncoder, UNet, and VAEDecoder under https://huggingface.co/qualcomm/Stable-Diffusion-v2.1/tree/main -2. Download vocab.json under https://huggingface.co/openai/clip-vit-base-patch32/tree/main - - -#### Step3: Install Requirements -Before running the code, you need to install the necessary Python packages. - -We have verified the code with `diffusers`==0.29.0 and `piq`==0.8.0. Please follow the instructions here to install the required items: -```bash -sh examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh -``` - -#### Step4: Verify context binary's version -Please refer to [Check context binary version](../../README.md#check-context-binary-version) for more info on why and how to verify the context binary's version - -#### Step5: Run default example -In this example, we execute the script for 20 time steps with the `prompt` 'a photo of an astronaut riding a horse on mars': -```bash -python examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py -b build-android -m ${SOC_MODEL} --s ${SERIAL_NUM} --text_encoder_bin ${PATH_TO_TEXT_ENCODER_CONTEXT_BINARY} --unet_bin ${PATH_TO_UNET_CONTEXT_BINARY} --vae_bin ${PATH_TO_VAE_CONTEXT_BINARY} --vocab_json ${PATH_TO_VOCAB_JSON_FILE} --num_time_steps 20 --prompt "a photo of an astronaut riding a horse on mars" -``` -- Please replace `${PATH_TO_TEXT_ENCODER_CONTEXT_BINARY}`, `${PATH_TO_UNET_CONTEXT_BINARY}`, and `${PATH_TO_VAE_CONTEXT_BINARY}` with the actual paths to your AI Hub context binary files. -- Please replace `${PATH_TO_VOCAB_JSON_FILE}` with the actual path to your vocab.json file. diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh b/examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh deleted file mode 100755 index bbb4767bee3..00000000000 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/install_requirements.sh +++ /dev/null @@ -1,3 +0,0 @@ -# For Stable Diffusion V2.1 -pip install diffusers==0.29.0 -pip install piq==0.8.0 diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py deleted file mode 100644 index 931056c5444..00000000000 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion.py +++ /dev/null @@ -1,403 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os -from multiprocessing.connection import Client - -import numpy as np -import piq -import torch -from diffusers import EulerDiscreteScheduler, UNet2DConditionModel -from diffusers.models.embeddings import get_timestep_embedding -from executorch.backends.qualcomm.export_utils import ( - QnnConfig, - setup_common_args_and_variables, - SimpleADB, -) - -from executorch.backends.qualcomm.utils.utils import ( - ExecutorchBackendConfig, - from_context_binary, - generate_htp_compiler_spec, - generate_qnn_executorch_compiler_spec, - get_soc_to_chipset_map, - QcomChipset, -) - -from executorch.examples.qualcomm.qaihub_scripts.stable_diffusion.stable_diffusion_lib import ( - StableDiffusion, -) -from executorch.examples.qualcomm.qaihub_scripts.utils.utils import ( - gen_pte_from_ctx_bin, - get_encoding, -) -from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass -from PIL import Image -from torchvision.transforms import ToTensor - -target_names = ("text_encoder", "unet", "vae") - - -def get_quant_data( - encoding: dict, data: torch.Tensor, input_model: str, input_index: int -): - scale = encoding[f"{input_model}_input"]["scale"][input_index] - offset = encoding[f"{input_model}_input"]["offset"][input_index] - if offset < 0: - quant_data = data.div(scale).sub(offset).clip(min=0, max=65535).detach() - else: - quant_data = data.div(scale).add(offset).clip(min=0, max=65535).detach() - - return quant_data.to(dtype=torch.uint16) - - -def get_encodings( - path_to_shard_encoder: str, - path_to_shard_unet: str, - path_to_shard_vae: str, - compiler_specs, -): - text_encoder_encoding = get_encoding( - path_to_shard=path_to_shard_encoder, - compiler_specs=compiler_specs, - get_input=False, - get_output=True, - num_input=1, - num_output=1, - ) - unet_encoding = get_encoding( - path_to_shard=path_to_shard_unet, - compiler_specs=compiler_specs, - get_input=True, - get_output=True, - num_input=3, - num_output=1, - ) - vae_encoding = get_encoding( - path_to_shard=path_to_shard_vae, - compiler_specs=compiler_specs, - get_input=True, - get_output=True, - num_input=1, - num_output=1, - ) - - return ( - text_encoder_encoding[0], - unet_encoding[0], - unet_encoding[1], - vae_encoding[0], - vae_encoding[1], - ) - - -def get_time_embedding(timestep, time_embedding): - timestep = torch.tensor([timestep]) - t_emb = get_timestep_embedding(timestep, 320, True, 0) - emb = time_embedding(t_emb) - - return emb - - -def build_args_parser(): - parser = setup_common_args_and_variables() - - parser.add_argument( - "-a", - "--artifact", - help="Path for storing generated artifacts by this example. Default ./stable_diffusion_qai_hub", - default="./stable_diffusion_qai_hub", - type=str, - ) - - parser.add_argument( - "--pte_prefix", - help="Prefix of pte files name. Default qaihub_stable_diffusion", - default="qaihub_stable_diffusion", - type=str, - ) - - parser.add_argument( - "--text_encoder_bin", - type=str, - default=None, - help="[For AI hub ctx binary] Path to Text Encoder.", - required=True, - ) - - parser.add_argument( - "--unet_bin", - type=str, - default=None, - help="[For AI hub ctx binary] Path to UNet.", - required=True, - ) - - parser.add_argument( - "--vae_bin", - type=str, - default=None, - help="[For AI hub ctx binary] Path to Vae Decoder.", - required=True, - ) - - parser.add_argument( - "--prompt", - default="a photo of an astronaut riding a horse on mars", - type=str, - help="Prompt to generate image from.", - ) - - parser.add_argument( - "--num_time_steps", - default=20, - type=int, - help="The number of diffusion time steps.", - ) - - parser.add_argument( - "--guidance_scale", - type=float, - default=7.5, - help="Strength of guidance (higher means more influence from prompt).", - ) - - parser.add_argument( - "--vocab_json", - type=str, - help="Path to tokenizer vocab.json file. Can get vocab.json under https://huggingface.co/openai/clip-vit-base-patch32/tree/main", - required=True, - ) - - parser.add_argument( - "--fix_latents", - help="Enable this option to fix the latents in the unet diffuse step.", - action="store_true", - ) - - return parser - - -def broadcast_ut_result(output_image, seed): - sd = StableDiffusion(seed) - to_tensor = ToTensor() - target = sd(args.prompt, 512, 512, args.num_time_steps) - target = to_tensor(target).unsqueeze(0) - output_tensor = to_tensor( - Image.fromarray(np.round(output_image[0] * 255).astype(np.uint8)[0]) - ).unsqueeze(0) - - psnr_piq = piq.psnr(target, output_tensor) - ssim_piq = piq.ssim(target, output_tensor) - print(f"PSNR: {round(psnr_piq.item(), 3)}, SSIM: {round(ssim_piq.item(), 3)}") - if args.ip and args.port != -1: - with Client((args.ip, args.port)) as conn: - conn.send(json.dumps({"PSNR": psnr_piq.item(), "SSIM": ssim_piq.item()})) - - -def save_result(output_image): - img = Image.fromarray(np.round(output_image[0] * 255).astype(np.uint8)[0]) - save_path = f"{args.artifact}/outputs/output_image.jpg" - img.save(save_path) - print(f"Output image saved at {save_path}") - - -def inference(args, qnn_config, compiler_specs, pte_files): - # Loading a pretrained EulerDiscreteScheduler from the https://huggingface.co/stabilityai/stable-diffusion-2-1-base. # @lint-ignore - scheduler = EulerDiscreteScheduler.from_pretrained( - "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler", revision="main" - ) - - # Loading a pretrained UNet2DConditionModel (which includes the time embedding) from the https://huggingface.co/stabilityai/stable-diffusion-2-1-base. # @lint-ignore - time_embedding = UNet2DConditionModel.from_pretrained( - "stabilityai/stable-diffusion-2-1-base", subfolder="unet", revision="main" - ).time_embedding - - scheduler.set_timesteps(args.num_time_steps) - scheduler.config.prediction_type = "epsilon" - # Get encoding of unet and vae - ( - encoder_output, - unet_input, - unet_output, - vae_input, - vae_output, - ) = get_encodings( - args.text_encoder_bin, - args.unet_bin, - args.vae_bin, - compiler_specs, - ) - encoding = { - "encoder_output": encoder_output, - "unet_input": unet_input, - "unet_output": unet_output, - "vae_input": vae_input, - "vae_output": vae_output, - } - - adb = SimpleADB( - qnn_config=qnn_config, - pte_path=pte_files, - workspace=f"/data/local/tmp/executorch/{args.pte_prefix}", - runner="examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner", - ) - - input_unet = () - - for t in scheduler.timesteps: - time_emb = get_quant_data( - encoding, get_time_embedding(t, time_embedding), "unet", 1 - ) - input_unet = input_unet + (time_emb,) - - qnn_executor_runner_args = [ - f"--text_encoder_path {adb.workspace}/{args.pte_prefix}_text_encoder.pte", - f"--unet_path {adb.workspace}/{args.pte_prefix}_unet.pte", - f"--vae_path {adb.workspace}/{args.pte_prefix}_vae.pte", - f"--input_list_path {adb.workspace}/input_list.txt", - f"--output_folder_path {adb.output_folder}", - f'--prompt "{args.prompt}"', - f"--guidance_scale {args.guidance_scale}", - f"--num_time_steps {args.num_time_steps}", - f"--vocab_json {adb.workspace}/vocab.json", - ] - if args.fix_latents: - qnn_executor_runner_args.append("--fix_latents") - - text_encoder_output_scale = encoding["encoder_output"]["scale"][0] - text_encoder_output_offset = encoding["encoder_output"]["offset"][0] - unet_input_latent_scale = encoding["unet_input"]["scale"][0] - unet_input_latent_offset = encoding["unet_input"]["offset"][0] - unet_input_text_emb_scale = encoding["unet_input"]["scale"][2] - unet_input_text_emb_offset = encoding["unet_input"]["offset"][2] - unet_output_scale = encoding["unet_output"]["scale"][0] - unet_output_offset = encoding["unet_output"]["offset"][0] - vae_input_scale = encoding["vae_input"]["scale"][0] - vae_input_offset = encoding["vae_input"]["offset"][0] - vae_output_scale = encoding["vae_output"]["scale"][0] - vae_output_offset = encoding["vae_output"]["offset"][0] - - qnn_executor_runner_args = qnn_executor_runner_args + [ - f"--text_encoder_output_scale {text_encoder_output_scale}", - f"--text_encoder_output_offset {text_encoder_output_offset}", - f"--unet_input_latent_scale {unet_input_latent_scale}", - f"--unet_input_latent_offset {unet_input_latent_offset}", - f"--unet_input_text_emb_scale {unet_input_text_emb_scale}", - f"--unet_input_text_emb_offset {unet_input_text_emb_offset}", - f"--unet_output_scale {unet_output_scale}", - f"--unet_output_offset {unet_output_offset}", - f"--vae_input_scale {vae_input_scale}", - f"--vae_input_offset {vae_input_offset}", - f"--vae_output_scale {vae_output_scale}", - f"--vae_output_offset {vae_output_offset}", - ] - - qnn_executor_runner_args = " ".join( - [ - f"cd {adb.workspace} &&", - f"./qaihub_stable_diffusion_runner {' '.join(qnn_executor_runner_args)}", - ] - ) - - files = [args.vocab_json] - - if args.fix_latents: - seed = 42 - latents = torch.randn((1, 4, 64, 64), generator=torch.manual_seed(seed)).to( - "cpu" - ) - # We need to explicitly permute after init tensor or else the random value will be different - latents = latents.permute(0, 2, 3, 1).contiguous() - latents = latents * scheduler.init_noise_sigma - flattened_tensor = latents.view(-1) - # Save the flattened tensor to a .raw file - with open(os.path.join(args.artifact, "latents.raw"), "wb") as file: - file.write(flattened_tensor.numpy().tobytes()) - files.append(os.path.join(args.artifact, "latents.raw")) - - adb.push(inputs=input_unet, files=files) - adb.execute(custom_runner_cmd=qnn_executor_runner_args) - - output_image = [] - - def post_process_vae(): - with open(f"{args.artifact}/outputs/output_0_0.raw", "rb") as f: - output_image.append( - np.fromfile(f, dtype=np.float32).reshape(1, 512, 512, 3) - ) - - adb.pull(host_output_path=args.artifact, callback=post_process_vae) - - if args.fix_latents: - broadcast_ut_result(output_image, seed) - else: - save_result(output_image) - - -def main(args): - qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) - os.makedirs(args.artifact, exist_ok=True) - # common part for compile & inference - backend_options = generate_htp_compiler_spec( - use_fp16=False, - use_multi_contexts=True, - ) - compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.soc_model), - backend_options=backend_options, - is_from_context_binary=True, - ) - - if args.pre_gen_pte is None: - # Create custom operators as context loader - soc_model = get_soc_to_chipset_map()[args.soc_model] - bundle_programs = [ - from_context_binary(args.text_encoder_bin, "ctx_loader_0", soc_model), - from_context_binary(args.unet_bin, "ctx_loader_1", soc_model), - from_context_binary(args.vae_bin, "ctx_loader_2", soc_model), - ] - pte_names = [f"{args.pte_prefix}_{target_name}" for target_name in target_names] - memory_planning_pass = MemoryPlanningPass( - alloc_graph_input=False, - alloc_graph_output=False, - ) - pte_files = gen_pte_from_ctx_bin( - artifact=args.artifact, - pte_names=pte_names, - bundle_programs=bundle_programs, - backend_config=ExecutorchBackendConfig( - memory_planning_pass=memory_planning_pass - ), - ) - assert ( - len(pte_files) == 3 - ), f"Error: Expected 3 PTE files, but got {len(pte_files)} files." - - else: - pte_files = [ - f"{args.pre_gen_pte}/{args.pte_prefix}_{target_name}.pte" - for target_name in target_names - ] - if args.compile_only: - return - - inference(args, qnn_config, compiler_specs, pte_files) - - -if __name__ == "__main__": # noqa: C901 - parser = build_args_parser() - args = parser.parse_args() - - try: - main(args) - except Exception as e: - if args.ip and args.port != -1: - with Client((args.ip, args.port)) as conn: - conn.send(json.dumps({"Error": str(e)})) - else: - raise Exception(e) diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp b/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp deleted file mode 100644 index 9c15ceadf8a..00000000000 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/qaihub_stable_diffusion_runner.cpp +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -DEFINE_string( - text_encoder_path, - "qaihub_stable_diffusion_text_encoder.pte", - "Text Encoder Model serialized in flatbuffer format."); -DEFINE_string( - unet_path, - "qaihub_stable_diffusion_unet.pte", - "Unet Model serialized in flatbuffer format."); -DEFINE_string( - vae_path, - "qaihub_stable_diffusion_vae.pte", - "Vae Model serialized in flatbuffer format."); -DEFINE_string( - output_folder_path, - "outputs", - "Executorch inference data output path."); -DEFINE_string( - input_list_path, - "input_list.txt", - "Input list storing time embedding."); -DEFINE_string( - vocab_json, - "vocab.json", - "Json path to retrieve a list of vocabs."); -DEFINE_string( - prompt, - "a photo of an astronaut riding a horse on mars", - "User input prompt"); -DEFINE_int32(num_time_steps, 20, "Number of time steps."); -DEFINE_double(guidance_scale, 7.5, "Guidance Scale"); - -DEFINE_double(text_encoder_output_scale, 0.0, "Text encoder output scale"); -DEFINE_int32(text_encoder_output_offset, 0, "Text encoder output offset"); -DEFINE_double(unet_input_latent_scale, 0.0, "Unet input latent scale"); -DEFINE_int32(unet_input_latent_offset, 0, "Unet input latent offset"); -DEFINE_double(unet_input_text_emb_scale, 0.0, "Unet input text emb scale"); -DEFINE_int32(unet_input_text_emb_offset, 0, "Unet input text emb offset"); -DEFINE_double(unet_output_scale, 0.0, "Unet output scale"); -DEFINE_int32(unet_output_offset, 0, "Unet output offset"); -DEFINE_double(vae_input_scale, 0.0, "Vae input scale"); -DEFINE_int32(vae_input_offset, 0, "Vae input offset"); -DEFINE_double(vae_output_scale, 0.0, "Vae output scale"); -DEFINE_int32(vae_output_offset, 0, "Vae output offset"); -DEFINE_bool( - fix_latents, - false, - "Enable this option to fix the latents in the unet diffuse step."); - -void usage_message() { - std::string usage_message = - "This is a sample executor runner capable of executing stable diffusion models." - "Users will need binary .pte program files for text_encoder, unet, and vae. Below are the options to retrieve required .pte program files:\n" - "For further information on how to generate the .pte program files and example command to execute this runner, please refer to qaihub_stable_diffsion.py."; - gflags::SetUsageMessage(usage_message); -} - -using executorch::runtime::Error; - -int main(int argc, char** argv) { - executorch::runtime::runtime_init(); - usage_message(); - gflags::ParseCommandLineFlags(&argc, &argv, true); - bool is_default = - gflags::GetCommandLineFlagInfoOrDie("text_encoder_output_scale") - .is_default || - gflags::GetCommandLineFlagInfoOrDie("text_encoder_output_offset") - .is_default || - gflags::GetCommandLineFlagInfoOrDie("unet_input_latent_scale") - .is_default || - gflags::GetCommandLineFlagInfoOrDie("unet_input_latent_offset") - .is_default || - gflags::GetCommandLineFlagInfoOrDie("unet_input_text_emb_scale") - .is_default || - gflags::GetCommandLineFlagInfoOrDie("unet_input_text_emb_offset") - .is_default || - gflags::GetCommandLineFlagInfoOrDie("unet_output_scale").is_default || - gflags::GetCommandLineFlagInfoOrDie("unet_output_offset").is_default || - gflags::GetCommandLineFlagInfoOrDie("vae_input_scale").is_default || - gflags::GetCommandLineFlagInfoOrDie("vae_input_offset").is_default || - gflags::GetCommandLineFlagInfoOrDie("vae_output_scale").is_default || - gflags::GetCommandLineFlagInfoOrDie("vae_output_offset").is_default; - - ET_CHECK_MSG( - !is_default, - "Please provide scale and offset for unet latent input, unet output, and vae input/output." - "Please refer to qaihub_stable_diffusion.py if you are unsure how to retrieve these values."); - - ET_LOG(Info, "Stable Diffusion runner started"); - std::vector models_path = { - FLAGS_text_encoder_path, FLAGS_unet_path, FLAGS_vae_path}; - - // Create stable_diffusion_runner - example::Runner runner( - models_path, - FLAGS_num_time_steps, - FLAGS_guidance_scale, - FLAGS_text_encoder_output_scale, - FLAGS_text_encoder_output_offset, - FLAGS_unet_input_latent_scale, - FLAGS_unet_input_latent_offset, - FLAGS_unet_input_text_emb_scale, - FLAGS_unet_input_text_emb_offset, - FLAGS_unet_output_scale, - FLAGS_unet_output_offset, - FLAGS_vae_input_scale, - FLAGS_vae_input_offset, - FLAGS_vae_output_scale, - FLAGS_vae_output_offset, - FLAGS_output_folder_path, - FLAGS_fix_latents); - - ET_CHECK_MSG( - runner.init_tokenizer(FLAGS_vocab_json) == Error::Ok, - "Runner failed to init tokenizer"); - - ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method"); - - ET_CHECK_MSG( - runner.parse_input_list(FLAGS_input_list_path) == Error::Ok, - "Failed to parse time embedding input list"); - ET_CHECK_MSG( - runner.generate(FLAGS_prompt) == Error::Ok, "Runner failed to generate"); - - ET_CHECK_MSG( - runner.print_performance() == Error::Ok, - "Runner failed to print performance"); - - return 0; -} diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp deleted file mode 100644 index 585d58b21ee..00000000000 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.cpp +++ /dev/null @@ -1,617 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -// A simple stable diffusion runner that includes preprocessing and post -// processing logic. The module takes in a string as input and emits a tensor as -// output. - -#include -#include -#include - -#include -#include -#include -#include - -#include -#include - -using executorch::extension::from_blob; -using executorch::extension::Module; -using executorch::extension::TensorPtr; -using executorch::extension::llm::time_in_ms; -using executorch::runtime::Error; -using executorch::runtime::MethodMeta; -using executorch::runtime::Result; - -namespace example { - -Runner::Runner( - const std::vector& models_path, - const int num_time_steps, - const float guidance_scale, - const float text_encoder_output_scale, - const int text_encoder_output_offset, - const float unet_input_latent_scale, - const int unet_input_latent_offset, - const float unet_input_text_emb_scale, - const float unet_input_text_emb_offset, - const float unet_output_scale, - const int unet_output_offset, - const float vae_input_scale, - const int vae_input_offset, - const float vae_output_scale, - const int vae_output_offset, - const std::string output_path, - const bool fix_latents) - : num_time_steps_(num_time_steps), - guidance_scale_(guidance_scale), - text_encoder_output_scale_(text_encoder_output_scale), - text_encoder_output_offset_(text_encoder_output_offset), - unet_input_latent_scale_(unet_input_latent_scale), - unet_input_latent_offset_(unet_input_latent_offset), - unet_input_text_emb_scale_(unet_input_text_emb_scale), - unet_input_text_emb_offset_(unet_input_text_emb_offset), - unet_output_scale_(unet_output_scale), - unet_output_offset_(unet_output_offset), - vae_input_scale_(vae_input_scale), - vae_input_offset_(vae_input_offset), - vae_output_scale_(vae_output_scale), - vae_output_offset_(vae_output_offset), - output_path_(output_path), - fix_latents_(fix_latents) { - for (int i = 0; i < models_path.size(); i++) { - modules_.push_back(std::make_unique( - models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors)); - ET_LOG(Info, "creating module: model_path=%s", models_path[i].c_str()); - } -} - -std::vector> Runner::get_methods_meta() { - std::vector> methods_meta; - for (size_t i = 0; i < modules_.size(); ++i) { - methods_meta.emplace_back(modules_[i]->method_meta(method_names_[i])); - } - return methods_meta; -} - -bool Runner::is_loaded() const { - bool loaded = true; - for (const std::unique_ptr& module : modules_) { - loaded &= module->is_loaded(); - } - return loaded; -} - -Error Runner::load() { - if (is_loaded()) { - return Error::Ok; - } - stats_.model_load_start_ms = time_in_ms(); - for (auto& module : modules_) { - method_names_.emplace_back(*module->method_names()->begin()); - ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(method_names_.back())); - } - stats_.model_load_end_ms = time_in_ms(); - return Error::Ok; -} - -Error Runner::parse_input_list(std::string& path) { - // Fill in data for input - std::ifstream input_list(path); - time_emb_list_.reserve(num_time_steps_); - ET_CHECK_MSG(input_list.is_open(), "Input list error opening file"); - std::string time_emb_file; - for (int i = 0; i < num_time_steps_; i++) { - std::getline(input_list, time_emb_file); - std::ifstream is; - is.open(time_emb_file, std::ios::binary); - is.seekg(0, std::ios::end); - size_t filesize = is.tellg(); - is.seekg(0, std::ios::beg); - std::vector time_emb; - time_emb.resize(filesize / sizeof(uint16_t)); - is.read(reinterpret_cast(time_emb.data()), filesize); - time_emb_list_.push_back(time_emb); - } - return Error::Ok; -} - -Error Runner::init_tokenizer(const std::string& vocab_json_path) { - ET_LOG(Info, "Loading Tokenizer from json"); - stats_.tokenizer_load_start_ms = time_in_ms(); - std::ifstream fin(vocab_json_path); - auto update_map = [this](std::string& target, std::regex& re) { - std::smatch sm; - std::regex_search(target, sm, re); - // replace special character, please extend this if any cornor case found - std::string text = sm[1]; - std::unordered_map post_process = { - {"\"", std::regex(R"(\\\")")}, - {" ", std::regex(R"()")}, - {"\\", std::regex(R"(\\\\)")}}; - for (auto& p : post_process) { - text = std::regex_replace(text, p.second, p.first); - } - vocab_to_token_map_[text] = std::stoi(sm[2]); - }; - - if (fin.is_open()) { - std::string line, text; - while (getline(fin, line)) { - text += line; - } - fin.close(); - - std::regex re_anchor(R"(\d,\")"); - std::regex re_pattern(R"(\{?\"(.*)\":([\d]+)\}?)"); - auto begin = std::sregex_iterator(text.begin(), text.end(), re_anchor); - auto end = std::sregex_iterator(); - size_t pos = 0; - for (std::sregex_iterator iter = begin; iter != end; ++iter) { - std::smatch match; - size_t len = iter->position() - pos + 1; - std::string target = text.substr(pos, len); - update_map(target, re_pattern); - pos = iter->position() + 1; - } - // process last vocabulary - std::string target = text.substr(pos); - update_map(target, re_pattern); - } - stats_.tokenizer_load_end_ms = time_in_ms(); - return Error::Ok; -} - -std::vector Runner::tokenize(std::string prompt) { - std::string bos("<|startoftext|>"), eos("<|endoftext|>"); - std::vector vocabs; - vocabs.reserve(max_tokens_); - std::vector tokens(1, vocab_to_token_map_[bos]); - - // pretokenize - // ref: https://github.com/monatis/clip.cpp - // https://huggingface.co/openai/clip-vit-base-patch32 - std::string text; - std::regex re( - R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"); - std::smatch sm; - while (std::regex_search(prompt, sm, re)) { - for (auto& v : sm) { - vocabs.push_back(v); - } - prompt = sm.suffix(); - } - for (std::string& v : vocabs) { - std::string word = (v[0] == ' ') ? v.substr(1) : v; - word += " "; - auto iter = vocab_to_token_map_.find(word); - if (iter != vocab_to_token_map_.end()) { - tokens.push_back(iter->second); - continue; - } - for (int i = 0; i < v.size(); ++i) { - for (int j = v.size() - 1; j >= i; --j) { - std::string token = v.substr(i, j - 1 + 1); - auto iter = vocab_to_token_map_.find(token); - if (iter != vocab_to_token_map_.end()) { - tokens.push_back(iter->second); - i = j + 1; - break; - } else if (j == i) { - ET_LOG(Error, "unknown token found: %s", token.c_str()); - } - } - } - } - tokens.push_back(vocab_to_token_map_[eos]); - return tokens; -} - -std::vector Runner::gen_latent_from_file() { - std::vector tensor_vector; - std::ifstream file("latents.raw", std::ios::binary); - if (!file.is_open()) { - ET_LOG(Error, "Error opening file!"); - return tensor_vector; - } - - // Read the tensor data - float value; - while (file.read(reinterpret_cast(&value), sizeof(float))) { - tensor_vector.push_back(value); - } - file.close(); - return tensor_vector; -} - -std::vector Runner::gen_random_latent(float sigma) { - std::random_device rnd_device; - std::mt19937 mersenne_engine{rnd_device()}; - std::normal_distribution dist{0.0f, 1.0f}; - - constexpr int latent_size = 1 * 64 * 64 * 4; - std::vector random_vector(latent_size); - - for (float& value : random_vector) { - value = dist(mersenne_engine) * sigma; - } - return random_vector; -} - -std::vector Runner::get_time_steps() { - std::vector time_steps(num_time_steps_); - for (int i = 0; i < num_time_steps_; ++i) { - time_steps[i] = (num_train_timesteps_ - 1) * - (1.0f - static_cast(i) / (num_time_steps_ - 1)); - } - return time_steps; -} - -std::vector Runner::get_sigmas(const std::vector& time_steps) { - float start = std::sqrt(beta_start_); - float end = std::sqrt(beta_end_); - std::vector betas(num_train_timesteps_); - float step = (end - start) / (num_train_timesteps_ - 1); - for (int i = 0; i < num_train_timesteps_; ++i) { - float value = start + i * step; - betas[i] = 1 - (value * value); - } - - std::vector alphas_cumprod(num_train_timesteps_); - float cumprod = 1.0; - for (int i = 0; i < num_train_timesteps_; ++i) { - cumprod *= betas[i]; - alphas_cumprod[i] = cumprod; - } - - std::vector sigmas(num_train_timesteps_); - for (int i = 0; i < num_train_timesteps_; ++i) { - sigmas[i] = std::sqrt((1.0 - alphas_cumprod[i]) / alphas_cumprod[i]); - } - - std::vector res(time_steps.size()); - for (size_t i = 0; i < time_steps.size(); ++i) { - float index = - static_cast(i) * (sigmas.size() - 1) / (time_steps.size() - 1); - size_t lower_index = static_cast(std::floor(index)); - size_t upper_index = static_cast(std::ceil(index)); - - float weight = index - lower_index; - res[i] = - (1.0 - weight) * sigmas[lower_index] + weight * sigmas[upper_index]; - } - std::reverse(res.begin(), res.end()); - res.push_back(0); - - return res; -} - -void Runner::scale_model_input( - const std::vector& latents, - std::vector& latent_model_input, - float sigma) { - for (int i = 0; i < latents.size(); i++) { - latent_model_input[i] = (latents[i] / std::sqrt(sigma * sigma + 1)); - } -} - -void Runner::quant_tensor( - const std::vector& fp_vec, - std::vector& quant_vec, - float scale, - int offset) { - offset = abs(offset); - for (int i = 0; i < fp_vec.size(); i++) { - quant_vec[i] = static_cast((fp_vec[i] / scale) + offset); - } -} - -void Runner::dequant_tensor( - const std::vector& quant_vec, - std::vector& fp_vec, - float scale, - int offset) { - offset = abs(offset); - for (int i = 0; i < quant_vec.size(); i++) { - fp_vec[i] = (quant_vec[i] - offset) * scale; - } -} - -// Using the same algorithm as EulerDiscreteScheduler in python. -void Runner::step( - const std::vector& model_output, - const std::vector& sigmas, - std::vector& sample, - std::vector& prev_sample, - int step_index) { - float sigma = sigmas[step_index]; - float dt = sigmas[step_index + 1] - sigma; - - for (int i = 0; i < sample.size(); ++i) { - float sigma_hat = sample[i] - (sigma * model_output[i]); - prev_sample[i] = (sample[i] - sigma_hat) / sigma; - prev_sample[i] = sample[i] + (prev_sample[i] * dt); - } - sample = prev_sample; -} - -Error Runner::generate(std::string prompt) { - ET_LOG(Info, "Start generating"); - stats_.generate_start_ms = time_in_ms(); - - // Start tokenize - stats_.tokenizer_parsing_start_ms = time_in_ms(); - std::vector cond_tokens = tokenize(prompt); - cond_tokens.resize(max_tokens_); - std::vector uncond_tokens = tokenize(""); - uncond_tokens.resize(max_tokens_); - stats_.tokenizer_parsing_end_ms = time_in_ms(); - - std::vector> method_metas = get_methods_meta(); - - MethodMeta encoder_method_meta = method_metas[0].get(); - // Initialize text_encoder input tensors: cond/uncond tokenized_input[1,77] - auto cond_tokens_tensor = from_blob( - cond_tokens.data(), - {1, 77}, - encoder_method_meta.input_tensor_meta(0)->scalar_type()); - auto uncond_tokens_tensor = from_blob( - uncond_tokens.data(), - {1, 77}, - encoder_method_meta.input_tensor_meta(0)->scalar_type()); - // Initialize text_encoder output tensors: cond/uncond embedding[1, 77, 1024] - constexpr int emb_size = 1 * 77 * 1024; - std::vector cond_emb_vec(emb_size); - std::vector uncond_emb_vec(emb_size); - std::vector fp_emb_vec(emb_size); - auto cond_emb_tensor = from_blob( - cond_emb_vec.data(), - {1, 77, 1024}, - encoder_method_meta.output_tensor_meta(0)->scalar_type()); - auto uncond_emb_tensor = from_blob( - uncond_emb_vec.data(), - {1, 77, 1024}, - encoder_method_meta.output_tensor_meta(0)->scalar_type()); - auto ret = modules_[0]->set_output(method_names_[0], cond_emb_tensor); - long encoder_start = time_in_ms(); - auto cond_res = modules_[0]->execute(method_names_[0], cond_tokens_tensor); - stats_.text_encoder_execution_time += (time_in_ms() - encoder_start); - ret = modules_[0]->set_output(method_names_[0], uncond_emb_tensor); - encoder_start = time_in_ms(); - auto uncond_res = - modules_[0]->execute(method_names_[0], uncond_tokens_tensor); - stats_.text_encoder_execution_time += (time_in_ms() - encoder_start); - - // Initialize unet parameters - MethodMeta unet_method_meta = method_metas[1].get(); - std::vector time_steps = get_time_steps(); - std::vector sigmas = get_sigmas(time_steps); - float max_sigma = *std::max_element(sigmas.begin(), sigmas.end()); - std::vector latent; - if (fix_latents_) { - latent = gen_latent_from_file(); - } else { - latent = gen_random_latent(max_sigma); - } - std::vector prev_sample(latent.size()); - - // Initialize unet input tensors - // 1. latent[1,64,64,4] - // 2. time_embedding[1,1280] - // 3. cond/uncond embedding[1,77,1024] - std::vector latent_model_input(latent.size()); - std::vector fp_latent_model_input(latent.size()); - auto latent_tensor = from_blob( - latent_model_input.data(), - {1, 64, 64, 4}, - unet_method_meta.input_tensor_meta(0)->scalar_type()); - std::vector time_emb_tensors; - time_emb_tensors.reserve(num_time_steps_); - for (auto step_index = 0; step_index < num_time_steps_; step_index++) { - time_emb_tensors.emplace_back(from_blob( - time_emb_list_[step_index].data(), - {1, 1280}, - unet_method_meta.input_tensor_meta(1)->scalar_type())); - } - // requantize text encoders output - dequant_tensor( - cond_emb_vec, - fp_emb_vec, - text_encoder_output_scale_, - text_encoder_output_offset_); - quant_tensor( - fp_emb_vec, - cond_emb_vec, - unet_input_text_emb_scale_, - unet_input_text_emb_offset_); - dequant_tensor( - uncond_emb_vec, - fp_emb_vec, - text_encoder_output_scale_, - text_encoder_output_offset_); - quant_tensor( - fp_emb_vec, - uncond_emb_vec, - unet_input_text_emb_scale_, - unet_input_text_emb_offset_); - - // Initialize unet output tensors: text/uncond noise_pred[1,64,64,4] - std::vector noise_pred_text(latent.size()); - std::vector noise_pred_uncond(latent.size()); - std::vector fp_noise_pred_text(noise_pred_text.size()); - std::vector fp_noise_pred_uncond(noise_pred_uncond.size()); - auto noise_pred_text_tensor = from_blob( - noise_pred_text.data(), - {1, 64, 64, 4}, - unet_method_meta.output_tensor_meta(0)->scalar_type()); - auto noise_pred_uncond_tensor = from_blob( - noise_pred_uncond.data(), - {1, 64, 64, 4}, - unet_method_meta.output_tensor_meta(0)->scalar_type()); - - // Execute unet - for (int step_index = 0; step_index < num_time_steps_; step_index++) { - long start_post_process = time_in_ms(); - scale_model_input(latent, fp_latent_model_input, sigmas[step_index]); - - quant_tensor( - fp_latent_model_input, - latent_model_input, - unet_input_latent_scale_, - unet_input_latent_offset_); - - stats_.unet_aggregate_post_processing_time += - (time_in_ms() - start_post_process); - ret = modules_[1]->set_output(method_names_[1], noise_pred_text_tensor); - long start_unet_execution = time_in_ms(); - auto cond_res = modules_[1]->execute( - method_names_[1], - {latent_tensor, time_emb_tensors[step_index], cond_emb_tensor}); - stats_.unet_aggregate_execution_time += - (time_in_ms() - start_unet_execution); - ret = modules_[1]->set_output(method_names_[1], noise_pred_uncond_tensor); - start_unet_execution = time_in_ms(); - auto uncond_res = modules_[1]->execute( - method_names_[1], - {latent_tensor, - time_emb_tensors[step_index], - uncond_emb_tensor}); // results in noise_pred_uncond_vec - stats_.unet_aggregate_execution_time += - (time_in_ms() - start_unet_execution); - - // start unet post processing - start_post_process = time_in_ms(); - - dequant_tensor( - noise_pred_text, - fp_noise_pred_text, - unet_output_scale_, - unet_output_offset_); - dequant_tensor( - noise_pred_uncond, - fp_noise_pred_uncond, - unet_output_scale_, - unet_output_offset_); - - for (int i = 0; i < fp_noise_pred_text.size(); i++) { - fp_noise_pred_text[i] = fp_noise_pred_uncond[i] + - guidance_scale_ * (fp_noise_pred_text[i] - fp_noise_pred_uncond[i]); - } - step(fp_noise_pred_text, sigmas, latent, prev_sample, step_index); - stats_.unet_aggregate_post_processing_time += - (time_in_ms() - start_post_process); - } - - // Start VAE - MethodMeta vae_method_meta = method_metas[2].get(); - // Initialize vae input tensor : latent[1,64,64,4] - std::vector vae_input(latent.size()); - auto vae_input_tensor = from_blob( - vae_input.data(), - {1, 64, 64, 4}, - vae_method_meta.input_tensor_meta(0)->scalar_type()); - // Intialize vae output tensor: output[1,512,512,3] - constexpr int image_size = 1 * 512 * 512 * 3; - std::vector q_out(image_size); - std::vector out(image_size); - auto output_tensor = from_blob( - q_out.data(), - {1, 512, 512, 3}, - vae_method_meta.output_tensor_meta(0)->scalar_type()); - - quant_tensor(latent, vae_input, vae_input_scale_, vae_input_offset_); - - ret = modules_[2]->set_output(method_names_[2], output_tensor); - long start_vae_execution = time_in_ms(); - auto vae_res = modules_[2]->execute(method_names_[2], vae_input_tensor); - stats_.vae_execution_time = (time_in_ms() - start_vae_execution); - stats_.generate_end_ms = time_in_ms(); - - // Dequant uint16 output to fp32 output - dequant_tensor(q_out, out, vae_output_scale_, vae_output_offset_); - - // Saving outputs - auto output_file_name = output_path_ + "/output_0_0.raw"; - std::ofstream fout(output_file_name.c_str(), std::ios::binary); - fout.write( - reinterpret_cast(out.data()), out.size() * sizeof(float)); - fout.close(); - - return Error::Ok; -} - -Error Runner::print_performance() { - ET_LOG(Info, "\tTotal Number of steps:\t\t\t\t%d", num_time_steps_); - - ET_LOG( - Info, - "\tTokenizer Load Time:\t\t\t\t%f (seconds)", - ((double)(stats_.tokenizer_load_end_ms - stats_.tokenizer_load_start_ms) / - stats_.SCALING_FACTOR_UNITS_PER_SECOND)); - - ET_LOG( - Info, - "\tModel Load Time:\t\t\t\t%f (seconds)", - ((double)(stats_.model_load_end_ms - stats_.model_load_start_ms) / - stats_.SCALING_FACTOR_UNITS_PER_SECOND)); - - ET_LOG( - Info, - "\tGenerate Time(Tokenize + Encoder + UNet + VAE):\t%f (seconds)", - ((double)(stats_.generate_end_ms - stats_.generate_start_ms) / - stats_.SCALING_FACTOR_UNITS_PER_SECOND)); - - ET_LOG( - Info, - "\tTokenize Time:\t\t\t\t\t%f (seconds)", - ((double)(stats_.tokenizer_parsing_end_ms - - stats_.tokenizer_parsing_start_ms) / - stats_.SCALING_FACTOR_UNITS_PER_SECOND)); - - ET_LOG( - Info, - "\tText Encoder Execution Time:\t\t\t%f (seconds)", - ((double)(stats_.text_encoder_execution_time) / - stats_.SCALING_FACTOR_UNITS_PER_SECOND)); - - ET_LOG( - Info, - "\tUnet Aggregate (Cond + Uncond) Execution Time:\t%f (seconds)", - ((double)stats_.unet_aggregate_execution_time / - (stats_.SCALING_FACTOR_UNITS_PER_SECOND))); - - ET_LOG( - Info, - "\tUnet Average Execution Time:\t\t\t%f (seconds)", - ((double)(stats_.unet_aggregate_execution_time / (num_time_steps_ * 2)) / - (stats_.SCALING_FACTOR_UNITS_PER_SECOND))); - - ET_LOG( - Info, - "\tUnet Aggregate Post-Processing Time:\t\t%f (seconds)", - ((double)(stats_.unet_aggregate_post_processing_time) / - stats_.SCALING_FACTOR_UNITS_PER_SECOND)); - - ET_LOG( - Info, - "\tUnet Average Post-Processing Time:\t\t%f (seconds)", - ((double)(stats_.unet_aggregate_post_processing_time / - (num_time_steps_ * 2)) / - (stats_.SCALING_FACTOR_UNITS_PER_SECOND))); - - ET_LOG( - Info, - "\tVAE Execution Time:\t\t\t\t%f (seconds)", - ((double)(stats_.vae_execution_time) / - stats_.SCALING_FACTOR_UNITS_PER_SECOND)); - return Error::Ok; -} - -} // namespace example diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h b/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h deleted file mode 100644 index e49201bca25..00000000000 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright (c) Qualcomm Innovation Center, Inc. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -// A simple diffusion runner that includes preprocessing and post processing -// logic. The module takes in a string as input and emites a tensor as output. - -#pragma once - -#include -#include -#include - -#include - -namespace example { - -class Runner { - public: - explicit Runner( - const std::vector& models_path, - const int num_time_steps, - const float guidance_scale, - const float text_encoder_output_scale, - const int text_encoder_output_offset, - const float unet_input_latent_scale, - const int unet_input_latent_offset, - const float unet_input_text_emb_scale, - const float unet_input_text_emb_offset, - const float unet_output_scale, - const int unet_output_offset, - const float vae_input_scale, - const int vae_input_offset, - const float vae_output_scale, - const int vae_output_offset, - const std::string output_path, - const bool fix_latents); - - struct Stats { - // Scaling factor for timestamps - in this case, we use ms. - const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; - // Time stamps for the different stages of the execution - // model_load_start_ms: Model loading time - long model_load_start_ms; - long model_load_end_ms; - - // tokenizer loading time - long tokenizer_load_start_ms = 0; - long tokenizer_load_end_ms = 0; - - // tokenizer parsing time - long tokenizer_parsing_start_ms = 0; - long tokenizer_parsing_end_ms = 0; - - // Total time to run generate - long generate_start_ms = 0; - long generate_end_ms = 0; - - // text encoder execution time - long text_encoder_execution_time = 0; - - // Unet aggregation execution time over n steps for cond + uncond - long unet_aggregate_execution_time = 0; - - // UNet aggregation post processing time over n steps for cond + uncond. - // This is the time from processing unet's output until feeding it into the - // next iteration. - long unet_aggregate_post_processing_time = 0; - - // VAE execution time - long vae_execution_time = 0; - }; - - bool is_loaded() const; - executorch::runtime::Error load(); - executorch::runtime::Error init_tokenizer(const std::string& vocab_json_path); - executorch::runtime::Error print_performance(); - std::vector tokenize(std::string prompt); - std::vector gen_latent_from_file(); - std::vector gen_random_latent(float sigma); - void step( - const std::vector& model_output, - const std::vector& sigmas, - std::vector& sample, - std::vector& prev_sample, - int step_index); - std::vector> - get_methods_meta(); - std::vector get_time_steps(); - std::vector get_sigmas(const std::vector& time_steps); - void scale_model_input( - const std::vector& vec, - std::vector& latent_model_input, - float sigma); - executorch::runtime::Error parse_input_list(std::string& path); - executorch::runtime::Error generate(std::string prompt); - void quant_tensor( - const std::vector& fp_vec, - std::vector& quant_vec, - float scale, - int offset); - void dequant_tensor( - const std::vector& quant_vec, - std::vector& fp_vec, - float scale, - int offset); - - private: - Stats stats_; - std::vector> modules_; - std::vector method_names_; - std::vector> time_emb_list_; - std::unordered_map vocab_to_token_map_; - - std::string output_path_; - int num_time_steps_; - float guidance_scale_; - float text_encoder_output_scale_; - int text_encoder_output_offset_; - float unet_input_latent_scale_; - int unet_input_latent_offset_; - float unet_input_text_emb_scale_; - int unet_input_text_emb_offset_; - float unet_output_scale_; - int unet_output_offset_; - float vae_input_scale_; - int vae_input_offset_; - float vae_output_scale_; - int vae_output_offset_; - const float beta_start_ = 0.00085; - const float beta_end_ = 0.012; - const int num_train_timesteps_ = 1000; - const int max_tokens_ = 77; - const bool fix_latents_ = false; -}; - -} // namespace example diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/stable_diffusion_lib.py b/examples/qualcomm/qaihub_scripts/stable_diffusion/stable_diffusion_lib.py deleted file mode 100644 index 8ec5783131d..00000000000 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/stable_diffusion_lib.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline - - -class StableDiffusion: - def __init__(self, seed=42): - self.model_id: str = "stabilityai/stable-diffusion-2-1-base" - self.generator = torch.manual_seed(seed) - self.scheduler = EulerDiscreteScheduler.from_pretrained( - self.model_id, subfolder="scheduler" - ) - - self.pipe = StableDiffusionPipeline.from_pretrained( - self.model_id, scheduler=self.scheduler, torch_dtype=torch.float32 - ) - self.pipe = self.pipe.to("cpu") - - def __call__(self, prompt, height, width, num_time_steps): - image = self.pipe( - prompt, height, width, num_time_steps, generator=self.generator - ).images[0] - return image diff --git a/examples/qualcomm/qaihub_scripts/utils/README.md b/examples/qualcomm/qaihub_scripts/utils/README.md deleted file mode 100644 index df4b989714a..00000000000 --- a/examples/qualcomm/qaihub_scripts/utils/README.md +++ /dev/null @@ -1,102 +0,0 @@ -# CLI Tool for Compile / Deploy Pre-Built QNN Artifacts - -An easy-to-use tool for generating / executing .pte program from pre-built model libraries / context binaries from Qualcomm AI Engine Direct. Tool is verified with [host environement](../../../../docs/source/backends-qualcomm.md#host-os). - -## Description - -This tool aims for users who want to leverage ExecuTorch runtime framework with their existent artifacts generated by QNN. It's possible for them to produce .pte program in few steps.
-If users are interested in well-known applications, [Qualcomm AI HUB](https://aihub.qualcomm.com/) is a great approach which provides tons of optimized state-of-the-art models ready for deploying. All of them could be downloaded in model library or context binary format. - -* Model libraries(.so) came from `qnn-model-lib-generator` | AI HUB, or context binaries(.bin) came from `qnn-context-binary-generator` | AI HUB, could apply tool directly with: - - To produce .pte program: - ```bash - $ python export.py compile - ``` - - To perform inference with generated .pte program: - ```bash - $ python export.py execute - ``` - -### Dependencies - -* Register for Qualcomm AI HUB. -* Download the corresponding QNN SDK via [link](https://www.qualcomm.com/developer/software/qualcomm-ai-engine-direct-sdk) which your favorite model is compiled with. Ths link will automatically download the latest version at this moment (users should be able to specify version soon, please refer to [this](../../../../docs/source/backends-qualcomm.md#software) for earlier releases). - -### Target Model - -* Consider using [virtual environment](https://app.aihub.qualcomm.com/docs/hub/getting_started.html) for AI HUB scripts to prevent package conflict against ExecuTorch. Please finish the [installation section](https://app.aihub.qualcomm.com/docs/hub/getting_started.html#installation) before proceeding following steps. -* Take [QuickSRNetLarge](https://aihub.qualcomm.com/iot/models/quicksrnetlarge) as an example, please [install](https://huggingface.co/qualcomm/QuickSRNetLarge-Quantized#installation) package as instructed. -* Create workspace and export pre-built model library: - ```bash - mkdir $MY_WS && cd $MY_WS - # target chipset is `SM8650` - python -m qai_hub_models.models.quicksrnetlarge_quantized.export --target-runtime qnn --chipset qualcomm-snapdragon-8gen3 - ``` -* The compiled model library will be located under `$MY_WS/build/quicksrnetlarge_quantized/quicksrnetlarge_quantized.so`. This model library maps to the artifacts generated by SDK tools mentioned in `Integration workflow` section on [Qualcomm AI Engine Direct document](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/overview.html). - -### Compiling Program - -* Compile .pte program - ```bash - # `pip install pydot` if package is missing - # Note that device serial & hostname might not be required if given artifacts is in context binary format - PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/qaihub_scripts/utils/export.py compile -a $MY_WS/build/quicksrnetlarge_quantized/quicksrnetlarge_quantized.so -m SM8650 -s $DEVICE_SERIAL -b $EXECUTORCH_ROOT/build-android - ``` -* Artifacts for checking IO information - - `output_pte/quicksrnetlarge_quantized/quicksrnetlarge_quantized.json` - - `output_pte/quicksrnetlarge_quantized/quicksrnetlarge_quantized.svg` - -### Executing Program - -* Prepare test image - ```bash - cd $MY_WS - wget https://user-images.githubusercontent.com/12981474/40157448-eff91f06-5953-11e8-9a37-f6b5693fa03f.png -O baboon.png - ``` - Execute following python script to generate input data: - ```python - import torch - import torchvision.transforms as transforms - from PIL import Image - img = Image.open('baboon.png').resize((128, 128)) - transform = transforms.Compose([transforms.PILToTensor()]) - # convert (C, H, W) to (N, H, W, C) - # IO tensor info. could be checked with quicksrnetlarge_quantized.json | .svg - img = transform(img).permute(1, 2, 0).unsqueeze(0) - torch.save(img, 'baboon.pt') - ``` -* Execute .pte program - ```bash - PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/qaihub_scripts/utils/export.py execute -p output_pte/quicksrnetlarge_quantized -i baboon.pt -s $DEVICE_SERIAL -b $EXECUTORCH_ROOT/build-android - ``` -* Post-process generated data - ```bash - cd output_data - ``` - Execute following python script to generate output image: - ```python - import io - import torch - import torchvision.transforms as transforms - # IO tensor info. could be checked with quicksrnetlarge_quantized.json | .svg - # generally we would have same layout for input / output tensors: e.g. either NHWC or NCHW - # this might not be true under different converter configurations - # learn more with converter tool from Qualcomm AI Engine Direct documentation - # https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-10/tools.html#model-conversion - with open('output__142.pt', 'rb') as f: - buffer = io.BytesIO(f.read()) - img = torch.load(buffer, weights_only=False) - transform = transforms.Compose([transforms.ToPILImage()]) - img_pil = transform(img.squeeze(0)) - img_pil.save('baboon_upscaled.png') - ``` - You could check the upscaled result now! - -## Help - -Please check help messages for more information: -```bash -PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/aihub/utils/export.py -h -PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/aihub/utils/python export.py compile -h -PYTHONPATH=$EXECUTORCH_ROOT/.. python $EXECUTORCH_ROOT/examples/qualcomm/aihub/utils/python export.py execute -h -``` diff --git a/examples/qualcomm/qaihub_scripts/utils/export.py b/examples/qualcomm/qaihub_scripts/utils/export.py deleted file mode 100644 index a144e74a82c..00000000000 --- a/examples/qualcomm/qaihub_scripts/utils/export.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import io -import json -import logging -import os -from pathlib import Path - -import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor -import numpy as np - -import torch -from executorch.backends.qualcomm.export_utils import QnnConfig, SimpleADB -from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset -from executorch.backends.qualcomm.utils.utils import ( - draw_graph, - from_context_binary, - generate_htp_compiler_spec, - generate_qnn_executorch_compiler_spec, - generate_qnn_executorch_option, -) -from executorch.examples.qualcomm.qaihub_scripts.utils.utils import preprocess_binary -from executorch.examples.qualcomm.utils import make_output_dir -from executorch.exir import ExecutorchBackendConfig -from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass - - -def get_logger(): - logger = logging.getLogger("aihub.utils.export") - handler = logging.StreamHandler() - handler.setFormatter( - logging.Formatter( - fmt="[%(asctime)s %(prefix)s] %(levelname)-8s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - ) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - logger.propagate = False - return logging.LoggerAdapter(logger, extra={"prefix": "UTILS.EXPORT"}) - - -def get_io_info(prog_info, ctx_bin_path, compiler_specs): - def fill_tensor_info(info, qnn_tensors, category): - # fetch related IO info stored in prog_info - for i, (name, tensor) in enumerate(prog_info[category].items()): - assert qnn_tensors[i].GetName() == name, "tensor name unmatch" - encoding = qnn_tensors[i].GetEncodings() - quantization_info = { - "scale": encoding.data["scale"].tolist(), - "offset": encoding.data["offset"].tolist(), - "axis": encoding.axis, - } - info[category].append( - { - "name": name, - "shape": tuple(tensor.shape), - "dtype": str(tensor.dtype), - "encoding": quantization_info, - } - ) - - # dictionary to be serialized into json format - in_key, out_key = "inputs", "outputs" - tensor_info = {in_key: [], out_key: []} - - with open(ctx_bin_path, "rb") as f: - ctx_bin = preprocess_binary(f.read(), compiler_specs) - # leverage QNN pybind interface to retrieve tensor encodings - qnn_mgr = PyQnnManagerAdaptor.QnnManager( - generate_qnn_executorch_option(compiler_specs), ctx_bin - ) - assert qnn_mgr.Init().value == 0, "failed to load context binary" - graph_name = qnn_mgr.GetGraphNames()[0] - qnn_mgr.AllocateTensor(graph_name) - fill_tensor_info(tensor_info, qnn_mgr.GetGraphInputs(graph_name), in_key) - fill_tensor_info(tensor_info, qnn_mgr.GetGraphOutputs(graph_name), out_key) - qnn_mgr.Destroy() - - return tensor_info - - -def get_ones_tensor(tensor_info, logger): - logger.warning( - f"tensor '{tensor_info['name']}' use ones tensor, " - "unexpected outputs might generate" - ) - return torch.ones(tensor_info["shape"], dtype=eval(tensor_info["dtype"])) - - -def get_tensor_with_encoding(tensor, tensor_info, logger): - scale = tensor_info["encoding"]["scale"] - offset = tensor_info["encoding"]["offset"] - - # user gave wrong tensor for no encoding appears - if len(scale) == 0: - logger.error(f"tensor '{tensor_info['name']}' has no encoding") - return get_ones_tensor(tensor_info, logger) - - # quant if tensor is float with encoding - return ( - tensor.div(scale).add(offset).round().to(eval(tensor_info["dtype"])) - if tensor.dtype == torch.float - else tensor.sub(offset).mul(scale).to(torch.float32) - ) - - -def get_tensor(io_info, tensors, logger, checking_output=False): - # check if enough tensors have been given - if len(tensors) != len(io_info): - logger.error( - "given tensor numbers mismatch, " - f"expected {len(io_info)} but got {len(tensors)}" - ) - if checking_output: - logger.error( - "output tensors failed to generate, " - "please check executor_runner logs." - ) - exit(-1) - - return [get_ones_tensor(t, logger) for t in io_info] - - # list of tensors to be returned - ret_tensors = [] - for i, info in enumerate(io_info): - if list(tensors[i].shape) != info["shape"]: - logger.error( - f"tensor '{info['name']}' shape mismatch: " - f"users > {tensors[i].shape} - " - f"required > {info['shape']}" - ) - ret_tensors.append(get_ones_tensor(info, logger)) - continue - - ret_tensors.append( - tensors[i] - if tensors[i].dtype == eval(info["dtype"]) - else - # try quant / dequant for given tensor if possible - ret_tensors.append(get_tensor_with_encoding(tensors[i], info, logger)) - ) - return [ret_tensors] - - -def to_context_binary( - model_lib, - soc_model, - device, - host, - target, - build_folder, - output_folder, - logger, -): - ext = Path(model_lib).suffix - if ext == ".bin": - return model_lib - - assert ( - device is not None - ), "Please assign device serial for model library conversion." - logger.info(f"Generating context binary for {model_lib}") - # leverage SimpleADB for model library conversion - lib_name = Path(model_lib).stem - sdk_root = os.getenv("QNN_SDK_ROOT") - qnn_config = QnnConfig( - soc_model=soc_model, - build_folder=build_folder, - device=device, - host=host, - target=target, - ) - adb = SimpleADB( - qnn_config=qnn_config, - pte_path=model_lib, - workspace=f"/data/local/tmp/executorch/{lib_name}", - ) - - logger.info("pushing QNN libraries & tool") - arch = adb.arch_table[soc_model] - files = [ - f"{sdk_root}/bin/aarch64-android/qnn-context-binary-generator", - f"{sdk_root}/lib/aarch64-android/libQnnHtp.so", - f"{sdk_root}/lib/aarch64-android/libQnnHtpV{arch}Stub.so", - f"{sdk_root}/lib/aarch64-android/libQnnHtpPrepare.so", - f"{sdk_root}/lib/hexagon-v{arch}/unsigned/libQnnHtpV{arch}Skel.so", - ] - adb.push(files=files) - - logger.info("starting conversion") - commands = " ".join( - [ - f"cd {adb.workspace} &&", - "export LD_LIBRARY_PATH=. &&", - "./qnn-context-binary-generator", - f"--model {Path(model_lib).name}", - "--backend libQnnHtp.so", - f"--binary_file {lib_name}", - ] - ) - adb.execute(custom_runner_cmd=commands) - - logger.info(f"collecting converted context binary - {lib_name}.bin") - adb._adb(["pull", f"{adb.workspace}/output/{lib_name}.bin", output_folder]) - - bin_path = f"{output_folder}/{lib_name}.bin" - assert os.path.exists(bin_path), ( - "Failed to convert context binary, " "please check logcat for more details." - ) - return bin_path - - -def compile(args): - logger = get_logger() - logger.info("prepare compiler spec for qualcomm backend") - - # setup compiler spec dedicated to QNN HTP backend - backend_options = generate_htp_compiler_spec(use_fp16=False) - # setup general compiler spec for QNN - compiler_specs = generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, args.soc_model), - backend_options=backend_options, - is_from_context_binary=True, - ) - # setup memory planning - memory_planning_pass = MemoryPlanningPass( - alloc_graph_input=args.allocate_graph_io, - alloc_graph_output=args.allocate_graph_io, - ) - - # dictionary for avoiding name collision when creating custom ops - name_map = {} - num_bins = len(args.artifacts) - for i, ctx_bin in enumerate(args.artifacts): - index = i + 1 - binary_name = Path(ctx_bin).stem - output_dir = f"{args.output_pte_folder}/{binary_name}" - make_output_dir(output_dir) - # conversion model library into context binary if required - ctx_bin = to_context_binary( - model_lib=ctx_bin, - soc_model=args.soc_model, - device=args.device, - host=args.host, - target=args.target, - build_folder=args.build_folder, - output_folder=output_dir, - logger=logger, - ) - # step 0: check if name collision happens for context binaries - logger.info(f"({index}/{num_bins}) checking custom op name of {ctx_bin}") - custom_op_name = f"ctx_loader_{binary_name}" - postfix = name_map.get(custom_op_name, 0) - if postfix > 0: - postfix += 1 - custom_op_name = f"{custom_op_name}_{postfix}" - name_map[custom_op_name] = postfix - # step 1: generate ExportedProgram with custom op as binary loader & lower to QnnBackend - logger.info(f"({index}/{num_bins}) exporting program for {ctx_bin}") - prog_info = from_context_binary( - ctx_bin, custom_op_name, getattr(QcomChipset, args.soc_model) - ) - # step 2: write pte files and IO information - logger.info(f"({index}/{num_bins}) exporting {binary_name}.pte") - with open(f"{output_dir}/{binary_name}.pte", "wb") as f: - prog_info["edge_program_manager"].to_executorch( - config=ExecutorchBackendConfig( - memory_planning_pass=memory_planning_pass - ) - ).write_to_file(f) - - logger.info( - f"({index}/{num_bins}) exporting network graph with {binary_name}.svg" - ) - draw_graph(binary_name, output_dir, prog_info["exported_program"].graph_module) - logger.info( - f"({index}/{num_bins}) exporting graph description with {binary_name}.json" - ) - with open(f"{output_dir}/{binary_name}.json", "w") as f: - graph_info = get_io_info(prog_info, ctx_bin, compiler_specs) - graph_info["soc_model"] = args.soc_model - json.dump(graph_info, f, indent=2) - - -def execute(args): - logger = get_logger() - - # load graph description file - pte_name = Path(args.pte_directory).stem - graph_desc = f"{args.pte_directory}/{pte_name}.json" - logger.info(f"loading graph description: {graph_desc}") - with open(graph_desc, "r") as f: - graph_info = json.load(f) - - # load input files - logger.info("loading user inputs") - user_inputs = [] - for input_file in args.input_files: - with open(input_file, "rb") as f: - buffer = io.BytesIO(f.read()) - user_inputs.append(torch.load(buffer, weights_only=False)) - - # check if inputs are valid, fallback to ones tensor if any - logger.info("generating input data") - inputs = get_tensor(graph_info["inputs"], user_inputs, logger) - - logger.info("preparing ADB connection") - qnn_config = QnnConfig.load_config(args.config_file if args.config_file else args) - # leverage SimpleADB for e2e inference - adb = SimpleADB( - qnn_config=qnn_config, - pte_path=f"{args.pte_directory}/{pte_name}.pte", - workspace=f"/data/local/tmp/executorch/{pte_name}", - ) - - logger.info("pushing QNN libraries & other artifacts") - adb.push(inputs=inputs) - - logger.info("starting inference") - adb.execute() - - logger.info("collecting output data") - - def post_process(): - output_info, outputs = graph_info["outputs"], [] - output_folder = f"{args.output_data_folder}/outputs" - for i, f in enumerate(sorted(os.listdir(output_folder))): - filename = os.path.join(output_folder, f) - output = np.fromfile( - filename, dtype=eval(f"np.{output_info[i]['dtype'].split('.')[-1]}") - ) - outputs.append(torch.from_numpy(output.reshape(output_info[i]["shape"]))) - os.remove(filename) - - os.rmdir(output_folder) - outputs, _ = get_tensor(output_info, outputs, logger, checking_output=True) - # dataset length equals to 1 - for i, output in enumerate(outputs[0]): - torch.save(output, f"{args.output_data_folder}/{output_info[i]['name']}.pt") - - make_output_dir(args.output_data_folder) - adb.pull(args.output_data_folder, post_process) - logger.info( - f"execution finished, please check {args.output_data_folder} for results" - ) - - -def main(): - parser = argparse.ArgumentParser( - description=( - "Utility to lower precompiled model libraries / " - "context binaries from Qualcomm AI Engine Direct to executorch" - " .pte program. Please visit https://aihub.qualcomm.com/ to " - "download your favorite models." - ), - ) - subparsers = parser.add_subparsers( - title="subcommands", - description=( - "[compile]: Compile designated model libraries / " - "context binaries into .pte files. " - "[execute]: Perform on-device inference with given .pte." - ), - ) - - sub_compile = subparsers.add_parser( - name="compile", - help=( - "e.g. python export.py compile -a model.bin -m SM8650 " - "-b /path/to/build-android" - ), - ) - sub_compile.add_argument( - "-a", - "--artifacts", - nargs="+", - type=str, - required=True, - help=( - "Path to AI HUB or QNN tool generated artifacts, " - "batch process is supported. " - "e.g. python export.py compile -a a.bin b.so c.bin " - "-m SM8650 -s $SERIAL_NO -b /path/to/build-android" - ), - ) - sub_compile.add_argument( - "-m", - "--model", - type=str, - required=True, - help="SoC model. e.g. SM8650", - ) - sub_compile.add_argument( - "-s", - "--device", - type=str, - help="Serial no of device which could be obtained by 'adb devices'.", - ) - sub_compile.add_argument( - "-o", - "--output_pte_folder", - type=str, - default="./output_pte", - help=( - "Path to output artifacts, store in 'output_pte' if not given. " - "graph descriptions & diagram will also be exported." - ), - ) - sub_compile.add_argument( - "-b", - "--build_folder", - help="Path to cmake binary directory for android, e.g., /path/to/build-android", - type=str, - required=True, - ) - sub_compile.add_argument( - "-l", - "--allocate_graph_io", - type=bool, - default=True, - help=( - "True if IO tensors are pre-allocated by framework. " - "False for users who want to manage resources in runtime." - ), - ) - sub_compile.add_argument( - "-H", - "--host", - type=str, - help="Gateway hostname.", - ) - sub_compile.set_defaults(callback=compile) - - sub_execute = subparsers.add_parser( - name="execute", - help=( - "e.g. python export.py execute -p model_dir -i inp.raw " "-s device_serial" - ), - ) - sub_execute.add_argument( - "-p", - "--pte_directory", - type=str, - required=True, - help="Path to .pte file folder generated from 'compile' subcommand.", - ) - sub_execute.add_argument( - "-i", - "--input_files", - nargs="*", - type=str, - help=( - "Path to input files stored via torch.save. " - "If the number / spec of input files doesn't match given .pte file, " - "tensors filled with value 1 will be taken as inputs." - ), - ) - sub_execute.add_argument( - "-s", - "--device", - type=str, - required=True, - help="Serial no of device which could be obtained by 'adb devices'.", - ) - sub_execute.add_argument( - "-o", - "--output_data_folder", - type=str, - default="./output_data", - help="Path to output data, store in 'output_data' if not given.", - ) - sub_execute.add_argument( - "-b", - "--build_folder", - help="Path to cmake binary directory for android, e.g., /path/to/build-android", - type=str, - required=True, - ) - sub_execute.add_argument( - "-z", - "--shared_buffer", - help=( - "Enables usage of shared buffer between application and backend for graph I/O." - " Please use with `--allocate_graph_io False` in compile command." - ), - action="store_true", - ) - sub_execute.add_argument( - "-H", - "--host", - type=str, - help="Gateway hostname.", - ) - sub_execute.set_defaults(callback=execute) - - args = parser.parse_args() - args.callback(args) - - -if __name__ == "__main__": - main() diff --git a/examples/qualcomm/qaihub_scripts/utils/utils.py b/examples/qualcomm/qaihub_scripts/utils/utils.py deleted file mode 100644 index fc065b79af5..00000000000 --- a/examples/qualcomm/qaihub_scripts/utils/utils.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import gc - -import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor - -from executorch.backends.qualcomm.utils.utils import ( - generate_qnn_executorch_option, - update_spill_fill_size, -) - - -def preprocess_binary(ctx_bin, compiler_specs): - qnn_mgr = PyQnnManagerAdaptor.QnnManager( - generate_qnn_executorch_option(compiler_specs), - ) - return bytes(qnn_mgr.MakeBinaryInfo(ctx_bin)) - - -def get_encoding( - path_to_shard: str, - compiler_specs: str, - get_input: bool, - get_output: bool, - num_input: int, - num_output: int, -): - encoding_list = [] - with open(path_to_shard, "rb") as f: - ctx_bin = preprocess_binary(f.read(), compiler_specs) - qnn_mgr = PyQnnManagerAdaptor.QnnManager( - generate_qnn_executorch_option(compiler_specs), ctx_bin - ) - assert qnn_mgr.Init().value == 0, "failed to load context binary" - graph_name = qnn_mgr.GetGraphNames()[0] - qnn_mgr.AllocateTensor(graph_name) - if get_input: - encoding_input = {"scale": [], "offset": []} - for i in range(num_input): - inputs = qnn_mgr.GetGraphInputs(graph_name)[i] - encoding = inputs.GetEncodings() - encoding_input["scale"].append(encoding.data["scale"].item()) - encoding_input["offset"].append(encoding.data["offset"].item()) - encoding_list.append(encoding_input) - if get_output: - encoding_output = {"scale": [], "offset": []} - for i in range(num_output): - outputs = qnn_mgr.GetGraphOutputs(graph_name)[i] - encoding = outputs.GetEncodings() - encoding_output["scale"].append(encoding.data["scale"].item()) - encoding_output["offset"].append(encoding.data["offset"].item()) - encoding_list.append(encoding_output) - qnn_mgr.Destroy() - return encoding_list - - -def gen_pte_from_ctx_bin(artifact, pte_names, bundle_programs, backend_config): - edge_prog_mgrs = [prog["edge_program_manager"] for prog in bundle_programs] - # Setup spill-fill buffer for relieving runtime memory usage - update_spill_fill_size( - [ - prog_mgr._edge_programs[list(prog_mgr.methods)[0]] - for prog_mgr in edge_prog_mgrs - ] - ) - # Export pte files - pte_files = [] - for pte_name in pte_names: - print(f"{pte_name} generating...") - pte_files.append(f"{artifact}/{pte_name}.pte") - with open(pte_files[-1], "wb") as f: - edge_prog_mgrs[0].to_executorch(config=backend_config).write_to_file(f) - # GC for reducing host memory consuming - bundle_programs.pop(0) - edge_prog_mgrs.pop(0) - gc.collect() - - return pte_files diff --git a/examples/qualcomm/util_scripts/cli.py b/examples/qualcomm/util_scripts/cli.py index 02af78e3dd4..a267a0b1fc7 100644 --- a/examples/qualcomm/util_scripts/cli.py +++ b/examples/qualcomm/util_scripts/cli.py @@ -48,7 +48,6 @@ QNN_TENSOR_TYPE_MAP, to_edge_transform_and_lower_to_qnn, ) -from executorch.examples.qualcomm.qaihub_scripts.utils.utils import preprocess_binary from executorch.exir import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from torchao.quantization import pt2e @@ -72,6 +71,13 @@ def get_logger(): return logging.LoggerAdapter(logger, extra={"prefix": "QNN_BACKEND"}) +def preprocess_binary(ctx_bin, compiler_specs): + qnn_mgr = PyQnnManagerAdaptor.QnnManager( + generate_qnn_executorch_option(compiler_specs), + ) + return bytes(qnn_mgr.MakeBinaryInfo(ctx_bin)) + + def get_io_info(pte_path, compiler_specs): dtype_map = {} for type_map in (QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP):