In [2]:
import json
from typing import Tuple, List
import numpy as np
from maraboupy import Marabou, MarabouCore, MarabouUtils
import logging
import pandas



In [3]:
PATH = './mnist-net_256x4.onnx'

In [4]:
EPSILON = 0.5 
PATTERN_PATH = './mnist_relu_patterns_0.json'
MAX_TIME = 300 #in seconds
M_OPTIONS: MarabouCore.Options = Marabou.createOptions(verbosity=0, numWorkers=10, timeoutInSeconds=MAX_TIME)

In [5]:
with open(PATTERN_PATH, "r") as f:
    STABLE_PATTERNS = json.load(f)

In [6]:
set(STABLE_PATTERNS['0']['val'])

{0, 5923}

In [7]:
loc = 0.5
radus = 0.5
non_restricted_dim = []

In [8]:
def init_network()->Marabou.MarabouNetworkNNet:
    network:Marabou.MarabouNetworkNNet = Marabou.read_onnx(PATH)

    # print("output nodes:", network.outputVars)
    # print("input nodes:", network.inputVars)

    logging.debug("relu list:")
    for r in network.reluList:
        logging.debug(r)
    for i in range(784):
        network.setLowerBound(i, 0.0)
        network.setUpperBound(i, 1.0)


        # if i in non_restricted_dim:
        #     network.setLowerBound(i, loc - radus)
        #     network.setUpperBound(i, loc + radus)
        # else:
        #     network.setLowerBound(i, 0.00)
        #     network.setUpperBound(i, 0.00)           
  
    return network

In [9]:
def parse_raw_idx(raw_idx: int) -> Tuple[int, int, int]:
    """
    only for MNIST 256xk network:
    """
    n_relus = 256
    offset = 28*28
    layer = raw_idx // n_relus
    idx = raw_idx % n_relus
    marabou_idx = 2*n_relus*layer + idx + offset
    return layer, idx, marabou_idx

In [10]:
def add_relu_constraints(network: Marabou.MarabouNetworkNNet, 
                        relu_check_list: List[int], 
                        relu_val: List[int])->Marabou.MarabouNetworkNNet:
    """
    Add stable relus constraints to the Marabou network
    """
    for i in range(len(relu_check_list)):
        layer, idx, marabou_idx = parse_raw_idx(relu_check_list[i])     
        if relu_val[i] < 2500:        
#        if relu_val[i] == 0:
            constraint = MarabouUtils.Equation(MarabouCore.Equation.LE)
            constraint.addAddend(1, marabou_idx)
            constraint.setScalar(-0.001)
        else:
            constraint = MarabouUtils.Equation(MarabouCore.Equation.GE)
            constraint.addAddend(1, marabou_idx)
            constraint.setScalar(0.001)
        network.addEquation(constraint)

    return network

In [11]:
def find_one_assignment(relu_check_list: List[int], relu_val: List[int])->None:
    network = init_network()
    network = add_relu_constraints(network, relu_check_list, relu_val)    
    exitCode, vals, stats = network.solve()
    assert(exitCode=="sat")    
    for idx, r in enumerate(relu_check_list):
        marabou_idx = parse_raw_idx(r)[-1]
#        print(marabou_idx, vals[marabou_idx], relu_val[idx])

In [12]:
def check_pattern(relu_check_list: List[int], relu_val: List[int], label: int, other_label: int)->Tuple[str, int]:
    """
    In ACAS, the prediction is the label with smallest value.
    So we check that label - other_label < 0 forall input
    by finding assignments for label - other_label >=0
    """
    print("--------CHECK PATTERN: output_{} is always less than output_{} ? --------".format(label, other_label))
    network = init_network()
    network = add_relu_constraints(network, relu_check_list, relu_val)    
    offset = network.outputVars[0][0][0]

    #print(network.numVars)
    #print(network.reluList)
    #print(network.outputVars)

    #add output constraint
    constraint = MarabouUtils.Equation(MarabouCore.Equation.GE)
    constraint.addAddend(1, other_label+offset)
    constraint.addAddend(-1, label+offset)
    constraint.setScalar(0.001)
    network.addEquation(constraint)

    print(f"Counter example constraint added : ({len(network.equList)})")

    exit_code: str    
    exit_code, vals, stats = network.solve(options=M_OPTIONS)
    running_time:int = stats.getTotalTimeInMicro()
    for idx, r in enumerate(relu_check_list):
        marabou_idx = parse_raw_idx(r)[-1]

    return exit_code, running_time

In [13]:
def main():
    res = [[-1.]*10 for i in range(10)]
    # print(res)
    for label in STABLE_PATTERNS:
        print(f"For label {label}, check if its stable RELU pattern guarantees the output")
        for other_label in range(10):
            if other_label == int(label):
                continue
            relu_check_list = STABLE_PATTERNS[label]["stable_idx"]
            relu_val = STABLE_PATTERNS[label]["val"] 
            exit_code, running_time = check_pattern(relu_check_list, relu_val, label=int(label), other_label = other_label)
            if exit_code=="sat":
                res[int(label)][other_label] = "SAT:{}".format(running_time/10**6)
                break
            elif exit_code=="unsat":
                res[int(label)][other_label] = "UNS:{}".format(running_time/10**6)
            else:
                res[int(label)][other_label] = exit_code

    res = pandas.DataFrame(res)
    print(res)

In [14]:
set(STABLE_PATTERNS['8']['val'])

{0}

In [15]:
main()

For label 0, check if its stable RELU pattern guarantees the output
--------CHECK PATTERN: output_0 is always less than output_1 ? --------
Counter example constraint added : (1547)
sat
input 0 = 0.0
input 1 = 0.0
input 2 = 0.0
input 3 = 0.0
input 4 = 0.26556637154272456
input 5 = 0.0
input 6 = 0.2779890409299663
input 7 = 0.0
input 8 = 0.0
input 9 = 0.0
input 10 = 0.0
input 11 = 0.44016137889546103
input 12 = 0.0
input 13 = 0.0
input 14 = 0.0
input 15 = 0.0
input 16 = 0.6236975829986653
input 17 = 0.0
input 18 = 0.0
input 19 = 0.0
input 20 = 0.0935307369928482
input 21 = 0.0
input 22 = 0.4091261875225136
input 23 = 0.0
input 24 = 0.0
input 25 = 0.0
input 26 = 0.07396583996317287
input 27 = 0.0
input 28 = 0.0
input 29 = 0.0
input 30 = 0.0
input 31 = 0.0
input 32 = 0.11909009389719025
input 33 = 0.7010894600057086
input 34 = 0.0
input 35 = 0.0
input 36 = 0.08924896841910102
input 37 = 0.0
input 38 = 0.0
input 39 = 0.0
input 40 = 0.0
input 41 = 0.0
input 42 = 0.0
input 43 = 0.0
input 44 

In [None]:
net = init_network()

In [None]:
net.

[]

In [None]:
res

NameError: name 'res' is not defined