Skip to content

Commit e732669

Browse files
authored
Make it possible for rust functions to increase recursion depth (RustPython#3252)
Make it possible for rust functions to increase recursion depth
1 parent 367b258 commit e732669

File tree

6 files changed

+45
-19
lines changed

6 files changed

+45
-19
lines changed

Lib/test/list_tests.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ def test_repr(self):
6060
self.assertEqual(str(a2), "[0, 1, 2, [...], 3]")
6161
self.assertEqual(repr(a2), "[0, 1, 2, [...], 3]")
6262

63-
# TODO: RUSTPYTHON
64-
@unittest.expectedFailure
6563
def test_repr_deep(self):
6664
a = self.type2test([])
6765
for i in range(sys.getrecursionlimit() + 100):

Lib/test/test_dict.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,8 +545,6 @@ def __repr__(self):
545545
d = {1: BadRepr()}
546546
self.assertRaises(Exc, repr, d)
547547

548-
# TODO: RUSTPYTHON
549-
@unittest.expectedFailure
550548
@unittest.skipIf(sys.platform == 'win32', 'TODO: RUSTPYTHON Windows')
551549
def test_repr_deep(self):
552550
d = {}

Lib/test/test_dictviews.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,6 @@ def test_recursive_repr(self):
224224
# Again.
225225
self.assertIsInstance(r, str)
226226

227-
# TODO: RUSTPYTHON
228-
@unittest.expectedFailure
229227
@unittest.skipIf(sys.platform == "win32", "thread 'main' has overflowed its stack")
230228
def test_deeply_nested_repr(self):
231229
d = {}

extra_tests/snippets/recursion.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from testutils import assert_raises
2+
3+
class Foo(object):
4+
pass
5+
6+
Foo.__repr__ = Foo.__str__
7+
8+
foo = Foo()
9+
# Since the default __str__ implementation calls __repr__ and __repr__ is
10+
# actually __str__, str(foo) should raise a RecursionError.
11+
assert_raises(RecursionError, str, foo)

vm/src/stdlib/sys.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -478,16 +478,16 @@ mod sys {
478478
"recursion limit must be greater than or equal to one".to_owned(),
479479
)
480480
})?;
481-
let recursion_depth = vm.frames.borrow().len();
481+
let recursion_depth = vm.current_recursion_depth();
482482

483483
if recursion_limit > recursion_depth + 1 {
484484
vm.recursion_limit.set(recursion_limit);
485485
Ok(())
486486
} else {
487487
Err(vm.new_recursion_error(format!(
488-
"cannot set the recursion limit to {} at the recursion depth {}: the limit is too low",
489-
recursion_limit, recursion_depth
490-
)))
488+
"cannot set the recursion limit to {} at the recursion depth {}: the limit is too low",
489+
recursion_limit, recursion_depth
490+
)))
491491
}
492492
}
493493

vm/src/vm.rs

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ pub struct VirtualMachine {
6262
pub repr_guards: RefCell<HashSet<usize>>,
6363
pub state: PyRc<PyGlobalState>,
6464
pub initialized: bool,
65+
recursion_depth: Cell<usize>,
6566
}
6667

6768
#[derive(Debug, Default)]
@@ -297,6 +298,7 @@ impl VirtualMachine {
297298
codec_registry,
298299
}),
299300
initialized: false,
301+
recursion_depth: Cell::new(0),
300302
};
301303

302304
let frozen = frozen::map_frozen(&vm, frozen::get_module_inits()).collect();
@@ -464,6 +466,7 @@ impl VirtualMachine {
464466
repr_guards: RefCell::default(),
465467
state: self.state.clone(),
466468
initialized: self.initialized,
469+
recursion_depth: Cell::new(0),
467470
};
468471
PyThread { thread_vm }
469472
}
@@ -502,25 +505,41 @@ impl VirtualMachine {
502505
}
503506
}
504507

508+
pub fn current_recursion_depth(&self) -> usize {
509+
self.recursion_depth.get()
510+
}
511+
512+
/// Used to run the body of a (possibly) recursive function. It will raise a
513+
/// RecursionError if recursive functions are nested far too many times,
514+
/// preventing a stack overflow.
515+
pub fn with_recursion<R, F: FnOnce() -> PyResult<R>>(&self, _where: &str, f: F) -> PyResult<R> {
516+
self.check_recursive_call(_where)?;
517+
self.recursion_depth.set(self.recursion_depth.get() + 1);
518+
let result = f();
519+
self.recursion_depth.set(self.recursion_depth.get() - 1);
520+
result
521+
}
522+
505523
pub fn with_frame<R, F: FnOnce(FrameRef) -> PyResult<R>>(
506524
&self,
507525
frame: FrameRef,
508526
f: F,
509527
) -> PyResult<R> {
510-
self.check_recursive_call("")?;
511-
self.frames.borrow_mut().push(frame.clone());
512-
let result = f(frame);
513-
// defer dec frame
514-
let _popped = self.frames.borrow_mut().pop();
515-
result
528+
self.with_recursion("", || {
529+
self.frames.borrow_mut().push(frame.clone());
530+
let result = f(frame);
531+
// defer dec frame
532+
let _popped = self.frames.borrow_mut().pop();
533+
result
534+
})
516535
}
517536

518537
pub fn run_frame(&self, frame: FrameRef) -> PyResult<ExecutionResult> {
519538
self.with_frame(frame, |f| f.run(self))
520539
}
521540

522541
fn check_recursive_call(&self, _where: &str) -> PyResult<()> {
523-
if self.frames.borrow().len() > self.recursion_limit.get() {
542+
if self.recursion_depth.get() > self.recursion_limit.get() {
524543
Err(self.new_recursion_error(format!("maximum recursion depth exceeded {}", _where)))
525544
} else {
526545
Ok(())
@@ -887,8 +906,10 @@ impl VirtualMachine {
887906
}
888907

889908
pub fn to_repr(&self, obj: &PyObjectRef) -> PyResult<PyStrRef> {
890-
let repr = self.call_special_method(obj.clone(), "__repr__", ())?;
891-
repr.try_into_value(self)
909+
self.with_recursion(" while getting the repr of an object", || {
910+
let repr = self.call_special_method(obj.clone(), "__repr__", ())?;
911+
repr.try_into_value(self)
912+
})
892913
}
893914

894915
pub fn to_index_opt(&self, obj: PyObjectRef) -> Option<PyResult<PyIntRef>> {

0 commit comments

Comments
 (0)