diff --git a/Plutarch/Rec.hs b/Plutarch/Rec.hs index 02594ad24..68c7b6c81 100644 --- a/Plutarch/Rec.hs +++ b/Plutarch/Rec.hs @@ -1,7 +1,10 @@ +{-# LANGUAGE DefaultSignatures #-} + module Plutarch.Rec ( PRecord (PRecord, getRecord), ScottEncoded, ScottEncoding, + FieldsFromData (fieldFromData), field, letrec, pletrec, @@ -12,7 +15,9 @@ import Data.Functor.Compose (Compose) import Data.Kind (Type) import Data.Monoid (Dual (Dual, getDual), Endo (Endo, appEndo), Sum (Sum, getSum)) import Numeric.Natural (Natural) -import Plutarch (PlutusType (PInner, pcon', pmatch'), phoistAcyclic, plam, punsafeCoerce, (#), (:-->)) +import Plutarch (PlutusType (PInner, pcon', pmatch'), phoistAcyclic, plam, plet, punsafeCoerce, (#), (:-->)) +import Plutarch.Bool (pif, (#==)) +import Plutarch.Builtin (PAsData, PBuiltinList, PData, pasConstr, pforgetData, pfstBuiltin, psndBuiltin) import Plutarch.Internal ( PType, RawTerm (RApply, RLamAbs, RVar), @@ -20,6 +25,8 @@ import Plutarch.Internal ( TermResult (TermResult, getDeps, getTerm), mapTerm, ) +import Plutarch.List (phead, ptail) +import Plutarch.Trace (ptraceError) import qualified Rank2 newtype PRecord r s = PRecord {getRecord :: r (Term s)} @@ -100,6 +107,39 @@ variables = Rank2.cotraverse var id , getDeps = [] } +newtype FocusFromData s a b = FocusFromData {getFocus :: Term s (PAsData a :--> PAsData b)} + +class FieldsFromData r where + -- | Converts a Haskell field function to a function term that extracts the 'Data' encoding of the field from the + -- encoding of the whole record. + fieldFromData :: (r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) t) + -> Term s (PAsData (PRecord r) :--> PAsData t) + default fieldFromData :: (Rank2.Distributive r, Rank2.Traversable r) + => (r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) t) + -> Term s (PAsData (PRecord r) :--> PAsData t) + fieldFromData f = getFocus (f fieldFoci) + +fieldFoci :: forall r s. (Rank2.Distributive r, Rank2.Traversable r) => r (FocusFromData s (PRecord r)) +fieldFoci = Rank2.cotraverse focus id + where + focus :: (r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) a) -> FocusFromData s (PRecord r) a + focus ref = ref ordered + ordered :: r (FocusFromData s (PRecord r)) + ordered = evalState (Rank2.traverse next $ initial @r) id + next :: f a -> State (Term s (PBuiltinList PData) -> Term s (PBuiltinList PData)) (FocusFromData s (PRecord r) a) + next _ = do + rest <- get + put ((ptail #) . rest) + return $ + FocusFromData $ punsafeCoerce $ fromFields ((phead #) . rest) + fromFields :: (Term s (PBuiltinList PData) -> Term s a) -> Term s (PAsData (PRecord r) :--> a) + fromFields f = plam $ \d-> + plet (pasConstr # pforgetData d) $ \constr -> + pif + (pfstBuiltin # constr #== 0) + (f $ psndBuiltin # constr) + (ptraceError "fieldFromData expects a sole constructor") + initial :: Rank2.Distributive r => r (Compose Maybe (Term s)) initial = Rank2.distribute Nothing diff --git a/examples/Examples/LetRec.hs b/examples/Examples/LetRec.hs index 890e825fc..273260fb6 100644 --- a/examples/Examples/LetRec.hs +++ b/examples/Examples/LetRec.hs @@ -2,13 +2,16 @@ module Examples.LetRec (tests) where -import Plutarch (pcon', pmatch', printTerm) +import Plutarch (pcon', pmatch', printTerm, punsafeBuiltin, punsafeCoerce) import Plutarch.Bool (PBool (PFalse, PTrue), pif, (#==)) +import Plutarch.Builtin (PAsData, PBuiltinList (PNil), PData, PIsData, pasConstr, pdata, pforgetData, pfromData, pfstBuiltin, psndBuiltin) import Plutarch.Integer (PInteger) +import Plutarch.List (phead, ptail) import Plutarch.Prelude -import Plutarch.Rec (PRecord (PRecord), ScottEncoded, ScottEncoding, field, letrec) +import Plutarch.Rec (FieldsFromData, PRecord (PRecord), ScottEncoded, ScottEncoding, field, fieldFromData, letrec) import Plutarch.Rec.TH (deriveAll) -import Plutarch.String (PString) +import Plutarch.String (PString, pdecodeUtf8, pencodeUtf8) +import qualified PlutusCore as PLC import qualified Rank2.TH import Test.Tasty (TestTree, testGroup) import Test.Tasty.HUnit (testCase, (@?=)) @@ -30,6 +33,45 @@ type instance ScottEncoded EvenOdd a = (PInteger :--> PBool) :--> (PInteger :--> $(Rank2.TH.deriveAll ''EvenOdd) $(deriveAll ''SampleRecord) -- also autoderives the @type instance ScottEncoded@ +instance FieldsFromData SampleRecord + +instance PIsData (PRecord SampleRecord) where + pfromData = strictRecordFromData + pdata = recordData + +--recordData :: (forall t. Term s (ScottEncoding SampleRecord t)) -> Term s (PAsData (PRecord SampleRecord)) +recordData :: forall s. Term s (PRecord SampleRecord) -> Term s (PAsData (PRecord SampleRecord)) +recordData r = pmatch r $ \(PRecord SampleRecord{sampleBool, sampleInt, sampleString})-> + punsafeBuiltin PLC.ConstrData # (0 :: Term s PInteger) #$ + pconsBuiltin # pforgetData (pdata sampleBool) #$ + pconsBuiltin # pforgetData (pdata sampleInt) #$ + pconsBuiltin # pforgetData (pdata $ pencodeUtf8 # sampleString) #$ + pcon PNil + +pconsBuiltin :: Term s (a :--> PBuiltinList a :--> PBuiltinList a) +pconsBuiltin = phoistAcyclic $ pforce $ punsafeBuiltin PLC.MkCons + +strictRecordFromData :: Term s (PAsData (PRecord SampleRecord)) -> Term s (PRecord SampleRecord) +strictRecordFromData d = + plet (pasConstr # pforgetData d) $ \constr -> + pif + (pfstBuiltin # constr #== 0) + (fillInFields #$ psndBuiltin # constr) + perror + where + fillInFields :: Term s (PBuiltinList PData :--> PRecord SampleRecord) + fillInFields = plam $ \bis -> + plet (phead # bis) $ \b -> + plet (ptail # bis) $ \is -> + plet (phead # is) $ \i -> + plet (phead #$ ptail # is) $ \s -> + pcon + ( PRecord $ + SampleRecord + (pfromData $ punsafeCoerce b) + (pfromData $ punsafeCoerce i) + (pdecodeUtf8 #$ pfromData $ punsafeCoerce s) + ) sampleRecord :: Term (s :: S) (ScottEncoding SampleRecord (t :: PType)) sampleRecord = @@ -61,6 +103,9 @@ evenOdd = letrec evenOddRecursion , odd = plam $ \n -> pif (n #== 0) (pcon PFalse) (even #$ n - 1) } +sampleData :: Term s (PAsData (PRecord SampleRecord)) +sampleData = pdata (punsafeCoerce sampleRecord) + tests :: HasTester => TestTree tests = testGroup @@ -83,4 +128,10 @@ tests = , testCase "even 4" $ equal' (evenOdd # field even # (4 :: Term s PInteger)) "(program 1.0.0 True)" , testCase "even 5" $ equal' (evenOdd # field even # (5 :: Term s PInteger)) "(program 1.0.0 False)" ] + , testGroup + "Data" + [ testCase "pdata" $ printTerm sampleData @?= "(program 1.0.0 ((\\i0 -> i1 False 6 \"Salut, Monde!\") (\\i0 -> \\i0 -> \\i0 -> constrData 0 (force mkCons ((\\i0 -> constrData (force ifThenElse i1 1 0) [ ]) i3) (force mkCons (iData i2) (force mkCons (bData (encodeUtf8 i1)) [ ]))))))" + , testCase "fieldFromData term" $ (printTerm $ plam $ \dat-> plam pfromData #$ fieldFromData sampleInt # dat) @?= "(program 1.0.0 (\\i0 -> unIData ((\\i0 -> (\\i0 -> force (force ifThenElse (equalsInteger (force (force fstPair) i1) 0) (delay (force headList (force tailList (force (force sndPair) i1)))) (delay error))) (unConstrData i1)) i1)))" + , testCase "fieldFromData value" $ equal' (fieldFromData sampleInt # sampleData) "(program 1.0.0 #06)" + ] ]