In [1]:
from sage.sat.solvers.dimacs import DIMACS
from sage.all import *
from sage.combinat.posets.poset_examples import Posets
import itertools

In [2]:
import sage.libs.ecl
sage.libs.ecl.ecl_eval("(ext:set-limit 'ext:heap-size 0)")

<ECL: 0>

In [3]:
from enum import Enum

class VarType(Enum): 
    IsLess = 1
    IsMore = 2
    IsUsed = 3

In [4]:
def construct_variable_map(poset, linear_extensions_count):
    '''
    is_lower[a, b, i]  > 0 means that a < b in the i-th linear extension
    is_higher[a, b, i] > 0 means that a > b in the i-th linear extension
    Is_used[a, i]      > 0 means that a is in the i-th linear extension
    '''
    variable_idx = 1
    
    is_lower = {}    
    for a,b in itertools.combinations(poset, int(2)):
        for i in range(linear_extensions_count):
            is_lower[a,b,i] = variable_idx
            variable_idx += 1

    is_higher = {}    
    for a,b in itertools.combinations(poset, int(2)):
        for i in range(linear_extensions_count):
            is_higher[a,b,i] = variable_idx
            variable_idx += 1

    is_used = {}    
    for a in poset:
        for i in range(linear_extensions_count):
            is_used[a, i] = variable_idx
            variable_idx += 1
    
    def is_less_variable_getter(*args):
        a, b, i = args
        if b < a:
            raise Exception("b < a")
        return is_lower[a,b,i]
    
    def is_more_variable_getter(*args):
        a, b, i = args
        if b < a:
            raise Exception("b < a")
        return is_higher[a,b,i]
    
    def is_used_variable_getter(*args):
        a, i = args
        return is_used[a, i]

    def variable_getter(VarType, *args):
        match VarType:
            case VarType.IsLess: return is_less_variable_getter(*args)
            case VarType.IsMore: return is_more_variable_getter(*args)
            case VarType.IsUsed: return is_used_variable_getter(*args)
            case _: raise Exception("Unknown variable type")
        
    return variable_getter

In [5]:
def get_transitivity_clauses(poset, linear_extensions_count, var_getter):
    '''
    For each linear extension and each a, b, c that are used:
        (a > b or b > c or a < c)
    Because if a < b and b < c that implies a < c
    '''

    clauses = []

    for a, b, c in itertools.combinations(poset, int(3)):
        for i in range(linear_extensions_count):
            use_a = var_getter(VarType.IsUsed, a, i)
            use_b = var_getter(VarType.IsUsed, b, i)
            use_c = var_getter(VarType.IsUsed, c, i)
            a_less_b = var_getter(VarType.IsLess, a, b, i)
            b_less_c = var_getter(VarType.IsLess, b, c, i)
            a_less_c = var_getter(VarType.IsLess, a, c, i)
            clauses.append((-use_a, -use_b, -use_c, -a_less_b, -b_less_c, a_less_c))

    for a, b, c in itertools.combinations(poset, int(3)):
        for i in range(linear_extensions_count):
            use_a = var_getter(VarType.IsUsed, a, i)
            use_b = var_getter(VarType.IsUsed, b, i)
            use_c = var_getter(VarType.IsUsed, c, i)
            a_more_b = var_getter(VarType.IsMore, a, b, i)
            b_more_c = var_getter(VarType.IsMore, b, c, i)
            a_more_c = var_getter(VarType.IsMore, a, c, i)
            clauses.append((-use_a, -use_b, -use_c, -a_more_b, -b_more_c, a_more_c))

    return clauses

In [6]:
def generate_use_clauses(poset, linear_extensions_count, local_dim, var_getter):
    '''
    1. a < b => a is used and b is used
    2. a > b => a is used and b is used 
    3. if a and b are used then a < b or a > b
    4. a is used at most local_dim times
    '''
    
    clauses = []

    for a, b in itertools.combinations(poset, int(2)):
        for i in range(linear_extensions_count):
            is_lower = var_getter(VarType.IsLess, a, b, i)
            is_used_a = var_getter(VarType.IsUsed, a, i)
            is_used_b = var_getter(VarType.IsUsed, b, i) 
            clauses.append((is_used_a, -is_lower))
            clauses.append((is_used_b, -is_lower))

    for a, b in itertools.combinations(poset, int(2)):
        for i in range(linear_extensions_count):
            is_higher = var_getter(VarType.IsMore, a, b, i)
            is_used_a = var_getter(VarType.IsUsed, a, i)
            is_used_b = var_getter(VarType.IsUsed, b, i)
            clauses.append((is_used_a, -is_higher))
            clauses.append((is_used_b, -is_higher))
            
    for a, b in itertools.combinations(poset, int(2)):
        for i in range(linear_extensions_count):
            is_used_a = var_getter(VarType.IsUsed, a, i)
            is_used_b = var_getter(VarType.IsUsed, b, i)
            is_lower = var_getter(VarType.IsLess, a, b, i)
            is_higher = var_getter(VarType.IsMore, a, b, i)
            clauses.append((-is_used_a, -is_used_b, is_lower, is_higher))

    cannot_be_used_times = local_dim + 1
    for a in poset:
        for x in itertools.combinations(range(linear_extensions_count), int(cannot_be_used_times)):
            clauses.append(tuple([-var_getter(VarType.IsUsed, a, i) for i in x]))

    return clauses

In [7]:
def generate_poset_clauses(boolean_graph, linear_extensions_count, var_getter):
    '''
    1. Keep same order 
    2. If a || b then in one linear extension a < b and in another b < a
    '''
    
    clauses = []

    # For each a and b that a < b in Boolean Lattice there must be PLE that a < b 
    for a, b in sorted(boolean_graph.edges(labels=False)):
        clauses.append(tuple([var_getter(VarType.IsLess, a, b, i) for i in range(linear_extensions_count)]))
        for i in range(linear_extensions_count):
            clauses.append(tuple([-var_getter(VarType.IsMore, a, b, i)]))
                 

    # For each a and b that a || b in Boolean Lattice
    for a in sorted(boolean_graph.vertices()):
        for b in sorted(boolean_graph.vertices()):
            if a != b and a < b and not boolean_graph.has_edge(a, b):
                one_way = [var_getter(VarType.IsLess, a, b, i) for i in range(linear_extensions_count)]
                another_way = [var_getter(VarType.IsMore, a, b, i) for i in range(linear_extensions_count)]
                clauses.append(tuple(one_way))
                clauses.append(tuple(another_way))
        
    return clauses

In [8]:
def generate_less_more_clauses(poset, linear_extensions_count, var_getter):
    '''
    1. a < b => !(b < a) and  b < a => !(a < b)
    '''
    
    clauses = []

    for a, b in itertools.combinations(poset, int(2)):
        for i in range(linear_extensions_count):
            is_lower = var_getter(VarType.IsLess, a, b, i)
            is_higher = var_getter(VarType.IsMore, a, b, i)
            clauses.append((-is_lower, -is_higher))
            
    return clauses

In [9]:
def get_boolean_graph(dim):
    B = Posets.BooleanLattice(dim)
    return DiGraph([x for x in B.relations() if x[0] != x[1]])

In [10]:
def generate_clauses(dim, linear_extensions_count, local_dim):
    boolean_graph = get_boolean_graph(dim)

    vertices = sorted(boolean_graph.vertices())
    var_getter = construct_variable_map(vertices, linear_extensions_count)

    clauses = get_transitivity_clauses(vertices, linear_extensions_count, var_getter) # done
    clauses += generate_poset_clauses(boolean_graph, linear_extensions_count, var_getter) # done
    clauses += generate_use_clauses(vertices, linear_extensions_count, local_dim, var_getter)
    clauses += generate_less_more_clauses(vertices, linear_extensions_count, var_getter)

    return clauses

def save_problem(dim, local_dim, linear_extensions_count, file_name):
    clauses = generate_clauses(dim, linear_extensions_count, local_dim)

    sat_generator = DIMACS()
    for c in clauses:
        sat_generator.add_clause(c)
    sat_generator.clauses(file_name)

In [11]:
save_problem(
    dim = 7,
    local_dim = 5,
    linear_extensions_count = 7,
    file_name = "dim757.dimacs"
)

In [12]:
def to_str(x):
    str = ''

    current_letter = 'a'
    current_number = 1

    while current_number <= x:
        if x & current_number:
            str += current_letter
        current_number *= 2
        current_letter = chr(ord(current_letter) + 1)      
     

    has_any_letter = any(c.isalpha() for c in str) 
    return str if has_any_letter else "(/)"

In [13]:
from functools import cmp_to_key

def recover_order(values, variables, n, lin_ext_idx):
    is_used = lambda x: values[variables(VarType.IsUsed, x, lin_ext_idx)] > 0
    ordering = list(filter(is_used,[e for e in range(1 << n)]))
    

    def Compare(a, b):
        was_reversed = False
        if a > b:
            a, b = b, a
            was_reversed = True
        is_lower = values[variables(VarType.IsLess, a, b, lin_ext_idx)]
        is_higher = values[variables(VarType.IsMore, a, b, lin_ext_idx)]

        multiplier = 1 if not was_reversed else -1
        if is_lower:
            return -1 * multiplier
        if is_higher:
            return 1 * multiplier
        
    return sorted(ordering, key=cmp_to_key(Compare))

In [14]:
def to_list(data):
    return [0] + sum(list(map(lambda x:  x.strip().split(' ')[1:], data.split("\n"))), [])

In [15]:
def to_map(list_data):
    values = {i:int(list_data[i]) > 0 for i in range(len(list_data))}
    values.update({-i : not (int(list_data[i]) > 0) for i in range(len(list_data))})
    return values

In [16]:
def parse_solution(n, le_count, data, should_to_string = False):
    variables = construct_variable_map(sorted(get_boolean_graph(n).vertices()), le_count)
    values = to_map(to_list(data))
    result = [recover_order(values, variables, n, j) for j in range(le_count)]
    if should_to_string:
        return [list(map(to_str, x)) for x in result]
    return result


In [17]:
file = "75.out"
data = open(file, "r").read()

In [18]:
parse_solution(7, 7, data, False)

[[0,
  65,
  72,
  4,
  68,
  12,
  73,
  13,
  16,
  80,
  20,
  24,
  88,
  28,
  69,
  92,
  17,
  25,
  89,
  29,
  93,
  2,
  66,
  3,
  18,
  19,
  82,
  22,
  67,
  83,
  7,
  10,
  26,
  86,
  90,
  30,
  27,
  71,
  31,
  78,
  79,
  91,
  87,
  94,
  95,
  32,
  40,
  36,
  96,
  48,
  52,
  112,
  100,
  116,
  34,
  98,
  38,
  99,
  50,
  102,
  42,
  103,
  114,
  115,
  54,
  118,
  119,
  104,
  106,
  56,
  121,
  43,
  123,
  44,
  108,
  46,
  110,
  111,
  60,
  62,
  63,
  124,
  127],
 [0,
  112,
  73,
  104,
  33,
  2,
  74,
  97,
  98,
  81,
  99,
  43,
  82,
  106,
  75,
  49,
  113,
  114,
  115,
  105,
  107,
  88,
  89,
  90,
  91,
  120,
  58,
  121,
  122,
  59,
  123,
  4,
  68,
  5,
  69,
  20,
  84,
  21,
  85,
  36,
  37,
  52,
  53,
  100,
  116,
  101,
  117,
  12,
  28,
  92,
  13,
  29,
  93,
  6,
  70,
  109,
  60,
  124,
  61,
  125,
  22,
  86,
  14,
  78,
  30,
  94,
  38,
  54,
  102,
  118,
  62,
  126,
  7,
  71,
  39,
  103,
  23,
  15,
  3