diff --git a/plutus-core/plutus-ir/src/PlutusIR/Transform/CommuteConst.hs b/plutus-core/plutus-ir/src/PlutusIR/Transform/CommuteConst.hs index ff6ca0a9b87..6cef7daf0b0 100644 --- a/plutus-core/plutus-ir/src/PlutusIR/Transform/CommuteConst.hs +++ b/plutus-core/plutus-ir/src/PlutusIR/Transform/CommuteConst.hs @@ -1,11 +1,12 @@ {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeApplications #-} module PlutusIR.Transform.CommuteConst (commuteConst) where -import Data.Typeable -import PlutusCore qualified as PLC -import PlutusIR +import Data.Typeable (Typeable, eqT) +import PlutusCore.Default +import PlutusIR.Core (Term (Apply, Builtin, Constant)) {- | Commute such that constants are the first arguments. Consider: @@ -28,28 +29,95 @@ invocations of `equalsInteger`. So the second one is harder to share, which is w So commuting `equalsInteger` so that it has the constant first both a) makes various occurrences of `equalsInteger` more likely to look similar, and b) gives us a maximally-shareable node for CSE. -This applies to any commutative builtin function, although we might expect that `equalsInteger` is -the one that will benefit the most. Plutonomy only commutes `EqualsInteger`. +This applies to any commutative builtin function that takes constants as arguments, although we +might expect that `equalsInteger` is the one that will benefit the most. +Plutonomy only commutes `EqualsInteger` in their `commEquals`. -} +isConstant :: Term tyname name uni fun a -> Bool +isConstant Constant{} = True +isConstant _ = False + commuteConstDefault :: forall tyname name uni a. - Term tyname name uni PLC.DefaultFun a -> - Term tyname name uni PLC.DefaultFun a -commuteConstDefault (Apply ann (Builtin annB PLC.EqualsInteger) (Apply ann1 x y@(Constant{}))) = - Apply ann (Builtin annB PLC.EqualsInteger) (Apply ann1 y x) -commuteConstDefault (Apply ann (Builtin annB PLC.EqualsByteString) (Apply ann1 x y@(Constant{}))) = - Apply ann (Builtin annB PLC.EqualsByteString) (Apply ann1 y x) -commuteConstDefault (Apply ann (Builtin annB PLC.EqualsString) (Apply ann1 x y@(Constant{}))) = - Apply ann (Builtin annB PLC.EqualsString) (Apply ann1 y x) -commuteConstDefault (Apply ann (Builtin annB PLC.AddInteger) (Apply ann1 x y@(Constant{}))) = - Apply ann (Builtin annB PLC.AddInteger) (Apply ann1 y x) -commuteConstDefault (Apply ann (Builtin annB PLC.MultiplyInteger) (Apply ann1 x y@(Constant{}))) = - Apply ann (Builtin annB PLC.MultiplyInteger) (Apply ann1 y x) + Term tyname name uni DefaultFun a -> + Term tyname name uni DefaultFun a +commuteConstDefault tm@(Apply ann (Apply ann1 (Builtin annB fun) x) y) = + case (isCommutativeWithConstant fun, isConstant x, isConstant y) of + (True, False, True) -> Apply ann (Apply ann1 (Builtin annB fun) y) x + _ -> tm commuteConstDefault tm = tm commuteConst :: forall tyname name uni fun a. Typeable fun => Term tyname name uni fun a -> Term tyname name uni fun a -commuteConst = case eqT @fun @PLC.DefaultFun of +commuteConst = case eqT @fun @DefaultFun of Just Refl -> commuteConstDefault Nothing -> id + +-- | Returns whether a `DefaultFun` is commutative with `Constant`'s as arguments. Not using +-- catchall to make sure that this function catches newly added `DefaultFun`. +isCommutativeWithConstant :: DefaultFun -> Bool +isCommutativeWithConstant = \case + AddInteger -> False + SubtractInteger -> False + MultiplyInteger -> True + DivideInteger -> False + QuotientInteger -> False + RemainderInteger -> False + ModInteger -> False + EqualsInteger -> True + LessThanInteger -> False + LessThanEqualsInteger -> False + -- Bytestrings + AppendByteString -> False + ConsByteString -> False + SliceByteString -> False + LengthOfByteString -> False + IndexByteString -> False + EqualsByteString -> True + LessThanByteString -> False + LessThanEqualsByteString -> False + -- Cryptography and hashes + Sha2_256 -> False + Sha3_256 -> False + Blake2b_256 -> False + VerifyEd25519Signature -> False + VerifyEcdsaSecp256k1Signature -> False + VerifySchnorrSecp256k1Signature -> False + -- Strings + AppendString -> False + EqualsString -> True + EncodeUtf8 -> False + DecodeUtf8 -> False + -- Bool + IfThenElse -> False + -- Unit + ChooseUnit -> False + -- Tracing + Trace -> False + -- Pairs + FstPair -> False + SndPair -> False + -- Lists + ChooseList -> False + MkCons -> False + HeadList -> False + TailList -> False + NullList -> False + -- Data + ChooseData -> False + ConstrData -> False + MapData -> False + ListData -> False + IData -> False + BData -> False + UnConstrData -> False + UnMapData -> False + UnListData -> False + UnIData -> False + UnBData -> False + EqualsData -> False -- doesn't take constant + SerialiseData -> False + MkPairData -> False + MkNilData -> False + MkNilPairData -> False