Skip to content

Commit

Permalink
Updated z3_match for better performance on static input
Browse files Browse the repository at this point in the history
  • Loading branch information
bannsec committed Apr 6, 2016
1 parent b038cca commit f77aa20
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 34 deletions.
6 changes: 4 additions & 2 deletions pyObjectManager/Char.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ def __init__(self,varName,ctx,count=None,variable=None,state=None,increment=Fals
assert type(ctx) is int
assert type(count) in [int, type(None)]

self.size = 16
self.count = 0 if count is None else count
self.varName = varName
self.ctx = ctx
self.variable = BitVec('{1}{0}'.format(self.varName,self.count),ctx=self.ctx,size=16) if variable is None else variable
self.variable = BitVec('{1}{0}'.format(self.varName,self.count),ctx=self.ctx,size=self.size) if variable is None else variable

if state is not None:
self.setState(state)
Expand Down Expand Up @@ -74,11 +75,12 @@ def setState(self,state):
assert type(state) == pyState.State

self.state = state
self.variable.setState(state)

def increment(self):
self.count += 1
# reset variable list if we're incrementing our count
self.variable = BitVec('{1}{0}'.format(self.varName,self.count),ctx=self.ctx,size=16)
self.variable = BitVec('{1}{0}'.format(self.varName,self.count),ctx=self.ctx,size=self.size)

def _isSame(self,**args):
"""
Expand Down
4 changes: 2 additions & 2 deletions pyState/AugAssign.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def _handleNum(state,element,value,op):

# Basic sanity checks complete. For augment assigns we will always need to update the vars.
# Grab the old var and create a new now
oldTargetVar = oldTarget.getZ3Object()
#oldTargetVar = oldTarget.getZ3Object()

# Match up the right hand side
oldTargetVar, valueVar = z3Helpers.z3_matchLeftAndRight(oldTargetVar,value.getZ3Object(),op)
oldTargetVar, valueVar = z3Helpers.z3_matchLeftAndRight(oldTarget,value,op)

if hasRealComponent(valueVar) or hasRealComponent(oldTargetVar):
parent[index] = Real(oldTarget.varName,ctx=state.ctx)
Expand Down
4 changes: 2 additions & 2 deletions pyState/BinOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def _handleNum(state,left,right,op):
# Match our object types
leftZ3Object,rightZ3Object = pyState.z3Helpers.z3_matchLeftAndRight(left.getZ3Object(),right.getZ3Object(),op)
leftZ3Object,rightZ3Object = pyState.z3Helpers.z3_matchLeftAndRight(left,right,op)

# Figure out what the op is and add constraint
if type(op) == ast.Add:
Expand Down Expand Up @@ -84,7 +84,7 @@ def _handleNum(state,left,right,op):
# Now that we have a clean variable to return, add constraints and return it
logger.debug("Adding constraint {0} == {1}".format(retVar.getZ3Object(),ret))
state.addConstraint(retVar.getZ3Object() == ret)
print([x for x in state.solver.assertions()])
#print([x for x in state.solver.assertions()])
return [retVar.copy()]

else:
Expand Down
14 changes: 8 additions & 6 deletions pyState/Compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ def _handleLeftVarInt(state,element,left):

# Resolve the z3 object
if type(left) in [Int, Real, BitVec, Char]:
left = left.getZ3Object()
pass
#left = left.getZ3Object()

# If this is a String, let's hope it's only one char...
elif type(left) is String and left.length() == 1:
left = left[0].getZ3Object()
#left = left[0].getZ3Object()
left = left[0]

else:
err = "_handleLeftVar: Don't know how to handle type '{0}'".format(type(left))
Expand Down Expand Up @@ -68,11 +70,12 @@ def _handleLeftVarInt(state,element,left):
for r in right:

if type(r) in [Int, Real, BitVec, Char]:
r = r.getZ3Object()
pass #r = r.getZ3Object()

# If this is a String, let's hope it's only one char...
elif type(r) is String and r.length() == 1:
r = r[0].getZ3Object()
r = r[0]
#r = r[0].getZ3Object()

else:
err = "_handleLeftVar: Don't know how to handle type '{0}'".format(type(r))
Expand All @@ -82,7 +85,7 @@ def _handleLeftVarInt(state,element,left):

# Adjust the types if needed
l,r = pyState.z3Helpers.z3_matchLeftAndRight(left,r,ops)

logger.debug("_handleLeftVar: Comparing {0} (type: {2}) and {1} (type: {3})".format(l,r,type(l),type(r)))

# Assume success. Add constraints
Expand Down Expand Up @@ -140,7 +143,6 @@ def handle(state,element,ctx=None):
if len(retObjs) > 0:
return retObjs


ret = []

# Loop through possibilities
Expand Down
2 changes: 0 additions & 2 deletions pyState/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,6 @@ def recursiveCopy(self,var,ctx=None,varName=None):
for elm in var:
ret = self.recursiveCopy(elm,varName=varName)
newList.append(ret)
#if type(ret) in [Int, Real, BitVec]:
# self.addConstraint(newList[-1].getZ3Object() == ret.getZ3Object())

return newList

Expand Down
64 changes: 44 additions & 20 deletions pyState/z3Helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import ast
import pyState
import logging
from pyObjectManager.BitVec import BitVec
from pyObjectManager.Real import Real
from pyObjectManager.Int import Int
from pyObjectManager.Char import Char

logger = logging.getLogger("pyState:z3Helpers")

Expand Down Expand Up @@ -81,56 +85,76 @@ def mk_var(name,vsort):
def z3_matchLeftAndRight(left,right,op):
"""
Input:
left = z3 object
right = z3 object
left = pyObjectManager Object (i.e.: Int)
right = pyObjectManager Object (i.e.: Int)
op = ast operation that will be performed
Action:
Appropriately cast the two variables so that they can be used in an expression
Main problem is between Int type and BitVec type
Returns:
(left,right) where both vars should be able to be used together
(left,right) as z3 vars where both vars should be able to be used together
"""
lType = type(left)
rType = type(right)

# If it's char, just grab the BitVec object
if lType is Char:
left = left.variable
lType = type(left)
if rType is Char:
right = right.variable
rType = type(right)

logger.debug("z3_matchLeftAndRight: Called to match {0} and {1}".format(type(left),type(right)))
needBitVec = True if type(op) in [ast.BitXor, ast.BitAnd, ast.BitOr, ast.LShift, ast.RShift] else False
# TODO: If the two sizes are different, we'll have problems down the road.
bitVecSize = max([c.size() for c in [b for b in [left,right] if type(b) in [z3.BitVecRef, z3.BitVecNumRef]]],default=Z3_DEFAULT_BITVEC_SIZE)
bitVecSize = max([c.size for c in [b for b in [left,right] if type(b) is BitVec]],default=Z3_DEFAULT_BITVEC_SIZE)

#####################################
# Case: Both are already BitVectors #
#####################################
# Check length. Extend if needed.
if type(left) in [z3.BitVecRef, z3.BitVecNumRef] and type(right) in [z3.BitVecRef, z3.BitVecNumRef]:
logger.debug("z3_matchLeftAndRight: Matching BitVecLength @ {0} (left={1},right={2})".format(bitVecSize,left.size(),right.size()))
if left.size() < right.size():
if type(left) is BitVec and type(right) is BitVec:
logger.debug("z3_matchLeftAndRight: Matching BitVecLength @ {0} (left={1},right={2})".format(bitVecSize,left.size,right.size))
if left.size < right.size:
# Sign extend left's value to match
left = z3.SignExt(right.size()-left.size(),left)
elif right.size() > left.size():
right = z3.SignExt(left.size()-right.size(),right)
left = z3.SignExt(right.size-left.size,left.getZ3Object())
right = right.getZ3Object()
elif right.size > left.size:
right = z3.SignExt(left.size-right.size,right.getZ3Object())
left = left.getZ3Object()

# Sync-up the output variables
left = left.getZ3Object() if type(left) in [Int, Real, BitVec] else left
right = right.getZ3Object() if type(right) in [Int, Real, BitVec] else right

logger.debug("z3_matchLeftAndRight: Returning {0} and {1}".format(type(left),type(right)))

return left,right

#####################################
# Case: One is BitVec and one isn't #
#####################################
# For now only handling casting of int to BV. Not other way around.
if (lType in [z3.BitVecNumRef, z3.BitVecRef] and rType in [z3.ArithRef,z3.IntNumRef]) or (rType in [z3.ArithRef,z3.IntNumRef] and needBitVec):
if (lType is BitVec and rType is Int) or (rType is Int and needBitVec):
# If we need to convert to BitVec and it is a constant, not variable, do so more directly
if rType is z3.IntNumRef and right.is_int():
right = z3.BitVecVal(right.as_long(),bitVecSize)
if right.isStatic():
right = z3.BitVecVal(right.getValue(),bitVecSize)
# Otherwise cast it. Not optimal, but oh well.
else:
right = z3_int_to_bv(right,size=bitVecSize)
right = z3_int_to_bv(right.getZ3Object(),size=bitVecSize)

if (rType in [z3.BitVecNumRef, z3.BitVecRef] and lType in [z3.ArithRef,z3.IntNumRef]) or (lType in [z3.ArithRef,z3.IntNumRef] and needBitVec):
if lType is z3.IntNumRef and left.is_int():
left = z3.BitVecVal(left.as_long(),bitVecSize)
if (rType is BitVec and lType is Int) or (lType is Int and needBitVec):
if left.isStatic():
left = z3.BitVecVal(left.getValue(),bitVecSize)
else:
left = z3_int_to_bv(left,size=bitVecSize)

left = z3_int_to_bv(left.getZ3Object(),size=bitVecSize)

# Sync-up the output variables
left = left.getZ3Object() if type(left) in [Int, Real, BitVec] else left
right = right.getZ3Object() if type(right) in [Int, Real, BitVec] else right

logger.debug("z3_matchLeftAndRight: Returning {0} and {1}".format(type(left),type(right)))

return (left,right)
return left,right

0 comments on commit f77aa20

Please sign in to comment.