In [4]:
from sage.sat.solvers.dimacs import DIMACS
from sage.all import *
from sage.combinat.posets.poset_examples import Posets
import itertools

In [5]:
from enum import Enum

class VarType(Enum): 
    IsLess = 1

In [6]:
def construct_variable_map(poset, linear_extensions_count):
    '''
    is_lower[a, b, i] > 0 means that a < b in the i-th linear extension
    '''
    variable_idx = 1
    
    is_lower = {}    
    for a,b in itertools.combinations(poset, int(2)):
        for i in range(linear_extensions_count):
            is_lower[a,b,i] = variable_idx
            variable_idx += 1
    
    def is_less_variable_getter(*args):
        a, b, i = args
        if b < a:
            return -is_lower[b,a,i]
        return is_lower[a,b,i]

    def variable_getter(VarType, *args):
        match VarType:
            case VarType.IsLess: return is_less_variable_getter(*args)
            case _: raise Exception("Unknown variable type")
        
    return variable_getter

In [7]:
def get_transitivity_clauses(poset, linear_extensions_count, var_getter):
    '''
    In each linear extension:
        (a > b or b > c or a < c)
    Because if a < b and b < c that implies a < c
    '''

    clauses = []
    for a, b, c in itertools.permutations(poset, int(3)):
        for i in range(linear_extensions_count):
            xab = var_getter(VarType.IsLess, a, b, i)
            xbc = var_getter(VarType.IsLess, b, c, i)
            xac = var_getter(VarType.IsLess, a, c, i)
            clauses.append((-xab, -xbc, xac))
    return clauses

In [8]:
def generate_poset_clauses(boolean_graph, linear_extensions_count, var_getter):
    '''
    1. Keep same order 
    2. If a || b then in one linear extension a < b and in another b < a
    '''
    
    clauses = []

    # For each a and b that a < b in Boolean Lattice 
    for a, b in sorted(boolean_graph.edges(labels=False)):
        for i in range(linear_extensions_count):
            clauses.append(tuple([var_getter(VarType.IsLess, a, b, i)]))

    # For each a and b that a || b in Boolean Lattice
    for a in sorted(boolean_graph.vertices()):
        for b in sorted(boolean_graph.vertices()):
            if a != b and a < b and not boolean_graph.has_edge(a, b):
                one_way = [-var_getter(VarType.IsLess, a, b, i) for i in range(linear_extensions_count)]
                reversed_way = [var_getter(VarType.IsLess, a, b, i) for i in range(linear_extensions_count)]
                clauses.append(tuple(one_way))
                clauses.append(tuple(reversed_way))
        
        
    return clauses

In [9]:
def get_boolean_graph(dim):
    B = Posets.BooleanLattice(dim)
    return DiGraph([x for x in B.relations() if x[0] != x[1]])

In [10]:
def generate_clauses(dim, linear_extensions_count):
    boolean_graph = get_boolean_graph(dim)
    var_getter = construct_variable_map(sorted(boolean_graph.vertices()), linear_extensions_count)

    clauses  = generate_poset_clauses(boolean_graph, linear_extensions_count, var_getter)
    clauses += get_transitivity_clauses(boolean_graph.vertices(), linear_extensions_count, var_getter)

    return clauses

def save_problem(dim, linear_extensions_count, file_name):
    clauses = generate_clauses(dim, linear_extensions_count)

    sat_generator = DIMACS()
    for c in clauses:
        sat_generator.add_clause(c)
    sat_generator.clauses(file_name)

In [11]:
N = 4
save_problem(N, N, "dim.dimacs")

In [12]:
def to_str(x):
    str = ''

    current_letter = 'a'
    current_number = 1

    while current_number <= x:
        if x & current_number:
            str += current_letter
        current_number *= 2
        current_letter = chr(ord(current_letter) + 1)      
     

    has_any_letter = any(c.isalpha() for c in str) 
    return str if has_any_letter else "(/)"

In [13]:
from functools import cmp_to_key

def recover_order(values, variables, n, lin_ext_idx):
    ordering = [e for e in range(1 << n)]

    def Compare(a, b):
        return -1 if values[variables(VarType.IsLess, a, b, lin_ext_idx)] else 1
        
    return list(map(to_str, sorted(ordering, key=cmp_to_key(Compare))))

In [14]:
def to_list(data):
    return [0] + sum(list(map(lambda x:  x.strip().split(' ')[1:], data.split("\n"))), [])

In [15]:
def to_map(list_data):
    values = {i:int(list_data[i]) > 0 for i in range(len(list_data))}
    values.update({-i : not (int(list_data[i]) > 0) for i in range(len(list_data))})
    return values

In [16]:
def parse_solution(n, data):
    variables = construct_variable_map(sorted(get_boolean_graph(n).vertices()), n)
    values = to_map(to_list(data))
    return [recover_order(values, variables, n, j) for j in range(n)]

In [17]:
data = """v 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
v 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
v 54 55 56 57 58 59 60 -61 62 63 64 65 66 67 68 -69 70 71 72 73 74 75 76 -77
v 78 79 80 81 82 83 84 -85 86 87 88 89 90 91 92 -93 94 95 96 97 98 99 100 -101
v 102 103 104 105 106 107 108 -109 110 111 112 113 114 115 116 117 118 119 120
v 121 -122 123 124 125 -126 127 128 129 130 131 132 133 134 135 136 137 -138
v 139 140 141 -142 143 144 145 146 147 148 149 150 151 152 153 -154 155 156
v 157 -158 159 160 161 162 163 164 165 166 167 168 -169 -170 171 172 173 -174
v 175 176 -177 178 179 180 181 182 183 184 -185 -186 187 188 189 -190 191 192
v -193 194 195 196 197 198 199 200 -201 -202 203 204 205 -206 207 208 -209 210
v 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
v 230 -231 232 233 234 -235 236 237 238 -239 240 241 242 -243 244 245 246 247
v 248 249 250 251 252 253 254 255 256 257 258 259 260 -261 262 263 264 265 266
v 267 268 -269 270 -271 272 273 274 -275 276 -277 278 -279 280 281 282 -283
v 284 -285 286 287 288 289 290 291 292 -293 294 295 296 297 298 299 300 301
v 302 303 304 305 -306 -307 308 309 -310 -311 312 313 314 -315 316 317 318
v -319 320 321 -322 323 324 325 -326 327 328 329 330 331 332 333 334 335 336
v -337 -338 -339 340 341 -342 -343 344 -345 346 -347 348 349 350 -351 352 -353
v -354 355 356 357 -358 359 360 -361 362 363 364 365 366 367 368 369 370 371
v 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
v 391 392 393 394 395 396 -397 398 399 400 401 402 403 404 -405 406 407 408
v 409 410 411 412 -413 414 415 416 417 418 419 420 421 422 423 424 425 -426
v 427 428 429 -430 431 432 433 434 435 436 437 438 439 440 -441 -442 443 444
v 445 -446 447 448 -449 450 451 452 453 454 455 456 457 458 459 460 461 462
v 463 464 465 466 467 468 -469 470 471 472 473 474 475 476 477 478 479 480 0"""

In [18]:
parse_solution(4, data)

[['(/)',
  'b',
  'c',
  'bc',
  'd',
  'bd',
  'cd',
  'bcd',
  'a',
  'ab',
  'ac',
  'abc',
  'ad',
  'abd',
  'acd',
  'abcd'],
 ['(/)',
  'a',
  'c',
  'ac',
  'd',
  'ad',
  'cd',
  'acd',
  'b',
  'ab',
  'bc',
  'abc',
  'bd',
  'abd',
  'bcd',
  'abcd'],
 ['(/)',
  'a',
  'b',
  'ab',
  'd',
  'ad',
  'bd',
  'abd',
  'c',
  'ac',
  'bc',
  'abc',
  'cd',
  'acd',
  'bcd',
  'abcd'],
 ['(/)',
  'a',
  'b',
  'ab',
  'c',
  'ac',
  'bc',
  'abc',
  'd',
  'ad',
  'bd',
  'abd',
  'cd',
  'acd',
  'bcd',
  'abcd']]