Skip to content

Commit

Permalink
Groupby recursion cleanup - [MOD-7245] (#4598)
Browse files Browse the repository at this point in the history
* cleanup recursion

* added a test

* remove unreachable code
  • Loading branch information
GuyAv46 committed Jun 16, 2024
1 parent 640bbcb commit 295e141
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 26 deletions.
44 changes: 18 additions & 26 deletions src/aggregate/group_by.c
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,11 @@ static void invokeReducers(Grouper *g, Group *gr, RLookupRow *srcrow) {
* the `GROUPER_NSRCKEYS(g)` macro
* @param xpos the current position in xarr
* @param xlen cached value of GROUPER_NSRCKEYS
* @param ypos if xarr[xpos] is an array, this is the current position within
* the array
* @param hval current X-wise hash value. Note that members of the same Y array
* are not hashed together.
* @param res the row is passed to each reducer
*/
static void extractGroups(Grouper *g, const RSValue **xarr, size_t xpos, size_t xlen, size_t arridx,
uint64_t hval, RLookupRow *res) {
static void extractGroups(Grouper *g, const RSValue **xarr, size_t xpos, size_t xlen, uint64_t hval, RLookupRow *res) {
// end of the line - create/add to group
if (xpos == xlen) {
Group *group = NULL;
Expand All @@ -193,31 +190,26 @@ static void extractGroups(Grouper *g, const RSValue **xarr, size_t xpos, size_t
// regular value - just move one step -- increment XPOS
if (v->t != RSValue_Array) {
hval = RSValue_Hash(v, hval);
extractGroups(g, xarr, xpos + 1, xlen, 0, hval, res);
extractGroups(g, xarr, xpos + 1, xlen, hval, res);
} else if (RSValue_ArrayLen(v) == 0) {
// Empty array - hash as null
hval = RSValue_Hash(RS_NullVal(), hval);
const RSValue *array = xarr[xpos];
xarr[xpos] = RS_NullVal();
extractGroups(g, xarr, xpos + 1, xlen, hval, res);
xarr[xpos] = array;
} else {
// Array value. Replace current XPOS with child temporarily
// Array value. Replace current XPOS with child temporarily.
// Each value in the array will be a separate group
const RSValue *array = xarr[xpos];
const RSValue *elem;

if (arridx >= RSValue_ArrayLen(v)) {
elem = NULL;
} else {
elem = RSValue_ArrayItem(v, arridx);
for (size_t i = 0; i < RSValue_ArrayLen(v); i++) {
const RSValue *elem = RSValue_ArrayItem(v, i);
// hash the element, even if it's an array
uint64_t hh = RSValue_Hash(elem, hval);
xarr[xpos] = elem;
extractGroups(g, xarr, xpos + 1, xlen, hh, res);
}

if (elem == NULL) {
elem = RS_NullVal();
}
uint64_t hh = RSValue_Hash(elem, hval);

xarr[xpos] = elem;
extractGroups(g, xarr, xpos, xlen, arridx, hh, res);
xarr[xpos] = array;

// Replace the value back, and proceed to the next value of the array
if (++arridx < RSValue_ArrayLen(v)) {
extractGroups(g, xarr, xpos, xlen, arridx, hval, res);
}
}
}

Expand All @@ -234,7 +226,7 @@ static void invokeGroupReducers(Grouper *g, RLookupRow *srcrow) {
}
groupvals[ii] = v;
}
extractGroups(g, groupvals, 0, nkeys, 0, 0, srcrow);
extractGroups(g, groupvals, 0, nkeys, 0, srcrow);
}

static int Grouper_rpAccum(ResultProcessor *base, SearchResult *res) {
Expand Down
20 changes: 20 additions & 0 deletions tests/pytests/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,26 @@ def testAggregateGroupByOnEmptyField(env):
for var in expected:
env.assertContains(var, res)

def test_groupby_array(env: Env):
env.expect('FT.CREATE', 'idx', 'SCHEMA', 't1', 'TEXT', 'SORTABLE', 't2', 'TEXT', 'SORTABLE').ok()
with env.getClusterConnectionIfNeeded() as con:
con.execute_command('HSET', 'doc1', 't1', 'foo,bar', 't2', 'baz,qux')

res = env.cmd('FT.AGGREGATE', 'idx', '*',
'APPLY', 'split(@t1, ",")', 'AS', 't1',
'APPLY', 'split(@t2, ",")', 'AS', 't2',
'GROUPBY', '2', '@t1', '@t2')

exp = [4, ['t1', 'foo', 't2', 'baz'],
['t1', 'foo', 't2', 'qux'],
['t1', 'bar', 't2', 'baz'],
['t1', 'bar', 't2', 'qux']]

# Check that the result is as expected (res elements contained in exp, and same size)
for row in res:
env.assertContains(row, exp)
env.assertEqual(len(res), len(exp), message=f'{res} != {exp}')

def testMultiSortBy(env):
conn = getConnectionByEnv(env)
env.cmd('FT.CREATE', 'sb_idx', 'SCHEMA', 't1', 'TEXT', 't2', 'TEXT')
Expand Down

0 comments on commit 295e141

Please sign in to comment.