Copyright **`(c)`** 2023 Giovanni Squillero `<giovanni.squillero@polito.it>`  
[`https://github.com/squillero/computational-intelligence`](https://github.com/squillero/computational-intelligence)  
Free for personal or classroom use; see [`LICENSE.md`](https://github.com/squillero/computational-intelligence/blob/master/LICENSE.md) for details.  

In [1197]:
from random import random
from functools import reduce
from collections import namedtuple
from queue import PriorityQueue

import numpy as np

In [1198]:
PROBLEM_SIZE = 20
NUM_SETS = 40
SETS = [
    np.array([random() < 0.3 for _ in range(PROBLEM_SIZE)]) for _ in range(NUM_SETS)
]
# SETS = [np.array([True, True, True, True, True, True, False, False, False]),
#         np.array([True, True, True, True, False, False, False, False, True]),
#         np.array([False, False, False, False, True, True, True, True, False])]
State = namedtuple("State", ["taken", "not_taken"])

In [1199]:
def goal_check(state):
    return np.all(
        reduce(
            np.logical_or,
            [SETS[i] for i in state.taken],
            np.array([False for _ in range(PROBLEM_SIZE)]),
        )
    )


def distance(state):
    return PROBLEM_SIZE - sum(
        reduce(
            np.logical_or,
            [SETS[i] for i in state.taken],
            np.array([False for _ in range(PROBLEM_SIZE)]),
        )
    )


def g(state, action):
    return sum(
        reduce(
            np.logical_and,
            reduce(
                np.logical_or,
                [SETS[i] for i in state],
                np.array([False for _ in range(PROBLEM_SIZE)]),
            ),
            SETS[action],
        )
    )

In [1200]:
assert goal_check(State(set(range(NUM_SETS)), set())), "Problem not solvable"

In [1201]:
def sorting_key(e):
    return e.sum()

def get_ordered_sets():
    SETS.sort(reverse=True, key=sorting_key)

In [1202]:
frontier = PriorityQueue()
solutions = PriorityQueue()
get_ordered_sets()

state = State(set(), set(range(NUM_SETS)))

frontier.put((distance(state), state))
wrong = []

counter = 0
_, current_state = frontier.get()

while counter==0 or not frontier.empty():
    if  solutions.empty() or len(current_state[0]) + 1 < solutions.queue[0][0]:
        counter += 1
        for action in current_state[1]:
            new_state = State(
                current_state.taken ^ {action},
                current_state.not_taken ^ {action},
            )
            
            if len(wrong)==0 or sorted(new_state[0]) not in wrong:
                g_res = g(current_state[0], action)
                d = distance(new_state)

                frontier.put((d + g_res, new_state))
    
                if d == 0:
                    print("find_solution")
                    solutions.put((len(new_state[0]), new_state[0]))
                    break


    if not solutions.empty() and (solutions.queue[0][0] == 1 or solutions.queue[0][0] == 2):
        break
    _, current_state = frontier.get()
    wrong.append(sorted(current_state[0]))
  
    

print(f"Solved in {counter:,} steps ({solutions.queue[0][0]} tiles)")

find_solution
Solved in 821 steps (4 tiles)


In [1203]:
solutions.queue

[(4, {0, 1, 2, 3})]

In [1204]:
SETS

[array([ True,  True, False,  True,  True, False, False, False, False,
         True,  True, False, False, False,  True,  True, False,  True,
        False,  True]),
 array([False,  True,  True, False,  True, False,  True, False,  True,
        False, False,  True,  True, False, False,  True, False, False,
        False,  True]),
 array([ True, False, False, False, False,  True,  True, False, False,
        False, False,  True,  True,  True,  True, False, False,  True,
         True, False]),
 array([False,  True, False, False,  True, False,  True,  True,  True,
        False,  True,  True, False, False, False, False,  True, False,
        False,  True]),
 array([ True, False, False, False,  True,  True, False,  True, False,
        False, False,  True, False, False,  True, False, False,  True,
         True,  True]),
 array([False, False,  True,  True, False,  True, False,  True, False,
        False,  True, False, False,  True,  True,  True, False,  True,
        False, False]),
 arr