In [1]:
import torch
from models import models_with_args
import numpy as np

from pyunigen import Sampler
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.data.dataset import files_exist, __repr__
from torch_sparse import SparseTensor
import numpy as np
import PyMiniSolvers.minisolvers as minisolvers
from pysat.solvers import Glucose3

In [2]:
model_name = 'NeuroSATRNN'
model_type = models_with_args[model_name]
model_class = model_type['model_class']
model_args = model_type['model_args']

model = model_class(**model_args)

In [3]:
checkpoint = torch.load("temp/tb_logs/Final/version_28/checkpoints/epoch=16-step=3332.ckpt")
state_dict = checkpoint['state_dict']

new_state_dict = {}
for key, value in state_dict.items():
    new_key = key[6:]  # Remove "model." prefix
    new_state_dict[new_key] = value

model.load_state_dict(new_state_dict)


<All keys matched successfully>

In [4]:
print(model)

NeuroSATRNN(
  (L_init): Linear(in_features=1, out_features=32, bias=True)
  (C_init): Linear(in_features=1, out_features=32, bias=True)
  (LC_msgs): LCMessagesRNN()
  (CL_msgs): CLMessagesRNN()
  (L_vote): Linear(in_features=32, out_features=1, bias=True)
  (true_vec_mult): Linear(in_features=32, out_features=1, bias=False)
)


In [5]:
true_vec = new_state_dict["true_vec_mult.weight"]
print(true_vec)

tensor([[ 0.0617, -0.3002,  0.1244, -0.9027,  0.0756, -0.9489, -0.0096, -0.2207,
         -0.0958, -0.0605, -0.6417, -0.0467, -0.0131, -0.0133, -0.0684,  0.2073,
          0.5979, -0.1113, -0.2064, -0.1717, -0.0651,  0.1142,  0.0578,  1.0465,
         -0.8400, -0.6261,  0.0703, -0.0857,  0.0155,  0.6567, -0.0643, -0.0234]],
       device='cuda:0')


In [6]:
device = "cuda"
model.to(device)
num_iters = 30

In [7]:
class Problem(Data):
    # a Problem is a bipartite graph

    def __init__(self, edge_index=None, x_l=None, x_c=None, y=None, clauses = None, sat_assignment=None, sampled_solutions=None):
        # edge_index is a bipartite adjacency matrix between lits and clauses
        # x_l is the feature vector of the lits nodes
        # x_c is the feature vector of the clauses nodes
        # y is the label: 1 if sat, 0 if unsat

        super(Problem, self).__init__()
        # nodes features
        self.x_l = x_l
        self.x_c = x_c
        self.y = y
        self.sat_assignment = sat_assignment
        self.sampled_solutions = sampled_solutions

        self.num_literals = x_l.size(0) if x_l is not None else 0
        self.num_clauses = x_c.size(0) if x_c is not None else 0

        # edges
        self.edge_index = edge_index
        self.adj_t = SparseTensor(row = edge_index[1],
                                  col = edge_index[0],
                                  sparse_sizes = [self.num_clauses, self.num_literals]
                                 ) if edge_index is not None else 0

        # compute number of variables
        assert self.num_literals %2 == 0
        self.num_vars = self.num_literals // 2

        self.num_nodes = self.num_literals + self.num_clauses

        self.clauses = clauses

def parse_dimacs(filename):
        with open(filename, 'r') as f:
            lines = f.readlines()

        i = 0
        while lines[i].strip().split(" ")[0] == "c":
            # strip : remove spaces at the beginning and at the end of the string
            i += 1

        header = lines[i].strip().split(" ")
        assert(header[0] == 'p')
        n_vars = int(header[2])
        clauses = [[int(s) for s in line.strip().split(" ")[:-1]] for line in lines[i+1:]]
        return n_vars, clauses

def create_problem(n_vars, clauses, y):
        # d is the number of features of x_l and x_c

        n_lits = int(2 * n_vars)
        n_clauses = len(clauses)


        #ADDED!
        d = 16
        l_init = torch.normal(mean=0.0, std=1.0, size=(1,d))
        c_init = torch.normal(mean=0.0, std=1.0, size=(1,d)) 
        denom = torch.sqrt(torch.tensor(d, dtype=torch.float32))

        # create feature vectors for lits and clauses
        x_l = (torch.div(l_init, denom)).repeat(n_lits, 1)
        x_c = (torch.div(c_init, denom)).repeat(n_clauses, 1)

        # get graph edges from list of clauses
        edge_index = [[],[]]
        for i,clause in enumerate(clauses):
            # get idxs of lits in clause
            lits_indices = [from_lit_to_idx(l, n_vars) for l in clause]
            clauses_indices = len(clause) * [i]

            # add all edges connected to clause i to edge_index
            edge_index[0].extend(lits_indices)
            edge_index[1].extend(clauses_indices)

        # convert edge_index to tensor
        edge_index = torch.tensor(edge_index, dtype=torch.long)

        if y:
            solver = Glucose3()
            for clause in clauses:
                solver.add_clause(clause)
            solver.solve()
            satisfying_assignment = solver.get_model()

            s = Sampler()
            for cl in clauses:
                s.add_clause(cl)
            _, _, samples = s.sample(num=10)

            unique_samples = []        
            for sample in samples:
                if sample not in unique_samples:
                    unique_samples.append(sample)
            

            #cat_unique_samples = []
            #for sample in unique_samples:
            #    cat_unique_samples+=sample
            #torch.tensor(np.sign(cat_unique_samples), dtype=torch.float)
            samples_tensor = torch.tensor(np.sign(samples), dtype=torch.float).permute(1, 0)

            prob = Problem(edge_index, x_l, x_c, y, clauses, torch.tensor(np.sign(satisfying_assignment), dtype=torch.float), samples_tensor) #
            print(prob)
            return prob
        return Problem(edge_index, x_l, x_c, y, clauses)

def solve_sat(n_vars, iclauses):
    solver = minisolvers.MinisatSolver()

    for i in range(n_vars):
        solver.new_var(dvar=True) # dvar=True <- this var will be used as a decision var

    for iclause in iclauses:
        solver.add_clause(iclause)

    is_sat = solver.solve()
    stats = solver.get_stats() # dictionary of solver statistics

    return is_sat, stats

def from_lit_to_idx(lit, n_vars):
        # from a literal in range {1,...n_vars,-1,...,-n_vars} get the literal
        # index in {0,...,n_lits-1} = {0,...,2*n_vars-1}
        # if l is positive l <- l-1
        # if l in negative l <- n_vars-l-1
        assert(lit!=0)
        if lit > 0 :
            return lit - 1
        if lit < 0 :
            return n_vars - lit - 1

def from_index_to_lit(idx, n_vars):
    # inverse of 'from_lit_to_idx', just in case
    if idx < n_vars:
        return idx+1
    else:
        return n_vars-idx-1

In [8]:
raw_path = "temp/cnfs/selsam_3_40/test/sr_n=0040_pk2=0.30_pg=0.40_t=0_sat=1.dimacs"

n_vars, clauses = parse_dimacs(raw_path)
y, _ = solve_sat(n_vars, clauses)
problem = create_problem(n_vars, clauses, y)

single_loader = DataLoader([problem], follow_batch=['x_l','x_c'], batch_size=1)

with torch.no_grad():
    res = model(next(iter(single_loader)).to(device), 80)

print(list(res.keys()))

Problem(x_l=[80, 16], x_c=[276, 16], y=True, sat_assignment=[40], sampled_solutions=[40, 10], num_literals=80, num_clauses=276, edge_index=[2, 1169], adj_t=[276, 80, nnz=1169], num_vars=40, num_nodes=356, clauses=[276])
['final_lits_votes', 'final_lits_mats', 'vote_mean_pool', 'final_truth_assignment', 'each_step_truth_assignments', 'clause_embs_all_steps', 'initial_truth_assignment']


In [15]:
bin_assignment = res["each_step_truth_assignments"][5][0]
bin_assignment = bin_assignment[:bin_assignment.shape[0]//2]
bin_assignment = torch.sign(bin_assignment).cpu().squeeze(1).numpy()

In [16]:
result_literals = []
for ix, assignment in enumerate(bin_assignment):
    result_literals.append(int((ix+1) * assignment))

In [17]:
sat_num = 0
for c in problem.clauses:
    for lit in c:
        if lit in result_literals:
            sat_num +=1
            break
print(len(problem.clauses)-sat_num)

10


In [136]:
import os
test_dir = "temp/cnfs/3sat_100_400/test/"
for filename in os.listdir(test_dir):
    if filename.endswith("1.dimacs"):
        filepath = os.path.join(test_dir, filename)

temp/cnfs/3sat_100_400/test/cnt=38_cls=112_var=28_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=23_cls=208_var=52_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=70_cls=228_var=57_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=90_cls=164_var=41_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=55_cls=132_var=33_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=9_cls=396_var=99_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=59_cls=248_var=62_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=26_cls=200_var=50_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=88_cls=96_var=24_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=62_cls=332_var=83_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=67_cls=432_var=108_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=21_cls=128_var=32_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=34_cls=228_var=57_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=99_cls=200_var=50_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=22_cls=292_var=73_sat=1.dimacs
temp/cnfs/3sat_100_400/test/cnt=78_cls=252_var=63_sat=1.

In [166]:
for _ in range(200):
    raw_path = "temp/cnfs/3sat_100_400/val/cnt=33_cls=400_var=100_sat=1.dimacs"

    n_vars, clauses = parse_dimacs(raw_path)
    y, _ = solve_sat(n_vars, clauses)
    problem = create_problem(n_vars, clauses, y)

    single_loader = DataLoader([problem], follow_batch=['x_l','x_c'], batch_size=1)

    with torch.no_grad():
        res = model(next(iter(single_loader)).to(device), 300)


    gap_by_iter = []
    for it in range(300):
        bin_assignment = res["each_step_truth_assignments"][it][0]
        bin_assignment = bin_assignment[:bin_assignment.shape[0]//2]
        bin_assignment = torch.sign(bin_assignment).cpu().squeeze(1).numpy()

        result_literals = []
        for ix, assignment in enumerate(bin_assignment):
            result_literals.append(int((ix+1) * assignment))
            sat_num = 0
        for c in problem.clauses:
            for lit in c:
                if lit in result_literals:
                    sat_num +=1
                    break
        gap_by_iter.append(len(problem.clauses)-sat_num)
    print(gap_by_iter[0])

Problem(x_l=[200, 16], x_c=[400, 16], y=True, sat_assignment=[100], sampled_solutions=[100, 10], num_literals=200, num_clauses=400, edge_index=[2, 1200], adj_t=[400, 200, nnz=1200], num_vars=100, num_nodes=600, clauses=[400])
56
Problem(x_l=[200, 16], x_c=[400, 16], y=True, sat_assignment=[100], sampled_solutions=[100, 10], num_literals=200, num_clauses=400, edge_index=[2, 1200], adj_t=[400, 200, nnz=1200], num_vars=100, num_nodes=600, clauses=[400])
59
Problem(x_l=[200, 16], x_c=[400, 16], y=True, sat_assignment=[100], sampled_solutions=[100, 10], num_literals=200, num_clauses=400, edge_index=[2, 1200], adj_t=[400, 200, nnz=1200], num_vars=100, num_nodes=600, clauses=[400])
60
Problem(x_l=[200, 16], x_c=[400, 16], y=True, sat_assignment=[100], sampled_solutions=[100, 10], num_literals=200, num_clauses=400, edge_index=[2, 1200], adj_t=[400, 200, nnz=1200], num_vars=100, num_nodes=600, clauses=[400])
57
Problem(x_l=[200, 16], x_c=[400, 16], y=True, sat_assignment=[100], sampled_solution

KeyboardInterrupt: 

In [159]:
gap_by_iter

[66,
 44,
 32,
 44,
 34,
 49,
 33,
 44,
 35,
 40,
 37,
 41,
 35,
 41,
 35,
 44,
 37,
 45,
 37,
 49,
 35,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49,
 36,
 49]

In [132]:
np.min(gap_by_iter)

1

In [131]:
for ix, num in enumerate(gap_by_iter):
    print(f"Iter: {ix} gap: {num}")

Iter: 0 gap: 38
Iter: 1 gap: 49
Iter: 2 gap: 32
Iter: 3 gap: 26
Iter: 4 gap: 31
Iter: 5 gap: 26
Iter: 6 gap: 29
Iter: 7 gap: 25
Iter: 8 gap: 14
Iter: 9 gap: 17
Iter: 10 gap: 14
Iter: 11 gap: 11
Iter: 12 gap: 12
Iter: 13 gap: 9
Iter: 14 gap: 7
Iter: 15 gap: 10
Iter: 16 gap: 11
Iter: 17 gap: 8
Iter: 18 gap: 7
Iter: 19 gap: 8
Iter: 20 gap: 7
Iter: 21 gap: 7
Iter: 22 gap: 6
Iter: 23 gap: 7
Iter: 24 gap: 6
Iter: 25 gap: 7
Iter: 26 gap: 11
Iter: 27 gap: 8
Iter: 28 gap: 6
Iter: 29 gap: 7
Iter: 30 gap: 6
Iter: 31 gap: 7
Iter: 32 gap: 6
Iter: 33 gap: 7
Iter: 34 gap: 5
Iter: 35 gap: 7
Iter: 36 gap: 6
Iter: 37 gap: 8
Iter: 38 gap: 7
Iter: 39 gap: 9
Iter: 40 gap: 5
Iter: 41 gap: 10
Iter: 42 gap: 4
Iter: 43 gap: 7
Iter: 44 gap: 3
Iter: 45 gap: 4
Iter: 46 gap: 2
Iter: 47 gap: 3
Iter: 48 gap: 2
Iter: 49 gap: 4
Iter: 50 gap: 4
Iter: 51 gap: 4
Iter: 52 gap: 4
Iter: 53 gap: 4
Iter: 54 gap: 4
Iter: 55 gap: 4
Iter: 56 gap: 4
Iter: 57 gap: 2
Iter: 58 gap: 2
Iter: 59 gap: 2
Iter: 60 gap: 4
Iter: 61 gap: 4
I

In [148]:
import os
test_dir = "temp/cnfs/3sat_100_400/val/"
all_gaps = []
starting_gap_value = []
for filename in os.listdir(test_dir):
    if filename.endswith("1.dimacs"):
        raw_path = filepath = os.path.join(test_dir, filename)
        n_vars, clauses = parse_dimacs(raw_path)
        y, _ = solve_sat(n_vars, clauses)
        problem = create_problem(n_vars, clauses, y)

        single_loader = DataLoader([problem], follow_batch=['x_l','x_c'], batch_size=1)

        with torch.no_grad():
            res = model(next(iter(single_loader)).to(device), 80)


        gap_by_iter = []
        for it in range(80):
            bin_assignment = res["each_step_truth_assignments"][it][0]
            bin_assignment = bin_assignment[:bin_assignment.shape[0]//2]
            bin_assignment = torch.sign(bin_assignment).cpu().squeeze(1).numpy()

            result_literals = []
            for ix, assignment in enumerate(bin_assignment):
                result_literals.append(int((ix+1) * assignment))
                sat_num = 0
            for c in problem.clauses:
                for lit in c:
                    if lit in result_literals:
                        sat_num +=1
                        break
            gap_by_iter.append(len(problem.clauses)-sat_num)
        starting_gap_value.append(gap_by_iter[0])
        all_gaps.append(np.min(gap_by_iter))

temp/cnfs/3sat_100_400/val/cnt=33_cls=400_var=100_sat=1.dimacs


In [145]:
np.mean(all_gaps)

2.3516483516483517

In [146]:
all_gaps

[34,
 1,
 2,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 22,
 0,
 15,
 1,
 0,
 2,
 1,
 1,
 4,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 12,
 1,
 0,
 11,
 0,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 26,
 9,
 0,
 1,
 0,
 1,
 1,
 0,
 1,
 1,
 3,
 2,
 0,
 0,
 1,
 23,
 0,
 21,
 3,
 0,
 2,
 1,
 0,
 0,
 1,
 0,
 0,
 23,
 2,
 0,
 28,
 0,
 1,
 1,
 0,
 10,
 1,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 3,
 7,
 2,
 1,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 3,
 0,
 1,
 10,
 1,
 1,
 0,
 2,
 0,
 0,
 0,
 0,
 2,
 2,
 0,
 0,
 0,
 1,
 3,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 5,
 1,
 11,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 4,
 14,
 0,
 1,
 1,
 2,
 0,
 0,
 2,
 28,
 0,
 5,
 0,
 0,
 1,
 1,
 0,
 2,
 1,
 0,
 0,
 0,
 0,
 4,
 3,
 1,
 1,
 3,
 0,
 0,
 0,
 1,
 0,
 4,
 0,
 1,
 1,
 0]

In [147]:
starting_gap_value

[63,
 34,
 48,
 38,
 42,
 45,
 53,
 28,
 51,
 49,
 52,
 53,
 54,
 52,
 48,
 53,
 44,
 43,
 48,
 45,
 50,
 45,
 57,
 58,
 40,
 45,
 66,
 47,
 49,
 56,
 57,
 40,
 40,
 62,
 53,
 48,
 48,
 34,
 50,
 54,
 52,
 50,
 50,
 50,
 45,
 47,
 47,
 50,
 52,
 57,
 49,
 42,
 49,
 47,
 59,
 65,
 50,
 52,
 32,
 41,
 42,
 50,
 60,
 55,
 44,
 59,
 57,
 58,
 56,
 60,
 61,
 47,
 46,
 51,
 45,
 47,
 58,
 49,
 46,
 41,
 51,
 63,
 47,
 39,
 54,
 52,
 64,
 58,
 60,
 44,
 65,
 43,
 49,
 51,
 47,
 40,
 36,
 47,
 48,
 45,
 37,
 49,
 53,
 59,
 51,
 47,
 47,
 48,
 56,
 52,
 47,
 58,
 45,
 46,
 63,
 44,
 41,
 53,
 47,
 50,
 53,
 53,
 49,
 45,
 54,
 60,
 46,
 53,
 54,
 50,
 51,
 45,
 50,
 56,
 46,
 56,
 51,
 60,
 41,
 43,
 48,
 48,
 48,
 59,
 52,
 49,
 53,
 39,
 61,
 52,
 63,
 41,
 56,
 50,
 51,
 39,
 46,
 53,
 41,
 41,
 55,
 53,
 57,
 48,
 48,
 46,
 55,
 50,
 48,
 48,
 51,
 55,
 50,
 45,
 58,
 47,
 40,
 51,
 46,
 51,
 51,
 38]