Skip to content

Commit

Permalink
Added class FieldsFromData
Browse files Browse the repository at this point in the history
  • Loading branch information
blamario committed Jan 17, 2022
1 parent 9dd625d commit 5e64279
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 4 deletions.
42 changes: 41 additions & 1 deletion Plutarch/Rec.hs
@@ -1,7 +1,10 @@
{-# LANGUAGE DefaultSignatures #-}

module Plutarch.Rec (
PRecord (PRecord, getRecord),
ScottEncoded,
ScottEncoding,
FieldsFromData (fieldFromData),
field,
letrec,
pletrec,
Expand All @@ -12,14 +15,18 @@ import Data.Functor.Compose (Compose)
import Data.Kind (Type)
import Data.Monoid (Dual (Dual, getDual), Endo (Endo, appEndo), Sum (Sum, getSum))
import Numeric.Natural (Natural)
import Plutarch (PlutusType (PInner, pcon', pmatch'), phoistAcyclic, plam, punsafeCoerce, (#), (:-->))
import Plutarch (PlutusType (PInner, pcon', pmatch'), phoistAcyclic, plam, plet, punsafeCoerce, (#), (:-->))
import Plutarch.Bool (pif, (#==))
import Plutarch.Builtin (PAsData, PBuiltinList, PData, pasConstr, pforgetData, pfstBuiltin, psndBuiltin)
import Plutarch.Internal (
PType,
RawTerm (RApply, RLamAbs, RVar),
Term (Term, asRawTerm),
TermResult (TermResult, getDeps, getTerm),
mapTerm,
)
import Plutarch.List (phead, ptail)
import Plutarch.Trace (ptraceError)
import qualified Rank2

newtype PRecord r s = PRecord {getRecord :: r (Term s)}
Expand Down Expand Up @@ -100,6 +107,39 @@ variables = Rank2.cotraverse var id
, getDeps = []
}

newtype FocusFromData s a b = FocusFromData {getFocus :: Term s (PAsData a :--> PAsData b)}

class FieldsFromData r where
-- | Converts a Haskell field function to a function term that extracts the 'Data' encoding of the field from the
-- encoding of the whole record.
fieldFromData :: (r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) t)
-> Term s (PAsData (PRecord r) :--> PAsData t)
default fieldFromData :: (Rank2.Distributive r, Rank2.Traversable r)
=> (r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) t)
-> Term s (PAsData (PRecord r) :--> PAsData t)
fieldFromData f = getFocus (f fieldFoci)

fieldFoci :: forall r s. (Rank2.Distributive r, Rank2.Traversable r) => r (FocusFromData s (PRecord r))
fieldFoci = Rank2.cotraverse focus id
where
focus :: (r (FocusFromData s (PRecord r)) -> FocusFromData s (PRecord r) a) -> FocusFromData s (PRecord r) a
focus ref = ref ordered
ordered :: r (FocusFromData s (PRecord r))
ordered = evalState (Rank2.traverse next $ initial @r) id
next :: f a -> State (Term s (PBuiltinList PData) -> Term s (PBuiltinList PData)) (FocusFromData s (PRecord r) a)
next _ = do
rest <- get
put ((ptail #) . rest)
return $
FocusFromData $ punsafeCoerce $ fromFields ((phead #) . rest)
fromFields :: (Term s (PBuiltinList PData) -> Term s a) -> Term s (PAsData (PRecord r) :--> a)
fromFields f = plam $ \d->
plet (pasConstr # pforgetData d) $ \constr ->
pif
(pfstBuiltin # constr #== 0)
(f $ psndBuiltin # constr)
(ptraceError "fieldFromData expects a sole constructor")

initial :: Rank2.Distributive r => r (Compose Maybe (Term s))
initial = Rank2.distribute Nothing

Expand Down
57 changes: 54 additions & 3 deletions examples/Examples/LetRec.hs
Expand Up @@ -2,13 +2,16 @@

module Examples.LetRec (tests) where

import Plutarch (pcon', pmatch', printTerm)
import Plutarch (pcon', pmatch', printTerm, punsafeBuiltin, punsafeCoerce)
import Plutarch.Bool (PBool (PFalse, PTrue), pif, (#==))
import Plutarch.Builtin (PAsData, PBuiltinList (PNil), PData, PIsData, pasConstr, pdata, pforgetData, pfromData, pfstBuiltin, psndBuiltin)
import Plutarch.Integer (PInteger)
import Plutarch.List (phead, ptail)
import Plutarch.Prelude
import Plutarch.Rec (PRecord (PRecord), ScottEncoded, ScottEncoding, field, letrec)
import Plutarch.Rec (FieldsFromData, PRecord (PRecord), ScottEncoded, ScottEncoding, field, fieldFromData, letrec)
import Plutarch.Rec.TH (deriveAll)
import Plutarch.String (PString)
import Plutarch.String (PString, pdecodeUtf8, pencodeUtf8)
import qualified PlutusCore as PLC
import qualified Rank2.TH
import Test.Tasty (TestTree, testGroup)
import Test.Tasty.HUnit (testCase, (@?=))
Expand All @@ -30,6 +33,45 @@ type instance ScottEncoded EvenOdd a = (PInteger :--> PBool) :--> (PInteger :-->

$(Rank2.TH.deriveAll ''EvenOdd)
$(deriveAll ''SampleRecord) -- also autoderives the @type instance ScottEncoded@
instance FieldsFromData SampleRecord

instance PIsData (PRecord SampleRecord) where
pfromData = strictRecordFromData
pdata = recordData

--recordData :: (forall t. Term s (ScottEncoding SampleRecord t)) -> Term s (PAsData (PRecord SampleRecord))
recordData :: forall s. Term s (PRecord SampleRecord) -> Term s (PAsData (PRecord SampleRecord))
recordData r = pmatch r $ \(PRecord SampleRecord{sampleBool, sampleInt, sampleString})->
punsafeBuiltin PLC.ConstrData # (0 :: Term s PInteger) #$
pconsBuiltin # pforgetData (pdata sampleBool) #$
pconsBuiltin # pforgetData (pdata sampleInt) #$
pconsBuiltin # pforgetData (pdata $ pencodeUtf8 # sampleString) #$
pcon PNil

pconsBuiltin :: Term s (a :--> PBuiltinList a :--> PBuiltinList a)
pconsBuiltin = phoistAcyclic $ pforce $ punsafeBuiltin PLC.MkCons

strictRecordFromData :: Term s (PAsData (PRecord SampleRecord)) -> Term s (PRecord SampleRecord)
strictRecordFromData d =
plet (pasConstr # pforgetData d) $ \constr ->
pif
(pfstBuiltin # constr #== 0)
(fillInFields #$ psndBuiltin # constr)
perror
where
fillInFields :: Term s (PBuiltinList PData :--> PRecord SampleRecord)
fillInFields = plam $ \bis ->
plet (phead # bis) $ \b ->
plet (ptail # bis) $ \is ->
plet (phead # is) $ \i ->
plet (phead #$ ptail # is) $ \s ->
pcon
( PRecord $
SampleRecord
(pfromData $ punsafeCoerce b)
(pfromData $ punsafeCoerce i)
(pdecodeUtf8 #$ pfromData $ punsafeCoerce s)
)

sampleRecord :: Term (s :: S) (ScottEncoding SampleRecord (t :: PType))
sampleRecord =
Expand Down Expand Up @@ -61,6 +103,9 @@ evenOdd = letrec evenOddRecursion
, odd = plam $ \n -> pif (n #== 0) (pcon PFalse) (even #$ n - 1)
}

sampleData :: Term s (PAsData (PRecord SampleRecord))
sampleData = pdata (punsafeCoerce sampleRecord)

tests :: HasTester => TestTree
tests =
testGroup
Expand All @@ -83,4 +128,10 @@ tests =
, testCase "even 4" $ equal' (evenOdd # field even # (4 :: Term s PInteger)) "(program 1.0.0 True)"
, testCase "even 5" $ equal' (evenOdd # field even # (5 :: Term s PInteger)) "(program 1.0.0 False)"
]
, testGroup
"Data"
[ testCase "pdata" $ printTerm sampleData @?= "(program 1.0.0 ((\\i0 -> i1 False 6 \"Salut, Monde!\") (\\i0 -> \\i0 -> \\i0 -> constrData 0 (force mkCons ((\\i0 -> constrData (force ifThenElse i1 1 0) [ ]) i3) (force mkCons (iData i2) (force mkCons (bData (encodeUtf8 i1)) [ ]))))))"
, testCase "fieldFromData term" $ (printTerm $ plam $ \dat-> plam pfromData #$ fieldFromData sampleInt # dat) @?= "(program 1.0.0 (\\i0 -> unIData ((\\i0 -> (\\i0 -> force (force ifThenElse (equalsInteger (force (force fstPair) i1) 0) (delay (force headList (force tailList (force (force sndPair) i1)))) (delay error))) (unConstrData i1)) i1)))"
, testCase "fieldFromData value" $ equal' (fieldFromData sampleInt # sampleData) "(program 1.0.0 #06)"
]
]

0 comments on commit 5e64279

Please sign in to comment.