In [1]:
import json
import numpy as np
from maraboupy import Marabou, MarabouCore, MarabouUtils, MarabouPythonic
import pandas
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm.auto import tqdm



In [2]:
# Global config

NET_PATH = './models/mnist_fc_64x4_adv_1.onnx'
PATTERN_PATH = './NAPs/mnist_fc_64x4_adv_1_d99.json'
MAX_TIME = 56
M_OPTIONS: MarabouCore.Options = Marabou.createOptions(verbosity=0, timeoutInSeconds=MAX_TIME)
START = 0
END = 200
EPSILON = 0.1
LAYERS = 4
NEURONS_WIDTH = 64

In [3]:
with open(PATTERN_PATH, "r") as f:
    NAP = json.load(f)

In [4]:
def init_network(data_point=None, epsilon=None)->Marabou.MarabouNetworkNNet:
    network:Marabou.MarabouNetworkNNet = Marabou.read_onnx(NET_PATH)
    if data_point is None:
        for i in range(784):
            network.setLowerBound(i, 0.0)
            network.setUpperBound(i, 1.0)       
    else:
        for i in range(784):
            network.setLowerBound(i, max(data_point[i]-epsilon, 0.))
            network.setUpperBound(i, min(data_point[i]+epsilon, 1.))
  
    return network

In [5]:
def find_marabou_idx(idx, net):
    layer, n_idx = idx
    offset = 784
    marabou_idx = offset + (layer) * NEURONS_WIDTH *2 + NEURONS_WIDTH + n_idx
    post_relu_neurons = [t[1] for t in net.reluList]
    assert marabou_idx in post_relu_neurons
    return marabou_idx 

In [6]:
def add_nap_constraints(network, A, D):
    for idx in A:
        constraint = MarabouUtils.Equation(MarabouCore.Equation.GE)
        constraint.addAddend(1, find_marabou_idx(idx, network))
        constraint.setScalar(0.001)
        network.addEquation(constraint)

    for idx in D:
        constraint = MarabouUtils.Equation(MarabouCore.Equation.EQ)
        constraint.addAddend(1, find_marabou_idx(idx, network))
        constraint.setScalar(0.0)
        network.addEquation(constraint)

    return network

In [7]:
def add_counter_example_constraint(network, label, other_label):
    offset = network.outputVars[0][0][0]
    constraint = MarabouUtils.Equation(MarabouCore.Equation.GE)
    constraint.addAddend(1, other_label+offset)
    constraint.addAddend(-1, label+offset)
    constraint.setScalar(0.001)
    network.addEquation(constraint)
    # print(f"Added constraint {other_label} >= {label}")  
    return network

In [8]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1.,))
])

trainset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(trainset, batch_size=10000, shuffle=False)

imgs, labels = next(iter(testloader))

In [9]:
res = [[-1.]*10 for i in range(START, END)]

safe = []
unsafe = []
unknown = []

for i in tqdm(range(START, END)):
    data_point = imgs[i].reshape((784)).numpy()
    label = labels[i].item()
    A = NAP[str(label)]['A']['indices']
    D = NAP[str(label)]['D']['indices']
    cnt = 0
    for other_label in range(10):
        if other_label == int(label):
            continue
        net = init_network(data_point, EPSILON)
        net = add_nap_constraints(net, A, D)
        net = add_counter_example_constraint(net, label, other_label)
        exit_code, vals, stats = net.solve(options=M_OPTIONS)
        running_time = stats.getTotalTimeInMicro()
        if exit_code=="sat":
            res[i-START][other_label] = "SAT:{}".format(running_time/10**6)
            unsafe.append(i)
            break
        elif exit_code=="unsat":
            res[i-START][other_label] = "UNS:{}".format(running_time/10**6)
            cnt += 1
        else:
            res[i-START][other_label] = exit_code
            unknown.append(i)
            break
    if cnt == 9:
        safe.append(i)

res = pandas.DataFrame(res)
res.to_csv("./checking_res/torch_64x4s_eps1_99_200.csv")
# print(res)

  0%|          | 0/200 [00:00<?, ?it/s]

unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
unsat
sat
input 0 = 0.0
input 1 = 0.0
input 2 = 0.0
input 3 = 0.0
input 4 = 0.0
input 5 = 0.0
input 6 = 0.0
input 7 = 0.0
input 8 = 0.0
input 9 = 0.0
input 10 = 0.0
input 11 = 0.0
input 12 = 0.0
input 13 = 0.0
input 14 = 0.0
input 15 = 0.0
input 16 = 0.0
input 17 = 0.0
input 18 = 0.0
input 19 = 0.0
input 20 = 0.0
input 21 = 0.0
input 22 = 0.0
input 23 = 0.0
input 24 = 0.0
input 25 = 0.0
input 26 = 0.0
input 27 = 0.0
input 28 = 0.0
input 29 = 0.0
input 30 = 0.0
input 31 = 0.0
input 32 = 0.0
input 33 = 0.0
input 34 = 0.0
input 35 = 0.0
input 36 =

In [10]:
print(f"safe: {safe}\nunsafe: {unsafe}\nunkown: {unknown}")

safe: [0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 63, 64, 67, 68, 69, 70, 71, 74, 75, 76, 78, 79, 80, 81, 82, 83, 85, 86, 88, 89, 90, 91, 93, 94, 96, 98, 99, 100, 101, 102, 103, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 116, 117, 118, 120, 121, 122, 123, 125, 126, 127, 128, 129, 130, 131, 132, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 150, 152, 153, 154, 155, 156, 157, 160, 161, 162, 163, 164, 165, 166, 168, 169, 170, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 196, 197, 198, 199]
unsafe: [8, 30, 35, 62, 65, 66, 72, 73, 77, 84, 87, 92, 95, 97, 104, 115, 119, 124, 133, 149, 151, 158, 159, 167, 171, 195]
unkown: []
