Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectors in UPLC #5816

Merged
merged 6 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
### Changed

- Use `Vector` in the datastructure for `case` terms during evaluation. This speeds
up evaluation fairly significantly.

2 changes: 2 additions & 0 deletions plutus-core/plutus-core.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ library
, time
, transformers
, unordered-containers
, vector
, witherable

if impl(ghc <9.0)
Expand Down Expand Up @@ -448,6 +449,7 @@ test-suite untyped-plutus-core-test
, tasty-hunit
, tasty-quickcheck
, text
, vector

executable plc
import: lang
Expand Down
3 changes: 2 additions & 1 deletion plutus-core/plutus-core/src/PlutusCore/Compiler/Erase.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module PlutusCore.Compiler.Erase (eraseTerm, eraseProgram) where

import Data.Vector (fromList)
import PlutusCore.Core
import UntypedPlutusCore.Core qualified as UPLC

Expand All @@ -16,7 +17,7 @@ eraseTerm (Unwrap _ term) = eraseTerm term
eraseTerm (IWrap _ _ _ term) = eraseTerm term
eraseTerm (Error ann _) = UPLC.Error ann
eraseTerm (Constr ann _ i args) = UPLC.Constr ann i (fmap eraseTerm args)
eraseTerm (Case ann _ arg cs) = UPLC.Case ann (eraseTerm arg) (fmap eraseTerm cs)
eraseTerm (Case ann _ arg cs) = UPLC.Case ann (eraseTerm arg) (fromList $ fmap eraseTerm cs)

eraseProgram :: Program tyname name uni fun ann -> UPLC.Program name uni fun ann
eraseProgram (Program a v t) = UPLC.Program a v $ eraseTerm t
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import Universe

import Data.Foldable (for_)
import Data.Hashable
import Data.Vector qualified as V

instance (GEq uni, Closed uni, uni `Everywhere` Eq, Eq fun, Eq ann) =>
Eq (Term Name uni fun ann) where
Expand All @@ -35,6 +36,11 @@ type HashableTermConstraints uni fun ann =
, Hashable fun
)

-- This instance is the only logical one, and exists also in the package `vector-instances`.
-- Since this is the same implementation as that one, there isn't even much risk of incoherence.
instance Hashable a => Hashable (V.Vector a) where
hashWithSalt s = hashWithSalt s . toList

instance HashableTermConstraints uni fun ann => Hashable (Term Name uni fun ann)

-- Simple Structural Equality of a `Term NamedDeBruijn`. This implies three things:
Expand Down Expand Up @@ -100,7 +106,7 @@ eqTermM (Constr ann1 i1 args1) (Constr ann2 i2 args2) = do
eqTermM (Case ann1 a1 cs1) (Case ann2 a2 cs2) = do
eqM ann1 ann2
eqTermM a1 a2
case zipExact cs1 cs2 of
case zipExact (toList cs1) (toList cs2) of
Just ps -> for_ ps $ \(t1, t2) -> eqTermM t1 t2
Nothing -> empty
eqTermM Constant{} _ = empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import PlutusCore.Version qualified as PLC
import UntypedPlutusCore.Core.Type

import Control.Monad
import Data.Vector qualified as V
import Data.Word (Word8)
import Flat
import Flat.Decoder
Expand Down Expand Up @@ -128,7 +129,7 @@ encodeTerm = \case
Error ann -> encodeTermTag 6 <> encode ann
Builtin ann bn -> encodeTermTag 7 <> encode ann <> encode bn
Constr ann i es -> encodeTermTag 8 <> encode ann <> encode i <> encodeListWith encodeTerm es
Case ann arg cs -> encodeTermTag 9 <> encode ann <> encodeTerm arg <> encodeListWith encodeTerm cs
Case ann arg cs -> encodeTermTag 9 <> encode ann <> encodeTerm arg <> encodeListWith encodeTerm (V.toList cs)

decodeTerm
:: forall name uni fun ann
Expand Down Expand Up @@ -165,7 +166,7 @@ decodeTerm version builtinPred = go
Constr <$> decode <*> decode <*> decodeListWith go
handleTerm 9 = do
unless (version >= PLC.plcVersion110) $ fail $ "'case' is not allowed before version 1.1.0, this program has version: " ++ (show $ pretty version)
Case <$> decode <*> go <*> decodeListWith go
Case <$> decode <*> go <*> (V.fromList <$> decodeListWith go)
handleTerm t = fail $ "Unknown term constructor tag: " ++ show t

sizeTerm
Expand Down Expand Up @@ -193,7 +194,7 @@ sizeTerm tm sz =
Error ann -> size ann sz'
Builtin ann bn -> size ann $ size bn sz'
Constr ann i es -> size ann $ size i $ sizeListWith sizeTerm es sz'
Case ann arg cs -> size ann $ sizeTerm arg $ sizeListWith sizeTerm cs sz'
Case ann arg cs -> size ann $ sizeTerm arg $ sizeListWith sizeTerm (V.toList cs) sz'

-- | An encoder for programs.
--
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ instance (PrettyClassicBy configName name, PrettyUni uni, Pretty fun, Pretty ann
Constr ann i es ->
sexp "constr" (consAnnIf config ann (pretty i : fmap (prettyBy config) es))
Case ann arg cs ->
sexp "case" (consAnnIf config ann (prettyBy config arg : fmap (prettyBy config) cs))
sexp "case" (consAnnIf config ann
(prettyBy config arg : fmap (prettyBy config) (toList cs)))
where
prettyTypeOf :: Some (ValueOf uni) -> Doc dann
prettyTypeOf (Some (ValueOf uni _ )) = prettyBy juxtRenderContext $ SomeTypeIn uni
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ instance
Error _ -> unitDocM "error"
-- Always rendering the tag on the same line for more compact output, it's just a tiny integer
-- anyway.
Constr _ i es -> iterAppDocM $ \_ prettyArg -> ("constr" <+> prettyArg i) :| [prettyArg es]
Case _ arg cs -> iterAppDocM $ \_ prettyArg -> "case" :| [prettyArg arg, prettyArg cs]
Constr _ i es -> iterAppDocM $ \_ prettyArg ->
("constr" <+> prettyArg i) :| [prettyArg es]
Case _ arg cs -> iterAppDocM $ \_ prettyArg -> "case" :| [prettyArg arg, prettyArg (toList cs)]

instance
(PrettyReadableBy configName (Term name uni fun a)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ module UntypedPlutusCore.Core.Type
import Control.Lens
import PlutusPrelude

import Data.Vector
import Data.Word
import PlutusCore.Builtin qualified as TPLC
import PlutusCore.Core qualified as TPLC
Expand Down Expand Up @@ -85,7 +86,7 @@ data Term name uni fun ann
-- TODO: try spine-strict list or strict list or vector
-- See Note [Constr tag type]
| Constr !ann !Word64 ![Term name uni fun ann]
| Case !ann !(Term name uni fun ann) ![Term name uni fun ann]
| Case !ann !(Term name uni fun ann) !(Vector (Term name uni fun ann))
deriving stock (Functor, Generic)

deriving stock instance (Show name, GShow uni, Everywhere uni Show, Show fun, Show ann, Closed uni)
Expand Down Expand Up @@ -123,7 +124,7 @@ instance TermLike (Term name uni fun) TPLC.TyName name uni fun where
iWrap = \_ _ _ -> id
error = \ann _ -> Error ann
constr = \ann _ i es -> Constr ann i es
kase = \ann _ arg cs -> Case ann arg cs
kase = \ann _ arg cs -> Case ann arg (fromList cs)

instance TPLC.HasConstant (Term name uni fun ()) where
asConstant (Constant _ val) = pure val
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ import UntypedPlutusCore.Evaluation.Machine.Cek.CekMachineCosts (CekMachineCosts
CekMachineCostsBase (..))
import UntypedPlutusCore.Evaluation.Machine.Cek.StepCounter

import Control.Lens ((^?))
import Control.Lens.Review
import Control.Monad (unless, when)
import Control.Monad.Catch
Expand All @@ -97,10 +96,10 @@ import Data.DList qualified as DList
import Data.Functor.Identity
import Data.Hashable (Hashable)
import Data.Kind qualified as GHC
import Data.List.Extras (wix)
import Data.Proxy
import Data.Semigroup (stimes)
import Data.Text (Text)
import Data.Vector qualified as V
import Data.Word
import GHC.TypeLits
import Prettyprinter
Expand Down Expand Up @@ -574,7 +573,7 @@ data Context uni fun ann
-- See Note [Accumulators for terms]
| FrameConstr !(CekValEnv uni fun ann) {-# UNPACK #-} !Word64 ![NTerm uni fun ann] !(ArgStack uni fun ann) !(Context uni fun ann)
-- ^ @(constr i V0 ... Vj-1 _ Nj ... Nn)@
| FrameCases !(CekValEnv uni fun ann) ![NTerm uni fun ann] !(Context uni fun ann)
| FrameCases !(CekValEnv uni fun ann) !(V.Vector (NTerm uni fun ann)) !(Context uni fun ann)
-- ^ @(case _ C0 .. Cn)@
| NoFrame

Expand Down Expand Up @@ -765,10 +764,17 @@ enterComputeCek = computeCek
let done' = ConsStack e done
case todo of
(next : todo') -> computeCek (FrameConstr env i todo' done' ctx) env next
_ -> returnCek ctx $ VConstr i done'
[] -> returnCek ctx $ VConstr i done'
-- s , case _ (C0 ... CN, ρ) ◅ constr i V1 .. Vm ↦ s , [_ V1 ... Vm] ; ρ ▻ Ci
returnCek (FrameCases env cs ctx) e = case e of
(VConstr i args) -> case cs ^? wix i of
-- If the index is larger than the max bound of an Int, or negative, then it's a bad index
-- As it happens, this will currently never trigger, since i is a Word64, and the largest
-- Word64 value wraps to -1 as an Int64. So you can't wrap around enough to get an
-- "apparently good" value.
(VConstr i _) | fromIntegral @_ @Integer i > fromIntegral @Int @Integer maxBound ->
throwingDischarged _MachineError (MissingCaseBranch i) e
-- Otherwise, we can safely convert the index to an Int and use it
(VConstr i args) -> case (V.!?) cs (fromIntegral i) of
Just t -> computeCek (transferArgStack args ctx) env t
Nothing -> throwingDischarged _MachineError (MissingCaseBranch i) e
_ -> throwingDischarged _MachineError NonConstrScrutinized e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ import UntypedPlutusCore.Evaluation.Machine.Cek.StepCounter

import Control.Lens hiding (Context)
import Control.Monad
import Data.List.Extras (wix)
import Data.Proxy
import Data.RandomAccessList.Class qualified as Env
import Data.Semigroup (stimes)
import Data.Text (Text)
import Data.Vector qualified as V
import Data.Word (Word64)
import GHC.TypeNats

Expand Down Expand Up @@ -99,7 +99,7 @@ data Context uni fun ann
| FrameAwaitFunValue ann !(CekValue uni fun ann) !(Context uni fun ann)
| FrameForce ann !(Context uni fun ann) -- ^ @(force _)@
| FrameConstr ann !(CekValEnv uni fun ann) {-# UNPACK #-} !Word64 ![NTerm uni fun ann] !(ArgStack uni fun ann) !(Context uni fun ann)
| FrameCases ann !(CekValEnv uni fun ann) ![NTerm uni fun ann] !(Context uni fun ann)
| FrameCases ann !(CekValEnv uni fun ann) !(V.Vector (NTerm uni fun ann)) !(Context uni fun ann)
| NoFrame

deriving stock instance (GShow uni, Everywhere uni Show, Show fun, Show ann, Closed uni)
Expand Down Expand Up @@ -154,7 +154,7 @@ computeCek !ctx !env (Constr ann i es) = do
stepAndMaybeSpend BConstr
case es of
(t : rest) -> computeCek (FrameConstr ann env i rest EmptyStack ctx) env t
_ -> returnCek ctx $ VConstr i EmptyStack
[] -> returnCek ctx $ VConstr i EmptyStack
-- s ; ρ ▻ case S C0 ... Cn ↦ s , case _ (C0 ... Cn, ρ) ; ρ ▻ S
computeCek !ctx !env (Case ann scrut cs) = do
stepAndMaybeSpend BCase
Expand Down Expand Up @@ -192,10 +192,16 @@ returnCek (FrameConstr ann env i todo done ctx) e = do
let done' = ConsStack e done
case todo of
(next : todo') -> computeCek (FrameConstr ann env i todo' done' ctx) env next
_ -> returnCek ctx $ VConstr i done'
[] -> returnCek ctx $ VConstr i done'
-- s , case _ (C0 ... CN, ρ) ◅ constr i V1 .. Vm ↦ s , [_ V1 ... Vm] ; ρ ▻ Ci
returnCek (FrameCases ann env cs ctx) e = case e of
(VConstr i args) -> case cs ^? wix i of
-- If the index is larger than the max bound of an Int, or negative, then it's a bad index
-- As it happens, this will currently never trigger, since i is a Word64, and the largest
-- Word64 value wraps to -1 as an Int64. So you can't wrap around enough to get an
-- "apparently good" value.
(VConstr i _) | fromIntegral @_ @Integer i > fromIntegral @Int @Integer maxBound ->
throwingDischarged _MachineError (MissingCaseBranch i) e
(VConstr i args) -> case (V.!?) cs (fromIntegral i) of
Just t ->
let ctx' = transferArgStack ann args ctx
in computeCek ctx' env t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import UntypedPlutusCore.Core.Type qualified as UPLC
import UntypedPlutusCore.Rename (Rename (rename))

import Data.Text (Text)
import Data.Vector qualified as V
import PlutusCore.Error (AsParserErrorBundle)
import PlutusCore.MkPlc (mkIterApp)
import PlutusCore.Parser hiding (parseProgram, parseTerm, program)
Expand Down Expand Up @@ -81,7 +82,7 @@ constrTerm = withSpan $ \sp ->
caseTerm :: Parser PTerm
caseTerm = withSpan $ \sp ->
inParens $ do
res <- UPLC.Case sp <$> (symbol "case" *> term) <*> many term
res <- UPLC.Case sp <$> (symbol "case" *> term) <*> (V.fromList <$> many term)
whenVersion (\v -> v < plcVersion110) $ fail "'case' is not allowed before version 1.1.0"
pure res

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ module UntypedPlutusCore.Transform.CaseReduce
import PlutusCore.MkPlc
import UntypedPlutusCore.Core

import Control.Lens (transformOf, (^?))
import Data.List.Extras
import Control.Lens (transformOf)
import Data.Vector qualified as V

caseReduce :: Term name uni fun a -> Term name uni fun a
caseReduce = transformOf termSubterms processTerm

processTerm :: Term name uni fun a -> Term name uni fun a
processTerm = \case
Case ann (Constr _ i args) cs | Just c <- cs ^? wix i -> mkIterApp c ((ann,) <$> args)
Case ann (Constr _ i args) cs | Just c <- (V.!?) cs (fromIntegral i) ->
mkIterApp c ((ann,) <$> args)
t -> t
3 changes: 2 additions & 1 deletion plutus-core/untyped-plutus-core/test/Generators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import UntypedPlutusCore.Parser (parseProgram, parseTerm)
import Control.Lens (view)
import Data.Text (Text)
import Data.Text qualified as T
import Data.Vector qualified as V

import Hedgehog (annotate, failure, property, tripping, (===))
import Hedgehog.Gen qualified as Gen
Expand Down Expand Up @@ -58,7 +59,7 @@ compareTerm (Delay _ t ) (Delay _ t') = compareTerm t t'
compareTerm (Constant _ x) (Constant _ y) = x == y
compareTerm (Builtin _ bi) (Builtin _ bi') = bi == bi'
compareTerm (Constr _ i es) (Constr _ i' es') = i == i' && maybe False (all (uncurry compareTerm)) (zipExact es es')
compareTerm (Case _ arg cs) (Case _ arg' cs') = compareTerm arg arg' && maybe False (all (uncurry compareTerm)) (zipExact cs cs')
compareTerm (Case _ arg cs) (Case _ arg' cs') = compareTerm arg arg' && maybe False (all (uncurry compareTerm)) (zipExact (V.toList cs) (V.toList cs'))
compareTerm (Error _ ) (Error _ ) = True
compareTerm _ _ = False

Expand Down
9 changes: 5 additions & 4 deletions plutus-core/untyped-plutus-core/test/Transform/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import Control.Lens ((&), (.~))
import Data.ByteString.Lazy qualified as BSL
import Data.Text (Text)
import Data.Text.Encoding (encodeUtf8)
import Data.Vector qualified as V
import Test.Tasty
import Test.Tasty.Golden

Expand All @@ -41,7 +42,7 @@ caseOfCase1 = runQuote $ do
let ite = Force () (Builtin () PLC.IfThenElse)
true = Constr () 0 []
false = Constr () 1 []
alts = [mkConstant @Integer () 1, mkConstant @Integer () 2]
alts = V.fromList [mkConstant @Integer () 1, mkConstant @Integer () 2]
pure $ Case () (mkIterApp ite [((), Var () b), ((), true), ((), false)]) alts

{- | This should not simplify, because one of the branches of `ifThenElse` is not a `Constr`.
Expand All @@ -55,7 +56,7 @@ caseOfCase2 = runQuote $ do
let ite = Force () (Builtin () PLC.IfThenElse)
true = Var () t
false = Constr () 1 []
alts = [mkConstant @Integer () 1, mkConstant @Integer () 2]
alts = V.fromList [mkConstant @Integer () 1, mkConstant @Integer () 2]
pure $ Case () (mkIterApp ite [((), Var () b), ((), true), ((), false)]) alts

{- | Similar to `caseOfCase1`, but the type of the @true@ and @false@ branches is
Expand All @@ -72,7 +73,7 @@ caseOfCase3 = runQuote $ do
false = Constr () 1 []
altTrue = Var () f
altFalse = mkConstant @Integer () 2
alts = [altTrue, altFalse]
alts = V.fromList [altTrue, altFalse]
pure $ Case () (mkIterApp ite [((), Var () b), ((), true), ((), false)]) alts

-- | The `Delay` should be floated into the lambda.
Expand Down Expand Up @@ -357,7 +358,7 @@ cse1 = runQuote $ do
branch1 = plus onePlusTwoPlusX threePlusX
branch2 = plus twoPlusX threePlusX
branch3 = fourPlusX
caseExpr = Case () (Var () y) [branch1, branch2, branch3]
caseExpr = Case () (Var () y) (V.fromList [branch1, branch2, branch3])
pure $ LamAbs () x (LamAbs () y body)

-- | This is the second example in Note [CSE].
Expand Down
6 changes: 3 additions & 3 deletions plutus-ledger-api/test/Spec/Versions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ errorScript :: SerialisedScript
errorScript = serialiseUPLC $ UPLC.Program () PLC.plcVersion100 $ UPLC.Error ()

v110script :: UPLC.Program UPLC.DeBruijn UPLC.DefaultUni UPLC.DefaultFun ()
v110script = UPLC.Program () PLC.plcVersion110 $ UPLC.Constr () 0 []
v110script = UPLC.Program () PLC.plcVersion110 $ UPLC.Constr () 0 mempty

badConstrScript :: UPLC.Program UPLC.DeBruijn UPLC.DefaultUni UPLC.DefaultFun ()
badConstrScript = UPLC.Program () PLC.plcVersion100 $ UPLC.Constr () 0 []
badConstrScript = UPLC.Program () PLC.plcVersion100 $ UPLC.Constr () 0 mempty

badCaseScript :: UPLC.Program UPLC.DeBruijn UPLC.DefaultUni UPLC.DefaultFun ()
badCaseScript = UPLC.Program () PLC.plcVersion100 $ UPLC.Case () (UPLC.Error ()) []
badCaseScript = UPLC.Program () PLC.plcVersion100 $ UPLC.Case () (UPLC.Error ()) mempty

-- Note that bls can work also with plcversion==1.0.0
blsExScript :: SerialisedScript
Expand Down
7 changes: 4 additions & 3 deletions plutus-metatheory/src/Untyped.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import UntypedPlutusCore
import Data.ByteString as BS hiding (map)
import Data.Text as T hiding (map)
import Data.Word (Word64)
import GHC.Exts (IsList (..))
import Universe

-- Untyped (Raw) syntax
Expand Down Expand Up @@ -41,8 +42,8 @@ conv (Constant _ c) = UCon c
conv (Error _) = UError
conv (Delay _ t) = UDelay (conv t)
conv (Force _ t) = UForce (conv t)
conv (Constr _ i es) = UConstr (toInteger i) (fmap conv es)
conv (Case _ arg cs) = UCase (conv arg) (fmap conv cs)
conv (Constr _ i es) = UConstr (toInteger i) (toList (fmap conv es))
conv (Case _ arg cs) = UCase (conv arg) (toList (fmap conv cs))

tmnames = ['a' .. 'z']

Expand All @@ -63,5 +64,5 @@ uconv i (UBuiltin b) = Builtin () b
uconv i (UDelay t) = Delay () (uconv i t)
uconv i (UForce t) = Force () (uconv i t)
uconv i (UConstr j xs) = Constr () (fromInteger j) (fmap (uconv i) xs)
uconv i (UCase t xs) = Case () (uconv i t) (fmap (uconv i) xs)
uconv i (UCase t xs) = Case () (uconv i t) (fromList (fmap (uconv i) xs))