Skip to content

Commit

Permalink
constrained-generators: Add flip_ to avoid having to add new native
Browse files Browse the repository at this point in the history
functions
  • Loading branch information
MaximilianAlgehed committed May 29, 2024
1 parent b3b8b3d commit c05e040
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 4 deletions.
2 changes: 2 additions & 0 deletions libs/constrained-generators/src/Constrained.hs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ module Constrained (
snd_,
pair_,
(<=.),
(>=.),
(<.),
(>.),
(==.),
(/=.),
(||.),
Expand Down
75 changes: 71 additions & 4 deletions libs/constrained-generators/src/Constrained/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1102,14 +1102,16 @@ prepareLinearization p = do
graph = transitiveClosure $ hints <> respecting hints (foldMap computeDependencies preds)
-- TODO: clean this up
subst =
[ x := t
[ (x := t, freeVarSet t)
| Assert _ (App (extractFn @(EqFn fn) @fn -> Just Equal) (V x :> t :> Nil)) <- preds
, not $ any (\y -> Name x `Set.member` dependencies y graph) (freeVarSet t)
]
preds' =
flattenPred . Block $
map (foldr (.) id [simplifyPred . substitutePred x tm | x := tm <- subst]) preds
++ [Assert [] (V x ==. t) | x := t <- subst]
map (foldr (.) id (map f subst)) preds
f (x := t, fvs) p
| Set.disjoint fvs (freeVarSet p) = simplifyPred $ substitutePred x t p
| otherwise = p
(,graph) <$> linearize preds' graph

prettyPlan :: HasSpec fn a => Specification fn a -> Doc ann
Expand Down Expand Up @@ -3775,6 +3777,31 @@ composeFn ::
fn '[a] c
composeFn f g = injectFn $ Compose f g

flip_ ::
forall fn a b c.
( Member (FunFn fn) fn
, Typeable a
, Typeable b
, HasSpec fn a
, HasSpec fn b
, HasSpec fn c
) =>
(Term fn a -> Term fn b -> Term fn c) ->
Term fn b ->
Term fn a ->
Term fn c
flip_ f =
app (injectFn @(FunFn fn) @fn (Flip f'))
where
x = Var (-1) :: Var a
y = Var (-2) :: Var b
f' = case f (V x) (V y) of
App fn (V x' :> V y' :> Nil)
| Just Refl <- eqVar x x'
, Just Refl <- eqVar y y' ->
fn
_ -> error "Malformed function in flip_"

data FunFn fn args res where
Id :: FunFn fn '[a] a
Compose ::
Expand All @@ -3788,6 +3815,14 @@ data FunFn fn args res where
fn '[b] c ->
fn '[a] b ->
FunFn fn '[a] c
Flip ::
( Show (fn '[a, b] c)
, Eq (fn '[a, b] c)
, HasSpec fn a
, HasSpec fn b
) =>
fn '[a, b] c ->
FunFn fn '[b, a] c

deriving instance Show (FunFn fn args res)

Expand All @@ -3797,17 +3832,26 @@ instance Typeable fn => Eq (FunFn fn args res) where
Compose {} == _ = False
Id == Id = True
Id == _ = False
Flip (f :: fn '[a, b] c) == Flip (g :: fn '[a', b'] c')
| Just Refl <- eqT @a @a'
, Just Refl <- eqT @b @b' =
f == g
Flip {} == _ = False

instance FunctionLike fn => FunctionLike (FunFn fn) where
sem = \case
Id -> id
Compose f g -> sem f . sem g
Flip f -> flip (sem f)

instance (BaseUniverse fn, Member (FunFn fn) fn) => Functions (FunFn fn) fn where
propagateSpecFun _ _ (ErrorSpec err) = ErrorSpec err
propagateSpecFun fn ctx spec = case fn of
Id | NilCtx HOLE <- ctx -> spec
Compose f g | NilCtx HOLE <- ctx -> propagateSpecFun g (NilCtx HOLE) $ propagateSpecFun f (NilCtx HOLE) spec
Flip f
| HOLE :? v :> Nil <- ctx -> propagateSpecFun f (v :! NilCtx HOLE) spec
| v :! NilCtx HOLE <- ctx -> propagateSpecFun f (HOLE :? v :> Nil) spec

-- NOTE: this function over-approximates and returns a liberal spec.
mapTypeSpec f ts = case f of
Expand All @@ -3816,6 +3860,11 @@ instance (BaseUniverse fn, Member (FunFn fn) fn) => Functions (FunFn fn) fn wher

rewriteRules Id (x :> Nil) = Just x
rewriteRules (Compose f g) (x :> Nil) = Just $ app f (app g x)
-- TODO: this is a bit crippled by the fact that we forget any other rewrite
-- rules that we had for `f`. That's something we'll have to think about.
rewriteRules (Flip f) (a@Lit {} :> b :> Nil) = Just $ app f b a
rewriteRules (Flip f) (a :> b@Lit {} :> Nil) = Just $ app f b a
rewriteRules Flip {} _ = Nothing

-- Ord functions ----------------------------------------------------------

Expand Down Expand Up @@ -4220,7 +4269,7 @@ null_ xs = sizeOf_ xs ==. 0

-- #####

infix 4 <=., <., ==., /=.
infix 4 <=., >=., >., <., ==., /=.

(<=.) ::
( Ord a
Expand All @@ -4231,6 +4280,15 @@ infix 4 <=., <., ==., /=.
Term fn Bool
(<=.) = app lessOrEqualFn

(>=.) ::
( Ord a
, OrdLike fn a
) =>
Term fn a ->
Term fn a ->
Term fn Bool
(>=.) = flip_ (<=.)

(<.) ::
( Ord a
, OrdLike fn a
Expand All @@ -4240,6 +4298,15 @@ infix 4 <=., <., ==., /=.
Term fn Bool
(<.) = app lessFn

(>.) ::
( Ord a
, OrdLike fn a
) =>
Term fn a ->
Term fn a ->
Term fn Bool
(>.) = flip_ (<.)

(==.) ::
HasSpec fn a =>
Term fn a ->
Expand Down
8 changes: 8 additions & 0 deletions libs/constrained-generators/src/Constrained/Examples/Basic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,11 @@ propBack = constrained' $ \x y ->
, x <. 20
, 8 <. y
]

propBack' :: Specification BaseFn (Int, Int)
propBack' = constrained' $ \x y ->
[ y ==. x - 10
, 20 >. x
, 8 >. y
, y >. x - 20
]
1 change: 1 addition & 0 deletions libs/constrained-generators/test/Constrained/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ tests nightly =
testSpecNoShrink "trueSpecUniform" trueSpecUniform
testSpec "ifElseMany" ifElseMany
testSpecNoShrink "propBack" propBack
testSpecNoShrink "propBack'" propBack'
testSpec "complexUnion" complexUnion
testSpec "unionBounded" unionBounded
numberyTests
Expand Down

0 comments on commit c05e040

Please sign in to comment.