diff --git a/docs/changes.rst b/docs/changes.rst index 2fa4d50e9..f294845d4 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -1,5 +1,11 @@ .. _changes: +4.0.0 (2018-01-14) +------------------ + +* Accept multiple keys in :py:meth:`MultiDict.update` and + :py:meth:`CIMultiDict.update` (:pr:`199`) + 3.3.2 (2017-11-02) ------------------ diff --git a/multidict/_multidict.pyx b/multidict/_multidict.pyx index ef1ba83da..1fcff3f6e 100644 --- a/multidict/_multidict.pyx +++ b/multidict/_multidict.pyx @@ -287,6 +287,9 @@ cdef class MultiDict(_Base): cdef _extend(self, tuple args, dict kwargs, name, bint do_add): cdef _Pair item cdef object key + cdef object value + cdef object arg + cdef object i if len(args) > 1: raise TypeError("{} takes at most 1 positional argument" @@ -295,45 +298,17 @@ cdef class MultiDict(_Base): if args: arg = args[0] if isinstance(arg, _Base): - for i in (<_Base>arg)._impl._items: - item = <_Pair>i - key = item._key - value = item._value - if do_add: - self._add(key, value) - else: - self._replace(key, value) - elif hasattr(arg, 'items'): - for i in arg.items(): - if isinstance(i, _Pair): - item = <_Pair>i - key = item._key - value = item._value - else: - key = i[0] - value = i[1] - if do_add: - self._add(key, value) - else: - self._replace(key, value) + if do_add: + self._append_items((<_Base>arg)._impl) + else: + self._update_items((<_Base>arg)._impl) else: - for i in arg: - if isinstance(i, _Pair): - item = <_Pair>i - key = item._key - value = item._value - else: - if not len(i) == 2: - raise TypeError( - "{} takes either dict or list of (key, value) " - "tuples".format(name)) - key = i[0] - value = i[1] - if do_add: - self._add(key, value) - else: - self._replace(key, value) - + if hasattr(arg, 'items'): + arg = arg.items() + if do_add: + self._append_items_seq(arg, name) + else: + self._update_items_seq(arg, name) for key, value in kwargs.items(): if do_add: @@ -341,6 +316,119 @@ cdef class MultiDict(_Base): else: self._replace(key, value) + cdef object _update_items(self, _Impl impl): + cdef _Pair item, item2 + cdef object i + cdef dict used_keys = {} + cdef Py_ssize_t start + cdef Py_ssize_t post + cdef Py_ssize_t size = len(self._impl._items) + cdef Py_hash_t h + + for i in impl._items: + item = <_Pair>i + + start = used_keys.get(item._identity, 0) + for pos in range(start, size): + item2 = <_Pair>(self._impl._items[pos]) + if item2._hash != item._hash: + continue + if item2._identity == item._identity: + used_keys[item._identity] = pos + 1 + item2._key = item._key + item2._value = item._value + break + else: + self._impl._items.append(_Pair.__new__( + _Pair, item._identity, item._key, item._value)) + size += 1 + used_keys[item._identity] = size + + self._post_update(used_keys) + + cdef object _update_items_seq(self, object arg, object name): + cdef _Pair item + cdef object i + cdef object identity + cdef object key + cdef object value + cdef dict used_keys = {} + cdef Py_ssize_t start + cdef Py_ssize_t post + cdef Py_ssize_t size = len(self._impl._items) + cdef Py_hash_t h + for i in arg: + if not len(i) == 2: + raise TypeError( + "{} takes either dict or list of (key, value) " + "tuples".format(name)) + key = _str(i[0]) + value = i[1] + identity = self._title(key) + h = hash(identity) + + start = used_keys.get(identity, 0) + for pos in range(start, size): + item = <_Pair>(self._impl._items[pos]) + if item._hash != h: + continue + if item._identity == identity: + used_keys[identity] = pos + 1 + item._key = key + item._value = value + break + else: + self._impl._items.append(_Pair.__new__( + _Pair, identity, key, value)) + size += 1 + used_keys[identity] = size + + self._post_update(used_keys) + + cdef object _post_update(self, dict used_keys): + cdef Py_ssize_t i = 0 + cdef _Pair item + while i < len(self._impl._items): + item = <_Pair>self._impl._items[i] + pos = used_keys.get(item._identity) + if pos is None: + i += 1 + continue + if i >= pos: + del self._impl._items[i] + else: + i += 1 + + self._impl.incr_version() + + cdef object _append_items(self, _Impl impl): + cdef _Pair item + cdef object i + cdef str key + cdef object value + for i in impl._items: + item = <_Pair>i + key = item._key + value = item._value + self._impl._items.append(_Pair.__new__( + _Pair, self._title(key), key, value)) + self._impl.incr_version() + + cdef object _append_items_seq(self, object arg, object name): + cdef object i + cdef object key + cdef object value + for i in arg: + if not len(i) == 2: + raise TypeError( + "{} takes either dict or list of (key, value) " + "tuples".format(name)) + key = i[0] + value = i[1] + self._impl._items.append(_Pair.__new__( + _Pair, self._title(key), _str(key), value)) + self._impl.incr_version() + cdef _add(self, key, value): self._impl._items.append(_Pair.__new__( _Pair, self._title(key), _str(key), value)) diff --git a/multidict/_multidict_py.py b/multidict/_multidict_py.py index aba485b3c..b698e2589 100644 --- a/multidict/_multidict_py.py +++ b/multidict/_multidict_py.py @@ -187,7 +187,8 @@ class MultiDict(_Base, MutableMultiMapping): def __init__(self, *args, **kwargs): self._impl = _Impl() - self._extend(args, kwargs, self.__class__.__name__, self.add) + self._extend(args, kwargs, self.__class__.__name__, + self._extend_items) def __reduce__(self): return (self.__class__, (list(self.items()),)) @@ -217,7 +218,7 @@ def extend(self, *args, **kwargs): This method must be used instead of update. """ - self._extend(args, kwargs, 'extend', self.add) + self._extend(args, kwargs, 'extend', self._extend_items) def _extend(self, args, kwargs, name, method): if len(args) > 1: @@ -225,26 +226,29 @@ def _extend(self, args, kwargs, name, method): " ({} given)".format(name, len(args))) if args: arg = args[0] - if isinstance(args[0], MultiDictProxy): + if isinstance(args[0], (MultiDict, MultiDictProxy)): items = arg._impl._items - elif isinstance(args[0], MultiDict): - items = arg._impl._items - elif hasattr(arg, 'items'): - items = [(k, k, v) for k, v in arg.items()] else: + if hasattr(arg, 'items'): + arg = arg.items() items = [] for item in arg: if not len(item) == 2: raise TypeError( "{} takes either dict or list of (key, value) " "tuples".format(name)) - items.append((item[0], item[0], item[1])) + items.append((self._title(item[0]), + self._key(item[0]), + item[1])) + + method(items) - for identity, key, value in items: - method(key, value) + method([(self._title(key), key, value) + for key, value in kwargs.items()]) - for key, value in kwargs.items(): - method(key, value) + def _extend_items(self, items): + for identity, key, value in items: + self.add(key, value) def clear(self): """Remove all items from MultiDict.""" @@ -338,7 +342,39 @@ def popitem(self): def update(self, *args, **kwargs): """Update the dictionary from *other*, overwriting existing keys.""" - self._extend(args, kwargs, 'update', self._replace) + self._extend(args, kwargs, 'update', self._update_items) + + def _update_items(self, items): + if not items: + return + used_keys = {} + for identity, key, value in items: + start = used_keys.get(identity, 0) + for i in range(start, len(self._impl._items)): + item = self._impl._items[i] + if item[0] == identity: + used_keys[identity] = i + 1 + self._impl._items[i] = (identity, key, value) + break + else: + self._impl._items.append((identity, key, value)) + used_keys[identity] = len(self._impl._items) + + # drop tails + i = 0 + while i < len(self._impl._items): + item = self._impl._items[i] + identity = item[0] + pos = used_keys.get(identity) + if pos is None: + i += 1 + continue + if i >= pos: + del self._impl._items[i] + else: + i += 1 + + self._impl.incr_version() def _replace(self, key, value): key = self._key(key) diff --git a/tests/test_mutable_multidict.py b/tests/test_mutable_multidict.py index f515d739d..bab9270e9 100644 --- a/tests/test_mutable_multidict.py +++ b/tests/test_mutable_multidict.py @@ -171,16 +171,6 @@ def test_pop_raises(self, cls): assert 'other' in d - def test_update(self, cls): - d = cls() - d.add('key', 'val1') - d.add('key', 'val2') - d.add('key2', 'val3') - - d.update(key='val') - - assert [('key', 'val'), ('key2', 'val3')] == list(d.items()) - def test_replacement_order(self, cls): d = cls() d.add('key1', 'val1') @@ -419,16 +409,6 @@ def test_pop_raises(self, cls): assert 'other' in d - def test_update(self, cls): - d = cls() - d.add('KEY', 'val1') - d.add('key', 'val2') - d.add('key2', 'val3') - - d.update(Key='val') - - assert [('Key', 'val'), ('key2', 'val3')] == list(d.items()) - def test_extend_with_istr(self, cls, istr): us = istr('a') d = cls() @@ -436,16 +416,6 @@ def test_extend_with_istr(self, cls, istr): d.extend([(us, 'val')]) assert [('A', 'val')] == list(d.items()) - def test_update_istr(self, cls, istr): - d = cls() - d.add(istr('KEY'), 'val1') - d.add('key', 'val2') - d.add('key2', 'val3') - - d.update({istr('key'): 'val'}) - - assert [('Key', 'val'), ('key2', 'val3')] == list(d.items()) - def test_copy_istr(self, cls, istr): d = cls({istr('Foo'): 'bar'}) d2 = d.copy() diff --git a/tests/test_update.py b/tests/test_update.py new file mode 100644 index 000000000..78cbc1651 --- /dev/null +++ b/tests/test_update.py @@ -0,0 +1,135 @@ +import pytest + +from multidict._compat import USE_CYTHON + +if USE_CYTHON: + from multidict._multidict import (MultiDict, CIMultiDict) + +from multidict._multidict_py import (MultiDict as PyMultiDict, # noqa: E402 + CIMultiDict as PyCIMultiDict) + + +@pytest.fixture( + params=( + [ + MultiDict, + CIMultiDict, + ] + if USE_CYTHON else + [] + ) + + [ + PyMultiDict, + PyCIMultiDict + ], + ids=( + [ + 'MultiDict', + 'CIMultiDict', + ] + if USE_CYTHON else + [] + ) + + [ + 'PyMultiDict', + 'PyCIMultiDict' + ] +) +def cls(request): + return request.param + + +@pytest.fixture +def md_cls(_multidict): + return _multidict.MultiDict + + +@pytest.fixture +def ci_md_cls(_multidict): + return _multidict.CIMultiDict + + +@pytest.fixture +def istr(_multidict): + return _multidict.istr + + +def test_update_replace(cls): + obj1 = cls([('a', 1), ('b', 2), ('a', 3), ('c', 10)]) + obj2 = cls([('a', 4), ('b', 5), ('a', 6)]) + obj1.update(obj2) + expected = [('a', 4), ('b', 5), ('a', 6), ('c', 10)] + assert list(obj1.items()) == expected + + +def test_update_append(cls): + obj1 = cls([('a', 1), ('b', 2), ('a', 3), ('c', 10)]) + obj2 = cls([('a', 4), ('a', 5), ('a', 6)]) + obj1.update(obj2) + expected = [('a', 4), ('b', 2), ('a', 5), ('c', 10), ('a', 6)] + assert list(obj1.items()) == expected + + +def test_update_remove(cls): + obj1 = cls([('a', 1), ('b', 2), ('a', 3), ('c', 10)]) + obj2 = cls([('a', 4)]) + obj1.update(obj2) + expected = [('a', 4), ('b', 2), ('c', 10)] + assert list(obj1.items()) == expected + + +def test_update_replace_seq(cls): + obj1 = cls([('a', 1), ('b', 2), ('a', 3), ('c', 10)]) + obj2 = [('a', 4), ('b', 5), ('a', 6)] + obj1.update(obj2) + expected = [('a', 4), ('b', 5), ('a', 6), ('c', 10)] + assert list(obj1.items()) == expected + + +def test_update_append_seq(cls): + obj1 = cls([('a', 1), ('b', 2), ('a', 3), ('c', 10)]) + obj2 = [('a', 4), ('a', 5), ('a', 6)] + obj1.update(obj2) + expected = [('a', 4), ('b', 2), ('a', 5), ('c', 10), ('a', 6)] + assert list(obj1.items()) == expected + + +def test_update_remove_seq(cls): + obj1 = cls([('a', 1), ('b', 2), ('a', 3), ('c', 10)]) + obj2 = [('a', 4)] + obj1.update(obj2) + expected = [('a', 4), ('b', 2), ('c', 10)] + assert list(obj1.items()) == expected + + +def test_update_md(md_cls): + d = md_cls() + d.add('key', 'val1') + d.add('key', 'val2') + d.add('key2', 'val3') + + d.update(key='val') + + assert [('key', 'val'), ('key2', 'val3')] == list(d.items()) + + +def test_update_istr_ci_md(ci_md_cls, istr): + d = ci_md_cls() + d.add(istr('KEY'), 'val1') + d.add('key', 'val2') + d.add('key2', 'val3') + + d.update({istr('key'): 'val'}) + + assert [('Key', 'val'), ('key2', 'val3')] == list(d.items()) + + +def test_update_ci_md(ci_md_cls): + d = ci_md_cls() + d.add('KEY', 'val1') + d.add('key', 'val2') + d.add('key2', 'val3') + + d.update(Key='val') + + assert [('Key', 'val'), ('key2', 'val3')] == list(d.items())