Skip to content

Commit

Permalink
Fix bug in aggregate function collect()
Browse files Browse the repository at this point in the history
Fixes a bug where the collect functions aggtransfn is not called but
the aggfinalfn is. While I'm not sure why this case happens, as it
'should' always call the aggtransfn, it crashes because the collect
aggregate state (castate) hasn't been initialized and is still null.

The fix is to have the aggfinalfn check for a null state and create
the state for finalizing the empty '[]' results.

Added a regression test for this case.
  • Loading branch information
jrgemignani committed May 10, 2022
1 parent 08d75d8 commit e36983b
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
6 changes: 6 additions & 0 deletions regress/expected/expr.out
Expand Up @@ -5224,6 +5224,12 @@ SELECT * FROM cypher('UCSC', $$ RETURN collect(NULL) $$) AS (empty agtype);
[]
(1 row)

SELECT * FROM cypher('UCSC', $$ MATCH (u) WHERE u.name =~ "doesn't exist" RETURN collect(u.name) $$) AS (name agtype);
name
------
[]
(1 row)

-- should fail
SELECT * FROM cypher('UCSC', $$ RETURN collect() $$) AS (collect agtype);
ERROR: function ag_catalog.age_collect() does not exist
Expand Down
2 changes: 2 additions & 0 deletions regress/sql/expr.sql
Expand Up @@ -2182,6 +2182,8 @@ AS (zip1 agtype, zip2 agtype);
SELECT * FROM cypher('UCSC', $$ RETURN collect(5) $$) AS (result agtype);
-- should return an empty aray
SELECT * FROM cypher('UCSC', $$ RETURN collect(NULL) $$) AS (empty agtype);
SELECT * FROM cypher('UCSC', $$ MATCH (u) WHERE u.name =~ "doesn't exist" RETURN collect(u.name) $$) AS (name agtype);

-- should fail
SELECT * FROM cypher('UCSC', $$ RETURN collect() $$) AS (collect agtype);

Expand Down
31 changes: 29 additions & 2 deletions src/backend/utils/adt/agtype.c
Expand Up @@ -8933,17 +8933,23 @@ Datum age_collect_aggtransfn(PG_FUNCTION_ARGS)
}
/* otherwise, retrieve the state */
else
{
castate = (agtype_in_state *) PG_GETARG_POINTER(0);
}

/*
* Extract the variadic args, of which there should only be one.
* Insert the arg into the array, unless it is null. Nulls are
* skipped over.
*/
if (PG_ARGISNULL(1))
{
nargs = 0;
}
else
{
nargs = extract_variadic_args(fcinfo, 1, true, &args, &types, &nulls);
}

if (nargs == 1)
{
Expand All @@ -8962,15 +8968,21 @@ Datum age_collect_aggtransfn(PG_FUNCTION_ARGS)
0);
/* add the arg if not agtype null */
if (agtv_value->type != AGTV_NULL)
{
add_agtype(args[0], nulls[0], castate, types[0], false);
}
}
else
{
add_agtype(args[0], nulls[0], castate, types[0], false);
}
}
}
else if (nargs > 1)
{
ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("collect() invalid number of arguments")));
}

/* restore the old context */
MemoryContextSwitchTo(old_mcxt);
Expand All @@ -8988,8 +9000,23 @@ Datum age_collect_aggfinalfn(PG_FUNCTION_ARGS)

/* verify we are in an aggregate context */
Assert(AggCheckCallContext(fcinfo, NULL) == AGG_CONTEXT_AGGREGATE);
/* get the state */
castate = (agtype_in_state *) PG_GETARG_POINTER(0);
/*
* Get the state. There are cases where the age_collect_aggtransfn never
* gets called. So, check to see if this is one.
*/
if (PG_ARGISNULL(0))
{
/* create and initialize the state */
castate = palloc0(sizeof(agtype_in_state));
memset(castate, 0, sizeof(agtype_in_state));
/* start the array */
castate->res = push_agtype_value(&castate->parse_state,
WAGT_BEGIN_ARRAY, NULL);
}
else
{
castate = (agtype_in_state *) PG_GETARG_POINTER(0);
}
/* switch to the correct aggregate context */
old_mcxt = MemoryContextSwitchTo(fcinfo->flinfo->fn_mcxt);
/* Finish/close the array */
Expand Down

0 comments on commit e36983b

Please sign in to comment.