Skip to content

Commit 7233579

Browse files
authored
Refactor list avoid duplicate the vec (RustPython#3241)
* list count bench * Refactor list count avoid duplicate the vec * optimize list count with HEAPTYPE flag * introduce generic safe iter functions for list * Refactor list functions (contains, index, remove) * Refactor list iter functions with const generics * optimize list with richcompare * optimize list iter_equal
1 parent 3ab48ba commit 7233579

File tree

5 files changed

+193
-62
lines changed

5 files changed

+193
-62
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
l = [i for i in range(ITERATIONS)]
2+
3+
# ---
4+
l.count(1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
class A:
2+
def __eq__(self, other):
3+
return True
4+
5+
l = [A()] * ITERATIONS
6+
7+
# ---
8+
l.count(1)

extra_tests/snippets/list.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ def __eq__(self, other):
646646

647647
l = [1, 2, 3, m, 4]
648648
m.list = l
649-
l.count(4) # TODO: assert l.count(4) == 1
649+
assert l.count(4) == 1
650650

651651
l = [1, 2, 3, m, 4]
652652
m.list = l

vm/src/builtins/list.rs

+179-60
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,21 @@ use crate::{
66
function::{ArgIterable, FuncArgs, IntoPyObject, OptionalArg},
77
protocol::{PyIterReturn, PyMappingMethods},
88
sequence::{self, SimpleSeq},
9-
sliceable::{PySliceableSequence, PySliceableSequenceMut, SequenceIndex},
9+
sliceable::{saturate_index, PySliceableSequence, PySliceableSequenceMut, SequenceIndex},
1010
stdlib::sys,
1111
types::{
12-
AsMapping, Comparable, Constructor, Hashable, IterNext, IterNextIterable, Iterable,
13-
PyComparisonOp, Unconstructible, Unhashable,
12+
richcompare_wrapper, AsMapping, Comparable, Constructor, Hashable, IterNext,
13+
IterNextIterable, Iterable, PyComparisonOp, RichCompareFunc, Unconstructible, Unhashable,
1414
},
1515
utils::Either,
1616
vm::{ReprGuard, VirtualMachine},
17-
PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue,
18-
TryFromObject, TypeProtocol,
17+
IdProtocol, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef,
18+
PyResult, PyValue, TryFromObject, TypeProtocol,
1919
};
2020
use std::fmt;
2121
use std::iter::FromIterator;
2222
use std::mem::size_of;
23-
use std::ops::DerefMut;
23+
use std::ops::{DerefMut, Range};
2424

2525
/// Built-in mutable sequence.
2626
///
@@ -252,30 +252,173 @@ impl PyList {
252252
Ok(zelf.clone())
253253
}
254254

255+
fn _iter_equal<F: FnMut(), const SHORT: bool>(
256+
&self,
257+
needle: &PyObjectRef,
258+
range: Range<usize>,
259+
mut f: F,
260+
vm: &VirtualMachine,
261+
) -> PyResult<usize> {
262+
let needle_cls = needle.class();
263+
let needle_cmp = needle_cls
264+
.mro_find_map(|cls| cls.slots.richcompare.load())
265+
.unwrap();
266+
267+
let mut borrower = None;
268+
let mut i = range.start;
269+
270+
let index = loop {
271+
if i >= range.end {
272+
break usize::MAX;
273+
}
274+
let guard = if let Some(x) = borrower.take() {
275+
x
276+
} else {
277+
self.borrow_vec()
278+
};
279+
280+
let elem = if let Some(x) = guard.get(i) {
281+
x
282+
} else {
283+
break usize::MAX;
284+
};
285+
286+
if elem.is(needle) {
287+
f();
288+
if SHORT {
289+
break i;
290+
}
291+
borrower = Some(guard);
292+
} else {
293+
let elem_cls = elem.class();
294+
let reverse_first = !elem_cls.is(&needle_cls) && elem_cls.issubclass(&needle_cls);
295+
296+
let eq = if reverse_first {
297+
let elem_cmp = elem_cls
298+
.mro_find_map(|cls| cls.slots.richcompare.load())
299+
.unwrap();
300+
drop(elem_cls);
301+
302+
fn cmp(
303+
elem: &PyObjectRef,
304+
needle: &PyObjectRef,
305+
elem_cmp: RichCompareFunc,
306+
needle_cmp: RichCompareFunc,
307+
vm: &VirtualMachine,
308+
) -> PyResult<bool> {
309+
match elem_cmp(elem, needle, PyComparisonOp::Eq, vm)? {
310+
Either::B(PyComparisonValue::Implemented(value)) => Ok(value),
311+
Either::A(obj) if !obj.is(&vm.ctx.not_implemented) => {
312+
obj.try_to_bool(vm)
313+
}
314+
_ => match needle_cmp(needle, elem, PyComparisonOp::Eq, vm)? {
315+
Either::B(PyComparisonValue::Implemented(value)) => Ok(value),
316+
Either::A(obj) if !obj.is(&vm.ctx.not_implemented) => {
317+
obj.try_to_bool(vm)
318+
}
319+
_ => Ok(false),
320+
},
321+
}
322+
}
323+
324+
if elem_cmp as usize == richcompare_wrapper as usize {
325+
let elem = elem.clone();
326+
drop(guard);
327+
cmp(&elem, needle, elem_cmp, needle_cmp, vm)?
328+
} else {
329+
let eq = cmp(elem, needle, elem_cmp, needle_cmp, vm)?;
330+
borrower = Some(guard);
331+
eq
332+
}
333+
} else {
334+
match needle_cmp(needle, elem, PyComparisonOp::Eq, vm)? {
335+
Either::B(PyComparisonValue::Implemented(value)) => {
336+
drop(elem_cls);
337+
borrower = Some(guard);
338+
value
339+
}
340+
Either::A(obj) if !obj.is(&vm.ctx.not_implemented) => {
341+
drop(elem_cls);
342+
borrower = Some(guard);
343+
obj.try_to_bool(vm)?
344+
}
345+
_ => {
346+
let elem_cmp = elem_cls
347+
.mro_find_map(|cls| cls.slots.richcompare.load())
348+
.unwrap();
349+
drop(elem_cls);
350+
351+
fn cmp(
352+
elem: &PyObjectRef,
353+
needle: &PyObjectRef,
354+
elem_cmp: RichCompareFunc,
355+
vm: &VirtualMachine,
356+
) -> PyResult<bool> {
357+
match elem_cmp(elem, needle, PyComparisonOp::Eq, vm)? {
358+
Either::B(PyComparisonValue::Implemented(value)) => Ok(value),
359+
Either::A(obj) if !obj.is(&vm.ctx.not_implemented) => {
360+
obj.try_to_bool(vm)
361+
}
362+
_ => Ok(false),
363+
}
364+
}
365+
366+
if elem_cmp as usize == richcompare_wrapper as usize {
367+
let elem = elem.clone();
368+
drop(guard);
369+
cmp(&elem, needle, elem_cmp, vm)?
370+
} else {
371+
let eq = cmp(elem, needle, elem_cmp, vm)?;
372+
borrower = Some(guard);
373+
eq
374+
}
375+
}
376+
}
377+
};
378+
379+
if eq {
380+
f();
381+
if SHORT {
382+
break i;
383+
}
384+
}
385+
}
386+
i += 1;
387+
};
388+
389+
// TODO: Optioned<usize>
390+
Ok(index)
391+
}
392+
393+
fn foreach_equal<F: FnMut()>(
394+
&self,
395+
needle: &PyObjectRef,
396+
f: F,
397+
vm: &VirtualMachine,
398+
) -> PyResult<()> {
399+
self._iter_equal::<_, false>(needle, 0..usize::MAX, f, vm)
400+
.map(|_| ())
401+
}
402+
403+
fn find_equal(
404+
&self,
405+
needle: &PyObjectRef,
406+
range: Range<usize>,
407+
vm: &VirtualMachine,
408+
) -> PyResult<usize> {
409+
self._iter_equal::<_, true>(needle, range, || {}, vm)
410+
}
411+
255412
#[pymethod]
256413
fn count(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
257-
// TODO: to_vec() cause copy which leads to cost O(N). It need to be improved.
258-
let elements = self.borrow_vec().to_vec();
259-
let mut count: usize = 0;
260-
for elem in elements.iter() {
261-
if vm.identical_or_equal(elem, &needle)? {
262-
count += 1;
263-
}
264-
}
414+
let mut count = 0;
415+
self.foreach_equal(&needle, || count += 1, vm)?;
265416
Ok(count)
266417
}
267418

268419
#[pymethod(magic)]
269-
pub fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
270-
// TODO: to_vec() cause copy which leads to cost O(N). It need to be improved.
271-
let elements = self.borrow_vec().to_vec();
272-
for elem in elements.iter() {
273-
if vm.identical_or_equal(elem, &needle)? {
274-
return Ok(true);
275-
}
276-
}
277-
278-
Ok(false)
420+
pub(crate) fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
421+
Ok(self.find_equal(&needle, 0..usize::MAX, vm)? != usize::MAX)
279422
}
280423

281424
#[pymethod]
@@ -286,33 +429,17 @@ impl PyList {
286429
stop: OptionalArg<isize>,
287430
vm: &VirtualMachine,
288431
) -> PyResult<usize> {
289-
let mut start = start.into_option().unwrap_or(0);
290-
if start < 0 {
291-
start += self.borrow_vec().len() as isize;
292-
if start < 0 {
293-
start = 0;
294-
}
295-
}
296-
let mut stop = stop.into_option().unwrap_or(sys::MAXSIZE);
297-
if stop < 0 {
298-
stop += self.borrow_vec().len() as isize;
299-
if stop < 0 {
300-
stop = 0;
301-
}
302-
}
303-
// TODO: to_vec() cause copy which leads to cost O(N). It need to be improved.
304-
let elements = self.borrow_vec().to_vec();
305-
for (index, element) in elements
306-
.iter()
307-
.enumerate()
308-
.take(stop as usize)
309-
.skip(start as usize)
310-
{
311-
if vm.identical_or_equal(element, &needle)? {
312-
return Ok(index);
313-
}
432+
let len = self.len();
433+
let start = start.map(|i| saturate_index(i, len)).unwrap_or(0);
434+
let stop = stop
435+
.map(|i| saturate_index(i, len))
436+
.unwrap_or(sys::MAXSIZE as usize);
437+
let index = self.find_equal(&needle, start..stop, vm)?;
438+
if index == usize::MAX {
439+
Err(vm.new_value_error(format!("'{}' is not in list", vm.to_str(&needle)?)))
440+
} else {
441+
Ok(index)
314442
}
315-
Err(vm.new_value_error(format!("'{}' is not in list", vm.to_str(&needle)?)))
316443
}
317444

318445
#[pymethod]
@@ -333,17 +460,9 @@ impl PyList {
333460

334461
#[pymethod]
335462
fn remove(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
336-
// TODO: to_vec() cause copy which leads to cost O(N). It need to be improved.
337-
let elements = self.borrow_vec().to_vec();
338-
let mut ri: Option<usize> = None;
339-
for (index, element) in elements.iter().enumerate() {
340-
if vm.identical_or_equal(element, &needle)? {
341-
ri = Some(index);
342-
break;
343-
}
344-
}
463+
let index = self.find_equal(&needle, 0..usize::MAX, vm)?;
345464

346-
if let Some(index) = ri {
465+
if index != usize::MAX {
347466
// defer delete out of borrow
348467
Ok(self.borrow_vec_mut().remove(index))
349468
} else {

vm/src/types/slot.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ fn setattro_wrapper(
225225
Ok(())
226226
}
227227

228-
fn richcompare_wrapper(
228+
pub(crate) fn richcompare_wrapper(
229229
zelf: &PyObjectRef,
230230
other: &PyObjectRef,
231231
op: PyComparisonOp,

0 commit comments

Comments
 (0)