In [1]:
import json
from typing import Tuple, List
import numpy as np
from maraboupy import Marabou, MarabouCore, MarabouUtils, MarabouPythonic
import logging
import pandas



In [9]:
# Global config

NET_PATH = './256x4_scratch.onnx'
PATTERN_PATH = './NAPs/256x4s_delta99.json'
MAX_TIME = 300
M_OPTIONS: MarabouCore.Options = Marabou.createOptions(verbosity=0, numWorkers=10, timeoutInSeconds=MAX_TIME)

In [3]:
with open(PATTERN_PATH, "r") as f:
    NAP = json.load(f)

In [4]:
def init_network()->Marabou.MarabouNetworkNNet:
    network:Marabou.MarabouNetworkNNet = Marabou.read_onnx(NET_PATH)

    logging.debug("relu list:")
    for r in network.reluList:
        logging.debug(r)
    for i in range(784):
        network.setLowerBound(i, 0.0)
        network.setUpperBound(i, 1.0)       
  
    return network

In [5]:
def find_marabou_idx(idx, net):
    layer, n_idx = idx
    offset = 784
    marabou_idx = offset + (layer) * 256 *2 + 256 + n_idx
    post_relu_neurons = [t[1] for t in net.reluList]
    assert marabou_idx in post_relu_neurons
    return marabou_idx 

In [6]:
def add_nap_constraints(network, A, D):
    for idx in A:
        constraint = MarabouUtils.Equation(MarabouCore.Equation.GE)
        constraint.addAddend(1, find_marabou_idx(idx, network))
        constraint.setScalar(0.001)
        network.addEquation(constraint)

    for idx in D:
        constraint = MarabouUtils.Equation(MarabouCore.Equation.EQ)
        constraint.addAddend(1, find_marabou_idx(idx, network))
        constraint.setScalar(0.0)
        network.addEquation(constraint)

    return network

In [7]:
def add_counter_example_constraint(network, label, other_label):
    offset = network.outputVars[0][0][0]
    constraint = MarabouUtils.Equation(MarabouCore.Equation.GE)
    constraint.addAddend(1, other_label+offset)
    constraint.addAddend(-1, label+offset)
    constraint.setScalar(0.001)
    network.addEquation(constraint)
    print(f"Added constraint {other_label} >= {label}")  
    return network

In [10]:
res = [[-1.]*10 for i in range(10)]

for label in range(10):
    print(f"Checking NAP pattern robustness for label: {label}")
    A = NAP[str(label)]['A']['indices']
    D = NAP[str(label)]['D']['indices']
    for other_label in range(10):
        if other_label == int(label):
            continue
        net = init_network()
        print(f"Network initilized : ({len(net.equList)})")
        net = add_nap_constraints(net, A, D)
        print(f"NAP constraints established : ({len(net.equList)})")
        net = add_counter_example_constraint(net, label, other_label)
        print(f"Counter example constraint added : ({len(net.equList)})")
        print(f"Solving...")
        exit_code, vals, stats = net.solve(options=M_OPTIONS)
        print(f"{exit_code}!")
        running_time = stats.getTotalTimeInMicro()
        if exit_code=="sat":
            res[int(label)][other_label] = "SAT:{}".format(running_time/10**6)
        elif exit_code=="unsat":
            res[int(label)][other_label] = "UNS:{}".format(running_time/10**6)
        else:
            res[int(label)][other_label] = exit_code

res = pandas.DataFrame(res)
print(res)

Checking NAP pattern robustness for label: 0
Network initilized : (1034)
NAP constraints established : (1472)
Added constraint 1 >= 0
Counter example constraint added : (1473)
Solving...
TIMEOUT
TIMEOUT!
Network initilized : (1034)
NAP constraints established : (1472)
Added constraint 2 >= 0
Counter example constraint added : (1473)
Solving...
TIMEOUT
TIMEOUT!
Network initilized : (1034)
NAP constraints established : (1472)
Added constraint 3 >= 0
Counter example constraint added : (1473)
Solving...
TIMEOUT
TIMEOUT!
Network initilized : (1034)
NAP constraints established : (1472)
Added constraint 4 >= 0
Counter example constraint added : (1473)
Solving...
TIMEOUT
TIMEOUT!
Network initilized : (1034)
NAP constraints established : (1472)
Added constraint 5 >= 0
Counter example constraint added : (1473)
Solving...
TIMEOUT
TIMEOUT!
Network initilized : (1034)
NAP constraints established : (1472)
Added constraint 6 >= 0
Counter example constraint added : (1473)
Solving...
TIMEOUT
TIMEOUT!
N

In [11]:
res.to_csv("./checking_res/new_99.csv")