In [17]:
import time

from aalpy.oracles import TransitionFocusOracle, StatePrefixEqOracle

from Comparison_with_White_Box import Weiss_to_AALpy_DFA_format, train_or_load_rnn
from RNN_SULs import RNN_BinarySUL_for_Weiss_Framework
from Refinement_based_extraction.Extraction import extract
from Refinement_based_extraction.GRU import GRUNetwork


In [7]:
# Train a BP RNN parentheses
rnn, alphabet, train_set = train_or_load_rnn('bp_ex', num_layers=2, hidden_dim=50,
                                             rnn_class=GRUNetwork, train=True)


classification loss on last batch was: 0.00018254815475327393
saving to RNN_Models/WeissComparisonModels/bp_ex_GRU_layers_2_dim_50.rnn


In [8]:
# initial examples for refinement-based approach
all_words = sorted(list(train_set.keys()), key=lambda x: len(x))
pos = next((w for w in all_words if rnn.classify_word(w) is True), None)
neg = next((w for w in all_words if rnn.classify_word(w) is False), None)
starting_examples = [w for w in [pos, neg] if None is not w]

In [9]:
# Extract Automaton Using White-Box eq. query
rnn.renew()

# Perform white-box refinement-based extraction
start_white_box = time.time()
dfa_weiss = extract(rnn, time_limit=1000, initial_split_depth=10, starting_examples=starting_examples)
time_white_box = time.time() - start_white_box
# Make sure that internal states are back to initial
rnn.renew()

provided counterexamples are: ['', '(']
obs table refinement took 0.005
guided starting equivalence query for DFA of size 2
split wasn't perfect: gotta start over
returning counterexample of length 2:		)), this counterexample is rejected by the given RNN.
equivalence checking took: 0.41981890000010935
obs table refinement took 0.035
guided starting equivalence query for DFA of size 3
lstar successful: unrolling seems equivalent to proposed automaton
equivalence checking took: 0.7214922999999089
overall guided extraction time took: 1.1856159999999818
generated counterexamples were: (format: (counterexample, counterexample generation time))
('))', 0.41981890000010935)


In [10]:
# Translate the model obtained from refinement-based approach to AALpy standard
white_box_hyp = Weiss_to_AALpy_DFA_format(dfa_weiss)
sul = RNN_BinarySUL_for_Weiss_Framework(rnn)

In [11]:
# Define the eq. oracle
eq_oracle = TransitionFocusOracle(alphabet, sul, num_random_walks=1000, walk_len=20)

In [14]:
# Try to falsify the model extracted by refinement-based approach.
# If counterexamples are found there is a high probability that they are adversarial inputs.
cex_set = set()
print('Time \t Counterexample')
for _ in range(20):
    start_time = time.time()
    cex = eq_oracle.find_cex(white_box_hyp)
    if tuple(cex) in cex_set:
        continue
    cex_set.add(tuple(cex))
    end_time = time.time() - start_time
    print(round(end_time, 2),'\t', "".join(cex))

Time 	 Counterexample
0.03 	 (gx())
0.1 	 ((()()()))
0.01 	 l(b)o(())
0.14 	 ((()z())(()))
0.02 	 q(()h)
0.01 	 (f())
0.01 	 (((())()))
0.13 	 (d())
0.0 	 (()((ku))()())
0.01 	 (y())
0.04 	 (g(()))
0.06 	 (())
0.04 	 (()f())
0.06 	 k((())uo)
0.03 	 ()u(()w)
0.0 	 w()(((q)a)y)
0.02 	 ()()(())
