From 295e141e4fe162313ef8e6b5a74d978003727fe6 Mon Sep 17 00:00:00 2001 From: GuyAv46 <47632673+GuyAv46@users.noreply.github.com> Date: Sun, 16 Jun 2024 17:25:24 +0300 Subject: [PATCH] Groupby recursion cleanup - [MOD-7245] (#4598) * cleanup recursion * added a test * remove unreachable code --- src/aggregate/group_by.c | 44 ++++++++++++++------------------- tests/pytests/test_aggregate.py | 20 +++++++++++++++ 2 files changed, 38 insertions(+), 26 deletions(-) diff --git a/src/aggregate/group_by.c b/src/aggregate/group_by.c index a462acef2b..9fac468d6e 100644 --- a/src/aggregate/group_by.c +++ b/src/aggregate/group_by.c @@ -162,14 +162,11 @@ static void invokeReducers(Grouper *g, Group *gr, RLookupRow *srcrow) { * the `GROUPER_NSRCKEYS(g)` macro * @param xpos the current position in xarr * @param xlen cached value of GROUPER_NSRCKEYS - * @param ypos if xarr[xpos] is an array, this is the current position within - * the array * @param hval current X-wise hash value. Note that members of the same Y array * are not hashed together. * @param res the row is passed to each reducer */ -static void extractGroups(Grouper *g, const RSValue **xarr, size_t xpos, size_t xlen, size_t arridx, - uint64_t hval, RLookupRow *res) { +static void extractGroups(Grouper *g, const RSValue **xarr, size_t xpos, size_t xlen, uint64_t hval, RLookupRow *res) { // end of the line - create/add to group if (xpos == xlen) { Group *group = NULL; @@ -193,31 +190,26 @@ static void extractGroups(Grouper *g, const RSValue **xarr, size_t xpos, size_t // regular value - just move one step -- increment XPOS if (v->t != RSValue_Array) { hval = RSValue_Hash(v, hval); - extractGroups(g, xarr, xpos + 1, xlen, 0, hval, res); + extractGroups(g, xarr, xpos + 1, xlen, hval, res); + } else if (RSValue_ArrayLen(v) == 0) { + // Empty array - hash as null + hval = RSValue_Hash(RS_NullVal(), hval); + const RSValue *array = xarr[xpos]; + xarr[xpos] = RS_NullVal(); + extractGroups(g, xarr, xpos + 1, xlen, hval, res); + xarr[xpos] = array; } else { - // Array value. Replace current XPOS with child temporarily + // Array value. Replace current XPOS with child temporarily. + // Each value in the array will be a separate group const RSValue *array = xarr[xpos]; - const RSValue *elem; - - if (arridx >= RSValue_ArrayLen(v)) { - elem = NULL; - } else { - elem = RSValue_ArrayItem(v, arridx); + for (size_t i = 0; i < RSValue_ArrayLen(v); i++) { + const RSValue *elem = RSValue_ArrayItem(v, i); + // hash the element, even if it's an array + uint64_t hh = RSValue_Hash(elem, hval); + xarr[xpos] = elem; + extractGroups(g, xarr, xpos + 1, xlen, hh, res); } - - if (elem == NULL) { - elem = RS_NullVal(); - } - uint64_t hh = RSValue_Hash(elem, hval); - - xarr[xpos] = elem; - extractGroups(g, xarr, xpos, xlen, arridx, hh, res); xarr[xpos] = array; - - // Replace the value back, and proceed to the next value of the array - if (++arridx < RSValue_ArrayLen(v)) { - extractGroups(g, xarr, xpos, xlen, arridx, hval, res); - } } } @@ -234,7 +226,7 @@ static void invokeGroupReducers(Grouper *g, RLookupRow *srcrow) { } groupvals[ii] = v; } - extractGroups(g, groupvals, 0, nkeys, 0, 0, srcrow); + extractGroups(g, groupvals, 0, nkeys, 0, srcrow); } static int Grouper_rpAccum(ResultProcessor *base, SearchResult *res) { diff --git a/tests/pytests/test_aggregate.py b/tests/pytests/test_aggregate.py index b33021b91c..680f22b88c 100644 --- a/tests/pytests/test_aggregate.py +++ b/tests/pytests/test_aggregate.py @@ -706,6 +706,26 @@ def testAggregateGroupByOnEmptyField(env): for var in expected: env.assertContains(var, res) +def test_groupby_array(env: Env): + env.expect('FT.CREATE', 'idx', 'SCHEMA', 't1', 'TEXT', 'SORTABLE', 't2', 'TEXT', 'SORTABLE').ok() + with env.getClusterConnectionIfNeeded() as con: + con.execute_command('HSET', 'doc1', 't1', 'foo,bar', 't2', 'baz,qux') + + res = env.cmd('FT.AGGREGATE', 'idx', '*', + 'APPLY', 'split(@t1, ",")', 'AS', 't1', + 'APPLY', 'split(@t2, ",")', 'AS', 't2', + 'GROUPBY', '2', '@t1', '@t2') + + exp = [4, ['t1', 'foo', 't2', 'baz'], + ['t1', 'foo', 't2', 'qux'], + ['t1', 'bar', 't2', 'baz'], + ['t1', 'bar', 't2', 'qux']] + + # Check that the result is as expected (res elements contained in exp, and same size) + for row in res: + env.assertContains(row, exp) + env.assertEqual(len(res), len(exp), message=f'{res} != {exp}') + def testMultiSortBy(env): conn = getConnectionByEnv(env) env.cmd('FT.CREATE', 'sb_idx', 'SCHEMA', 't1', 'TEXT', 't2', 'TEXT')