In [1]:
import math
import numpy as np
import random
import os
import argparse
from pysat.solvers import Glucose3
from tqdm import tqdm

def write_dimacs_to(n_vars, iclauses, out_filename):
    with open(out_filename, 'w') as f:
        f.write("p cnf %d %d\n" % (n_vars, len(iclauses)))
        for c in iclauses:
            for x in c:
                f.write("%d " % x)
            f.write("0\n")

def mk_out_filenames(opts, n_vars, t):
    prefix = "%s/sr_n=%.4d_pk2=%.2f_pg=%.2f_t=%d" % \
        (opts['out_dir'], n_vars, opts['p_k_2'], opts['p_geo'], t)
    return ("%s_sat=0.dimacs" % prefix, "%s_sat=1.dimacs" % prefix)

def generate_k_iclause(n, k):
    vs = np.random.choice(n, size=min(n, k), replace=False)
    iclause_numpy = [v + 1 if random.random() < 0.5 else -(v + 1) for v in vs]
    
    return [int(lit) for lit in iclause_numpy]

def gen_iclause_pair(opts):
    n = random.randint(opts['min_n'], opts['max_n'])

    solver = Glucose3()
    #for i in range(n): solver.new_var(dvar=True)

    iclauses = []

    while True:
        k_base = 1 if random.random() < opts['p_k_2'] else 2
        k = k_base + np.random.geometric(opts['p_geo'])
        iclause = generate_k_iclause(n, k)

        solver.add_clause(iclause)
        is_sat = solver.solve()
        if is_sat:
            iclauses.append(iclause)
        else:
            break

    iclause_unsat = iclause
    iclause_sat = [- iclause_unsat[0] ] + iclause_unsat[1:]
    return n, iclauses, iclause_unsat, iclause_sat


def create_dataset(opts):
# create output directory..replace if it exists
    out_dir = opts['out_dir']
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    else:
        print("Output directory %s already exists. Replacing." % out_dir)
        os.system("rm -rf %s" % out_dir)
        os.makedirs(out_dir)
    for pair in tqdm(range(opts['n_pairs'])):
        n_vars, iclauses, iclause_unsat, iclause_sat = gen_iclause_pair(opts)
        out_filenames = mk_out_filenames(opts, n_vars, pair)

        iclauses.append(iclause_unsat)
        write_dimacs_to(n_vars, iclauses, out_filenames[0])

        iclauses[-1] = iclause_sat
        write_dimacs_to(n_vars, iclauses, out_filenames[1])


opts = {'max_n':20,
        'min_n':20,
        'p_k_2':0.3,
        'p_geo':0.4,
        'out_dir':"../temp/sr10/train",
        'n_pairs': 10000
}

create_dataset(opts)
opts['n_pairs']=500
opts['out_dir']="../temp/sr10/val"
create_dataset(opts)

Output directory ../temp/sr10/train already exists. Replacing.


100%|██████████| 10000/10000 [00:40<00:00, 249.62it/s]


Output directory ../temp/sr10/val already exists. Replacing.


100%|██████████| 500/500 [00:01<00:00, 250.40it/s]


In [3]:
from pathlib import Path
from pysat.formula import CNF
from tqdm import tqdm
from collections import defaultdict
import random
import pickle
from pysat.solvers import Glucose3
import numpy as np

def get_example_dic(formula):
    #max_clauses = 100
    clauses,solution = formula['clauses'],formula['solution']
    #clauses = clauses[:max_clauses]
    num_clauses = len(clauses)
    num_vars = len(solution)
    empty_token = -num_vars - 1
    solution_vec = np.zeros(num_vars)
    for i,s in enumerate(solution):
        solution_vec[abs(s)-1] = -1 if s < 0 else 1

    #dummy_solution_vec = np.zeros(num_vars+max_clauses)
    #solution_vec = np.concatenate([solution_vec,dummy_solution_vec])

    pos = [i+1 for i in range(num_vars)]
    neg = [-1*(i+1) for i in range(num_vars)]
    lits = pos+neg
    #print(lits)
    num_lits = len(lits)
    clause_inc = num_vars + 1 # positive indices + 1 dummy token
    clause_dic = {}
    lit2clauses = defaultdict(list)
    clause_tokens = []
    for i,cl in enumerate(clauses):
        cl_ix = i+clause_inc
        clause_tokens.append(cl_ix)
        clause_dic[cl_ix]=cl
        for l in cl:
            lit2clauses[l].append(cl_ix)
    tokens = np.array(lits+clause_tokens)

    max_lit_incidence = max([len(l) for l in lit2clauses.values()])
    max_clause_incidence = max([len(l) for l in clause_dic.values()])
    max_pos = max(max_lit_incidence,max_clause_incidence)
    # get pos matrix and mask matrix
    clause_pos = np.ones((max_pos,num_clauses)) * empty_token
    clause_mask = np.zeros((max_pos,num_clauses))
    lit_pos = np.ones((max_pos,num_lits)) * empty_token
    lit_mask = np.zeros((max_pos,num_lits))

    for k,v in lit2clauses.items():
        len_v = len(v)
        ix = lits.index(k)
        lit_pos[:len_v,ix] = v
        lit_mask[:len_v,ix] = 1

    for k,v in clause_dic.items():
        len_v = len(v)
        ix = clause_tokens.index(k)
        clause_pos[:len_v,ix] = v
        clause_mask[:len_v,ix] = 1

    pos_matrix = np.concatenate([lit_pos,clause_pos],axis=1) + num_vars + 1# shifting to positive indices
    mask_matrix = np.concatenate([lit_mask,clause_mask],axis=1)
    tokens = tokens + num_vars + 1  # shifting to positive indices
    tokens = tokens.astype(np.int64)
    pos_matrix = pos_matrix.astype(np.int64)
    # queries = np.zeros((len(tokens),300))
    # for i,tok in enumerate(tokens):
    #     queries[i,tok] = 1
    # keys = np.zeros((len(tokens),300))
    # for i,pos in enumerate(pos_matrix.T):
    #     for j,p in enumerate(pos):
    #         if p > 0:
    #             keys[i,p] = 1

    dic = {
        'pos_embeddings':pos_matrix,
        'masks':mask_matrix.astype(np.int64),
        'labels':solution_vec.astype(np.int64),
        'tokens':tokens,
        #'keys':keys,
        #'queries':queries,
    }
    return dic,max_lit_incidence,max_clause_incidence


def get_data_point(f):
    cnf = CNF(from_file=str(f))
    solver = Glucose3()
    for clause in cnf.clauses:
        solver.add_clause(clause)
    
    is_sat = solver.solve()
    if is_sat:
        solution = solver.get_model()
    else:
        return None
    return {
        'clauses': cnf.clauses,
        'n_vars': cnf.nv,
        'solution': solution,
        'is_sat': is_sat
    }
maxes = []
def get_numpy_data(path):
    out_dir = Path(path)
    files = list(out_dir.glob("*.dimacs"))
    formulas = []
    for i,f in enumerate(tqdm(files)):
        datapoint = get_data_point(f)  
        if datapoint == None or len(datapoint['clauses']) > 220:
            continue  
        example_dic,max_lit_incidence,max_clause_incidence = get_example_dic(datapoint)
        maxes.append((max_lit_incidence,max_clause_incidence))

        formulas.append(example_dic)
    dics = {}
    for i,f in enumerate(formulas):
        dics[i] = f
    return dics

val_path = "../temp/sr10/val"
train_path = "../temp/sr10/train"

val_formulas = get_numpy_data(val_path)
train_formulas = get_numpy_data(train_path)
print('len(train_formulas):',len(train_formulas))
print('len(val_formulas):',len(val_formulas))
with open("../temp/val_formulas.pkl", "wb") as f:
    pickle.dump(val_formulas, f)

with open("../temp/train_formulas.pkl", "wb") as f:
    pickle.dump(train_formulas, f)


  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:00<00:00, 1668.05it/s]
100%|██████████| 20000/20000 [00:09<00:00, 2043.05it/s]


len(train_formulas): 9994
len(val_formulas): 500


In [5]:
maxes

[(21, 12),
 (18, 11),
 (21, 13),
 (21, 13),
 (29, 13),
 (19, 11),
 (16, 14),
 (22, 12),
 (25, 15),
 (22, 11),
 (20, 9),
 (21, 12),
 (19, 13),
 (25, 11),
 (20, 12),
 (23, 10),
 (15, 15),
 (18, 14),
 (16, 14),
 (22, 9),
 (19, 12),
 (20, 14),
 (18, 9),
 (27, 17),
 (18, 11),
 (14, 14),
 (19, 13),
 (18, 12),
 (27, 11),
 (22, 14),
 (15, 10),
 (23, 15),
 (15, 12),
 (17, 10),
 (20, 12),
 (18, 17),
 (19, 13),
 (21, 12),
 (22, 12),
 (23, 11),
 (17, 17),
 (17, 10),
 (19, 10),
 (16, 13),
 (16, 13),
 (18, 9),
 (23, 17),
 (16, 17),
 (26, 14),
 (26, 14),
 (23, 13),
 (18, 11),
 (15, 11),
 (18, 13),
 (22, 12),
 (13, 11),
 (18, 9),
 (22, 12),
 (20, 12),
 (20, 9),
 (17, 11),
 (17, 14),
 (11, 11),
 (17, 16),
 (19, 10),
 (12, 13),
 (21, 11),
 (21, 11),
 (22, 13),
 (14, 10),
 (26, 15),
 (15, 13),
 (13, 12),
 (22, 16),
 (27, 11),
 (21, 13),
 (19, 12),
 (26, 20),
 (23, 16),
 (23, 19),
 (25, 17),
 (17, 14),
 (22, 14),
 (31, 14),
 (25, 13),
 (14, 10),
 (18, 13),
 (21, 9),
 (18, 20),
 (22, 15),
 (19, 12),
 (19, 

In [4]:
print('len(train_formulas):',len(train_formulas))
print('len(val_formulas):',len(val_formulas))

len(train_formulas): 9994
len(val_formulas): 500


In [5]:
f = "../temp/sr10/train/sr_n=0006_pk2=0.30_pg=0.40_t=1_sat=1.dimacs"
datapoint = get_data_point(f)  
example_dic,m1,m2  = get_example_dic(datapoint)
clauses = datapoint['clauses']
tokens= example_dic['tokens']
pos_embeddings = example_dic['pos_embeddings']



107


array([  8,   9,  10,  11,  12,  13,   6,   5,   4,   3,   2,   1,  14,
        15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
        28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,
        41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,
        54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,
        67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,
        80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,
        93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105,
       106, 107, 108])

(226,)