Skip to content

Commit

Permalink
Use Vector in Case
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelpj committed Apr 2, 2024
1 parent cb91fa6 commit 4c4f04b
Show file tree
Hide file tree
Showing 17 changed files with 84 additions and 68 deletions.
2 changes: 2 additions & 0 deletions plutus-core/plutus-core.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ library
, time
, transformers
, unordered-containers
, vector
, witherable

if impl(ghc <9.0)
Expand Down Expand Up @@ -445,6 +446,7 @@ test-suite untyped-plutus-core-test
, tasty-hunit
, tasty-quickcheck
, text
, vector

executable plc
import: lang
Expand Down
5 changes: 3 additions & 2 deletions plutus-core/plutus-core/src/PlutusCore/Compiler/Erase.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module PlutusCore.Compiler.Erase (eraseTerm, eraseProgram) where

import Data.Vector (fromList)
import PlutusCore.Core
import UntypedPlutusCore.Core qualified as UPLC

Expand All @@ -15,8 +16,8 @@ eraseTerm (TyInst ann term _) = UPLC.Force ann (eraseTerm term)
eraseTerm (Unwrap _ term) = eraseTerm term
eraseTerm (IWrap _ _ _ term) = eraseTerm term
eraseTerm (Error ann _) = UPLC.Error ann
eraseTerm (Constr ann _ i args) = UPLC.Constr ann i (fmap eraseTerm args)
eraseTerm (Case ann _ arg cs) = UPLC.Case ann (eraseTerm arg) (fmap eraseTerm cs)
eraseTerm (Constr ann _ i args) = UPLC.Constr ann i (fromList $ fmap eraseTerm args)
eraseTerm (Case ann _ arg cs) = UPLC.Case ann (eraseTerm arg) (fromList $ fmap eraseTerm cs)

eraseProgram :: Program tyname name uni fun ann -> UPLC.Program name uni fun ann
eraseProgram (Program a v t) = UPLC.Program a v $ eraseTerm t
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import Universe

import Data.Foldable (for_)
import Data.Hashable
import Data.Vector qualified as V

instance (GEq uni, Closed uni, uni `Everywhere` Eq, Eq fun, Eq ann) =>
Eq (Term Name uni fun ann) where
Expand All @@ -35,6 +36,10 @@ type HashableTermConstraints uni fun ann =
, Hashable fun
)

-- TODO: dubious
instance Hashable a => Hashable (V.Vector a) where
hashWithSalt s = hashWithSalt s . toList

instance HashableTermConstraints uni fun ann => Hashable (Term Name uni fun ann)

-- Simple Structural Equality of a `Term NamedDeBruijn`. This implies three things:
Expand Down Expand Up @@ -94,13 +99,13 @@ eqTermM (Error ann1) (Error ann2) = eqM ann1 ann2
eqTermM (Constr ann1 i1 args1) (Constr ann2 i2 args2) = do
eqM ann1 ann2
eqM i1 i2
case zipExact args1 args2 of
case zipExact (toList args1) (toList args2) of
Just ps -> for_ ps $ \(t1, t2) -> eqTermM t1 t2
Nothing -> empty
eqTermM (Case ann1 a1 cs1) (Case ann2 a2 cs2) = do
eqM ann1 ann2
eqTermM a1 a2
case zipExact cs1 cs2 of
case zipExact (toList cs1) (toList cs2) of
Just ps -> for_ ps $ \(t1, t2) -> eqTermM t1 t2
Nothing -> empty
eqTermM Constant{} _ = empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import PlutusCore.Version qualified as PLC
import UntypedPlutusCore.Core.Type

import Control.Monad
import Data.Vector qualified as V
import Data.Word (Word8)
import Flat
import Flat.Decoder
Expand Down Expand Up @@ -127,8 +128,8 @@ encodeTerm = \case
Force ann t -> encodeTermTag 5 <> encode ann <> encodeTerm t
Error ann -> encodeTermTag 6 <> encode ann
Builtin ann bn -> encodeTermTag 7 <> encode ann <> encode bn
Constr ann i es -> encodeTermTag 8 <> encode ann <> encode i <> encodeListWith encodeTerm es
Case ann arg cs -> encodeTermTag 9 <> encode ann <> encodeTerm arg <> encodeListWith encodeTerm cs
Constr ann i es -> encodeTermTag 8 <> encode ann <> encode i <> encodeListWith encodeTerm (V.toList es)
Case ann arg cs -> encodeTermTag 9 <> encode ann <> encodeTerm arg <> encodeListWith encodeTerm (V.toList cs)

decodeTerm
:: forall name uni fun ann
Expand Down Expand Up @@ -162,10 +163,10 @@ decodeTerm version builtinPred = go
Just e -> fail e
handleTerm 8 = do
unless (version >= PLC.plcVersion110) $ fail $ "'constr' is not allowed before version 1.1.0, this program has version: " ++ (show $ pretty version)
Constr <$> decode <*> decode <*> decodeListWith go
Constr <$> decode <*> decode <*> (V.fromList <$> decodeListWith go)
handleTerm 9 = do
unless (version >= PLC.plcVersion110) $ fail $ "'case' is not allowed before version 1.1.0, this program has version: " ++ (show $ pretty version)
Case <$> decode <*> go <*> decodeListWith go
Case <$> decode <*> go <*> (V.fromList <$> decodeListWith go)
handleTerm t = fail $ "Unknown term constructor tag: " ++ show t

sizeTerm
Expand All @@ -192,8 +193,8 @@ sizeTerm tm sz =
Force ann t -> size ann $ sizeTerm t sz'
Error ann -> size ann sz'
Builtin ann bn -> size ann $ size bn sz'
Constr ann i es -> size ann $ size i $ sizeListWith sizeTerm es sz'
Case ann arg cs -> size ann $ sizeTerm arg $ sizeListWith sizeTerm cs sz'
Constr ann i es -> size ann $ size i $ sizeListWith sizeTerm (V.toList es) sz'
Case ann arg cs -> size ann $ sizeTerm arg $ sizeListWith sizeTerm (V.toList cs) sz'

-- | An encoder for programs.
--
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ instance (PrettyClassicBy configName name, PrettyUni uni, Pretty fun, Pretty ann
sexp "force" (consAnnIf config ann
[prettyBy config term])
Constr ann i es ->
sexp "constr" (consAnnIf config ann (pretty i : fmap (prettyBy config) es))
sexp "constr" (consAnnIf config ann (pretty i : fmap (prettyBy config) (toList es)))
Case ann arg cs ->
sexp "case" (consAnnIf config ann (prettyBy config arg : fmap (prettyBy config) cs))
sexp "case" (consAnnIf config ann (prettyBy config arg : fmap (prettyBy config) (toList cs)))
where
prettyTypeOf :: Some (ValueOf uni) -> Doc dann
prettyTypeOf (Some (ValueOf uni _ )) = prettyBy juxtRenderContext $ SomeTypeIn uni
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ instance
Error _ -> unitDocM "error"
-- Always rendering the tag on the same line for more compact output, it's just a tiny integer
-- anyway.
Constr _ i es -> iterAppDocM $ \_ prettyArg -> ("constr" <+> prettyArg i) :| [prettyArg es]
Case _ arg cs -> iterAppDocM $ \_ prettyArg -> "case" :| [prettyArg arg, prettyArg cs]
Constr _ i es -> iterAppDocM $ \_ prettyArg -> ("constr" <+> prettyArg i) :| [prettyArg (toList es)]
Case _ arg cs -> iterAppDocM $ \_ prettyArg -> "case" :| [prettyArg arg, prettyArg (toList cs)]

instance
(PrettyReadableBy configName (Term name uni fun a)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ module UntypedPlutusCore.Core.Type
import Control.Lens
import PlutusPrelude

import Data.Vector
import Data.Word
import PlutusCore.Builtin qualified as TPLC
import PlutusCore.Core qualified as TPLC
Expand Down Expand Up @@ -84,8 +85,8 @@ data Term name uni fun ann
-- TODO: worry about overflow, maybe use an Integer
-- TODO: try spine-strict list or strict list or vector
-- See Note [Constr tag type]
| Constr !ann !Word64 ![Term name uni fun ann]
| Case !ann !(Term name uni fun ann) ![Term name uni fun ann]
| Constr !ann !Word64 !(Vector (Term name uni fun ann))
| Case !ann !(Term name uni fun ann) !(Vector (Term name uni fun ann))
deriving stock (Functor, Generic)

deriving stock instance (Show name, GShow uni, Everywhere uni Show, Show fun, Show ann, Closed uni)
Expand Down Expand Up @@ -122,8 +123,8 @@ instance TermLike (Term name uni fun) TPLC.TyName name uni fun where
unwrap = const id
iWrap = \_ _ _ -> id
error = \ann _ -> Error ann
constr = \ann _ i es -> Constr ann i es
kase = \ann _ arg cs -> Case ann arg cs
constr = \ann _ i es -> Constr ann i (fromList es)
kase = \ann _ arg cs -> Case ann arg (fromList cs)

instance TPLC.HasConstant (Term name uni fun ()) where
asConstant (Constant _ val) = pure val
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ import UntypedPlutusCore.Evaluation.Machine.Cek.CekMachineCosts (CekMachineCosts
CekMachineCostsBase (..))
import UntypedPlutusCore.Evaluation.Machine.Cek.StepCounter

import Control.Lens ((^?))
import Control.Lens.Review
import Control.Monad (unless, when)
import Control.Monad.Catch
Expand All @@ -97,10 +96,10 @@ import Data.DList qualified as DList
import Data.Functor.Identity
import Data.Hashable (Hashable)
import Data.Kind qualified as GHC
import Data.List.Extras (wix)
import Data.Proxy
import Data.Semigroup (stimes)
import Data.Text (Text)
import Data.Vector qualified as V
import Data.Word
import GHC.TypeLits
import Prettyprinter
Expand Down Expand Up @@ -537,11 +536,11 @@ dischargeCekValue = \case
-- or (b) it's needed for an error message.
-- @term@ is fully discharged, so we can return it directly without any further discharging.
VBuiltin _ term _ -> term
VConstr i es -> Constr () i (fmap dischargeCekValue $ stack2list es)
VConstr i es -> Constr () i (fmap dischargeCekValue $ stack2vec es)
where
stack2list = go []
stack2vec = go mempty
go acc EmptyStack = acc
go acc (ConsStack arg rest) = go (arg : acc) rest
go acc (ConsStack arg rest) = go (arg `V.cons` acc) rest

instance (PrettyUni uni, Pretty fun) => PrettyBy PrettyConfigPlc (CekValue uni fun ann) where
prettyBy cfg = prettyBy cfg . dischargeCekValue
Expand Down Expand Up @@ -572,9 +571,9 @@ data Context uni fun ann
| FrameForce !(Context uni fun ann)
-- ^ @(force _)@
-- See Note [Accumulators for terms]
| FrameConstr !(CekValEnv uni fun ann) {-# UNPACK #-} !Word64 ![NTerm uni fun ann] !(ArgStack uni fun ann) !(Context uni fun ann)
| FrameConstr !(CekValEnv uni fun ann) {-# UNPACK #-} !Word64 !(V.Vector (NTerm uni fun ann)) !(ArgStack uni fun ann) !(Context uni fun ann)
-- ^ @(constr i V0 ... Vj-1 _ Nj ... Nn)@
| FrameCases !(CekValEnv uni fun ann) ![NTerm uni fun ann] !(Context uni fun ann)
| FrameCases !(CekValEnv uni fun ann) !(V.Vector (NTerm uni fun ann)) !(Context uni fun ann)
-- ^ @(case _ C0 .. Cn)@
| NoFrame

Expand Down Expand Up @@ -719,9 +718,9 @@ enterComputeCek = computeCek
-- s ; ρ ▻ constr I T0 .. Tn ↦ s , constr I _ (T1 ... Tn, ρ) ; ρ ▻ T0
computeCek !ctx !env (Constr _ i es) = do
stepAndMaybeSpend BConstr
case es of
(t : rest) -> computeCek (FrameConstr env i rest EmptyStack ctx) env t
[] -> returnCek ctx $ VConstr i EmptyStack
case V.uncons es of
Just (t, rest) -> computeCek (FrameConstr env i rest EmptyStack ctx) env t
Nothing -> returnCek ctx $ VConstr i EmptyStack
-- s ; ρ ▻ case S C0 ... Cn ↦ s , case _ (C0 ... Cn, ρ) ; ρ ▻ S
computeCek !ctx !env (Case _ scrut cs) = do
stepAndMaybeSpend BCase
Expand Down Expand Up @@ -763,12 +762,13 @@ enterComputeCek = computeCek
-- s , constr I V0 ... Vj-1 _ (Tj+1 ... Tn, ρ) ◅ Vj ↦ s , constr i V0 ... Vj _ (Tj+2... Tn, ρ) ; ρ ▻ Tj+1
returnCek (FrameConstr env i todo done ctx) e = do
let done' = ConsStack e done
case todo of
(next : todo') -> computeCek (FrameConstr env i todo' done' ctx) env next
_ -> returnCek ctx $ VConstr i done'
case V.uncons todo of
Just (next, todo') -> computeCek (FrameConstr env i todo' done' ctx) env next
Nothing -> returnCek ctx $ VConstr i done'
-- s , case _ (C0 ... CN, ρ) ◅ constr i V1 .. Vm ↦ s , [_ V1 ... Vm] ; ρ ▻ Ci
returnCek (FrameCases env cs ctx) e = case e of
(VConstr i args) -> case cs ^? wix i of
-- TODO: handle word/int conversion better
(VConstr i args) -> case (V.!?) cs (fromIntegral i) of
Just t -> computeCek (transferArgStack args ctx) env t
Nothing -> throwingDischarged _MachineError (MissingCaseBranch i) e
_ -> throwingDischarged _MachineError NonConstrScrutinized e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ import UntypedPlutusCore.Evaluation.Machine.Cek.StepCounter

import Control.Lens hiding (Context)
import Control.Monad
import Data.List.Extras (wix)
import Data.Proxy
import Data.RandomAccessList.Class qualified as Env
import Data.Semigroup (stimes)
import Data.Text (Text)
import Data.Vector qualified as V
import Data.Word (Word64)
import GHC.TypeNats

Expand Down Expand Up @@ -98,8 +98,8 @@ data Context uni fun ann
| FrameAwaitFunTerm ann !(CekValEnv uni fun ann) !(NTerm uni fun ann) !(Context uni fun ann) -- ^ @[_ N]@
| FrameAwaitFunValue ann !(CekValue uni fun ann) !(Context uni fun ann)
| FrameForce ann !(Context uni fun ann) -- ^ @(force _)@
| FrameConstr ann !(CekValEnv uni fun ann) {-# UNPACK #-} !Word64 ![NTerm uni fun ann] !(ArgStack uni fun ann) !(Context uni fun ann)
| FrameCases ann !(CekValEnv uni fun ann) ![NTerm uni fun ann] !(Context uni fun ann)
| FrameConstr ann !(CekValEnv uni fun ann) {-# UNPACK #-} !Word64 !(V.Vector (NTerm uni fun ann)) !(ArgStack uni fun ann) !(Context uni fun ann)
| FrameCases ann !(CekValEnv uni fun ann) !(V.Vector (NTerm uni fun ann)) !(Context uni fun ann)
| NoFrame

deriving stock instance (GShow uni, Everywhere uni Show, Show fun, Show ann, Closed uni)
Expand Down Expand Up @@ -152,9 +152,9 @@ computeCek !ctx !_ (Builtin _ bn) = do
-- s ; ρ ▻ constr I T0 .. Tn ↦ s , constr I _ (T1 ... Tn, ρ) ; ρ ▻ T0
computeCek !ctx !env (Constr ann i es) = do
stepAndMaybeSpend BConstr
case es of
(t : rest) -> computeCek (FrameConstr ann env i rest EmptyStack ctx) env t
_ -> returnCek ctx $ VConstr i EmptyStack
case V.uncons es of
Just (t, rest) -> computeCek (FrameConstr ann env i rest EmptyStack ctx) env t
Nothing -> returnCek ctx $ VConstr i EmptyStack
-- s ; ρ ▻ case S C0 ... Cn ↦ s , case _ (C0 ... Cn, ρ) ; ρ ▻ S
computeCek !ctx !env (Case ann scrut cs) = do
stepAndMaybeSpend BCase
Expand Down Expand Up @@ -190,12 +190,12 @@ returnCek (FrameAwaitFunValue _ arg ctx) fun =
-- s , constr I V0 ... Vj-1 _ (Tj+1 ... Tn, ρ) ◅ Vj ↦ s , constr i V0 ... Vj _ (Tj+2... Tn, ρ) ; ρ ▻ Tj+1
returnCek (FrameConstr ann env i todo done ctx) e = do
let done' = ConsStack e done
case todo of
(next : todo') -> computeCek (FrameConstr ann env i todo' done' ctx) env next
_ -> returnCek ctx $ VConstr i done'
case V.uncons todo of
Just (next, todo') -> computeCek (FrameConstr ann env i todo' done' ctx) env next
Nothing -> returnCek ctx $ VConstr i done'
-- s , case _ (C0 ... CN, ρ) ◅ constr i V1 .. Vm ↦ s , [_ V1 ... Vm] ; ρ ▻ Ci
returnCek (FrameCases ann env cs ctx) e = case e of
(VConstr i args) -> case cs ^? wix i of
(VConstr i args) -> case (V.!?) cs (fromIntegral i) of
Just t ->
let ctx' = transferArgStack ann args ctx
in computeCek ctx' env t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import UntypedPlutusCore.Core.Type qualified as UPLC
import UntypedPlutusCore.Rename (Rename (rename))

import Data.Text (Text)
import Data.Vector qualified as V
import PlutusCore.Error (AsParserErrorBundle)
import PlutusCore.MkPlc (mkIterApp)
import PlutusCore.Parser hiding (parseProgram, parseTerm, program)
Expand Down Expand Up @@ -74,14 +75,14 @@ errorTerm = withSpan $ \sp ->
constrTerm :: Parser PTerm
constrTerm = withSpan $ \sp ->
inParens $ do
res <- UPLC.Constr sp <$> (symbol "constr" *> lexeme Lex.decimal) <*> many term
res <- UPLC.Constr sp <$> (symbol "constr" *> lexeme Lex.decimal) <*> (V.fromList <$> many term)
whenVersion (\v -> v < plcVersion110) $ fail "'constr' is not allowed before version 1.1.0"
pure res

caseTerm :: Parser PTerm
caseTerm = withSpan $ \sp ->
inParens $ do
res <- UPLC.Case sp <$> (symbol "case" *> term) <*> many term
res <- UPLC.Case sp <$> (symbol "case" *> term) <*> (V.fromList <$> many term)
whenVersion (\v -> v < plcVersion110) $ fail "'case' is not allowed before version 1.1.0"
pure res

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ module UntypedPlutusCore.Transform.CaseReduce
import PlutusCore.MkPlc
import UntypedPlutusCore.Core

import Control.Lens (transformOf, (^?))
import Data.List.Extras
import Control.Lens (transformOf)
import Data.Vector qualified as V

caseReduce :: Term name uni fun a -> Term name uni fun a
caseReduce = transformOf termSubterms processTerm

processTerm :: Term name uni fun a -> Term name uni fun a
processTerm = \case
Case ann (Constr _ i args) cs | Just c <- cs ^? wix i -> mkIterApp c ((ann,) <$> args)
Case ann (Constr _ i args) cs | Just c <- (V.!?) cs (fromIntegral i) ->
mkIterApp c ((ann,) <$> (V.toList args))
t -> t
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ costIsAcceptable = \case
Apply{} -> False
-- Inlining constructors of size 1 or 0 seems okay, but does result in doing
-- the work for the elements at each use site.
Constr _ _ es -> case es of
Constr _ _ es -> case toList es of
[] -> True
[e] -> costIsAcceptable e
_ -> False
Expand All @@ -425,7 +425,7 @@ sizeIsAcceptable inlineConstants = \case
-- See Note [Differences from PIR inliner] 4
LamAbs{} -> False
-- Inlining constructors of size 1 or 0 seems okay
Constr _ _ es -> case es of
Constr _ _ es -> case toList es of
[] -> True
[e] -> sizeIsAcceptable inlineConstants e
_ -> False
Expand Down
3 changes: 2 additions & 1 deletion plutus-core/untyped-plutus-core/test/Analysis/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module Analysis.Spec where

import Test.Tasty.Extras

import Data.Vector qualified as V
import PlutusCore qualified as PLC
import PlutusCore.MkPlc
import PlutusCore.Pretty (prettyPlcReadableDef)
Expand All @@ -28,7 +29,7 @@ dangerTerm = runQuote $ do
-- The UPLC term type is strict, so it's hard to hide an undefined in there
-- Take advantage of the fact that it's still using lazy lists for constr
-- arguments for now.
pure $ Apply () (Apply () (Var () n) (Var () m)) (Constr () 1 [undefined])
pure $ Apply () (Apply () (Var () n) (Var () m)) (Constr () 1 (V.fromList [undefined]))

letFun :: Term Name PLC.DefaultUni PLC.DefaultFun ()
letFun = runQuote $ do
Expand Down
5 changes: 3 additions & 2 deletions plutus-core/untyped-plutus-core/test/Generators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import UntypedPlutusCore.Parser (parseProgram, parseTerm)
import Control.Lens (view)
import Data.Text (Text)
import Data.Text qualified as T
import Data.Vector qualified as V

import Hedgehog (annotate, failure, property, tripping, (===))
import Hedgehog.Gen qualified as Gen
Expand Down Expand Up @@ -57,8 +58,8 @@ compareTerm (Force _ t ) (Force _ t') = compareTerm t t'
compareTerm (Delay _ t ) (Delay _ t') = compareTerm t t'
compareTerm (Constant _ x) (Constant _ y) = x == y
compareTerm (Builtin _ bi) (Builtin _ bi') = bi == bi'
compareTerm (Constr _ i es) (Constr _ i' es') = i == i' && maybe False (all (uncurry compareTerm)) (zipExact es es')
compareTerm (Case _ arg cs) (Case _ arg' cs') = compareTerm arg arg' && maybe False (all (uncurry compareTerm)) (zipExact cs cs')
compareTerm (Constr _ i es) (Constr _ i' es') = i == i' && maybe False (all (uncurry compareTerm)) (zipExact (V.toList es) (V.toList es'))
compareTerm (Case _ arg cs) (Case _ arg' cs') = compareTerm arg arg' && maybe False (all (uncurry compareTerm)) (zipExact (V.toList cs) (V.toList cs'))
compareTerm (Error _ ) (Error _ ) = True
compareTerm _ _ = False

Expand Down

0 comments on commit 4c4f04b

Please sign in to comment.