-
Notifications
You must be signed in to change notification settings - Fork 42
/
Constant.hs
1204 lines (1066 loc) · 41.5 KB
/
Constant.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
-----------------------------------------------------------------------
-- |
-- Module : Lang.Crucible.LLVM.Translation.Constant
-- Description : LLVM constant expression evaluation and GEPs
-- Copyright : (c) Galois, Inc 2014-2015
-- License : BSD3
-- Maintainer : Rob Dockins <rdockins@galois.com>
-- Stability : provisional
--
-- This module provides translation-time evaluation of constant
-- expressions. It also provides an intermediate representation
-- for GEP (getelementpointer) instructions that makes more explicit
-- the places where vectorization may occur, as well as resolving type
-- sizes and field offsets.
--
-- See @liftConstant@ for how to turn these into expressions.
-----------------------------------------------------------------------
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Lang.Crucible.LLVM.Translation.Constant
( -- * Representation of LLVM constant values
LLVMConst(..)
, boolConst
, intConst
-- * Translations from LLVM syntax to constant values
, transConstant
, transConstantWithType
, transConstant'
, transConstantExpr
-- * Intermediate representation for GEP
, GEP(..)
, GEPResult(..)
, translateGEP
-- * Utility functions
, showInstr
, testBreakpointFunction
) where
import Control.Lens( to, (^.) )
import Control.Monad
import Control.Monad.Except
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Bits
import Data.Kind
import Data.List (intercalate, isPrefixOf)
import Data.Traversable
import Data.Fixed (mod')
import qualified Data.Vector as V
import Numeric.Natural
import GHC.TypeNats
import qualified Text.LLVM.AST as L
import qualified Text.LLVM.PP as L
import qualified Data.BitVector.Sized as BV
import qualified Data.BitVector.Sized.Overflow as BV
import Data.Parameterized.NatRepr
import Data.Parameterized.Some
import Data.Parameterized.DecidableEq (decEq)
import Lang.Crucible.LLVM.Bytes
import Lang.Crucible.LLVM.DataLayout( intLayout, EndianForm(..) )
import Lang.Crucible.LLVM.MemModel.Pointer
import Lang.Crucible.LLVM.MemType
import qualified Lang.Crucible.LLVM.PrettyPrint as LPP
import Lang.Crucible.LLVM.Translation.Types
import Lang.Crucible.LLVM.TypeContext
-- | Pretty print an LLVM instruction
showInstr :: L.Instr -> String
showInstr i = show (L.ppLLVM38 (L.ppInstr i))
-- | Intermediate representation of a GEP.
-- A @GEP n expr@ is a representation of a GEP with
-- @n@ parallel vector lanes with expressions represented
-- by @expr@ values.
data GEP (n :: Nat) (expr :: Type) where
-- | Start a GEP with a single base pointer
GEP_scalar_base :: expr -> GEP 1 expr
-- | Start a GEP with a vector of @n@ base pointers
GEP_vector_base :: NatRepr n -> expr -> GEP n expr
-- | Copy a scalar base vector pointwise into a
-- vector of length @n@.
GEP_scatter :: NatRepr n -> GEP 1 expr -> GEP n expr
-- | Add the offset corresponding to the given field
-- pointwise to each pointer
GEP_field :: FieldInfo -> GEP n expr -> GEP n expr
-- | Add an offset corresponding to the given array index
-- (multiplied by the given type size) pointwise to the pointers
-- in each lane.
GEP_index_each :: MemType -> GEP n expr -> expr -> GEP n expr
-- | Given a vector of offsets (whose length must match
-- the number of lanes), multiply each one by the
-- type size, and add the offsets to the corresponding
-- pointers.
GEP_index_vector :: MemType -> GEP n expr -> expr -> GEP n expr
instance Functor (GEP n) where
fmap = fmapDefault
instance Foldable (GEP n) where
foldMap = foldMapDefault
instance Traversable (GEP n) where
traverse f gep = case gep of
GEP_scalar_base x -> GEP_scalar_base <$> f x
GEP_vector_base n x -> GEP_vector_base n <$> f x
GEP_scatter n gep' -> GEP_scatter n <$> traverse f gep'
GEP_field fi gep' -> GEP_field fi <$> traverse f gep'
GEP_index_each mt gep' idx -> GEP_index_each mt <$> traverse f gep' <*> f idx
GEP_index_vector mt gep' idx -> GEP_index_vector mt <$> traverse f gep' <*> f idx
-- | The result of a GEP instruction translation. It records the number
-- of parallel vector lanes in the resulting instruction, the resulting
-- memory type of the instruction, and the sequence of sub-operations
-- required to compute the GEP instruction.
data GEPResult expr where
GEPResult :: (1 <= n) => NatRepr n -> MemType -> GEP n expr -> GEPResult expr
instance Functor GEPResult where
fmap = fmapDefault
instance Foldable GEPResult where
foldMap = foldMapDefault
instance Traversable GEPResult where
traverse f (GEPResult n mt gep) = GEPResult n mt <$> traverse f gep
-- | Given the data for an LLVM getelementpointer instruction,
-- preprocess the instruction into a @GEPResult@, checking
-- types, computing vectorization lanes, etc.
--
-- As a concrete example, consider a call to
-- @'translateGEP' inbounds baseTy basePtr elts@ with the following
-- instruction:
--
-- @
-- getelementptr [12 x i8], ptr %aptr, i64 0, i32 1
-- @
--
-- Here:
--
-- * @inbounds@ is 'False', as the keyword of the same name is missing from
-- the instruction. (Currently, @crucible-llvm@ ignores this information.)
--
-- * @baseTy@ is @[12 x i8]@. This is the type used as the basis for
-- subsequent calculations.
--
-- * @basePtr@ is @ptr %aptr@. This pointer is used as the base address to
-- start calculations from. Note that the type of @basePtr@ is /not/
-- @baseTy@, but rather a pointer type.
--
-- * The @elts@ are @[i64 0, i32 1]@. These are the indices that indicate
-- which of the elements of the aggregate object are indexed.
translateGEP :: forall wptr m.
(?lc :: TypeContext, MonadError String m, HasPtrWidth wptr) =>
Bool {- ^ inbounds flag -} ->
L.Type {- ^ base type for calculations -} ->
L.Typed L.Value {- ^ base pointer expression -} ->
[L.Typed L.Value] {- ^ index arguments -} ->
m (GEPResult (L.Typed L.Value))
translateGEP _ _ _ [] =
throwError "getelementpointer must have at least one index"
translateGEP inbounds baseTy basePtr elts =
do baseMemType <- liftMemType baseTy
mt <- liftMemType (L.typedType basePtr)
-- Input value to a GEP must have a pointer type (or be a vector of pointer
-- types), and the base type used for calculations must be representable
-- as a memory type. The resulting memory type drives the interpretation of
-- the GEP arguments.
case mt of
-- Vector base case, with as many lanes as there are input pointers
VecType n vmt
| isPointerMemType vmt
, Some lanes <- mkNatRepr n
, Just LeqProof <- isPosNat lanes
-> let mt' = ArrayType 0 baseMemType in
go lanes mt' (GEP_vector_base lanes basePtr) elts
-- Scalar base case with exactly 1 lane
_ | isPointerMemType mt
-> let mt' = ArrayType 0 baseMemType in
go (knownNat @1) mt' (GEP_scalar_base basePtr) elts
| otherwise
-> badGEP
where
badGEP :: m a
badGEP = throwError $ unlines [ "Invalid GEP", showInstr (L.GEP inbounds baseTy basePtr elts) ]
-- This auxilary function builds up the intermediate GEP mini-instructions that compute
-- the overall GEP, as well as the resulting memory type of the final pointers and the
-- number of vector lanes eventually computed by the GEP.
go ::
(1 <= lanes) =>
NatRepr lanes {- Number of lanes of the GEP so far -} ->
MemType {- Memory type of the incoming pointer(s) -} ->
GEP lanes (L.Typed L.Value) {- partial GEP computation -} ->
[L.Typed L.Value] {- remaining arguments to process -} ->
m (GEPResult (L.Typed L.Value))
-- Final step, all arguments are used up, return the GEPResult
go lanes mt gep [] = return (GEPResult lanes mt gep)
-- Resolve one offset value and recurse
go lanes mt gep (off:xs) =
do offt <- liftMemType (L.typedType off)
-- The meaning of the offset depends on the static type of the intermediate result
case mt of
ArrayType _ mt' -> goArray lanes off offt mt' gep xs
VecType _ mt' -> goArray lanes off offt mt' gep xs
StructType si -> goStruct lanes off offt si gep xs
_ -> badGEP
-- If it is an array type, the offset should be considered an array index, or
-- vector of array indices.
goArray ::
(1 <= lanes) =>
NatRepr lanes {- Number of lanes of the GEP so far -} ->
L.Typed L.Value {- Current index value -} ->
MemType {- MemType of the index value -} ->
MemType {- MemType of the incoming pointer(s) -} ->
GEP lanes (L.Typed L.Value) {- partial GEP computation -} ->
[L.Typed L.Value] {- remaining arguments to process -} ->
m (GEPResult (L.Typed L.Value))
goArray lanes off offt mt' gep xs =
case offt of
-- Single array index, apply pointwise to all intermediate pointers
IntType _
-> go lanes mt' (GEP_index_each mt' gep off) xs
-- Vector of indices, matching the current number of lanes, apply
-- each offset to the corresponding base pointer
VecType n (IntType _)
| natValue lanes == n
-> go lanes mt' (GEP_index_vector mt' gep off) xs
-- Vector of indices, with a single incoming base pointer. Scatter
-- the base pointer across the correct number of lanes, and then
-- apply the vector of offsets componentwise.
VecType n (IntType _)
| Some n' <- mkNatRepr n
, Just LeqProof <- isPosNat n'
, Just Refl <- testEquality lanes (knownNat @1)
-> go n' mt' (GEP_index_vector mt' (GEP_scatter n' gep) off) xs
-- Otherwise, some sort of mismatch occured.
_ -> badGEP
-- If it is a structure type, the index must be a constant value that indicates
-- which field (counting from 0) is to be indexed.
goStruct ::
(1 <= lanes) =>
NatRepr lanes {- Number of lanes of the GEP so far -} ->
L.Typed L.Value {- Field index number -} ->
MemType {- MemType of the field index -} ->
StructInfo {- Struct layout information -} ->
GEP lanes (L.Typed L.Value) {- partial GEP computation -} ->
[L.Typed L.Value] {- remaining arguments to process -} ->
m (GEPResult (L.Typed L.Value))
goStruct lanes off offt si gep xs =
do off' <- transConstant' offt (L.typedValue off)
case off' of
-- Special case for the zero value
ZeroConst (IntType _) -> goidx 0
-- Single index; compute the corresponding field.
IntConst _ idx -> goidx (BV.asUnsigned idx)
-- Special case. A vector of indices is allowed, but it must be of the correct
-- number of lanes, and each (constant) index must be the same value.
VectorConst (IntType _) (i@(IntConst _ idx) : is) | all (same i) is -> goidx (BV.asUnsigned idx)
where
same :: LLVMConst -> LLVMConst -> Bool
same (IntConst wx x) (IntConst wy y)
| Just Refl <- testEquality wx wy = x == y
same _ _ = False
-- Otherwise, invalid GEP instruction
_ -> badGEP
-- using the information from the struct type, figure out which
-- field is indicated
where goidx idx | 0 <= idx && idx < toInteger (V.length flds) =
go lanes (fiType fi) (GEP_field fi gep) xs
where flds = siFields si
fi = flds V.! (fromInteger idx)
goidx _ = badGEP
-- | Translation-time LLVM constant values.
data LLVMConst where
-- | A constant value consisting of all zero bits.
ZeroConst :: !MemType -> LLVMConst
-- | A constant integer value, with bit-width @w@.
IntConst :: (1 <= w) => !(NatRepr w) -> !(BV.BV w) -> LLVMConst
-- | A constant floating point value.
FloatConst :: !Float -> LLVMConst
-- | A constant double value.
DoubleConst :: !Double -> LLVMConst
-- | A constant long double value (X86_FP80)
LongDoubleConst :: !L.FP80Value -> LLVMConst
-- | A constant sequence of bytes
StringConst :: !ByteString -> LLVMConst
-- | A constant array value.
ArrayConst :: !MemType -> [LLVMConst] -> LLVMConst
-- | A constant vector value.
VectorConst :: !MemType -> [LLVMConst] -> LLVMConst
-- | A constant structure value.
StructConst :: !StructInfo -> [LLVMConst] -> LLVMConst
-- | A pointer value, consisting of a concrete offset from a global symbol.
SymbolConst :: !L.Symbol -> !Integer -> LLVMConst
-- | The @undef@ value is quite strange. See: The LLVM Language Reference,
-- § Undefined Values.
UndefConst :: !MemType -> LLVMConst
-- | This also can't be derived, but is completely uninteresting.
instance Show LLVMConst where
show lc = intercalate " " $
case lc of
(ZeroConst mem) -> ["ZeroConst", show mem]
(IntConst w x) -> ["IntConst", show w, show x]
(FloatConst f) -> ["FloatConst", show f]
(DoubleConst d) -> ["DoubleConst", show d]
ld@(LongDoubleConst _)-> ["LongDoubleConst", show ld]
(ArrayConst mem a) -> ["ArrayConst", show mem, show a]
(VectorConst mem v) -> ["VectorConst", show mem, show v]
(StructConst si a) -> ["StructConst", show si, show a]
(SymbolConst s x) -> ["SymbolConst", show s, show x]
(UndefConst mem) -> ["UndefConst", show mem]
(StringConst bs) -> ["StringConst", show bs]
-- | The interesting cases here are:
-- * @IntConst@: GHC can't derive this because @IntConst@ existentially
-- quantifies the integer's width. We say that two integers are equal when
-- they have the same width *and* the same value.
-- * @UndefConst@: Two @undef@ values aren't necessarily the same...
instance Eq LLVMConst where
(ZeroConst mem1) == (ZeroConst mem2) = mem1 == mem2
(IntConst w1 x1) == (IntConst w2 x2) =
case decEq w1 w2 of
Left Refl -> x1 == x2
Right _ -> False
(FloatConst f1) == (FloatConst f2) = f1 == f2
(DoubleConst d1) == (DoubleConst d2) = d1 == d2
(LongDoubleConst ld1) == (LongDoubleConst ld2) = ld1 == ld2
(ArrayConst mem1 a1) == (ArrayConst mem2 a2) = mem1 == mem2 && a1 == a2
(VectorConst mem1 v1) == (VectorConst mem2 v2) = mem1 == mem2 && v1 == v2
(StructConst si1 a1) == (StructConst si2 a2) = si1 == si2 && a1 == a2
(SymbolConst s1 x1) == (SymbolConst s2 x2) = s1 == s2 && x1 == x2
(UndefConst _) == (UndefConst _) = False
_ == _ = False
-- | Create an LLVM constant value from a boolean.
boolConst :: Bool -> LLVMConst
boolConst False = IntConst (knownNat @1) (BV.zero knownNat)
boolConst True = IntConst (knownNat @1) (BV.one knownNat)
-- | Create an LLVM constant of a given width. The resulting integer
-- constant value will be the unsigned integer value @n mod 2^w@.
intConst ::
MonadError String m =>
Natural {- ^ width of the integer constant, @w@ -} ->
Integer {- ^ value of the integer constant, @n@ -} ->
m LLVMConst
intConst n 0
= return (ZeroConst (IntType n))
intConst n x
| Some w <- mkNatRepr n
, Just LeqProof <- isPosNat w
= return (IntConst w (BV.mkBV w x))
intConst n _
= throwError ("Invalid integer width: " ++ show n)
-- | Compute the constant value of an expression. Fail if the
-- given value does not represent a constant.
transConstantWithType ::
(?lc :: TypeContext, MonadError String m, HasPtrWidth wptr) =>
L.Typed L.Value ->
m (MemType, LLVMConst)
transConstantWithType (L.Typed tp v) =
do mt <- liftMemType tp
c <- transConstant' mt v
return (mt, c)
transConstant ::
(?lc :: TypeContext, MonadError String m, HasPtrWidth wptr) =>
L.Typed L.Value ->
m LLVMConst
transConstant x = snd <$> transConstantWithType x
-- | Compute the constant value of an expression. Fail if the
-- given value does not represent a constant.
transConstant' ::
(?lc :: TypeContext, MonadError String m, HasPtrWidth wptr) =>
MemType ->
L.Value ->
m LLVMConst
transConstant' tp (L.ValUndef) =
return (UndefConst tp)
transConstant' (IntType n) (L.ValInteger x) =
intConst n x
transConstant' (IntType 1) (L.ValBool b) =
return . IntConst (knownNat @1) $ if b
then (BV.one knownNat)
else (BV.zero knownNat)
transConstant' FloatType (L.ValFloat f) =
return (FloatConst f)
transConstant' DoubleType (L.ValDouble d) =
return (DoubleConst d)
transConstant' X86_FP80Type (L.ValFP80 ld) =
return (LongDoubleConst ld)
transConstant' (PtrType _) (L.ValSymbol s) =
return (SymbolConst s 0)
transConstant' PtrOpaqueType (L.ValSymbol s) =
return (SymbolConst s 0)
transConstant' tp L.ValZeroInit =
return (ZeroConst tp)
transConstant' (PtrType stp) L.ValNull =
return (ZeroConst (PtrType stp))
transConstant' PtrOpaqueType L.ValNull =
return (ZeroConst PtrOpaqueType)
transConstant' (VecType n tp) (L.ValVector _tp xs)
| n == fromIntegral (length xs)
= VectorConst tp <$> traverse (transConstant' tp) xs
transConstant' (ArrayType n tp) (L.ValArray _tp xs)
| n == fromIntegral (length xs)
= ArrayConst tp <$> traverse (transConstant' tp) xs
transConstant' (StructType si) (L.ValStruct xs)
| not (siIsPacked si)
, V.length (siFields si) == length xs
= StructConst si <$> traverse transConstant xs
transConstant' (StructType si) (L.ValPackedStruct xs)
| siIsPacked si
, V.length (siFields si) == length xs
= StructConst si <$> traverse transConstant xs
transConstant' (ArrayType n tp) (L.ValString cs)
| tp == IntType 8, n == fromIntegral (length cs)
= return . StringConst $! BS.pack cs
transConstant' _ (L.ValConstExpr cexpr) = transConstantExpr cexpr
transConstant' tp val =
throwError $ unlines [ "Cannot compute constant value for expression: "
, "Type: " ++ (show $ ppMemType tp)
, "Value: " ++ (show $ LPP.ppValue val)
]
-- | Evaluate a GEP instruction to a constant value.
evalConstGEP :: forall m wptr.
(?lc :: TypeContext, MonadError String m, HasPtrWidth wptr) =>
GEPResult LLVMConst ->
m (MemType, LLVMConst)
evalConstGEP (GEPResult lanes finalMemType gep0) =
do xs <- go gep0
unless (fromIntegral (length xs) == natValue lanes)
(throwError "Unexpected vector length in result of constant GEP")
case xs of
[x] -> return ( PtrType (MemType finalMemType), x)
_ -> return ( VecType (fromIntegral (length xs)) (PtrType (MemType finalMemType))
, VectorConst (PtrType (MemType finalMemType)) xs
)
where
dl = llvmDataLayout ?lc
asOffset :: MemType -> LLVMConst -> m Integer
asOffset _ (ZeroConst (IntType _)) = return 0
asOffset mt (IntConst _ x) =
do let x' = BV.asUnsigned x * bytesToInteger (memTypeSize dl mt)
unless (x' <= maxUnsigned ?ptrWidth)
(throwError "Computed offset overflow in constant GEP")
return x'
asOffset ty val = throwError $ unlines $
[ "Expected offset value in constant GEP"
, "Type: " ++ show ty
, "Offset: " ++ show val
]
addOffset :: Integer -> LLVMConst -> m LLVMConst
addOffset x (SymbolConst sym off) = return (SymbolConst sym (off+x))
addOffset _ constant = throwError $ unlines $
[ "Expected symbol constant in constant GEP"
, "Constant: " ++ show constant
]
-- Given a processed GEP instruction, compute the sequence of output
-- pointer values that result from the instruction. If the GEP is
-- scalar-valued, then the result will be a list of one element.
go :: GEP n LLVMConst -> m [LLVMConst]
-- Scalar base, return a list containing just the base value.
go (GEP_scalar_base base)
= return [base]
-- Vector base, deconstruct the input value and return the
-- corresponding values.
go (GEP_vector_base n x)
= asVectorOf (natValue n) return x
-- Scatter a scalar input across n lanes
go (GEP_scatter n gep)
= do ps <- go gep
case ps of
[p] -> return (replicate (widthVal n) p)
_ -> throwError "vector length mismatch in GEP scatter"
-- Add the offset corresponding to the given field across
-- all the lanes of the GEP
go (GEP_field fi gep)
= do ps <- go gep
let i = bytesToInteger (fiOffset fi)
traverse (addOffset i) ps
-- Compute the offset corresponding to the given array index
-- and add that offest across all the lanes of the GEP
go (GEP_index_each mt gep x)
= do ps <- go gep
i <- asOffset mt x
traverse (addOffset i) ps
-- For each index in the input vector, compute and offset according
-- to the given memory type and add the corresponding offset across
-- each lane of the GEP componentwise.
go (GEP_index_vector mt gep x)
= do ps <- go gep
is <- asVectorOf (fromIntegral (length ps)) (asOffset mt) x
zipWithM addOffset is ps
-- | Evaluate a floating point comparison.
evalFcmp ::
RealFloat a =>
L.FCmpOp ->
a -> a -> LLVMConst
evalFcmp op x y = boolConst $ case op of
L.Ffalse -> False
L.Ftrue -> True
L.Foeq -> ordered && x == y
L.Fone -> ordered && x /= y
L.Fogt -> ordered && x > y
L.Foge -> ordered && x >= y
L.Folt -> ordered && x < y
L.Fole -> ordered && x <= y
L.Ford -> ordered
L.Fueq -> unordered || x == y
L.Fune -> unordered || x /= y
L.Fugt -> unordered || x > y
L.Fuge -> unordered || x >= y
L.Fult -> unordered || x < y
L.Fule -> unordered || x <= y
L.Funo -> unordered
where
unordered = isNaN x || isNaN y
ordered = not unordered
-- | Evaluate an integer comparison.
evalIcmp ::
(1 <= w) =>
L.ICmpOp ->
NatRepr w ->
BV.BV w -> BV.BV w -> LLVMConst
evalIcmp op w x y = boolConst $ case op of
L.Ieq -> x == y
L.Ine -> x /= y
L.Iugt -> BV.ult y x
L.Iuge -> BV.ule y x
L.Iult -> BV.ult x y
L.Iule -> BV.ule x y
L.Isgt -> BV.slt w y x
L.Isge -> BV.sle w y x
L.Islt -> BV.slt w x y
L.Isle -> BV.sle w x y
-- | Evaluate a binary arithmetic operation.
evalArith ::
(MonadError String m, HasPtrWidth wptr) =>
L.ArithOp ->
MemType ->
Arith -> Arith -> m LLVMConst
evalArith op (IntType m) (ArithI x) (ArithI y)
| Just (Some w) <- someNat m
, Just LeqProof <- isPosNat w
= evalIarith op w x y
evalArith op FloatType (ArithF x) (ArithF y) = FloatConst <$> evalFarith op x y
evalArith op DoubleType (ArithD x) (ArithD y) = DoubleConst <$> evalFarith op x y
evalArith _ _ _ _ = throwError "binary arithmetic argument mismatch"
-- | Evaluate a unary arithmetic operation.
evalUnaryArith ::
(MonadError String m, HasPtrWidth wptr) =>
L.UnaryArithOp ->
MemType ->
Arith -> m LLVMConst
evalUnaryArith op FloatType (ArithF x) = FloatConst <$> evalFunaryArith op x
evalUnaryArith op DoubleType (ArithD x) = DoubleConst <$> evalFunaryArith op x
evalUnaryArith _ _ _ = throwError "unary arithmetic argument mismatch"
-- | Evaluate a binary floating-point operation.
evalFarith ::
(RealFrac a, MonadError String m) =>
L.ArithOp ->
a -> a -> m a
evalFarith op x y =
case op of
L.FAdd -> return (x + y)
L.FSub -> return (x - y)
L.FMul -> return (x * y)
L.FDiv -> return (x / y)
L.FRem -> return (mod' x y)
_ -> throwError "Encountered integer arithmetic operation applied to floating point arguments"
-- | Evaluate a unary floating-point operation.
evalFunaryArith ::
(RealFrac a, MonadError String m) =>
L.UnaryArithOp ->
a -> m a
evalFunaryArith op x =
case op of
L.FNeg -> return (negate x)
-- | Evaluate an integer or pointer arithmetic operation.
evalIarith ::
(1 <= w, MonadError String m, HasPtrWidth wptr) =>
L.ArithOp ->
NatRepr w ->
ArithInt -> ArithInt -> m LLVMConst
evalIarith op w (ArithInt x) (ArithInt y)
= IntConst w <$> evalIarith' op w (BV.mkBV w x) (BV.mkBV w y)
evalIarith op w (ArithPtr sym x) (ArithInt y)
| Just Refl <- testEquality w ?ptrWidth
, L.Add _ _ <- op
= return $ SymbolConst sym (x+y)
| otherwise
= throwError "Illegal operation applied to pointer argument"
evalIarith op w (ArithInt x) (ArithPtr sym y)
| Just Refl <- testEquality w ?ptrWidth
, L.Add _ _ <- op
= return $ SymbolConst sym (x+y)
| otherwise
= throwError "Illegal operation applied to pointer argument"
evalIarith op w (ArithPtr symx x) (ArithPtr symy y)
| Just Refl <- testEquality w ?ptrWidth
, symx == symy
, L.Sub _ _ <- op
= return $ IntConst ?ptrWidth (BV.mkBV ?ptrWidth (x - y))
| otherwise
= throwError "Illegal operation applied to pointer argument"
-- | Evaluate an integer (non-pointer) arithmetic operation.
evalIarith' ::
(1 <= w, MonadError String m) =>
L.ArithOp ->
NatRepr w ->
BV.BV w -> BV.BV w -> m (BV.BV w)
evalIarith' op w x y = do
let nuwTest nuw zres =
when (nuw && BV.ofUnsigned zres)
(throwError "Unsigned overflow in constant arithmetic operation")
let nswTest nsw zres =
when (nsw && BV.ofSigned zres)
(throwError "Signed overflow in constant arithmetic operation")
case op of
L.Add nuw nsw ->
do let zres = BV.addOf w x y
nuwTest nuw zres
nswTest nsw zres
return (BV.ofResult zres)
L.Sub nuw nsw ->
do let zres = BV.subOf w x y
nuwTest nuw zres
nswTest nsw zres
return (BV.ofResult zres)
L.Mul nuw nsw ->
do let zres = BV.mulOf w x y
nuwTest nuw zres
nswTest nsw zres
return (BV.ofResult zres)
L.UDiv exact ->
do when (y == BV.zero w)
(throwError "Division by 0 in constant arithmetic operation")
let (z,r) = BV.uquotRem x y
when (exact && r /= BV.zero w)
(throwError "Exact division failed in constant arithmetic operation")
return z
L.SDiv exact ->
do when (y == BV.zero w)
(throwError "Division by 0 in constant arithmetic operation")
when (x == BV.minSigned w && y == BV.mkBV w (-1))
(throwError "Signed division overflow in constant arithmetic operation")
let (z,r) = BV.squotRem w x y
when (exact && r /= BV.zero w )
(throwError "Exact division failed in constant arithmetic operation")
return z
L.URem ->
do when (y == BV.zero w)
(throwError "Division by 0 in constant arithmetic operation")
let r = BV.urem x y
return r
L.SRem ->
do when (y == BV.zero w)
(throwError "Division by 0 in constant arithmetic operation")
when (x == BV.minSigned w && y == BV.mkBV w (-1))
(throwError "Signed division overflow in constant arithmetic operation")
let r = BV.srem w x y
return r
_ -> throwError "Floating point operation applied to integer arguments"
-- BGS: Leave this alone for now, as we don't have a good way to
-- detect overflow from bitvector operations.
-- | Evaluate a bitwise operation on integer values.
evalBitwise ::
(1 <= w, MonadError String m) =>
L.BitOp ->
NatRepr w ->
BV.BV w -> BV.BV w -> m LLVMConst
evalBitwise op w x y = IntConst w <$>
let yshf = fromInteger (BV.asUnsigned y) :: Natural
in case op of
L.And -> return (BV.and x y)
L.Or -> return (BV.or x y)
L.Xor -> return (BV.xor x y)
L.Shl nuw nsw ->
do let zres = BV.shlOf w x yshf
when (nuw && BV.ofUnsigned zres)
(throwError "Unsigned overflow in left shift")
when (nsw && BV.ofSigned zres)
(throwError "Signed overflow in left shift")
return (BV.ofResult zres)
L.Lshr exact ->
do let z = BV.lshr w x yshf
when (exact && x /= BV.shl w z yshf)
(throwError "Exact right shift failed")
return z
L.Ashr exact ->
do let z = BV.ashr w x yshf
when (exact && x /= BV.shl w z yshf)
(throwError "Exact right shift failed")
return z
-- | Evaluate a conversion operation on constants.
evalConv ::
(?lc :: TypeContext, MonadError String m, HasPtrWidth wptr) =>
L.ConstExpr ->
L.ConvOp ->
MemType ->
LLVMConst ->
m LLVMConst
evalConv expr op mt x = case op of
L.FpToUi
| IntType n <- mt
, Just (Some w) <- someNat n
, Just LeqProof <- isPosNat w
, FloatConst f <- x
-> return $ IntConst w (BV.mkBV w (truncate f))
| IntType n <- mt
, Just (Some w) <- someNat n
, Just LeqProof <- isPosNat w
, DoubleConst d <- x
-> return $ IntConst w (BV.mkBV w (truncate d))
L.FpToSi
| IntType n <- mt
, Just (Some w) <- someNat n
, Just LeqProof <- isPosNat w
, FloatConst f <- x
-> return $ IntConst w (BV.mkBV w (truncate f))
| IntType n <- mt
, Just (Some w) <- someNat n
, Just LeqProof <- isPosNat w
, DoubleConst d <- x
-> return $ IntConst w (BV.mkBV w (truncate d))
L.UiToFp
| FloatType <- mt
, IntConst _w i <- x
-> return $ FloatConst (fromInteger (BV.asUnsigned i) :: Float)
| DoubleType <- mt
, IntConst _w i <- x
-> return $ DoubleConst (fromInteger (BV.asUnsigned i) :: Double)
L.SiToFp
| FloatType <- mt
, IntConst w i <- x
-> return $ FloatConst (fromInteger (BV.asSigned w i) :: Float)
| DoubleType <- mt
, IntConst w i <- x
-> return $ DoubleConst (fromInteger (BV.asSigned w i) :: Double)
L.Trunc
| IntType n <- mt
, IntConst w i <- x
, Just (Some w') <- someNat n
, Just LeqProof <- isPosNat w'
-> case testNatCases w' w of
NatCaseLT LeqProof -> return $ IntConst w' (BV.trunc w' i)
NatCaseEQ -> return x
NatCaseGT LeqProof ->
throwError $ "Attempted to truncate " <> show w <> " bits to " <> show w'
L.ZExt
| IntType n <- mt
, IntConst w i <- x
, Just (Some w') <- someNat n
, Just LeqProof <- isPosNat w'
-> case testNatCases w w' of
NatCaseLT LeqProof -> return $ IntConst w' (BV.zext w' i)
NatCaseEQ -> return x
NatCaseGT LeqProof ->
throwError $ "Attempted to zext " <> show w <> " bits to " <> show w'
L.SExt
| IntType n <- mt
, IntConst w i <- x
, Just (Some w') <- someNat n
, Just LeqProof <- isPosNat w'
-> case testNatCases w w' of
NatCaseLT LeqProof -> return $ IntConst w' (BV.sext w w' i)
NatCaseEQ -> return x
NatCaseGT LeqProof ->
throwError $ "Attempted to sext " <> show w <> " bits to " <> show w'
L.FpTrunc
| DoubleType <- mt
, DoubleConst d <- x
-> return $ DoubleConst d
| FloatType <- mt
, DoubleConst d <- x
-> return $ FloatConst (realToFrac d)
| FloatType <- mt
, FloatConst f <- x
-> return $ FloatConst f
L.FpExt
| DoubleType <- mt
, DoubleConst d <- x
-> return $ DoubleConst d
| DoubleType <- mt
, FloatConst f <- x
-> return $ DoubleConst (realToFrac f)
| FloatType <- mt
, FloatConst f <- x
-> return $ FloatConst f
L.IntToPtr -> return x
L.PtrToInt -> return x
_ -> badExp "unexpected conversion operation"
where badExp msg = throwError $ unlines [msg, show expr]
castToInt ::
MonadError String m =>
L.ConstExpr {- ^ original expression to evaluate -} ->
EndianForm ->
Natural ->
MemType ->
LLVMConst ->
m Integer
castToInt _expr _endian _w (IntType w) x = asInt w x
castToInt expr endian w (VecType n tp) x
| (m,0) <- w `divMod` n =
do xs <- asVectorOf n (castToInt expr endian m tp) x
let indices = case endian of
LittleEndian -> [0 .. n-1]
BigEndian -> reverse [0 .. n-1]
let pieces = [ v `shiftL` (fromIntegral (i * m))
| i <- indices
| v <- xs
]
return (foldr (.|.) 0 pieces)
castToInt expr _ _ _ _ =
throwError $ unlines ["Cannot cast expression to integer type", show expr]
castFromInt ::
MonadError String m =>
EndianForm ->
Integer ->
Natural ->
MemType ->
m LLVMConst
castFromInt _ xint w (IntType w')
| w == w'
, Some wsz <- mkNatRepr w
, Just LeqProof <- isPosNat wsz
= return $ IntConst wsz (BV.mkBV wsz xint)
castFromInt endian xint w (VecType n tp)
| (m,0) <- w `divMod` n =
do let mask = (1 `shiftL` fromIntegral m) - 1
let indices = case endian of
LittleEndian -> [0 .. n-1]
BigEndian -> reverse [0 .. n-1]
let pieces = [ mask .&. (xint `shiftR` fromIntegral (i * m))
| i <- indices
]
VectorConst tp <$> mapM (\x -> castFromInt endian x m tp) pieces
castFromInt _ _ _ tp =
throwError $ unlines ["Cant cast integer to type", show tp]
-- | Evaluate a bitcast
evalBitCast ::
(?lc :: TypeContext, MonadError String m) =>
L.ConstExpr {- ^ original expression to evaluate -} ->
MemType {- ^ input expressio type -} ->
LLVMConst {- ^ input expression -} ->
MemType {- ^ desired output type -} ->
m LLVMConst
-- cast zero constants to relabeled zero constants
evalBitCast _ _ (ZeroConst _) tgtT = return (ZeroConst tgtT)
-- pointer casts always succeed
evalBitCast _ (PtrType _) expr (PtrType _) = return expr
evalBitCast _ (PtrType _) expr PtrOpaqueType = return expr
evalBitCast _ PtrOpaqueType expr (PtrType _) = return expr
evalBitCast _ PtrOpaqueType expr PtrOpaqueType = return expr
-- casts between vectors of the same length can just be done pointwise
evalBitCast expr (VecType n srcT) (VectorConst _ xs) (VecType n' tgtT)
| n == n' = VectorConst tgtT <$> traverse (\x -> evalBitCast expr srcT x tgtT) xs
-- otherwise, cast via an intermediate integer type
evalBitCast expr xty x toty
| Just w1 <- memTypeBitwidth xty
, Just w2 <- memTypeBitwidth toty
, w1 == w2
= do let endian = ?lc ^. to llvmDataLayout.intLayout
xint <- castToInt expr endian w1 xty x
castFromInt endian xint w1 toty
evalBitCast expr _ _ _ =
throwError $ unlines ["illegal constant bitcast", show expr]
asVectorOf ::
MonadError String m =>
Natural ->
(LLVMConst -> m a) ->
(LLVMConst -> m [a])
asVectorOf n f (ZeroConst (VecType m mt))
| n == m
= do x <- f (ZeroConst mt)
return (replicate (fromIntegral n) x)
asVectorOf n f (VectorConst _ xs)
| n == fromIntegral (length xs)
= traverse f xs
asVectorOf n _ _
= throwError ("Expected vector constant value of length: " ++ show n)
-- | Type representing integer-like things. These are either actual
-- integer constants, or constant offsets from global symbols.
data ArithInt where
ArithInt :: Integer -> ArithInt
ArithPtr :: L.Symbol -> Integer -> ArithInt