Skip to content

Commit

Permalink
Allow additional keys to be added to DBRef instances. PYTHON-133
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Dirolf committed Jun 21, 2010
1 parent 8b76667 commit 6b0a9cc
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 12 deletions.
35 changes: 32 additions & 3 deletions pymongo/_cbsonmodule.c
Expand Up @@ -1279,10 +1279,39 @@ static PyObject* get_value(const char* buffer, int* position, int type,
PyObject* id = PyDict_GetItemString(value, "$id");
PyObject* collection = PyDict_GetItemString(value, "$ref");
PyObject* database = PyDict_GetItemString(value, "$db");
PyObject* args;

Py_INCREF(id);
PyDict_DelItemString(value, "$id");
Py_INCREF(collection);
PyDict_DelItemString(value, "$ref");

if (database != NULL) {
Py_INCREF(database);
PyDict_DelItemString(value, "$db");
args = Py_BuildValue("(OOO)", collection, id, database);
} else {
args = Py_BuildValue("(OO)", collection, id);
}
if (!args) {
Py_DECREF(id);
Py_DECREF(collection);
if (database != NULL) {
Py_DECREF(database);
}
return NULL;
}

/* This works even if there is no $db since database will be NULL and
the call will be as if there were only two arguments specified. */
value = PyObject_CallFunctionObjArgs(DBRef, collection, id, database, NULL);
value = PyObject_Call(DBRef, args, value);
Py_DECREF(args);
Py_DECREF(id);
Py_DECREF(collection);
if (database != NULL) {
Py_DECREF(database);
}
if (!value) {
return NULL;
}
}

*position += size;
Expand Down
4 changes: 2 additions & 2 deletions pymongo/bson.py
Expand Up @@ -98,8 +98,8 @@ def _get_string(data, as_class):
def _get_object(data, as_class):
(object, data) = _bson_to_dict(data, as_class)
if "$ref" in object:
return (DBRef(object["$ref"], object["$id"],
object.get("$db", None)), data)
return (DBRef(object.pop("$ref"), object.pop("$id"),
object.pop("$db", None), **object), data)
return (object, data)


Expand Down
27 changes: 20 additions & 7 deletions pymongo/dbref.py
Expand Up @@ -21,18 +21,24 @@ class DBRef(object):
"""A reference to a document stored in a Mongo database.
"""

def __init__(self, collection, id, database=None):
def __init__(self, collection, id, database=None, **kwargs):
"""Initialize a new :class:`DBRef`.
Raises :class:`TypeError` if `collection` or `database` is not
an instance of :class:`basestring`. `database` is optional and
allows references to documents to work across databases.
allows references to documents to work across databases. Any
additional keyword arguments will create additional fields in
the resultant embedded document.
:Parameters:
- `collection`: name of the collection the document is stored in
- `id`: the value of the document's ``"_id"`` field
- `database` (optional): name of the database to reference
- `**kwargs` (optional): additional keyword arguments will
create additional, custom fields
.. versionchanged:: 1.7+
Now takes keyword arguments to specify additional fields.
.. versionadded:: 1.1.1
The `database` parameter.
Expand All @@ -46,6 +52,7 @@ def __init__(self, collection, id, database=None):
self.__collection = collection
self.__id = id
self.__database = database
self.__kwargs = kwargs

@property
def collection(self):
Expand All @@ -69,6 +76,9 @@ def database(self):
"""
return self.__database

def __getattr__(self, key):
return self.__kwargs[key]

def as_doc(self):
"""Get the SON document representation of this DBRef.
Expand All @@ -78,22 +88,25 @@ def as_doc(self):
("$id", self.id)])
if self.database is not None:
doc["$db"] = self.database
doc.update(self.__kwargs)
return doc

def __repr__(self):
extra = "".join([", %s=%r" % (k,v) for k,v in self.__kwargs.iteritems()])
if self.database is None:
return "DBRef(%r, %r)" % (self.collection, self.id)
return "DBRef(%r, %r, %r)" % (self.collection, self.id, self.database)
return "DBRef(%r, %r%s)" % (self.collection, self.id, extra)
return "DBRef(%r, %r, %r%s)" % (self.collection, self.id, self.database,
extra)

def __cmp__(self, other):
if isinstance(other, DBRef):
return cmp([self.__database, self.__collection, self.__id],
[other.__database, other.__collection, other.__id])
return cmp([self.__database, self.__collection, self.__id, self.__kwargs],
[other.__database, other.__collection, other.__id, other.__kwargs])
return NotImplemented

def __hash__(self):
"""Get a hash value for this :class:`DBRef`.
.. versionadded:: 1.1
"""
return hash((self.__collection, self.__id, self.__database))
return hash((self.__collection, self.__id, self.__database, self.__kwargs))
2 changes: 2 additions & 0 deletions test/test_bson.py
Expand Up @@ -174,7 +174,9 @@ def helper(dict):
helper(SON([(u'test dst', datetime.datetime(1993, 4, 4, 2))]))
helper({"big float": float(10000000000)})
helper({"ref": DBRef("coll", 5)})
helper({"ref": DBRef("coll", 5, foo="bar", bar=4)})
helper({"ref": DBRef("coll", 5, "foo")})
helper({"ref": DBRef("coll", 5, "foo", foo="bar")})
helper({"ref": Timestamp(1,2)})
helper({"foo": MinKey()})
helper({"foo": MaxKey()})
Expand Down
11 changes: 11 additions & 0 deletions test/test_dbref.py
Expand Up @@ -61,8 +61,12 @@ def test_repr(self):
"DBRef('coll', ObjectId('1234567890abcdef12345678'))")
self.assertEqual(repr(DBRef(u"coll", ObjectId("1234567890abcdef12345678"))),
"DBRef(u'coll', ObjectId('1234567890abcdef12345678'))")
self.assertEqual(repr(DBRef("coll", 5, foo="bar")),
"DBRef('coll', 5, foo='bar')")
self.assertEqual(repr(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo")),
"DBRef('coll', ObjectId('1234567890abcdef12345678'), 'foo')")
self.assertEqual(repr(DBRef("coll", 5, "baz", foo="bar", baz=4)),
"DBRef('coll', 5, 'baz', foo='bar', baz=4)")

def test_cmp(self):
self.assertEqual(DBRef("coll", ObjectId("1234567890abcdef12345678")),
Expand All @@ -79,6 +83,13 @@ def test_cmp(self):
self.assertNotEqual(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo"),
DBRef(u"coll", ObjectId("1234567890abcdef12345678"), "bar"))

def test_kwargs(self):
self.assertEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5, foo="bar"))
self.assertNotEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5))
self.assertNotEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5, foo="baz"))
self.assertEqual("bar", DBRef("coll", 5, foo="bar").foo)
self.assertRaises(KeyError, getattr, DBRef("coll", 5, foo="bar"), "bar")


if __name__ == "__main__":
unittest.main()

0 comments on commit 6b0a9cc

Please sign in to comment.