Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guard mux-pushing simplifications behind option #256

Merged
merged 5 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions what4/CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

* Add support for the `bitwuzla` SMT solver.

* Add `pushMuxOps` and `pushMuxOpsOption`. If this option is enabled, What4 will
push certain `ExprBuilder` operations (e.g., `zext`) down to the branches of
`ite` expressions. In some (but not all) circumstances, this can result in
operations that are easier for SMT solvers to reason about.

# 1.5.1 (October 2023)

* Require building with `versions >= 6.0.2`.
Expand Down
1 change: 1 addition & 0 deletions what4/src/What4/Expr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ module What4.Expr
, curProgramLoc
, unaryThreshold
, cacheStartSize
, pushMuxOps
, exprBuilderSplitConfig
, exprBuilderFreshConfig
, EmptyExprBuilderState(..)
Expand Down
200 changes: 137 additions & 63 deletions what4/src/What4/Expr/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ module What4.Expr.Builder
, sbNonceExpr
, curProgramLoc
, unaryThreshold
, pushMuxOps
, cacheStartSize
, userState
, exprCounter
Expand All @@ -76,6 +77,7 @@ module What4.Expr.Builder
-- * configuration options
, unaryThresholdOption
, cacheStartSizeOption
, pushMuxOpsOption
, cacheTerms

-- * Expr
Expand Down Expand Up @@ -374,6 +376,12 @@ data ExprBuilder t (st :: Type -> Type) (fs :: Type)
-- | The starting size when building a new cache
, sbCacheStartSize :: !(CFG.OptionSetting BaseIntegerType)

-- | If enabled, push certain 'ExprBuilder' operations (e.g., @zext@)
-- down to the branches of @ite@ expressions. In some (but not all)
-- circumstances, this can result in operations that are easier for
-- SMT solvers to reason about.
, sbPushMuxOps :: !(CFG.OptionSetting BaseBoolType)

-- | Counter to generate new unique identifiers for elements and functions.
, sbExprCounter :: !(NonceGenerator IO t)

Expand Down Expand Up @@ -421,6 +429,9 @@ unaryThreshold = to sbUnaryThreshold
cacheStartSize :: Getter (ExprBuilder t st fs) (CFG.OptionSetting BaseIntegerType)
cacheStartSize = to sbCacheStartSize

pushMuxOps :: Getter (ExprBuilder t st fs) (CFG.OptionSetting BaseBoolType)
pushMuxOps = to sbPushMuxOps

-- | Return a new expr builder where the configuration object has
-- been "split" using the @splitConfig@ operation.
-- The returned sym will share any preexisting options with the
Expand Down Expand Up @@ -456,16 +467,19 @@ exprBuilderFreshConfig sym =
cfg <- CFG.initialConfig 0
[ unaryThresholdDesc
, cacheStartSizeDesc
, pushMuxOpsDesc
]
unarySetting <- CFG.getOptionSetting unaryThresholdOption cfg
cacheStartSetting <- CFG.getOptionSetting cacheStartSizeOption cfg
pushMuxOpsSetting <- CFG.getOptionSetting pushMuxOpsOption cfg
CFG.extendConfig [cacheOptDesc gen storage_ref cacheStartSetting] cfg
nonLinearOps <- newIORef 0

return sym { sbConfiguration = cfg
, sbFloatReduce = True
, sbUnaryThreshold = unarySetting
, sbCacheStartSize = cacheStartSetting
, sbPushMuxOps = pushMuxOpsSetting
, sbProgramLoc = loc_ref
, sbCurAllocator = storage_ref
, sbNonLinearOps = nonLinearOps
Expand Down Expand Up @@ -650,7 +664,28 @@ unaryThresholdDesc = CFG.mkOpt unaryThresholdOption sty help (Just (ConcreteInte
where sty = CFG.integerWithMinOptSty (CFG.Inclusive 0)
help = Just "Maximum number of values in unary bitvector encoding."


------------------------------------------------------------------------
-- Configuration option for controlling whether to push certain ExprBuilder
-- operations (e.g., @zext@) down to the branches of @ite@ expressions.

-- | If this option enabled, push certain 'ExprBuilder' operations (e.g.,
-- @zext@) down to the branches of @ite@ expressions. In some (but not all)
-- circumstances, this can result in operations that are easier for SMT solvers
-- to reason about. The expressions that may be pushed down are determined on a
-- case-by-case basis in the 'IsExprBuilder' instance for 'ExprBuilder', but
-- this control applies to all such push-down checks.
--
-- This option is named \"backend.push_mux_ops\".
pushMuxOpsOption :: CFG.ConfigOption BaseBoolType
pushMuxOpsOption = CFG.configOption BaseBoolRepr "backend.push_mux_ops"

-- | The 'CFG.ConfigDesc' for 'pushMuxOpsOption'.
pushMuxOpsDesc :: CFG.ConfigDesc
pushMuxOpsDesc = CFG.mkOpt pushMuxOpsOption sty help (Just (ConcreteBool False))
where sty = CFG.boolOptSty
help = Just $
"If this option enabled, push certain ExprBuilder operations " <>
"(e.g., zext) down to the branches of ite expressions."

newExprBuilder ::
FloatModeRepr fm
Expand Down Expand Up @@ -678,9 +713,11 @@ newExprBuilder floatMode st gen = do
cfg <- CFG.initialConfig 0
[ unaryThresholdDesc
, cacheStartSizeDesc
, pushMuxOpsDesc
]
unarySetting <- CFG.getOptionSetting unaryThresholdOption cfg
cacheStartSetting <- CFG.getOptionSetting cacheStartSizeOption cfg
pushMuxOpsSetting <- CFG.getOptionSetting pushMuxOpsOption cfg
CFG.extendConfig [cacheOptDesc gen storage_ref cacheStartSetting] cfg
nonLinearOps <- newIORef 0

Expand All @@ -691,6 +728,7 @@ newExprBuilder floatMode st gen = do
, sbFloatReduce = True
, sbUnaryThreshold = unarySetting
, sbCacheStartSize = cacheStartSetting
, sbPushMuxOps = pushMuxOpsSetting
, sbProgramLoc = loc_ref
, sbExprCounter = gen
, sbCurAllocator = storage_ref
Expand Down Expand Up @@ -2913,17 +2951,20 @@ instance IsExprBuilder (ExprBuilder t st fs) where
Just LeqProof <- return $ isPosNat w
bvUnary sym $ UnaryBV.uext u w

| Just (BaseIte _ _ c a b) <- asApp x
, Just a_bv <- asBV a
, Just b_bv <- asBV b = do
Just LeqProof <- return $ isPosNat w
a' <- bvLit sym w $ BV.zext w a_bv
b' <- bvLit sym w $ BV.zext w b_bv
bvIte sym c a' b'

| otherwise = do
Just LeqProof <- return $ testLeq (knownNat :: NatRepr 1) w
sbMakeExpr sym $ BVZext w x
pmo <- CFG.getOpt (sbPushMuxOps sym)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick check: looks like you had to do this in 6-8 other locations as well. Do you think it's reasonable to float this getOpt to the top before the primary case on expr so that the case pattern match branches can be largely as they were with just the pmo check? I'm fine either way (based on the assumption that getOpt is inexpensive), but just wanted to call this out for you to make a decision on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand you correctly, you are suggesting to call getOpt at the beginning of each method's implementation to minimize the diff to the actual logic of the code itself? If so, I'm not sure that that would actually have the desired effect. In particular, each method of this IsExprBuilder instance is implemented using pattern guards, so in order to call getOpt before the pattern guards, we'd likely need to do something like this:

  bvZext sym w x = do
    pmo <- CFG.getOpt (sbPushMuxOps sym)
    if | Just xv <- asBV x = do ...
       | Just (BVZext _ y) <- asApp x = do ...
       |  ...

But using MultiWayIf (or a similar refactoring) here would cause the diff to be just as large (if not larger) than it is currently. As such, I opted to localize the getOpt calls to the parts of the code that actually make use of the option.

if | pmo
, Just (BaseIte _ _ c a b) <- asApp x
, Just a_bv <- asBV a
, Just b_bv <- asBV b -> do
Just LeqProof <- return $ isPosNat w
a' <- bvLit sym w $ BV.zext w a_bv
b' <- bvLit sym w $ BV.zext w b_bv
bvIte sym c a' b'

| otherwise -> do
Just LeqProof <- return $ testLeq (knownNat :: NatRepr 1) w
sbMakeExpr sym $ BVZext w x

bvSext sym w x
| Just xv <- asBV x = do
Expand All @@ -2944,17 +2985,20 @@ instance IsExprBuilder (ExprBuilder t st fs) where
Just LeqProof <- return $ isPosNat w
bvUnary sym $ UnaryBV.sext u w

| Just (BaseIte _ _ c a b) <- asApp x
, Just a_bv <- asBV a
, Just b_bv <- asBV b = do
Just LeqProof <- return $ isPosNat w
a' <- bvLit sym w $ BV.sext (bvWidth x) w a_bv
b' <- bvLit sym w $ BV.sext (bvWidth x) w b_bv
bvIte sym c a' b'

| otherwise = do
Just LeqProof <- return $ testLeq (knownNat :: NatRepr 1) w
sbMakeExpr sym (BVSext w x)
pmo <- CFG.getOpt (sbPushMuxOps sym)
if | pmo
, Just (BaseIte _ _ c a b) <- asApp x
, Just a_bv <- asBV a
, Just b_bv <- asBV b -> do
Just LeqProof <- return $ isPosNat w
a' <- bvLit sym w $ BV.sext (bvWidth x) w a_bv
b' <- bvLit sym w $ BV.sext (bvWidth x) w b_bv
bvIte sym c a' b'

| otherwise -> do
Just LeqProof <- return $ testLeq (knownNat :: NatRepr 1) w
sbMakeExpr sym (BVSext w x)

bvXorBits sym x y
| x == y = bvZero sym (bvWidth x) -- special case: x `xor` x = 0
Expand All @@ -2965,22 +3009,6 @@ instance IsExprBuilder (ExprBuilder t st fs) where
bvAndBits sym x y
| x == y = return x -- Special case: idempotency of and

| Just (BaseIte _ _ c a b) <- asApp x
, Just a_bv <- asBV a
, Just b_bv <- asBV b
, Just y_bv <- asBV y = do
a' <- bvLit sym (bvWidth x) $ BV.and a_bv y_bv
b' <- bvLit sym (bvWidth x) $ BV.and b_bv y_bv
bvIte sym c a' b'

| Just (BaseIte _ _ c a b) <- asApp y
, Just a_bv <- asBV a
, Just b_bv <- asBV b
, Just x_bv <- asBV x = do
a' <- bvLit sym (bvWidth x) $ BV.and x_bv a_bv
b' <- bvLit sym (bvWidth x) $ BV.and x_bv b_bv
bvIte sym c a' b'

| Just (BVOrBits _ bs) <- asApp x
, bvOrContains y bs
= return y -- absorption law
Expand All @@ -2990,25 +3018,47 @@ instance IsExprBuilder (ExprBuilder t st fs) where
= return x -- absorption law

| otherwise
= let sr = SR.SemiRingBVRepr SR.BVBitsRepr (bvWidth x)
in semiRingMul sym sr x y
= do pmo <- CFG.getOpt (sbPushMuxOps sym)
if | pmo
, Just (BaseIte _ _ c a b) <- asApp x
, Just a_bv <- asBV a
, Just b_bv <- asBV b
, Just y_bv <- asBV y -> do
a' <- bvLit sym (bvWidth x) $ BV.and a_bv y_bv
b' <- bvLit sym (bvWidth x) $ BV.and b_bv y_bv
bvIte sym c a' b'

| pmo
, Just (BaseIte _ _ c a b) <- asApp y
, Just a_bv <- asBV a
, Just b_bv <- asBV b
, Just x_bv <- asBV x -> do
a' <- bvLit sym (bvWidth x) $ BV.and x_bv a_bv
b' <- bvLit sym (bvWidth x) $ BV.and x_bv b_bv
bvIte sym c a' b'

| otherwise
-> let sr = SR.SemiRingBVRepr SR.BVBitsRepr (bvWidth x)
in semiRingMul sym sr x y

-- XOR by the all-1 constant of the bitwise semiring.
-- This is equivalant to negation
bvNotBits sym x
| Just xv <- asBV x
= bvLit sym (bvWidth x) $ xv `BV.xor` (BV.maxUnsigned (bvWidth x))

| Just (BaseIte _ _ c a b) <- asApp x
, Just a_bv <- asBV a
, Just b_bv <- asBV b = do
a' <- bvLit sym (bvWidth x) $ BV.complement (bvWidth x) a_bv
b' <- bvLit sym (bvWidth x) $ BV.complement (bvWidth x) b_bv
bvIte sym c a' b'

| otherwise
= let sr = (SR.SemiRingBVRepr SR.BVBitsRepr (bvWidth x))
in semiRingSum sym $ WSum.addConstant sr (asWeightedSum sr x) (BV.maxUnsigned (bvWidth x))
= do pmo <- CFG.getOpt (sbPushMuxOps sym)
if | pmo
, Just (BaseIte _ _ c a b) <- asApp x
, Just a_bv <- asBV a
, Just b_bv <- asBV b -> do
a' <- bvLit sym (bvWidth x) $ BV.complement (bvWidth x) a_bv
b' <- bvLit sym (bvWidth x) $ BV.complement (bvWidth x) b_bv
bvIte sym c a' b'
| otherwise ->
let sr = (SR.SemiRingBVRepr SR.BVBitsRepr (bvWidth x))
in semiRingSum sym $ WSum.addConstant sr (asWeightedSum sr x) (BV.maxUnsigned (bvWidth x))

bvOrBits sym x y =
case (asBV x, asBV y) of
Expand Down Expand Up @@ -3085,20 +3135,23 @@ instance IsExprBuilder (ExprBuilder t st fs) where

bvNeg sym x
| Just xv <- asBV x = bvLit sym (bvWidth x) (BV.negate (bvWidth x) xv)
| Just (BaseIte _ _ c a b) <- asApp x
, Just a_bv <- asBV a
, Just b_bv <- asBV b = do
a' <- bvLit sym (bvWidth x) $ BV.negate (bvWidth x) a_bv
b' <- bvLit sym (bvWidth x) $ BV.negate (bvWidth x) b_bv
bvIte sym c a' b'
| otherwise =
do ut <- CFG.getOpt (sbUnaryThreshold sym)
let ?unaryThreshold = fromInteger ut
sbTryUnaryTerm sym
(do ux <- asUnaryBV sym x
Just (UnaryBV.neg sym ux))
(do let sr = SR.SemiRingBVRepr SR.BVArithRepr (bvWidth x)
scalarMul sym sr (BV.mkBV (bvWidth x) (-1)) x)
| otherwise = do
pmo <- CFG.getOpt (sbPushMuxOps sym)
if | pmo
, Just (BaseIte _ _ c a b) <- asApp x
, Just a_bv <- asBV a
, Just b_bv <- asBV b -> do
a' <- bvLit sym (bvWidth x) $ BV.negate (bvWidth x) a_bv
b' <- bvLit sym (bvWidth x) $ BV.negate (bvWidth x) b_bv
bvIte sym c a' b'
| otherwise -> do
ut <- CFG.getOpt (sbUnaryThreshold sym)
let ?unaryThreshold = fromInteger ut
sbTryUnaryTerm sym
(do ux <- asUnaryBV sym x
Just (UnaryBV.neg sym ux))
(do let sr = SR.SemiRingBVRepr SR.BVArithRepr (bvWidth x)
scalarMul sym sr (BV.mkBV (bvWidth x) (-1)) x)

bvIsNonzero sym x
| Just (BaseIte _ _ p t f) <- asApp x
Expand Down Expand Up @@ -3391,7 +3444,28 @@ instance IsExprBuilder (ExprBuilder t st fs) where
else
sbMakeExpr sym $ ArrayMap idx_tps baseRepr new_map def_map

arrayIte sym p x y = mkIte sym p x y
arrayIte sym p x y = do
pmo <- CFG.getOpt (sbPushMuxOps sym)
if -- Extract all concrete updates out.
| not pmo
, ArrayMapView mx x' <- viewArrayMap x
, ArrayMapView my y' <- viewArrayMap y
, not (AUM.null mx) || not (AUM.null my) -> do
case exprType x of
BaseArrayRepr idxRepr bRepr -> do
let both_fn _ u v = baseTypeIte sym p u v
left_fn idx u = do
v <- sbConcreteLookup sym y' (Just idx) =<< symbolicIndices sym idx
both_fn idx u v
right_fn idx v = do
u <- sbConcreteLookup sym x' (Just idx) =<< symbolicIndices sym idx
both_fn idx u v
mz <- AUM.mergeM bRepr both_fn left_fn right_fn mx my
z' <- arrayIte sym p x' y'

sbMakeExpr sym $ ArrayMap idxRepr bRepr mz z'

| otherwise -> mkIte sym p x y

arrayEq sym x y
| x == y =
Expand Down
2 changes: 1 addition & 1 deletion what4/src/What4/Protocol/SMTLib2.hs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ import Control.Monad.Fail( MonadFail )

import Control.Applicative
import Control.Exception
import Control.Monad (forM, forM_, replicateM_, unless, when)
import Control.Monad (forM, replicateM_, unless, when)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Except (MonadError(..), ExceptT, runExceptT)
import Control.Monad.Reader (MonadReader(..), ReaderT(..), asks)
Expand Down
Loading
Loading