In [1]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import pathlib
import glob
import numpy as np
import pysat 
from pysat.formula import CNF

import jax
import jax.numpy as jnp
import jax.scipy
import jax.scipy.optimize

import functools


In [16]:

def init_problem(
    cnf_problem: CNF,
    key: jnp.ndarray = None,
    batch_size: int = 10,
    single_device: bool = False,
):
    var_embedding = jax.random.normal(key, (batch_size, cnf_problem.nv))
    var_embedding = jax.nn.sigmoid(var_embedding)
    max_clause_len = max([len(clause) for clause in cnf_problem.clauses])
    num_clauses = len(cnf_problem.clauses)
    literal_tensor = jnp.array(
        [
            [c + (-1) ** (c > 0) for c in clause]
            + [num_clauses] * (max_clause_len - len(clause))
            for clause in cnf_problem.clauses
        ]
    )
    return var_embedding, literal_tensor

def compute_loss(
        params: jnp.ndarray,
        literal_tensor: jnp.ndarray,
    ):
        # params = jax.nn.sigmoid(params)
        x = jnp.take(params, jnp.abs(literal_tensor), fill_value=1.0, axis=0)
        x = jnp.where(literal_tensor > 0, x, 1 - x)
        x = jnp.prod(x, axis=-1)
        return jnp.square(x).sum()
        
@functools.partial(jax.pmap, in_axes=(0, None))
def scan_sat_solutions(
    assignment: jnp.ndarray,
    literal_tensor: jnp.ndarray,
):
    sat = jnp.take(assignment, jnp.abs(literal_tensor), fill_value=1, axis=1)
    sat = jnp.where(literal_tensor > 0, 1 - sat, sat)
    sat = jnp.all(jnp.any(sat > 0, axis=2), axis=1)
    satisfying_row_indices = jnp.where(
        sat, jnp.arange(sat.shape[0]), sat.shape[0] + 1
    )
    return jnp.take(assignment, satisfying_row_indices, axis=0, fill_value=-1)

def get_solutions(
    params: jnp.ndarray,
    literal_tensor: jnp.ndarray,
):
    assignment = (jax.nn.sigmoid(params) > 0.5).astype(int)
    if assignment.ndim == 2:
        assignment = jnp.expand_dims(assignment, axis=0)
    solutions = scan_sat_solutions(assignment, literal_tensor)
    solutions = solutions.reshape((-1, solutions.shape[-1]))
    # remove spurious solutions that are all -1's
    pruned_solutions = jnp.take(
        solutions, jnp.where(jnp.any(solutions >= 0, axis=1))[0], axis=0
    )
    return np.unique(pruned_solutions, axis=1)

In [23]:
# import test
BATCHSIZE=5
key = jax.random.PRNGKey(0)
cnf_file_paths = glob.glob("../data/counting_or/*.gz")
cnf_problem = CNF(from_file=cnf_file_paths[0])

var_embedding, literal_tensor = init_problem(cnf_problem, key, batch_size=BATCHSIZE)
scipy_var_embedding = var_embedding.copy()

In [30]:
var_embedding.shape

(5, 200)

In [25]:
# solve by BFGS
results_array = []
for i in range(BATCHSIZE):
    print(f'initial loss: {compute_loss(var_embedding[i], literal_tensor)}')
    results = jax.scipy.optimize.minimize(
        compute_loss,
        x0=var_embedding[i],
        args=(literal_tensor,),
        method="BFGS",
        tol=1e-4,
    )
    results_array.append(results.x)
    print(results.success)
    print(results.status)
    print(f'final loss: {results.fun}')
    print(results.nit, results.nfev, results.njev)
results_array = jnp.array(results_array)

initial loss: 40.41490936279297
False
3
final loss: 4.885860443115234
4 7 7
initial loss: 39.752296447753906
False
3
final loss: 0.00487174466252327
16 23 23
initial loss: 34.743404388427734
False
3
final loss: 0.34646981954574585
12 15 15
initial loss: 36.14906311035156
False
3
final loss: 12.705944061279297
2 4 4
initial loss: 37.093605041503906
False
3
final loss: 4.0333339711651206e-05
51 61 61


In [26]:
from scipy.optimize import minimize as scipy_minimize
scipy_results = []
for i in range(BATCHSIZE):
    print(f'initial loss: {compute_loss(scipy_var_embedding[i], literal_tensor)}')
    scipy_res = scipy_minimize(
        compute_loss,
        x0=scipy_var_embedding[i],
        args=(literal_tensor,),
        method="BFGS",
        tol=1e-5,
    )
    scipy_results.append(scipy_res.x)
    print(scipy_res.success)
    print(scipy_res.status)
    print(f'final loss: {scipy_res.fun}')
    print(scipy_res.nit, scipy_res.nfev, scipy_res.njev)
scipy_results = jnp.array(scipy_results)

initial loss: 40.41490936279297
True
0
final loss: 40.41490936279297
0 201 1
initial loss: 39.752296447753906
True
0
final loss: 39.752296447753906
0 201 1
initial loss: 34.743404388427734
True
0
final loss: 34.743404388427734
0 201 1
initial loss: 36.14906311035156
True
0
final loss: 36.14906311035156
0 201 1
initial loss: 37.093605041503906
True
0
final loss: 35.96630096435547
1 402 2


In [31]:
# check if the solution is correct
results_array
# solns = get_solutions(results_array, literal_tensor)
# print(solns)
# scipy_solns = get_solutions(scipy_results, literal_tensor)
# print(scipy_solns)

Array([[ 7.73107886e-01,  9.37994301e-01,  4.14450020e-01,
         4.24424827e-01,  9.70469654e-01,  9.37167048e-01,
         8.71247709e-01,  9.53394771e-01,  9.53251779e-01,
         6.72332346e-01,  9.33261871e-01,  6.69383526e-01,
         8.95459831e-01,  3.94907624e-01,  8.63713205e-01,
         4.90017772e-01,  4.91846055e-01,  4.76029694e-01,
         8.82066011e-01,  6.28501654e-01,  9.14022088e-01,
         9.49702621e-01,  3.36053014e-01,  7.16921926e-01,
         7.07220078e-01,  5.58445990e-01,  6.15132511e-01,
         7.22684920e-01,  7.31881022e-01,  8.12650979e-01,
         6.51732862e-01,  4.32724237e-01,  5.87022245e-01,
         9.35380936e-01,  1.10328937e+00,  1.05807543e+00,
         1.10664499e+00,  5.33801794e-01,  9.43049252e-01,
         7.19778836e-01,  7.30655909e-01, -3.32223535e-01,
         7.10221887e-01,  6.05317831e-01,  9.45877314e-01,
         2.27263734e-01,  3.85276109e-01,  6.29055351e-02,
         5.80417573e-01,  5.65227270e-01,  5.87910414e-0