Skip to content

Commit 3bc65bd

Browse files
authored
Fix itertools.count step to take PyNumber instead of PyInt (RustPython#3834)
1 parent 9dce3e2 commit 3bc65bd

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

vm/src/stdlib/itertools.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ mod decl {
1717
AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, PyWeakRef, VirtualMachine,
1818
};
1919
use crossbeam_utils::atomic::AtomicCell;
20-
use num_traits::{One, Signed, ToPrimitive};
20+
use num_traits::{Signed, ToPrimitive};
2121
use std::fmt;
2222

2323
#[pyattr]
@@ -174,7 +174,7 @@ mod decl {
174174
#[derive(Debug, PyPayload)]
175175
struct PyItertoolsCount {
176176
cur: PyRwLock<PyObjectRef>,
177-
step: PyIntRef,
177+
step: PyObjectRef,
178178
}
179179

180180
#[derive(FromArgs)]
@@ -183,7 +183,7 @@ mod decl {
183183
start: OptionalArg<PyObjectRef>,
184184

185185
#[pyarg(positional, optional)]
186-
step: OptionalArg<PyIntRef>,
186+
step: OptionalArg<PyObjectRef>,
187187
}
188188

189189
impl Constructor for PyItertoolsCount {
@@ -194,9 +194,9 @@ mod decl {
194194
Self::Args { start, step }: Self::Args,
195195
vm: &VirtualMachine,
196196
) -> PyResult {
197-
let start: PyObjectRef = start.into_option().unwrap_or_else(|| vm.new_pyobj(0));
198-
let step: PyIntRef = step.into_option().unwrap_or_else(|| vm.new_pyref(1));
199-
if !PyNumber::check(&start, vm) {
197+
let start = start.into_option().unwrap_or_else(|| vm.new_pyobj(0));
198+
let step = step.into_option().unwrap_or_else(|| vm.new_pyobj(1));
199+
if !PyNumber::check(&start, vm) || !PyNumber::check(&step, vm) {
200200
return Err(vm.new_value_error("a number is require".to_owned()));
201201
}
202202

@@ -222,11 +222,11 @@ mod decl {
222222
#[pymethod(magic)]
223223
fn repr(&self, vm: &VirtualMachine) -> PyResult<String> {
224224
let cur = format!("{}", self.cur.read().clone().repr(vm)?);
225-
let step = self.step.as_bigint();
226-
if step.is_one() {
225+
let step = &self.step;
226+
if vm.bool_eq(step, vm.ctx.new_int(1).as_object())? {
227227
return Ok(format!("count({})", cur));
228228
}
229-
Ok(format!("count({}, {})", cur, step))
229+
Ok(format!("count({}, {})", cur, step.repr(vm)?))
230230
}
231231
}
232232
impl IterNextIterable for PyItertoolsCount {}

0 commit comments

Comments
 (0)