Skip to content

Commit

Permalink
Merge pull request #1193 from langston-barrett/lb/llvm-casts
Browse files Browse the repository at this point in the history
crucible-llvm: Refactor and export override pipe-fitting code
  • Loading branch information
langston-barrett committed Mar 27, 2024
2 parents 5f5447d + 2a87559 commit 383ddb8
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 72 deletions.
1 change: 1 addition & 0 deletions crucible-llvm/crucible-llvm.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ library
Lang.Crucible.LLVM.Extension
Lang.Crucible.LLVM.Globals
Lang.Crucible.LLVM.Intrinsics
Lang.Crucible.LLVM.Intrinsics.Cast
Lang.Crucible.LLVM.Intrinsics.Libc
Lang.Crucible.LLVM.Intrinsics.LLVM
Lang.Crucible.LLVM.MalformedLLVMModule
Expand Down
120 changes: 120 additions & 0 deletions crucible-llvm/src/Lang/Crucible/LLVM/Intrinsics/Cast.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
-- |
-- Module : Lang.Crucible.LLVM.Intrinsics.Cast
-- Description : Cast between bitvectors and pointers in signatures
-- Copyright : (c) Galois, Inc 2024
-- License : BSD3
-- Maintainer : Langston Barrett <langston@galois.com>
-- Stability : provisional
--
-- The built-in overrides in "Lang.Crucible.LLVM.Intrinsics.Libc" and
-- "Lang.Crucible.LLVM.Intrinsics.LLVM" frequently take arguments of type
-- 'Lang.Crucible.Types.BVType', but at runtime everything is represented as an
-- 'Lang.Crucible.LLVM.MemModel.Pointer.LLVMPtr'. This module contains helpers
-- for \"casting\" between pointers and bitvectors.
------------------------------------------------------------------------

{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Lang.Crucible.LLVM.Intrinsics.Cast
( ValCastError
, printValCastError
, ArgCast(applyArgCast)
, ValCast(applyValCast)
, castLLVMArgs
, castLLVMRet
) where

import Control.Monad.IO.Class (liftIO)
import Control.Lens

import qualified Data.Parameterized.Context as Ctx
import Data.Parameterized.Some (Some(Some))
import Data.Parameterized.TraversableFC (fmapFC)

import Lang.Crucible.Backend
import Lang.Crucible.Simulator.OverrideSim
import Lang.Crucible.Simulator.RegMap
import Lang.Crucible.Types

import Lang.Crucible.LLVM.MemModel

data ValCastError
= -- | Mismatched number of arguments ('castLLVMArgs') or struct fields
-- ('castLLVMRet').
MismatchedShape
-- | Can\'t cast between these types
| ValCastError (Some TypeRepr) (Some TypeRepr)

-- | Turn a 'ValCastError' into a human-readable message (lines).
printValCastError :: ValCastError -> [String]
printValCastError =
\case
MismatchedShape -> ["argument shape mismatch"]
ValCastError (Some ret) (Some ret') ->
[ "Cannot cast types"
, "*** Source type: " ++ show ret
, "*** Target type: " ++ show ret'
]

-- | A function to (infallibly) cast between 'Ctx.Assignment's of 'RegEntry's.
newtype ArgCast p sym ext args args' =
ArgCast { applyArgCast :: (forall rtp l a.
Ctx.Assignment (RegEntry sym) args ->
OverrideSim p sym ext rtp l a (Ctx.Assignment (RegEntry sym) args')) }

-- | A function to (infallibly) cast a value of types @tp@ to @tp'@.
newtype ValCast p sym ext tp tp' =
ValCast { applyValCast :: (forall rtp l a.
RegValue sym tp ->
OverrideSim p sym ext rtp l a (RegValue sym tp')) }

-- | Attempt to construct a function to cast between 'Ctx.Assignment's of
-- 'RegEntry's.
castLLVMArgs :: forall p sym ext bak args args'.
IsSymBackend sym bak =>
bak ->
CtxRepr args' ->
CtxRepr args ->
Either ValCastError (ArgCast p sym ext args args')
castLLVMArgs _ Ctx.Empty Ctx.Empty =
Right (ArgCast (\_ -> return Ctx.Empty))
castLLVMArgs bak (rest' Ctx.:> tp') (rest Ctx.:> tp) =
do ValCast f <- castLLVMRet bak tp tp'
ArgCast fs <- castLLVMArgs bak rest' rest
Right (ArgCast
(\(xs Ctx.:> x) -> do
xs' <- fs xs
x' <- f (regValue x)
pure (xs' Ctx.:> RegEntry tp' x')))
castLLVMArgs _ _ _ = Left MismatchedShape

-- | Attempt to construct a function to cast values of type @ret@ to type
-- @ret'@.
castLLVMRet ::
IsSymBackend sym bak =>
bak ->
TypeRepr ret ->
TypeRepr ret' ->
Either ValCastError (ValCast p sym ext ret ret')
castLLVMRet bak (BVRepr w) (LLVMPointerRepr w')
| Just Refl <- testEquality w w'
= Right (ValCast (liftIO . llvmPointer_bv (backendGetSym bak)))
castLLVMRet bak (LLVMPointerRepr w) (BVRepr w')
| Just Refl <- testEquality w w'
= Right (ValCast (liftIO . projectLLVM_bv bak))
castLLVMRet bak (VectorRepr tp) (VectorRepr tp')
= do ValCast f <- castLLVMRet bak tp tp'
Right (ValCast (traverse f))
castLLVMRet bak (StructRepr ctx) (StructRepr ctx')
= do ArgCast tf <- castLLVMArgs bak ctx' ctx
Right (ValCast (\vals ->
let vals' = Ctx.zipWith (\tp (RV v) -> RegEntry tp v) ctx vals in
fmapFC (\x -> RV (regValue x)) <$> tf vals'))

castLLVMRet _bak ret ret'
| Just Refl <- testEquality ret ret'
= Right (ValCast return)
castLLVMRet _bak ret ret' = Left (ValCastError (Some ret) (Some ret'))
88 changes: 16 additions & 72 deletions crucible-llvm/src/Lang/Crucible/LLVM/Intrinsics/Common.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ import qualified System.Info as Info
import qualified ABI.Itanium as ABI
import qualified Data.Parameterized.Context as Ctx
import Data.Parameterized.Some (Some(..))
import Data.Parameterized.TraversableFC (fmapFC)

import Lang.Crucible.Backend
import Lang.Crucible.CFG.Common (GlobalVar)
Expand All @@ -78,6 +77,7 @@ import Lang.Crucible.LLVM.Eval (callStackFromMemVar)
import Lang.Crucible.LLVM.Globals (registerFunPtr)
import Lang.Crucible.LLVM.MemModel
import Lang.Crucible.LLVM.MemModel.CallStack (CallStack)
import qualified Lang.Crucible.LLVM.Intrinsics.Cast as Cast
import Lang.Crucible.LLVM.Translation.Monad
import Lang.Crucible.LLVM.Translation.Types

Expand Down Expand Up @@ -199,74 +199,6 @@ apply this special case to other override functions (e.g.,
------------------------------------------------------------------------
-- ** register_llvm_override

newtype ArgTransformer p sym ext args args' =
ArgTransformer { applyArgTransformer :: (forall rtp l a.
Ctx.Assignment (RegEntry sym) args ->
OverrideSim p sym ext rtp l a (Ctx.Assignment (RegEntry sym) args')) }

newtype ValTransformer p sym ext tp tp' =
ValTransformer { applyValTransformer :: (forall rtp l a.
RegValue sym tp ->
OverrideSim p sym ext rtp l a (RegValue sym tp')) }

transformLLVMArgs :: forall m p sym ext bak args args'.
(IsSymBackend sym bak, Monad m, HasLLVMAnn sym) =>
-- | This function name is only used in panic messages.
FunctionName ->
bak ->
CtxRepr args' ->
CtxRepr args ->
m (ArgTransformer p sym ext args args')
transformLLVMArgs _fnName _ Ctx.Empty Ctx.Empty =
return (ArgTransformer (\_ -> return Ctx.Empty))
transformLLVMArgs fnName bak (rest' Ctx.:> tp') (rest Ctx.:> tp) = do
return (ArgTransformer
(\(xs Ctx.:> x) ->
do (ValTransformer f) <- transformLLVMRet fnName bak tp tp'
(ArgTransformer fs) <- transformLLVMArgs fnName bak rest' rest
xs' <- fs xs
x' <- RegEntry tp' <$> f (regValue x)
pure (xs' Ctx.:> x')))
transformLLVMArgs fnName _ _ _ =
panic "Intrinsics.transformLLVMArgs"
[ "transformLLVMArgs: argument shape mismatch!"
, "in function: " ++ Text.unpack (functionName fnName)
]

transformLLVMRet ::
(IsSymBackend sym bak, Monad m, HasLLVMAnn sym) =>
-- | This function name is only used in panic messages.
FunctionName ->
bak ->
TypeRepr ret ->
TypeRepr ret' ->
m (ValTransformer p sym ext ret ret')
transformLLVMRet _fnName bak (BVRepr w) (LLVMPointerRepr w')
| Just Refl <- testEquality w w'
= return (ValTransformer (liftIO . llvmPointer_bv (backendGetSym bak)))
transformLLVMRet _fnName bak (LLVMPointerRepr w) (BVRepr w')
| Just Refl <- testEquality w w'
= return (ValTransformer (liftIO . projectLLVM_bv bak))
transformLLVMRet fnName bak (VectorRepr tp) (VectorRepr tp')
= do ValTransformer f <- transformLLVMRet fnName bak tp tp'
return (ValTransformer (traverse f))
transformLLVMRet fnName bak (StructRepr ctx) (StructRepr ctx')
= do ArgTransformer tf <- transformLLVMArgs fnName bak ctx' ctx
return (ValTransformer (\vals ->
let vals' = Ctx.zipWith (\tp (RV v) -> RegEntry tp v) ctx vals in
fmapFC (\x -> RV (regValue x)) <$> tf vals'))

transformLLVMRet _fnName _bak ret ret'
| Just Refl <- testEquality ret ret'
= return (ValTransformer return)
transformLLVMRet fnName _bak ret ret'
= panic "Intrinsics.transformLLVMRet"
[ "Cannot transform types"
, "*** Source type: " ++ show ret
, "*** Target type: " ++ show ret'
, "in function: " ++ Text.unpack (functionName fnName)
]

-- | Do some pipe-fitting to match a Crucible override function into the shape
-- expected by the LLVM calling convention. This basically just coerces
-- between values of @BVType w@ and values of @LLVMPointerType w@.
Expand All @@ -283,11 +215,23 @@ build_llvm_override ::
OverrideSim p sym ext rtp l a (Override p sym ext args' ret')
build_llvm_override fnm args ret args' ret' llvmOverride =
ovrWithBackend $ \bak ->
do fargs <- transformLLVMArgs fnm bak args args'
fret <- transformLLVMRet fnm bak ret ret'
do fargs <-
case Cast.castLLVMArgs bak args args' of
Left err ->
panic "Intrinsics.build_llvm_override"
(Cast.printValCastError err ++
[ "in function: " ++ Text.unpack (functionName fnm) ])
Right f -> pure f
fret <-
case Cast.castLLVMRet bak ret ret' of
Left err ->
panic "Intrinsics.build_llvm_override"
(Cast.printValCastError err ++
[ "in function: " ++ Text.unpack (functionName fnm) ])
Right f -> pure f
return $ mkOverride' fnm ret' $
do RegMap xs <- getOverrideArgs
applyValTransformer fret =<< llvmOverride =<< applyArgTransformer fargs xs
Cast.applyValCast fret =<< llvmOverride =<< Cast.applyArgCast fargs xs

polymorphic1_llvm_override :: forall p sym arch wptr l a rtp.
(IsSymInterface sym, HasLLVMAnn sym, HasPtrWidth wptr) =>
Expand Down

0 comments on commit 383ddb8

Please sign in to comment.