Skip to content

Commit 06463f9

Browse files
committed
Fix bytes.find
1 parent ad111b0 commit 06463f9

File tree

6 files changed

+40
-44
lines changed

6 files changed

+40
-44
lines changed

Lib/test/string_tests.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def test_count(self):
157157
self.assertEqual(rem, 0, '%s != 0 for %s' % (rem, i))
158158
self.assertEqual(r1, r2, '%s != %s for %s' % (r1, r2, i))
159159

160-
@unittest.skip("TODO: RUSTPYTHON test_bytes")
160+
# TODO: RUSTPYTHON
161+
@unittest.expectedFailure
161162
def test_find(self):
162163
self.checkequal(0, 'abcdefghiabc', 'find', 'abc')
163164
self.checkequal(9, 'abcdefghiabc', 'find', 'abc', 1)
@@ -215,7 +216,8 @@ def test_find(self):
215216
if loc != -1:
216217
self.assertEqual(i[loc:loc+len(j)], j)
217218

218-
@unittest.skip("TODO: RUSTPYTHON test_bytes")
219+
# TODO: RUSTPYTHON
220+
@unittest.expectedFailure
219221
def test_rfind(self):
220222
self.checkequal(9, 'abcdefghiabc', 'rfind', 'abc')
221223
self.checkequal(12, 'abcdefghiabc', 'rfind', '')
@@ -294,7 +296,6 @@ def test_index(self):
294296
else:
295297
self.checkraises(TypeError, 'hello', 'index', 42)
296298

297-
@unittest.skip("TODO: RUSTPYTHON test_bytes")
298299
def test_rindex(self):
299300
self.checkequal(12, 'abcdefghiabc', 'rindex', '')
300301
self.checkequal(3, 'abcdefghiabc', 'rindex', 'def')

vm/src/obj/objbytearray.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
//! Implementation of the python bytearray object.
2+
use bstr::ByteSlice;
23
use crossbeam_utils::atomic::AtomicCell;
34
use std::convert::TryFrom;
5+
use std::mem::size_of;
6+
use std::str::FromStr;
47
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
58

69
use super::objbyteinner::{
@@ -21,8 +24,6 @@ use crate::pyobject::{
2124
PyValue, ThreadSafe, TryFromObject, TypeProtocol,
2225
};
2326
use crate::vm::VirtualMachine;
24-
use std::mem::size_of;
25-
use std::str::FromStr;
2627

2728
/// "bytearray(iterable_of_ints) -> bytearray\n\
2829
/// bytearray(string, encoding[, errors]) -> bytearray\n\
@@ -327,25 +328,25 @@ impl PyByteArray {
327328

328329
#[pymethod(name = "find")]
329330
fn find(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
330-
let index = self.borrow_value().find(options, false, vm)?;
331+
let index = self.borrow_value().find(options, |h, n| h.find(n), vm)?;
331332
Ok(index.map_or(-1, |v| v as isize))
332333
}
333334

334335
#[pymethod(name = "index")]
335336
fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
336-
let index = self.borrow_value().find(options, false, vm)?;
337+
let index = self.borrow_value().find(options, |h, n| h.find(n), vm)?;
337338
index.ok_or_else(|| vm.new_value_error("substring not found".to_owned()))
338339
}
339340

340341
#[pymethod(name = "rfind")]
341342
fn rfind(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
342-
let index = self.borrow_value().find(options, true, vm)?;
343+
let index = self.borrow_value().find(options, |h, n| h.rfind(n), vm)?;
343344
Ok(index.map_or(-1, |v| v as isize))
344345
}
345346

346347
#[pymethod(name = "rindex")]
347348
fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
348-
let index = self.borrow_value().find(options, true, vm)?;
349+
let index = self.borrow_value().find(options, |h, n| h.rfind(n), vm)?;
349350
index.ok_or_else(|| vm.new_value_error("substring not found".to_owned()))
350351
}
351352

vm/src/obj/objbyteinner.rs

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -823,35 +823,17 @@ impl PyByteInner {
823823
}
824824

825825
#[inline]
826-
pub fn find(
826+
pub fn find<F>(
827827
&self,
828828
options: ByteInnerFindOptions,
829-
reverse: bool,
829+
find: F,
830830
vm: &VirtualMachine,
831-
) -> PyResult<Option<usize>> {
831+
) -> PyResult<Option<usize>>
832+
where
833+
F: Fn(&[u8], &[u8]) -> Option<usize>,
834+
{
832835
let (needle, range) = options.get_value(self.elements.len(), vm)?;
833-
if !range.is_normal() {
834-
return Ok(None);
835-
}
836-
if needle.is_empty() {
837-
return Ok(Some(if reverse { range.end } else { range.start }));
838-
}
839-
let haystack = &self.elements[range.clone()];
840-
let windows = haystack.windows(needle.len());
841-
if reverse {
842-
for (i, w) in windows.rev().enumerate() {
843-
if w == needle.as_slice() {
844-
return Ok(Some(range.end - i - needle.len()));
845-
}
846-
}
847-
} else {
848-
for (i, w) in windows.enumerate() {
849-
if w == needle.as_slice() {
850-
return Ok(Some(range.start + i));
851-
}
852-
}
853-
}
854-
Ok(None)
836+
Ok(self.elements.py_find(&needle, range, find))
855837
}
856838

857839
pub fn maketrans(from: PyByteInner, to: PyByteInner, vm: &VirtualMachine) -> PyResult {

vm/src/obj/objbytes.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use bstr::ByteSlice;
12
use crossbeam_utils::atomic::AtomicCell;
23
use std::mem::size_of;
34
use std::ops::Deref;
@@ -300,25 +301,25 @@ impl PyBytes {
300301

301302
#[pymethod(name = "find")]
302303
fn find(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
303-
let index = self.inner.find(options, false, vm)?;
304+
let index = self.inner.find(options, |h, n| h.find(n), vm)?;
304305
Ok(index.map_or(-1, |v| v as isize))
305306
}
306307

307308
#[pymethod(name = "index")]
308309
fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
309-
let index = self.inner.find(options, false, vm)?;
310+
let index = self.inner.find(options, |h, n| h.find(n), vm)?;
310311
index.ok_or_else(|| vm.new_value_error("substring not found".to_owned()))
311312
}
312313

313314
#[pymethod(name = "rfind")]
314315
fn rfind(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
315-
let index = self.inner.find(options, true, vm)?;
316+
let index = self.inner.find(options, |h, n| h.rfind(n), vm)?;
316317
Ok(index.map_or(-1, |v| v as isize))
317318
}
318319

319320
#[pymethod(name = "rindex")]
320321
fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
321-
let index = self.inner.find(options, true, vm)?;
322+
let index = self.inner.find(options, |h, n| h.rfind(n), vm)?;
322323
index.ok_or_else(|| vm.new_value_error("substring not found".to_owned()))
323324
}
324325

vm/src/obj/objstr.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,7 @@ impl PyString {
811811
Ok(joined)
812812
}
813813

814+
#[inline]
814815
fn _find<F>(
815816
&self,
816817
sub: PyStringRef,
@@ -822,12 +823,7 @@ impl PyString {
822823
F: Fn(&str, &str) -> Option<usize>,
823824
{
824825
let range = adjust_indices(start, end, self.value.len());
825-
if range.is_normal() {
826-
if let Some(index) = find(&self.value[range.clone()], &sub.value) {
827-
return Some(range.start + index);
828-
}
829-
}
830-
None
826+
self.value.py_find(&sub.value, range, find)
831827
}
832828

833829
#[pymethod]

vm/src/obj/pystr.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ pub trait PyCommonString<E> {
186186
}
187187
}
188188

189+
#[inline]
189190
fn py_strip<'a, S, FC, FD>(
190191
&'a self,
191192
chars: OptionalOption<S>,
@@ -203,4 +204,18 @@ pub trait PyCommonString<E> {
203204
None => func_default(self),
204205
}
205206
}
207+
208+
#[inline]
209+
fn py_find<F>(&self, needle: &Self, range: std::ops::Range<usize>, find: F) -> Option<usize>
210+
where
211+
F: Fn(&Self, &Self) -> Option<usize>,
212+
{
213+
if range.is_normal() {
214+
let start = range.start;
215+
if let Some(index) = find(self.get_slice(range), &needle) {
216+
return Some(start + index);
217+
}
218+
}
219+
None
220+
}
206221
}

0 commit comments

Comments
 (0)