Skip to content

Commit

Permalink
feat: add statement_timeout set on functions (#3056)
Browse files Browse the repository at this point in the history
  • Loading branch information
taimoorzaeem committed Nov 17, 2023
1 parent 3c1a7f2 commit 125f10a
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- #2825, SQL handlers for custom media types - @steve-chavez
+ Solves #1548, #2699, #2763, #2170, #1462, #1102, #1374, #2901
- #2799, Add timezone in Prefer header - @taimoorzaeem
- #3001, Add `statement_timeout` set on functions - @taimoorzaeem

### Fixed

Expand Down
18 changes: 9 additions & 9 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -179,49 +179,49 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A
case (iAction, iTarget) of
(ActionRead headersOnly, TargetIdent identifier) -> do
(planTime', wrPlan) <- withTiming $ liftEither $ Plan.wrappedReadPlan identifier conf sCache apiReq
(rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq
(rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.wrTxMode wrPlan) $ Query.readQuery wrPlan conf apiReq
(renderTime', pgrst) <- withTiming $ liftEither $ Response.readResponse wrPlan headersOnly identifier apiReq resultSet
let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime]
return $ pgrstResponse metrics pgrst

(ActionMutate MutationCreate, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationCreate apiReq identifier conf sCache
(rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf
(rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.createQuery mrPlan apiReq conf
(renderTime', pgrst) <- withTiming $ liftEither $ Response.createResponse identifier mrPlan apiReq resultSet
let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime]
return $ pgrstResponse metrics pgrst

(ActionMutate MutationUpdate, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationUpdate apiReq identifier conf sCache
(rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf
(rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.updateQuery mrPlan apiReq conf
(renderTime', pgrst) <- withTiming $ liftEither $ Response.updateResponse mrPlan apiReq resultSet
let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime]
return $ pgrstResponse metrics pgrst

(ActionMutate MutationSingleUpsert, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationSingleUpsert apiReq identifier conf sCache
(rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf
(rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.singleUpsertQuery mrPlan apiReq conf
(renderTime', pgrst) <- withTiming $ liftEither $ Response.singleUpsertResponse mrPlan apiReq resultSet
let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime]
return $ pgrstResponse metrics pgrst

(ActionMutate MutationDelete, TargetIdent identifier) -> do
(planTime', mrPlan) <- withTiming $ liftEither $ Plan.mutateReadPlan MutationDelete apiReq identifier conf sCache
(rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf
(rsTime', resultSet) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.mrTxMode mrPlan) $ Query.deleteQuery mrPlan apiReq conf
(renderTime', pgrst) <- withTiming $ liftEither $ Response.deleteResponse mrPlan apiReq resultSet
let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime]
return $ pgrstResponse metrics pgrst

(ActionInvoke invMethod, TargetProc identifier _) -> do
(planTime', cPlan) <- withTiming $ liftEither $ Plan.callReadPlan identifier conf sCache apiReq invMethod
(rsTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer
(rsTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdTimeout $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer
(renderTime', pgrst) <- withTiming $ liftEither $ Response.invokeResponse cPlan invMethod (Plan.crProc cPlan) apiReq resultSet
let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime]
return $ pgrstResponse metrics pgrst

(ActionInspect headersOnly, TargetDefaultSpec tSchema) -> do
(planTime', iPlan) <- withTiming $ liftEither $ Plan.inspectPlan apiReq
(rsTime', oaiResult) <- withTiming $ runQuery roleIsoLvl (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema
(rsTime', oaiResult) <- withTiming $ runQuery roleIsoLvl Nothing (Plan.ipTxmode iPlan) $ Query.openApiQuery sCache pgVer conf tSchema
(renderTime', pgrst) <- withTiming $ liftEither $ Response.openApiResponse (T.decodeUtf8 prettyVersion, docsVersion) headersOnly oaiResult conf sCache iSchema iNegotiatedByProfile
let metrics = Map.fromList [(SMPlan, planTime'), (SMQuery, rsTime'), (SMRender, renderTime'), jwtTime]
return $ pgrstResponse metrics pgrst
Expand Down Expand Up @@ -249,9 +249,9 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A
where
roleSettings = fromMaybe mempty (HM.lookup authRole $ configRoleSettings conf)
roleIsoLvl = HM.findWithDefault SQL.ReadCommitted authRole $ configRoleIsoLvl conf
runQuery isoLvl mode query =
runQuery isoLvl timeout mode query =
runDbHandler appState isoLvl mode authenticated prepared $ do
Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) apiReq
Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) apiReq timeout
Query.runPreReq conf
query

Expand Down
7 changes: 4 additions & 3 deletions src/PostgREST/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,10 @@ optionalRollback AppConfig{..} ApiRequest{iPreferences=Preferences{..}} = do

-- | Runs local (transaction scoped) GUCs for every request.
setPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> BS.ByteString -> [(ByteString, ByteString)] ->
ApiRequest -> DbHandler ()
setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} = lift $
ApiRequest -> Maybe Text -> DbHandler ()
setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $
SQL.statement mempty $ SQL.dynamicallyParameterized
("select " <> intercalateSnippet ", " (searchPathSql : roleSql ++ roleSettingsSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ appSettingsSql))
("select " <> intercalateSnippet ", " (searchPathSql : roleSql ++ roleSettingsSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ timeoutSql ++ appSettingsSql))
HD.noResult configDbPreparedStatements
where
methodSql = setConfigWithConstantName ("request.method", iMethod)
Expand All @@ -250,6 +250,7 @@ setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} = lift $
roleSettingsSql = setConfigWithDynamicName <$> roleSettings
appSettingsSql = setConfigWithDynamicName <$> (join bimap toUtf8 <$> configAppSettings)
timezoneSql = maybe mempty (\(PreferTimezone tz) -> [setConfigWithConstantName ("timezone", tz)]) $ preferTimezone iPreferences
timeoutSql = maybe mempty ((\t -> [setConfigWithConstantName ("statement_timeout", t)]) . encodeUtf8) tout
searchPathSql =
let schemas = escapeIdentList (iSchema : configDbExtraSearchPath) in
setConfigWithConstantName ("search_path", schemas)
Expand Down
7 changes: 5 additions & 2 deletions src/PostgREST/SchemaCache.hs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ decodeFuncs =
<*> (parseVolatility <$> column HD.char)
<*> column HD.bool
<*> nullableColumn (toIsolationLevel <$> HD.text)
<*> nullableColumn HD.text

addKey :: Routine -> (QualifiedIdentifier, Routine)
addKey pd = (QualifiedIdentifier (pdSchema pd) (pdName pd), pd)
Expand Down Expand Up @@ -430,7 +431,8 @@ funcsSqlQuery pgVer = [q|
bt.oid <> bt.base as rettype_is_composite_alias,
p.provolatile,
p.provariadic > 0 as hasvariadic,
lower((regexp_split_to_array((regexp_split_to_array(config, '='))[2], ','))[1]) AS transaction_isolation_level
lower((regexp_split_to_array((regexp_split_to_array(iso_config, '='))[2], ','))[1]) AS transaction_isolation_level,
lower((regexp_split_to_array((regexp_split_to_array(timeout_config, '='))[2], ','))[1]) AS statement_timeout
FROM pg_proc p
LEFT JOIN arguments a ON a.oid = p.oid
JOIN pg_namespace pn ON pn.oid = p.pronamespace
Expand All @@ -439,7 +441,8 @@ funcsSqlQuery pgVer = [q|
JOIN pg_namespace tn ON tn.oid = t.typnamespace
LEFT JOIN pg_class comp ON comp.oid = t.typrelid
LEFT JOIN pg_description as d ON d.objoid = p.oid
LEFT JOIN LATERAL unnest(proconfig) config ON config like 'default_transaction_isolation%'
LEFT JOIN LATERAL unnest(proconfig) iso_config ON iso_config like 'default_transaction_isolation%'
LEFT JOIN LATERAL unnest(proconfig) timeout_config ON timeout_config like 'statement_timeout%'
WHERE t.oid <> 'trigger'::regtype AND COALESCE(a.callable, true)
|] <> (if pgVer >= pgVersion110 then "AND prokind = 'f'" else "AND NOT (proisagg OR proiswindow)")

Expand Down
8 changes: 5 additions & 3 deletions src/PostgREST/SchemaCache/Routine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ data Routine = Function
, pdVolatility :: FuncVolatility
, pdHasVariadic :: Bool
, pdIsoLvl :: Maybe SQL.IsolationLevel
, pdTimeout :: Maybe Text
}
deriving (Eq, Show, Generic)
-- need to define JSON manually bc SQL.IsolationLevel doesn't have a JSON instance(and we can't define one for that type without getting a compiler error)
instance JSON.ToJSON Routine where
toJSON (Function sch nam desc params ret vol hasVar _) = JSON.object
toJSON (Function sch nam desc params ret vol hasVar _ tout) = JSON.object
[
"pdSchema" .= sch
, "pdName" .= nam
Expand All @@ -70,6 +71,7 @@ instance JSON.ToJSON Routine where
, "pdReturnType" .= JSON.toJSON ret
, "pdVolatility" .= JSON.toJSON vol
, "pdHasVariadic" .= JSON.toJSON hasVar
, "pdTimeout" .= tout
]

data RoutineParam = RoutineParam
Expand All @@ -83,10 +85,10 @@ data RoutineParam = RoutineParam

-- Order by least number of params in the case of overloaded functions
instance Ord Routine where
Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2
Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 tout1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 tout2
| schema1 == schema2 && name1 == name2 && length prms1 < length prms2 = LT
| schema2 == schema2 && name1 == name2 && length prms1 > length prms2 = GT
| otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2)
| otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, tout1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, tout2)

-- | A map of all procs, all of which can be overloaded(one entry will have more than one Routine).
-- | It uses a HashMap for a faster lookup.
Expand Down
8 changes: 8 additions & 0 deletions test/io/fixtures.sql
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,11 @@ $$;
create function terminate_pgrst() returns setof record as $$
select pg_terminate_backend(pid) from pg_stat_activity where application_name iLIKE '%postgrest%';
$$ language sql security definer;

create or replace function one_sec_timeout() returns void as $$
select pg_sleep(3);
$$ language sql set statement_timeout = '1s';

create or replace function four_sec_timeout() returns void as $$
select pg_sleep(3);
$$ language sql set statement_timeout = '4s';
22 changes: 22 additions & 0 deletions test/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,3 +1294,25 @@ def test_no_preflight_request_with_CORS_config_should_not_return_header(defaulte
with run(env=env) as postgrest:
response = postgrest.session.get("/items", headers=headers)
assert "Access-Control-Allow-Origin" not in response.headers


def test_fail_with_3_sec_statement_and_1_sec_statement_timeout(defaultenv):
"statement that takes three seconds to execute should fail with one second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/one_sec_timeout")

assert response.status_code == 500
assert (
response.text
== '{"code":"57014","details":null,"hint":null,"message":"canceling statement due to statement timeout"}'
)


def test_passes_with_3_sec_statement_and_4_sec_statement_timeout(defaultenv):
"statement that takes three seconds to execute should succeed with four second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/four_sec_timeout")

assert response.status_code == 204

0 comments on commit 125f10a

Please sign in to comment.