In [532]:
from sage.sat.solvers.dimacs import DIMACS
from sage.all import *
from sage.combinat.posets.poset_examples import Posets
import itertools

In [533]:
from enum import Enum

class QT(Enum): 
    Less = 1
    YesString = 2 # still not sure about this one

In [534]:
def construct_variable_map(poset, linear_extensions_count):
    '''
    is_lower[a, b, i] > 0 means that a < b in the i-th linear extension
    '''

    is_lower = {}
    yes_map = {}
    
    variable_idx = 1

    for a,b in itertools.combinations(poset, int(2)):
        for i in range(linear_extensions_count):
            is_lower[a,b,i] = variable_idx
            variable_idx += 1

    for s in itertools.product([0,1], repeat=int(linear_extensions_count)):
        yes_map[s] = variable_idx
        variable_idx += 1

    
    def variable_getter(qtype, *args):
        if qtype == QT.Less:
            a, b, i = args
            if b < a:
                return -is_lower[b,a,i]
            return is_lower[a,b,i]
        elif qtype == QT.YesString:
            return yes_map[args]
        
    return variable_getter

In [535]:
def get_transitivity_clauses(poset, linear_extensions_count, variable_getter):
    '''
    In each linear extension:
        (a > b or b > c or a < c)
    Because if a < b and b < c that implies a < c
    '''

    clauses = []
    for a, b, c in itertools.permutations(poset, int(3)):
        for i in range(linear_extensions_count):
            xab = variable_getter(QT.Less, a, b, i)
            xbc = variable_getter(QT.Less, b, c, i)
            xac = variable_getter(QT.Less, a, c, i)
            clauses.append((-xab, -xbc, xac))
    return clauses

In [536]:
# TODO: no idea xD
def OtherThanDim(pset, ldim, vn):
    return [
        #tuple(vn(QT.YesString, *x) for x in itertools.product([0,1], repeat=int(ldim)) if x != tuple(1 for _ in range(ldim))),
         (vn(QT.YesString, *tuple(1 for _ in range(ldim))),)
    ]

In [537]:
#OtherThanDim(3, ConstructVmap([1,2,3], 3))

In [538]:
def as_poset_elem(x):
    str = ''

    current_letter = 'a'
    current_number = 1

    while current_number <= x:
        if x & current_number:
            str += current_letter
        current_number *= 2
        current_letter = chr(ord(current_letter) + 1)      
     

    has_any_letter = any(c.isalpha() for c in str) 
    return str if has_any_letter else "(/)"

In [539]:
def generate_poset_clauses(boolean_graph, linear_extensions_count, var_getter):
    clauses = []


    # For each a and b that a < b in Boolean Lattice 
    for a, b in sorted(boolean_graph.edges(labels=False)):
        for i in range(linear_extensions_count):
            clauses.append(tuple([var_getter(QT.Less, a, b, i)]))

    
    for a, b in sorted(boolean_graph.complement().edges(labels=False)):
        al_least_one_reversed = []
        for i in range(linear_extensions_count):
            al_least_one_reversed += [-var_getter(QT.Less, a, b, i)]
            
        clauses.append(tuple(al_least_one_reversed))
    
    
    # for a, b in sorted(boolean_graph.edges(labels=False)):
    #     for mask in itertools.product([0, 1], repeat=int(linear_extensions_count)):
    #         c = [ -vn(QT.Less, a, b, i) if mask[i] == 1 else -vn(QT.Less, b, a, i) for i in range(ldim)] + [vn(QT.YesString, *mask)]
    #         cl.append(tuple(c))
    # for a, b in sorted(dG.complement().edges(labels=False)):
    #     for mask in itertools.product([0, 1], repeat=int(ldim)):
    #         c = [ -vn(QT.Less, a, b, i) if mask[i] == 1 else -vn(QT.Less, b, a, i) for i in range(ldim)] + [-vn(QT.YesString, *mask)]
    #         cl.append(tuple(c))


    return clauses

In [540]:
def get_boolean_graph(dim):
    B = Posets.BooleanLattice(dim)
    return DiGraph([x for x in B.relations() if x[0] != x[1]])

In [541]:
def generate_clauses(dim, linear_extensions_count):
    boolean_graph = get_boolean_graph(dim)
    var_getter = construct_variable_map(sorted(boolean_graph.vertices()), linear_extensions_count)

    clauses  = generate_poset_clauses(boolean_graph, linear_extensions_count, var_getter)
    #clauses += OtherThanDim(boolean_graph.vertices(), linear_extensions_count, var_getter)
    clauses += get_transitivity_clauses(boolean_graph.vertices(), linear_extensions_count, var_getter)

    return clauses

def save_problem(dim, linear_extensions_count, file_name):
    clauses = generate_clauses(dim, linear_extensions_count)

    sat_generator = DIMACS()
    for c in clauses:
        sat_generator.add_clause(c)
    sat_generator.clauses(file_name)

In [542]:
N = 7
save_problem(N, N, "dim.dimacs")

In [543]:
from functools import cmp_to_key

def RecoverOrder(val, vn, dim, lin_ext_idx):
    ordering = [e for e in range(1 << dim)]

    def Compare(a, b):
        if val[vn(QT.Less, a, b, lin_ext_idx)]:
            return -1
        else:
            return 1
        
    return list(map(as_poset_elem, sorted(ordering, key=cmp_to_key(Compare))))

In [544]:
# def PrintMasks(val, vn, ldim):
#     for mask in itertools.product([0,1], repeat=int(ldim)):
#         print(f"{mask} -> {'YES' if val[vn(QT.YesString, *mask)] else 'NO'}")

In [545]:
def ToList(data):
    xd = [0] + sum(list(map(lambda x:  x.strip().split(' ')[1:], data.split("\n"))), [])
    return xd

In [546]:
def ToMap(xd):
    val = {i:int(xd[i]) > 0 for i in range(len(xd))}
    val.update({-i : not (int(xd[i]) > 0) for i in range(len(xd))})
    return val

In [547]:
data = """v 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
v 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
v 54 55 56 57 58 59 60 -61 62 63 64 65 66 67 68 -69 70 71 72 73 74 75 76 -77
v 78 79 80 81 82 83 84 -85 86 87 88 89 90 91 92 -93 94 95 96 97 98 99 100 -101
v 102 103 104 105 106 107 108 -109 110 111 112 113 114 115 116 117 118 119 120
v 121 -122 123 124 125 -126 127 128 129 130 131 132 133 134 135 136 137 -138
v 139 140 141 -142 143 144 145 146 147 148 149 150 151 152 153 -154 155 156
v 157 -158 159 160 161 162 163 164 165 166 167 168 -169 -170 171 172 173 -174
v 175 176 -177 178 179 180 181 182 183 184 -185 -186 187 188 189 -190 191 192
v -193 194 195 196 197 198 199 200 -201 -202 203 204 205 -206 207 208 -209 210
v 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
v 230 -231 232 233 234 -235 236 237 238 -239 240 241 242 -243 244 245 246 247
v 248 249 250 251 252 253 254 255 256 257 258 259 260 -261 262 263 264 265 266
v 267 268 -269 270 -271 272 273 274 -275 276 -277 278 -279 280 281 282 -283
v 284 -285 286 287 288 289 290 291 292 -293 294 295 296 297 298 299 300 301
v 302 303 304 305 -306 -307 308 309 -310 -311 312 313 314 -315 316 317 318
v -319 320 321 -322 323 324 325 -326 327 328 329 330 331 332 333 334 335 336
v -337 -338 -339 340 341 -342 -343 344 -345 346 -347 348 349 350 -351 352 -353
v -354 355 356 357 -358 359 360 -361 362 363 364 365 366 367 368 369 370 371
v 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
v 391 392 393 394 395 396 -397 398 399 400 401 402 403 404 -405 406 407 408
v 409 410 411 412 -413 414 415 416 417 418 419 420 421 422 423 424 425 -426
v 427 428 429 -430 431 432 433 434 435 436 437 438 439 440 -441 -442 443 444
v 445 -446 447 448 -449 450 451 452 453 454 455 456 457 458 459 460 461 462
v 463 464 465 466 467 468 -469 470 471 472 473 474 475 476 477 478 479 480 0"""

In [548]:
result_list = sum(list(map(lambda x:  x.strip().split(' ')[1:], data.split("\n"))), [])
result_list = list(map(int, result_list))
result_list = sorted(result_list, key=lambda x: abs(x))

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, -61, 62, 63, 64, 65, 66, 67, 68, -69, 70, 71, 72, 73, 74, 75, 76, -77, 78, 79, 80, 81, 82, 83, 84, -85, 86, 87, 88, 89, 90, 91, 92, -93, 94, 95, 96, 97, 98, 99, 100, -101, 102, 103, 104, 105, 106, 107, 108, -109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, -122, 123, 124, 125, -126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, -138, 139, 140, 141, -142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, -154, 155, 156, 157, -158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, -169, -170, 171, 172, 173, -174, 175, 176, -177, 178, 179, 180, 181, 182, 183, 184, -185, -186, 187, 188, 189, -190, 191, 192, -193, 194, 195, 196, 197, 198, 199, 200, -201, -202, 203, 204, 205, -206, 207, 208, -209, 210, 211, 212, 213, 214, 215, 216,

In [549]:
values = {i:int(result_list[i]) > 0 for i in range(len(result_list))}
values.update({-i : not (int(result_list[i]) > 0) for i in range(len(result_list))})


N = 4
vn = construct_variable_map(sorted(get_boolean_graph(N).vertices()), N)
ords = [RecoverOrder(ToMap(ToList(data)), vn, N, j) for j in range(N)]
ords

[['(/)',
  'b',
  'c',
  'bc',
  'd',
  'bd',
  'cd',
  'bcd',
  'a',
  'ab',
  'ac',
  'abc',
  'ad',
  'abd',
  'acd',
  'abcd'],
 ['(/)',
  'a',
  'c',
  'ac',
  'd',
  'ad',
  'cd',
  'acd',
  'b',
  'ab',
  'bc',
  'abc',
  'bd',
  'abd',
  'bcd',
  'abcd'],
 ['(/)',
  'a',
  'b',
  'ab',
  'd',
  'ad',
  'bd',
  'abd',
  'c',
  'ac',
  'bc',
  'abc',
  'cd',
  'acd',
  'bcd',
  'abcd'],
 ['(/)',
  'a',
  'b',
  'ab',
  'c',
  'ac',
  'bc',
  'abc',
  'd',
  'ad',
  'bd',
  'abd',
  'cd',
  'acd',
  'bcd',
  'abcd']]