In [1]:
from z3 import Solver, sat, Int, If, Bool, Xor, unsat, Not
import subprocess

In [19]:
# Inputs
a = Int('a')
b = Int('b')

# Assert conditions
S = Solver()

# Two key sets
k11 = Int('k11')
k21 = Int('k21')
k31 = Bool('k31')

k12 = Int('k12')
k22 = Int('k22')
k32 = Bool('k32')

out1 = Int('out1')
out2 = Int('out2')
    
# Running the loop
limit = 1000;
iter = 1;

# Adding conditions
S.add(If(Xor((a>b), k31), out1 == a + k11, out1 == b * k21)) 
S.add(If(Xor((a>b), k32), out2 == a + k12, out2 == b * k22))
S.add(Not(out1 == out2))

# Open the file
file = open("output.txt", "w")
while iter <= limit:
    if(S.check() == unsat):
        print("Unsat")
        file_2 = open("assertions.txt", "w")
        for s in S.assertions():
            line = str(s) + "\n"
            file_2.write(line)
        file_2.close()
        
        final_solver = Solver()
        for s in S.assertions():
            if str(s) == "Not(out1 == out2)": continue;
            final_solver.add(s)
        final_solver.add(out1 == out2)
        
        file_3 = open("final_assertions.txt", "w")
        for s in final_solver.assertions():
            line = str(s) + "\n"
            file_3.write(line)  
        file_3.close()
            
        if final_solver.check() == unsat:
            print("Unsat")
        else: 
            print(final_solver.model())
        break;
    else:
        # DIP exists so find output from blackbox
        m = S.model()
        
        # Printing keys to the file output.txt
        line = "Iteration: " + str(iter) + "\nModel is: " + str(m) + "\n"
        file.write(line)
        
        # extract a, b from model
        dip_a = m[a]
        dip_b = m[b]
        
        print("DIP: ", dip_a, dip_b)
        
        # convert to int
        dip_a = int(str(dip_a))
        dip_b = int(str(dip_b))
        
        process = subprocess.Popen(['./a.out', str(dip_a), str(dip_b)], stdout=subprocess.PIPE)

        output, error = process.communicate()

        if error:
            print("Error occurred: ", error)
            break;
        else:
            dip_out = int(output.decode())
            print("Output: ", dip_out)
            
        # Adding new condition to eliminate the wrong keys
        S.add(If(Xor((dip_a>dip_b), k31), dip_out == dip_a + k11, dip_out == dip_b * k21))
        S.add(If(Xor((dip_a>dip_b), k32), dip_out == dip_a + k12, dip_out == dip_b * k22))
        
    iter = iter + 1;
        
if iter == limit:
    print("Limit reached")
    
file.close()

DIP:  0 0
Output:  0
DIP:  2 1
Output:  7
DIP:  3 2
Output:  8
DIP:  5 5
Output:  15
Unsat
[k32 = False,
 k31 = False,
 out2 = 5,
 b = -10,
 k12 = 5,
 k11 = 5,
 out1 = 5,
 a = 0,
 k21 = 3,
 k22 = 3]
