Skip to content

Commit 37e5c9b

Browse files
committed
Fix bytes.replace
1 parent f60b07a commit 37e5c9b

File tree

4 files changed

+102
-32
lines changed

4 files changed

+102
-32
lines changed

Lib/test/string_tests.py

-1
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,6 @@ def test_rsplit(self):
500500
self.checkraises(ValueError, 'hello', 'rsplit', '')
501501
self.checkraises(ValueError, 'hello', 'rsplit', '', 0)
502502

503-
@unittest.skip("TODO: RUSTPYTHON test_bytes")
504503
def test_replace(self):
505504
EQ = self.checkequal
506505

vm/src/obj/objbytearray.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -465,9 +465,10 @@ impl PyByteArray {
465465
&self,
466466
old: PyByteInner,
467467
new: PyByteInner,
468-
count: OptionalArg<PyIntRef>,
468+
count: OptionalArg<isize>,
469+
vm: &VirtualMachine,
469470
) -> PyResult<PyByteArray> {
470-
Ok(self.borrow_value().replace(old, new, count)?.into())
471+
Ok(self.borrow_value().replace(old, new, count, vm)?.into())
471472
}
472473

473474
#[pymethod(name = "clear")]

vm/src/obj/objbyteinner.rs

+96-27
Original file line numberDiff line numberDiff line change
@@ -1127,41 +1127,100 @@ impl PyByteInner {
11271127
bytes_zfill(&self.elements, width.to_usize().unwrap_or(0))
11281128
}
11291129

1130-
pub fn replace(
1130+
// len(self)>=1, from="", len(to)>=1, maxcount>=1
1131+
fn replace_interleave(&self, to: PyByteInner, maxcount: Option<usize>) -> Vec<u8> {
1132+
let place_count = self.elements.len() + 1;
1133+
let count = maxcount.map_or(place_count, |v| std::cmp::min(v, place_count)) - 1;
1134+
let capacity = self.elements.len() + count * to.len();
1135+
let mut result = Vec::with_capacity(capacity);
1136+
let to_slice = to.elements.as_slice();
1137+
result.extend_from_slice(to_slice);
1138+
for c in &self.elements[..count] {
1139+
result.push(*c);
1140+
result.extend_from_slice(to_slice);
1141+
}
1142+
result.extend_from_slice(&self.elements[count..]);
1143+
result
1144+
}
1145+
1146+
fn replace_general(
11311147
&self,
1132-
old: PyByteInner,
1133-
new: PyByteInner,
1134-
count: OptionalArg<PyIntRef>,
1148+
from: PyByteInner,
1149+
to: PyByteInner,
1150+
maxcount: Option<usize>,
1151+
vm: &VirtualMachine,
11351152
) -> PyResult<Vec<u8>> {
1136-
let count = match count.into_option() {
1137-
Some(int) => int
1138-
.as_bigint()
1139-
.to_u32()
1140-
.unwrap_or(self.elements.len() as u32),
1141-
None => self.elements.len() as u32,
1142-
};
1143-
1144-
let mut res = vec![];
1145-
let mut index = 0;
1146-
let mut done = 0;
1153+
let count = count_substring(self.elements.as_slice(), from.elements.as_slice(), maxcount);
1154+
if count == 0 {
1155+
// no matches, return unchanged
1156+
return Ok(self.elements.clone());
1157+
}
11471158

1148-
let slice = &self.elements;
1149-
loop {
1150-
if done == count || index > slice.len() - old.len() {
1151-
res.extend_from_slice(&slice[index..]);
1159+
// Check for overflow
1160+
// result_len = self_len + count * (to_len-from_len)
1161+
debug_assert!(count > 0);
1162+
if to.len() as isize - from.len() as isize
1163+
> (std::isize::MAX - self.elements.len() as isize) / count as isize
1164+
{
1165+
return Err(vm.new_overflow_error("replace bytes is too long".to_owned()));
1166+
}
1167+
let result_len = self.elements.len() + count * (to.len() - from.len());
1168+
1169+
let mut result = Vec::with_capacity(result_len);
1170+
let mut last_end = 0;
1171+
let mut count = count;
1172+
for offset in self.elements.find_iter(&from.elements) {
1173+
result.extend_from_slice(&self.elements[last_end..offset]);
1174+
result.extend_from_slice(to.elements.as_slice());
1175+
last_end = offset + from.len();
1176+
count -= 1;
1177+
if count == 0 {
11521178
break;
11531179
}
1154-
if &slice[index..index + old.len()] == old.elements.as_slice() {
1155-
res.extend_from_slice(&new.elements);
1156-
index += old.len();
1157-
done += 1;
1158-
} else {
1159-
res.push(slice[index]);
1160-
index += 1
1180+
}
1181+
result.extend_from_slice(&self.elements[last_end..]);
1182+
Ok(result)
1183+
}
1184+
1185+
pub fn replace(
1186+
&self,
1187+
from: PyByteInner,
1188+
to: PyByteInner,
1189+
maxcount: OptionalArg<isize>,
1190+
vm: &VirtualMachine,
1191+
) -> PyResult<Vec<u8>> {
1192+
// stringlib_replace in CPython
1193+
let maxcount = match maxcount {
1194+
OptionalArg::Present(maxcount) if maxcount >= 0 => {
1195+
if maxcount == 0 || self.elements.is_empty() {
1196+
// nothing to do; return the original bytes
1197+
return Ok(self.elements.clone());
1198+
}
1199+
Some(maxcount as usize)
1200+
}
1201+
_ => None,
1202+
};
1203+
1204+
// Handle zero-length special cases
1205+
if from.elements.is_empty() {
1206+
if to.elements.is_empty() {
1207+
// nothing to do; return the original bytes
1208+
return Ok(self.elements.clone());
11611209
}
1210+
// insert the 'to' bytes everywhere.
1211+
// >>> b"Python".replace(b"", b".")
1212+
// b'.P.y.t.h.o.n.'
1213+
return Ok(self.replace_interleave(to, maxcount));
11621214
}
11631215

1164-
Ok(res)
1216+
// Except for b"".replace(b"", b"A") == b"A" there is no way beyond this
1217+
// point for an empty self bytes to generate a non-empty bytes
1218+
// Special case so the remaining code always gets a non-empty bytes
1219+
if self.elements.is_empty() {
1220+
return Ok(self.elements.clone());
1221+
}
1222+
1223+
self.replace_general(from, to, maxcount, vm)
11651224
}
11661225

11671226
pub fn title(&self) -> Vec<u8> {
@@ -1233,6 +1292,16 @@ pub fn try_as_byte(obj: &PyObjectRef) -> Option<Vec<u8>> {
12331292
})
12341293
}
12351294

1295+
#[inline]
1296+
fn count_substring(haystack: &[u8], needle: &[u8], maxcount: Option<usize>) -> usize {
1297+
let substrings = haystack.find_iter(needle);
1298+
if let Some(maxcount) = maxcount {
1299+
std::cmp::min(substrings.take(maxcount).count(), maxcount)
1300+
} else {
1301+
substrings.count()
1302+
}
1303+
}
1304+
12361305
pub trait ByteOr: ToPrimitive {
12371306
fn byte_or(&self, vm: &VirtualMachine) -> PyResult<u8> {
12381307
match self.to_u8() {

vm/src/obj/objbytes.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -426,9 +426,10 @@ impl PyBytes {
426426
&self,
427427
old: PyByteInner,
428428
new: PyByteInner,
429-
count: OptionalArg<PyIntRef>,
429+
count: OptionalArg<isize>,
430+
vm: &VirtualMachine,
430431
) -> PyResult<PyBytes> {
431-
Ok(self.inner.replace(old, new, count)?.into())
432+
Ok(self.inner.replace(old, new, count, vm)?.into())
432433
}
433434

434435
#[pymethod(name = "title")]

0 commit comments

Comments
 (0)