In [1]:
import torch
from networks import FullyConnected, SPU, Normalization
import os
import time

DEVICE = 'cpu'
INPUT_SIZE = 28

In [2]:
class Args:
    net: str
    spec: str

    def __init__(self, net: str, spec: str):
        self.net = net
        self.spec = spec

In [3]:
def analyze(net: torch.nn.Module, inputs: torch.FloatTensor, eps: float, true_label: int) -> str:
    """Returns "verified" if the returned label is the same for all points in a L-infinite epsilon-ball around an input point, and "not verified" otherwise.

    Args:
        net (torch.nn.Module): [description]
        inputs (torch.FloatTensor): [description]
        eps (float): [description]
        true_label (int): [description]

    Returns:
        [str]: Returns "verified" or "not verified"
    """
    start_time = time.time()

    # Compute lower and upper bounds for all points
    detached_inputs = inputs.detach()
    lower_orig = torch.min(detached_inputs - eps * 1, torch.tensor(0.0)).to(DEVICE)
    upper_orig = torch.max(detached_inputs + eps * 1, torch.tensor(1.0)).to(DEVICE)

    with torch.no_grad():
        outputs = net(inputs)
        low = net(lower_orig)
        upper = net(upper_orig)

    print(f'Low: {low}')
    print(f'High: {upper}')
    print(f'Outputs: {outputs.shape}')
    verified = sum((low[0][true_label] > upper[0]).int()) == true_label
    
    if verified: 
        return True


    for i, layer in enumerate(net.layers, 1):
        if type(layer) is SPU:
            continue
        
        elif type(layer) is Normalization:
            continue
        
        elif type(layer) is torch.nn.Flatten:
            continue

        elif type(layer) is torch.nn.Linear:
            continue
    
    # ???

    # Compare around 
    end_time = time.time()
    
    return 0

    # return (true_label_lx > labels_ux).all()


In [4]:
args = Args(
    net="net0_fc1",
    spec="../test_cases/net0_fc1/example_img0_0.01800.txt"
)

with open(args.spec, 'r') as f:
    lines = [line[:-1] for line in f.readlines()]
    true_label = int(lines[0])
    pixel_values = [float(line) for line in lines[1:]]
    eps = float(args.spec[:-4].split('/')[-1].split('_')[-1])

if args.net.endswith('fc1'):
    net = FullyConnected(DEVICE, INPUT_SIZE, [50, 10]).to(DEVICE)

elif args.net.endswith('fc2'):
    net = FullyConnected(DEVICE, INPUT_SIZE, [100, 50, 10]).to(DEVICE)

elif args.net.endswith('fc3'):
    net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 10]).to(DEVICE)

elif args.net.endswith('fc4'):
    net = FullyConnected(DEVICE, INPUT_SIZE, [100, 100, 50, 10]).to(DEVICE)

elif args.net.endswith('fc5'):
    net = FullyConnected(DEVICE, INPUT_SIZE, [
                            100, 100, 100, 100, 10]).to(DEVICE)
else:
    assert False

net.load_state_dict(torch.load('../mnist_nets/%s.pt' %
                    args.net, map_location=torch.device(DEVICE)))

inputs = torch.FloatTensor(pixel_values).view(
    1, 1, INPUT_SIZE, INPUT_SIZE).to(DEVICE)

outs = net(inputs)
pred_label = outs.max(dim=1)[1].item()
assert pred_label == true_label

print(f'Network: {net}')
print(f'Inputs shape: {inputs.shape}')
print(f'Epsilon: {eps}')
print(f'True label: {true_label}')

print()

if analyze(net, inputs, eps, true_label):
    print('verified')
else:
    print('not verified')


Network: FullyConnected(
  (layers): Sequential(
    (0): Normalization()
    (1): Flatten(start_dim=1, end_dim=-1)
    (2): Linear(in_features=784, out_features=50, bias=True)
    (3): SPU()
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)
Inputs shape: torch.Size([1, 1, 28, 28])
Epsilon: 0.018
True label: 1

Low: tensor([[-2.6558, -2.8654,  0.3886,  0.1550, -4.0704,  5.2014,  2.0776, -0.8347,
         -4.7754, -5.9432]])
High: tensor([[ -35.3241, -358.8481, -322.7823,  487.9682,  113.4878, -659.5154,
          619.2648,  352.8987,  673.6779, -608.7426]])
Outputs: torch.Size([1, 10])
Weight: tensor([[-0.0196,  0.0307, -0.0089,  ...,  0.0421, -0.0024,  0.0109],
        [-0.0294, -0.0135,  0.0263,  ...,  0.0180,  0.0147, -0.0285],
        [ 0.0374, -0.0079,  0.0339,  ..., -0.0065, -0.0236,  0.0149],
        ...,
        [-0.0296,  0.0230,  0.0202,  ..., -0.0023, -0.0060, -0.0028],
        [-0.0380,  0.0218, -0.0011,  ...,  0.0294,  0.0266,  0.0111],
        [-0.0025, -0

In [6]:
m = torch.nn.Linear(20, 30)

m.weight.detach()

tensor([[ 9.1838e-02,  1.4977e-01, -8.3062e-02, -1.5281e-01,  9.1426e-02,
         -2.1549e-01, -1.7435e-01, -8.3057e-02, -9.5209e-02, -7.2022e-02,
         -1.7632e-01, -1.6243e-01, -1.7700e-01,  2.1169e-01,  8.5248e-03,
          1.2260e-01, -2.6617e-02, -1.7812e-02,  2.3447e-04,  1.3785e-03],
        [ 5.7116e-02,  1.9791e-01,  2.0694e-01, -1.2121e-01,  1.6504e-01,
         -1.1325e-01,  1.1331e-01,  1.3060e-01, -5.3863e-02,  8.2702e-02,
         -8.7948e-02,  6.6766e-02, -1.2768e-01, -1.7938e-01,  1.3652e-01,
          9.9580e-02, -1.6689e-01,  8.8625e-03, -1.4726e-01,  4.2850e-03],
        [-1.6622e-01,  2.1482e-02, -4.1756e-04, -4.4849e-02,  1.6238e-01,
         -9.6830e-02,  8.3080e-02,  1.7443e-01, -2.1331e-01,  1.4089e-01,
         -1.7067e-01,  2.1659e-02,  3.9886e-02,  6.9913e-02,  1.7665e-01,
         -1.7926e-02, -2.0548e-01, -1.8443e-01, -1.9268e-01, -1.4712e-02],
        [-1.1924e-01,  6.5098e-02,  1.3339e-02, -1.1742e-02,  1.2709e-01,
          1.1218e-02, -4.3042e-02, 