Skip to content

Commit 0c08d30

Browse files
committed
Add rich comparison for dict_iterator
Modified inner_eq to inner_cmp so that dict_iterator can use and added rich comparison for dict_iterator.
1 parent 3d88545 commit 0c08d30

File tree

2 files changed

+150
-26
lines changed

2 files changed

+150
-26
lines changed

Lib/test/test_dict.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,6 @@ def __hash__(self):
582582
with self.assertRaises(Exc):
583583
d1 == d2
584584

585-
@unittest.skip("TODO: RUSTPYTHON")
586585
def test_keys_contained(self):
587586
self.helper_keys_contained(lambda x: x.keys())
588587
self.helper_keys_contained(lambda x: x.items())
@@ -631,8 +630,6 @@ def helper_keys_contained(self, fn):
631630
self.assertTrue(larger != larger3)
632631
self.assertFalse(larger == larger3)
633632

634-
# TODO: RUSTPYTHON
635-
@unittest.expectedFailure
636633
def test_errors_in_view_containment_check(self):
637634
class C:
638635
def __eq__(self, other):

vm/src/obj/objdict.rs

Lines changed: 150 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,21 @@ use std::fmt;
33
use crossbeam_utils::atomic::AtomicCell;
44

55
use super::objiter;
6+
use super::objset::PySet;
67
use super::objstr;
78
use super::objtype::{self, PyClassRef};
89
use crate::dictdatatype::{self, DictKey};
910
use crate::exceptions::PyBaseExceptionRef;
1011
use crate::function::{KwArgs, OptionalArg, PyFuncArgs};
1112
use crate::pyobject::{
12-
BorrowValue, IdProtocol, IntoPyObject, ItemProtocol, PyAttributes, PyClassImpl, PyContext,
13-
PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
13+
BorrowValue, IdProtocol, IntoPyObject, ItemProtocol, PyArithmaticValue, PyAttributes,
14+
PyClassImpl, PyComparisonValue, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue,
15+
TryFromObject, TypeProtocol,
1416
};
1517
use crate::vm::{ReprGuard, VirtualMachine};
1618

1719
use std::mem::size_of;
20+
use PyArithmaticValue::{Implemented, NotImplemented};
1821

1922
pub type DictContentType = dictdatatype::Dict;
2023

@@ -146,46 +149,65 @@ impl PyDict {
146149
!self.entries.is_empty()
147150
}
148151

149-
fn inner_eq(zelf: PyRef<Self>, other: &PyDict, vm: &VirtualMachine) -> PyResult<bool> {
150-
if other.entries.len() != zelf.entries.len() {
151-
return Ok(false);
152+
fn inner_cmp(
153+
zelf: PyRef<Self>,
154+
other: PyDictRef,
155+
size_func: fn(usize, usize) -> bool,
156+
item: bool,
157+
vm: &VirtualMachine,
158+
) -> PyResult<PyComparisonValue> {
159+
if size_func(zelf.len(), other.len()) {
160+
return Ok(Implemented(false));
152161
}
153-
for (k, v1) in zelf {
154-
match other.entries.get(vm, &k)? {
162+
let (zelf, other) = if zelf.len() < other.len() {
163+
(other, zelf)
164+
} else {
165+
(zelf, other)
166+
};
167+
for (k, v1) in other {
168+
match zelf.get_item_option(k, vm)? {
155169
Some(v2) => {
156170
if v1.is(&v2) {
157171
continue;
158172
}
159-
if !vm.bool_eq(v1, v2)? {
160-
return Ok(false);
173+
if item && !vm.bool_eq(v1, v2)? {
174+
return Ok(Implemented(false));
161175
}
162176
}
163177
None => {
164-
return Ok(false);
178+
return Ok(Implemented(false));
165179
}
166180
}
167181
}
168-
Ok(true)
182+
Ok(Implemented(true))
169183
}
170184

171185
#[pymethod(magic)]
172-
fn eq(zelf: PyRef<Self>, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
173-
if let Some(other) = other.payload::<PyDict>() {
174-
let eq = Self::inner_eq(zelf, other, vm)?;
175-
Ok(vm.ctx.new_bool(eq))
186+
fn eq(
187+
zelf: PyRef<Self>,
188+
other: PyObjectRef,
189+
vm: &VirtualMachine,
190+
) -> PyResult<PyComparisonValue> {
191+
if let Ok(other) = other.downcast::<PyDict>() {
192+
Self::inner_cmp(
193+
zelf,
194+
other,
195+
|zelf: usize, other: usize| -> bool { zelf != other },
196+
true,
197+
vm,
198+
)
176199
} else {
177-
Ok(vm.ctx.not_implemented())
200+
Ok(NotImplemented)
178201
}
179202
}
180203

181204
#[pymethod(magic)]
182-
fn ne(zelf: PyRef<Self>, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
183-
if let Some(other) = other.payload::<PyDict>() {
184-
let neq = !Self::inner_eq(zelf, other, vm)?;
185-
Ok(vm.ctx.new_bool(neq))
186-
} else {
187-
Ok(vm.ctx.not_implemented())
188-
}
205+
fn ne(
206+
zelf: PyRef<Self>,
207+
other: PyObjectRef,
208+
vm: &VirtualMachine,
209+
) -> PyResult<PyComparisonValue> {
210+
Ok(Self::eq(zelf, other, vm)?.map(|v| !v))
189211
}
190212

191213
#[pymethod(magic)]
@@ -611,6 +633,111 @@ macro_rules! dict_iterator {
611633
fn reversed(&self) -> $reverse_iter_name {
612634
$reverse_iter_name::new(self.dict.clone())
613635
}
636+
637+
fn cmp(
638+
zelf: PyRef<Self>,
639+
other: PyObjectRef,
640+
size_func: fn(usize, usize) -> bool,
641+
vm: &VirtualMachine,
642+
) -> PyResult<PyComparisonValue> {
643+
match_class!(match other {
644+
dictview @ Self => {
645+
PyDict::inner_cmp(
646+
zelf.dict.clone(),
647+
dictview.dict.clone(),
648+
size_func,
649+
!zelf.class().is(&vm.ctx.types.dict_keys_type),
650+
vm,
651+
)
652+
}
653+
_set @ PySet => {
654+
// TODO: Implement comparison for set
655+
Ok(NotImplemented)
656+
}
657+
_ => {
658+
Ok(NotImplemented)
659+
}
660+
})
661+
}
662+
663+
#[pymethod(name = "__eq__")]
664+
fn eq(
665+
zelf: PyRef<Self>,
666+
other: PyObjectRef,
667+
vm: &VirtualMachine,
668+
) -> PyResult<PyComparisonValue> {
669+
Self::cmp(
670+
zelf,
671+
other,
672+
|zelf: usize, other: usize| -> bool { zelf != other },
673+
vm,
674+
)
675+
}
676+
677+
#[pymethod(name = "__ne__")]
678+
fn ne(
679+
zelf: PyRef<Self>,
680+
other: PyObjectRef,
681+
vm: &VirtualMachine,
682+
) -> PyResult<PyComparisonValue> {
683+
Ok(Self::eq(zelf, other, vm)?.map(|v| !v))
684+
}
685+
686+
#[pymethod(name = "__lt__")]
687+
fn lt(
688+
zelf: PyRef<Self>,
689+
other: PyObjectRef,
690+
vm: &VirtualMachine,
691+
) -> PyResult<PyComparisonValue> {
692+
Self::cmp(
693+
zelf,
694+
other,
695+
|zelf: usize, other: usize| -> bool { zelf >= other },
696+
vm,
697+
)
698+
}
699+
700+
#[pymethod(name = "__le__")]
701+
fn le(
702+
zelf: PyRef<Self>,
703+
other: PyObjectRef,
704+
vm: &VirtualMachine,
705+
) -> PyResult<PyComparisonValue> {
706+
Self::cmp(
707+
zelf,
708+
other,
709+
|zelf: usize, other: usize| -> bool { zelf > other },
710+
vm,
711+
)
712+
}
713+
714+
#[pymethod(name = "__gt__")]
715+
fn gt(
716+
zelf: PyRef<Self>,
717+
other: PyObjectRef,
718+
vm: &VirtualMachine,
719+
) -> PyResult<PyComparisonValue> {
720+
Self::cmp(
721+
zelf,
722+
other,
723+
|zelf: usize, other: usize| -> bool { zelf <= other },
724+
vm,
725+
)
726+
}
727+
728+
#[pymethod(name = "__ge__")]
729+
fn ge(
730+
zelf: PyRef<Self>,
731+
other: PyObjectRef,
732+
vm: &VirtualMachine,
733+
) -> PyResult<PyComparisonValue> {
734+
Self::cmp(
735+
zelf,
736+
other,
737+
|zelf: usize, other: usize| -> bool { zelf < other },
738+
vm,
739+
)
740+
}
614741
}
615742

616743
impl PyValue for $name {

0 commit comments

Comments
 (0)