In [1]:
# Imports
from aalpy.learning_algs import run_Lstar
from aalpy.oracles import StatePrefixEqOracle

from DataProcessing import get_coffee_machine, generate_data_from_automaton, split_train_validation, tokenized_dict
from RNNClassifier import RNNClassifier
from RNN_SULs import RnnMealySUL

In [2]:
# Get the model of the coffee machine

coffee_machine = get_coffee_machine()

# Get input and output alphbatet and do one-hot enconding over the output alphabet
input_al = coffee_machine.get_input_alphabet()
output_al = {output for state in coffee_machine.states for output in state.output_fun.values()}
outputs_2_ints = {integer: output for output, integer in tokenized_dict(output_al).items()}


In [3]:
# Create training data
train_seq, train_labels = generate_data_from_automaton(coffee_machine, input_al,
                                                       num_examples=15000, lens=(1,2,3,5,8,12,15))

# Split it into training and verificaiton datasets
x_train, y_train, x_test, y_test = split_train_validation(train_seq, train_labels, 0.8, uniform=True)


In [4]:
# Create a RNN. It has 2 hidden layers with 50 nodes and it is GRU network.
rnn = RNNClassifier(input_al, output_dim=len(output_al), num_layers=2, hidden_dim=50,
                    x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test,
                    batch_size=32, nn_type='GRU')

# Train it until 100% accuracy is achieved
rnn.train(epochs=250, stop_acc=1.0, stop_epochs=3)

Starting train
Epoch 0: Accuracy 0.97617, Avg. Loss 6.01565 Validation Accuracy 0.97934
Epoch 1: Accuracy 0.97809, Avg. Loss 2.4939 Validation Accuracy 0.98201
Epoch 2: Accuracy 0.99225, Avg. Loss 1.42906 Validation Accuracy 0.99367
Epoch 3: Accuracy 0.99583, Avg. Loss 0.60914 Validation Accuracy 0.99667
Epoch 4: Accuracy 0.99917, Avg. Loss 0.258 Validation Accuracy 0.99967
Epoch 5: Accuracy 0.99983, Avg. Loss 0.07625 Validation Accuracy 0.99967
Epoch 6: Accuracy 0.99992, Avg. Loss 0.0386 Validation Accuracy 0.99967
Epoch 7: Accuracy 0.99983, Avg. Loss 0.03375 Validation Accuracy 1.0
Epoch 8: Accuracy 1.0, Avg. Loss 0.02833 Validation Accuracy 0.99967
Epoch 9: Accuracy 0.99675, Avg. Loss 0.11456 Validation Accuracy 0.99767
Epoch 10: Accuracy 0.99983, Avg. Loss 0.01049 Validation Accuracy 1.0
Epoch 11: Accuracy 0.99967, Avg. Loss 0.23278 Validation Accuracy 0.99967
Epoch 12: Accuracy 1.0, Avg. Loss 0.00494 Validation Accuracy 1.0
Epoch 13: Accuracy 1.0, Avg. Loss 0.00378 Validation Accu

In [5]:
# Use RNN as system under learning
sul = RnnMealySUL(rnn, outputs_2_ints)

# Define the eq. oracle
eq_oracle = StatePrefixEqOracle(input_al, sul, walks_per_state=150, walk_len=25)

# Run the learning algorithm
learned_automaton = run_Lstar(alphabet=input_al, sul=sul, eq_oracle=eq_oracle, automaton_type='mealy',
                              cache_and_non_det_check=False, max_learning_rounds=10,
                              suffix_closedness=False, print_level=2)

Hypothesis 1: 2 states.
Hypothesis 2: 5 states.
Hypothesis 3: 6 states.
-----------------------------------
Learning Finished.
Learning Rounds:  3
Number of states: 6
Time (in seconds)
  Total                : 0.91
  Learning algorithm   : 0.02
  Conformance checking : 0.89
Learning Algorithm
 # Membership Queries  : 154
 # Steps               : 546
Equivalence Query
 # Membership Queries  : 900
 # Steps               : 23656
-----------------------------------


In [6]:
# Print the learned model
print(learned_automaton)

digraph learnedModel {
s0 [label=s0];
s1 [label=s1];
s2 [label=s2];
s3 [label=s3];
s4 [label=s4];
s5 [label=s5];
s0 -> s0  [label="clean/check"];
s0 -> s5  [label="pod/check"];
s0 -> s2  [label="water/check"];
s0 -> s1  [label="button/star"];
s1 -> s1  [label="clean/star"];
s1 -> s1  [label="pod/star"];
s1 -> s1  [label="water/star"];
s1 -> s1  [label="button/star"];
s2 -> s0  [label="clean/check"];
s2 -> s3  [label="pod/check"];
s2 -> s2  [label="water/check"];
s2 -> s1  [label="button/star"];
s3 -> s0  [label="clean/check"];
s3 -> s3  [label="pod/check"];
s3 -> s3  [label="water/check"];
s3 -> s4  [label="button/coffee"];
s4 -> s0  [label="clean/check"];
s4 -> s1  [label="pod/coffee"];
s4 -> s1  [label="water/coffee"];
s4 -> s1  [label="button/coffee"];
s5 -> s0  [label="clean/check"];
s5 -> s5  [label="pod/check"];
s5 -> s3  [label="water/check"];
s5 -> s1  [label="button/star"];
__start0 [label="", shape=none];
__start0 -> s0  [label=""];
}

