Skip to content

Commit

Permalink
crucible-llvm: Clean up and export override pipe-fitting code
Browse files Browse the repository at this point in the history
Make it not panic by default, to make it usable by other clients who may
not want to panic. Also, clear up the phase boundary to handle errors
(e.g., in the case for structs).
  • Loading branch information
langston-barrett committed Mar 27, 2024
1 parent 602e33a commit ac879a3
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 47 deletions.
96 changes: 51 additions & 45 deletions crucible-llvm/src/Lang/Crucible/LLVM/Intrinsics/Cast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,27 @@
------------------------------------------------------------------------

{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Lang.Crucible.LLVM.Intrinsics.Cast
( ArgTransformer(applyArgTransformer)
( ValCastError
, printValCastError
, ArgTransformer(applyArgTransformer)
, ValTransformer(applyValTransformer)
, transformLLVMArgs
, transformLLVMRet
) where

import Control.Monad.IO.Class (liftIO)
import Control.Lens
import Data.Either.Extra (mapLeft)
import qualified Data.Text as Text
import Data.Vector (Vector)

import qualified Data.Parameterized.Context as Ctx
import Data.Parameterized.Some (Some(Some))
import Data.Parameterized.TraversableFC (fmapFC)

import Lang.Crucible.Backend
Expand All @@ -34,6 +41,20 @@ import What4.FunctionName

import Lang.Crucible.LLVM.MemModel

data ValCastError
= MismatchedShape
| ValCastError (Some TypeRepr) (Some TypeRepr)

printValCastError :: ValCastError -> [String]
printValCastError =
\case
MismatchedShape -> ["argument shape mismatch"]
ValCastError (Some ret) (Some ret') ->
[ "Cannot transform types"
, "*** Source type: " ++ show ret
, "*** Target type: " ++ show ret'
]

newtype ArgTransformer p sym ext args args' =
ArgTransformer { applyArgTransformer :: (forall rtp l a.
Ctx.Assignment (RegEntry sym) args ->
Expand All @@ -44,61 +65,46 @@ newtype ValTransformer p sym ext tp tp' =
RegValue sym tp ->
OverrideSim p sym ext rtp l a (RegValue sym tp')) }

transformLLVMArgs :: forall m p sym ext bak args args'.
(IsSymBackend sym bak, Monad m, HasLLVMAnn sym) =>
-- | This function name is only used in panic messages.
FunctionName ->
transformLLVMArgs :: forall p sym ext bak args args'.
(IsSymBackend sym bak, HasLLVMAnn sym) =>
bak ->
CtxRepr args' ->
CtxRepr args ->
m (ArgTransformer p sym ext args args')
transformLLVMArgs _fnName _ Ctx.Empty Ctx.Empty =
return (ArgTransformer (\_ -> return Ctx.Empty))
transformLLVMArgs fnName bak (rest' Ctx.:> tp') (rest Ctx.:> tp) = do
return (ArgTransformer
(\(xs Ctx.:> x) ->
do (ValTransformer f) <- transformLLVMRet fnName bak tp tp'
(ArgTransformer fs) <- transformLLVMArgs fnName bak rest' rest
xs' <- fs xs
x' <- RegEntry tp' <$> f (regValue x)
pure (xs' Ctx.:> x')))
transformLLVMArgs fnName _ _ _ =
panic "Intrinsics.transformLLVMArgs"
[ "transformLLVMArgs: argument shape mismatch!"
, "in function: " ++ Text.unpack (functionName fnName)
]
Either ValCastError (ArgTransformer p sym ext args args')
transformLLVMArgs _ Ctx.Empty Ctx.Empty =
Right (ArgTransformer (\_ -> return Ctx.Empty))
transformLLVMArgs bak (rest' Ctx.:> tp') (rest Ctx.:> tp) =
do (ValTransformer f) <- transformLLVMRet bak tp tp'
(ArgTransformer fs) <- transformLLVMArgs bak rest' rest
Right (ArgTransformer
(\(xs Ctx.:> x) -> do
xs' <- fs xs
x' <- f (regValue x)
pure (xs' Ctx.:> RegEntry tp' x')))
transformLLVMArgs _ _ _ = Left MismatchedShape

transformLLVMRet ::
(IsSymBackend sym bak, Monad m, HasLLVMAnn sym) =>
-- | This function name is only used in panic messages.
FunctionName ->
(IsSymBackend sym bak, HasLLVMAnn sym) =>
bak ->
TypeRepr ret ->
TypeRepr ret' ->
m (ValTransformer p sym ext ret ret')
transformLLVMRet _fnName bak (BVRepr w) (LLVMPointerRepr w')
Either ValCastError (ValTransformer p sym ext ret ret')
transformLLVMRet bak (BVRepr w) (LLVMPointerRepr w')
| Just Refl <- testEquality w w'
= return (ValTransformer (liftIO . llvmPointer_bv (backendGetSym bak)))
transformLLVMRet _fnName bak (LLVMPointerRepr w) (BVRepr w')
= Right (ValTransformer (liftIO . llvmPointer_bv (backendGetSym bak)))
transformLLVMRet bak (LLVMPointerRepr w) (BVRepr w')
| Just Refl <- testEquality w w'
= return (ValTransformer (liftIO . projectLLVM_bv bak))
transformLLVMRet fnName bak (VectorRepr tp) (VectorRepr tp')
= do ValTransformer f <- transformLLVMRet fnName bak tp tp'
return (ValTransformer (traverse f))
transformLLVMRet fnName bak (StructRepr ctx) (StructRepr ctx')
= do ArgTransformer tf <- transformLLVMArgs fnName bak ctx' ctx
return (ValTransformer (\vals ->
= Right (ValTransformer (liftIO . projectLLVM_bv bak))
transformLLVMRet bak (VectorRepr tp) (VectorRepr tp')
= do ValTransformer f <- transformLLVMRet bak tp tp'
Right (ValTransformer (traverse f))
transformLLVMRet bak (StructRepr ctx) (StructRepr ctx')
= do ArgTransformer tf <- transformLLVMArgs bak ctx' ctx
Right (ValTransformer (\vals ->
let vals' = Ctx.zipWith (\tp (RV v) -> RegEntry tp v) ctx vals in
fmapFC (\x -> RV (regValue x)) <$> tf vals'))

transformLLVMRet _fnName _bak ret ret'
transformLLVMRet _bak ret ret'
| Just Refl <- testEquality ret ret'
= return (ValTransformer return)
transformLLVMRet fnName _bak ret ret'
= panic "Intrinsics.transformLLVMRet"
[ "Cannot transform types"
, "*** Source type: " ++ show ret
, "*** Target type: " ++ show ret'
, "in function: " ++ Text.unpack (functionName fnName)
]

= Right (ValTransformer return)
transformLLVMRet _bak ret ret' = Left (ValCastError (Some ret) (Some ret'))
17 changes: 15 additions & 2 deletions crucible-llvm/src/Lang/Crucible/LLVM/Intrinsics/Common.hs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ import Lang.Crucible.Backend
import Lang.Crucible.CFG.Common (GlobalVar)
import Lang.Crucible.Simulator.ExecutionTree (FnState(UseOverride))
import Lang.Crucible.FunctionHandle (FnHandle, mkHandle')
import Lang.Crucible.Panic (panic)
import Lang.Crucible.Simulator (stateContext, simHandleAllocator)
import Lang.Crucible.Simulator.OverrideSim
import Lang.Crucible.Utils.MonadVerbosity (getLogFunction)
Expand Down Expand Up @@ -214,8 +215,20 @@ build_llvm_override ::
OverrideSim p sym ext rtp l a (Override p sym ext args' ret')
build_llvm_override fnm args ret args' ret' llvmOverride =
ovrWithBackend $ \bak ->
do fargs <- Cast.transformLLVMArgs fnm bak args args'
fret <- Cast.transformLLVMRet fnm bak ret ret'
do fargs <-
case Cast.transformLLVMArgs bak args args' of
Left err ->
panic "Intrinsics.build_llvm_override"
(Cast.printValCastError err ++
[ "in function: " ++ Text.unpack (functionName fnm) ])
Right f -> pure f
fret <-
case Cast.transformLLVMRet bak ret ret' of
Left err ->
panic "Intrinsics.build_llvm_override"
(Cast.printValCastError err ++
[ "in function: " ++ Text.unpack (functionName fnm) ])
Right f -> pure f
return $ mkOverride' fnm ret' $
do RegMap xs <- getOverrideArgs
Cast.applyValTransformer fret =<< llvmOverride =<< Cast.applyArgTransformer fargs xs
Expand Down

0 comments on commit ac879a3

Please sign in to comment.