Skip to content

Commit cc83db3

Browse files
rebuntoyouknowone
authored andcommitted
implement key argument of bisect
1 parent f40643a commit cc83db3

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

stdlib/src/bisect.rs

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ mod _bisect {
1414
lo: OptionalArg<PyObjectRef>,
1515
#[pyarg(any, optional)]
1616
hi: OptionalArg<PyObjectRef>,
17+
#[pyarg(named, default)]
18+
key: Option<PyObjectRef>,
1719
}
1820

1921
// Handles objects that implement __index__ and makes sure index fits in needed isize.
@@ -66,17 +68,21 @@ mod _bisect {
6668
#[inline]
6769
#[pyfunction]
6870
fn bisect_left(
69-
BisectArgs { a, x, lo, hi }: BisectArgs,
71+
BisectArgs { a, x, lo, hi, key }: BisectArgs,
7072
vm: &VirtualMachine,
7173
) -> PyResult<usize> {
7274
let (mut lo, mut hi) = as_usize(lo, hi, a.length(vm)?, vm)?;
7375

7476
while lo < hi {
7577
// Handles issue 13496.
7678
let mid = (lo + hi) / 2;
77-
if a.get_item(&mid, vm)?
78-
.rich_compare_bool(&x, PyComparisonOp::Lt, vm)?
79-
{
79+
let a_mid = a.get_item(&mid, vm)?;
80+
let comp = if let Some(ref key) = key {
81+
vm.invoke(key, (a_mid,))?
82+
} else {
83+
a_mid
84+
};
85+
if comp.rich_compare_bool(&x, PyComparisonOp::Lt, vm)? {
8086
lo = mid + 1;
8187
} else {
8288
hi = mid;
@@ -96,15 +102,21 @@ mod _bisect {
96102
#[inline]
97103
#[pyfunction]
98104
fn bisect_right(
99-
BisectArgs { a, x, lo, hi }: BisectArgs,
105+
BisectArgs { a, x, lo, hi, key }: BisectArgs,
100106
vm: &VirtualMachine,
101107
) -> PyResult<usize> {
102108
let (mut lo, mut hi) = as_usize(lo, hi, a.length(vm)?, vm)?;
103109

104110
while lo < hi {
105111
// Handles issue 13496.
106112
let mid = (lo + hi) / 2;
107-
if x.rich_compare_bool(&*a.get_item(&mid, vm)?, PyComparisonOp::Lt, vm)? {
113+
let a_mid = a.get_item(&mid, vm)?;
114+
let comp = if let Some(ref key) = key {
115+
vm.invoke(key, (a_mid,))?
116+
} else {
117+
a_mid
118+
};
119+
if x.rich_compare_bool(&*comp, PyComparisonOp::Lt, vm)? {
108120
hi = mid;
109121
} else {
110122
lo = mid + 1;
@@ -120,13 +132,19 @@ mod _bisect {
120132
/// Optional args lo (default 0) and hi (default len(a)) bound the
121133
/// slice of a to be searched.
122134
#[pyfunction]
123-
fn insort_left(BisectArgs { a, x, lo, hi }: BisectArgs, vm: &VirtualMachine) -> PyResult {
135+
fn insort_left(BisectArgs { a, x, lo, hi, key }: BisectArgs, vm: &VirtualMachine) -> PyResult {
136+
let x = if let Some(ref key) = key {
137+
vm.invoke(key, (x,))?
138+
} else {
139+
x
140+
};
124141
let index = bisect_left(
125142
BisectArgs {
126143
a: a.clone(),
127144
x: x.clone(),
128145
lo,
129146
hi,
147+
key,
130148
},
131149
vm,
132150
)?;
@@ -140,13 +158,19 @@ mod _bisect {
140158
/// Optional args lo (default 0) and hi (default len(a)) bound the
141159
/// slice of a to be searched
142160
#[pyfunction]
143-
fn insort_right(BisectArgs { a, x, lo, hi }: BisectArgs, vm: &VirtualMachine) -> PyResult {
161+
fn insort_right(BisectArgs { a, x, lo, hi, key }: BisectArgs, vm: &VirtualMachine) -> PyResult {
162+
let x = if let Some(ref key) = key {
163+
vm.invoke(key, (x,))?
164+
} else {
165+
x
166+
};
144167
let index = bisect_right(
145168
BisectArgs {
146169
a: a.clone(),
147170
x: x.clone(),
148171
lo,
149172
hi,
173+
key,
150174
},
151175
vm,
152176
)?;

0 commit comments

Comments
 (0)