diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f682f4e..06bfe7f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ ### License - is licensed under the terms in [LICENSE]. By contributing to the project, you agree to the license and copyright terms therein and release your contribution under these terms. +IDEAS is licensed under the terms in [LICENSE](LICENSE). By contributing to the project, you agree to the license and copyright terms therein and release your contribution under these terms. ### Sign your work diff --git a/IDEAS.mk b/IDEAS.mk index 4ac0635..4f2b588 100644 --- a/IDEAS.mk +++ b/IDEAS.mk @@ -6,7 +6,7 @@ MAKEFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST))) MAKEFILE_DIR := $(realpath $(dir $(MAKEFILE_PATH))) -CARGO_TOML_CMAKE := ${MAKEFILE_DIR}/cargo_toml.cmake +EXTRACT_INFO_CMAKE := ${MAKEFILE_DIR}/extract_info.cmake IDEAS_MAKEFILE := $(MAKEFILE_DIR)/IDEAS.mk PROVIDER ?= hosted_vllm @@ -18,6 +18,7 @@ TRANSLATION_DIR ?= translation.$(shell git --git-dir=${MAKEFILE_DIR}/.git rev-pa ifeq (${PROVIDER},hosted_vllm) override TRANSLATE_ARGS += model.base_url=${BASE_URL} override REPAIR_ARGS += model.base_url=${BASE_URL} +override WRAPPER_ARGS += model.base_url=${BASE_URL} endif RUSTFLAGS ?= -Awarnings## Ignore Rust compiler warnings CFLAGS ?= -w## Ignore C compiler warnings @@ -32,6 +33,13 @@ GREEN_COL := \033[1;32m PROJECT_C_FILES = $(shell jq -r 'map(.file) | .[] | @text' build-ninja/compile_commands.json) C_FILES = $(subst ${CURDIR}/test_case/,,${PROJECT_C_FILES}) TEST_FILES := $(wildcard test_vectors/*.json) +TARGETS := $(shell find build-ninja -maxdepth 1 -type f -executable -exec basename {} \; | cut -d. -f1 | sed -e "s/^lib//gi") +ARTIFACTS := $(shell find build-ninja -maxdepth 1 -type f -executable -exec basename {} \;) +ifeq (${TARGETS},) +ifneq (${MAKECMDGOALS},cmake) +$(error No TARGETS found! You need to run cmake!) +endif +endif AFL_TAG = aflplusplus/aflplusplus:stable FUZZING_TIMEOUT ?= 60 @@ -40,9 +48,6 @@ FUZZING_TEST_VECTORS := $(subst :,\:, $(wildcard afl/out/default/queue/*)) CRATEIFY_BIN = ${MAKEFILE_DIR}/tools/crateify/target/debug/crateify -.PHONY: FORCE -FORCE: - # cmake cmake: build-ninja/build.log @@ -51,41 +56,48 @@ build-ninja/translate.log: build-ninja/compile_commands.json @$(MAKE) --no-print-directory -f ${IDEAS_MAKEFILE} $(addprefix test_case/,$(addsuffix .i,${C_FILES})) @touch $@ -.PRECIOUS: build-ninja/% -build-ninja/%: build-ninja/CMakeCache.txt ; - .PRECIOUS: build-ninja/CMakeCache.txt -build-ninja/CMakeCache.txt: test_case/CMakeLists.txt ${CARGO_TOML_CMAKE} +build-ninja/CMakeCache.txt: test_case/CMakeLists.txt ${EXTRACT_INFO_CMAKE} @rm -rf build-ninja ifeq ($(wildcard CMakePresets.json),) cmake -S test_case -B build-ninja -G Ninja \ - -DCMAKE_PROJECT_TOP_LEVEL_INCLUDES=${CARGO_TOML_CMAKE} \ - -DCMAKE_C_FLAGS="${CFLAGS}" \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=ON + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_C_FLAGS_DEBUG="-g -O0" \ + -DCMAKE_PROJECT_TOP_LEVEL_INCLUDES="${EXTRACT_INFO_CMAKE}" \ + -DCMAKE_C_FLAGS="${CFLAGS}" \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON else cmake -S . --preset test \ - -DCMAKE_PROJECT_TOP_LEVEL_INCLUDES=${CARGO_TOML_CMAKE} \ - -DCMAKE_C_FLAGS="${CFLAGS}" \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=ON + -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_C_FLAGS_DEBUG="-g -O0" \ + -DCMAKE_PROJECT_TOP_LEVEL_INCLUDES="${EXTRACT_INFO_CMAKE}" \ + -DCMAKE_C_FLAGS="${CFLAGS}" \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON endif .PRECIOUS: build-ninja/compile_commands.json build-ninja/compile_commands.json: build-ninja/CMakeCache.txt ; .PRECIOUS: build-ninja/build.log -build-ninja/build.log: build-ninja/translate.log +build-ninja/build.log: build-ninja/CMakeCache.txt ifeq ($(wildcard CMakePresets.json),) -cmake --build build-ninja --target all 2> $@ else -cmake --build build-ninja --target all --preset test 2> $@ endif + @find build-ninja -maxdepth 1 -type f -executable | \ + xargs -I{} sh -c "nm --extern-only {} | \ + awk '{if (\$$2 == \"T\") print \$$NF}' | \ + grep -v ^_ > {}.symbols" .PRECIOUS: test_case/%.c.i test_case/%.c.i: build-ninja/compile_commands.json - $(shell cat build-ninja/compile_commands.json | \ + cat build-ninja/compile_commands.json | \ jq -r '.[] | select(.file == "${CURDIR}/test_case/$*.c") | .command' | \ sed -e 's/-o [^ ]*//g' | \ - xargs -I{} echo "{} -E -o $@") + xargs -I{} echo "{} -E -o $@" | \ + sh + # Add more tests from fuzzing. The procedure is # 1. Copy test input from the initial JSON test vectors; @@ -168,95 +180,148 @@ test_vectors/%.json: afl/out/default/queue/% # init .PHONY: init -init: ${TRANSLATION_DIR}/Cargo.toml build-ninja/compile_commands.json ${CRATEIFY_BIN} - @$(MAKE) --no-print-directory -f${IDEAS_MAKEFILE} $(addprefix ${TRANSLATION_DIR}/src/,$(patsubst src/%,%,$(patsubst %.c,%.rs,${C_FILES}))) - ${CRATEIFY_BIN} ${TRANSLATION_DIR}/src +init: $(patsubst %,${TRANSLATION_DIR}/%/init.log,${TARGETS}) ; .PRECIOUS: ${TRANSLATION_DIR}/Cargo.toml -${TRANSLATION_DIR}/Cargo.toml: build-ninja/Cargo.toml - @mkdir -p $(@D) - cp build-ninja/Cargo.toml $@ - -${TRANSLATION_DIR}/%.rs: - @mkdir -p $(@D) - echo 'fn main() {\n println!("Hello, world!");\n}' > $@ +${TRANSLATION_DIR}/Cargo.toml: + mkdir -p $(@D) + echo -n "[workspace]\nresolver = \"3\"" > $@ + +.PRECIOUS: ${TRANSLATION_DIR}/%/Cargo.toml +${TRANSLATION_DIR}/%/Cargo.toml: | ${TRANSLATION_DIR}/Cargo.toml build-ninja/lib%.so.type + cargo new --quiet --lib --vcs=none $(@D) + echo -n "\n[lib]\ncrate-type = [\"lib\", \"cdylib\"]" >> $@ + cargo add --quiet --manifest-path $@ --dev assert_cmd@2.0.17 ntest@0.9.3 predicates@3.1.3 + cargo add --quiet --manifest-path $@ openssl@0.10.75 + +.PRECIOUS: ${TRANSLATION_DIR}/%/Cargo.toml +${TRANSLATION_DIR}/%/Cargo.toml: | ${TRANSLATION_DIR}/Cargo.toml build-ninja/%.type + cargo new --quiet --bin --vcs=none $(@D) + cargo add --quiet --manifest-path $@ --dev assert_cmd@2.0.17 ntest@0.9.3 predicates@3.1.3 + cargo add --quiet --manifest-path $@ openssl@0.10.75 + +.PRECIOUS: ${TRANSLATION_DIR}/%/src/lib.c +${TRANSLATION_DIR}/%/src/lib.c: ${TRANSLATION_DIR}/%/init.log ; + +.PRECIOUS: ${TRANSLATION_DIR}/%/init.log +${TRANSLATION_DIR}/%/init.log: | ${TRANSLATION_DIR}/%/Cargo.toml build-ninja/lib%.so.type + uv run python -m ideas.init filename=build-ninja/compile_commands.json \ + export_symbols=build-ninja/lib$*.so.symbols \ + source_priority=build-ninja/lib$*.so.sources \ + hydra.output_subdir=.init \ + hydra.run.dir=${TRANSLATION_DIR}/$* + +.PRECIOUS: ${TRANSLATION_DIR}/%/src/main.c +${TRANSLATION_DIR}/%/src/main.c: ${TRANSLATION_DIR}/%/init.log ; + +.PRECIOUS: ${TRANSLATION_DIR}/%/init.log +${TRANSLATION_DIR}/%/init.log: | ${TRANSLATION_DIR}/%/Cargo.toml build-ninja/%.type + uv run python -m ideas.init filename=build-ninja/compile_commands.json \ + export_symbols=build-ninja/$*.symbols \ + source_priority=build-ninja/$*.sources \ + hydra.output_subdir=.init \ + hydra.run.dir=${TRANSLATION_DIR}/$* .PRECIOUS: ${CRATEIFY_BIN} ${CRATEIFY_BIN}: @cd ${MAKEFILE_DIR}/tools/crateify && cargo build -.PRECIOUS: runner/release/runner -runner/release/runner: runner/Cargo.toml - @cd runner && cargo build --release --target-dir . # translate .PHONY: translate -translate: ${TRANSLATION_DIR}/translate.log ; -${TRANSLATION_DIR}/translate.log: build-ninja/compile_commands.json - -uv run python -m ideas.translate model.name=${PROVIDER}/${MODEL} filename=build-ninja/compile_commands.json hydra.run.dir=${TRANSLATION_DIR} ${TRANSLATE_ARGS} +translate: $(patsubst %,${TRANSLATION_DIR}/%/translate.log,${TARGETS}) ; + +${TRANSLATION_DIR}/translate.log: $(patsubst %,${TRANSLATION_DIR}/%/translate.log,${TARGETS}) + cat $^ > $@ + +.PRECIOUS: ${TRANSLATION_DIR}/%/src/lib.rs +${TRANSLATION_DIR}/%/src/lib.rs: ${TRANSLATION_DIR}/%/translate.log ; + +.PRECIOUS: ${TRANSLATION_DIR}/%/translate.log +${TRANSLATION_DIR}/%/translate.log: | ${TRANSLATION_DIR}/%/src/lib.c build-ninja/compile_commands.json build-ninja/lib%.so.symbols build-ninja/lib%.so.sources + -uv run python -m ideas.translate_recurrent model.name=${PROVIDER}/${MODEL} \ + filename=${TRANSLATION_DIR}/$*/src/lib.c \ + hydra.output_subdir=.translate \ + hydra.job.name=translate \ + hydra.run.dir=${TRANSLATION_DIR}/$* ${TRANSLATE_ARGS} + +.PRECIOUS: ${TRANSLATION_DIR}/%/src/main.rs +${TRANSLATION_DIR}/%/src/main.rs: ${TRANSLATION_DIR}/%/translate.log ; + +.PRECIOUS: ${TRANSLATION_DIR}/%/translate.log +${TRANSLATION_DIR}/%/translate.log: | ${TRANSLATION_DIR}/%/src/main.c build-ninja/compile_commands.json build-ninja/%.symbols build-ninja/%.sources + -uv run python -m ideas.translate_recurrent model.name=${PROVIDER}/${MODEL} \ + filename=${TRANSLATION_DIR}/$*/src/main.c \ + hydra.output_subdir=.translate \ + hydra.job.name=translate \ + hydra.run.dir=${TRANSLATION_DIR}/$* ${TRANSLATE_ARGS} + + +# wrapper +.PHONY: wrapper +wrapper: $(patsubst %,${TRANSLATION_DIR}/%/wrapper.log,${TARGETS}) ; + +${TRANSLATION_DIR}/%/wrapper.log: ${TRANSLATION_DIR}/%/translate.log | build-ninja/lib%.so.symbols + @mkdir -p $(@D)/src/wrapper + -@cat build-ninja/lib$*.so.symbols | xargs -t -I{} bindgen --disable-header-comment --no-doc-comments --no-layout-tests $(@D)/src/lib.c --allowlist-function {} -o $(@D)/src/wrapper/{}.rs + -@cat build-ninja/lib$*.so.symbols | xargs -t -I{} sed -zEe 's/\nunsafe extern "C" \{\s+(.*);\s+}/\n\#[unsafe(export_name = "{}")]\1 {\n unimplemented!();\n}/gi' -i $(@D)/src/wrapper/{}.rs + -@cat build-ninja/lib$*.so.symbols | xargs -t -I{} sed -e 's/pub fn/pub extern "C" fn/gi' -i $(@D)/src/wrapper/{}.rs + -@cat build-ninja/lib$*.so.symbols | xargs -t -I{} rustfmt ${@D}/src/wrapper/{}.rs + -uv run python -m ideas.wrapper model.name=${PROVIDER}/${MODEL} \ + symbols=build-ninja/lib$*.so.symbols \ + cargo_toml=${TRANSLATION_DIR}/$*/Cargo.toml \ + hydra.output_subdir=.wrapper \ + hydra.job.name=wrapper \ + hydra.run.dir=${TRANSLATION_DIR}/$* ${WRAPPER_ARGS} + +${TRANSLATION_DIR}/%/wrapper.log: ${TRANSLATION_DIR}/%/translate.log | build-ninja/%.symbols ; # build .PHONY: build build: ${TRANSLATION_DIR}/build.log ; -.PRECIOUS: ${TRANSLATION_DIR}/build.log -${TRANSLATION_DIR}/build.log: ${TRANSLATION_DIR}/translate.log ${TRANSLATION_DIR}/Cargo.toml ${CRATEIFY_BIN} FORCE - ${CRATEIFY_BIN} ${TRANSLATION_DIR}/src +${TRANSLATION_DIR}/build.log: $(patsubst %,${TRANSLATION_DIR}/%/build.log,${TARGETS}) + cat $^ > $@ + +${TRANSLATION_DIR}/%/build.log: ${TRANSLATION_DIR}/%/wrapper.log -export RUSTFLAGS=${RUSTFLAGS} && cargo build --quiet --manifest-path $(@D)/Cargo.toml 2> $@ + @cat $@ + +${TRANSLATION_DIR}/target/debug/lib%.so: ${TRANSLATION_DIR}/%/build.log | build-ninja/lib%.so.type ; +${TRANSLATION_DIR}/target/debug/%: ${TRANSLATION_DIR}/%/build.log | build-ninja/%.type ; -# tests for executables +# tests for TARGETS .PHONY: test test: ${TRANSLATION_DIR}/cargo_test.log ; .PRECIOUS: ${TRANSLATION_DIR}/cargo_test.log -${TRANSLATION_DIR}/cargo_test.log: ${TRANSLATION_DIR}/Cargo.toml \ - ${TRANSLATION_DIR}/tests/test_cases.rs \ - ${TRANSLATION_DIR}/build.log - @if [ $$(stat -c %s ${TRANSLATION_DIR}/build.log) = 0 ]; then \ +${TRANSLATION_DIR}/cargo_test.log: ${TRANSLATION_DIR}/build.log $(patsubst %,${TRANSLATION_DIR}/%/tests/test_cases.rs,${TARGETS}) + if [ $$(stat -c %s ${TRANSLATION_DIR}/build.log) = 0 ]; then \ cargo test --manifest-path ${TRANSLATION_DIR}/Cargo.toml --test test_cases | tee $@ ; \ else \ find test_vectors -name '*.json' -exec echo "test {} ... FAILED" \; | tee $@ ; \ fi -.PRECIOUS: ${TRANSLATION_DIR}/tests/test_cases.rs -${TRANSLATION_DIR}/tests/test_cases.rs: ${TEST_FILES} +.PRECIOUS: ${TRANSLATION_DIR}/%/tests/test_cases.rs +${TRANSLATION_DIR}/%/tests/test_cases.rs: ${TEST_FILES} @mkdir -p $(@D) - -uv run python -m ideas.convert_tests $^ | rustfmt > $@ + -uv run python -m ideas.convert_tests ${TEST_FILES} | rustfmt > $@ .PRECIOUS: test_vectors/%.json test_vectors/%.json: $(error $@ not found) -# tests for C libraries -.PHONY: test_libc -test_libc: runner/test_libc.log ; - -.PRECIOUS: runner/test_libc.log -runner/test_libc.log: build-ninja/build.log \ - runner/release/runner - find test_vectors -name '*.json' \ - | sort \ - | xargs -I {} sh -c './runner/release/runner lib -c ../{} -v' \ - | tee $@ - -# tests for Rust libraries -.PHONY: test_librs -test_librs: runner/test_librs.log ; - -.PRECIOUS: runner/test_librs.log -runner/test_librs.log: ${TRANSLATION_DIR}/build.log \ - runner/release/runner - find test_vectors -name '*.json' \ - | sort \ - | xargs -I {} sh -c './runner/release/runner -b ${TRANSLATION_DIR}/target/debug lib -c ../{} -v' \ - | tee $@ # repair .PHONY: repair -repair: ${TRANSLATION_DIR}/translate.log ${TRANSLATION_DIR}/Cargo.toml ${TRANSLATION_DIR}/tests/test_cases.rs - -uv run python -m ideas.repair model.name=${PROVIDER}/${MODEL} cargo_toml=${TRANSLATION_DIR}/Cargo.toml ${REPAIR_ARGS} +repair: ${TRANSLATION_DIR}/translate.log \ + ${TRANSLATION_DIR}/Cargo.toml \ + ${TRANSLATION_DIR}/tests/test_cases.rs + -uv run python -m ideas.repair model.name=${PROVIDER}/${MODEL} \ + cargo_toml=${TRANSLATION_DIR}/Cargo.toml \ + ${REPAIR_ARGS} # clean .PHONY: clean diff --git a/Makefile b/Makefile index afc6945..f3f8e7a 100644 --- a/Makefile +++ b/Makefile @@ -18,28 +18,26 @@ BASE_URL ?= http://${HOST}:${PORT}/v1## Base URL of vLLM server VLLM_ARGS ?= --tensor-parallel-size 8 --enable-expert-parallel --max-num-seqs 32 --max-model-len 128k## Args to pass to vllm serve TRANSLATION_DIR ?= translation.$(shell git rev-parse HEAD)## Directory to put IDEAS translation TRANSLATE_ARGS ?= ## Args to pass to IDEAS translation +TESTGEN_DIR ?= testgen.$(shell git rev-parse HEAD)## Directory to put IDEAS test generation +TESTGEN_ARGS ?= ## Args to pass to IDEAS test generation RUSTFLAGS ?= -Awarnings## Flags to build Rust translation VERBOSE ?= 0## Whether to output failed/partial projects in summaries AFL_TAG = aflplusplus/aflplusplus:stable # Pass these variables to IDEAS.mk -export MODEL BASE_URL TRANSLATION_DIR RUSTFLAGS +export MODEL BASE_URL TRANSLATION_DIR TESTGEN_DIR RUSTFLAGS -EXAMPLES ?= $(sort $(shell find ${EXAMPLES_DIR} -name test_case -type d))## List of examples to run on +EXAMPLES ?= $(sort $(shell find ${EXAMPLES_DIR} -maxdepth 3 -name test_case -type d))## List of examples to run on ifeq ($(EXAMPLES),) $(warning No projects found in ${EXAMPLES_DIR}. You may need to re-run commands!) endif -EXAMPLES_WITH_RUNNERS := $(filter-out $(foreach ex,$(EXAMPLES),$(if $(wildcard $(ex)/../runner $(ex)/../test_vectors),,$(ex))),$(EXAMPLES)) -ifeq ($(EXAMPLES_WITH_RUNNERS),) -$(warning No projects with runners found in ${EXAMPLES_DIR}. You may need to re-run commands!) -endif all: help ; .PHONY: install -install: install-uv install-rust## Install uv and Rust +install: install-uv install-rust install-deno ## Install uv, Rust, and Deno .PHONY: install-uv install-uv:## Install uv@0.7.12 @@ -49,6 +47,10 @@ install-uv:## Install uv@0.7.12 install-rust:## Install Rust@1.88.0 curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.88.0 +.PHONY: install-deno +install-deno:## Install Deno, which is required by dspy.PythonInterpreter() + curl -fsSL https://deno.land/install.sh | sh + .PHONY: test test:## Run pytest uv run pytest @@ -69,6 +71,7 @@ examples/init: $(subst /test_case,/init,${EXAMPLES}) ; @echo "# ${TRANSLATION_DIR}" examples/%/init:## Initialize specific example examples/%/init: FORCE + -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) cmake -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) init @@ -95,13 +98,25 @@ examples/translate: $(subst /test_case,/translate,${EXAMPLES}) @echo "\`\`\`" examples/%/translate:## Translate specific example examples/%/translate: FORCE + -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) cmake -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) translate + +.PHONY: examples/wrapper +examples/wrapper:## Generate C FFI wrappers for all examples +examples/wrapper: $(subst /test_case,/wrapper,${EXAMPLES}) +examples/%/wrapper:## Generate C FFI wrappers for specific example +examples/%/wrapper: FORCE + -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) cmake + -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) wrapper + + .PHONY: examples/add_test_vectors examples/add_test_vectors:## Build all translated examples examples/add_test_vectors: $(subst /test_case,/add_test_vectors,${EXAMPLES}) examples/%/add_test_vectors:## Build specific translated example examples/%/add_test_vectors: FORCE + -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) cmake -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) add_test_vectors .PHONY: examples/build @@ -120,6 +135,7 @@ endif @echo "\`\`\`" examples/%/build:## Build specific translated example examples/%/build: FORCE + -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) cmake -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) build .PHONY: examples/test @@ -149,27 +165,16 @@ endif @echo "\`\`\`" examples/%/test:## Test specific translated example examples/%/test: FORCE + -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) cmake -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) test -.PHONY: examples/test_libc -examples/test_libc:## Run tests using runners on all C library examples -examples/test_libc: $(subst /test_case,/test_libc,${EXAMPLES_WITH_RUNNERS}) ; -examples/%/test_libc:## Run tests using runners on specific C library example -examples/%/test_libc: FORCE - -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) test_libc - -.PHONY: examples/test_librs -examples/test_librs:## Run tests using runners on all translated Rust library examples -examples/test_librs: $(subst /test_case,/test_librs,${EXAMPLES_WITH_RUNNERS}) ; -examples/%/test_librs:## Run tests using runners on specific translated Rust library example -examples/%/test_librs: FORCE - -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) test_librs .PHONY: examples/repair examples/repair:## Repair all examples examples/repair: $(subst /test_case,/repair,${EXAMPLES}) examples/%/repair:## Repair specific example examples/%/repair: FORCE + -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) cmake -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) repair .PHONY: examples/clean @@ -200,6 +205,7 @@ examples/update_tests:## Update test cases to use TRANSLATION_DIR test cases for examples/update_tests: $(subst /test_case,/update_tests,${EXAMPLES}) examples/%/update_tests:## Update specific test cases to use TRANSLATION_DIR test cases examples/%/update_tests: FORCE + -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) cmake -@$(MAKE) -j1 -f $(IDEAS_MAKEFILE) -C $(@D) update_tests # clean diff --git a/cargo_toml.cmake b/cargo_toml.cmake deleted file mode 100644 index 8ae2177..0000000 --- a/cargo_toml.cmake +++ /dev/null @@ -1,82 +0,0 @@ -cmake_minimum_required(VERSION 3.20) - -# https://stackoverflow.com/questions/60211516/programmatically-get-all-targets-in-a-cmake-project -function(get_all_targets _result _dir) - get_property(_subdirs DIRECTORY "${_dir}" PROPERTY SUBDIRECTORIES) - foreach(_subdir IN LISTS _subdirs) - get_all_targets(${_result} "${_subdir}") - endforeach() - - get_directory_property(_sub_targets DIRECTORY "${_dir}" BUILDSYSTEM_TARGETS) - set(${_result} ${${_result}} ${_sub_targets} PARENT_SCOPE) -endfunction() - -function(generate-cargo-toml) - # Sanitize project name for cargo target - # FIXME: Add more sanitizations? - set(CARGO_NAME ${PROJECT_NAME}) - string(REGEX REPLACE "-" "_" CARGO_NAME ${CARGO_NAME}) - - get_all_targets(ALL_TARGETS ${PROJECT_SOURCE_DIR}) - message(VERBOSE "ALL_TARGETS: ${ALL_TARGETS}") - # FIXME: This is super error prone because it uses the last target as the desired target - list(GET ALL_TARGETS -1 PROJECT_TARGET) - message(VERBOSE "PROJECT_TARGET: ${PROJECT_TARGET}") - - get_target_property(PROJECT_TYPE ${PROJECT_TARGET} TYPE) - message(VERBOSE "PROJECT_TYPE: ${PROJECT_TYPE}") - if(${PROJECT_TYPE} STREQUAL "EXECUTABLE") - set(CARGO_TARGET "[bin]") - set(CRATE_TYPE "") - else() - set(CARGO_TARGET "lib") - set(CRATE_TYPE "crate-type = [\"lib\", \"cdylib\"]") - endif() - - get_target_property(TARGET_DIR ${PROJECT_TARGET} SOURCE_DIR) - message(VERBOSE "TARGET_DIR: ${TARGET_DIR}") - message(VERBOSE "CMAKE_CURRENT_SOURCE_DIR: ${CMAKE_CURRENT_SOURCE_DIR}") - - get_target_property(PROJECT_SOURCES ${PROJECT_TARGET} SOURCES) - message(VERBOSE "PROJECT_SOURCES: ${PROJECT_SOURCES}") - - # FIXME: This is super error prone because it uses the first C file as the path for the crate bin - # Ideally we'd use the one with the main function in it... - list(GET PROJECT_SOURCES 0 CARGO_PATH) - cmake_path(REPLACE_EXTENSION CARGO_PATH rs OUTPUT_VARIABLE CARGO_PATH) - cmake_path(ABSOLUTE_PATH CARGO_PATH BASE_DIRECTORY "${TARGET_DIR}") - cmake_path(RELATIVE_PATH CARGO_PATH BASE_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}") - # P01 projects have test_case prefix, so remove it - if (${CARGO_PATH} MATCHES "test_case/.*") - message(VERBOSE "Removing test_case from ${CARGO_PATH}!") - string(REGEX REPLACE "^test_case/" "" CARGO_PATH ${CARGO_PATH}) - endif() - - # Prepend a top-level src/ directory only if the C target does not already - # have one - if (NOT(${CARGO_PATH} MATCHES "^src/.*")) - message(VERBOSE "Prepending src/ to ${CARGO_PATH}!") - set(CARGO_PATH "src/${CARGO_PATH}") - endif() - message(VERBOSE "CARGO_PATH: ${CARGO_PATH}") - - file(GENERATE OUTPUT "Cargo.toml" CONTENT "[package] -name = \"${PROJECT_NAME}\" -version = \"0.1.0\" -edition = \"2024\" - -[${CARGO_TARGET}] -name = \"${CARGO_NAME}\" -path = \"${CARGO_PATH}\" -${CRATE_TYPE} - -[dev-dependencies] -assert_cmd = \"2.0.17\" -ntest = \"0.9.3\" -predicates = \"3.1.3\" -") - - message(STATUS "Generated Cargo.toml for ${PROJECT_TARGET} target!") -endfunction() - -cmake_language(DEFER CALL generate-cargo-toml) diff --git a/extract_info.cmake b/extract_info.cmake new file mode 100644 index 0000000..6516e32 --- /dev/null +++ b/extract_info.cmake @@ -0,0 +1,67 @@ +cmake_minimum_required(VERSION 3.20) + +# https://stackoverflow.com/questions/60211516/programmatically-get-all-targets-in-a-cmake-project +function(get_all_targets _result _dir) + get_property(_subdirs DIRECTORY "${_dir}" PROPERTY SUBDIRECTORIES) + foreach(_subdir IN LISTS _subdirs) + get_all_targets(${_result} "${_subdir}") + endforeach() + + get_directory_property(_sub_targets DIRECTORY "${_dir}" BUILDSYSTEM_TARGETS) + set(${_result} ${${_result}} ${_sub_targets} PARENT_SCOPE) +endfunction() + +function(get_target_sources SOURCES TARGET) + get_target_property(TARGET_DIR ${TARGET} SOURCE_DIR) + get_target_property(TARGET_SOURCES_RAW ${TARGET} SOURCES) + set(TARGET_SOURCES "") + foreach(TARGET_SOURCE ${TARGET_SOURCES_RAW}) + # Complex projects may contain generator expressions so we handle that here since there is no way to evaluate that expression + if(TARGET_SOURCE MATCHES "\\$]+)>") + set(SOURCE_TARGET ${CMAKE_MATCH_1}) + if(TARGET "${SOURCE_TARGET}") + get_target_sources(SOURCE_SOURCES ${SOURCE_TARGET}) + if(SOURCE_SOURCES) + list(APPEND TARGET_SOURCES ${SOURCE_SOURCES}) + endif() + endif() + else() + set(TARGET_SOURCE "${TARGET_DIR}/${TARGET_SOURCE}") + list(APPEND TARGET_SOURCES ${TARGET_SOURCE}) + endif() + endforeach() + set(${SOURCES} ${TARGET_SOURCES} PARENT_SCOPE) +endfunction() + +function(extract_info) + message(STATUS "Detecting targets ...") + get_all_targets(ALL_TARGETS ${PROJECT_SOURCE_DIR}) + foreach(TARGET ${ALL_TARGETS}) + get_target_property(TARGET_TYPE ${TARGET} TYPE) + get_target_property(TARGET_DIR ${TARGET} SOURCE_DIR) + get_target_property(TARGET_LINK_LIBRARIES ${TARGET} LINK_LIBRARIES) + + message(STATUS " Found ${TARGET_TYPE} ${TARGET}") + + # Recursively get target sources and and shared library/object sources + get_target_sources(TARGET_SOURCES ${TARGET}) + foreach(LINK_TARGET IN LISTS TARGET_LINK_LIBRARIES) + if(TARGET ${LINK_TARGET}) + get_target_sources(LINK_SOURCES ${LINK_TARGET}) + list(APPEND TARGET_SOURCES ${LINK_SOURCES}) + endif() + endforeach() + list(JOIN TARGET_SOURCES "\n" TARGET_SOURCES) + + if(${TARGET_TYPE} STREQUAL "OBJECT_LIBRARY") + set(TARGET_NAME ${TARGET}) + else() + set(TARGET_NAME $) + endif() + file(GENERATE OUTPUT "${TARGET_NAME}.type" CONTENT "${TARGET_TYPE}") + file(GENERATE OUTPUT "${TARGET_NAME}.dir" CONTENT "${TARGET_DIR}") + file(GENERATE OUTPUT "${TARGET_NAME}.sources" CONTENT "${TARGET_SOURCES}") + endforeach() +endfunction() + +cmake_language(DEFER CALL extract_info) diff --git a/pyproject.toml b/pyproject.toml index 3e7174c..984eb26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires-python = "~=3.11.0" dependencies = [ "clang==14.0", - "dspy==3.0.3", + "dspy==3.0.4", "hydra-core==1.3.2", "tree-sitter==0.24.0", "tree-sitter-c==0.23.4", diff --git a/src/ideas/__init__.py b/src/ideas/__init__.py index 2606a3b..c744651 100644 --- a/src/ideas/__init__.py +++ b/src/ideas/__init__.py @@ -6,10 +6,11 @@ from .logging import JSONFormatter, CodePairFilter from .ast import create_translation_unit, extract_info_c, TreeResult +from .ast_rust import ensure_no_mangle_in_module from .ltu import build_unit from .model import ModelConfig, GenerateConfig from .agents import TranslateAgent - +from .tools import get_info_from_cargo_toml __version__ = "2025.10" @@ -17,10 +18,12 @@ "create_translation_unit", "extract_info_c", "TreeResult", + "ensure_no_mangle_in_module", "ModelConfig", "GenerateConfig", "JSONFormatter", "CodePairFilter", "build_unit", "TranslateAgent", + "get_info_from_cargo_toml", ] diff --git a/src/ideas/agents.py b/src/ideas/agents.py index e9e88e8..307bf85 100644 --- a/src/ideas/agents.py +++ b/src/ideas/agents.py @@ -11,9 +11,9 @@ from pathlib import Path import dspy -from clang.cindex import TranslationUnit, CursorKind +from clang.cindex import TranslationUnit -from .ast import extract_info_c, get_cursor_prettyprinted +from .ast import extract_info_c, get_cursor_code from .tools import check_rust from .ltu import build_unit from .cover import CoVeR @@ -38,7 +38,7 @@ def __init__(self, preproc_strategy: str): def forward( self, c_code: str, c_full_code: str, tu: TranslationUnit - ) -> dict[str, list[str]]: + ) -> dict[str, dict[str, str] | list[str]]: output_code = [] # Use Clang to analyze the pre-processed C code ast_info = extract_info_c(tu) @@ -81,13 +81,7 @@ def forward( if child.extent.start.column != 1: continue - output_code += get_cursor_prettyprinted(child) - - # Non-function definitions require statement terminations - if child.kind != CursorKind.FUNCTION_DECL or not child.is_definition(): # type: ignore - output_code += ";" - - output_code += "\n" + output_code += get_cursor_code(child, pretty_print=True) + "\n" output_code = [output_code] case "tu-sys-filter": @@ -100,13 +94,7 @@ def forward( if child.location.file.name.startswith("/usr"): continue - output_code += get_cursor_prettyprinted(child) - - # Non-function definitions require statement terminations - if child.kind != CursorKind.FUNCTION_DECL or not child.is_definition(): # type: ignore - output_code += ";" - - output_code += "\n" + output_code += get_cursor_code(child, pretty_print=True) + "\n" output_code = [output_code] case "c": @@ -114,11 +102,11 @@ def forward( case "ltu-max": output_units = build_unit(ast_info, type="functional_maximal") - output_code = [str(unit) for unit in output_units] + output_code = {unit.symbol_name: str(unit) for unit in output_units} case "ltu-min": output_units = build_unit(ast_info, type="functional_minimal") - output_code = [str(unit) for unit in output_units] + output_code = {unit.symbol_name: str(unit) for unit in output_units} return {"input_code": output_code} @@ -130,6 +118,7 @@ def __init__( translator: str, max_iters: int, use_raw_fixer_output: bool, + patch_no_mangle: bool, ): super().__init__() @@ -166,6 +155,7 @@ def compile_rust(output_code: dspy.Code["Rust"]) -> str: # noqa: F821 success=success_message, max_iters=max_iters, use_raw_fixer_output=use_raw_fixer_output, + patch_no_mangle=patch_no_mangle, ) case "Predict": diff --git a/src/ideas/ast.py b/src/ideas/ast.py index 2df5897..ab0c74e 100644 --- a/src/ideas/ast.py +++ b/src/ideas/ast.py @@ -81,13 +81,13 @@ def extract_info_c(tu: TranslationUnit) -> TreeResult: match (kind, node.is_definition()): # Function definition case (CursorKind.FUNCTION_DECL, True): # type: ignore[reportAttributeAccessIssue] - result.symbols[usr] = Symbol(usr, kind, decl) + result.symbols[usr] = Symbol(usr, node, decl) fn_defn = get_code_from_tu_range(tu, node.extent) result.fn_definitions[usr] = fn_defn # Data structures case (kind, _) if kind in DATA_STRUCT_NODE_MAP: - result.symbols[usr] = Symbol(usr, kind, decl) + result.symbols[usr] = Symbol(usr, node, decl) previous_usr = usr for child in node.get_children(): @@ -97,7 +97,7 @@ def extract_info_c(tu: TranslationUnit) -> TreeResult: # Typedefs case (CursorKind.TYPEDEF_DECL, _): # type: ignore[reportAttributeAccessIssue] - result.symbols[usr] = Symbol(usr, kind, decl) + result.symbols[usr] = Symbol(usr, node, decl) # NOTE: If this is a typedef the data structure was visited just before for child in node.walk_preorder(): @@ -123,7 +123,13 @@ def extract_info_c(tu: TranslationUnit) -> TreeResult: # All other declarations case (kind, _) if kind in DECL_NODE_KIND: - result.symbols[usr] = Symbol(usr, kind, decl) + # Handle the case where we're a declaration occurs after a definition, e.g.: + # static const tflac_u16 tflac_crc16_tables[8][256] = { .. }; + # static const tflac_u16 tflac_crc16_tables[8][256]; + if usr not in result.symbols or ( + usr in result.symbols and node.is_definition() + ): + result.symbols[usr] = Symbol(usr, node, decl) case (_, _): raise NotImplementedError() @@ -153,7 +159,7 @@ def extract_referenced_symbols(node: Cursor) -> list[Symbol]: # Ignore internal references to, e.g., function parameters if child_node.referenced is None: continue - symbol_uses.append(Symbol(child_node.referenced.get_usr(), child_node.kind)) + symbol_uses.append(Symbol(child_node.referenced.get_usr(), child_node)) return symbol_uses @@ -195,6 +201,12 @@ def get_code_from_tu_range( def get_cursor_prettyprinted(cursor: Cursor) -> str: + # Include tag definition when typedef cursor with non-typeref child + include_tag_definition = 0 + if cursor.kind == CursorKind.TYPEDEF_DECL: # type: ignore[reportAttributeAccessIssue] + children = list(cursor.get_children()) + include_tag_definition = len(children) == 1 and children[0].kind != CursorKind.TYPE_REF # type: ignore[reportAttributeAccessIssue] + # Setup FFI for unsupported python libclang functions # NOTE: Upgrade libclang to get these? clang_getCursorPrintingPolicy = conf.lib.clang_getCursorPrintingPolicy # type: ignore[reportAttributeAccessIssue] @@ -214,7 +226,21 @@ def get_cursor_prettyprinted(cursor: Cursor) -> str: clang_getCursorPrettyPrinted.errcheck = _CXString.from_result policy = clang_getCursorPrintingPolicy(cursor) - clang_PrintingPolicy_setProperty(policy, 3, 1) # IncludeTagDefinition - clang_PrintingPolicy_setProperty(policy, 23, 1) # ConstantsAsWritten + clang_PrintingPolicy_setProperty(policy, 3, include_tag_definition) + clang_PrintingPolicy_setProperty(policy, 23, 0) # ConstantsAsWritten + # clang_PrintingPolicy_setProperty(policy, 26, 1) # PrintAsCanonical + + return clang_getCursorPrettyPrinted(cursor, policy).rstrip() + + +def get_cursor_code(cursor: Cursor, pretty_print: bool = False) -> str: + if pretty_print: + code = get_cursor_prettyprinted(cursor) + else: + code = get_code_from_tu_range(cursor.translation_unit, cursor.extent) + + # Non-function definitions require statement terminations + if cursor.kind != CursorKind.FUNCTION_DECL or not cursor.is_definition(): # type: ignore[reportAttributeAccessIssue] + code += ";" - return clang_getCursorPrettyPrinted(cursor, policy) + return code diff --git a/src/ideas/ast_rust.py b/src/ideas/ast_rust.py new file mode 100644 index 0000000..5f0bbed --- /dev/null +++ b/src/ideas/ast_rust.py @@ -0,0 +1,414 @@ +# +# Copyright (C) 2025 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 +# + +import logging +from collections import defaultdict +from dataclasses import dataclass, field + +from typing import Callable +from tree_sitter import Language, Parser, Node +import tree_sitter_rust as tsrust + +from .utils import RustSymbol + +logger = logging.getLogger("ideas.ast_rust") + + +@dataclass +class RustTreeResult: + symbols: dict[str, RustSymbol] = field(default_factory=dict) + fn_definitions: dict[str, str | None] = field( + default_factory=lambda: defaultdict(lambda: None) + ) + + +def extract_all_attributes(function_node: Node, source_code: str) -> list[str]: + attributes = [] + prev_sibling = function_node.prev_sibling + while prev_sibling and prev_sibling.type in [ + "attribute_item", + "line_comment", + "block_comment", + ]: + if prev_sibling.type == "attribute_item": + attr_text = source_code[prev_sibling.start_byte : prev_sibling.end_byte] + attributes.append(attr_text) + prev_sibling = prev_sibling.prev_sibling + + return list(reversed(attributes)) + + +def has_attribute(function_node: Node, source_code: str, attribute: str) -> bool: + prev_sibling = function_node.prev_sibling + while prev_sibling and prev_sibling.type in [ + "attribute_item", + "line_comment", + "block_comment", + ]: + if prev_sibling.type == "attribute_item": + attr_text = source_code[prev_sibling.start_byte : prev_sibling.end_byte] + if attr_text == attribute: + return True + + prev_sibling = prev_sibling.prev_sibling + + return False + + +def find_attribute_node(function_node: Node, source_code: str, attribute: str) -> Node | None: + prev_sibling = function_node.prev_sibling + while prev_sibling and prev_sibling.type in [ + "attribute_item", + "line_comment", + "block_comment", + ]: + if prev_sibling.type == "attribute_item": + attr_text = source_code[prev_sibling.start_byte : prev_sibling.end_byte] + if attr_text == attribute: + return prev_sibling + prev_sibling = prev_sibling.prev_sibling + + return None + + +def find_function_node( + root_node: Node, source_code: str, symbol_name: str, module_name: str | None = None +) -> Node | None: + def find_in_node(node: Node, in_target_module: bool = False) -> Node | None: + for child in node.children: + if child.type == "function_item": + func_name = get_function_name(child, source_code) + if func_name == symbol_name and ( + (module_name is None and node == root_node) or in_target_module + ): + return child + + elif child.type == "mod_item" and module_name is not None: + name_node = child.child_by_field_name("name") + if ( + not name_node + or source_code[name_node.start_byte : name_node.end_byte] != module_name + ): + continue + + # Search within this module's body + body = child.child_by_field_name("body") + if not body: + continue + + result = find_in_node(body, in_target_module=True) + if result: + return result + return None + + if module_name is None: + # Search only top-level functions + for node in root_node.children: + if ( + node.type == "function_item" + and get_function_name(node, source_code) == symbol_name + ): + return node + return None + else: + # Search for function in the specified module + return find_in_node(root_node) + + +def get_function_name(function_node: Node, source_code: str) -> str | None: + for child in function_node.children: + if child.type == "identifier": + return source_code[child.start_byte : child.end_byte] + + return None + + +def get_extern_symbols(node: Node, source_code: str) -> list[RustSymbol]: + symbols: list[RustSymbol] = [] + for child in node.children: + # Only look for the "declaration_list" node + if child.type != "declaration_list": + continue + + for decl_child in child.children: + # Only look for "function_signature_item" nodes + if decl_child.type != "function_signature_item": + continue + + name_node = decl_child.child_by_field_name("name") + if not name_node: + raise ValueError( + f"Encountered an extern ABI function signature item without a name in {source_code[decl_child.start_byte:decl_child.end_byte] = }!" + ) + fn_name = source_code[name_node.start_byte : name_node.end_byte] + fn_signature = source_code[decl_child.start_byte : decl_child.end_byte] + attributes = extract_all_attributes(decl_child, source_code) + + symbols.append(RustSymbol(fn_name, decl_child, fn_signature, attributes=attributes)) + + return symbols + + +def get_function_names_in_module( + root_node: Node, source_code: str, module_name: str | None = None +) -> list[str]: + function_names = [] + + if module_name is None: + # Search only top-level functions + for node in root_node.children: + if node.type == "function_item": + func_name = get_function_name(node, source_code) + if func_name: + function_names.append(func_name) + elif node.type == "foreign_mod_item": + extern_symbols = get_extern_symbols(node, source_code) + function_names.extend([symbol.name for symbol in extern_symbols]) + else: + # Find the module body + module_node = find_module_body(root_node, source_code, module_name) + if not module_node: + raise ValueError(f"Module {module_name} not found!") + + # Extract all functions from the module + for child in module_node.children: + if child.type == "function_item": + func_name = get_function_name(child, source_code) + if func_name: + function_names.append(func_name) + elif child.type == "foreign_mod_item": + extern_symbols = get_extern_symbols(child, source_code) + function_names.extend([symbol.name for symbol in extern_symbols]) + + return function_names + + +def find_module_body(node: Node, source_code: str, module_name: str) -> Node | None: + for child in node.children: + if child.type == "mod_item": + name_node = child.child_by_field_name("name") + if ( + name_node + and source_code[name_node.start_byte : name_node.end_byte] == module_name + ): + return child.child_by_field_name("body") + + # Check nested modules recursively + body = child.child_by_field_name("body") + if body: + result = find_module_body(body, source_code, module_name) + if result: + return result + return None + + +def traverse_rust(source_code: str, node: Node) -> RustTreeResult: + result = RustTreeResult() + + if node.type == "function_item": + # Extract function name from the function_item node + name_node = node.child_by_field_name("name") + if name_node: + fn_name = source_code[name_node.start_byte : name_node.end_byte] + fn_definition = source_code[node.start_byte : node.end_byte] + # Extract just the signature (everything before the body) + body_node = node.child_by_field_name("body") + if body_node: + fn_signature = source_code[node.start_byte : body_node.start_byte] + else: + fn_signature = fn_definition + attributes = extract_all_attributes(node, source_code) + + result.symbols[fn_name] = RustSymbol( + fn_name, node, fn_signature, attributes=attributes + ) + result.fn_definitions[fn_name] = fn_definition + + # Non-ABI nested function signatures (e.g., in trait definitions) + elif node.type == "function_signature_item": + name_node = node.child_by_field_name("name") + if name_node: + fn_name = source_code[name_node.start_byte : name_node.end_byte] + fn_signature = source_code[node.start_byte : node.end_byte] + attributes = extract_all_attributes(node, source_code) + + result.symbols[fn_name] = RustSymbol( + fn_name, node, fn_signature, attributes=attributes + ) + result.fn_definitions[fn_name] = fn_signature + + # Handle extern blocks with FFI function declarations + elif node.type == "foreign_mod_item": + # The ABI is the "string_literal" child of the "extern_modifier" child + abi = None + for child in node.children: + if child.type != "extern_modifier": + continue + + if abi: + raise ValueError( + f"Multiple extern modifiers found in extern block {source_code[node.start_byte:node.end_byte] = }!" + ) + + for modifier_child in child.children: + if modifier_child.type == "string_literal": + abi = source_code[modifier_child.start_byte : modifier_child.end_byte] + break + + if abi is None: + raise ValueError( + f"Unable to determine ABI for extern block {source_code[node.start_byte:node.end_byte] = }!" + ) + + # Extract all symbols in this extern block + symbols: list[RustSymbol] = get_extern_symbols(node, source_code) + for symbol in symbols: + fn_name = symbol.name + decl = symbol.decl + + # Construct standalone declaration with ABI specifier + complete_decl = f"extern {abi} {{ {decl} }}" + result.fn_definitions[fn_name] = complete_decl + result.symbols[fn_name] = RustSymbol( + fn_name, symbol.node, complete_decl, symbol.attributes + ) + + # Don't recursively traverse extern block children since we handled them with the ABI identifier + return result + + # Recursively traverse all child nodes + for child in node.children: + child_result = traverse_rust(source_code, child) + result.symbols.update(child_result.symbols) + result.fn_definitions.update(child_result.fn_definitions) + + return result + + +def extract_info(code: str, traverse_fn: Callable, parser: Parser) -> RustTreeResult: + tree = parser.parse(bytes(code, "utf8")) + return traverse_fn(code, tree.root_node) + + +def extract_info_rust(code: str) -> RustTreeResult: + lang = Language(tsrust.language()) + parser = Parser(lang) + return extract_info(code, traverse_fn=traverse_rust, parser=parser) + + +def remove_attribute_from_fn( + source_code: str, + symbol_name: str, + attr: str, + module_name: str | None = None, +) -> str: + # Parse the code + parser = Parser(Language(tsrust.language())) + source_bytes = source_code.encode("utf-8") + tree = parser.parse(source_bytes) + root_node = tree.root_node + + # Find the target function in the target module (or top-level if module_name is None) + target_function = find_function_node(root_node, source_code, symbol_name, module_name) + if target_function is None: + logger.warning( + f"Function {symbol_name} not found in {module_name or 'top-level'} module!" + ) + return source_code + + # Find the attribute node + attr_node = find_attribute_node(target_function, source_code, attr) + if attr_node is None: + logger.debug(f"Function {symbol_name} does not have attribute {attr}, no changes made.") + return source_code + + # Remove the attribute from the source code + modified_source = source_code[: attr_node.start_byte] + source_code[attr_node.end_byte :] + + return modified_source + + +def ensure_attribute_for_fn( + source_code: str, + symbol_name: str, + attr: str, + module_name: str | None = None, +) -> str: + # 'main' should not have #[unsafe(no_mangle)] in Rust + if symbol_name == "main" and attr == "#[unsafe(no_mangle)]": + logger.warning("Skipping addition of #[unsafe(no_mangle)] to 'main' function!") + return source_code + + # Parse the code + parser = Parser(Language(tsrust.language())) + source_bytes = source_code.encode("utf-8") + tree = parser.parse(source_bytes) + root_node = tree.root_node + + # Find the target function in the target module (or top-level if module_name is None) + target_function = find_function_node(root_node, source_code, symbol_name, module_name) + if target_function is None: + logger.warning( + f"Function {symbol_name} not found in {module_name or 'top-level'} module!" + ) + return source_code + + # If function already has attribute, do nothing + if has_attribute(target_function, source_code, attr): + logger.debug(f"Function {symbol_name} already has attribute {attr}, no changes made.") + return source_code + + # Add as the last attribute + func_line_start = source_code.rfind("\n", 0, target_function.start_byte) + 1 + modified_source = ( + source_code[:func_line_start] + f"{attr}\n" + source_code[func_line_start:] + ) + + return modified_source + + +def ensure_no_mangle_in_module( + source_code: str, module_name: str | None = None, add: bool = False +) -> str: + """Ensure that all functions in the specified module have the #[unsafe(no_mangle)] attribute. + + Args: + source_code (str): The Rust source code. + module_name (str | None): The module name to target, or None for top-level + add (bool): Whether to also add the attribute to all functions or only replace existing invalid occurrences. + + Returns: + str: The modified source code. + """ + # Parse the code + parser = Parser(Language(tsrust.language())) + source_bytes = source_code.encode("utf-8") + tree = parser.parse(source_bytes) + root_node = tree.root_node + function_names = get_function_names_in_module( + root_node=root_node, source_code=source_code, module_name=module_name + ) + + # Patch each function if needed + modified_source = source_code + for func_name in function_names: + fn_node = find_function_node(root_node, source_code, func_name, module_name) + if fn_node is None: + raise ValueError(f"Function {func_name} not found in {module_name = }!") + + attributes = extract_all_attributes(fn_node, source_code) + # Remove invalid use + if "#[no_mangle]" in attributes: + modified_source = remove_attribute_from_fn( + modified_source, func_name, "#[no_mangle]", module_name=module_name + ) + # Add the correct attribute + if add or "#[no_mangle]" in attributes: + modified_source = ensure_attribute_for_fn( + modified_source, func_name, "#[unsafe(no_mangle)]", module_name=module_name + ) + + return modified_source diff --git a/src/ideas/convert_tests.py b/src/ideas/convert_tests.py index ddbc270..71a5018 100644 --- a/src/ideas/convert_tests.py +++ b/src/ideas/convert_tests.py @@ -5,7 +5,7 @@ # """ -Convert executable JSON test cases like: +Use JSON test vectors like: ```json { @@ -18,7 +18,7 @@ } ``` -to: +to generate tests for binary targets: ```rust use assert_cmd::Command; @@ -26,6 +26,7 @@ use predicates::prelude::*; #[test] +#[timeout(some_timeout)] fn test1() { Command::cargo_bin(assert_cmd::crate_name!()).unwrap() .args(&["--flag", "value"]) @@ -38,10 +39,12 @@ .code(rc); } ``` + +TODO: Use a template to generate tests for library targets. """ -import sys import json +import argparse from pathlib import Path @@ -49,7 +52,13 @@ def to_rust_str(string): return '"' + repr(string)[1:-1] + '"' -def convert_tests(test_cases: list[Path]): +def is_bin_test(test_case: Path): + test_case_json = json.loads(test_case.read_text()) + return "lib_state_in" not in test_case_json and "lib_state_out" not in test_case_json + + +def convert_tests_for_exec(test_cases: list[Path], timeout: int = 60000): + test_cases = list(filter(is_bin_test, test_cases)) if len(test_cases) == 0: return @@ -103,7 +112,7 @@ def convert_tests(test_cases: list[Path]): raise ValueError(f"stderr.is_regex must be a boolean, got {type(is_stderr_regex)}") print("#[test]") - print("#[timeout(600000)]") # 10 minutes + print(f"#[timeout({timeout})]") print(f"fn test_case_{test_case.stem}() {{") print(" Command::cargo_bin(assert_cmd::crate_name!()).unwrap()") if len(args) > 0: @@ -126,5 +135,30 @@ def convert_tests(test_cases: list[Path]): print("") +def is_lib_test(test_case: Path): + test_case_json = json.loads(test_case.read_text()) + return "lib_state_in" in test_case_json and "lib_state_out" in test_case_json + + +def convert_tests_for_lib( + test_cases: list[Path], template_path: Path | None, timeout: int = 60000 +): + raise ValueError("Library test conversion not implemented yet!") + + if __name__ == "__main__": - convert_tests([Path(test_case) for test_case in sys.argv[1:]]) + parser = argparse.ArgumentParser(description="Convert JSON test vectors to cargo tests") + parser.add_argument( + "test_vectors", type=Path, nargs="+", help="Path(s) to JSON test vector(s)" + ) + parser.add_argument( + "--template", type=Path, help="Path to Rust test template", required=False + ) + parser.add_argument( + "--timeout", type=int, help="Timeout for each test in milliseconds", default=60000 + ) + args = parser.parse_args() + + test_vectors = [Path(path) for path in args.test_vectors] + convert_tests_for_exec(test_vectors, args.timeout) + convert_tests_for_lib(test_vectors, args.template, args.timeout) diff --git a/src/ideas/cover.py b/src/ideas/cover.py index 96bce0a..bcac731 100644 --- a/src/ideas/cover.py +++ b/src/ideas/cover.py @@ -6,13 +6,15 @@ # import logging -from typing import Any +from typing import Any, Callable from litellm.exceptions import ContextWindowExceededError import dspy from dspy.predict.react import _fmt_exc +from ideas import ensure_no_mangle_in_module + logger = logging.getLogger(__name__) @@ -20,20 +22,23 @@ class CoVeR(dspy.Module): def __init__( self, signature: dspy.SignatureMeta, - tools: list[dspy.Tool], + tools: list[dspy.Tool | Callable], success: str = "Success!", max_iters: int = 5, use_raw_fixer_output: bool = True, + patch_no_mangle: bool = True, ): super().__init__() self.signature = signature = dspy.ensure_signature(signature) # type: ignore self.success = success self.max_iters = max_iters self.use_raw_fixer_output = use_raw_fixer_output + self.patch_no_mangle = patch_no_mangle if len(tools) == 0: raise ValueError("Need at least one valid dspy.Tool!") - tool_dict = {tool.name: tool for tool in tools} + dspy_tools = [t if isinstance(t, dspy.Tool) else dspy.Tool(t) for t in tools] + tool_dict = {tool.name: tool for tool in dspy_tools} inputs = ", ".join([f"`{k}`" for k in signature.input_fields.keys()]) outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()]) @@ -72,7 +77,7 @@ def __init__( f"No valid outputs in {self.task_outputs} can be routed to tool {name}!" ) - cover_signature = ( + self.cover_signature = ( dspy.Signature( {**signature.input_fields, **signature.output_fields}, # type: ignore "\n".join(instr), @@ -86,7 +91,7 @@ def __init__( signature.instructions, ).append("trajectory", dspy.InputField(), type_=str) - self.cover = dspy.Predict(cover_signature) + self.cover = dspy.Predict(self.cover_signature) self.extract = dspy.ChainOfThought(fallback_signature) def _format_trajectory(self, trajectory: dict[str, Any]): @@ -109,6 +114,14 @@ def forward(self, **input_args): break assert pred is not None, "Prediction should not be None!" + + # Patch #[no_mangle] -> #[unsafe(no_mangle)] for all Rust code outputs + # But do not add the attribute if it does not exist + if self.patch_no_mangle: + for key, value in self.cover_signature.output_fields.items(): + if isinstance(value.annotation, dspy.Code["Rust"].__class__): + pred[key].code = ensure_no_mangle_in_module(pred[key].code, add=False) + trajectory[f"thought_{idx}"] = pred.next_thought # type: ignore tool_names, tool_args, observations = [], [], [] for name in self.tools.keys(): diff --git a/src/ideas/init.py b/src/ideas/init.py new file mode 100644 index 0000000..7c6b487 --- /dev/null +++ b/src/ideas/init.py @@ -0,0 +1,297 @@ +# +# Copyright (C) 2025 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 +# + +import os +import logging +from pathlib import Path +from graphlib import TopologicalSorter +from collections.abc import Iterable, Container +from dataclasses import dataclass + +import hydra +from omegaconf import MISSING +from hydra.core.config_store import ConfigStore +from hydra.core.hydra_config import HydraConfig +from clang.cindex import CompilationDatabase, TranslationUnit, CursorKind + +from ideas import get_info_from_cargo_toml +from .ast import get_cursor_code, extract_info_c, TreeResult +from .utils import Symbol +from .tools import Crate + +logger = logging.getLogger("ideas.preprocess") + + +@dataclass +class InitConfig: + filename: Path = MISSING + pretty_print: bool = True + export_symbols: Path | None = None + source_priority: Path | None = None + + +cs = ConfigStore.instance() +cs.store(name="init", node=InitConfig) + + +def init( + compile_commands: Path, + crate: Crate, + export_symbols: list[str] | None = None, + source_priority: list[Path] | None = None, + pretty_print: bool = True, +): + # Get symbol table and dependencies taking into account source priority and exported symbols + asts = get_asts(compile_commands, valid_paths=source_priority) + symbols, dependencies = get_symbols_and_dependencies(asts, source_priority, export_symbols) + logger.info(f"Found {len(symbols)} symbols in {compile_commands}!") + + # Assemble C sources in topological order + includes = get_includes(symbols) + sources = [] + for symbol_name in TopologicalSorter(dependencies).static_order(): + # Ignore tag definitions and function declarations + if symbol_name not in symbols: + logger.warning(f"Skipping `{symbol_name}` ...") + continue + symbol_code = get_cursor_code(symbols[symbol_name].cursor, pretty_print=pretty_print) + sources.append(symbol_code) + + # Create initial outputs + crate.rust_src_path.parent.mkdir(exist_ok=True, parents=True) + crate.rust_src_path.with_suffix(".c").write_text( + "\n".join(includes) + "\n\n" + "\n\n".join(sources) + ) + + +def get_symbols_and_dependencies( + asts: list[TreeResult], + source_priority: list[Path] | None = None, + export_symbols: list[str] | None = None, +) -> tuple[dict[str, Symbol], dict[str, list[str]]]: + global_symbols = merge_symbols( + [ast.symbols for ast in asts], source_priority=source_priority + ) + + # Filter global symbols to create project symbols + project_symbols = filter_symbols(global_symbols, filter_system=True) + project_dependencies = merge_complete_graphs(asts, valid_names=project_symbols) + + # Use export_symbols to filter project symbols and dependencies + dependencies = project_dependencies + if export_symbols is not None: + export_symbols = [c14n_symbol_name(name, project_symbols) for name in export_symbols] + dependencies = reachable_subgraph(project_dependencies, export_symbols) + symbols = filter_symbols( + project_symbols, filter_tag_definitions=True, filter_function_declarations=True + ) + + return symbols, dependencies + + +def get_includes(symbols: dict[str, Symbol]) -> set[str]: + includes: set[str] = set() + for symbol in symbols.values(): + tu = symbol.cursor.translation_unit + for inclusion in tu.get_includes(): + # Source of the include should be in same path as TU while the include should NOT be + tu_path = str(Path(tu.spelling).resolve()) + inclusion_source_path = str(Path(inclusion.source.name).resolve()) + inclusion_include_path = str(Path(inclusion.include.name).resolve()) + if (os.path.commonprefix((tu_path, inclusion_source_path)) != "/") and ( + os.path.commonprefix((tu_path, inclusion_include_path)) == "/" + ): + # Get include directive from source + with open(inclusion.location.file.name, "rb") as f: + f.seek(inclusion.location.offset) + include = f.readline().decode().strip() + includes.add(f"#include {include}") + return includes + + +def get_asts(filename: Path, valid_paths: list[Path] | None = None) -> list[TreeResult]: + assert filename.name == "compile_commands.json" + db = CompilationDatabase.fromDirectory(filename.parent) + cmds = db.getAllCompileCommands() + tus = [TranslationUnit.from_source(None, args=list(cmd.arguments)) for cmd in cmds] + if valid_paths is not None: + tus = [tu for tu in tus if Path(tu.cursor.spelling) in valid_paths] + asts = [extract_info_c(tu) for tu in tus] + return asts + + +def filter_symbols( + symbols: dict[str, Symbol], + filter_system: bool = True, + filter_tag_definitions: bool = False, + filter_function_declarations: bool = False, +) -> dict[str, Symbol]: + filtered_symbols = {} + for name, symbol in symbols.items(): + # Ignore "system" symbols + if filter_system: + tu_path = Path(symbol.cursor.translation_unit.spelling).resolve() + sym_path = Path(symbol.cursor.location.file.name).resolve() + if os.path.commonprefix([str(tu_path), str(sym_path)]) == "/": + continue + + # Ignore "tag definitions" that are contained with in other symbols like: + # typedef enum name { value } name; + # This produces two ENUM cursors. + if filter_tag_definitions: + children = list(symbol.cursor.get_children()) + if ( + symbol.cursor.kind == CursorKind.TYPEDEF_DECL # type: ignore[reportAttributeAccessIssue] + and len(children) == 1 + and children[0].kind != CursorKind.TYPE_REF # type: ignore[reportAttributeAccessIssue] + ): + contained_name = children[0].get_usr() + filtered_symbols.pop(contained_name, None) + + # Filter function declarations + if filter_function_declarations: + if ( + symbol.cursor.kind == CursorKind.FUNCTION_DECL # type: ignore[reportAttributeAccessIssue] + and not symbol.cursor.is_definition() + ): + continue + + filtered_symbols[name] = symbols[name] + return filtered_symbols + + +def merge_symbols( + list_of_symbols: list[dict[str, Symbol]], source_priority: list[Path] | None = None +) -> dict[str, Symbol]: + if source_priority is None: + source_priority = [] + + global_symbols: dict[str, Symbol] = {} + for symbols in list_of_symbols: + # Gather symbols + for name, symbol in symbols.items(): + # If not in global symbol table add it + if name not in global_symbols: + global_symbols[name] = symbol + continue + + # If code matches, then don't bother replacing + global_code = get_cursor_code(global_symbols[name].cursor) + code = get_cursor_code(symbol.cursor) + if global_code == code: + continue + + global_source = Path( + global_symbols[name].cursor.translation_unit.spelling + ).resolve() + symbol_source = Path(symbol.cursor.translation_unit.spelling).resolve() + + # If overwriting a symbol, then prefer one with a definition + if ( + global_symbols[name].cursor.is_definition() + and not symbol.cursor.is_definition() + ): + continue + elif ( + not global_symbols[name].cursor.is_definition() + and symbol.cursor.is_definition() + ): + global_symbols[name] = symbol + # Or prefer the symbol with source priority + elif global_source in source_priority and symbol_source not in source_priority: + continue + elif global_source not in source_priority and symbol_source in source_priority: + global_symbols[name] = symbol + elif ( + global_source in source_priority + and symbol_source in source_priority + and source_priority.index(global_source) > source_priority.index(symbol_source) + ): + global_symbols[name] = symbol + elif ( + global_source in source_priority + and symbol_source in source_priority + and source_priority.index(global_source) < source_priority.index(symbol_source) + ): + continue + else: + # Two symbols have similar names but different declarations or definitions and no source priority! + raise NotImplementedError( + f"Unable to handle symbol {name} with multiple different definitions and unknown source priority!\nSymbol found in {global_source} and {symbol_source}." + ) + return global_symbols + + +def merge_complete_graphs( + asts: list[TreeResult], valid_names: Container[str] +) -> dict[str, list[str]]: + graph: dict[str, list[str]] = {} + for ast in asts: + # FIXME: Would be nice if complete_graph was a dict[str, list[str]] instead of dict[str, list[Symbol]] + for node, neighbor_symbols in ast.complete_graph.items(): + neighbors = [sym.name for sym in neighbor_symbols] + if node not in valid_names: + continue + if node not in graph: + graph[node] = [] + for neighbor in neighbors: + if neighbor not in valid_names or neighbor in graph[node]: + continue + graph[node].append(neighbor) + return dict(graph) + + +def reachable_subgraph( + dependencies: dict[str, list[str]], names: Iterable[str] +) -> dict[str, list[str]]: + subgraph: dict[str, list[str]] = {} + for name in names: + subgraph[name] = dependencies[name] + subgraph.update(reachable_subgraph(dependencies, subgraph[name])) + return subgraph + + +def c14n_symbol_name(name: str, symbols: dict[str, Symbol]): + if name in symbols: + return name + if f"c:@F@{name}" in symbols: + return f"c:@F@{name}" + + # Find symbols with spelling of name + potential_names = {s.name for s in symbols.values() if name in s.cursor.spelling} + if len(potential_names) == 0: + symbol_names = "\n".join(symbols.keys()) + raise ValueError(f"Unable to find {name} in symbols:\n{symbol_names}") + elif len(potential_names) != 1: + raise ValueError(f"Unable to find {name} in symbols! Found: {potential_names}") + return potential_names.pop() + + +@hydra.main(version_base=None, config_name="init") +def main(cfg: InitConfig) -> None: + output_dir = Path(HydraConfig.get().runtime.output_dir) + crate = get_info_from_cargo_toml(output_dir / "Cargo.toml") + + export_symbols = None + if isinstance(cfg.export_symbols, Path): + export_symbols = cfg.export_symbols.read_text().splitlines() + + source_priority = None + if isinstance(cfg.source_priority, Path): + source_priority = [Path(path) for path in cfg.source_priority.read_text().splitlines()] + + init( + cfg.filename, + crate, + export_symbols=export_symbols, + source_priority=source_priority, + pretty_print=cfg.pretty_print, + ) + logger.info(f"Prepared translation in {output_dir}") + + +if __name__ == "__main__": + main() diff --git a/src/ideas/model.py b/src/ideas/model.py index b36e9c1..ba99bee 100644 --- a/src/ideas/model.py +++ b/src/ideas/model.py @@ -34,7 +34,7 @@ class GenerateConfig: cs.store(name="generate", node=GenerateConfig) -def configure(model: ModelConfig, generate: GenerateConfig): +def get_lm(model: ModelConfig, generate: GenerateConfig) -> dspy.LM: lm = dspy.LM( model=model.name, cache=model.cache, @@ -58,4 +58,9 @@ def configure(model: ModelConfig, generate: GenerateConfig): lm.kwargs["provider"] = provider # type: ignore[reportArgumentType] + return lm + + +def configure(model: ModelConfig, generate: GenerateConfig): + lm = get_lm(model, generate) dspy.configure(lm=lm) diff --git a/src/ideas/repair.py b/src/ideas/repair.py index 7c17e3b..57979a7 100644 --- a/src/ideas/repair.py +++ b/src/ideas/repair.py @@ -6,7 +6,6 @@ import io import sys -import json import logging import subprocess from pathlib import Path @@ -19,8 +18,10 @@ from hydra.core.config_store import ConfigStore from hydra.core.hydra_config import HydraConfig - from ideas import model, ModelConfig, GenerateConfig +from ideas import get_info_from_cargo_toml +from .ast_rust import ensure_no_mangle_in_module +from .tools import Crate logger = logging.getLogger("ideas.repair") OmegaConf.register_new_resolver( @@ -46,6 +47,8 @@ class RepairConfig: cargo_toml: Path = MISSING max_iters: int = 100 + ensure_no_mangle: bool = True + cs = ConfigStore.instance() cs.store(name="repair", node=RepairConfig) @@ -57,7 +60,7 @@ def repair(cfg: RepairConfig) -> None: model.configure(cfg.model, cfg.generate) - agent = RepairAgent(max_iters=cfg.max_iters) + agent = RepairAgent(max_iters=cfg.max_iters, ensure_no_mangle=cfg.ensure_no_mangle) reparation = agent(cfg.cargo_toml) if "history" in reparation: @@ -77,40 +80,24 @@ class Reparation(dspy.Signature): class RepairAgent(dspy.Module): - def __init__(self, max_iters: int = 1): + def __init__(self, max_iters: int = 1, ensure_no_mangle: bool = True): super().__init__() self.max_iters = max_iters self.repair = dspy.ChainOfThought(Reparation) + self.ensure_no_mangle = ensure_no_mangle def forward(self, cargo_toml: Path) -> dict[str, str]: if not cargo_toml.exists(): raise ValueError(f"{cargo_toml=} must exist!") - # Extract root package metadata from Cargo.toml - out = subprocess.run( - ["cargo", "metadata", "--manifest-path", cargo_toml], text=True, capture_output=True - ) - if out.returncode != 0: - raise ValueError(f"Failed to get cargo metadata from {cargo_toml}!\n{out.stderr}") - metadata = json.loads(out.stdout) - root = metadata["resolve"]["root"] - root_package = next(filter(lambda p: p["id"] == root, metadata["packages"])) - - # Get rust source path for bin or lib - bin_targets = list(filter(lambda t: "bin" in t["kind"], root_package["targets"])) - lib_targets = list(filter(lambda t: "lib" in t["kind"], root_package["targets"])) - if len(bin_targets) == 1 and len(lib_targets) == 0: - rust_src_path = Path(bin_targets[0]["src_path"]) - elif len(bin_targets) == 0 and len(lib_targets) == 1: - rust_src_path = Path(lib_targets[0]["src_path"]) - else: - raise ValueError( - f"Unhandled bin/lib targets configuration in Cargo.toml: {bin_targets=} {lib_targets=}" - ) + # Get target source path + crate: Crate = get_info_from_cargo_toml(cargo_toml) # Get test source path - test_targets = list(filter(lambda t: "test" in t["kind"], root_package["targets"])) + test_targets = list( + filter(lambda t: "test" in t["kind"], crate.root_package["targets"]) + ) if len(test_targets) != 1: raise ValueError( f"Unhandled test targets configuration in Cargo.toml: {test_targets=}" @@ -134,13 +121,17 @@ def forward(self, cargo_toml: Path) -> dict[str, str]: break reparation: dspy.Prediction = self.repair( - input_code=rust_src_path.read_text(), + input_code=crate.rust_src_path.read_text(), test_code=test_src_path.read_text(), cargo_test_output=out.stdout, ) repaired_code = reparation["repaired_code"].code - rust_src_path.write_text(repaired_code) + # Guarantee #[unsafe(no_mangle)] for all top-level symbols + if self.ensure_no_mangle: + repaired_code = ensure_no_mangle_in_module(repaired_code, add=True) + + crate.rust_src_path.write_text(repaired_code) # Save agent history to string history = io.StringIO() diff --git a/src/ideas/tools.py b/src/ideas/tools.py index 26da06b..a7d13aa 100644 --- a/src/ideas/tools.py +++ b/src/ideas/tools.py @@ -4,7 +4,10 @@ # SPDX-License-Identifier: Apache-2.0 # +import json from json import loads as js_loads +from dataclasses import dataclass + import logging import subprocess from typing import Any @@ -19,6 +22,14 @@ DEFAULT_TEST_TIMEOUT = 10.0 # seconds +@dataclass +class Crate: + cargo_toml: Path + rust_src_path: Path + root_package: dict[str, Any] + is_bin: bool + + def run_subprocess( cmd: list[str], input: str | None = None, @@ -116,6 +127,41 @@ def check_rust( return run_subprocess(cmd, input=code) +def get_info_from_cargo_toml(cargo_toml: Path) -> Crate: + # Extract root package metadata from Cargo.toml + out = subprocess.run( + ["cargo", "metadata", "--manifest-path", cargo_toml], text=True, capture_output=True + ) + if out.returncode != 0: + raise ValueError(f"Failed to get cargo metadata from {cargo_toml}!\n{out.stderr}") + metadata = json.loads(out.stdout) + root = metadata["resolve"]["root"] + if root is None: + raise ValueError("No root package specified!") + root_package = next(filter(lambda p: p["id"] == root, metadata["packages"])) + + # Get rust source path for bin or lib + bin_targets = list(filter(lambda t: "bin" in t["kind"], root_package["targets"])) + lib_targets = list(filter(lambda t: "lib" in t["kind"], root_package["targets"])) + if len(bin_targets) == 1 and len(lib_targets) == 0: + rust_src_path = Path(bin_targets[0]["src_path"]) + is_bin = True + elif len(bin_targets) == 0 and len(lib_targets) == 1: + rust_src_path = Path(lib_targets[0]["src_path"]) + is_bin = False + else: + raise ValueError( + f"Unhandled bin/lib targets configuration in Cargo.toml: {bin_targets=} {lib_targets=}" + ) + + return Crate( + cargo_toml=cargo_toml, + rust_src_path=rust_src_path, + root_package=root_package, + is_bin=is_bin, + ) + + def run_clippy( source_file: str, flags: list[list[str]] | None = None, structured_output: bool = False ) -> list[tuple[bool, str]]: diff --git a/src/ideas/translate.py b/src/ideas/translate.py index c6fd85b..9767244 100644 --- a/src/ideas/translate.py +++ b/src/ideas/translate.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: Apache-2.0 # -import os import logging from pathlib import Path from dataclasses import dataclass, field @@ -19,10 +18,12 @@ from ideas import model, ModelConfig, GenerateConfig from ideas import TranslateAgent, tools +from .ast_rust import ensure_no_mangle_in_module @dataclass class TranslateConfig: + c_project_dir: Path = MISSING filename: Path = MISSING model: ModelConfig = field(default_factory=ModelConfig) generate: GenerateConfig = field(default_factory=GenerateConfig) @@ -32,9 +33,12 @@ class TranslateConfig: translator: str = "CoVeR" use_raw_fixer_output: bool = True + patch_no_mangle: bool = True batched: bool = False + ensure_no_mangle: bool = True + cs = ConfigStore.instance() cs.store(name="translate", node=TranslateConfig) @@ -81,20 +85,23 @@ def translate(cfg: TranslateConfig) -> None: "filename must be a pre-processed C source file with .i extension or compile_commands.json" ) - # Find common path amongst input paths and construct examples to pass to agent - common_dir = Path(os.path.commonpath(c_paths)) - if not common_dir.is_dir(): - common_dir = common_dir.parent + # Construct examples to pass to agent + if not cfg.c_project_dir.is_dir(): + raise ValueError(f"Common directory {cfg.c_project_dir} is not a directory!") + path_prefixes = [ + "" if c_path.relative_to(cfg.c_project_dir).is_relative_to("src") else "src" + for c_path in c_paths + ] examples = [ dspy.Example( - input_code_path=c_path.relative_to(common_dir), + input_code_path=path_prefix / c_path.relative_to(cfg.c_project_dir), input_code=c_code, - full_code_path=c_i_path.relative_to(common_dir), + full_code_path=path_prefix / c_i_path.relative_to(cfg.c_project_dir), full_code=c_i_code, tu=tu, ).with_inputs("input_code_path", "input_code", "full_code_path", "full_code", "tu") - for c_path, c_code, c_i_path, c_i_code, tu in zip( - c_paths, c_codes, c_i_paths, c_i_codes, tus + for c_path, c_code, c_i_path, c_i_code, tu, path_prefix in zip( + c_paths, c_codes, c_i_paths, c_i_codes, tus, path_prefixes ) ] @@ -104,8 +111,9 @@ def translate(cfg: TranslateConfig) -> None: cfg.translator, cfg.max_iters, cfg.use_raw_fixer_output, + cfg.patch_no_mangle, ) - with chdir(common_dir): + with chdir(cfg.c_project_dir): if cfg.batched: translations = agent.batch( examples, @@ -124,16 +132,21 @@ def translate(cfg: TranslateConfig) -> None: filename = translation["input_code_path"] logger.info(f"Translated {filename} ...") + output_code = translation["output_code"] + # Guarantee #[unsafe(no_mangle)] for all top-level symbols + if cfg.ensure_no_mangle: + output_code = ensure_no_mangle_in_module(output_code, add=True) + # Write rust translation to disk - rs_translation_path = output_dir / "src" / filename.with_suffix(".rs") + rs_translation_path = output_dir / filename.with_suffix(".rs") rs_translation_path.parent.mkdir(parents=True, exist_ok=True) - rs_translation_path.write_text(translation["output_code"]) + rs_translation_path.write_text(output_code) - prompt_path = output_dir / "src" / filename.with_suffix(".translate_prompt") + prompt_path = output_dir / filename.with_suffix(".translate_prompt") prompt_path.parent.mkdir(parents=True, exist_ok=True) prompt_path.write_text(translation["input_code"]) - history_path = output_dir / "src" / filename.with_suffix(".translate_history") + history_path = output_dir / filename.with_suffix(".translate_history") history_path.parent.mkdir(parents=True, exist_ok=True) history_path.write_text(translation["history"]) diff --git a/src/ideas/translate_recurrent.py b/src/ideas/translate_recurrent.py new file mode 100644 index 0000000..b401795 --- /dev/null +++ b/src/ideas/translate_recurrent.py @@ -0,0 +1,325 @@ +# +# Copyright (C) 2025 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 +# + +import io +import json +import logging +from pathlib import Path +from difflib import unified_diff +from graphlib import TopologicalSorter +from contextlib import redirect_stdout +from dataclasses import dataclass, field +from collections import OrderedDict, defaultdict, deque + +import dspy +import hydra +from omegaconf import MISSING +from clang.cindex import TranslationUnit +from hydra.core.config_store import ConfigStore +from hydra.core.hydra_config import HydraConfig + +from ideas import model, ModelConfig, GenerateConfig, tools +from ideas import get_info_from_cargo_toml, extract_info_c +from .tools import Crate +from .init import get_symbols_and_dependencies +from .ast import get_cursor_code + +logger = logging.getLogger("ideas.translate") + + +@dataclass +class TranslateConfig: + filename: Path = MISSING + model: ModelConfig = field(default_factory=ModelConfig) + generate: GenerateConfig = field(default_factory=GenerateConfig) + + translator: str = "ChainOfThought" + max_iters: int = 5 + + +cs = ConfigStore.instance() +cs.store(name="translate", node=TranslateConfig) + + +class Translator(dspy.Module): + def __init__( + self, + translator: type[dspy.Module], + max_iters: int = 5, + ): + super().__init__() + self.translator = translator + self.max_iters = max_iters + + def forward(self, filename: Path, crate: Crate) -> dspy.Prediction: + # Get global symbol table + tu = TranslationUnit.from_source(filename) + asts = [extract_info_c(tu)] + symbols, dependencies = get_symbols_and_dependencies(asts) + references = transpose_graph(dependencies) + + # Assemble C sources in topological order + sources: dict[str, str] = OrderedDict() + for symbol_name in TopologicalSorter(dependencies).static_order(): + # Ignore tag definitions and function declarations + if symbol_name not in symbols: + logger.warning(f"Skipping `{symbol_name}` ...") + continue + sources[symbol_name] = get_cursor_code(symbols[symbol_name].cursor) + + # Translate symbol by symbol + translations: dict[str, str] = OrderedDict() + for symbol_name, symbol_code in sources.items(): + ref_names = [ + name for name in bfs(symbol_name, dependencies) if name in translations + ] + dep_names = [ + name for name in bfs(symbol_name, references, max_depth=1) if name in sources + ] + # move dependent names that has a translation to reference names + for dep_name in dep_names: + assert dep_name not in translations, ( + f"Dependency {dep_name} should not already be translated" + ) + for ref_name in bfs(dep_name, dependencies): + if ref_name in translations and ref_name not in ref_names: + ref_names.append(ref_name) + + # Gather reference and dependent code in order of translations and sources, respectively + ref_translations = "\n\n".join( + [translation for name, translation in translations.items() if name in ref_names] + ) + dep_sources = "\n\n".join( + [source for name, source in sources.items() if name in dep_names] + ) + + logger.info(f"Translating `{symbol_name}` ...") + logger.debug(f"```c\n{symbol_code}\n```") + + pred = self.translate_with_feedback( + ref_translations, + symbol_code, + dep_sources, + crate, + max_iters=self.max_iters, + ) + # pred = dspy.Prediction(translation=dspy.Code(code="")) + + # Logging + logger.info(f"Translated `{symbol_name}` ...") + logger.debug(f"```rust\n{pred.translation.code}\n```") + + # Update state + translations[symbol_name] = pred.translation.code + with crate.rust_src_path.with_suffix(".jsonl").open("a") as f: + f.write( + json.dumps( + { + "name": symbol_name, + "source": symbol_code, + "translation": pred.translation.code, + } + ) + + "\n" + ) + + translation = "\n\n".join(translations.values()) + return dspy.Prediction(translation=translation) + + class TranslateSignature(dspy.Signature): + """ + Generate an idiomatic, memory-safe Rust translation of the snippet. + The reference_code contains Rust code that should be used by the translation. + The snippet contains a single C definition to translate to idiomatic, memory-safe Rust. + The dependent_code contains C code that uses the C snippet. + Reason about the dependent_code to understand any special memory management or complex ownership requirements a safe and idiomatic translation may need to take into account. + Ensure the translation of the snippet does not use any unsafe constructs! + Do not refactor the reference_code in the translation! + Do not translate dependent_code to Rust in the translation! + Do not define any implementations (`impl`) in the translation! + Always assume all C integer arithmetic operations on the underlying value are intended to have wrapping semantics, and thus any translation should use Rust's wrapping arithmetic functions like `wrapping_add`, `wrapping_shr`, etc.. + Analyze all bitwise operations carefully, especially rotations. + For all bitwise operations, including those that may appear to swap bits for bytes, implement the behavior exactly as written in the C code, without making assumptions about intent. + Use the `cargo build` feedback about the prior_translation, if provided, when generating the Rust translation. + """ + + # For example, reason about how a Rust translation of the dependent_code would inform a safe and idiomatic translation of the C snippet. + + reference_code: dspy.Code["Rust"] = dspy.InputField() # noqa: F821 + snippet: dspy.Code["C"] = dspy.InputField() # noqa: F821 + dependent_code: dspy.Code["C"] = dspy.InputField() # noqa: F821 + prior_translation: dspy.Code["Rust"] = dspy.InputField() # noqa: F821 + feedback: str = dspy.InputField() + translation: dspy.Code["Rust"] = dspy.OutputField() # noqa: F821 + + def translate( + self, + reference_code: str, + snippet: str, + dependent_code: str, + *, + prior_translation: str = "", + feedback: str = "", + ) -> dspy.Prediction: + translate = self.translator(Translator.TranslateSignature) + + pred = translate( + reference_code=reference_code, + snippet=snippet, + dependent_code=dependent_code, + prior_translation=prior_translation, + feedback=feedback, + ) + # FIXME: Add rustfmt and FeedbackException? + return pred + + # FIXME: convert to using symbol + def translate_with_feedback( + self, + reference_code: str, + snippet: str, + dependent_code: str, + crate: Crate, + *, + max_iters: int = 0, + ) -> dspy.Prediction: + pred = self.translate(reference_code, snippet, dependent_code) + i = 0 + for i in range(max_iters): + rust_src = "" + if len(reference_code) > 0: + rust_src += reference_code + "\n\n" + rust_src += pred.translation.code + "\n\n" + if crate.is_bin and "fn main()" not in rust_src: + # Work around E0601 error + rust_src += 'fn main() {\n println!("Hello, world!");\n}\n' + + crate.rust_src_path.write_text(rust_src) + success, feedback = tools.run_subprocess( + [ + "cargo", + "build", + "--quiet", + "--color=never", + f"--manifest-path={crate.cargo_toml}", + ] + ) + if success: + break + logger.debug( + f"Feedback\n```rust\n{reference_code}\n{pred.translation.code}\n```\n\n# Feedback\n{feedback}\n\n# reasoning\n{pred.reasoning}" + ) + + pred = self.translate( + reference_code, + snippet, + dependent_code, + prior_translation=pred.translation, + feedback=feedback, + ) + else: + logger.warning( + f"Translation failed to build after {max_iters} feedback iterations!" + ) + pred["iters"] = i + return pred + + def get_history(self, n: int = 1, clear: bool = False) -> str: + f = io.StringIO() + with redirect_stdout(f): + self.inspect_history(n=n, clear=clear) + return f.getvalue().strip() + + def inspect_history(self, n: int = 1, clear: bool = True): + super().inspect_history(n) + if clear: + self.history = [] + + +def diff(old: str, new: str): + diff = "\n".join( + unified_diff( + old.splitlines(), new.splitlines(), lineterm="", fromfile="old", tofile="new" + ) + ) + return diff + + +def rustc(translation: dspy.Code["Rust"] | str) -> str: # noqa: F821 + "Compiles the translation using rustc and returns any errors." + if isinstance(translation, dspy.Code): + code = translation.code + else: + code = translation + success, output = tools.run_subprocess( + ["rustc", "-A", "warnings", "--crate-type", "lib", "--edition", "2024", "-"], + input=code, + ) + return "" if success else output + + +def rustfmt(translation: dspy.Code["Rust"] | str) -> str: # noqa: F821 + "Formats the Rust code using rustfmt." + if isinstance(translation, dspy.Code): + code = translation.code + else: + code = translation + _, formatted_code = tools.run_subprocess( + ["rustfmt", "--edition", "2024", "--color", "never"], input=code + ) + return formatted_code + + +def transpose_graph(graph: dict[str, list[str]]) -> dict[str, list[str]]: + transpose: dict[str, list[str]] = defaultdict(list) + for node, neighbors in graph.items(): + for neighbor in neighbors: + transpose[neighbor].append(node) + return dict(transpose) + + +def bfs(node: str, graph: dict[str, list[str]], max_depth: int = -1) -> list[str]: + nodes = [node] + queue = deque() + queue.append((node, 0)) + while queue: + curr_node, level = queue.popleft() + for neighbor in graph.get(curr_node, []): + # ignore visited or too deep nodes + if neighbor in nodes or (max_depth >= 0 and level + 1 > max_depth): + continue + nodes.append(neighbor) + queue.append((neighbor, level + 1)) + # ignore initial node + return nodes[1:] + + +@hydra.main(version_base=None, config_name="translate") +def main(cfg: TranslateConfig) -> None: + output_dir = Path(HydraConfig.get().runtime.output_dir) + logger.info(f"Saving results to {output_dir}") + crate = get_info_from_cargo_toml(output_dir / "Cargo.toml") + + model.configure(cfg.model, cfg.generate) + if cfg.translator == "ChainOfThought": + translator = dspy.ChainOfThought + elif cfg.translator == "Predict": + translator = dspy.Predict + else: + raise ValueError(f"Unknown translator: {cfg.translator}!") + + agent = Translator(translator, cfg.max_iters) + pred = agent(cfg.filename, crate) + translation = pred.translation + # Write translation and history to disk + crate.rust_src_path.parent.mkdir(exist_ok=True, parents=True) + crate.rust_src_path.write_text(translation) + crate.rust_src_path.with_suffix(".history").write_text(agent.get_history(n=100000)) + logger.info(f"Saved translation to {crate.rust_src_path}") + + +if __name__ == "__main__": + main() diff --git a/src/ideas/treesitter.py b/src/ideas/treesitter.py deleted file mode 100644 index 735ba7d..0000000 --- a/src/ideas/treesitter.py +++ /dev/null @@ -1,56 +0,0 @@ -# -# Copyright (C) 2025 Intel Corporation -# -# SPDX-License-Identifier: Apache-2.0 -# - -from typing import Callable - -import tree_sitter_rust as ts_rs -from tree_sitter import Language, Parser, Node - -from ideas import TreeResult - - -def traverse_rust(code: str, node: Node) -> TreeResult: - result = TreeResult() - - if node.type == "function_item": - def_text = code[node.start_byte : node.end_byte] - brace_pos = def_text.find("{") - if brace_pos != -1: - signature = def_text[:brace_pos].strip() - result.fn_definitions[signature] = result.fn_definitions.get(signature) - - # NOTE: need this for signatures in traits - elif node.type == "function_signature_item": - sig_text = code[node.start_byte : node.end_byte].strip() - if sig_text.endswith(";"): - sig_text = sig_text[:-1].strip() - result.fn_definitions[sig_text] = result.fn_definitions.get(sig_text) - - # NOTE: need this for FFI signatures - elif ( - node.type == "function_signature" and node.parent and node.parent.type == "extern_block" - ): - sig_text = code[node.start_byte : node.end_byte].strip() - if sig_text.endswith(";"): - sig_text = sig_text[:-1].strip() - result.fn_definitions[sig_text] = result.fn_definitions.get(sig_text) - - for child in node.children: - child_result = traverse_rust(code, child) - result.fn_definitions.update(child_result.fn_definitions) - - return result - - -def extract_info(code: str, traverse_fn: Callable, parser: Parser) -> TreeResult: - tree = parser.parse(bytes(code, "utf8")) - return traverse_fn(code, tree.root_node) - - -def extract_info_rust(code: str) -> TreeResult: - lang = Language(ts_rs.language()) - parser = Parser(lang) - return extract_info(code, traverse_fn=traverse_rust, parser=parser) diff --git a/src/ideas/utils.py b/src/ideas/utils.py index 9ec9e4c..15fba4b 100644 --- a/src/ideas/utils.py +++ b/src/ideas/utils.py @@ -10,16 +10,29 @@ from dataclasses import dataclass from collections.abc import KeysView -from clang.cindex import CursorKind +from clang.cindex import Cursor, CursorKind +from tree_sitter import Node @dataclass(frozen=True) class Symbol: name: str - kind: CursorKind + cursor: Cursor decl: str = "" usr: str | None = None + @property + def kind(self) -> CursorKind: + return self.cursor.kind + + +@dataclass(frozen=True) +class RustSymbol: + name: str + node: Node + decl: str = "" + attributes: list[str] | None = None + # Modify graph edges to only count called symbols from a set def filter_edges_by_set(inputs: list[Symbol], symbols: KeysView | set[str]) -> list[Symbol]: diff --git a/src/ideas/wrapper.py b/src/ideas/wrapper.py new file mode 100644 index 0000000..d56ebad --- /dev/null +++ b/src/ideas/wrapper.py @@ -0,0 +1,180 @@ +# +# Copyright (C) 2025 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 +# + +import io +import re +import logging +from pathlib import Path +from contextlib import redirect_stdout +from dataclasses import dataclass, field + +import dspy +import hydra +from omegaconf import MISSING +from hydra.core.config_store import ConfigStore +from hydra.core.hydra_config import HydraConfig + +from ideas import model, ModelConfig, GenerateConfig, tools +from ideas.tools import Crate, get_info_from_cargo_toml + +logger = logging.getLogger("ideas.wrapper") + + +@dataclass +class WrapperConfig: + model: ModelConfig = field(default_factory=ModelConfig) + generate: GenerateConfig = field(default_factory=GenerateConfig) + + symbols: Path = MISSING + cargo_toml: Path = MISSING + + max_iters: int = 5 + + +cs = ConfigStore.instance() +cs.store(name="wrapper", node=WrapperConfig) + + +class WrapperGenerator(dspy.Module): + class Signature(dspy.Signature): + """ + Implement a C-compatible FFI wrapper for `crate::{symbol_name}` by replacing the `unimplemented!()` macro in `example_wrapper`. + The implementation for `crate::{symbol_name}` is in a crate that was read from "{crate_path}". + Assume the types in `crate::wrapper::` do not have the same memory layout as those in `crate::`. + The wrapper should properly convert between `crate::wrapper::` and `crate::` types by copying the values from the wrapper type to the crate type before calling `crate::{symbol_name}`. + After this conversion, the wrapper should call the Rust function `crate::{symbol_name}`. + After the call to `crate::{symbol_name}`, the wrapper should convert back the `crate::` types to `crate::wrapper::` types. + The wrapper will be written to "{wrapper_path}". + Use the feedback, if provided, from `cargo build` about the `prior_wrapper` when generating the wrapper. + """ + + crate: dspy.Code["Rust"] = dspy.InputField() # noqa: F821 + example_wrapper: dspy.Code["Rust"] = dspy.InputField() # noqa: F821 + wrapper: dspy.Code["Rust"] = dspy.OutputField() # noqa: F821 + prior_wrapper: dspy.Code["Rust"] = dspy.InputField() # noqa: F821 + feedback: str = dspy.InputField() + + def __init__( + self, + max_iters: int, + ) -> None: + super().__init__() + + self.max_iters = max_iters + + def forward(self, symbol_name: str, crate: Crate) -> dspy.Prediction: + crate_path = crate.rust_src_path + wrapper_path = crate.rust_src_path.parent / "wrapper.rs" + symbol_wrapper_path = crate.rust_src_path.parent / "wrapper" / f"{symbol_name}.rs" + example_wrapper = symbol_wrapper_path.read_text() + + signature = WrapperGenerator.Signature.with_instructions( + WrapperGenerator.Signature.instructions.format( + symbol_name=symbol_name, + crate_path="src/lib.rs", + wrapper_path="src/wrapper.rs", + ) + ) + generate_wrapper = dspy.ChainOfThought(signature) + + orig_wrapper = wrapper_path.read_text() + orig_code = crate_path.read_text() + + # Add "pub mod wrapper;" to code + code = orig_code + if not re.search(r"^pub mod wrapper;$", code, flags=re.MULTILINE): + code = f"pub mod wrapper;\n\n{code}" + crate_path.write_text(code) + + i, wrapper, success, feedback, prior_wrapper = 0, "", False, "", example_wrapper + for i in range(self.max_iters): + pred = generate_wrapper( + crate=code, + example_wrapper=example_wrapper, + feedback=feedback, + prior_wrapper=prior_wrapper, + ) + if pred.wrapper is None: + continue + prior_wrapper = pred.wrapper.code + + # Statically convert #[no_mangle] to #[unsafe(export_name="...")] + # TODO: Attempt to use `cargo fix` + wrapper = re.sub( + r"^(\s*)#\[no_mangle\]", + f'\\1#[unsafe(export_name="{symbol_name}")]', + prior_wrapper, + flags=re.MULTILINE, + ) + + # Write wrapper to disk and check if we build + wrapper_path.write_text(wrapper) + success, feedback = tools.run_subprocess( + [ + "cargo", + "build", + "--quiet", + "--color=never", + f"--manifest-path={crate.cargo_toml}", + ] + ) + if success: + break + + # Write original code and wrapper and return wrapper + crate_path.write_text(orig_code) + wrapper_path.write_text(orig_wrapper) + return dspy.Prediction(wrapper=wrapper, success=success, iters=i) + + def get_history(self, n: int = 1, clear: bool = False) -> str: + f = io.StringIO() + with redirect_stdout(f): + self.inspect_history(n=n, clear=clear) + return f.getvalue().strip() + + def inspect_history(self, n: int = 1, clear: bool = True): + super().inspect_history(n) + if clear: + self.history = [] + + +@hydra.main(version_base=None, config_name="wrapper") +def main(cfg: WrapperConfig) -> None: + output_dir = Path(HydraConfig.get().runtime.output_dir) + logger.info(f"Saving results to {output_dir}") + + model.configure(cfg.model, cfg.generate) + + # Get crate info and update src to include reference to wrapper + crate = get_info_from_cargo_toml(cfg.cargo_toml) + if not re.search( + r"^pub mod wrapper;$", crate.rust_src_path.read_text(), flags=re.MULTILINE + ): + with crate.rust_src_path.open("a+") as f: + f.write("\n\npub mod wrapper;\n") + wrapper_path = crate.rust_src_path.parent / "wrapper.rs" + wrapper_path.write_text("") + + # Generate wrappers for each symbol + for symbol_name in cfg.symbols.read_text().splitlines(): + logger.info(f"Generating wrapper for `{symbol_name}` ...") + agent = WrapperGenerator(max_iters=cfg.max_iters) + pred = agent(symbol_name, crate) + + # Write wrapper to disk and reference in wrapper.rs + symbol_wrapper = pred.wrapper + symbol_wrapper_path = crate.rust_src_path.parent / "wrapper" / f"{symbol_name}.rs" + symbol_wrapper_path.parent.mkdir(exist_ok=True, parents=True) + symbol_wrapper_path.write_text(symbol_wrapper) + symbol_wrapper_path.with_suffix(".history").write_text(agent.get_history(n=100000)) + + if pred.success: + with wrapper_path.open("a+") as f: + f.write(f"pub mod {symbol_name};\n") + + +if __name__ == "__main__": + main() diff --git a/test/fixtures/ast/data_structures/src/main.c b/test/fixtures/ast/data_structures/src/main.c index b24bbe0..24717e2 100644 --- a/test/fixtures/ast/data_structures/src/main.c +++ b/test/fixtures/ast/data_structures/src/main.c @@ -8,6 +8,7 @@ struct Point p1; int num_dimensions = 2; const double half_pi = PI / 2.0; +const double half_pi; static double one_third_pi = PI / 3.0; static const double quarter_pi = PI / 4.0; diff --git a/test/fixtures/ast/data_structures/src/main.c.i b/test/fixtures/ast/data_structures/src/main.c.i index f696308..efa799e 100644 --- a/test/fixtures/ast/data_structures/src/main.c.i +++ b/test/fixtures/ast/data_structures/src/main.c.i @@ -776,6 +776,7 @@ struct Point p1; int num_dimensions = 2; const double half_pi = PI / 2.0; +const double half_pi; static double one_third_pi = PI / 3.0; static const double quarter_pi = PI / 4.0; diff --git a/test/fixtures/ast/signatures.rs b/test/fixtures/ast/signatures.rs deleted file mode 100644 index 2f1504b..0000000 --- a/test/fixtures/ast/signatures.rs +++ /dev/null @@ -1,21 +0,0 @@ -fn add(a: i32, b: i32) -> i32 { - a + b -} - -pub fn multiply(x: f64, y: f64) -> f64 { - x * y -} - -fn identity(value: T) -> T { - value -} - -trait Calculator { - fn calculate(&self, a: i32, b: i32) -> i32; - fn reset(&mut self); -} - -extern "C" { - fn printf(format: *const i8, ...) -> c_int; - fn malloc(size: usize) -> *mut u8; -} diff --git a/test/fixtures/ast_rust/functions.rs b/test/fixtures/ast_rust/functions.rs new file mode 100644 index 0000000..47363cb --- /dev/null +++ b/test/fixtures/ast_rust/functions.rs @@ -0,0 +1,90 @@ +#![allow(dead_code, unused_variables)] + +use std::ffi::c_int; + +pub fn simple_public() { + println!("Hello"); +} + +fn private_function() { + println!("Hello"); +} + +pub fn generic_function(value: T) { + println!("Hello"); +} + +pub fn lifetime_function<'a>(s: &'a str) -> &'a str { + println!("Hello"); + s +} + +pub unsafe fn unsafe_function(ptr: *const i32) { + println!("Hello"); +} + +pub const fn const_function(x: i32) -> i32 { + println!("Hello"); + x * 2 +} + +pub async fn async_function() { + println!("Hello"); +} + +pub fn where_clause_function(value: T) where T: Clone + std::fmt::Debug { + println!("Hello"); +} + +#[unsafe(no_mangle)] +pub extern "C" fn ffi_function(x: c_int) -> c_int { + println!("Hello"); + x +} + +#[inline(always)] +#[must_use] +pub fn attributed_function() -> i32 { + println!("Hello"); + 42 +} + +pub fn complex_return() -> Result>, Box> { + println!("Hello"); + Ok(vec![]) +} + +pub fn multi_lifetime<'a, 'b>(x: &'a mut i32, y: &'b str) -> &'a i32 { + println!("Hello"); + x +} + +pub unsafe extern "system" fn system_abi_function(code: i32) { + println!("Hello"); +} + +#[allow(clippy::all)] +pub fn complex_generic(first: T, second: U) -> String +where + T: std::fmt::Display + Clone, + U: std::fmt::Debug + Send + Sync, +{ + println!("Hello"); + format!("{}", first) +} + +// External function declarations (no body) - typically used for FFI +extern "C" { + pub fn external_c_function(x: c_int) -> c_int; + + unsafe fn unsafe_external_function(ptr: *mut u8, len: usize); + + static EXTERNAL_GLOBAL: c_int; + + fn printf(format: *const u8, ...) -> c_int; +} + +// Extern block with different ABI +extern "system" { + pub fn system_api_call(code: u32) -> i32; +} diff --git a/test/fixtures/ast_rust/no_mangle.rs b/test/fixtures/ast_rust/no_mangle.rs new file mode 100644 index 0000000..1d864ff --- /dev/null +++ b/test/fixtures/ast_rust/no_mangle.rs @@ -0,0 +1,76 @@ +fn no_attributes() { + println!("No attributes"); +} + +#[no_mangle] +fn needs_unsafe_single() { + println!("Single no_mangle"); +} + +#[no_mangle] #[inline] +fn needs_unsafe_same_line() { + println!("Multiple on same line"); +} + +#[inline] +#[no_mangle] +fn needs_unsafe_different_lines() { + println!("Different lines"); +} + +#[inline] +#[no_mangle] +#[must_use] +fn needs_unsafe_between() { + println!("In between lines"); +} + +#[inline] +#[no_mangle] #[must_use] +fn needs_unsafe_irregular_1_2() { + println!("Irregular pattern"); +} + +#[inline] #[no_mangle] +#[must_use] +fn needs_unsafe_irregular_2_1() { + println!("Irregular pattern"); +} + +#[unsafe(no_mangle)] +fn already_safe() { + println!("Already has unsafe"); +} + +#[inline] #[unsafe(no_mangle)] +fn already_safe_with_others() { + println!("Already safe with others"); +} + +#[inline] #[must_use] +fn other_attributes_only() { + println!("Has others only"); +} + +#[inline] #[no_mangle] #[must_use] +fn needs_unsafe_three_same_line() { + println!("Middle of line"); +} + +#[no_mangle] +pub extern "C" fn extern_c_function() { + println!("Extern C function"); +} + +#[no_mangle] +extern "C" fn extern_c_with_args(arg1: i32, arg2: f64) -> i32 { + println!("Extern C with args: {}, {}", arg1, arg2); + arg1 + arg2 as i32 +} + +pub mod foo { + #[no_mangle] + pub fn namespaced_function() { + println!("Namespaced function"); + } +} diff --git a/test/test_clang.py b/test/test_clang.py index e2200eb..9a2eb41 100644 --- a/test/test_clang.py +++ b/test/test_clang.py @@ -200,6 +200,10 @@ def test_variables(c_data_structures_code: str): assert variables["c:@num_dimensions"] == "int num_dimensions" assert variables["c:@anonymous_struct"] == "struct { int a; int b; } anonymous_struct" assert variables["c:@half_pi"] == "const double half_pi" + assert ( + ast.get_cursor_code(result.symbols["c:@half_pi"].cursor) + == "const double half_pi = PI / 2.0;" + ) assert variables["c:file.c@one_third_pi"] == "static double one_third_pi" assert variables["c:file.c@quarter_pi"] == "static const double quarter_pi" assert variables["c:@PI"] == "extern const double PI" diff --git a/test/test_convert_json_to_rust.py b/test/test_convert_json_to_rust.py index 028f44e..8bda659 100644 --- a/test/test_convert_json_to_rust.py +++ b/test/test_convert_json_to_rust.py @@ -42,7 +42,7 @@ def test_convert_to_cargo_test( # Temporarily redirect stdout with contextlib.redirect_stdout(captured_output): - convert_tests.convert_tests(json_test_cases) + convert_tests.convert_tests_for_exec(json_test_cases) # Write the captured Rust code to a fresh tests/test_cases.rs original_harness = rust_tests_harness.read_text() diff --git a/test/test_no_mangle.py b/test/test_no_mangle.py new file mode 100644 index 0000000..132e67c --- /dev/null +++ b/test/test_no_mangle.py @@ -0,0 +1,108 @@ +# +# Copyright (C) 2025 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 +# + +import re +import pytest + +from pathlib import Path + +from ideas.ast_rust import extract_info_rust +from ideas.ast_rust import ensure_no_mangle_in_module + + +@pytest.fixture +def fixtures_dir() -> Path: + return Path(__file__).parent / "fixtures" / "ast_rust" + + +@pytest.fixture +def no_mangle_code(fixtures_dir: Path) -> str: + return (fixtures_dir / "no_mangle.rs").read_text() + + +@pytest.fixture +def test_rust_file(tmp_path: Path, no_mangle_code: str) -> Path: + """Create a temporary copy of the no_mangle.rs file for testing.""" + test_file = tmp_path / "no_mangle_test.rs" + test_file.write_text(no_mangle_code, encoding="utf-8") + return test_file + + +def extract_function_with_attributes(source_code: str, function_name: str) -> str: + # Match attributes and function signature for the given function + pattern = rf"(?:^#\[.*\].*\n)*^fn {re.escape(function_name)}\(\)" + match = re.search(pattern, source_code, re.MULTILINE) + if not match: + raise ValueError(f"Function {function_name} not found") + return match.group(0) + + +def test_add_no_mangle_top(test_rust_file: Path): + original_code = test_rust_file.read_text(encoding="utf-8") + # Add the no_mangle attribute at the top-level and inside the module + modified_code = ensure_no_mangle_in_module(original_code, module_name=None, add=True) + modified_code = ensure_no_mangle_in_module(modified_code, module_name="foo", add=True) + info = extract_info_rust(modified_code) + + for function_name in [ + "no_attributes", + "needs_unsafe_single", + "needs_unsafe_same_line", + "needs_unsafe_different_lines", + "needs_unsafe_between", + "needs_unsafe_irregular_1_2", + "needs_unsafe_irregular_2_1", + "already_safe", + "already_safe_with_others", + "other_attributes_only", + "needs_unsafe_three_same_line", + "extern_c_function", + "extern_c_with_args", + "namespaced_function", + ]: + attributes = info.symbols[function_name].attributes + assert function_name in info.symbols + assert attributes is not None + assert "#[unsafe(no_mangle)]" in attributes + assert "#[no_mangle]" not in attributes + + +def test_patch_no_mangle(no_mangle_code: str): + # Apply the patching function at the top-level + modified_code = ensure_no_mangle_in_module(no_mangle_code, module_name=None, add=False) + info = extract_info_rust(modified_code) + + # Check each expected function + for function_name in [ + "needs_unsafe_single", + "needs_unsafe_same_line", + "needs_unsafe_different_lines", + "needs_unsafe_between", + "needs_unsafe_irregular_1_2", + "needs_unsafe_irregular_2_1", + "already_safe", + "already_safe_with_others", + "needs_unsafe_three_same_line", + "extern_c_function", + "extern_c_with_args", + ]: + attributes = info.symbols[function_name].attributes + assert function_name in info.symbols + assert attributes is not None + assert "#[unsafe(no_mangle)]" in attributes + assert "#[no_mangle]" not in attributes + + # Apply it inside the module + modified_code = ensure_no_mangle_in_module(no_mangle_code, module_name="foo", add=False) + info = extract_info_rust(modified_code) + + # Check the namespaced function + function_name = "namespaced_function" + attributes = info.symbols[function_name].attributes + assert function_name in info.symbols + assert attributes is not None + assert "#[unsafe(no_mangle)]" in attributes + assert "#[no_mangle]" not in attributes diff --git a/test/test_rust_ast.py b/test/test_rust_ast.py new file mode 100644 index 0000000..dbf2b51 --- /dev/null +++ b/test/test_rust_ast.py @@ -0,0 +1,229 @@ +# +# Copyright (C) 2025 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 +# + + +import pytest +from pathlib import Path + +from tree_sitter import Language, Parser +import tree_sitter_rust as tsrust + +from ideas.ast_rust import get_function_names_in_module, extract_info_rust + + +@pytest.fixture +def fixtures_dir() -> Path: + return Path(__file__).parent / "fixtures" / "ast_rust" + + +@pytest.fixture +def functions(fixtures_dir: Path) -> str: + return (fixtures_dir / "functions.rs").read_text() + + +def test_extract_functions_rust(functions: str): + # Parse the code + parser = Parser(Language(tsrust.language())) + source_bytes = functions.encode("utf-8") + tree = parser.parse(source_bytes) + root_node = tree.root_node + function_names = get_function_names_in_module( + root_node=root_node, source_code=functions, module_name=None + ) + + # Check that all expected functions are extracted + expected_functions = [ + "simple_public", + "private_function", + "unsafe_function", + "const_function", + "async_function", + "generic_function", + "lifetime_function", + "where_clause_function", + "ffi_function", + "attributed_function", + "complex_return", + "multi_lifetime", + "system_abi_function", + "complex_generic", + # FFI signature declarations + "external_c_function", + "unsafe_external_function", + "printf", + "system_api_call", + ] + + for func in expected_functions: + assert func in function_names + + +def test_extract_info_rust(functions: str): + result = extract_info_rust(functions) + + assert "simple_public" in result.symbols + assert ( + result.fn_definitions["simple_public"] + == """ +pub fn simple_public() { + println!("Hello"); +}""".strip() + ) + + assert "private_function" in result.symbols + assert ( + result.fn_definitions["private_function"] + == """ +fn private_function() { + println!("Hello"); +}""".strip() + ) + + assert "unsafe_function" in result.symbols + assert ( + result.fn_definitions["unsafe_function"] + == """ +pub unsafe fn unsafe_function(ptr: *const i32) { + println!("Hello"); +}""".strip() + ) + + assert "const_function" in result.symbols + assert ( + result.fn_definitions["const_function"] + == """ +pub const fn const_function(x: i32) -> i32 { + println!("Hello"); + x * 2 +}""".strip() + ) + + assert "async_function" in result.symbols + assert ( + result.fn_definitions["async_function"] + == """ +pub async fn async_function() { + println!("Hello"); +}""".strip() + ) + + assert "generic_function" in result.symbols + assert ( + result.fn_definitions["generic_function"] + == """ +pub fn generic_function(value: T) { + println!("Hello"); +}""".strip() + ) + + assert "lifetime_function" in result.symbols + assert ( + result.fn_definitions["lifetime_function"] + == """ +pub fn lifetime_function<'a>(s: &'a str) -> &'a str { + println!("Hello"); + s +}""".strip() + ) + + assert "where_clause_function" in result.symbols + assert ( + result.fn_definitions["where_clause_function"] + == """ +pub fn where_clause_function(value: T) where T: Clone + std::fmt::Debug { + println!("Hello"); +}""".strip() + ) + + assert "ffi_function" in result.symbols + assert ( + result.fn_definitions["ffi_function"] + == """ +pub extern "C" fn ffi_function(x: c_int) -> c_int { + println!("Hello"); + x +}""".strip() + ) + + assert "attributed_function" in result.symbols + assert ( + result.fn_definitions["attributed_function"] + == """ +pub fn attributed_function() -> i32 { + println!("Hello"); + 42 +}""".strip() + ) + + assert "complex_return" in result.symbols + assert ( + result.fn_definitions["complex_return"] + == """ +pub fn complex_return() -> Result>, Box> { + println!("Hello"); + Ok(vec![]) +}""".strip() + ) + + assert "multi_lifetime" in result.symbols + assert ( + result.fn_definitions["multi_lifetime"] + == """ +pub fn multi_lifetime<'a, 'b>(x: &'a mut i32, y: &'b str) -> &'a i32 { + println!("Hello"); + x +}""".strip() + ) + + assert "system_abi_function" in result.symbols + assert ( + result.fn_definitions["system_abi_function"] + == """ +pub unsafe extern "system" fn system_abi_function(code: i32) { + println!("Hello"); +}""".strip() + ) + + assert "complex_generic" in result.symbols + assert ( + result.fn_definitions["complex_generic"] + == """ +pub fn complex_generic(first: T, second: U) -> String +where + T: std::fmt::Display + Clone, + U: std::fmt::Debug + Send + Sync, +{ + println!("Hello"); + format!("{}", first) +}""".strip() + ) + + # Check FFI function signatures + assert "external_c_function" in result.symbols + assert ( + result.fn_definitions["external_c_function"] + == 'extern "C" { pub fn external_c_function(x: c_int) -> c_int; }' + ) + + assert "unsafe_external_function" in result.symbols + assert ( + result.fn_definitions["unsafe_external_function"] + == 'extern "C" { unsafe fn unsafe_external_function(ptr: *mut u8, len: usize); }' + ) + + assert "printf" in result.symbols + assert ( + result.fn_definitions["printf"] + == 'extern "C" { fn printf(format: *const u8, ...) -> c_int; }' + ) + assert "system_api_call" in result.symbols + assert ( + result.fn_definitions["system_api_call"] + == 'extern "system" { pub fn system_api_call(code: u32) -> i32; }' + ) + + # Check FFI variables for not being captured + assert "EXTERNAL_GLOBAL" not in result.symbols diff --git a/test/test_treesitter.py b/test/test_treesitter.py deleted file mode 100644 index 9ab119c..0000000 --- a/test/test_treesitter.py +++ /dev/null @@ -1,37 +0,0 @@ -# -# Copyright (C) 2025 Intel Corporation -# -# SPDX-License-Identifier: Apache-2.0 -# - -import pytest -from pathlib import Path - -from ideas import treesitter - - -@pytest.fixture -def fixtures_dir() -> Path: - return Path(__file__).parent / "fixtures" / "ast" - - -@pytest.fixture -def rust_code(fixtures_dir: Path) -> str: - return (fixtures_dir / "signatures.rs").read_text() - - -def test_info_rust(rust_code: str): - expected = [ - "fn add(a: i32, b: i32) -> i32", - "pub fn multiply(x: f64, y: f64) -> f64", - "fn identity(value: T) -> T", - "fn calculate(&self, a: i32, b: i32) -> i32", - "fn reset(&mut self)", - "fn printf(format: *const i8, ...) -> c_int", - "fn malloc(size: usize) -> *mut u8", - ] - - result = treesitter.extract_info_rust(rust_code) - signatures = result.fn_definitions.keys() - for sig, exp in zip(signatures, expected): - assert sig == exp diff --git a/uv.lock b/uv.lock index d29507f..b57dc70 100644 --- a/uv.lock +++ b/uv.lock @@ -423,7 +423,7 @@ wheels = [ [[package]] name = "dspy" -version = "3.0.3" +version = "3.0.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -441,6 +441,7 @@ dependencies = [ { name = "openai" }, { name = "optuna" }, { name = "orjson" }, + { name = "pillow" }, { name = "pydantic" }, { name = "regex" }, { name = "requests" }, @@ -449,9 +450,9 @@ dependencies = [ { name = "tqdm" }, { name = "xxhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3b/19/49fd72c0b4f905ba7b6eee306efa8d3350098e1b3392f7592147ee7dc092/dspy-3.0.3.tar.gz", hash = "sha256:4f77c9571a0f5071495b81acedd44ded1dacd4cdcb4e9fe942da144274f7fbf8", size = 215658, upload-time = "2025-08-31T18:49:31.337Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/18/0042d299cd5e85fdb381568f0cfcc7769122e8f70ea0a2d33e12fd63e705/dspy-3.0.4.tar.gz", hash = "sha256:cb4529df9a91353a16144d9d94ba6ff25f36fc5adfd921f127f4c49d0e309fb8", size = 236376, upload-time = "2025-11-10T17:43:37.619Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/4f/58e7dce7985b35f98fcaba7b366de5baaf4637bc0811be66df4025c1885f/dspy-3.0.3-py3-none-any.whl", hash = "sha256:d19cc38ab3ec7edcb3db56a3463a606268dd2e83280595062b052bcfe0cfd24f", size = 261742, upload-time = "2025-08-31T18:49:30.129Z" }, + { url = "https://files.pythonhosted.org/packages/94/52/56eed4828175f48f712a50a994293065afa7cc98cb112992a0b071179b6c/dspy-3.0.4-py3-none-any.whl", hash = "sha256:c0a88c7936f41f6f613ee6ca8cd92e63746ff2bd780e3896615ade7628eb6a6a", size = 285224, upload-time = "2025-11-10T17:43:36.263Z" }, ] [[package]] @@ -589,11 +590,11 @@ wheels = [ [[package]] name = "gepa" -version = "0.0.7" +version = "0.0.17" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/29/e2/4f8f56ebabac609a2e5e43840c8f6955096906e6e7899e40953cf2adb353/gepa-0.0.7.tar.gz", hash = "sha256:3fb98c2908f6e4cbe701a6f0088c4ea599185a801a02b7872b0c624142679cf7", size = 50763, upload-time = "2025-08-25T03:46:41.471Z" } +sdist = { url = "https://files.pythonhosted.org/packages/61/f0/fe312ed4405ddc2ca97dc1ce8915c4dd707e413503e6832910ab088fceb6/gepa-0.0.17.tar.gz", hash = "sha256:641ed46f8127618341b66ee82a87fb46a21c5d2d427a5e0b91c850a7f7f64e7f", size = 99816, upload-time = "2025-09-25T22:13:45.476Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7d/de/6b36d65bb85f46b40b96e04eb7facfcdb674b6cec554a821be2e44cd4871/gepa-0.0.7-py3-none-any.whl", hash = "sha256:59b8b74f5e384a62d6f590ac6ffe0fa8a0e62fee8d8d6c539f490823d0ffb25c", size = 52316, upload-time = "2025-08-25T03:46:40.424Z" }, + { url = "https://files.pythonhosted.org/packages/88/dc/2bc81a01caa887ed58db3c725bebf1e98f37807a4d06c51ecaa85a7cabe0/gepa-0.0.17-py3-none-any.whl", hash = "sha256:0ea98f4179dbc8dd83bdf53494f302e663ee1da8300d086c4cc8ce4aefa4042c", size = 110464, upload-time = "2025-09-25T22:13:44.14Z" }, ] [[package]] @@ -758,7 +759,7 @@ requires-dist = [ { name = "accelerate", marker = "extra == 'gpu'", specifier = "==1.6.0" }, { name = "basedpyright", marker = "extra == 'dev'", specifier = "==1.29.4" }, { name = "clang", specifier = "==14.0" }, - { name = "dspy", specifier = "==3.0.3" }, + { name = "dspy", specifier = "==3.0.4" }, { name = "hydra-core", specifier = "==1.3.2" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = "==4.2.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = "==8.4.0" },