diff --git a/Plutarch/Rec.hs b/Plutarch/Rec.hs index 02594ad24..0f34c5b01 100644 --- a/Plutarch/Rec.hs +++ b/Plutarch/Rec.hs @@ -1,18 +1,26 @@ +{-# LANGUAGE DefaultSignatures #-} + module Plutarch.Rec ( + DataReader (DataReader, readData), PRecord (PRecord, getRecord), ScottEncoded, ScottEncoding, + RecordFromData (fieldFoci, fieldListFoci), field, + fieldFromData, letrec, pletrec, + recordFromFieldReaders, ) where import Control.Monad.Trans.State.Lazy (State, evalState, get, put) -import Data.Functor.Compose (Compose) +import Data.Functor.Compose (Compose (Compose, getCompose)) import Data.Kind (Type) import Data.Monoid (Dual (Dual, getDual), Endo (Endo, appEndo), Sum (Sum, getSum)) import Numeric.Natural (Natural) -import Plutarch (PlutusType (PInner, pcon', pmatch'), phoistAcyclic, plam, punsafeCoerce, (#), (:-->)) +import Plutarch (PlutusType (PInner, pcon', pmatch'), pcon, phoistAcyclic, plam, plet, punsafeCoerce, (#), (:-->)) +import Plutarch.Bool (pif, (#==)) +import Plutarch.Builtin (PAsData, PBuiltinList, PData, pasConstr, pforgetData, pfstBuiltin, psndBuiltin) import Plutarch.Internal ( PType, RawTerm (RApply, RLamAbs, RVar), @@ -20,6 +28,8 @@ import Plutarch.Internal ( TermResult (TermResult, getDeps, getTerm), mapTerm, ) +import Plutarch.List (phead, ptail) +import Plutarch.Trace (ptraceError) import qualified Rank2 newtype PRecord r s = PRecord {getRecord :: r (Term s)} @@ -100,6 +110,74 @@ variables = Rank2.cotraverse var id , getDeps = [] } +newtype DataReader s a = DataReader {readData :: Term s (PAsData a) -> Term s a} +newtype FocusFromData s a b = FocusFromData {getFocus :: Term s (PAsData a :--> PAsData b)} +newtype FocusFromDataList s a = FocusFromDataList {getItem :: Term s (PBuiltinList PData) -> Term s (PAsData a)} + +{- | Converts a record of field DataReaders to a DataReader of the whole + record. If you only need a single field or two, use `fieldFromData` + instead. +-} +recordFromFieldReaders :: + forall r s. + (Rank2.Apply r, RecordFromData r) => + r (DataReader s) -> + DataReader s (PRecord r) +recordFromFieldReaders reader = DataReader $ verifySoleConstructor readRecord + where + readRecord :: Term s (PBuiltinList PData) -> Term s (PRecord r) + readRecord dat = pcon $ PRecord $ Rank2.liftA2 (flip readData . getCompose) (fields dat) reader + fields :: Term s (PBuiltinList PData) -> r (Compose (Term s) PAsData) + fields bis = (\f -> Compose $ getItem f bis) Rank2.<$> fieldListFoci + +{- | Converts a Haskell field function to a function term that extracts the 'Data' encoding of the field from the + encoding of the whole record. +-} +fieldFromData :: + RecordFromData r => + (r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) t) -> + Term s (PAsData (PRecord r) :--> PAsData t) +fieldFromData f = getFocus (f fieldFoci) + +{- | Instances of this class must know how to focus on individual fields of + the data-encoded record. If the declared order of the record fields doesn't + match the encoding order, you must override the method defaults. +-} +class (Rank2.Distributive r, Rank2.Traversable r) => RecordFromData r where + -- | Given the encoding of the whole record, every field focuses on its own encoding. + fieldFoci :: r (FocusFromData s (PRecord r)) + + -- | Given the encoding of the list of all fields, every field focuses on its own encoding. + fieldListFoci :: r (FocusFromDataList s) + + fieldFoci = Rank2.cotraverse focus id + where + focus :: (r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) a) -> FocusFromData s (PRecord r) a + focus ref = ref foci + foci :: r (FocusFromData s (PRecord r)) + foci = fieldsFromRecord Rank2.<$> fieldListFoci + fieldsFromRecord :: FocusFromDataList s a -> FocusFromData s (PRecord r) a + fieldsFromRecord (FocusFromDataList f) = FocusFromData $ plam $ verifySoleConstructor f + fieldListFoci = Rank2.cotraverse focus id + where + focus :: (r (FocusFromDataList s) -> FocusFromDataList s a) -> FocusFromDataList s a + focus ref = ref foci + foci :: r (FocusFromDataList s) + foci = evalState (Rank2.traverse next $ initial @r) id + next :: f a -> State (Term s (PBuiltinList PData) -> Term s (PBuiltinList PData)) (FocusFromDataList s a) + next _ = do + rest <- get + put ((ptail #) . rest) + return $ FocusFromDataList (punsafeCoerce . (phead #) . rest) + +verifySoleConstructor :: (Term s (PBuiltinList PData) -> Term s a) -> (Term s (PAsData (PRecord r)) -> Term s a) +verifySoleConstructor f d = + plet (pasConstr # pforgetData d) $ \constr -> + pif + (pfstBuiltin # constr #== 0) + (f $ psndBuiltin # constr) + (ptraceError "verifySoleConstructor failed") + initial :: Rank2.Distributive r => r (Compose Maybe (Term s)) initial = Rank2.distribute Nothing diff --git a/examples/Examples/LetRec.hs b/examples/Examples/LetRec.hs index 890e825fc..1bc37d7c7 100644 --- a/examples/Examples/LetRec.hs +++ b/examples/Examples/LetRec.hs @@ -2,13 +2,25 @@ module Examples.LetRec (tests) where -import Plutarch (pcon', pmatch', printTerm) +import Plutarch (pcon', pmatch', printTerm, punsafeBuiltin, punsafeCoerce) import Plutarch.Bool (PBool (PFalse, PTrue), pif, (#==)) +import Plutarch.Builtin (PAsData, PBuiltinList (PNil), PIsData, pdata, pforgetData, pfromData) import Plutarch.Integer (PInteger) import Plutarch.Prelude -import Plutarch.Rec (PRecord (PRecord), ScottEncoded, ScottEncoding, field, letrec) +import Plutarch.Rec ( + DataReader (DataReader, readData), + PRecord (PRecord), + RecordFromData, + ScottEncoded, + ScottEncoding, + field, + fieldFromData, + letrec, + recordFromFieldReaders, + ) import Plutarch.Rec.TH (deriveAll) -import Plutarch.String (PString) +import Plutarch.String (PString, pdecodeUtf8, pencodeUtf8) +import qualified PlutusCore as PLC import qualified Rank2.TH import Test.Tasty (TestTree, testGroup) import Test.Tasty.HUnit (testCase, (@?=)) @@ -30,6 +42,30 @@ type instance ScottEncoded EvenOdd a = (PInteger :--> PBool) :--> (PInteger :--> $(Rank2.TH.deriveAll ''EvenOdd) $(deriveAll ''SampleRecord) -- also autoderives the @type instance ScottEncoded@ +instance RecordFromData SampleRecord + +instance PIsData (PRecord SampleRecord) where + pfromData = readData (recordFromFieldReaders sampleReader) + pdata = recordData + +recordData :: forall s. Term s (PRecord SampleRecord) -> Term s (PAsData (PRecord SampleRecord)) +recordData r = pmatch r $ \(PRecord SampleRecord {sampleBool, sampleInt, sampleString}) -> + punsafeBuiltin PLC.ConstrData # (0 :: Term s PInteger) + #$ pconsBuiltin # pforgetData (pdata sampleBool) + #$ pconsBuiltin # pforgetData (pdata sampleInt) + #$ pconsBuiltin # pforgetData (pdata $ pencodeUtf8 # sampleString) + #$ pcon PNil + +pconsBuiltin :: Term s (a :--> PBuiltinList a :--> PBuiltinList a) +pconsBuiltin = phoistAcyclic $ pforce $ punsafeBuiltin PLC.MkCons + +sampleReader :: SampleRecord (DataReader s) +sampleReader = + SampleRecord + { sampleBool = DataReader pfromData + , sampleInt = DataReader pfromData + , sampleString = DataReader $ \d -> pdecodeUtf8 #$ pfromData $ punsafeCoerce d + } sampleRecord :: Term (s :: S) (ScottEncoding SampleRecord (t :: PType)) sampleRecord = @@ -61,6 +97,9 @@ evenOdd = letrec evenOddRecursion , odd = plam $ \n -> pif (n #== 0) (pcon PFalse) (even #$ n - 1) } +sampleData :: Term s (PAsData (PRecord SampleRecord)) +sampleData = pdata (punsafeCoerce sampleRecord) + tests :: HasTester => TestTree tests = testGroup @@ -83,4 +122,11 @@ tests = , testCase "even 4" $ equal' (evenOdd # field even # (4 :: Term s PInteger)) "(program 1.0.0 True)" , testCase "even 5" $ equal' (evenOdd # field even # (5 :: Term s PInteger)) "(program 1.0.0 False)" ] + , testGroup + "Data" + [ testCase "pdata" $ printTerm sampleData @?= "(program 1.0.0 ((\\i0 -> i1 False 6 \"Salut, Monde!\") (\\i0 -> \\i0 -> \\i0 -> constrData 0 (force mkCons ((\\i0 -> constrData (force ifThenElse i1 1 0) [ ]) i3) (force mkCons (iData i2) (force mkCons (bData (encodeUtf8 i1)) [ ]))))))" + , testCase "fieldFromData term" $ (printTerm $ plam $ \dat -> plam pfromData #$ fieldFromData sampleInt # dat) @?= "(program 1.0.0 (\\i0 -> unIData ((\\i0 -> (\\i0 -> force (force ifThenElse (equalsInteger (force (force fstPair) i1) 0) (delay (force headList (force tailList (force (force sndPair) i1)))) (delay error))) (unConstrData i1)) i1)))" + , testCase "fieldFromData value" $ equal' (fieldFromData sampleInt # sampleData) "(program 1.0.0 #06)" + , testCase "pfromData" $ (printTerm $ plam $ \d -> punsafeCoerce (pfromData d :: Term _ (PRecord SampleRecord)) # field sampleInt) @?= "(program 1.0.0 ((\\i0 -> (\\i0 -> (\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> force (force ifThenElse (equalsInteger (i3 i1) 0) (delay (\\i0 -> i1 ((\\i0 -> equalsInteger (i5 (unConstrData i1)) 1) (i5 (i7 i2))) (unIData (i5 (i6 (i7 i2)))) (decodeUtf8 (unBData (i5 (i6 (i6 (i7 i2)))))))) (delay error))) (unConstrData i1) (\\i0 -> \\i0 -> \\i0 -> i2)) (force (force fstPair))) (force headList)) (force tailList)) (force (force sndPair))))" + ] ]