In [40]:
import json
from typing import Tuple, List
import numpy as np
from maraboupy import Marabou, MarabouCore, MarabouUtils
import logging
import pandas

In [41]:
PATH = './mnist-net_256x4.onnx'

In [42]:
EPSILON = 0.5 
PATTERN_PATH = './mnist_relu_patterns_0.json'
MAX_TIME = 300 #in seconds
M_OPTIONS: MarabouCore.Options = Marabou.createOptions(verbosity=0, numWorkers=10, timeoutInSeconds=MAX_TIME)

In [43]:
with open(PATTERN_PATH, "r") as f:
    STABLE_PATTERNS = json.load(f)

In [46]:
set(STABLE_PATTERNS['0']['val'])

{0, 5923}

In [21]:
loc = 0.5
radus = 0.5
non_restricted_dim = []

In [22]:
def init_network()->Marabou.MarabouNetworkNNet:
    network:Marabou.MarabouNetworkNNet = Marabou.read_onnx(PATH)

    # print("output nodes:", network.outputVars)
    # print("input nodes:", network.inputVars)

    logging.debug("relu list:")
    for r in network.reluList:
        logging.debug(r)
    for i in range(784):
        network.setLowerBound(i, 0.0)
        network.setUpperBound(i, 1.0)


        # if i in non_restricted_dim:
        #     network.setLowerBound(i, loc - radus)
        #     network.setUpperBound(i, loc + radus)
        # else:
        #     network.setLowerBound(i, 0.00)
        #     network.setUpperBound(i, 0.00)           
  
    return network

In [23]:
def parse_raw_idx(raw_idx: int) -> Tuple[int, int, int]:
    """
    only for MNIST 256xk network:
    """
    n_relus = 256
    offset = 28*28+10
    layer = raw_idx // n_relus
    idx = raw_idx % n_relus
    marabou_idx = 2*n_relus*layer + idx + offset
    return layer, idx, marabou_idx

In [24]:
def add_relu_constraints(network: Marabou.MarabouNetworkNNet, 
                        relu_check_list: List[int], 
                        relu_val: List[int])->Marabou.MarabouNetworkNNet:
    """
    Add stable relus constraints to the Marabou network
    """
    for i in range(len(relu_check_list)):
        layer, idx, marabou_idx = parse_raw_idx(relu_check_list[i])     
        if relu_val[i] < 2500:        
#        if relu_val[i] == 0:
            constraint = MarabouUtils.Equation(MarabouCore.Equation.LE)
            constraint.addAddend(1, marabou_idx)
            constraint.setScalar(-0.001)
        else:
            constraint = MarabouUtils.Equation(MarabouCore.Equation.GE)
            constraint.addAddend(1, marabou_idx)
            constraint.setScalar(0.001)
        network.addEquation(constraint)

    return network

In [25]:
def find_one_assignment(relu_check_list: List[int], relu_val: List[int])->None:
    network = init_network()
    network = add_relu_constraints(network, relu_check_list, relu_val)    
    exitCode, vals, stats = network.solve()
    assert(exitCode=="sat")    
    for idx, r in enumerate(relu_check_list):
        marabou_idx = parse_raw_idx(r)[-1]
#        print(marabou_idx, vals[marabou_idx], relu_val[idx])

In [26]:
def check_pattern(relu_check_list: List[int], relu_val: List[int], label: int, other_label: int)->Tuple[str, int]:
    """
    In ACAS, the prediction is the label with smallest value.
    So we check that label - other_label < 0 forall input
    by finding assignments for label - other_label >=0
    """
    print("--------CHECK PATTERN: output_{} is always less than output_{} ? --------".format(label, other_label))
    network = init_network()
    network = add_relu_constraints(network, relu_check_list, relu_val)    
    offset = network.outputVars[0][0][0]

    #print(network.numVars)
    #print(network.reluList)
    #print(network.outputVars)

    #add output constraint
    constraint = MarabouUtils.Equation(MarabouCore.Equation.GE)
    constraint.addAddend(1, other_label+offset)
    constraint.addAddend(-1, label+offset)
    constraint.setScalar(0.001)
    network.addEquation(constraint)

    exit_code: str    
    exit_code, vals, stats = network.solve(options=M_OPTIONS)
    running_time:int = stats.getTotalTimeInMicro()
    for idx, r in enumerate(relu_check_list):
        marabou_idx = parse_raw_idx(r)[-1]

    return exit_code, running_time

In [27]:
def main():
    res = [[-1.]*10 for i in range(10)]
    # print(res)
    for label in STABLE_PATTERNS:
        print(f"For label {label}, check if its stable RELU pattern guarantees the output")
        for other_label in range(10):
            if other_label == int(label):
                continue
            relu_check_list = STABLE_PATTERNS[label]["stable_idx"]
            relu_val = STABLE_PATTERNS[label]["val"] 
            exit_code, running_time = check_pattern(relu_check_list, relu_val, label=int(label), other_label = other_label)
            if exit_code=="sat":
                res[int(label)][other_label] = "SAT:{}".format(running_time/10**6)
                break
            elif exit_code=="unsat":
                res[int(label)][other_label] = "UNS:{}".format(running_time/10**6)
            else:
                res[int(label)][other_label] = exit_code

    res = pandas.DataFrame(res)
    print(res)

In [39]:
set(STABLE_PATTERNS['8']['val'])

{0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 41,
 42,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 55,
 56,
 58,
 5801,
 5812,
 5814,
 5818,
 5826,
 5827,
 5831,
 5833,
 5838,
 5850}

In [28]:
main()

For label 0, check if its stable RELU pattern guarantees the output
--------CHECK PATTERN: output_0 is always less than output_1 ? --------
unsat
--------CHECK PATTERN: output_0 is always less than output_2 ? --------
unsat
--------CHECK PATTERN: output_0 is always less than output_3 ? --------
unsat
--------CHECK PATTERN: output_0 is always less than output_4 ? --------
unsat
--------CHECK PATTERN: output_0 is always less than output_5 ? --------
unsat
--------CHECK PATTERN: output_0 is always less than output_6 ? --------
unsat
--------CHECK PATTERN: output_0 is always less than output_7 ? --------
unsat
--------CHECK PATTERN: output_0 is always less than output_8 ? --------
unsat
--------CHECK PATTERN: output_0 is always less than output_9 ? --------
unsat
For label 1, check if its stable RELU pattern guarantees the output
--------CHECK PATTERN: output_1 is always less than output_0 ? --------
unsat
--------CHECK PATTERN: output_1 is always less than output_2 ? --------
unsat
------

KeyboardInterrupt: 

In [48]:
net = init_network()

In [54]:
net.

[]