diff --git a/cinder/api/openstack/volume/snapshots.py b/cinder/api/openstack/volume/snapshots.py index 3264c3ab2e5..9fe84a7fd9e 100644 --- a/cinder/api/openstack/volume/snapshots.py +++ b/cinder/api/openstack/volume/snapshots.py @@ -129,7 +129,11 @@ def _items(self, req, entity_maker): """Returns a list of snapshots, transformed through entity_maker.""" context = req.environ['cinder.context'] - snapshots = self.volume_api.get_all_snapshots(context) + search_opts = {} + search_opts.update(req.GET) + + snapshots = self.volume_api.get_all_snapshots(context, + search_opts=search_opts) limited_list = common.limited(snapshots, req) res = [entity_maker(context, snapshot) for snapshot in limited_list] return {'snapshots': res} diff --git a/cinder/tests/api/openstack/fakes.py b/cinder/tests/api/openstack/fakes.py index 605d611eacf..b6d2e34e86e 100644 --- a/cinder/tests/api/openstack/fakes.py +++ b/cinder/tests/api/openstack/fakes.py @@ -239,3 +239,29 @@ def stub_volume_get_all(context, search_opts=None): def stub_volume_get_all_by_project(self, context, search_opts=None): return [stub_volume_get(self, context, '1')] + + +def stub_snapshot(id, **kwargs): + snapshot = { + 'id': id, + 'volume_id': 12, + 'status': 'available', + 'volume_size': 100, + 'created_at': None, + 'display_name': 'Default name', + 'display_description': 'Default description', + 'project_id': 'fake' + } + + snapshot.update(kwargs) + return snapshot + + +def stub_snapshot_get_all(self): + return [stub_snapshot(100, project_id='fake'), + stub_snapshot(101, project_id='superfake'), + stub_snapshot(102, project_id='superduperfake')] + + +def stub_snapshot_get_all_by_project(self, context): + return [stub_snapshot(1)] diff --git a/cinder/tests/api/openstack/volume/contrib/test_extended_snapshot_attributes.py b/cinder/tests/api/openstack/volume/contrib/test_extended_snapshot_attributes.py index 56e95e6fa74..43490fbc19f 100644 --- a/cinder/tests/api/openstack/volume/contrib/test_extended_snapshot_attributes.py +++ b/cinder/tests/api/openstack/volume/contrib/test_extended_snapshot_attributes.py @@ -51,7 +51,7 @@ def fake_snapshot_get(self, context, snapshot_id): return param -def fake_snapshot_get_all(self, context): +def fake_snapshot_get_all(self, context, search_opts=None): param = _get_default_snapshot_param() return [param] diff --git a/cinder/tests/api/openstack/volume/test_snapshots.py b/cinder/tests/api/openstack/volume/test_snapshots.py index 7dcdb023982..b224c4819ce 100644 --- a/cinder/tests/api/openstack/volume/test_snapshots.py +++ b/cinder/tests/api/openstack/volume/test_snapshots.py @@ -19,6 +19,7 @@ import webob from cinder.api.openstack.volume import snapshots +from cinder import db from cinder import exception from cinder import flags from cinder.openstack.common import log as logging @@ -67,7 +68,7 @@ def stub_snapshot_get(self, context, snapshot_id): return param -def stub_snapshot_get_all(self, context): +def stub_snapshot_get_all(self, context, search_opts=None): param = _get_default_snapshot_param() return [param] @@ -77,9 +78,10 @@ def setUp(self): super(SnapshotApiTest, self).setUp() self.controller = snapshots.SnapshotsController() - self.stubs.Set(volume.api.API, "get_snapshot", stub_snapshot_get) - self.stubs.Set(volume.api.API, "get_all_snapshots", - stub_snapshot_get_all) + self.stubs.Set(db, 'snapshot_get_all_by_project', + fakes.stub_snapshot_get_all_by_project) + self.stubs.Set(db, 'snapshot_get_all', + fakes.stub_snapshot_get_all) def test_snapshot_create(self): self.stubs.Set(volume.api.API, "create_snapshot", stub_snapshot_create) @@ -117,6 +119,7 @@ def test_snapshot_create_force(self): snapshot['display_description']) def test_snapshot_delete(self): + self.stubs.Set(volume.api.API, "get_snapshot", stub_snapshot_get) self.stubs.Set(volume.api.API, "delete_snapshot", stub_snapshot_delete) snapshot_id = UUID @@ -134,6 +137,7 @@ def test_snapshot_delete_invalid_id(self): snapshot_id) def test_snapshot_show(self): + self.stubs.Set(volume.api.API, "get_snapshot", stub_snapshot_get) req = fakes.HTTPRequest.blank('/v1/snapshots/%s' % UUID) resp_dict = self.controller.show(req, UUID) @@ -149,6 +153,8 @@ def test_snapshot_show_invalid_id(self): snapshot_id) def test_snapshot_detail(self): + self.stubs.Set(volume.api.API, "get_all_snapshots", + stub_snapshot_get_all) req = fakes.HTTPRequest.blank('/v1/snapshots/detail') resp_dict = self.controller.detail(req) @@ -159,6 +165,33 @@ def test_snapshot_detail(self): resp_snapshot = resp_snapshots.pop() self.assertEqual(resp_snapshot['id'], UUID) + def test_admin_list_snapshots_limited_to_project(self): + req = fakes.HTTPRequest.blank('/v1/fake/snapshots', + use_admin_context=True) + res = self.controller.index(req) + + self.assertTrue('snapshots' in res) + self.assertEqual(1, len(res['snapshots'])) + + def test_admin_list_snapshots_all_tenants(self): + req = fakes.HTTPRequest.blank('/v2/fake/snapshots?all_tenants=1', + use_admin_context=True) + res = self.controller.index(req) + self.assertTrue('snapshots' in res) + self.assertEqual(3, len(res['snapshots'])) + + def test_all_tenants_non_admin_gets_all_tenants(self): + req = fakes.HTTPRequest.blank('/v2/fake/snapshots?all_tenants=1') + res = self.controller.index(req) + self.assertTrue('snapshots' in res) + self.assertEqual(1, len(res['snapshots'])) + + def test_non_admin_get_by_project(self): + req = fakes.HTTPRequest.blank('/v2/fake/snapshots') + res = self.controller.index(req) + self.assertTrue('snapshots' in res) + self.assertEqual(1, len(res['snapshots'])) + class SnapshotSerializerTest(test.TestCase): def _verify_snapshot(self, snap, tree): diff --git a/cinder/volume/api.py b/cinder/volume/api.py index 0d83d79b840..80e656f5786 100644 --- a/cinder/volume/api.py +++ b/cinder/volume/api.py @@ -216,11 +216,18 @@ def get_snapshot(self, context, snapshot_id): rv = self.db.snapshot_get(context, snapshot_id) return dict(rv.iteritems()) - def get_all_snapshots(self, context): + def get_all_snapshots(self, context, search_opts=None): check_policy(context, 'get_all_snapshots') - if context.is_admin: + + search_opts = search_opts or {} + + if (context.is_admin and 'all_tenants' in search_opts): + # Need to remove all_tenants to pass the filtering below. + del search_opts['all_tenants'] return self.db.snapshot_get_all(context) - return self.db.snapshot_get_all_by_project(context, context.project_id) + else: + return self.db.snapshot_get_all_by_project(context, + context.project_id) @wrap_check_policy def check_attach(self, context, volume):