diff --git a/plutus-core/plutus-ir/src/PlutusIR/Transform/Inline/CallSiteInline.hs b/plutus-core/plutus-ir/src/PlutusIR/Transform/Inline/CallSiteInline.hs index f46ba90b696..02b3eccfabe 100644 --- a/plutus-core/plutus-ir/src/PlutusIR/Transform/Inline/CallSiteInline.hs +++ b/plutus-core/plutus-ir/src/PlutusIR/Transform/Inline/CallSiteInline.hs @@ -12,11 +12,10 @@ See note [Inlining of fully applied functions]. module PlutusIR.Transform.Inline.CallSiteInline where +import Control.Lens (forMOf) import Control.Monad.State -import Data.Map.Strict qualified as Map import PlutusIR.Core import PlutusIR.Transform.Inline.Utils -import Prettyprinter {- Note [Inlining of fully applied functions] @@ -148,16 +147,16 @@ We may want to check their sizes instead of just rejecting them. -- | Computes the 'Arity' of a term. computeArity :: Term tyname name uni fun ann - -> Arity + -> (Arity, Term tyname name uni fun ann) computeArity = \case - LamAbs _ _ _ body -> MkTerm : computeArity body - TyAbs _ _ _ body -> MkType : computeArity body + LamAbs _ _ _ body -> (MkTerm : fst (computeArity body), body) + TyAbs _ _ _ body -> (MkType : fst (computeArity body), body) -- Whenever we encounter a body that is not a lambda or type abstraction, we are done counting - _ -> [] + tm -> ([],tm) -- | Inline fully applied functions iff the body of the function is `acceptable`. -considerInline :: - Term tyname name uni fun ann -- the variable that is a function +considerInline :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun + => Term tyname name uni fun ann -- the variable that is a function -> InlineM tyname name uni fun ann (Term tyname name uni fun ann) considerInline v@(Var ann n) = do -- look up the variable in the `CalledVar` map @@ -166,14 +165,16 @@ considerInline v@(Var ann n) = do -- if it's not in the map, it's not a function, don't inline. Nothing -> pure v Just info -> do - let subst = calledVarDef info -- what we substitute in is its definition - isAcceptable <- acceptable subst + let + -- subst = calledVarDef info -- what we substitute in is its definition + bodyToCheckAcceptable = calledVarBody info + isAcceptable <- acceptable bodyToCheckAcceptable -- if the size and cost are not acceptable, don't inline if not isAcceptable then pure v -- if the size and cost are acceptable, then check if it's fully applied -- See note [Identifying fully applied call sites]. - else do - pure v + else + inlineSat v considerInline _notVar = -- this should not happen Prelude.error "considerInline: should be a variable." @@ -189,32 +190,41 @@ considerInline _notVar = -- this should not happen -- sites. -- type ApplicationMap ann name = Map.Map name [ApplicationOrder ann] +-- | A term or type argument. data Args tyname name uni fun ann = MkTermArg (Term tyname name uni fun ann) | MkTypeArg (Type tyname uni ann) +-- | A list of type or term argument(s) being applied. type ArgOrder tyname name uni fun ann = [Args tyname name uni fun ann] +-- | A pair of argument and the annotation of the term being applied to, +-- so the term can be built back in `mkApps`. +type ArgOrderWithAnn tyname name uni fun ann = + [(Args tyname name uni fun ann, ann)] + -- | Takes a term or type application expression and returns the function -- being applied and the arguments to which it is applied collectArgs :: Term tyname name uni fun ann - -> (Term tyname name uni fun ann, ArgOrder tyname name uni fun ann) + -> (Term tyname name uni fun ann, ArgOrderWithAnn tyname name uni fun ann) collectArgs expr = go expr [] where - go (Apply _ f a) as = go f (MkTermArg a:as) - go (TyInst _ f tyArg) as = go f (MkTypeArg tyArg:as) - go e as = (e, as) + go (Apply ann f a) as = go f ((MkTermArg a, ann):as) + go (TyInst ann f tyArg) as = go f ((MkTypeArg tyArg, ann):as) + go e as = (e, as) -- | Apply a list of term and type arguments to a function in potentially a nested fashion. mkApps :: Term tyname name uni fun ann - -> ArgOrder tyname name uni fun ann + -> ArgOrderWithAnn tyname name uni fun ann -> Term tyname name uni fun ann -mkApps f (MkTermArg tmArg : args) = Apply f args +mkApps f ((MkTermArg tmArg,ann) : args) = mkApps (Apply ann f tmArg) args +mkApps f ((MkTypeArg tyArg,ann) : args) = mkApps (TyInst ann f tyArg) args +mkApps f [] = f enoughArgs :: Arity -> ArgOrder tyname name uni fun ann -> Bool -enoughArgs [] argsOrder = True -enoughArgs arity [] = False +enoughArgs [] _argsOrder = True +enoughArgs _arity [] = False enoughArgs lamOrder argsOrder = -- start comparing from the end because there may be over-application case (last lamOrder, last argsOrder) of @@ -223,67 +233,68 @@ enoughArgs lamOrder argsOrder = _ -> False -- | Inline fully applied functions. See note [Identifying fully applied call sites]. -inlineFullyApplied :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun +inlineSat :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun => Term tyname name uni fun ann -- ^ The `body` of the `Let` term. -> InlineM tyname name uni fun ann (Term tyname name uni fun ann) --- If the term is a term application, get the `AppOrder` of the -inlineFullyApplied appTerm@(Apply _ fun arg) = do +-- If the term is a term application, see if we it's applying to something that we may inline +inlineSat appTerm@(Apply _varAnn _fun _arg) = do -- collect all the arguments of the term being applied to let argsAppliedTo = fst $ collectArgs appTerm args = snd $ collectArgs appTerm case argsAppliedTo of -- if it is a `Var` that is being applied to, check to see if it's fully applied - Var _ name -> do + Var _ann name -> do maybeVarInfo <- gets (lookupCalled name) case maybeVarInfo of - Nothing -> pure $ Apply _ (go fun) (go arg) + -- the variable is not in the map that contains all the in-scope functions, this shouldn't + -- happen? TODO maybe error out instead? + Nothing -> forMOf termSubterms appTerm inlineSat Just varInfo -> do - if enoughArgs (arity varInfo) args then + if enoughArgs (arity varInfo) (map fst args) then -- if the `Var` is fully applied (over-application is allowed) then inline it - mkApps (calledVarDef varInfo) (go <$> args) - else pure $ Apply _ (go fun) (go arg) -- otherwise just keep going - -- if the term being applied is not a `Var`, just keep going - _ -> pure $ Apply _ (go fun) (go arg) -inlineFullyApplied (TyInst _ fnBody _) = - -- If the term is a type application, add it to the application stack, and - -- keep on examining the body. - countLocal (appStack <> [MkType]) calledStack fnBody -inlineFullyApplied tm = pure tm - --- (Var ann name) = --- -- When we encounter a body that is a variable, we have found a call site of it. --- -- Using `insertWith` ensures that if a variable is called more than once, the new --- -- `ApplicationOrder` map will be appended to the existing one. --- Map.insertWith (<>) name [MkApplicationOrder ann appStack] calledStack --- go (Let _ _ bds letBody) = --- -- recursive or not, the bindings of this let term *may* contain the variable in --- -- question, so we need to check all the bindings and also the body --- let --- -- get the list of rhs's of the term bindings --- getRHS :: Binding tyname name uni fun ann -> Maybe (Term tyname name uni fun ann) --- getRHS (TermBind _ _ _ rhs) = Just rhs --- getRHS _ = --- -- no need to keep track of the type bindings. Even though this type variable --- -- called in the body, it does not affect the resulting `ApplicationMap` --- Nothing --- listOfRHSOfBindings = mapMaybe getRHS (toList bds) --- in --- foldr (flip $ countLocal []) (countLocal [] calledStack letBody) listOfRHSOfBindings --- go (TyAbs _ _ _ tyAbsBody) = --- -- start count in the body of the type lambda abstraction --- countLocal [] calledStack tyAbsBody --- go (LamAbs _ _ _ fnBody) = --- -- start the count in the body of the term lambda abstraction --- countLocal [] calledStack fnBody --- go (Constant _ _) = --- calledStack -- constants cannot call the variable --- go (Builtin _ _) = --- -- default builtin functions in `PlutusCore/Default/Builtins.hs` --- -- cannot call the variable --- calledStack --- go (Unwrap _ tm) = --- countLocal [] calledStack tm --- go (IWrap _ _ _ tm) = --- countLocal [] calledStack tm --- go (Error _ _) = calledStack - + pure $ mkApps (calledVarDef varInfo) args + -- otherwise just keep going + else forMOf termSubterms appTerm inlineSat + -- if the term being applied is not a `Var`, don't inline, but keep checking + v -> forMOf termSubterms v inlineSat -- keep checking all subterms +inlineSat tyInstTerm@(TyInst varAnn fun arg) = do + -- collect all the arguments of the term being applied to + let argsAppliedTo = fst $ collectArgs tyInstTerm + args = snd $ collectArgs tyInstTerm + case argsAppliedTo of + -- if it is a `Var` that is being applied to, check to see if it's fully applied + Var _ann name -> do + maybeVarInfo <- gets (lookupCalled name) + case maybeVarInfo of + Nothing -> forMOf termSubterms tyInstTerm inlineSat + Just varInfo -> do + if enoughArgs (arity varInfo) (map fst args) then + -- if the `Var` is fully applied (over-application is allowed) then inline it + pure $ mkApps (calledVarDef varInfo) args + -- otherwise just keep going + else forMOf termSubterms tyInstTerm inlineSat + -- if the term being applied is not a `Var`, don't inline but keep checking the subterms + v -> forMOf termSubterms v inlineSat +inlineSat letTm@(Let _ _ bds _letBody) = do + -- recursive or not, the bindings of this let term *may* contain a saturated function, + -- so we need to check all the bindings and also the body + -- `PlutusIR.Core.Plated.termSubterms` gives all that + forMOf termSubterms letTm inlineSat +inlineSat (TyAbs _ _ _ tyAbsBody) = + -- start count in the body of the type lambda abstraction + inlineSat tyAbsBody +inlineSat (LamAbs _ _ _ fnBody) = + -- start the count in the body of the term lambda abstraction + inlineSat fnBody +inlineSat con@(Constant _ _) = + -- constants cannot call the variable + pure con +inlineSat bi@(Builtin _ _) = + -- default builtin functions in `PlutusCore/Default/Builtins.hs` + -- cannot call the variable + pure bi +inlineSat v@(Var _ _) = + -- variables being applied should have been checked already, these ones aren't fully applied. + -- We don't inline them. + pure v +inlineSat others = pure others diff --git a/plutus-core/plutus-ir/src/PlutusIR/Transform/Inline/UnconditionalInline.hs b/plutus-core/plutus-ir/src/PlutusIR/Transform/Inline/UnconditionalInline.hs index ce05386a02b..4ca559aa142 100644 --- a/plutus-core/plutus-ir/src/PlutusIR/Transform/Inline/UnconditionalInline.hs +++ b/plutus-core/plutus-ir/src/PlutusIR/Transform/Inline/UnconditionalInline.hs @@ -34,6 +34,7 @@ import Control.Monad.State import Algebra.Graph qualified as G import Data.Map qualified as Map +import PlutusIR.Transform.Inline.CallSiteInline import Witherable (Witherable (wither)) {- Note [Inlining approach and 'Secrets of the GHC Inliner'] @@ -168,7 +169,7 @@ processTerm = handleTerm <=< traverseOf termSubtypes applyTypeSubstitution where Nothing -> do considerInline v -- If it's in the substitution map, do the substitution - Just v -> pure v + Just var -> pure var Let ann NonRec bs t -> do -- Process bindings, eliminating those which will be inlined unconditionally, -- and accumulating the new substitutions @@ -184,7 +185,8 @@ processTerm = handleTerm <=< traverseOf termSubtypes applyTypeSubstitution where -- This includes recursive let terms, we don't even consider inlining them at the moment t -> forMOf termSubterms t processTerm -applyTypeSubstitution :: Type tyname uni ann +applyTypeSubstitution :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun + => Type tyname uni ann -> InlineM tyname name uni fun ann (Type tyname uni ann) applyTypeSubstitution t = gets isTypeSubstEmpty >>= \case -- The type substitution is very often empty, and there are lots of types in the program, @@ -193,17 +195,21 @@ applyTypeSubstitution t = gets isTypeSubstEmpty >>= \case _ -> typeSubstTyNamesM substTyName t -- See Note [Renaming strategy] -substTyName :: tyname -> InlineM tyname name uni fun ann (Maybe (Type tyname uni ann)) +substTyName :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun + => tyname + -> InlineM tyname name uni fun ann (Maybe (Type tyname uni ann)) substTyName tyname = gets (lookupType tyname) >>= traverse liftDupable -- See Note [Renaming strategy] -substName :: name -> InlineM tyname name uni fun ann (Maybe (Term tyname name uni fun ann)) +substName :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun + => name + -> InlineM tyname name uni fun ann (Maybe (Term tyname name uni fun ann)) substName name = gets (lookupTerm name) >>= traverse renameTerm -- See Note [Inlining approach and 'Secrets of the GHC Inliner'] -- Already processed term, just rename and put it in, don't do any further optimization here. -renameTerm :: - InlineTerm tyname name uni fun ann +renameTerm :: forall tyname name uni fun ann. InliningConstraints tyname name uni fun + => InlineTerm tyname name uni fun ann -> InlineM tyname name uni fun ann (Term tyname name uni fun ann) renameTerm (Done t) = liftDupable t @@ -228,50 +234,25 @@ processSingleBinding body = \case TermBind ann s v@(VarDecl _ n (TyFun _ _tyArg _tyBody)) rhs -> do let -- track the term and type lambda abstraction order of the function - varLamOrder = computeArity rhs - -- examine the `body` of the `Let` term and track all term/type applications. - appSites = countApp body - -- list of all call sites of this variable - listOfCallSites = Map.lookup n appSites - case listOfCallSites of - Nothing -> - -- we don't remove the binding because we decide *at the call site* whether we want to - -- inline, and it may be called more than once - pure $ TermBind ann s v rhs - Just list -> do - let - isEqAppOrder :: ApplicationOrder ann -> Bool - isEqAppOrder appOrder = applicationOrder appOrder == varLamOrder - -- filter the list to only call locations that are fully applied - filteredFullyApplied = filter isEqAppOrder list - fullyAppliedAnns = fmap annotation filteredFullyApplied - -- add the function to `CalledVarEnv` - void $ modify' $ extendCalled n (MkCalledVarInfo rhs varLamOrder fullyAppliedAnns) - pure $ TermBind ann s v rhs + varLamOrder = fst $ computeArity rhs + bodyToCheck = snd $ computeArity rhs + -- add the function to `CalledVarEnv` + void $ modify' $ extendCalled n (MkCalledVarInfo rhs varLamOrder bodyToCheck) + -- we still want to do unconditional inline + maybeRhs' <- maybeAddSubst body ann s n rhs + pure $ TermBind ann s v <$> maybeRhs' -- when the let binding is a type lambda abstraction, we add it to the `CalledVarEnv` and -- consider whether we want to inline at the call site. TermBind ann s v@(VarDecl _ n (TyLam _ann _tyname _tyArg _tyBody)) rhs -> do - let varLamOrder = countLam rhs - appSites = countApp body - listOfCallSites = Map.lookup n appSites - case listOfCallSites of - Nothing -> - -- we don't remove the binding because we decide *at the call site* whether we want to - -- inline, and it may be called more than once - pure $ TermBind ann s v rhs - Just list -> do - let - isEqAppOrder :: ApplicationOrder ann -> Bool - isEqAppOrder appOrder = applicationOrder appOrder == varLamOrder - -- filter the list to only call locations that are fully applied - filteredFullyApplied = filter isEqAppOrder list - fullyAppliedAnns = fmap annotation filteredFullyApplied - -- add the function to `CalledVarEnv` - -- add the type abstraction to `CalledVarEnv` - void $ modify' $ extendCalled n (MkCalledVarInfo rhs varLamOrder fullyAppliedAnns) - -- we don't remove the binding because we decide *at the call site* whether we want to - -- inline, and it may be called more than once - pure $ TermBind ann s v rhs + let varLamOrder = fst $ computeArity rhs + bodyToCheck = snd $ computeArity rhs + -- add the function to `CalledVarEnv` + -- add the type abstraction to `CalledVarEnv` + void $ modify' $ extendCalled n (MkCalledVarInfo rhs varLamOrder bodyToCheck) + -- we don't remove the binding because we decide *at the call site* whether we want to + -- inline, and it may be called more than once + maybeRhs' <- maybeAddSubst body ann s n rhs + pure $ TermBind ann s v <$> maybeRhs' -- for binding that aren't functions, maybe do unconditional inline TermBind ann s v@(VarDecl _ n _) rhs -> do maybeRhs' <- maybeAddSubst body ann s n rhs diff --git a/plutus-core/plutus-ir/test/TransformSpec.hs b/plutus-core/plutus-ir/test/TransformSpec.hs index 9fe9bfa825a..b3da1662565 100644 --- a/plutus-core/plutus-ir/test/TransformSpec.hs +++ b/plutus-core/plutus-ir/test/TransformSpec.hs @@ -214,7 +214,7 @@ inline = computeArityTest :: TestNested computeArityTest = testNested "computeArityTest" $ map - (goldenPir (computeArity . runQuote . PLC.rename) pTerm) + (goldenPir (fst . computeArity . runQuote . PLC.rename) pTerm) [ "var" -- from inline tests, testing let terms , "tyvar" , "single"