diff --git a/plutus-core/plutus-core.cabal b/plutus-core/plutus-core.cabal index dc4b4d7997f..5cbdd4fe325 100644 --- a/plutus-core/plutus-core.cabal +++ b/plutus-core/plutus-core.cabal @@ -538,6 +538,7 @@ library plutus-ir PlutusIR.Transform.Rename PlutusIR.Transform.RewriteRules PlutusIR.Transform.RewriteRules.CommuteFnWithConst + PlutusIR.Transform.RewriteRules.RemoveTrace PlutusIR.Transform.StrictifyBindings PlutusIR.Transform.Substitute PlutusIR.Transform.ThunkRecursions diff --git a/plutus-core/plutus-ir/src/PlutusIR/Transform/RewriteRules/RemoveTrace.hs b/plutus-core/plutus-ir/src/PlutusIR/Transform/RewriteRules/RemoveTrace.hs new file mode 100644 index 00000000000..2ce86849eec --- /dev/null +++ b/plutus-core/plutus-ir/src/PlutusIR/Transform/RewriteRules/RemoveTrace.hs @@ -0,0 +1,17 @@ +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} + +module PlutusIR.Transform.RewriteRules.RemoveTrace + ( rewriteRuleRemoveTrace + ) where + +import PlutusCore.Default (DefaultFun) +import PlutusCore.Default.Builtins qualified as Builtin +import PlutusIR.Transform.RewriteRules.Common (pattern A, pattern B, pattern I) +import PlutusIR.Transform.RewriteRules.Internal (RewriteRules (..)) + +rewriteRuleRemoveTrace :: RewriteRules uni DefaultFun +rewriteRuleRemoveTrace = RewriteRules \_varsInfo -> \case + B Builtin.Trace `I` _ty `A` _msg `A` arg -> pure arg + term -> pure term diff --git a/plutus-tx-plugin/src/PlutusTx/Compiler/Builtins.hs b/plutus-tx-plugin/src/PlutusTx/Compiler/Builtins.hs index 103111167d7..077f99932eb 100644 --- a/plutus-tx-plugin/src/PlutusTx/Compiler/Builtins.hs +++ b/plutus-tx-plugin/src/PlutusTx/Compiler/Builtins.hs @@ -46,7 +46,7 @@ import GHC.Types.TyThing qualified as GHC import Language.Haskell.TH.Syntax qualified as TH -import Control.Monad.Reader (ask, asks) +import Control.Monad.Reader (asks) import Data.ByteString qualified as BS import Data.Foldable (for_) @@ -301,8 +301,6 @@ defineBuiltinType name ty = do -- | Add definitions for all the builtin terms to the environment. defineBuiltinTerms :: CompilingDefault uni fun m ann => m () defineBuiltinTerms = do - CompileContext {ccOpts=compileOpts} <- ask - -- Error -- See Note [Delaying error] func <- delayedErrorFunc @@ -380,28 +378,7 @@ defineBuiltinTerms = do PLC.EqualsInteger -> defineBuiltinInl 'Builtins.equalsInteger -- Tracing - -- When `remove-trace` is specified, we define `trace` as `\_ a -> a` instead of the - -- version. - PLC.Trace -> do - (traceTerm, ann) <- - if coRemoveTrace compileOpts - then liftQuote $ do - ta <- freshTyName "a" - t <- freshName "t" - a <- freshName "a" - pure - ( PIR.tyAbs annMayInline ta (PLC.Type annMayInline) $ - PIR.mkIterLamAbs - [ PIR.VarDecl annMayInline t $ - PIR.mkTyBuiltin @_ @Text annMayInline - , PIR.VarDecl annMayInline a $ - PLC.TyVar annMayInline ta - ] - $ PIR.Var annMayInline a - , annMayInline - ) - else pure (mkBuiltin PLC.Trace, annMayInline) - defineBuiltinTerm ann 'Builtins.trace traceTerm + PLC.Trace -> defineBuiltinInl 'Builtins.trace -- Pairs PLC.FstPair -> defineBuiltinInl 'Builtins.fst diff --git a/plutus-tx-plugin/src/PlutusTx/Plugin.hs b/plutus-tx-plugin/src/PlutusTx/Plugin.hs index 38370932e28..e04b53f6311 100644 --- a/plutus-tx-plugin/src/PlutusTx/Plugin.hs +++ b/plutus-tx-plugin/src/PlutusTx/Plugin.hs @@ -79,13 +79,16 @@ import Data.ByteString qualified as BS import Data.ByteString.Unsafe qualified as BSUnsafe import Data.Either.Validation import Data.Map qualified as Map +import Data.Monoid.Extra (mwhen) import Data.Set qualified as Set import Data.Type.Bool qualified as PlutusTx.Bool import GHC.Num.Integer qualified +import PlutusCore.Default (DefaultFun, DefaultUni) import PlutusIR.Analysis.Builtins import PlutusIR.Compiler.Provenance (noProvenance, original) import PlutusIR.Compiler.Types qualified as PIR import PlutusIR.Transform.RewriteRules +import PlutusIR.Transform.RewriteRules.RemoveTrace (rewriteRuleRemoveTrace) import Prettyprinter qualified as PP import System.IO (openTempFile) import System.IO.Unsafe (unsafePerformIO) @@ -423,7 +426,7 @@ compileMarkedExpr locStr codeTy origE = do ccBuiltinsInfo = def, ccBuiltinCostModel = def, ccDebugTraceOn = _posDumpCompilationTrace opts, - ccRewriteRules = def + ccRewriteRules = makeRewriteRules opts } st = CompileState 0 mempty -- See Note [Occurrence analysis] @@ -482,6 +485,9 @@ runCompiler moduleName opts expr = do PIR.DatatypeComponent PIR.Destructor _ -> True _ -> AlwaysInline `elem` fmap annInline (toList ann) + + rewriteRules <- asks ccRewriteRules + -- Compilation configuration -- pir's tc-config is based on plc tcconfig let pirTcConfig = PIR.PirTCConfig plcTcConfig PIR.YesEscape @@ -524,6 +530,7 @@ runCompiler moduleName opts expr = do -- TODO: ensure the same as the one used in the plugin & set PIR.ccBuiltinsInfo def & set PIR.ccBuiltinCostModel def + & set PIR.ccRewriteRules rewriteRules plcOpts = PLC.defaultCompilationOpts & set (PLC.coSimplifyOpts . UPLC.soMaxSimplifierIterations) (opts ^. posMaxSimplifierIterationsUPlc) @@ -642,3 +649,10 @@ makePrimitiveNameInfo names = do thing <- lift . lift $ GHC.lookupThing ghcName pure (name, thing) pure $ Map.fromList infos + +makeRewriteRules :: PluginOptions -> RewriteRules DefaultUni DefaultFun +makeRewriteRules options = + fold + [ mwhen (options ^. posRemoveTrace) rewriteRuleRemoveTrace + , defaultUniRewriteRules + ]