Permalink
Browse files

added count method on has_many and habtm relationships

  • Loading branch information...
1 parent deea7dd commit 40d6314b59a6fa45ee8d2c72badb2cb88e011b4a @steeeveb steeeveb committed with xadhoom Jun 15, 2011
Showing with 79 additions and 0 deletions.
  1. +31 −0 twistar/relationships.py
  2. +48 −0 twistar/tests/test_relationships.py
View
31 twistar/relationships.py
@@ -127,6 +127,18 @@ def get(self, **kwargs):
kwargs['where'] = where
return self.otherklass.find(**kwargs)
+ def count(self, **kwargs):
+ if self.args.has_key('as'):
+ w = "%s_id = ? AND %s_type = ?" % (self.args['as'], self.args['as'])
+ where = [w, self.inst.id, self.thisclass.__name__]
+ else:
+ where = ["%s = ?" % self.thisname, self.inst.id]
+
+ if kwargs.has_key('where'):
+ kwargs['where'] = joinWheres(where, kwargs['where'])
+ else:
+ kwargs['where'] = where
+ return self.otherklass.count(**kwargs)
def _set_polymorphic(self, others):
ds = []
@@ -258,6 +270,25 @@ def _get(rows):
where = ["%s = ?" % self.thisname, self.inst.id]
return self.dbconfig.select(tablename, where=where).addCallback(_get)
+ def count(self, **kwargs):
+ def _get(rows):
+ if len(rows) == 0:
+ return defer.succeed(0)
+ if not kwargs.has_key('where'):
+ return defer.succeed(len(rows))
+ ids = [str(row[self.othername]) for row in rows]
+ where = ["id IN (%s)" % ",".join(ids)]
+ if kwargs.has_key('where'):
+ kwargs['where'] = joinWheres(where, kwargs['where'])
+ else:
+ kwargs['where'] = where
+ kwargs['select'] = 'count(*)'
+ d = self.dbconfig.select(self.otherklass.tablename(), **kwargs)
+ return d.addCallback(lambda x: x[0]['count(*)'])
+
+ tablename = self.tablename()
+ where = ["%s = ?" % self.thisname, self.inst.id]
+ return self.dbconfig.select(tablename, where=where).addCallback(_get)
def _set(self, _, others):
args = []
View
48 twistar/tests/test_relationships.py
@@ -108,6 +108,16 @@ def test_has_many(self):
picids = [pic.id for pic in pics]
self.assertEqual(ids, picids)
+ @inlineCallbacks
+ def test_has_many_count(self):
+ # First, make a few pics
+ ids = [self.picture.id]
+ for _ in range(3):
+ pic = yield Picture(user_id=self.user.id).save()
+ ids.append(pic.id)
+
+ totalnum = yield self.user.pictures.count()
+ self.assertEqual(totalnum, 4)
@inlineCallbacks
def test_has_many_get_with_args(self):
@@ -121,6 +131,16 @@ def test_has_many_get_with_args(self):
self.assertEqual(len(pics),1)
self.assertEqual(pics[0].name,'a pic')
+ @inlineCallbacks
+ def test_has_many_count_with_args(self):
+ # First, make a few pics
+ ids = [self.picture.id]
+ for _ in range(3):
+ pic = yield Picture(user_id=self.user.id).save()
+ ids.append(pic.id)
+
+ picsnum = yield self.user.pictures.count(where=['name = ?','a pic'])
+ self.assertEqual(picsnum,1)
@inlineCallbacks
def test_set_has_many(self):
@@ -198,6 +218,20 @@ def test_habtm(self):
newcolorids = [color.id for color in newcolors]
self.assertEqual(newcolorids, colorids)
+ @inlineCallbacks
+ def test_habtm_count(self):
+ color = yield FavoriteColor(name="red").save()
+ colors = [self.favcolor, color]
+ colorids = [color.id for color in colors]
+ yield FavoriteColor(name="green").save()
+
+ args = {'user_id': self.user.id, 'favorite_color_id': colors[0].id}
+ yield self.config.insert('favorite_colors_users', args)
+ args = {'user_id': self.user.id, 'favorite_color_id': colors[1].id}
+ yield self.config.insert('favorite_colors_users', args)
+
+ newcolorsnum = yield self.user.favorite_colors.count()
+ self.assertEqual(newcolorsnum, 2)
@inlineCallbacks
def test_habtm_get_with_args(self):
@@ -213,6 +247,20 @@ def test_habtm_get_with_args(self):
newcolor = yield self.user.favorite_colors.get(where=['name = ?','red'], limit=1)
self.assertEqual(newcolor.id, color.id)
+ @inlineCallbacks
+ def test_habtm_count_with_args(self):
+ color = yield FavoriteColor(name="red").save()
+ colors = [self.favcolor, color]
+ colorids = [color.id for color in colors]
+
+ args = {'user_id': self.user.id, 'favorite_color_id': colors[0].id}
+ yield self.config.insert('favorite_colors_users', args)
+ args = {'user_id': self.user.id, 'favorite_color_id': colors[1].id}
+ yield self.config.insert('favorite_colors_users', args)
+
+ newcolorsnum = yield self.user.favorite_colors.count(where=['name = ?','red'])
+ self.assertEqual(newcolorsnum, 1)
+
@inlineCallbacks
def test_set_habtm(self):

0 comments on commit 40d6314

Please sign in to comment.