Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 150 additions & 2 deletions Lib/test/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,66 @@ def test_filter_pickle(self):
f2 = filter(filter_char, "abcdeabcde")
self.check_iter_pickle(f1, list(f2), proto)

def test_zip_pickle_strict(self):
a = (1, 2, 3)
b = (4, 5, 6)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
z1 = zip(a, b, strict=True)
self.check_iter_pickle(z1, t, proto)

def test_zip_pickle_strict_fail(self):
a = (1, 2, 3)
b = (4, 5, 6, 7)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
z1 = zip(a, b, strict=True)
z2 = pickle.loads(pickle.dumps(z1, proto))
self.assertEqual(self.iter_error(z1, ValueError), t)
self.assertEqual(self.iter_error(z2, ValueError), t)

def test_zip_pickle_stability(self):
# Pickles of zip((1, 2, 3), (4, 5, 6)) dumped from 3.9:
pickles = [
b'citertools\nizip\np0\n(c__builtin__\niter\np1\n((I1\nI2\nI3\ntp2\ntp3\nRp4\nI0\nbg1\n((I4\nI5\nI6\ntp5\ntp6\nRp7\nI0\nbtp8\nRp9\n.',
b'citertools\nizip\nq\x00(c__builtin__\niter\nq\x01((K\x01K\x02K\x03tq\x02tq\x03Rq\x04K\x00bh\x01((K\x04K\x05K\x06tq\x05tq\x06Rq\x07K\x00btq\x08Rq\t.',
b'\x80\x02citertools\nizip\nq\x00c__builtin__\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05K\x06\x87q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t.',
b'\x80\x03cbuiltins\nzip\nq\x00cbuiltins\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05K\x06\x87q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t.',
b'\x80\x04\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05K\x06\x87\x94\x85\x94R\x94K\x00b\x86\x94R\x94.',
b'\x80\x05\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05K\x06\x87\x94\x85\x94R\x94K\x00b\x86\x94R\x94.',
]
for protocol, dump in enumerate(pickles):
z1 = zip((1, 2, 3), (4, 5, 6))
z2 = zip((1, 2, 3), (4, 5, 6), strict=False)
z3 = pickle.loads(dump)
l3 = list(z3)
self.assertEqual(type(z3), zip)
self.assertEqual(pickle.dumps(z1, protocol), dump)
self.assertEqual(pickle.dumps(z2, protocol), dump)
self.assertEqual(list(z1), l3)
self.assertEqual(list(z2), l3)

def test_zip_pickle_strict_stability(self):
# Pickles of zip((1, 2, 3), (4, 5), strict=True) dumped from 3.10:
pickles = [
b'citertools\nizip\np0\n(c__builtin__\niter\np1\n((I1\nI2\nI3\ntp2\ntp3\nRp4\nI0\nbg1\n((I4\nI5\ntp5\ntp6\nRp7\nI0\nbtp8\nRp9\nI01\nb.',
b'citertools\nizip\nq\x00(c__builtin__\niter\nq\x01((K\x01K\x02K\x03tq\x02tq\x03Rq\x04K\x00bh\x01((K\x04K\x05tq\x05tq\x06Rq\x07K\x00btq\x08Rq\tI01\nb.',
b'\x80\x02citertools\nizip\nq\x00c__builtin__\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05\x86q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t\x88b.',
b'\x80\x03cbuiltins\nzip\nq\x00cbuiltins\niter\nq\x01K\x01K\x02K\x03\x87q\x02\x85q\x03Rq\x04K\x00bh\x01K\x04K\x05\x86q\x05\x85q\x06Rq\x07K\x00b\x86q\x08Rq\t\x88b.',
b'\x80\x04\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05\x86\x94\x85\x94R\x94K\x00b\x86\x94R\x94\x88b.',
b'\x80\x05\x95L\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x03zip\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94K\x01K\x02K\x03\x87\x94\x85\x94R\x94K\x00bh\x05K\x04K\x05\x86\x94\x85\x94R\x94K\x00b\x86\x94R\x94\x88b.',
]
a = (1, 2, 3)
b = (4, 5)
t = [(1, 4), (2, 5)]
for protocol, dump in enumerate(pickles):
z1 = zip(a, b, strict=True)
z2 = pickle.loads(dump)
self.assertEqual(pickle.dumps(z1, protocol), dump)
self.assertEqual(type(z2), zip)
self.assertEqual(self.iter_error(z1, ValueError), t)
self.assertEqual(self.iter_error(z2, ValueError), t)

def test_getattr(self):
self.assertTrue(getattr(sys, 'stdout') is sys.stdout)
self.assertRaises(TypeError, getattr, sys, 1)
Expand Down Expand Up @@ -1384,6 +1444,14 @@ def test_vars(self):
self.assertRaises(TypeError, vars, 42)
self.assertEqual(vars(self.C_get_vars()), {'a':2})

def iter_error(self, iterable, error):
"""Collect `iterable` into a list, catching an expected `error`."""
items = []
with self.assertRaises(error):
for item in iterable:
items.append(item)
return items

def test_zip(self):
a = (1, 2, 3)
b = (4, 5, 6)
Expand Down Expand Up @@ -1428,8 +1496,6 @@ def __getitem__(self, i):
return i
self.assertRaises(ValueError, list, zip(BadSeq(), BadSeq()))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_zip_pickle(self):
a = (1, 2, 3)
b = (4, 5, 6)
Expand All @@ -1438,6 +1504,88 @@ def test_zip_pickle(self):
z1 = zip(a, b)
self.check_iter_pickle(z1, t, proto)

def test_zip_strict(self):
self.assertEqual(tuple(zip((1, 2, 3), 'abc', strict=True)),
((1, 'a'), (2, 'b'), (3, 'c')))
self.assertRaises(ValueError, tuple,
zip((1, 2, 3, 4), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
zip((1, 2), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
zip((1, 2), (1, 2), 'abc', strict=True))

def test_zip_strict_iterators(self):
x = iter(range(5))
y = [0]
z = iter(range(5))
self.assertRaises(ValueError, list,
(zip(x, y, z, strict=True)))
self.assertEqual(next(x), 2)
self.assertEqual(next(z), 1)

def test_zip_strict_error_handling(self):

class Error(Exception):
pass

class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise Error
return self.size

l1 = self.iter_error(zip("AB", Iter(1), strict=True), Error)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(zip("AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(zip("AB", Iter(2), "ABC", strict=True), Error)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(zip("AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(zip(Iter(1), "AB", strict=True), Error)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(zip(Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(zip(Iter(2), "ABC", strict=True), Error)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(zip(Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])

def test_zip_strict_error_handling_stopiteration(self):

class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise StopIteration
return self.size

l1 = self.iter_error(zip("AB", Iter(1), strict=True), ValueError)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(zip("AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(zip("AB", Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(zip("AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(zip(Iter(1), "AB", strict=True), ValueError)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(zip(Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(zip(Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(zip(Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_format(self):
Expand Down
76 changes: 68 additions & 8 deletions vm/src/builtins/zip.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
use super::PyTypeRef;
use crate::{
function::PosArgs,
builtins::IntoPyBool,
function::{IntoPyObject, OptionalArg, PosArgs},
protocol::{PyIter, PyIterReturn},
slots::{IteratorIterable, SlotConstructor, SlotIterator},
PyClassImpl, PyContext, PyRef, PyResult, PyValue, VirtualMachine,
PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
VirtualMachine,
};
use rustpython_common::atomic::{self, PyAtomic, Radium};

#[pyclass(module = false, name = "zip")]
#[derive(Debug)]
pub struct PyZip {
iterators: Vec<PyIter>,
strict: PyAtomic<bool>,
}

impl PyValue for PyZip {
Expand All @@ -18,17 +22,49 @@ impl PyValue for PyZip {
}
}

#[derive(FromArgs)]
pub struct PyZipNewArgs {
#[pyarg(named, optional)]
strict: OptionalArg<bool>,
}

impl SlotConstructor for PyZip {
type Args = PosArgs<PyIter>;
type Args = (PosArgs<PyIter>, PyZipNewArgs);

fn py_new(cls: PyTypeRef, iterators: Self::Args, vm: &VirtualMachine) -> PyResult {
fn py_new(cls: PyTypeRef, (iterators, args): Self::Args, vm: &VirtualMachine) -> PyResult {
let iterators = iterators.into_vec();
PyZip { iterators }.into_pyresult_with_type(vm, cls)
let strict = Radium::new(args.strict.unwrap_or(false));
PyZip { iterators, strict }.into_pyresult_with_type(vm, cls)
}
}

#[pyimpl(with(SlotIterator, SlotConstructor), flags(BASETYPE))]
impl PyZip {}
impl PyZip {
#[pymethod(magic)]
fn reduce(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
let cls = zelf.clone_class().into_pyobject(vm);
let iterators = zelf
.iterators
.iter()
.map(|obj| obj.clone().into_object())
.collect::<Vec<_>>();
let tuple_iter = vm.ctx.new_tuple(iterators);
Ok(if zelf.strict.load(atomic::Ordering::Acquire) {
vm.ctx
.new_tuple(vec![cls, tuple_iter, vm.ctx.new_bool(true)])
} else {
vm.ctx.new_tuple(vec![cls, tuple_iter])
})
}

#[pymethod(magic)]
fn setstate(zelf: PyRef<Self>, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
if let Ok(obj) = IntoPyBool::try_from_object(vm, state) {
zelf.strict.store(obj.to_bool(), atomic::Ordering::Release);
}
Ok(())
}
}

impl IteratorIterable for PyZip {}
impl SlotIterator for PyZip {
Expand All @@ -37,10 +73,34 @@ impl SlotIterator for PyZip {
return Ok(PyIterReturn::StopIteration(None));
}
let mut next_objs = Vec::new();
for iterator in zelf.iterators.iter() {
for (idx, iterator) in zelf.iterators.iter().enumerate() {
let item = match iterator.next(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)),
PyIterReturn::StopIteration(v) => {
if zelf.strict.load(atomic::Ordering::Acquire) {
if idx > 0 {
let plural = if idx == 1 { " " } else { "s 1-" };
return Err(vm.new_value_error(format!(
"zip() argument {} is shorter than argument{}{}",
idx + 1,
plural,
idx
)));
}
for (idx, iterator) in zelf.iterators[1..].iter().enumerate() {
if let PyIterReturn::Return(_obj) = iterator.next(vm)? {
let plural = if idx == 0 { " " } else { "s 1-" };
return Err(vm.new_value_error(format!(
"zip() argument {} is longer than argument{}{}",
idx + 2,
plural,
idx + 1
)));
}
}
}
return Ok(PyIterReturn::StopIteration(v));
}
};
next_objs.push(item);
}
Expand Down