In [1]:
import json
import numpy as np
from maraboupy import Marabou, MarabouCore, MarabouUtils, MarabouPythonic
import pandas
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm.auto import tqdm



In [2]:
# Global config

NET_PATH = './mnist_fc_64x4_adv_1.onnx'
PATTERN_PATH = './NAPs/mnist_fc_64x4_adv_1_d99.json'
MAX_TIME = 56
M_OPTIONS: MarabouCore.Options = Marabou.createOptions(verbosity=0, timeoutInSeconds=MAX_TIME)
START = 100
END = 200
EPSILON = 0.1
LAYERS = 4
NEURONS_WIDTH = 64

In [3]:
with open(PATTERN_PATH, "r") as f:
    NAP = json.load(f)

In [4]:
def init_network(data_point=None, epsilon=None)->Marabou.MarabouNetworkNNet:
    network:Marabou.MarabouNetworkNNet = Marabou.read_onnx(NET_PATH)
    if data_point is None:
        for i in range(784):
            network.setLowerBound(i, 0.0)
            network.setUpperBound(i, 1.0)       
    else:
        for i in range(784):
            network.setLowerBound(i, max(data_point[i]-epsilon, 0.))
            network.setUpperBound(i, min(data_point[i]+epsilon, 1.))
  
    return network

In [5]:
def find_marabou_idx(idx, net):
    layer, n_idx = idx
    offset = 784
    marabou_idx = offset + (layer) * NEURONS_WIDTH *2 + NEURONS_WIDTH + n_idx
    post_relu_neurons = [t[1] for t in net.reluList]
    assert marabou_idx in post_relu_neurons
    return marabou_idx 

In [6]:
def add_nap_constraints(network, A, D):
    for idx in A:
        constraint = MarabouUtils.Equation(MarabouCore.Equation.GE)
        constraint.addAddend(1, find_marabou_idx(idx, network))
        constraint.setScalar(0.001)
        network.addEquation(constraint)

    for idx in D:
        constraint = MarabouUtils.Equation(MarabouCore.Equation.EQ)
        constraint.addAddend(1, find_marabou_idx(idx, network))
        constraint.setScalar(0.0)
        network.addEquation(constraint)

    return network

In [7]:
def add_counter_example_constraint(network, label, other_label):
    offset = network.outputVars[0][0][0]
    constraint = MarabouUtils.Equation(MarabouCore.Equation.GE)
    constraint.addAddend(1, other_label+offset)
    constraint.addAddend(-1, label+offset)
    constraint.setScalar(0.001)
    network.addEquation(constraint)
    # print(f"Added constraint {other_label} >= {label}")  
    return network

In [8]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1.,))
])

trainset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(trainset, batch_size=10000, shuffle=False)

imgs, labels = next(iter(testloader))

In [9]:
res = [[-1.]*10 for i in range(START, END)]


for i in tqdm(range(START, END)):
    data_point = imgs[i].reshape((784)).numpy()
    label = labels[i].item()
    A = NAP[str(label)]['A']['indices']
    D = NAP[str(label)]['D']['indices']
    for other_label in range(10):
        if other_label == int(label):
            continue
        net = init_network(data_point, EPSILON)
        # print(f"Network initilized : ({len(net.equList)})")
        net = add_nap_constraints(net, A, D)
        # print(f"NAP constraints established : ({len(net.equList)})")
        net = add_counter_example_constraint(net, label, other_label)
        # print(f"Counter example constraint added : ({len(net.equList)})")
        # print(f"Solving...")
        exit_code, vals, stats = net.solve(options=M_OPTIONS)
        # print(f"{exit_code}!")
        running_time = stats.getTotalTimeInMicro()
        if exit_code=="sat":
            res[i-START][other_label] = "SAT:{}".format(running_time/10**6)
        elif exit_code=="unsat":
            res[i-START][other_label] = "UNS:{}".format(running_time/10**6)
        else:
            res[i-START][other_label] = exit_code

res = pandas.DataFrame(res)
res.to_csv("./checking_res/64x4s_eps1_99_fix_200.csv")
# print(res)

  0%|          | 0/100 [00:00<?, ?it/s]

unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
sat
input 0 = 0.0
input 1 = 0.0
input 2 = 0.0
input 3 = 0.0
input 4 = 0.0
input 5 = 0.0
input 6 = 0.0
input 7 = 0.0
input 8 = 0.0
input 9 = 0.0
input 10 = 0.0
input 11 = 0.0
input 12 = 0.0
input 13 = 0.0
input 14 = 0.0
input 15 = 0.0
input 16 = 0.0
input 17 = 0.0
input 18 = 0.0
input 19 = 0.0
input 20 = 0.0
input 21 = 0.0
input 22 = 0.0
input 23 = 0.0
input 24 = 0.0
input 25 = 0.0
input 26 = 0.0
input 27 = 0.0
input 28 = 0.0
input 29 = 0.0
input 30 = 0.0
input 31 = 0.0
input 32 = 0.0
input 33 = 0.0
input 34 = 0.0
input 35 = 0.0
input 36 = 0.0
input 37 = 0.0
input 38 = 0.0
input 39 = 0.0
input 40 = 0.0
input 41 = 0.0
input 42 = 0.0
input 43 = 0.0
input 44 = 0.0
input 45 = 0.0
input 46 = 0.0
input 47 = 0.0
input 48 = 0.0
input 49 = 0.0
input 50 =

In [10]:
# Verifying without datapoint on entire data space. Not scalable.

# res = [[-1.]*10 for i in range(10)]

# for label in range(10):
#     print(f"Checking NAP pattern robustness for label: {label}")
#     A = NAP[str(label)]['A']['indices']
#     D = NAP[str(label)]['D']['indices']
#     for other_label in range(10):
#         if other_label == int(label):
#             continue
#         net = init_network()
#         print(f"Network initilized : ({len(net.equList)})")
#         net = add_nap_constraints(net, A, D)
#         print(f"NAP constraints established : ({len(net.equList)})")
#         net = add_counter_example_constraint(net, label, other_label)
#         print(f"Counter example constraint added : ({len(net.equList)})")
#         print(f"Solving...")
#         exit_code, vals, stats = net.solve(options=M_OPTIONS)
#         print(f"{exit_code}!")
#         running_time = stats.getTotalTimeInMicro()
#         if exit_code=="sat":
#             res[int(label)][other_label] = "SAT:{}".format(running_time/10**6)
#         elif exit_code=="unsat":
#             res[int(label)][other_label] = "UNS:{}".format(running_time/10**6)
#         else:
#             res[int(label)][other_label] = exit_code

# res = pandas.DataFrame(res)
# print(res)

In [11]:
# res.to_csv("./checking_res/old_99.csv")