diff --git a/README.md b/README.md index f344b314..e819bb80 100644 --- a/README.md +++ b/README.md @@ -45,3 +45,22 @@ This is the oauth flow we are using to authenticate users with Signonotron2 - **GET** (to signonotron) `/user.json` uses access token to get user data and see if they have permissions to sign in to backdrop 4. User is now signed in +## Requesting data + +Requests return a JSON object containing a `data` array. + +`GET /bucket_name` will return an array of data. Each element is an object. + +`GET /bucket_name?collect=score&group_by=name` will return an array. In this +case, each element of the array is an object containing a `name` value, a +`score` array with the scores for that name and a `_count` value with the +number of scores. + +`GET /bucket_name?filter_by=name:Foo` returns all elements with `name` equal to "Foo". + +Other parameters: + +- `start_at` (YYYY-MM-DDTHH:MM:SS+HH:MM) and `end_at` (YYYY-MM-DDTHH:MM:SS+HH:MM) +- `period` ("week", "month") +- `sort_by` (field) +- `limit` (number) diff --git a/backdrop/core/database.py b/backdrop/core/database.py index b49fb070..888110b7 100644 --- a/backdrop/core/database.py +++ b/backdrop/core/database.py @@ -49,31 +49,32 @@ def _ignore_docs_without_grouping_keys(self, keys, query): query[key] = {"$ne": None} return query - def group(self, keys, query, collect): + def group(self, keys, query, collect_fields): return self._collection.group( key=keys, condition=self._ignore_docs_without_grouping_keys(keys, query), - initial=self._build_accumulator_initial_state(collect), - reduce=self._build_reducer_function(collect) + initial=self._build_accumulator_initial_state(collect_fields), + reduce=self._build_reducer_function(collect_fields) ) - def _build_collector_code(self, collect): + def _build_collector_code(self, collect_fields): template = "if (current.{c} !== undefined) " \ "{{ previous.{c}.push(current.{c}); }}" - code = [template.format(c=collect_me) for collect_me in collect] + code = [template.format(c=collect_field) + for collect_field in collect_fields] return "\n".join(code) - def _build_accumulator_initial_state(self, collect): + def _build_accumulator_initial_state(self, collect_fields): initial = {'_count': 0} - for collect_me in collect: - initial.update({collect_me: []}) + for collect_field in collect_fields: + initial.update({collect_field: []}) return initial - def _build_reducer_function(self, collect): + def _build_reducer_function(self, collect_fields): reducer_skeleton = "function (current, previous)" + \ "{{ previous._count++; {collectors} }}" reducer_code = reducer_skeleton.format( - collectors=self._build_collector_code(collect) + collectors=self._build_collector_code(collect_fields) ) reducer = Code(reducer_code) return reducer @@ -143,7 +144,8 @@ def _require_keys_in_query(self, keys, query): return query def _group(self, keys, query, sort=None, limit=None, collect=None): - results = self._mongo.group(keys, query, collect) + collect_fields = unique_collect_fields(collect) + results = self._mongo.group(keys, query, list(collect_fields)) results = nested_merge(keys, collect, results) @@ -173,7 +175,7 @@ class InvalidSortError(ValueError): def extract_collected_values(collect, result): collected = {} - for collect_field in collect: + for collect_field in unique_collect_fields(collect): collected[collect_field] = result.pop(collect_field) return collected, result @@ -181,14 +183,52 @@ def extract_collected_values(collect, result): def insert_collected_values(collected, group): for collect_field in collected.keys(): if collect_field not in group: - group[collect_field] = set() - group[collect_field].update(collected[collect_field]) + group[collect_field] = [] + group[collect_field] += collected[collect_field] -def convert_collected_values_to_list(collect, groups): +def apply_collection_methods(collect, groups): for group in groups: - for collected_field in collect: - group[collected_field] = sorted(list(group[collected_field])) + for collect_field, collect_method in collect: + if collect_method == 'default': + collect_keys = [collect_field, '{0}:set'.format(collect_field)] + else: + collect_keys = ['{0}:{1}'.format(collect_field, + collect_method)] + for collect_key in collect_keys: + group[collect_key] = apply_collection_method( + group[collect_field], collect_method) + for collect_field in unique_collect_fields(collect): + del group[collect_field] + # This is to provide backwards compatibility with earlier interface + if (collect_field, 'default') in collect: + group[collect_field] = group['{0}:set'.format(collect_field)] + + +def apply_collection_method(collected_data, collect_method): + if "sum" == collect_method: + try: + return sum(collected_data) + except TypeError: + raise InvalidOperationError("Unable to sum that data") + elif "count" == collect_method: + return len(collected_data) + elif "set" == collect_method: + return sorted(list(set(collected_data))) + elif "mean" == collect_method: + try: + return sum(collected_data) / float(len(collected_data)) + except TypeError: + raise InvalidOperationError("Unable to find the mean of that data") + elif "default" == collect_method: + return sorted(list(set(collected_data))) + else: + raise ValueError("Unknown collection method") + + +def unique_collect_fields(collect): + """Return the unique set of field names to collect.""" + return set([collect_field for collect_field, _ in collect]) def nested_merge(keys, collect, results): @@ -200,7 +240,7 @@ def nested_merge(keys, collect, results): insert_collected_values(collected, group) - convert_collected_values_to_list(collect, groups) + apply_collection_methods(collect, groups) return groups @@ -213,7 +253,7 @@ def _merge(groups, keys, result): group = _find_group(group for group in groups if group[key] == value) if not group: if is_leaf: - group = _new_leaf_node(key, value, result) + group = _new_leaf_node(key, value, result.get('_count')) else: group = _new_branch_node(key, value) groups.append(group) @@ -240,10 +280,14 @@ def _new_branch_node(key, value): } -def _new_leaf_node(key, value, result): +def _new_leaf_node(key, value, count=None): """Create a new node that has no further sub-groups""" - result[key] = value - return result + r = { + key: value, + } + if count is not None: + r['_count'] = count + return r def _merge_and_sort_subgroup(group, keys, result): @@ -254,3 +298,7 @@ def _merge_and_sort_subgroup(group, keys, result): def _add_branch_node_counts(group): group['_count'] = sum(doc.get('_count', 0) for doc in group['_subgroup']) group['_group_count'] = len(group['_subgroup']) + + +class InvalidOperationError(TypeError): + pass diff --git a/backdrop/read/api.py b/backdrop/read/api.py index 8935972a..ff7e6d15 100644 --- a/backdrop/read/api.py +++ b/backdrop/read/api.py @@ -12,6 +12,7 @@ from .validation import validate_request_args from ..core import database, log_handler, cache_control from ..core.bucket import Bucket +from ..core.database import InvalidOperationError def setup_logging(): @@ -69,6 +70,11 @@ def health_check(): message='cannot connect to database'), 500 +def log_error_and_respond(message, status_code): + app.logger.error(message) + return jsonify(status='error', message=message), status_code + + @app.route('/', methods=['GET', 'OPTIONS']) @cache_control.set("max-age=3600, must-revalidate") @cache_control.etag @@ -84,11 +90,14 @@ def query(bucket_name): raw_queries_allowed(bucket_name)) if not result.is_valid: - app.logger.error(result.message) - return jsonify(status='error', message=result.message), 400 + return log_error_and_respond(result.message, 400) bucket = Bucket(db, bucket_name) - result_data = bucket.query(Query.parse(request.args)).data() + + try: + result_data = bucket.query(Query.parse(request.args)).data() + except InvalidOperationError: + return log_error_and_respond('invalid collect for that data', 400) # Taken from flask.helpers.jsonify to add JSONEncoder # NB. this can be removed once fix #471 works it's way into a release diff --git a/backdrop/read/query.py b/backdrop/read/query.py index dad17d17..8f970375 100644 --- a/backdrop/read/query.py +++ b/backdrop/read/query.py @@ -37,7 +37,12 @@ def parse_request_args(request_args): args['limit'] = if_present(int, request_args.get('limit')) - args['collect'] = request_args.getlist('collect') + args['collect'] = [] + for collect_arg in request_args.getlist('collect'): + if ':' in collect_arg: + args['collect'].append(tuple(collect_arg.split(':'))) + else: + args['collect'].append((collect_arg, 'default')) return args diff --git a/backdrop/read/validation.py b/backdrop/read/validation.py index 04675716..aeae730d 100644 --- a/backdrop/read/validation.py +++ b/backdrop/read/validation.py @@ -176,6 +176,11 @@ def validate(self, request_args, context): validate_field_value=self.validate_field_value) def validate_field_value(self, value, request_args, _): + if ":" in value: + value, operator = value.split(":") + if operator not in ["sum", "count", "set", "mean"]: + self.add_error("Unknown collection method") + if not key_is_valid(value): self.add_error('Cannot collect an invalid field name') if value.startswith('_'): diff --git a/features/read_api/collect.feature b/features/read_api/collect.feature index 8a951d2f..4bb91e43 100644 --- a/features/read_api/collect.feature +++ b/features/read_api/collect.feature @@ -23,8 +23,22 @@ Feature: collect fields into grouped responses when I go to "/foo?collect=authority" then I should get back a status of "400" + Scenario: should be able to collect false values Given "licensing_2.json" is in "foo" bucket when I go to "/foo?group_by=licence_name&filter_by=isPaymentRequired:false&collect=isPaymentRequired" then I should get back a status of "200" and the "1st" result should have "isPaymentRequired" with item "false" + + Scenario: should be able to perform maths on collect + Given "sort_and_limit.json" is in "foo" bucket + when I go to "/foo?group_by=type&filter_by=type:wild&collect=value:sum&collect=value:mean" + then I should get back a status of "200" + and the "1st" result should have "value:sum" with json "27" + and the "1st" result should have "value:mean" with json "6.75" + + Scenario: should receive a nice error when performing invalid operation + Given "dinosaurs.json" is in "foo" bucket + when I go to "/foo?group_by=type&collect=name:sum" + then I should get back a status of "400" + and the error message should be "invalid collect for that data" diff --git a/features/read_api/group.feature b/features/read_api/group.feature index 3a671ae7..8a290f7e 100644 --- a/features/read_api/group.feature +++ b/features/read_api/group.feature @@ -57,6 +57,7 @@ Feature: grouping queries for read api then I should get back a status of "200" and the JSON should have "1" results and the "1st" result should have "values" with item "{"_start_at": "2013-03-11T00:00:00+00:00", "_end_at": "2013-03-18T00:00:00+00:00", "_count": 2.0}" + and the "1st" result should have "values" with item "{"_start_at": "2013-03-18T00:00:00+00:00", "_end_at": "2013-03-25T00:00:00+00:00", "_count": 1.0}" Scenario: grouping data by time period (week) and a name that doesn't exist diff --git a/features/steps/read_api.py b/features/steps/read_api.py index ec24beaa..1d5a8d3b 100644 --- a/features/steps/read_api.py +++ b/features/steps/read_api.py @@ -106,6 +106,19 @@ def step(context, nth, key, value): assert_that(the_data[i][key], has_item(json.loads(value))) +@then('the "{nth}" result should have "{key}" with json "{expected_json}"') +def impl(context, nth, key, expected_json): + the_data = json.loads(context.response.data)['data'] + i = parse_position(nth, the_data) + assert_that(the_data[i][key], is_(json.loads(expected_json))) + + @then('the "{header}" header should be "{value}"') def step(context, header, value): assert_that(context.response.headers.get(header), is_(value)) + + +@then(u'the error message should be "{expected_message}"') +def impl(context, expected_message): + error_message = json.loads(context.response.data)['message'] + assert_that(error_message, is_(expected_message)) diff --git a/tests/core/integration/test_database_integration.py b/tests/core/integration/test_database_integration.py index 93e843f3..f4f0e2e9 100644 --- a/tests/core/integration/test_database_integration.py +++ b/tests/core/integration/test_database_integration.py @@ -82,7 +82,7 @@ def test_find_with_limit(self): def test_group(self): self._setup_musical_instruments() - results = self.mongo_driver.group(keys=["type"], query={}, collect=[]) + results = self.mongo_driver.group(keys=["type"], query={}, collect_fields=[]) assert_that(results, contains_inanyorder( has_entries({"_count": is_(2), "type": "wind"}), @@ -94,7 +94,7 @@ def test_group_with_query(self): results = self.mongo_driver.group(keys=["type"], query={"range": "high"}, - collect=[]) + collect_fields=[]) assert_that(results, contains_inanyorder( has_entries({"_count": is_(1), "type": "wind"}), @@ -104,7 +104,7 @@ def test_group_with_query(self): def test_group_and_collect_additional_properties(self): self._setup_musical_instruments() - results = self.mongo_driver.group(keys=["type"], query={}, collect=["range"]) + results = self.mongo_driver.group(keys=["type"], query={}, collect_fields=["range"]) assert_that(results, contains( has_entries( @@ -137,7 +137,7 @@ def test_group_and_collect_with_false_value(self): def test_group_without_keys(self): self._setup_people() - results = self.mongo_driver.group(keys=[], query={}, collect=[]) + results = self.mongo_driver.group(keys=[], query={}, collect_fields=[]) assert_that(results, contains( has_entries({"_count": is_(4)}), @@ -148,7 +148,7 @@ def test_group_ignores_documents_without_grouping_keys(self): self._setup_people() self.mongo_collection.save({"name": "Yoko"}) - results = self.mongo_driver.group(keys=["plays"], query={}, collect=[]) + results = self.mongo_driver.group(keys=["plays"], query={}, collect_fields=[]) assert_that(results, contains( has_entries({"_count": is_(2), "plays": "guitar"}), @@ -295,32 +295,33 @@ def test_grouping_by_multiple_keys(self): def test_grouping_with_collect(self): self.setUpPeopleLocationData() - results = self.repo.group("person", Query.create(), None, None, ["place"]) + results = self.repo.group("person", Query.create(), None, None, [("place", "set")]) assert_that(results, has_item(has_entries({ "person": "John", - "place": has_items("Kettering", "Kennington") + "place:set": has_items("Kettering", "Kennington") }))) def test_another_grouping_with_collect(self): self.setUpPeopleLocationData() - results = self.repo.group("place", Query.create(), None, None, ["person"]) + results = self.repo.group("place", Query.create(), None, None, [("person", "set")]) assert_that(results, has_item(has_entries({ "place": "Kettering", - "person": has_items("Jack", "John") + "person:set": has_items("Jack", "John") }))) def test_grouping_with_collect_two_fields(self): self.setUpPeopleLocationData() - results = self.repo.group("place", Query.create(), None, None, ["person", "hair"]) + results = self.repo.group("place", Query.create(), None, None, + [("person", "set"), ("hair", "set")]) assert_that(results, has_item(has_entries({ "place": "Kettering", - "person": ["Jack", "John"], - "hair": ["blond", "dark", "red"] + "person:set": ["Jack", "John"], + "hair:set": ["blond", "dark", "red"] }))) def test_grouping_on_non_existent_keys(self): @@ -426,12 +427,12 @@ def test_multi_group_with_collect(self): "place", "_week_start_at", Query.create(), - collect=["person"] + collect=[("person", "set")] ) assert_that(results, has_item(has_entries({ "place": "Kettering", - "person": ["Jack", "John"] + "person:set": ["Jack", "John"] }))) @@ -466,8 +467,7 @@ def test_query_for_data_with_different_missing_fields_some_results(self): "bar": "2" }) - result = self.repo.multi_group("_week_start_at", "bar", Query.create(), - collect=["foo"]) + result = self.repo.multi_group("_week_start_at", "bar", Query.create()) assert_that(result, has_item(has_entry("_count", 1))) assert_that(result, has_item(has_entry("_group_count", 1))) @@ -488,8 +488,7 @@ def test_query_for_data_with_different_missing_fields_with_filter(self): }) result = self.repo.multi_group("_week_start_at", "bar", - Query.create(filter_by= [["bar", "2"]]), - collect=["foo"]) + Query.create(filter_by= [["bar", "2"]])) assert_that(result, has_item(has_entry("_count", 1))) assert_that(result, has_item(has_entry("_group_count", 1))) diff --git a/tests/core/test_database.py b/tests/core/test_database.py index 44edc743..9eed2af3 100644 --- a/tests/core/test_database.py +++ b/tests/core/test_database.py @@ -3,7 +3,7 @@ from mock import Mock, patch from pymongo.errors import AutoReconnect from backdrop.core import database -from backdrop.core.database import Repository, InvalidSortError, MongoDriver +from backdrop.core.database import Repository, InvalidSortError, InvalidOperationError, MongoDriver, apply_collection_method from backdrop.read.query import Query from tests.support.test_helpers import d_tz @@ -51,8 +51,8 @@ def test_nested_merge_merges_dictionaries(self): "_count": 0, "_group_count": 2, "_subgroup": [ - {"b": 1, "c": 3}, - {"b": 2, "c": 3}, + {"b": 1}, + {"b": 2}, ], })) assert_that(output[1], is_({ @@ -60,18 +60,88 @@ def test_nested_merge_merges_dictionaries(self): "_count": 0, "_group_count": 1, "_subgroup": [ - {"b": 1, "c": 3} + {"b": 1} ], })) def test_nested_merge_squashes_duplicates(self): output = database.nested_merge(['a'], [], self.dictionaries) assert_that(output, is_([ - {'a': 1, 'b': 2, 'c': 3}, - {'a': 2, 'b': 1, 'c': 3} + {'a': 1}, + {'a': 2} + ])) + + def test_nested_merge_collect_default(self): + stub_dictionaries = [ + {'a': 1, 'b': [2], 'c': 3}, + {'a': 1, 'b': [1], 'c': 3}, + {'a': 2, 'b': [1], 'c': 3} + ] + output = database.nested_merge(['a'], [('b', 'default')], stub_dictionaries) + assert_that(output, is_([ + {'a': 1, 'b:set': [1, 2], 'b': [1, 2]}, + {'a': 2, 'b:set': [1], 'b': [1]} + ])) + + def test_nested_merge_collect_set(self): + stub_dictionaries = [ + {'a': 1, 'b': [2], 'c': 3}, + {'a': 1, 'b': [1], 'c': 3}, + {'a': 2, 'b': [1], 'c': 3} + ] + output = database.nested_merge(['a'], [('b', 'set')], stub_dictionaries) + assert_that(output, is_([ + {'a': 1, 'b:set': [1, 2]}, + {'a': 2, 'b:set': [1]} + ])) + + def test_nested_merge_collect_sum(self): + stub_dictionaries = [ + {'a': 1, 'b': [2]}, + {'a': 1, 'b': [1]}, + {'a': 2, 'b': [1]} + ] + output = database.nested_merge(['a'], [('b', 'sum')], stub_dictionaries) + assert_that(output, is_([ + {'a': 1, 'b:sum': 3}, + {'a': 2, 'b:sum': 1} ])) +class TestApplyCollectionMethod(unittest.TestCase): + def test_sum(self): + data = [2, 5, 8] + response = apply_collection_method(data, "sum") + assert_that(response, is_(15)) + + def test_count(self): + data = ['Sheep', 'Elephant', 'Wolf', 'Dog'] + response = apply_collection_method(data, "count") + assert_that(response, is_(4)) + + def test_set(self): + data = ['Badger', 'Badger', 'Badger', 'Snake'] + response = apply_collection_method(data, "set") + assert_that(response, is_(['Badger', 'Snake'])) + + def test_mean(self): + data = [13, 19, 15, 2] + response = apply_collection_method(data, "mean") + assert_that(response, is_(12.25)) + + def test_unknown_collection_method_raises_error(self): + self.assertRaises(ValueError, + apply_collection_method, ['foo'], "unknown") + + def test_bad_data_for_sum_raises_error(self): + self.assertRaises(InvalidOperationError, + apply_collection_method, ['sum', 'this'], "sum") + + def test_bad_data_for_mean_raises_error(self): + self.assertRaises(InvalidOperationError, + apply_collection_method, ['average', 'this'], "mean") + + class TestRepository(unittest.TestCase): def setUp(self): self.mongo = Mock() diff --git a/tests/read/test_parse_request_args.py b/tests/read/test_parse_request_args.py index 78580751..0bdae9dc 100644 --- a/tests/read/test_parse_request_args.py +++ b/tests/read/test_parse_request_args.py @@ -103,16 +103,16 @@ def test_limit_is_parsed(self): assert_that(args['limit'], is_(123)) - def test_one_collect_is_parsed(self): + def test_one_collect_is_parsed_with_default_method(self): request_args = MultiDict([ ("collect", "some_key") ]) args = parse_request_args(request_args) - assert_that(args['collect'], is_(["some_key"])) + assert_that(args['collect'], is_([("some_key", "default")])) - def test_two_collects_are_parsed(self): + def test_two_collects_are_parsed_with_default_methods(self): request_args = MultiDict([ ("collect", "some_key"), ("collect", "some_other_key") @@ -120,4 +120,14 @@ def test_two_collects_are_parsed(self): args = parse_request_args(request_args) - assert_that(args['collect'], is_(["some_key", "some_other_key"])) + assert_that(args['collect'], is_([("some_key", "default"), + ("some_other_key", "default")])) + + def test_one_collect_is_parsed_with_custom_method(self): + request_args = MultiDict([ + ("collect", "some_key:mean") + ]) + + args = parse_request_args(request_args) + + assert_that(args['collect'], is_([("some_key", "mean")])) diff --git a/tests/read/test_validation.py b/tests/read/test_validation.py index 53290d63..0f0597cc 100644 --- a/tests/read/test_validation.py +++ b/tests/read/test_validation.py @@ -306,6 +306,27 @@ def test_that_queries_with_invalid_timezone_are_disallowed(self): assert_that(validation_result, is_invalid_with_message( "start_at is not a valid datetime")) + def test_that_collect_queries_with_valid_methods_are_allowed(self): + valid_collection_methods = ["sum", "count", "set", "mean"] + + for method in valid_collection_methods: + validation_result = validate_request_args({ + 'group_by': 'foo', + 'collect': 'field:{0}'.format(method), + }) + + assert_that(validation_result, is_valid()) + + def test_that_collect_queries_with_invalid_method_are_disallowed(self): + validation_result = validate_request_args({ + 'group_by': 'foo', + 'collect': 'field:infinity', + }) + + assert_that(validation_result, is_invalid_with_message(( + "Unknown collection method" + ))) + class TestValidationHelpers(TestCase): def test_timestamp_is_valid_method(self):