diff --git a/Plutarch/Rec.hs b/Plutarch/Rec.hs index c3fafc958..0f34c5b01 100644 --- a/Plutarch/Rec.hs +++ b/Plutarch/Rec.hs @@ -1,7 +1,7 @@ {-# LANGUAGE DefaultSignatures #-} module Plutarch.Rec ( - DataReader(DataReader, readData), + DataReader (DataReader, readData), PRecord (PRecord, getRecord), ScottEncoded, ScottEncoding, @@ -14,7 +14,7 @@ module Plutarch.Rec ( ) where import Control.Monad.Trans.State.Lazy (State, evalState, get, put) -import Data.Functor.Compose (Compose(Compose, getCompose)) +import Data.Functor.Compose (Compose (Compose, getCompose)) import Data.Kind (Type) import Data.Monoid (Dual (Dual, getDual), Endo (Endo, appEndo), Sum (Sum, getSum)) import Numeric.Natural (Natural) @@ -114,33 +114,42 @@ newtype DataReader s a = DataReader {readData :: Term s (PAsData a) -> Term s a} newtype FocusFromData s a b = FocusFromData {getFocus :: Term s (PAsData a :--> PAsData b)} newtype FocusFromDataList s a = FocusFromDataList {getItem :: Term s (PBuiltinList PData) -> Term s (PAsData a)} --- | Converts a record of field DataReaders to a DataReader of the whole --- record. If you only need a single field or two, use `fieldFromData` --- instead. -recordFromFieldReaders :: forall r s. (Rank2.Apply r, RecordFromData r) - => r (DataReader s) -> DataReader s (PRecord r) +{- | Converts a record of field DataReaders to a DataReader of the whole + record. If you only need a single field or two, use `fieldFromData` + instead. +-} +recordFromFieldReaders :: + forall r s. + (Rank2.Apply r, RecordFromData r) => + r (DataReader s) -> + DataReader s (PRecord r) recordFromFieldReaders reader = DataReader $ verifySoleConstructor readRecord where readRecord :: Term s (PBuiltinList PData) -> Term s (PRecord r) readRecord dat = pcon $ PRecord $ Rank2.liftA2 (flip readData . getCompose) (fields dat) reader fields :: Term s (PBuiltinList PData) -> r (Compose (Term s) PAsData) - fields bis = (\f-> Compose $ getItem f bis) Rank2.<$> fieldListFoci + fields bis = (\f -> Compose $ getItem f bis) Rank2.<$> fieldListFoci --- | Converts a Haskell field function to a function term that extracts the 'Data' encoding of the field from the --- encoding of the whole record. -fieldFromData :: RecordFromData r - => (r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) t) - -> Term s (PAsData (PRecord r) :--> PAsData t) +{- | Converts a Haskell field function to a function term that extracts the 'Data' encoding of the field from the + encoding of the whole record. +-} +fieldFromData :: + RecordFromData r => + (r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) t) -> + Term s (PAsData (PRecord r) :--> PAsData t) fieldFromData f = getFocus (f fieldFoci) --- | Instances of this class must know how to focus on individual fields of --- the data-encoded record. If the declared order of the record fields doesn't --- match the encoding order, you must override the method defaults. +{- | Instances of this class must know how to focus on individual fields of + the data-encoded record. If the declared order of the record fields doesn't + match the encoding order, you must override the method defaults. +-} class (Rank2.Distributive r, Rank2.Traversable r) => RecordFromData r where -- | Given the encoding of the whole record, every field focuses on its own encoding. fieldFoci :: r (FocusFromData s (PRecord r)) + -- | Given the encoding of the list of all fields, every field focuses on its own encoding. fieldListFoci :: r (FocusFromDataList s) + fieldFoci = Rank2.cotraverse focus id where focus :: (r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) a) -> FocusFromData s (PRecord r) a diff --git a/examples/Examples/LetRec.hs b/examples/Examples/LetRec.hs index 8981844b0..1bc37d7c7 100644 --- a/examples/Examples/LetRec.hs +++ b/examples/Examples/LetRec.hs @@ -7,8 +7,17 @@ import Plutarch.Bool (PBool (PFalse, PTrue), pif, (#==)) import Plutarch.Builtin (PAsData, PBuiltinList (PNil), PIsData, pdata, pforgetData, pfromData) import Plutarch.Integer (PInteger) import Plutarch.Prelude -import Plutarch.Rec (DataReader(DataReader, readData), RecordFromData, PRecord (PRecord), ScottEncoded, ScottEncoding, - field, fieldFromData, letrec, recordFromFieldReaders) +import Plutarch.Rec ( + DataReader (DataReader, readData), + PRecord (PRecord), + RecordFromData, + ScottEncoded, + ScottEncoding, + field, + fieldFromData, + letrec, + recordFromFieldReaders, + ) import Plutarch.Rec.TH (deriveAll) import Plutarch.String (PString, pdecodeUtf8, pencodeUtf8) import qualified PlutusCore as PLC @@ -40,21 +49,23 @@ instance PIsData (PRecord SampleRecord) where pdata = recordData recordData :: forall s. Term s (PRecord SampleRecord) -> Term s (PAsData (PRecord SampleRecord)) -recordData r = pmatch r $ \(PRecord SampleRecord{sampleBool, sampleInt, sampleString})-> - punsafeBuiltin PLC.ConstrData # (0 :: Term s PInteger) #$ - pconsBuiltin # pforgetData (pdata sampleBool) #$ - pconsBuiltin # pforgetData (pdata sampleInt) #$ - pconsBuiltin # pforgetData (pdata $ pencodeUtf8 # sampleString) #$ - pcon PNil +recordData r = pmatch r $ \(PRecord SampleRecord {sampleBool, sampleInt, sampleString}) -> + punsafeBuiltin PLC.ConstrData # (0 :: Term s PInteger) + #$ pconsBuiltin # pforgetData (pdata sampleBool) + #$ pconsBuiltin # pforgetData (pdata sampleInt) + #$ pconsBuiltin # pforgetData (pdata $ pencodeUtf8 # sampleString) + #$ pcon PNil pconsBuiltin :: Term s (a :--> PBuiltinList a :--> PBuiltinList a) pconsBuiltin = phoistAcyclic $ pforce $ punsafeBuiltin PLC.MkCons sampleReader :: SampleRecord (DataReader s) -sampleReader = SampleRecord{ - sampleBool = DataReader pfromData, - sampleInt = DataReader pfromData, - sampleString = DataReader $ \d-> pdecodeUtf8 #$ pfromData $ punsafeCoerce d} +sampleReader = + SampleRecord + { sampleBool = DataReader pfromData + , sampleInt = DataReader pfromData + , sampleString = DataReader $ \d -> pdecodeUtf8 #$ pfromData $ punsafeCoerce d + } sampleRecord :: Term (s :: S) (ScottEncoding SampleRecord (t :: PType)) sampleRecord = @@ -114,8 +125,8 @@ tests = , testGroup "Data" [ testCase "pdata" $ printTerm sampleData @?= "(program 1.0.0 ((\\i0 -> i1 False 6 \"Salut, Monde!\") (\\i0 -> \\i0 -> \\i0 -> constrData 0 (force mkCons ((\\i0 -> constrData (force ifThenElse i1 1 0) [ ]) i3) (force mkCons (iData i2) (force mkCons (bData (encodeUtf8 i1)) [ ]))))))" - , testCase "fieldFromData term" $ (printTerm $ plam $ \dat-> plam pfromData #$ fieldFromData sampleInt # dat) @?= "(program 1.0.0 (\\i0 -> unIData ((\\i0 -> (\\i0 -> force (force ifThenElse (equalsInteger (force (force fstPair) i1) 0) (delay (force headList (force tailList (force (force sndPair) i1)))) (delay error))) (unConstrData i1)) i1)))" + , testCase "fieldFromData term" $ (printTerm $ plam $ \dat -> plam pfromData #$ fieldFromData sampleInt # dat) @?= "(program 1.0.0 (\\i0 -> unIData ((\\i0 -> (\\i0 -> force (force ifThenElse (equalsInteger (force (force fstPair) i1) 0) (delay (force headList (force tailList (force (force sndPair) i1)))) (delay error))) (unConstrData i1)) i1)))" , testCase "fieldFromData value" $ equal' (fieldFromData sampleInt # sampleData) "(program 1.0.0 #06)" - , testCase "pfromData" $ (printTerm $ plam $ \d-> punsafeCoerce (pfromData d :: Term _ (PRecord SampleRecord)) # field sampleInt) @?= "(program 1.0.0 ((\\i0 -> (\\i0 -> (\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> force (force ifThenElse (equalsInteger (i3 i1) 0) (delay (\\i0 -> i1 ((\\i0 -> equalsInteger (i5 (unConstrData i1)) 1) (i5 (i7 i2))) (unIData (i5 (i6 (i7 i2)))) (decodeUtf8 (unBData (i5 (i6 (i6 (i7 i2)))))))) (delay error))) (unConstrData i1) (\\i0 -> \\i0 -> \\i0 -> i2)) (force (force fstPair))) (force headList)) (force tailList)) (force (force sndPair))))" + , testCase "pfromData" $ (printTerm $ plam $ \d -> punsafeCoerce (pfromData d :: Term _ (PRecord SampleRecord)) # field sampleInt) @?= "(program 1.0.0 ((\\i0 -> (\\i0 -> (\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> force (force ifThenElse (equalsInteger (i3 i1) 0) (delay (\\i0 -> i1 ((\\i0 -> equalsInteger (i5 (unConstrData i1)) 1) (i5 (i7 i2))) (unIData (i5 (i6 (i7 i2)))) (decodeUtf8 (unBData (i5 (i6 (i6 (i7 i2)))))))) (delay error))) (unConstrData i1) (\\i0 -> \\i0 -> \\i0 -> i2)) (force (force fstPair))) (force headList)) (force tailList)) (force (force sndPair))))" ] ]