In [1]:
from z3 import *

In [2]:
solver = Solver()
solver.reset()

constraints = parse_smt2_file("/tmp/cbmc.smt2", sorts={}, decls={})
solver.add(constraints)

In [3]:
# Wrapper for allowing Z3 ASTs to be stored into Python Hashtables. 
class AstRefKey:
    def __init__(self, n):
        self.n = n
    def __hash__(self):
        return self.n.hash()
    def __eq__(self, other):
        return self.n.eq(other.n)
    def __repr__(self):
        return str(self.n)

def askey(n):
    assert isinstance(n, AstRef)
    return AstRefKey(n)

def get_vars(f):
    r = set()
    def collect(f):
      if is_const(f): 
          if f.decl().kind() == Z3_OP_UNINTERPRETED and not askey(f) in r:
              r.add(askey(f))
      else:
          for c in f.children():
              collect(c)
    collect(f)
    return {ele.n for ele in r}

def get_var_names(f):
    r = set()
    def collect(f):
      if is_const(f): 
          if f.decl().kind() == Z3_OP_UNINTERPRETED and not askey(f) in r:
              r.add(askey(f))
      else:
          for c in f.children():
              collect(c)
    collect(f)
    return {str(ele.n) for ele in r}

In [4]:
list_of_var_set = [get_var_names(a) for a in constraints]
list_of_vars = set().union(*list_of_var_set)
obsvs = [o for o in list_of_vars if "Observation_" in o]
obsvs.sort(key=lambda x: int(x.split("_")[-1].split("!")[0]))
len(obsvs)

17

In [5]:
import re
# either a declaration, definition, lhs equality that contains var
# if none of the above, then any constraint that contains var
def is_important_constraint(constraint, var):
    if "__CPROVER_deallocated" in constraint and var in constraint:
        # we want to preserve safe-pointer constraint
        return True
    pattern = "(\(assert \(= |\(define-fun |\(declare-fun |\(assert \(= \(select )(\|.*?\|)"
    matches = re.match(pattern, constraint)
    if matches:
        if matches.group(2) == var:
            return True
        else:
            return False
    else:
        return var in constraint

[is_important_constraint("(declare-fun |Observation_2177| () (_ BitVec 32))", "|Observation_2177|"),
is_important_constraint("(define-fun |B4975| () Bool (=> |goto_symex::&92;guard#99| (not (= ((_ zero_extend 20) ((_ extract 31 20) (concat (_ bv205 12) (_ bv0 20)))) ((_ zero_extend 20) ((_ extract 31 20) |__CPROVER_deallocated#0|))))))", "|B4975|"),
is_important_constraint("(define-fun |B4975| () Bool (=> |goto_symex::&92;guard#99| (not (= ((_ zero_extend 20) ((_ extract 31 20) (concat (_ bv205 12) (_ bv0 20)))) ((_ zero_extend 20) ((_ extract 31 20) |__CPROVER_deallocated#0|))))))", "|goto_symex::&92;guard#99|"),
is_important_constraint("(assert (= |cbmc_pointer_offset_2375!0@1#2| (_ bv0 32)))", "|cbmc_pointer_offset_2375!0@1#2|"),
is_important_constraint("(assert (bvsge (bvmul |histogram::1::2::1::t!0@88#4| (_ bv4 32)) (_ bv0 32)))", "|histogram::1::2::1::t!0@88#4|"),
is_important_constraint("", "|histogram::1::2::1::t!0@88#4|"),
is_important_constraint("(assert (= |histogram::1::2::1::v!0@86#2| (select |main::1::a!0@1#1| (_ bv85 32))))", "|main::1::a!0@1#1|")]



[True, True, True, True, True, False, False]

In [6]:
import re

def get_var_in_constraint(constraints):
    pattern = "\|.*?\||array\.\\d*"
    matches = re.findall(pattern, constraints)
    return set(matches)



def get_constraints_for_var(file_path, var_name):
    f = open(file_path)
    lines = f.readlines()
    obsv_constraints = ""
    for line in lines:
        if var_name in line:
            obsv_constraints += line
    # pattern = "\|.*?\|"
    # matches = re.findall(pattern, obsv_constraints)
    # print(matches)
    
    # dependent_vars = set(filter(lambda m: var_name not in m, matches))
    dependent_vars = get_var_in_constraint(obsv_constraints)

    lines.reverse()
    first_line_of_obsv_constraint = "(declare-fun |Observation_0| () (_ BitVec 32))"
    active_searching = False
    cbmc_constraint = ""
    for line in lines:
        if first_line_of_obsv_constraint in line:
            active_searching = True
            continue
        if active_searching and any(map(lambda dep_var: is_important_constraint(line, dep_var), dependent_vars)):
            # HACK: assert the safe-point assertion
            pattern = "\(define-fun (\|B\d*\|)"
            matches = re.match(pattern, line)
            if matches and "__CPROVER_deallocated" in line:
                cbmc_constraint = "(assert {})\n".format(matches.group(1)) + cbmc_constraint

            dependent_vars = dependent_vars | get_var_in_constraint(line)
            dependent_vars.discard("|__CPROVER_deallocated#0|")
            cbmc_constraint = line + cbmc_constraint
    # print(dependent_vars)
    # print("----------------")
    # print(cbmc_constraint+obsv_constraints)
    # print("----------------")
    return parse_smt2_string("(declare-fun |__CPROVER_deallocated#0| () (_ BitVec 32))\n" + cbmc_constraint+obsv_constraints, sorts={}, decls={})
solver.reset()
print(*get_constraints_for_var("/tmp/cbmc.smt2", "|Observation_1336|"), sep="\n")





In [7]:
import time


def get_differential_set_for_obsv(obsv_name):
    solver = Solver()
    solver.reset()
    start_time = time.time()
    constraints = get_constraints_for_var("/tmp/cbmc.smt2", "|"+str(obsv_name)+"|")
    # constraints = parse_smt2_file("/tmp/cbmc.smt2", sorts={}, decls={})
    solver.add(constraints)

    obsv = None
    for a in solver.assertions():
        for v in get_vars(a):
            if str(v) == str(obsv_name):
                obsv = v
    assert(obsv != None)


    slicing_elapsed_time = round(time.time() - start_time, 2)
    all_models = []
    start_time = time.time()
    for i in range(1000):
        if (solver.check() == sat):
            m = solver.model()
            all_models.append(m[obsv].as_signed_long())
            solver.add(obsv != m[obsv])
        else:
            break
        
    elapsed_time = round(time.time() - start_time, 2)

    object_id_to_offsets = {}
    object_gap = 1048576
    for val in all_models:
        object_id = val // object_gap
        if object_id not in object_id_to_offsets:
            object_id_to_offsets[object_id] = []
        object_id_to_offsets[object_id].append(val % object_gap)
    print(obsv, slicing_elapsed_time, elapsed_time, len(constraints), len(all_models))
    return (elapsed_time, len(all_models), object_id_to_offsets)

infos = {o: get_differential_set_for_obsv(str(o)) for o in ["Observation_1"]}


Observation_1 0.01 0.03 3 1


In [8]:
from joblib import Parallel, delayed
def wrapper(o):
    return (o, get_differential_set_for_obsv(o))
tuple_list = Parallel(n_jobs=6)(delayed(wrapper)(str(o)) for o in obsvs)
infos = dict(tuple_list)

Observation_0 0.02 0.01 3 1
Observation_2 0.01 0.03 5 2
Observation_6 0.02 0.02 3 1
Observation_7 0.02 0.02 3 1
Observation_1 0.01 0.01 3 1
Observation_3 0.02 0.01 5 2
Observation_10Observation_8 0.02 0.01 3 1
 0.01 0.01 3 1
Observation_9 0.03 0.01 3 1
Observation_11 0.01 0.01 3 1
Observation_12Observation_14 0.01 0.02 3 1
Observation_15 0.01 0.01 3 1
Observation_13 0.01 0.03 3 1
 0.02 0.01 3 1
Observation_4 0.01 0.01 5 2
Observation_16 0.01 0.01 3 1
Observation_5 0.01 0.0 3 1


In [9]:
infos

{'Observation_0': (0.01, 1, {2: [0]}),
 'Observation_1': (0.01, 1, {4: [0]}),
 'Observation_2': (0.03, 2, {5: [0], -1: [1048575]}),
 'Observation_3': (0.01, 2, {5: [0], -1: [1048575]}),
 'Observation_4': (0.01, 2, {-1: [1048575], 2: [0]}),
 'Observation_5': (0.0, 1, {7: [0]}),
 'Observation_6': (0.02, 1, {8: [0]}),
 'Observation_7': (0.02, 1, {9: [0]}),
 'Observation_8': (0.01, 1, {10: [0]}),
 'Observation_9': (0.01, 1, {10: [0]}),
 'Observation_10': (0.01, 1, {9: [0]}),
 'Observation_11': (0.01, 1, {8: [0]}),
 'Observation_12': (0.01, 1, {9: [0]}),
 'Observation_13': (0.03, 1, {10: [0]}),
 'Observation_14': (0.02, 1, {10: [0]}),
 'Observation_15': (0.01, 1, {9: [0]}),
 'Observation_16': (0.01, 1, {11: [0]})}

In [10]:

def get_pointer_to_name_and_type_map():
    mapping = {}
    f = open("/tmp/pointer_numbering.csv")
    for line in f.readlines():
        line = line.strip()
        splits = line.split(",")
        var_type = splits[0]
        full_var = splits[1]
        numbering = splits[2]
        #histogram::1::2::i!0@1
        #histogram$$1$$2$$1$$i
        #histogram$$1$$2$$1$$1$$t
        mapping[int(numbering)] = (full_var.split("::")[-1].split("!")[0], var_type)
    return mapping

def is_pointer_or_array(var_type):
    # return "*" in var_type or var_type.count('[') >= 2
    return var_type.count('[') >= 2

pointer_mapping = get_pointer_to_name_and_type_map()
ds_file = open("/tmp/ds_of_obsvs.txt", "w")
for obsv, info in infos.items():
    index = str(obsv).split("_")[1]
    memory_locs = []
    ds_size = 0
    for (base, offsets) in info[2].items():
        if (base == -1):
            # decoy access is not part of ds
            continue
        var_name = pointer_mapping[base][0]
        var_type = pointer_mapping[base][1]
        for offset in offsets:
            memory_locs.append("((char*){}{})+{}".format("" if is_pointer_or_array(var_type) else "&", var_name, offset))
            ds_size += 1
    memory_locs = ",".join(memory_locs)
    memory_locs = "{"+memory_locs+"}"
    ds_file.write("#define ds_{} (void* [{}]){}\n".format(index, ds_size, memory_locs))
    ds_file.write("#define ds_size_{} {}\n".format(index, ds_size))
ds_file.close()

In [9]:
start_time = time.time()
solver.check()
elapsed_time = time.time() - start_time
print("TIME: " + str(elapsed_time))

TIME: 108.817866563797


In [15]:
bv_solver = Then(With('simplify', mul2concat=True),
                 'solve-eqs', 
                 'bit-blast', 
                 'aig',
                 'sat').solver()
start_time = time.time()
solve_using(bv_solver, solver.assertions())
elapsed_time = time.time() - start_time
print("TIME: " + str(elapsed_time))

failed to solve
[array.3 = K(BitVec(32), 0),
 array.2 = K(BitVec(32), 0),
 array.1 = K(BitVec(32), 0),
 array.4 = K(BitVec(32), 0),
 array.0 = K(BitVec(32), 0),
 main::1::key!0@1#1[[C]] = 0,
 main::1::key!0@1#1[[7]] = 0,
 main::1::key!0@1#1[[E]] = 0,
 main::1::key!0@1#1[[F]] = 0,
 main::1::key!0@1#1[[2]] = 0,
 main::1::key!0@1#1[[4]] = 0,
 main::1::key!0@1#1[[8]] = 0,
 main::1::key!0@1#1[[0]] = 0,
 main::1::key!0@1#1[[D]] = 0,
 main::1::key!0@1#1[[5]] = 0,
 main::1::key!0@1#1[[3]] = 0,
 main::1::key!0@1#1[[1]] = 0,
 main::1::key!0@1#1[[B]] = 0,
 main::1::key!0@1#1[[9]] = 0,
 main::1::key!0@1#1[[A]] = 0,
 main::1::key!0@1#1[[6]] = 0,
 Observation_1359 = 76800,
 Observation_1358 = 16399,
 Observation_1357 = 15360,
 Observation_1356 = 24576,
 Observation_1355 = 16398,
 Observation_1354 = 15360,
 Observation_1353 = 24576,
 Observation_1352 = 16397,
 Observation_1351 = 15360,
 Observation_1350 = 24576,
 Observation_1349 = 16396,
 Observation_1348 = 15360,
 Observation_1347 = 24576,
 Observa

In [13]:
len(solver.assertions())

7873

In [91]:
print([obsv != m[obsv] for obsv in obsvs][:10])

[2048 != Observation_0, 3072 != Observation_1, 5120 != Observation_2, 7168 != Observation_3, 7168 != Observation_4, 3072 != Observation_5, 4096 != Observation_6, 8192 != Observation_9, 8192 != Observation_10, 3072 != Observation_11]


In [15]:
t = Then('simplify', 'bit-blast', 'tseitin-cnf')
g = Goal()

a, b = BitVecs('a b', 8)

g.add(Or(a == 1, a == 0))
g.add(Implies(a == 0, And(0 <= b, b < 8)))
bitmap = {}
for ob in [a,b]:
    for i in range(ob.size()):
        bitmap[(ob, i)] = Bool(str(ob)+str(i))
        mask = BitVecSort(ob.size()).cast(math.pow(2, i))
        g.add(bitmap[(ob, i)] == ((ob & mask) == mask))
sg = t(g)


In [20]:
x, y = BitVecs('x y', 10)
s = Solver()
s.add(x > 0, y > 0, x + y == 512)
s.check()


In [21]:
with open("/tmp/tmp.smt2", mode='w') as f:
    f.write(s.to_smt2())

In [6]:
list_of_var_set = [get_vars(a) for a in g]
list_of_vars = set().union(*list_of_var_set)


goal = sg[0]
var_set, var_map = var_set_and_map(goal)
cnf_str = "p cnf {} {}\n".format(len(var_set), len(goal))

sampling_set = []
for ob in [a,b]:
    for i in range(ob.size()):
        sampling_set.append(var_map[bitmap[(ob, i)]])
cnf_str += "c ind {} 0\n".format(" ".join(sampling_set))


for i in range(len(goal)):
    if (is_or(goal[i])):
        for j in range(goal[i].num_args()):
            assert(goal[i].arg(j).num_args() <= 1)
            lit = goal[i].arg(j)
            if is_not(lit):
                cnf_str += "-" + var_map[lit.arg(0)]
            else:
                cnf_str += var_map[lit]
            cnf_str += " "
    else:
        # this clause is just a literal
        assert(goal[i].num_args() <= 1)
        lit = goal[i]
        if is_not(lit):
            cnf_str += "-" + var_map[lit.arg(0)]
        else:
            cnf_str += var_map[lit]
        cnf_str += " "
    cnf_str += "0\n"
    
with open("/home/cream/toy/cnf.dimacs", "w") as f:
    f.write(cnf_str)

NameError: name 'g' is not defined

In [11]:
# count num of variables
def var_set_and_map(goal):    
    list_of_var_set = [get_vars(a) for a in goal]
    var_set = set().union(*list_of_var_set)

    # give each variable a name
    var_map = {}
    for i, v in enumerate(var_set):
        var_map[v] = str(i+1)

    return var_set, var_map

def generate_cnf(goal, obs):
    var_set, var_map = var_set_and_map(goal)
    cnf_str = "p cnf {} {}\n".format(len(var_set), len(goal))
    
    # gather sampling set(bits of the interested values)
    sampling_set = []
    for ob in obs:
        for i in range(ob.size()):
            sampling_set.append(var_map[bitmap[(ob, i)]])
    cnf_str += "c ind {} 0\n".format(" ".join(sampling_set))


    for i in range(len(goal)):
        if (is_or(goal[i])):
            for j in range(goal[i].num_args()):
                assert(goal[i].arg(j).num_args() <= 1)
                lit = goal[i].arg(j)
                if is_not(lit):
                    cnf_str += "-" + var_map[lit.arg(0)]
                else:
                    cnf_str += var_map[lit]
                cnf_str += " "
        else:
            print(goal[i])
            # this clause is just a literal
            assert(goal[i].num_args() <= 1)
            lit = goal[i]
            if is_not(lit):
                cnf_str += "-" + var_map[lit.arg(0)]
            else:
                cnf_str += var_map[lit]
            cnf_str += " "
        cnf_str += "0\n"
    return cnf_str

In [None]:
with open("/home/cream/play/cnf", "w") as f:
    f.write(cnf_str)

In [None]:
t = Then('simplify', 'bit-blast', 'tseitin-cnf')
g = Goal()

F = BitVecVal(0, 2)
T = BitVecVal(1, 2)
N = BitVecVal(2, 2)

o0, o1 = BitVecs('o0 o1', 2)
A, B, k, p, o2, o3, o4 = BitVecs('A B k p o2 o3 o4', 32) # bitvectors are signed, so use 4 instead 3 to hold a maximum value of 8

obs = [o0, o2, o3]

# secret value
g.add(0 <= k, k <= 2**3)
# public value
g.add(p == 0)
# memory allocation
g.add(A == 0, B == 0)

g.add(o0 == If(k < 4, T, F))
g.add(o1 == If(k < 4, If(k < 2, T, F), N))
g.add(o2 == If(k < 4, If(k < 2, B + p, B + p), B + p))
g.add(o3 == If(k < 4, If(k < 2, A + k / 3, A + k % 2), A + k / 4))


g.add(o0 == If(k < 4, T, F))
g.add(o2 == If(k < 4, A + k % 2, A + k /2))
g.add(o3 == B + k / 5)

obs = [o0, o2, o3, o4]
g.add(o0 == If(k < 2**19, T, N))
g.add(o2 == If(k < 2**19, A + k % 2, A + k /2))
g.add(o3 == If(k < 2**19, A + p, 2**20+1))
g.add(o4 == B + k / 5)


bitmap = {}
for ob in obs:
    for i in range(ob.size()):
        bitmap[(ob, i)] = Bool(str(ob)+str(i))
        mask = BitVecSort(ob.size()).cast(math.pow(2, i))
        g.add(bitmap[(ob, i)] == ((ob & mask) == mask))

sg = t(g)
print(g)

In [None]:
x = Int('x')
y = Int('y')
f = Function('f', IntSort(), IntSort())
s = Solver()
s.add(f(f(x)) == x, f(x) == y, x != y)
print(s.check())
m = s.model()
print(m)
print("f(f(x)) =", m.evaluate(f(f(x))))
print("f(x)    =", m.evaluate(f(x)))

In [25]:
solver = z3.Solver()
solver.reset()

constraints = z3.parse_smt2_file("/home/cream/src/smtapproxmc/temp_amc/temp_0_0.smt2", sorts={}, decls={})
solver.add(constraints)
solver.check()

# remove variables from CPROVER initializer
constraints = list(filter(lambda c: not(is_eq(c) and str(c.arg(0)).startswith("__CPROVER")), constraints))
print("\n------\n".join([str(c) for c in constraints]))

array.0[0] == 99
------
array.0[1] == 124
------
array.0[2] == 119
------
array.0[3] == 123
------
array.0[4] == 242
------
array.0[5] == 107
------
array.0[6] == 111
------
array.0[7] == 197
------
array.0[8] == 48
------
array.0[9] == 1
------
array.0[10] == 103
------
array.0[11] == 43
------
array.0[12] == 254
------
array.0[13] == 215
------
array.0[14] == 171
------
array.0[15] == 118
------
array.0[16] == 202
------
array.0[17] == 130
------
array.0[18] == 201
------
array.0[19] == 125
------
array.0[20] == 250
------
array.0[21] == 89
------
array.0[22] == 71
------
array.0[23] == 240
------
array.0[24] == 173
------
array.0[25] == 212
------
array.0[26] == 162
------
array.0[27] == 175
------
array.0[28] == 156
------
array.0[29] == 164
------
array.0[30] == 114
------
array.0[31] == 192
------
array.0[32] == 183
------
array.0[33] == 253
------
array.0[34] == 147
------
array.0[35] == 38
------
array.0[36] == 54
------
array.0[37] == 63
------
array.0[38] == 247
------
array.

In [26]:
solver.check()

In [None]:
solver.reset()

solver.add(constraints)
solver.add(o != 7168)
solver.check()

# m = solver.model()
# list_of_var_set = [get_vars(a) for a in constraints]

# list_of_vars = set().union(*list_of_var_set)

# list_of_obsvs = [v for v in list_of_vars if str(v).startswith("Observation_")]
# # list_of_obsvs = [v for v in list_of_vars if str(v).startswith("function#return_value!0#1")]

# o = [o for o in list_of_obsvs if str(o) == "Observation_2"][0]
# m.evaluate(o)

In [None]:
list_of_var_set = [get_vars(a) for a in constraints]

list_of_vars = set().union(*list_of_var_set)

list_of_obsvs = [v for v in list_of_vars if str(v).startswith("Observation_")]
# list_of_obsvs = [v for v in list_of_vars if str(v).startswith("function#return_value!0#1")]

list_of_obsvs

In [None]:
len(list_of_obsvs)

In [None]:
def transitively_select_constraints(constraints, interested_vars):

    def intersection(s0, s1):
        result = set()
        for o0 in s0:
            for o1 in s1:
                if o0.get_id() == o1.get_id():
                    result.add(o0)
        return result

    interested_vars = list_of_obsvs
    constraints = constraints_back

    new_added = True
    transitively_included_vars = set()
    transitively_included_vars.update(list_of_obsvs)
    transitively_included_constraints = set()

    constraints = set(constraints)
    constraint_var_map = {}

    # get variables used in each constraints
    for constraint in constraints:
        constraint_var_map[constraint] = get_vars(constraint)

    while new_added:
        constraints_to_add = []
        for constraint in constraints:
            if intersection(constraint_var_map[constraint], transitively_included_vars):
                constraints_to_add.append(constraint)
        new_added = False 
        for constraint in constraints_to_add:
            transitively_included_vars.update(constraint_var_map[constraint])
            transitively_included_constraints.add(constraint)
            constraints.remove(constraint)
            new_added = True
    print(transitively_included_vars)
    print(transitively_included_constraints)
    return transitively_included_constraints

selected_constraints = transitively_select_constraints(constraints, list_of_obsvs)

In [None]:
len(selected_constraints)

In [None]:
t = Then('simplify', 'bit-blast', 'tseitin-cnf')
g = Goal()
g.add(constraints)

bitmap = {}
for ob in list_of_obsvs:
    for i in range(ob.size()):
        bitmap[(ob, i)] = Bool(str(ob)+str(i))
        mask = BitVecSort(ob.size()).cast(math.pow(2, i))
        g.add(bitmap[(ob, i)] == ((ob & mask) == mask))
        
sg = t(g)
# print(sg[0])
cnf_str = generate_cnf(sg[0], list_of_obsvs)
cnf_str
# def equality_substitution(old_goal, new_goal):
#     e = old_goal.as_expr()
#     old_constraints = [e.arg(i) for i in range(e.num_args())]
#     equality_constraints = [c for c in old_constraints if (not is_or(c)) and c.num_args() != 1]
#     subs = []
#     for cons in equality_constraints:
# #         print(cons)
#         assert(cons.num_args() == 0 or is_eq(cons))
#         if is_eq(cons):
#             subs.append((cons.arg(0), cons.arg(1)))
#     print(subs)
#     for cons in old_constraints:
#         new_goal.add(substitute(cons, subs))
        
# # Remove equality of arrays, which can be generated from memcpy
# new_goal = Goal()
# equality_substitution(sg[0], new_goal)
# sim_sg = Tactic('simplify')(new_goal)
# cnf_str = generate_cnf(sim_sg[0], list_of_obsvs)

In [None]:
with open("/home/cream/toy/toy.smt", "w") as f:
    f.write(cnf_str)