Skip to content

Commit

Permalink
Merge pull request #125 from Plutonomicon/mario/records-data
Browse files Browse the repository at this point in the history
Mario/records data
  • Loading branch information
L-as committed Jan 17, 2022
2 parents 9dd625d + cf139a0 commit 23bd0ca
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 5 deletions.
82 changes: 80 additions & 2 deletions Plutarch/Rec.hs
@@ -1,25 +1,35 @@
{-# LANGUAGE DefaultSignatures #-}

module Plutarch.Rec (
DataReader (DataReader, readData),
PRecord (PRecord, getRecord),
ScottEncoded,
ScottEncoding,
RecordFromData (fieldFoci, fieldListFoci),
field,
fieldFromData,
letrec,
pletrec,
recordFromFieldReaders,
) where

import Control.Monad.Trans.State.Lazy (State, evalState, get, put)
import Data.Functor.Compose (Compose)
import Data.Functor.Compose (Compose (Compose, getCompose))
import Data.Kind (Type)
import Data.Monoid (Dual (Dual, getDual), Endo (Endo, appEndo), Sum (Sum, getSum))
import Numeric.Natural (Natural)
import Plutarch (PlutusType (PInner, pcon', pmatch'), phoistAcyclic, plam, punsafeCoerce, (#), (:-->))
import Plutarch (PlutusType (PInner, pcon', pmatch'), pcon, phoistAcyclic, plam, plet, punsafeCoerce, (#), (:-->))
import Plutarch.Bool (pif, (#==))
import Plutarch.Builtin (PAsData, PBuiltinList, PData, pasConstr, pforgetData, pfstBuiltin, psndBuiltin)
import Plutarch.Internal (
PType,
RawTerm (RApply, RLamAbs, RVar),
Term (Term, asRawTerm),
TermResult (TermResult, getDeps, getTerm),
mapTerm,
)
import Plutarch.List (phead, ptail)
import Plutarch.Trace (ptraceError)
import qualified Rank2

newtype PRecord r s = PRecord {getRecord :: r (Term s)}
Expand Down Expand Up @@ -100,6 +110,74 @@ variables = Rank2.cotraverse var id
, getDeps = []
}

newtype DataReader s a = DataReader {readData :: Term s (PAsData a) -> Term s a}
newtype FocusFromData s a b = FocusFromData {getFocus :: Term s (PAsData a :--> PAsData b)}
newtype FocusFromDataList s a = FocusFromDataList {getItem :: Term s (PBuiltinList PData) -> Term s (PAsData a)}

{- | Converts a record of field DataReaders to a DataReader of the whole
record. If you only need a single field or two, use `fieldFromData`
instead.
-}
recordFromFieldReaders ::
forall r s.
(Rank2.Apply r, RecordFromData r) =>
r (DataReader s) ->
DataReader s (PRecord r)
recordFromFieldReaders reader = DataReader $ verifySoleConstructor readRecord
where
readRecord :: Term s (PBuiltinList PData) -> Term s (PRecord r)
readRecord dat = pcon $ PRecord $ Rank2.liftA2 (flip readData . getCompose) (fields dat) reader
fields :: Term s (PBuiltinList PData) -> r (Compose (Term s) PAsData)
fields bis = (\f -> Compose $ getItem f bis) Rank2.<$> fieldListFoci

{- | Converts a Haskell field function to a function term that extracts the 'Data' encoding of the field from the
encoding of the whole record.
-}
fieldFromData ::
RecordFromData r =>
(r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) t) ->
Term s (PAsData (PRecord r) :--> PAsData t)
fieldFromData f = getFocus (f fieldFoci)

{- | Instances of this class must know how to focus on individual fields of
the data-encoded record. If the declared order of the record fields doesn't
match the encoding order, you must override the method defaults.
-}
class (Rank2.Distributive r, Rank2.Traversable r) => RecordFromData r where
-- | Given the encoding of the whole record, every field focuses on its own encoding.
fieldFoci :: r (FocusFromData s (PRecord r))

-- | Given the encoding of the list of all fields, every field focuses on its own encoding.
fieldListFoci :: r (FocusFromDataList s)

fieldFoci = Rank2.cotraverse focus id
where
focus :: (r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) a) -> FocusFromData s (PRecord r) a
focus ref = ref foci
foci :: r (FocusFromData s (PRecord r))
foci = fieldsFromRecord Rank2.<$> fieldListFoci
fieldsFromRecord :: FocusFromDataList s a -> FocusFromData s (PRecord r) a
fieldsFromRecord (FocusFromDataList f) = FocusFromData $ plam $ verifySoleConstructor f
fieldListFoci = Rank2.cotraverse focus id
where
focus :: (r (FocusFromDataList s) -> FocusFromDataList s a) -> FocusFromDataList s a
focus ref = ref foci
foci :: r (FocusFromDataList s)
foci = evalState (Rank2.traverse next $ initial @r) id
next :: f a -> State (Term s (PBuiltinList PData) -> Term s (PBuiltinList PData)) (FocusFromDataList s a)
next _ = do
rest <- get
put ((ptail #) . rest)
return $ FocusFromDataList (punsafeCoerce . (phead #) . rest)

verifySoleConstructor :: (Term s (PBuiltinList PData) -> Term s a) -> (Term s (PAsData (PRecord r)) -> Term s a)
verifySoleConstructor f d =
plet (pasConstr # pforgetData d) $ \constr ->
pif
(pfstBuiltin # constr #== 0)
(f $ psndBuiltin # constr)
(ptraceError "verifySoleConstructor failed")

initial :: Rank2.Distributive r => r (Compose Maybe (Term s))
initial = Rank2.distribute Nothing

Expand Down
52 changes: 49 additions & 3 deletions examples/Examples/LetRec.hs
Expand Up @@ -2,13 +2,25 @@

module Examples.LetRec (tests) where

import Plutarch (pcon', pmatch', printTerm)
import Plutarch (pcon', pmatch', printTerm, punsafeBuiltin, punsafeCoerce)
import Plutarch.Bool (PBool (PFalse, PTrue), pif, (#==))
import Plutarch.Builtin (PAsData, PBuiltinList (PNil), PIsData, pdata, pforgetData, pfromData)
import Plutarch.Integer (PInteger)
import Plutarch.Prelude
import Plutarch.Rec (PRecord (PRecord), ScottEncoded, ScottEncoding, field, letrec)
import Plutarch.Rec (
DataReader (DataReader, readData),
PRecord (PRecord),
RecordFromData,
ScottEncoded,
ScottEncoding,
field,
fieldFromData,
letrec,
recordFromFieldReaders,
)
import Plutarch.Rec.TH (deriveAll)
import Plutarch.String (PString)
import Plutarch.String (PString, pdecodeUtf8, pencodeUtf8)
import qualified PlutusCore as PLC
import qualified Rank2.TH
import Test.Tasty (TestTree, testGroup)
import Test.Tasty.HUnit (testCase, (@?=))
Expand All @@ -30,6 +42,30 @@ type instance ScottEncoded EvenOdd a = (PInteger :--> PBool) :--> (PInteger :-->

$(Rank2.TH.deriveAll ''EvenOdd)
$(deriveAll ''SampleRecord) -- also autoderives the @type instance ScottEncoded@
instance RecordFromData SampleRecord

instance PIsData (PRecord SampleRecord) where
pfromData = readData (recordFromFieldReaders sampleReader)
pdata = recordData

recordData :: forall s. Term s (PRecord SampleRecord) -> Term s (PAsData (PRecord SampleRecord))
recordData r = pmatch r $ \(PRecord SampleRecord {sampleBool, sampleInt, sampleString}) ->
punsafeBuiltin PLC.ConstrData # (0 :: Term s PInteger)
#$ pconsBuiltin # pforgetData (pdata sampleBool)
#$ pconsBuiltin # pforgetData (pdata sampleInt)
#$ pconsBuiltin # pforgetData (pdata $ pencodeUtf8 # sampleString)
#$ pcon PNil

pconsBuiltin :: Term s (a :--> PBuiltinList a :--> PBuiltinList a)
pconsBuiltin = phoistAcyclic $ pforce $ punsafeBuiltin PLC.MkCons

sampleReader :: SampleRecord (DataReader s)
sampleReader =
SampleRecord
{ sampleBool = DataReader pfromData
, sampleInt = DataReader pfromData
, sampleString = DataReader $ \d -> pdecodeUtf8 #$ pfromData $ punsafeCoerce d
}

sampleRecord :: Term (s :: S) (ScottEncoding SampleRecord (t :: PType))
sampleRecord =
Expand Down Expand Up @@ -61,6 +97,9 @@ evenOdd = letrec evenOddRecursion
, odd = plam $ \n -> pif (n #== 0) (pcon PFalse) (even #$ n - 1)
}

sampleData :: Term s (PAsData (PRecord SampleRecord))
sampleData = pdata (punsafeCoerce sampleRecord)

tests :: HasTester => TestTree
tests =
testGroup
Expand All @@ -83,4 +122,11 @@ tests =
, testCase "even 4" $ equal' (evenOdd # field even # (4 :: Term s PInteger)) "(program 1.0.0 True)"
, testCase "even 5" $ equal' (evenOdd # field even # (5 :: Term s PInteger)) "(program 1.0.0 False)"
]
, testGroup
"Data"
[ testCase "pdata" $ printTerm sampleData @?= "(program 1.0.0 ((\\i0 -> i1 False 6 \"Salut, Monde!\") (\\i0 -> \\i0 -> \\i0 -> constrData 0 (force mkCons ((\\i0 -> constrData (force ifThenElse i1 1 0) [ ]) i3) (force mkCons (iData i2) (force mkCons (bData (encodeUtf8 i1)) [ ]))))))"
, testCase "fieldFromData term" $ (printTerm $ plam $ \dat -> plam pfromData #$ fieldFromData sampleInt # dat) @?= "(program 1.0.0 (\\i0 -> unIData ((\\i0 -> (\\i0 -> force (force ifThenElse (equalsInteger (force (force fstPair) i1) 0) (delay (force headList (force tailList (force (force sndPair) i1)))) (delay error))) (unConstrData i1)) i1)))"
, testCase "fieldFromData value" $ equal' (fieldFromData sampleInt # sampleData) "(program 1.0.0 #06)"
, testCase "pfromData" $ (printTerm $ plam $ \d -> punsafeCoerce (pfromData d :: Term _ (PRecord SampleRecord)) # field sampleInt) @?= "(program 1.0.0 ((\\i0 -> (\\i0 -> (\\i0 -> (\\i0 -> \\i0 -> (\\i0 -> force (force ifThenElse (equalsInteger (i3 i1) 0) (delay (\\i0 -> i1 ((\\i0 -> equalsInteger (i5 (unConstrData i1)) 1) (i5 (i7 i2))) (unIData (i5 (i6 (i7 i2)))) (decodeUtf8 (unBData (i5 (i6 (i6 (i7 i2)))))))) (delay error))) (unConstrData i1) (\\i0 -> \\i0 -> \\i0 -> i2)) (force (force fstPair))) (force headList)) (force tailList)) (force (force sndPair))))"
]
]

0 comments on commit 23bd0ca

Please sign in to comment.