Skip to content

Commit

Permalink
Use DB count to get resource counts.
Browse files Browse the repository at this point in the history
Fixes bug 1075369.

During quota check we used to simply retrieve the entire collection
of resources from the database, then counting them in Python. This
patch introduces a specialized _get_collection_count() method, which
instead take advantage of the DB's built-in count capabilities.

In order to take advantage of this, plugins can now implement
get_*_count() methods for their resources. This is used (if present)
by the quota checking function.

Patch incorporates review feedback from Dan W, Alex Xu, Zhongyue Luo,
Edgar Magana, Akihiro Motoki and gongysh.

Change-Id: I87e2d0294e116e8147fed2ee90c9eb0cf1a54362
  • Loading branch information
Juergen Brendel committed Nov 16, 2012
1 parent ddeeee8 commit 513307f
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 28 deletions.
30 changes: 26 additions & 4 deletions quantum/db/db_base_plugin_v2.py
Expand Up @@ -206,11 +206,18 @@ def _apply_filters_to_query(self, query, model, filters):
query = query.filter(column.in_(value))
return query

def _get_collection(self, context, model, dict_func, filters=None,
fields=None):
def _get_collection_query(self, context, model, filters=None):
collection = self._model_query(context, model)
collection = self._apply_filters_to_query(collection, model, filters)
return [dict_func(c, fields) for c in collection.all()]
return collection

def _get_collection(self, context, model, dict_func, filters=None,
fields=None):
query = self._get_collection_query(context, model, filters)
return [dict_func(c, fields) for c in query.all()]

def _get_collection_count(self, context, model, filters=None):
return self._get_collection_query(context, model, filters).count()

@staticmethod
def _generate_mac(context, network_id):
Expand Down Expand Up @@ -952,6 +959,10 @@ def get_networks(self, context, filters=None, fields=None):
self._make_network_dict,
filters=filters, fields=fields)

def get_networks_count(self, context, filters=None):
return self._get_collection_count(context, models_v2.Network,
filters=filters)

def create_subnet_bulk(self, context, subnets):
return self._create_bulk('subnet', context, subnets)

Expand Down Expand Up @@ -1143,6 +1154,10 @@ def get_subnets(self, context, filters=None, fields=None):
self._make_subnet_dict,
filters=filters, fields=fields)

def get_subnets_count(self, context, filters=None):
return self._get_collection_count(context, models_v2.Subnet,
filters=filters)

def create_port_bulk(self, context, ports):
return self._create_bulk('port', context, ports)

Expand Down Expand Up @@ -1274,7 +1289,7 @@ def get_port(self, context, id, fields=None):
port = self._get_port(context, id)
return self._make_port_dict(port, fields)

def get_ports(self, context, filters=None, fields=None):
def _get_ports_query(self, context, filters=None):
Port = models_v2.Port
IPAllocation = models_v2.IPAllocation

Expand All @@ -1294,4 +1309,11 @@ def get_ports(self, context, filters=None, fields=None):
query = query.filter(IPAllocation.subnet_id.in_(subnet_ids))

query = self._apply_filters_to_query(query, Port, filters)
return query

def get_ports(self, context, filters=None, fields=None):
query = self._get_ports_query(context, filters)
return [self._make_port_dict(c, fields) for c in query.all()]

def get_ports_count(self, context, filters=None):
return self._get_ports_query(context, filters).count()
8 changes: 8 additions & 0 deletions quantum/db/l3_db.py
Expand Up @@ -251,6 +251,10 @@ def get_routers(self, context, filters=None, fields=None):
self._make_router_dict,
filters=filters, fields=fields)

def get_routers_count(self, context, filters=None):
return self._get_collection_count(context, Router,
filters=filters)

def _check_for_dup_router_subnet(self, context, router_id,
network_id, subnet_id):
try:
Expand Down Expand Up @@ -615,6 +619,10 @@ def get_floatingips(self, context, filters=None, fields=None):
self._make_floatingip_dict,
filters=filters, fields=fields)

def get_floatingips_count(self, context, filters=None):
return self._get_collection_count(context, FloatingIP,
filters=filters)

def prevent_l3_port_deletion(self, context, port_id):
""" Checks to make sure a port is allowed to be deleted, raising
an exception if this is not the case. This should be called by
Expand Down
8 changes: 8 additions & 0 deletions quantum/db/securitygroups_db.py
Expand Up @@ -150,6 +150,10 @@ def get_security_groups(self, context, filters=None, fields=None):
self._make_security_group_dict,
filters=filters, fields=fields)

def get_security_groups_count(self, context, filters=None):
return self._get_collection_count(context, SecurityGroup,
filters=filters)

def get_security_group(self, context, id, fields=None, tenant_id=None):
"""Tenant id is given to handle the case when we
are creating a security group or security group rule on behalf of
Expand Down Expand Up @@ -384,6 +388,10 @@ def get_security_group_rules(self, context, filters=None, fields=None):
self._make_security_group_rule_dict,
filters=filters, fields=fields)

def get_security_group_rules_count(self, context, filters=None):
return self._get_collection_count(context, SecurityGroupRule,
filters=filters)

def get_security_group_rule(self, context, id, fields=None):
security_group_rule = self._get_security_group_rule(context, id)
return self._make_security_group_rule_dict(security_group_rule, fields)
Expand Down
12 changes: 12 additions & 0 deletions quantum/plugins/cisco/models/virt_phy_sw_v2.py
Expand Up @@ -289,6 +289,10 @@ def get_networks(self, context, filters=None, fields=None):
"""For this model this method will be delegated to vswitch plugin"""
pass

def get_networks_count(self, context, filters=None):
"""For this model this method will be delegated to vswitch plugin"""
pass

def create_port(self, context, port):
"""For this model this method will be delegated to vswitch plugin"""
pass
Expand All @@ -301,6 +305,10 @@ def get_ports(self, context, filters=None, fields=None):
"""For this model this method will be delegated to vswitch plugin"""
pass

def get_ports_count(self, context, filters=None):
"""For this model this method will be delegated to vswitch plugin"""
pass

def update_port(self, context, id, port):
"""For this model this method will be delegated to vswitch plugin"""
pass
Expand Down Expand Up @@ -328,3 +336,7 @@ def delete_subnet(self, context, id, kwargs):
def get_subnets(self, context, filters=None, fields=None):
"""For this model this method will be delegated to vswitch plugin"""
pass

def get_subnets_count(self, context, filters=None):
"""For this model this method will be delegated to vswitch plugin"""
pass
5 changes: 5 additions & 0 deletions quantum/plugins/cisco/network_plugin.py
Expand Up @@ -92,6 +92,11 @@ def __getattr__(self, name):
"""
if hasattr(self._model, name):
return getattr(self._model, name)
else:
# Must make sure we re-raise the error that led us here, since
# otherwise getattr() and even hasattr() doesn't work corretly.
raise AttributeError("'%s' object has no attribute '%s'" %
(self._model, name))

"""
Core API implementation
Expand Down
57 changes: 54 additions & 3 deletions quantum/quantum_plugin_base_v2.py
Expand Up @@ -70,7 +70,7 @@ def get_subnet(self, context, id, fields=None):
def get_subnets(self, context, filters=None, fields=None):
"""
Retrieve a list of subnets. The contents of the list depends on
the identify of the user making the request (as indicated by the
the identity of the user making the request (as indicated by the
context) as well as any filters.
: param context: quantum api request context
: param filters: a dictionary with keys that are valid keys for
Expand All @@ -87,6 +87,23 @@ def get_subnets(self, context, filters=None, fields=None):
"""
pass

@abstractmethod
def get_subnets_count(self, context, filters=None):
"""
Return the number of subnets. The result depends on the identity of
the user making the request (as indicated by the context) as well as
any filters.
: param context: quantum api request context
: param filters: a dictionary with keys that are valid keys for
a network as listed in the RESOURCE_ATTRIBUTE_MAP object
in quantum/api/v2/attributes.py. Values in this dictiontary
are an iterable containing values that will be used for an exact
match comparison for that value. Each result returned by this
function will have matched one of the values for each key in
filters.
"""
pass

@abstractmethod
def delete_subnet(self, context, id):
"""
Expand Down Expand Up @@ -138,7 +155,7 @@ def get_network(self, context, id, fields=None):
def get_networks(self, context, filters=None, fields=None):
"""
Retrieve a list of networks. The contents of the list depends on
the identify of the user making the request (as indicated by the
the identity of the user making the request (as indicated by the
context) as well as any filters.
: param context: quantum api request context
: param filters: a dictionary with keys that are valid keys for
Expand All @@ -155,6 +172,23 @@ def get_networks(self, context, filters=None, fields=None):
"""
pass

@abstractmethod
def get_networks_count(self, context, filters=None):
"""
Return the number of networks. The result depends on the identity
of the user making the request (as indicated by the context) as well
as any filters.
: param context: quantum api request context
: param filters: a dictionary with keys that are valid keys for
a network as listed in the RESOURCE_ATTRIBUTE_MAP object
in quantum/api/v2/attributes.py. Values in this dictiontary
are an iterable containing values that will be used for an exact
match comparison for that value. Each result returned by this
function will have matched one of the values for each key in
filters.
"""
pass

@abstractmethod
def delete_network(self, context, id):
"""
Expand Down Expand Up @@ -206,7 +240,7 @@ def get_port(self, context, id, fields=None):
def get_ports(self, context, filters=None, fields=None):
"""
Retrieve a list of ports. The contents of the list depends on
the identify of the user making the request (as indicated by the
the identity of the user making the request (as indicated by the
context) as well as any filters.
: param context: quantum api request context
: param filters: a dictionary with keys that are valid keys for
Expand All @@ -223,6 +257,23 @@ def get_ports(self, context, filters=None, fields=None):
"""
pass

@abstractmethod
def get_ports_count(self, context, filters=None):
"""
Return the number of ports. The result depends on the identity of
the user making the request (as indicated by the context) as well as
any filters.
: param context: quantum api request context
: param filters: a dictionary with keys that are valid keys for
a network as listed in the RESOURCE_ATTRIBUTE_MAP object
in quantum/api/v2/attributes.py. Values in this dictiontary
are an iterable containing values that will be used for an exact
match comparison for that value. Each result returned by this
function will have matched one of the values for each key in
filters.
"""
pass

@abstractmethod
def delete_port(self, context, id):
"""
Expand Down
16 changes: 13 additions & 3 deletions quantum/quota.py
Expand Up @@ -266,9 +266,19 @@ def resources(self):


def _count_resource(context, plugin, resources, tenant_id):
obj_getter = getattr(plugin, "get_%s" % resources)
obj_list = obj_getter(context, filters={'tenant_id': [tenant_id]})
return len(obj_list) if obj_list else 0
count_getter_name = "get_%s_count" % resources

# Some plugins support a count method for particular resources,
# using a DB's optimized counting features. We try to use that one
# if present. Otherwise just use regular getter to retrieve all objects
# and count in python, allowing older plugins to still be supported
if hasattr(plugin, count_getter_name):
obj_count_getter = getattr(plugin, count_getter_name)
return obj_count_getter(context, filters={'tenant_id': [tenant_id]})
else:
obj_getter = getattr(plugin, "get_%s" % resources)
obj_list = obj_getter(context, filters={'tenant_id': [tenant_id]})
return len(obj_list) if obj_list else 0


resources = []
Expand Down
29 changes: 12 additions & 17 deletions quantum/tests/unit/test_api_v2.py
Expand Up @@ -369,6 +369,7 @@ def test_create(self):

instance = self.plugin.return_value
instance.create_network.return_value = return_value
instance.get_networks_count.return_value = 0

res = self.api.post_json(_get_path('networks'), data)

Expand All @@ -390,6 +391,7 @@ def test_create_use_defaults(self):

instance = self.plugin.return_value
instance.create_network.return_value = return_value
instance.get_networks_count.return_value = 0

res = self.api.post_json(_get_path('networks'), initial_input)

Expand Down Expand Up @@ -423,6 +425,7 @@ def test_create_with_keystone_env(self):

instance = self.plugin.return_value
instance.create_network.return_value = return_value
instance.get_networks_count.return_value = 0

res = self.api.post_json(_get_path('networks'), initial_input,
extra_environ=env)
Expand Down Expand Up @@ -479,6 +482,7 @@ def side_effect(context, network):

instance = self.plugin.return_value
instance.create_network.side_effect = side_effect
instance.get_networks_count.return_value = 0

res = self.api.post_json(_get_path('networks'), data)
self.assertEqual(res.status_int, exc.HTTPCreated.code)
Expand Down Expand Up @@ -525,6 +529,7 @@ def test_create_attr_not_specified(self):

instance = self.plugin.return_value
instance.get_network.return_value = {'tenant_id': unicode(tenant_id)}
instance.get_ports_count.return_value = 1
instance.create_port.return_value = return_value
res = self.api.post_json(_get_path('ports'), initial_input)

Expand All @@ -545,6 +550,7 @@ def test_create_return_extra_attr(self):

instance = self.plugin.return_value
instance.create_network.return_value = return_value
instance.get_networks_count.return_value = 0

res = self.api.post_json(_get_path('networks'), data)

Expand Down Expand Up @@ -699,6 +705,7 @@ def _resource_op_notifier(self, opname, resource, expected_errors=False):
initial_input = {resource: {'name': 'myname'}}
instance = self.plugin.return_value
instance.get_networks.return_value = initial_input
instance.get_networks_count.return_value = 0
expected_code = exc.HTTPCreated.code
with mock.patch.object(notifer_api, 'notify') as mynotifier:
if opname == 'create':
Expand Down Expand Up @@ -742,37 +749,24 @@ def test_network_update_notifer(self):
class QuotaTest(APIv2TestBase):
def test_create_network_quota(self):
cfg.CONF.set_override('quota_network', 1, group='QUOTAS')
net_id = _uuid()
initial_input = {'network': {'name': 'net1', 'tenant_id': _uuid()}}
full_input = {'network': {'admin_state_up': True, 'subnets': []}}
full_input['network'].update(initial_input['network'])

return_value = {'id': net_id, 'status': "ACTIVE"}
return_value.update(full_input['network'])
return_networks = {'networks': [return_value]}
instance = self.plugin.return_value
instance.get_networks.return_value = return_networks
instance.get_networks_count.return_value = 1
res = self.api.post_json(
_get_path('networks'), initial_input, expect_errors=True)
instance.get_networks.assert_called_with(mock.ANY,
filters=mock.ANY)
instance.get_networks_count.assert_called_with(mock.ANY,
filters=mock.ANY)
self.assertTrue("Quota exceeded for resources" in
res.json['QuantumError'])

def test_create_network_quota_without_limit(self):
cfg.CONF.set_override('quota_network', -1, group='QUOTAS')
net_id = _uuid()
initial_input = {'network': {'name': 'net1', 'tenant_id': _uuid()}}
full_input = {'network': {'admin_state_up': True, 'subnets': []}}
full_input['network'].update(initial_input['network'])
return_networks = []
for i in xrange(0, 3):
return_value = {'id': net_id + str(i), 'status': "ACTIVE"}
return_value.update(full_input['network'])
return_networks.append(return_value)
self.assertEquals(3, len(return_networks))
instance = self.plugin.return_value
instance.get_networks.return_value = return_networks
instance.get_networks_count.return_value = 3
res = self.api.post_json(
_get_path('networks'), initial_input)
self.assertEqual(res.status_int, exc.HTTPCreated.code)
Expand Down Expand Up @@ -836,6 +830,7 @@ def test_extended_create(self):

instance = self.plugin.return_value
instance.create_network.return_value = return_value
instance.get_networks_count.return_value = 0

res = self.api.post_json(_get_path('networks'), initial_input)

Expand Down

0 comments on commit 513307f

Please sign in to comment.