# Sanity Checks

In [1]:
import sys, os, re
import numpy as np
from IPython.display import display
from maraboupy import Marabou, MarabouCore
from tot_net import TOTNet

In [2]:
sample_input = np.array([
    0.2622614500364243, -0.22479658685112885, -0.14872559153186365,-0.1183740939370412,       -0.11582643087693376, 0.4638055359970424, 0.5343624908151317, -1.8587238209603285,        -1.925346908225373, 0.7745258391529399, 1.691414092773877, -1.1513317010900714,           0.15687995017838718, -0.4932713694542802, 0.3758603164252629, -0.6443093237643779,        0.5198151301903109, -0.1392841543884293, -0.17885723928180608, -0.7026510235427964,       -0.04793722518544871, -0.4594806526602232, 0.18987588687321624, 0.4657886911920623,       -0.1626985247701109])
sample_output = [1, 0, 0]

nnet_file = '../artifacts/models/model-v3.1.3/model.nnet'

In [3]:
# 0. Sanity test 0: predict known input using the network.
test_net = Marabou.read_nnet(nnet_file)
output = test_net.evaluate([sample_input])[0]
assert(0 == list(output).index(max(output)))

print(f'Correct. {output}') # --> [1, 0, 0]

Correct. [ 3.42226921  0.74396088 -0.96801591]


In [4]:
# 1. Sanity test 1 (correct class y0, should be SAT)
net = TOTNet(nnet_file)
net.set_lower_bounds(sample_input)
net.set_upper_bounds(sample_input)
net.set_expected_category(0) # +y1 -y0 <= 0, +y1 -y2 <= 0
vals, stats = net.solve()
assert(len(vals) > 0)
print('SAT: ', vals)

SAT:  ['input 0 = 0.2622614500364243', 'input 1 = -0.22479658685112885', 'input 2 = -0.14872559153186365', 'input 3 = -0.1183740939370412', 'input 4 = -0.11582643087693376', 'input 5 = 0.4638055359970424', 'input 6 = 0.5343624908151317', 'input 7 = -1.8587238209603285', 'input 8 = -1.925346908225373', 'input 9 = 0.7745258391529399', 'input 10 = 1.691414092773877', 'input 11 = -1.1513317010900714', 'input 12 = 0.15687995017838718', 'input 13 = -0.4932713694542802', 'input 14 = 0.3758603164252629', 'input 15 = -0.6443093237643779', 'input 16 = 0.5198151301903109', 'input 17 = -0.1392841543884293', 'input 18 = -0.17885723928180608', 'input 19 = -0.7026510235427964', 'input 20 = -0.04793722518544871', 'input 21 = -0.4594806526602232', 'input 22 = 0.18987588687321624', 'input 23 = 0.4657886911920623', 'input 24 = -0.1626985247701109', 'output 0 = 3.42226921215419', 'output 1 = 0.7439608810539671', 'output 2 = -0.968015911041988']


In [5]:
# 2. Sanity test 2 (incorrect class y2, should be UNSAT)
net = TOTNet(nnet_file)
net.set_lower_bounds(sample_input)
net.set_upper_bounds(sample_input)
net.set_expected_category(1) # +y0 -y2 <= 0, +y1 -y2 <= 0
vals, _ = net.solve()
assert(len(vals) == 0)
print('UNSAT', vals)

UNSAT []


In [6]:
# 3. Sanity test 3 (incorrect class y2, should be UNSAT)
net = TOTNet(nnet_file)
net.set_lower_bounds(sample_input)
net.set_upper_bounds(sample_input)
net.set_expected_category(2) # +y0 -y2 <= 0, +y1 -y2 <= 0
vals, _ = net.solve()
assert(len(vals) == 0)
print('UNSAT', vals)

UNSAT []
