Skip to content

Commit

Permalink
perf: change Text queries to ByteString
Browse files Browse the repository at this point in the history
Improves performance by not utf8 encoding the whole query with
encodeUtf8. Only certain parts.

It's also a gradual step needed to use the Snippet type
from hasql-dynamic-statements.
  • Loading branch information
steve-chavez committed Oct 17, 2020
1 parent 7e3e19a commit d1d0c67
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 85 deletions.
2 changes: 1 addition & 1 deletion src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ app dbStructure proc cols conf apiRequest =
, Just $ contentRangeH 1 0 $ if shouldCount then Just queryTotal else Nothing
, if null pkCols && isNothing (iOnConflict apiRequest)
then Nothing
else (\x -> ("Preference-Applied", encodeUtf8 (show x))) <$> iPreferResolution apiRequest
else (\x -> ("Preference-Applied", BS.pack (show x))) <$> iPreferResolution apiRequest
] ++ ctHeaders)) (unwrapGucHeader <$> ghdrs)
if contentType == CTSingularJSON && queryTotal /= 1
then do
Expand Down
7 changes: 4 additions & 3 deletions src/PostgREST/DbRequestBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ readRequest schema rootTableName maxRows allRels apiRequest =
rootWithRels :: Schema -> TableName -> [Relation] -> Action -> (QualifiedIdentifier, [Relation])
rootWithRels schema rootTableName allRels action = case action of
ActionRead _ -> (QualifiedIdentifier schema rootTableName, allRels) -- normal read case
_ -> (QualifiedIdentifier mempty sourceCTEName, mapMaybe toSourceRel allRels ++ allRels) -- mutation cases and calling proc
_ -> (QualifiedIdentifier mempty _sourceCTEName, mapMaybe toSourceRel allRels ++ allRels) -- mutation cases and calling proc
where
_sourceCTEName = decodeUtf8 sourceCTEName
-- To enable embedding in the sourceCTEName cases we need to replace the foreign key tableName in the Relation
-- with {sourceCTEName}. This way findRel can find relationships with sourceCTEName.
toSourceRel :: Relation -> Maybe Relation
toSourceRel r@Relation{relTable=t}
| rootTableName == tableName t = Just $ r {relTable=t {tableName=sourceCTEName}}
| rootTableName == tableName t = Just $ r {relTable=t {tableName=_sourceCTEName}}
| otherwise = Nothing

-- Build the initial tree with a Depth attribute so when a self join occurs we can differentiate the parent and child tables by having
Expand Down Expand Up @@ -228,7 +229,7 @@ getJoinConditions previousAlias newAlias (Relation Table{tableSchema=tSchema, ta
-- if this happens remove the schema `FROM "schema"."{sourceCTEName}"` and use only the
-- `FROM "{sourceCTEName}"`. If the schema remains the FROM would be invalid.
removeSourceCTESchema :: Schema -> TableName -> QualifiedIdentifier
removeSourceCTESchema schema tbl = QualifiedIdentifier (if tbl == sourceCTEName then mempty else schema) tbl
removeSourceCTESchema schema tbl = QualifiedIdentifier (if tbl == decodeUtf8 sourceCTEName then mempty else schema) tbl

addFiltersOrdersRanges :: ApiRequest -> ReadRequest -> Either ApiRequestError ReadRequest
addFiltersOrdersRanges apiRequest rReq = do
Expand Down
80 changes: 41 additions & 39 deletions src/PostgREST/Private/QueryFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ Any function that outputs a SqlFragment should be in this module.
-}
module PostgREST.Private.QueryFragment where

import qualified Data.ByteString.Char8 as BS (intercalate,
pack, unwords)
import qualified Data.HashMap.Strict as HM
import Data.Maybe
import Data.Text (intercalate,
isInfixOf, replace,
toLower)
import qualified Data.Text as T (map, null,
takeWhile)
import qualified Data.Text as T (intercalate,
isInfixOf, map,
null, replace,
takeWhile,
toLower)
import PostgREST.Types
import Protolude hiding (cast,
intercalate, replace,
Expand All @@ -36,7 +38,7 @@ ignoredBody = "pgrst_ignored_body AS (SELECT $1::text) "
-- We do this in SQL to avoid processing the JSON in application code
normalizedBody :: SqlFragment
normalizedBody =
unwords [
BS.unwords [
"pgrst_payload AS (SELECT $1::json AS json_data),",
"pgrst_body AS (",
"SELECT",
Expand All @@ -49,17 +51,20 @@ normalizedBody =
selectBody :: SqlFragment
selectBody = "(SELECT val FROM pgrst_body)"

pgFmtLit :: SqlFragment -> SqlFragment
pgFmtLit :: Text -> SqlFragment
pgFmtLit x =
let trimmed = trimNullChars x
escaped = "'" <> replace "'" "''" trimmed <> "'"
slashed = replace "\\" "\\\\" escaped in
if "\\" `isInfixOf` escaped
escaped = "'" <> T.replace "'" "''" trimmed <> "'"
slashed = T.replace "\\" "\\\\" escaped in
encodeUtf8 $ if "\\" `T.isInfixOf` escaped
then "E" <> slashed
else slashed

pgFmtIdent :: SqlFragment -> SqlFragment
pgFmtIdent x = "\"" <> replace "\"" "\"\"" (trimNullChars $ toS x) <> "\""
pgFmtIdent :: Text -> SqlFragment
pgFmtIdent x = encodeUtf8 $ "\"" <> T.replace "\"" "\"\"" (trimNullChars x) <> "\""

trimNullChars :: Text -> Text
trimNullChars = T.takeWhile (/= '\x0')

asCsvF :: SqlFragment
asCsvF = asCsvHeaderF <> " || '\n' || " <> asCsvBodyF
Expand Down Expand Up @@ -92,16 +97,16 @@ locationF pKeys = [qc|(
WHERE json_data.key IN ('{fmtPKeys}')
)|]
where
fmtPKeys = intercalate "','" pKeys
fmtPKeys = T.intercalate "','" pKeys

fromQi :: QualifiedIdentifier -> SqlFragment
fromQi t = (if s == "" then "" else pgFmtIdent s <> ".") <> pgFmtIdent n
fromQi t = (if T.null s then mempty else pgFmtIdent s <> ".") <> pgFmtIdent n
where
n = qiName t
s = qiSchema t

emptyOnFalse :: Text -> Bool -> Text
emptyOnFalse val cond = if cond then "" else val
emptyOnFalse :: SqlFragment -> Bool -> SqlFragment
emptyOnFalse val cond = if cond then mempty else val

pgFmtColumn :: QualifiedIdentifier -> Text -> SqlFragment
pgFmtColumn table "*" = fromQi table <> ".*"
Expand All @@ -112,13 +117,13 @@ pgFmtField table (c, jp) = pgFmtColumn table c <> pgFmtJsonPath jp

pgFmtSelectItem :: QualifiedIdentifier -> SelectItem -> SqlFragment
pgFmtSelectItem table (f@(fName, jp), Nothing, alias, _) = pgFmtField table f <> pgFmtAs fName jp alias
pgFmtSelectItem table (f@(fName, jp), Just cast, alias, _) = "CAST (" <> pgFmtField table f <> " AS " <> cast <> " )" <> pgFmtAs fName jp alias
pgFmtSelectItem table (f@(fName, jp), Just cast, alias, _) = "CAST (" <> pgFmtField table f <> " AS " <> encodeUtf8 cast <> " )" <> pgFmtAs fName jp alias

pgFmtOrderTerm :: QualifiedIdentifier -> OrderTerm -> SqlFragment
pgFmtOrderTerm qi ot = unwords [
toS . pgFmtField qi $ otTerm ot,
maybe "" show $ otDirection ot,
maybe "" show $ otNullOrder ot]
pgFmtOrderTerm qi ot = BS.unwords [
pgFmtField qi $ otTerm ot,
BS.pack $ maybe mempty show $ otDirection ot,
BS.pack $ maybe mempty show $ otNullOrder ot]

pgFmtFilter :: QualifiedIdentifier -> Filter -> SqlFragment
pgFmtFilter table (Filter fld (OpExpr hasNot oper)) = notOp <> " " <> case oper of
Expand All @@ -131,60 +136,57 @@ pgFmtFilter table (Filter fld (OpExpr hasNot oper)) = notOp <> " " <> case oper
In vals -> pgFmtField table fld <> " " <>
let emptyValForIn = "= any('{}') " in -- Workaround because for postgresql "col IN ()" is invalid syntax, we instead do "col = any('{}')"
case (&&) (length vals == 1) . T.null <$> headMay vals of
Just False -> sqlOperator "in" <> "(" <> intercalate ", " (map unknownLiteral vals) <> ") "
Just False -> sqlOperator "in" <> "(" <> BS.intercalate ", " (unknownLiteral <$> vals) <> ") "
Just True -> emptyValForIn
Nothing -> emptyValForIn

Fts op lang val ->
pgFmtFieldOp op
<> "("
<> maybe "" ((<> ", ") . pgFmtLit) lang
<> maybe mempty ((<> ", ") . pgFmtLit) lang
<> unknownLiteral val
<> ") "
where
pgFmtFieldOp op = pgFmtField table fld <> " " <> sqlOperator op
sqlOperator o = HM.lookupDefault "=" o operators
notOp = if hasNot then "NOT" else ""
notOp = if hasNot then "NOT" else mempty
star c = if c == '*' then '%' else c
unknownLiteral = (<> "::unknown ") . pgFmtLit
whiteList :: Text -> SqlFragment
whiteList v = fromMaybe
(toS (pgFmtLit v) <> "::unknown ")
(find ((==) . toLower $ v) ["null","true","false"])
whiteList v = maybe
(pgFmtLit v <> "::unknown") encodeUtf8
(find ((==) . T.toLower $ v) ["null","true","false"])

pgFmtJoinCondition :: JoinCondition -> SqlFragment
pgFmtJoinCondition (JoinCondition (qi1, col1) (qi2, col2)) =
pgFmtColumn qi1 col1 <> " = " <> pgFmtColumn qi2 col2

pgFmtLogicTree :: QualifiedIdentifier -> LogicTree -> SqlFragment
pgFmtLogicTree qi (Expr hasNot op forest) = notOp <> " (" <> intercalate (" " <> show op <> " ") (pgFmtLogicTree qi <$> forest) <> ")"
where notOp = if hasNot then "NOT" else ""
pgFmtLogicTree qi (Expr hasNot op forest) = notOp <> " (" <> BS.intercalate (" " <> BS.pack (show op) <> " ") (pgFmtLogicTree qi <$> forest) <> ")"
where notOp = if hasNot then "NOT" else mempty
pgFmtLogicTree qi (Stmnt flt) = pgFmtFilter qi flt

pgFmtJsonPath :: JsonPath -> SqlFragment
pgFmtJsonPath = \case
[] -> ""
[] -> mempty
(JArrow x:xs) -> "->" <> pgFmtJsonOperand x <> pgFmtJsonPath xs
(J2Arrow x:xs) -> "->>" <> pgFmtJsonOperand x <> pgFmtJsonPath xs
where
pgFmtJsonOperand (JKey k) = pgFmtLit k
pgFmtJsonOperand (JIdx i) = pgFmtLit i <> "::int"

pgFmtAs :: FieldName -> JsonPath -> Maybe Alias -> SqlFragment
pgFmtAs _ [] Nothing = ""
pgFmtAs _ [] Nothing = mempty
pgFmtAs fName jp Nothing = case jOp <$> lastMay jp of
Just (JKey key) -> " AS " <> pgFmtIdent key
Just (JIdx _) -> " AS " <> pgFmtIdent (fromMaybe fName lastKey)
-- We get the lastKey because on:
-- `select=data->1->mycol->>2`, we need to show the result as [ {"mycol": ..}, {"mycol": ..} ]
-- `select=data->3`, we need to show the result as [ {"data": ..}, {"data": ..} ]
where lastKey = jVal <$> find (\case JKey{} -> True; _ -> False) (jOp <$> reverse jp)
Nothing -> ""
Nothing -> mempty
pgFmtAs _ _ (Just alias) = " AS " <> pgFmtIdent alias

trimNullChars :: Text -> Text
trimNullChars = T.takeWhile (/= '\x0')

countF :: SqlQuery -> Bool -> (SqlFragment, SqlFragment)
countF countQuery shouldCount =
if shouldCount
Expand All @@ -199,21 +201,21 @@ returningF :: QualifiedIdentifier -> [FieldName] -> SqlFragment
returningF qi returnings =
if null returnings
then "RETURNING 1" -- For mutation cases where there's no ?select, we return 1 to know how many rows were modified
else "RETURNING " <> intercalate ", " (pgFmtColumn qi <$> returnings)
else "RETURNING " <> BS.intercalate ", " (pgFmtColumn qi <$> returnings)

responseHeadersF :: PgVersion -> SqlFragment
responseHeadersF pgVer =
if pgVer >= pgVersion96
then currentSettingF "response.headers"
else "null" :: Text
else "null"

responseStatusF :: PgVersion -> SqlFragment
responseStatusF pgVer =
if pgVer >= pgVersion96
then currentSettingF "response.status"
else "null" :: Text
else "null"

currentSettingF :: SqlFragment -> SqlFragment
currentSettingF :: Text -> SqlFragment
currentSettingF setting =
-- nullif is used because of https://gist.github.com/steve-chavez/8d7033ea5655096903f3b52f8ed09a15
"nullif(current_setting(" <> pgFmtLit setting <> ", true), '')"
62 changes: 31 additions & 31 deletions src/PostgREST/QueryBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ module PostgREST.QueryBuilder (
, setLocalSearchPathQuery
) where

import qualified Data.Set as S
import qualified Data.ByteString.Char8 as BS
import qualified Data.Set as S

import Data.Text (intercalate)
import Data.Tree (Tree (..))

import Data.Maybe
Expand All @@ -36,14 +36,14 @@ import Protolude hiding (cast, intercalate,

readRequestToQuery :: ReadRequest -> SqlQuery
readRequestToQuery (Node (Select colSelects mainQi tblAlias implJoins logicForest joinConditions_ ordts range, _) forest) =
unwords [
"SELECT " <> intercalate ", " (map (pgFmtSelectItem qi) colSelects ++ selects),
"FROM " <> intercalate ", " (tabl : implJs),
unwords joins,
("WHERE " <> intercalate " AND " (map (pgFmtLogicTree qi) logicForest ++ map pgFmtJoinCondition joinConditions_))
BS.unwords [
"SELECT " <> BS.intercalate ", " (map (pgFmtSelectItem qi) colSelects ++ selects),
"FROM " <> BS.intercalate ", " (tabl : implJs),
BS.unwords joins,
("WHERE " <> BS.intercalate " AND " (map (pgFmtLogicTree qi) logicForest ++ map pgFmtJoinCondition joinConditions_))
`emptyOnFalse` (null logicForest && null joinConditions_),
("ORDER BY " <> intercalate ", " (map (pgFmtOrderTerm qi) ordts)) `emptyOnFalse` null ordts,
("LIMIT " <> maybe "ALL" show (rangeLimit range) <> " OFFSET " <> show (rangeOffset range)) `emptyOnFalse` (range == allRange)
("ORDER BY " <> BS.intercalate ", " (map (pgFmtOrderTerm qi) ordts)) `emptyOnFalse` null ordts,
("LIMIT " <> maybe "ALL" (BS.pack . show) (rangeLimit range) <> " OFFSET " <> (BS.pack . show) (rangeOffset range)) `emptyOnFalse` (range == allRange)
]
where
implJs = fromQi <$> implJoins
Expand Down Expand Up @@ -71,58 +71,58 @@ getJoinsSelects (Node (_, (_, Nothing, _, _, _)) _) _ = ([], [])

mutateRequestToQuery :: MutateRequest -> SqlQuery
mutateRequestToQuery (Insert mainQi iCols onConflct putConditions returnings) =
unwords [
BS.unwords [
"WITH " <> normalizedBody,
"INSERT INTO ", fromQi mainQi, if S.null iCols then " " else "(" <> cols <> ")",
unwords [
BS.unwords [
"SELECT " <> cols <> " FROM",
"json_populate_recordset", "(null::", fromQi mainQi, ", " <> selectBody <> ") _",
-- Only used for PUT
("WHERE " <> intercalate " AND " (pgFmtLogicTree (QualifiedIdentifier mempty "_") <$> putConditions)) `emptyOnFalse` null putConditions],
("WHERE " <> BS.intercalate " AND " (pgFmtLogicTree (QualifiedIdentifier mempty "_") <$> putConditions)) `emptyOnFalse` null putConditions],
maybe "" (\(oncDo, oncCols) -> (
"ON CONFLICT(" <> intercalate ", " (pgFmtIdent <$> oncCols) <> ") " <> case oncDo of
"ON CONFLICT(" <> BS.intercalate ", " (pgFmtIdent <$> oncCols) <> ") " <> case oncDo of
IgnoreDuplicates ->
"DO NOTHING"
MergeDuplicates ->
if S.null iCols
then "DO NOTHING"
else "DO UPDATE SET " <> intercalate ", " (pgFmtIdent <> const " = EXCLUDED." <> pgFmtIdent <$> S.toList iCols)
else "DO UPDATE SET " <> BS.intercalate ", " (pgFmtIdent <> const " = EXCLUDED." <> pgFmtIdent <$> S.toList iCols)
) `emptyOnFalse` null oncCols) onConflct,
returningF mainQi returnings
]
where
cols = intercalate ", " $ pgFmtIdent <$> S.toList iCols
cols = BS.intercalate ", " $ pgFmtIdent <$> S.toList iCols
mutateRequestToQuery (Update mainQi uCols logicForest returnings) =
if S.null uCols
-- if there are no columns we cannot do UPDATE table SET {empty}, it'd be invalid syntax
-- selecting an empty resultset from mainQi gives us the column names to prevent errors when using &select=
-- the select has to be based on "returnings" to make computed overloaded functions not throw
then "WITH " <> ignoredBody <> "SELECT " <> empty_body_returned_columns <> " FROM " <> fromQi mainQi <> " WHERE false"
else
unwords [
BS.unwords [
"WITH " <> normalizedBody,
"UPDATE " <> fromQi mainQi <> " SET " <> cols,
"FROM (SELECT * FROM json_populate_recordset", "(null::", fromQi mainQi, ", " <> selectBody <> ")) _ ",
("WHERE " <> intercalate " AND " (pgFmtLogicTree mainQi <$> logicForest)) `emptyOnFalse` null logicForest,
("WHERE " <> BS.intercalate " AND " (pgFmtLogicTree mainQi <$> logicForest)) `emptyOnFalse` null logicForest,
returningF mainQi returnings
]
where
cols = intercalate ", " (pgFmtIdent <> const " = _." <> pgFmtIdent <$> S.toList uCols)
cols = BS.intercalate ", " (pgFmtIdent <> const " = _." <> pgFmtIdent <$> S.toList uCols)
empty_body_returned_columns :: SqlFragment
empty_body_returned_columns
| null returnings = "NULL"
| otherwise = intercalate ", " (pgFmtColumn (QualifiedIdentifier mempty $ qiName mainQi) <$> returnings)
| otherwise = BS.intercalate ", " (pgFmtColumn (QualifiedIdentifier mempty $ qiName mainQi) <$> returnings)
mutateRequestToQuery (Delete mainQi logicForest returnings) =
unwords [
BS.unwords [
"WITH " <> ignoredBody,
"DELETE FROM ", fromQi mainQi,
("WHERE " <> intercalate " AND " (map (pgFmtLogicTree mainQi) logicForest)) `emptyOnFalse` null logicForest,
("WHERE " <> BS.intercalate " AND " (map (pgFmtLogicTree mainQi) logicForest)) `emptyOnFalse` null logicForest,
returningF mainQi returnings
]

requestToCallProcQuery :: QualifiedIdentifier -> [PgArg] -> Bool -> Maybe PreferParameters -> [FieldName] -> SqlQuery
requestToCallProcQuery qi pgArgs returnsScalar preferParams returnings =
unwords [
BS.unwords [
"WITH",
argsCTE,
sourceBody ]
Expand All @@ -134,25 +134,25 @@ requestToCallProcQuery qi pgArgs returnsScalar preferParams returnings =
| null pgArgs = (ignoredBody, "")
| paramsAsSingleObject = ("pgrst_args AS (SELECT NULL)", "$1::json")
| otherwise = (
unwords [
BS.unwords [
normalizedBody <> ",",
"pgrst_args AS (",
"SELECT * FROM json_to_recordset(" <> selectBody <> ") AS _(" <> fmtArgs (\a -> " " <> pgaType a) <> ")",
"SELECT * FROM json_to_recordset(" <> selectBody <> ") AS _(" <> fmtArgs (\a -> " " <> encodeUtf8 (pgaType a)) <> ")",
")"]
, if paramsAsMultipleObjects
then fmtArgs (\a -> " := pgrst_args." <> pgFmtIdent (pgaName a))
else fmtArgs (\a -> " := (SELECT " <> pgFmtIdent (pgaName a) <> " FROM pgrst_args LIMIT 1)")
)

fmtArgs :: (PgArg -> SqlFragment) -> SqlFragment
fmtArgs argFrag = intercalate ", " ((\a -> pgFmtIdent (pgaName a) <> argFrag a) <$> pgArgs)
fmtArgs argFrag = BS.intercalate ", " ((\a -> pgFmtIdent (pgaName a) <> argFrag a) <$> pgArgs)

sourceBody :: SqlFragment
sourceBody
| paramsAsMultipleObjects =
if returnsScalar
then "SELECT " <> callIt <> " AS pgrst_scalar FROM pgrst_args"
else unwords [ "SELECT pgrst_lat_args.*"
else BS.unwords [ "SELECT pgrst_lat_args.*"
, "FROM pgrst_args,"
, "LATERAL ( SELECT " <> returned_columns <> " FROM " <> callIt <> " ) pgrst_lat_args" ]
| otherwise =
Expand All @@ -166,7 +166,7 @@ requestToCallProcQuery qi pgArgs returnsScalar preferParams returnings =
returned_columns :: SqlFragment
returned_columns
| null returnings = "*"
| otherwise = intercalate ", " (pgFmtColumn (QualifiedIdentifier mempty $ qiName qi) <$> returnings)
| otherwise = BS.intercalate ", " (pgFmtColumn (QualifiedIdentifier mempty $ qiName qi) <$> returnings)


-- | SQL query meant for COUNTing the root node of the Tree.
Expand All @@ -175,19 +175,19 @@ requestToCallProcQuery qi pgArgs returnsScalar preferParams returnings =
-- inside the FROM target.
readRequestToCountQuery :: ReadRequest -> SqlQuery
readRequestToCountQuery (Node (Select{from=qi, where_=logicForest}, _) _) =
unwords [
BS.unwords [
"SELECT 1",
"FROM " <> fromQi qi,
("WHERE " <> intercalate " AND " (map (pgFmtLogicTree qi) logicForest)) `emptyOnFalse` null logicForest
("WHERE " <> BS.intercalate " AND " (map (pgFmtLogicTree qi) logicForest)) `emptyOnFalse` null logicForest
]

limitedQuery :: SqlQuery -> Maybe Integer -> SqlQuery
limitedQuery query maxRows = query <> maybe mempty (\x -> " LIMIT " <> show x) maxRows
limitedQuery query maxRows = query <> maybe mempty (\x -> " LIMIT " <> BS.pack (show x)) maxRows

setLocalQuery :: Text -> (Text, Text) -> SqlQuery
setLocalQuery prefix (k, v) =
"SET LOCAL " <> pgFmtIdent (prefix <> k) <> " = " <> pgFmtLit v <> ";"

setLocalSearchPathQuery :: [Text] -> SqlQuery
setLocalSearchPathQuery vals =
"SET LOCAL search_path = " <> intercalate ", " (pgFmtLit <$> vals) <> ";"
"SET LOCAL search_path = " <> BS.intercalate ", " (pgFmtLit <$> vals) <> ";"
Loading

0 comments on commit d1d0c67

Please sign in to comment.