Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions crates/vm/src/builtins/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::{
use crate::common::lock::OnceCell;
use crate::common::lock::PyMutex;
use crate::function::ArgMapping;
use crate::object::{Traverse, TraverseFn};
use crate::object::{PyAtomicRef, Traverse, TraverseFn};
use crate::{
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
bytecode,
Expand Down Expand Up @@ -61,7 +61,7 @@ fn format_missing_args(
#[pyclass(module = false, name = "function", traverse = "manual")]
#[derive(Debug)]
pub struct PyFunction {
code: PyMutex<PyRef<PyCode>>,
code: PyAtomicRef<PyCode>,
globals: PyDictRef,
builtins: PyObjectRef,
closure: Option<PyRef<PyTuple<PyCellRef>>>,
Expand Down Expand Up @@ -192,7 +192,7 @@ impl PyFunction {

let qualname = vm.ctx.new_str(code.qualname.as_str());
let func = Self {
code: PyMutex::new(code.clone()),
code: PyAtomicRef::from(code.clone()),
globals,
builtins,
closure: None,
Expand All @@ -217,7 +217,7 @@ impl PyFunction {
func_args: FuncArgs,
vm: &VirtualMachine,
) -> PyResult<()> {
let code = &*self.code.lock();
let code: &Py<PyCode> = &self.code;
let nargs = func_args.args.len();
let n_expected_args = code.arg_count as usize;
let total_args = code.arg_count as usize + code.kwonlyarg_count as usize;
Expand Down Expand Up @@ -539,13 +539,12 @@ impl Py<PyFunction> {
Err(err) => info!(
"jit: function `{}` is falling back to being interpreted because of the \
error: {}",
self.code.lock().obj_name,
err
self.code.obj_name, err
),
}
}

let code = self.code.lock().clone();
let code: PyRef<PyCode> = (*self.code).to_owned();

let locals = if code.flags.contains(bytecode::CodeFlags::NEWLOCALS) {
ArgMapping::from_dict_exact(vm.ctx.new_dict())
Expand Down Expand Up @@ -609,7 +608,7 @@ impl Py<PyFunction> {
/// Returns true if: no VARARGS, no VARKEYWORDS, no kwonly args, not generator/coroutine,
/// and effective_nargs matches co_argcount.
pub(crate) fn can_specialize_call(&self, effective_nargs: u32) -> bool {
let code = self.code.lock();
let code: &Py<PyCode> = &self.code;
let flags = code.flags;
flags.contains(bytecode::CodeFlags::NEWLOCALS)
&& !flags.intersects(
Expand All @@ -627,7 +626,7 @@ impl Py<PyFunction> {
/// Only valid when: no VARARGS, no VARKEYWORDS, no kwonlyargs, not generator/coroutine,
/// and nargs == co_argcount.
pub fn invoke_exact_args(&self, args: &[PyObjectRef], vm: &VirtualMachine) -> PyResult {
let code = self.code.lock().clone();
let code: PyRef<PyCode> = (*self.code).to_owned();

let locals = ArgMapping::from_dict_exact(vm.ctx.new_dict());

Expand Down Expand Up @@ -676,12 +675,12 @@ impl PyPayload for PyFunction {
impl PyFunction {
#[pygetset]
fn __code__(&self) -> PyRef<PyCode> {
self.code.lock().clone()
(*self.code).to_owned()
}

#[pygetset(setter)]
fn set___code__(&self, code: PyRef<PyCode>) {
*self.code.lock() = code;
fn set___code__(&self, code: PyRef<PyCode>, vm: &VirtualMachine) {
self.code.swap_to_temporary_refs(code, vm);
self.func_version.store(0, Relaxed);
Comment on lines +682 to 684
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Invalidate cached JIT code when __code__ is reassigned.

Line 683 swaps the code object, but the cached jitted_code is not reset. The JIT fast path in invoke_with_locals (Line 533 onward) still executes cached compiled code when present, so f.__code__ = ... can leave the function running stale machine code.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@crates/vm/src/builtins/function.rs` around lines 682 - 684, When assigning a
new code object in set___code__, also invalidate any cached JIT artifacts so the
old machine code cannot be used; specifically, after
self.code.swap_to_temporary_refs(...) and before/after
self.func_version.store(0, Relaxed), clear the cached jitted_code (and any
associated owner/version fields used by the JIT fast path) so invoke_with_locals
cannot find/execute stale compiled code — update the fields that hold the
compiled entry (jitted_code) to a neutral/empty state consistent with how they
are checked in invoke_with_locals.

}

Expand Down Expand Up @@ -923,7 +922,7 @@ impl PyFunction {
}
let arg_types = jit::get_jit_arg_types(&zelf, vm)?;
let ret_type = jit::jit_ret_type(&zelf, vm)?;
let code = zelf.code.lock();
let code: &Py<PyCode> = &zelf.code;
let compiled = rustpython_jit::compile(&code.code, &arg_types, ret_type)
.map_err(|err| jit::new_jit_error(err.to_string(), vm))?;
let _ = zelf.jitted_code.set(compiled);
Expand Down
8 changes: 3 additions & 5 deletions crates/vm/src/builtins/function/jit.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
AsObject, Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine,
builtins::{
PyBaseExceptionRef, PyDict, PyDictRef, PyFunction, PyStrInterned, bool_, float, int,
PyBaseExceptionRef, PyCode, PyDict, PyDictRef, PyFunction, PyStrInterned, bool_, float, int,
},
bytecode::CodeFlags,
convert::ToPyObject,
Expand Down Expand Up @@ -67,7 +67,7 @@ fn get_jit_arg_type(dict: &Py<PyDict>, name: &str, vm: &VirtualMachine) -> PyRes
}

pub fn get_jit_arg_types(func: &Py<PyFunction>, vm: &VirtualMachine) -> PyResult<Vec<JitType>> {
let code = func.code.lock();
let code: &Py<PyCode> = &func.code;
let arg_names = code.arg_names();

if code
Expand Down Expand Up @@ -160,7 +160,7 @@ pub(crate) fn get_jit_args<'a>(
let mut jit_args = jitted_code.args_builder();
let nargs = func_args.args.len();

let code = func.code.lock();
let code: &Py<PyCode> = &func.code;
let arg_names = code.arg_names();
let arg_count = code.arg_count;
let posonlyarg_count = code.posonlyarg_count;
Expand Down Expand Up @@ -220,7 +220,5 @@ pub(crate) fn get_jit_args<'a>(
}
}

drop(code);

jit_args.into_args().ok_or(ArgsError::NotAllArgsPassed)
}
Loading