Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Groupby recursion cleanup - [MOD-7245] #4598

Merged
merged 3 commits into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 18 additions & 26 deletions src/aggregate/group_by.c
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,11 @@ static void invokeReducers(Grouper *g, Group *gr, RLookupRow *srcrow) {
* the `GROUPER_NSRCKEYS(g)` macro
* @param xpos the current position in xarr
* @param xlen cached value of GROUPER_NSRCKEYS
* @param ypos if xarr[xpos] is an array, this is the current position within
* the array
* @param hval current X-wise hash value. Note that members of the same Y array
* are not hashed together.
* @param res the row is passed to each reducer
*/
static void extractGroups(Grouper *g, const RSValue **xarr, size_t xpos, size_t xlen, size_t arridx,
uint64_t hval, RLookupRow *res) {
static void extractGroups(Grouper *g, const RSValue **xarr, size_t xpos, size_t xlen, uint64_t hval, RLookupRow *res) {
// end of the line - create/add to group
if (xpos == xlen) {
Group *group = NULL;
Expand All @@ -193,31 +190,26 @@ static void extractGroups(Grouper *g, const RSValue **xarr, size_t xpos, size_t
// regular value - just move one step -- increment XPOS
if (v->t != RSValue_Array) {
hval = RSValue_Hash(v, hval);
extractGroups(g, xarr, xpos + 1, xlen, 0, hval, res);
extractGroups(g, xarr, xpos + 1, xlen, hval, res);
} else if (RSValue_ArrayLen(v) == 0) {
// Empty array - hash as null
hval = RSValue_Hash(RS_NullVal(), hval);
const RSValue *array = xarr[xpos];
xarr[xpos] = RS_NullVal();
extractGroups(g, xarr, xpos + 1, xlen, hval, res);
xarr[xpos] = array;
} else {
// Array value. Replace current XPOS with child temporarily
// Array value. Replace current XPOS with child temporarily.
// Each value in the array will be a separate group
const RSValue *array = xarr[xpos];
const RSValue *elem;

if (arridx >= RSValue_ArrayLen(v)) {
elem = NULL;
} else {
elem = RSValue_ArrayItem(v, arridx);
for (size_t i = 0; i < RSValue_ArrayLen(v); i++) {
const RSValue *elem = RSValue_ArrayItem(v, i);
// hash the element, even if it's an array
uint64_t hh = RSValue_Hash(elem, hval);
xarr[xpos] = elem;
extractGroups(g, xarr, xpos + 1, xlen, hh, res);
}

if (elem == NULL) {
elem = RS_NullVal();
}
uint64_t hh = RSValue_Hash(elem, hval);

xarr[xpos] = elem;
extractGroups(g, xarr, xpos, xlen, arridx, hh, res);
xarr[xpos] = array;

// Replace the value back, and proceed to the next value of the array
if (++arridx < RSValue_ArrayLen(v)) {
extractGroups(g, xarr, xpos, xlen, arridx, hval, res);
}
}
}

Expand All @@ -234,7 +226,7 @@ static void invokeGroupReducers(Grouper *g, RLookupRow *srcrow) {
}
groupvals[ii] = v;
}
extractGroups(g, groupvals, 0, nkeys, 0, 0, srcrow);
extractGroups(g, groupvals, 0, nkeys, 0, srcrow);
}

static int Grouper_rpAccum(ResultProcessor *base, SearchResult *res) {
Expand Down
20 changes: 20 additions & 0 deletions tests/pytests/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,26 @@ def testAggregateGroupByOnEmptyField(env):
for var in expected:
env.assertContains(var, res)

def test_groupby_array(env: Env):
env.expect('FT.CREATE', 'idx', 'SCHEMA', 't1', 'TEXT', 'SORTABLE', 't2', 'TEXT', 'SORTABLE').ok()
with env.getClusterConnectionIfNeeded() as con:
con.execute_command('HSET', 'doc1', 't1', 'foo,bar', 't2', 'baz,qux')

res = env.cmd('FT.AGGREGATE', 'idx', '*',
'APPLY', 'split(@t1, ",")', 'AS', 't1',
'APPLY', 'split(@t2, ",")', 'AS', 't2',
'GROUPBY', '2', '@t1', '@t2')

exp = [4, ['t1', 'foo', 't2', 'baz'],
['t1', 'foo', 't2', 'qux'],
['t1', 'bar', 't2', 'baz'],
['t1', 'bar', 't2', 'qux']]

# Check that the result is as expected (res elements contained in exp, and same size)
for row in res:
env.assertContains(row, exp)
env.assertEqual(len(res), len(exp), message=f'{res} != {exp}')

def testMultiSortBy(env):
conn = getConnectionByEnv(env)
env.cmd('FT.CREATE', 'sb_idx', 'SCHEMA', 't1', 'TEXT', 't2', 'TEXT')
Expand Down
Loading