In [10]:
from pysat.formula import *
from pysat.solvers import Solver
from itertools import permutations
from disjoint_set import DisjointSet
import time
import pandas as pd
import json
import csv

In [2]:
def phpformula(p, h):
  clauses = []

  for i in range(1,p+1):
    clause = []
    for j in range(1,h+1):
      clause.append((i-1)*h+j)
    clauses.append(clause)

  for i in range(1,h+1):
    for j in range(1, p+1):
      for k in range(j+1,p+1):
        clauses.append([-(j-1)*h-i,-(k-1)*h-i])

  return clauses


In [3]:
def nqueens(n):
    clauses = []

    def var(r, c):
        return r * n + c + 1

    for r in range(n):
        clauses.append([var(r, c) for c in range(n)])
        for c1 in range(n):
            for c2 in range(c1 + 1, n):
                clauses.append([-var(r, c1), -var(r, c2)])

    for c in range(n):
        clauses.append([var(r, c) for r in range(n)])
        for r1 in range(n):
            for r2 in range(r1 + 1, n):
                clauses.append([-var(r1, c), -var(r2, c)])

    for d in range(-(n - 1), n):
        diag = [(r, r - d) for r in range(n) if 0 <= r - d < n]
        for i in range(len(diag)):
            for j in range(i + 1, len(diag)):
                r1, c1 = diag[i]
                r2, c2 = diag[j]
                clauses.append([-var(r1, c1), -var(r2, c2)])

    for s in range(2 * n - 1):
        diag = [(r, s - r) for r in range(n) if 0 <= s - r < n]
        for i in range(len(diag)):
            for j in range(i + 1, len(diag)):
                r1, c1 = diag[i]
                r2, c2 = diag[j]
                clauses.append([-var(r1, c1), -var(r2, c2)])

    return clauses


In [4]:
def get_matrix(m, n):
    variables = []
    idx = 1
    for i in range(m):
      variables.append([])
      for j in range(n):
          variables[i].append(idx)
          idx = idx + 1
    return variables

def get_perm(matrix, pi_r, pi_c):
  m = len(matrix)
  n = len(matrix[0])
  varperm = [[0 for _ in range(n)] for _ in range(m)]
  for i in range(m):
    for j in range(n):
        varperm[pi_r[i]][pi_c[j]] = matrix[i][j]
  return varperm

def get_clauses(matrix, matrixperm, kmax, aux):
  clauses = []
  a = []
  b = []  
  for i in range(len(matrix)):
    a = a + matrix[i]
    b = b + matrixperm[i]
  matlen = len(a)
    
  for i in range(len(a)):
    auxcurr = aux[0]
    clause = [-a[i],b[i]]
    ds = DisjointSet()
    if i > kmax:
      break
    ds.union('o', b[i])
    ds.union('i', a[i])
      
    if i > 0:
      for k in range(i):
        if a[k] != b[k]:
            ds.union(a[k], b[k])
            clause.append(auxcurr)
            if k == (i-1):
                clauses.append([-auxcurr,a[k],b[k]])
                clauses.append([-auxcurr,-a[k],-b[k]])
                clauses.append([auxcurr,-a[k],b[k]])
                clauses.append([auxcurr,a[k],-b[k]])
            auxcurr = auxcurr + 1
    if ds.connected('o', 'i'):
        continue
    
    clauses.append(clause)
  aux[0] = auxcurr+1
  return clauses

In [5]:
def swap_positions(p, i, j):
    p = list(p)
    p[i], p[j] = p[j], p[i]
    return p

def all_perm(n):
    return list(permutations(range(n)))

def transpositions(n):
    idp = list(range(n))
    result = []
    for i in range(n):
        for j in range(i+1, n):
            result.append(swap_positions(idp, i, j))
    return result

def neighbors(n):
    idp = list(range(n))
    result = []
    for i in range(n - 1):
        result.append(swap_positions(idp, i, i + 1))
    return result

def gen2(n):
  result = [list((i+1) % n for i in range(n))]
  result.append(swap_positions(list(range(n)), 0, 1))
  return result

def symmetry_clauses(matrix, symtype, perms, kmax):
  m = len(matrix)
  n = len(matrix[0])
  idpr = list(range(m))
  idpc = list(range(n))
  clauses = []
  aux = [m*n+1]  

  if symtype == "r":
    row = perms(m)

    for p in row:
      perm = get_perm(matrix, p, idpc)
      clauses += get_clauses(matrix, perm, kmax, aux)

  elif symtype == "c":
    col = perms(n)

    for p in col:
      perm = get_perm(matrix, idpr, p)
      clauses += get_clauses(matrix, perm, kmax, aux)

  elif symtype == "rc":
    row = perms(m)
    col = perms(n)
    row.append(idpr)
    col.append(idpc)
      
    for p in row:
      for q in col:
        if p == idpr and q == idpc:
            continue
        perm = get_perm(matrix, p, q)
        clauses += get_clauses(matrix, perm, kmax, aux)

  else:
    row = perms(m) # = col

    for p in row:
      for q in row:
        perm = get_perm(matrix, p, q)
        clauses += get_clauses(matrix, perm, kmax, aux)

  return clauses

In [54]:
rows = []
csv.field_size_limit(sys.maxsize)
with open("data.csv") as f:
    reader = csv.DictReader(f)
    for row in reader:
        rows.append([int(row["Rows"]),int(row["Columns"]),row["SymType"],json.loads(row["Clauses"])])

In [62]:
base_clauses = rows[0][3]
cnf = CNF(from_clauses=base_clauses)
s = time.time()
clauses = symmetry_clauses(get_matrix(rows[0][0],rows[0][1]), rows[0][2], gen2, 10)
for clause in clauses:
  cnf.append(clause)
  base_clauses.append(clause)
print("Symmetries: ", (time.time() - s) * 1e3, "ms")
print(" ")

s = time.time()
with Solver(bootstrap_with=cnf) as solver:
    print('formula is', f'{"S" if solver.solve() else "UNS"}ATisfiable')
    #print('and the model is:', solver.get_model())

print(" ")
print("SAT solver: ", (time.time() - s) * 1e3, "ms")

#df.at[0, "Clauses"] = str(base_clauses)
#df.to_csv("result.csv", index=False)

Symmetries:  1.6074180603027344 ms
 
formula is UNSATisfiable
 
SAT solver:  0.6322860717773438 ms


In [6]:
df = pd.read_csv("test.csv")
print(df)

   Rows  Columns SymType                                            Clauses
0    10       10      rc  [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11, 12, 13,...
1    90       90      rc  [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 1...


In [56]:
cnf = CNF(from_clauses=nqueens(n))
s = time.time()
with Solver(bootstrap_with=cnf) as solver:
    print('formula is', f'{"S" if solver.solve() else "UNS"}ATisfiable')
    #print('and the model is:', solver.get_model())
print(" ")
print("Time: ", (time.time() - s) * 1e3, "ms")

formula is SATisfiable
 
Time:  1680.3762912750244 ms


In [63]:
pig = 30
hol = 30
n = 150
with open("data.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["Rows", "Columns","SymType","Clauses"])

    writer.writerow([pig,hol,"rc", json.dumps(phpformula(pig,hol))])
    writer.writerow([n,n,"rc", json.dumps(nqueens(n))])