In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
import math
from torch.autograd import Function
import numpy as np
import os
from torchvision import datasets as datasets
from torchvision import transforms as transforms
from torch.utils.data import DataLoader as DataLoader
from math import ceil, floor

from tools import *
from encoding import *
from models import *
from bnn import *

In [3]:
mnist = BNN(name="MNIST", output_dir="mnist_logs/")
mnist.train()

Start/Continue training from epoch 10
Finish training for 10 epochs


In [4]:
verify_net(mnist.net)

# manual_acc(mnist.net, mnist.train_loader)

In [5]:
decomp = decompose_net(mnist.net)
linear, bn = decomp['linear'], decomp['bn']

D = lower_bound(mnist.net)

In [9]:
def r_value(row, m, C):
    return int((row[:m] == 1).sum().item() >= C)

def make_bool(mat):
    mat = mat.cpu().detach().numpy()
    
    return ((mat + 1)/2).astype(int)

In [10]:
from functools import *
from collections import defaultdict

In [53]:
A = linear[0]['weight'][0,:]
b = linear[0]['bias'].unsqueeze(1)[0]

variables = create_vocabulary([linear[0]['weight']], [0], D)
assert (D[0][0]*784 + 784) == len(variables)

cnf = constraints(linear[0]['weight'], 1, 0, D)

for i, (data, target) in enumerate(mnist.test_loader):
    data = data.cuda()
    
    manual = data.view(mnist.train_loader.batch_size, -1, 1)
    manual = binarize(manual)
            
    for j in range(1):
        inp = manual[j]
                
        ## neural network approach
        layer1 = torch.matmul(A, inp) + b
        layer1 = bn[0]['scale'][0] * ((layer1 - bn[0]['mean'][0]) / bn[0]['std'][0]) + bn[0]['bias'][0]
        layer1 = binarize(layer1)
        real_output = ((layer1.squeeze() + 1)/2).item()
        
        ## cnf approach        
        cnf_inp = torch.Tensor(list(map(lambda x,y : (x | (not y)) & ((not x) | y), 
                                        make_bool(A), make_bool(inp))))
        
        assert r_value(cnf_inp, 784, D[0][0]) == real_output # check that r(784, D) works
        
        if not check_encoding(cnf_inp, cnf):
            print("Unsatisfiable.")

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [40]:
fmt_x = "x%d,%d,%d" # block, row, col
fmt_r = "r%d,%d,%d,%d" # block, row, col (m), C value

def create_vocabulary(matrices, blocks, D):
    variables = dict()
    var_idx = 1

    for b, mat in zip(blocks, matrices):
        mat = make_bool(mat)
        
        for row in range(1):
            for col in range(mat.shape[1]):
                variables[fmt_x % (b, row, col)] = var_idx
                var_idx += 1

                for C in range(1, D[b][row] + 1):
                    variables[fmt_r % (b, row, col+1, C)] = var_idx
                    var_idx += 1
    
    return variables

In [41]:
def convert_to_dimacs(cnf_str, variables):
    cnf_str = cnf_str.strip().rstrip('*').split('*')
    cnf_list = list(map(lambda x : x.strip().replace('|', ' '), cnf_str))
        
    print('\n'.join(cnf_list))
    
    ret = ''
    
    for clause in cnf_list:
        no_neg_clause = clause.replace("-", "")
        
        for literal in no_neg_clause.split():
            clause = clause.replace(literal, str(variables[literal]))
        
        ret += (clause + ' 0\n')
    
    return ret

In [42]:
def constraints(mat, num_rows, block, D):
    mat = make_bool(mat)
    res = ''
    
    for row in range(1):
        
        '''constraint 1'''
        x = fmt_x % (block, row, 0)
        r = fmt_r % (block, row, 1, 1)
        
        
        res += (x + '|' + '-' + r) + '*'
        res += ('-' + x + '|' + r) + '*'
        '''end constraint 1'''
        
        
        '''constraint 2'''
        for C in range(2, D[block][row] + 1):
            r = fmt_r % (block, row, 1, C)
            
            res += ('-' + r) + '*'
        '''end constraint 2'''
        
        
        r = mat[row,:]
        
        for i in range(2, len(r) + 1):
            x = fmt_x % (0, 0, i-1)
            xnot = '-' + x

            '''constraint 3'''
            r1 = fmt_r % (0, 0, i, 1)
            r2 = fmt_r % (0, 0, i-1, 1)
        
            res += ('-' + r1 + '|' + x + '|' + r2) + '*'
            res += (xnot + '|' + r1) + '*'
            res += ('-' + r2 + '|' + r1) + '*'
            '''end constraint 3'''
            
            for j in range(2, D[block][row] + 1):
                
                '''constraint 4'''
                r1 = fmt_r % (0, 0, i, j)
                r2 = fmt_r % (0, 0, i-1, j)
                r3 = fmt_r % (0, 0, i-1, j-1)

                res += ('-' + r1 + '|' + x + '|' + r2) + '*'
                res += ('-' + r1 + '|' + r3 + '|' + r2) + '*'
                res += (xnot + '|' + '-' + r3 + '|' + r1) + '*'
                res += ('-' + r2 + '|' + r1) + '*'
                '''end constraint 4'''
    
    return convert_to_dimacs(res, variables)

In [43]:
def check_encoding(row, cnf): 
    row = make_bool(row)
    
    assign = dict()

    for key, value in variables.items():
        if 'x' in key:
            split = key.split(',')
            assign[value] = row[int(split[-1])]

        if 'r' in key:
            split = key.split(',')
            assign[value] = r_value(row, int(split[2]), int(split[3]))
                                                
    ands = 1

    for clause in cnf.split('0\n'):

        if clause == '':
            continue

        ors = 0
        for each in clause.strip().split():
            if '-' in each:
                ors = ors | (not assign[int(each.replace('-', ''))])
            else:
                ors = ors | assign[int(each)]
                
        ands = ands & ors

    return ands

For a given matrix that is NxM, an input vector of size Mx1, and lower bound D, the number of variables needed for a CNF representation is:

N(MD + M)

= N(D+1)(M)

= NM(D+1)



D is a constant which varies by layer:

Average D value for 1st block: 392

Average D value for 2nd block: 784

Average D value for 3rd block: 784

= O(NM) variables


ex. NxM = 1568x784 matrix, with D = 400 ≈ 500 million variables (for one matrix)

In [None]:
# def create_cnf_file(fname, mat, num_rows, block):
#     c = constraints(mat, num_rows, block).rstrip('\n')
        
#     file = open(fname, 'w+')
#     file.write('p cnf ' + str(len(variables)) + ' ' + str(c.count(' 0')) + '\n')
#     file.write(c)
#     file.close()
    
#     return c

# cnf = create_cnf_file('./first_row.cnf', linear[0]['weight'], 1, 0)

In [None]:
# with open('./first_row_solved', 'r') as file:
#     solved = file.read().split()

# solved = {str(abs(int(i))) : i for i in solved}