In [70]:
from random import random
from functools import reduce
from collections import namedtuple
from queue import PriorityQueue, SimpleQueue
import numpy as np

In [71]:
PROBLEM_SIZE = 10
NUM_SETS = 50
SETS = tuple(
    [
        np.array([random() < 0.3 for _ in range(PROBLEM_SIZE)])
        for _ in range(NUM_SETS)
    ]
)
# prob 30% to be true, 70% to be false. This is the problem space.
State = namedtuple("State", ["taken", "not_taken"])

In [72]:
# Returns the state covered by the taken sets
def overall_state_covered(taken: set[int]) -> list[bool]:
    return reduce(
        np.logical_or,
        [SETS[i] for i in taken],
        np.array([False for _ in range(PROBLEM_SIZE)]),
    )


# Returns True if all state are covered, False otherwise
def goal_check(state: State("taken", "not_taken")) -> bool:
    return np.all(overall_state_covered(state.taken))

In [73]:
# Is the problem solvable?
assert goal_check((State(set(range(NUM_SETS)), set()))), "Problem not solvable"

In [74]:
# Returns the number of sets taken, which is the number of steps done 'till now.
def g(state: State("taken", "not_taken")) -> int:
    return len(state.taken)


# Returns 0 if we are covering every state, we don't need any more step.
# Returns 1 otherwise.
# This function is optimistic because if we are not covering some state,
# it assumes that we will just need 1 more step to cover them.
def h(state: State("taken", "not_taken")) -> int:
    return 0 if goal_check(state) else 1


# Returns the sum of g and h.
def a_star_func(state: State("taken", "not_taken")) -> int:
    return g(state) + h(state)

In [75]:
# Returns True if the action doesn't add more covered set.
# Returns False if the action cover set that are currently not covered.
def is_overlapping(taken: set[int], action: int) -> bool:
    new_taken = taken | {action}
    state_covered_by_new_taken = overall_state_covered(new_taken)

    state_covered_by_taken = overall_state_covered(taken)

    if all(state_covered_by_new_taken == state_covered_by_taken):
        return True
    else:
        return False


def search(queue: object):
    frontier = queue
    state = State(set(), set(range(NUM_SETS)))

    frontier.put((a_star_func(state), state))

    counter = 0
    _, current_state = frontier.get()  # take the state

    # remove sets containing all false.
    for set_index in current_state.not_taken:
        if not any(SETS[set_index]):
            current_state.not_taken - {set_index}

    # check if state is the solution
    while not goal_check(current_state):
        counter += 1
        # for all actions that we can take now
        for action in current_state.not_taken:
            # if the action let us cover more sets
            if not is_overlapping(current_state.taken, action):
                new_state = State(
                    current_state.taken | {action},
                    current_state.not_taken - {action},
                )
                frontier.put((a_star_func(new_state), new_state))

        _, current_state = frontier.get()

    print(
        f"Solved in {counter:,} steps ({len(current_state.taken)} tiles), with state: {current_state.taken}"
    )

In [76]:
print("a* solution: ")
search(PriorityQueue())

print()

print("breadth-first solution :")
search(SimpleQueue())

a* solution: 
Solved in 8 steps (2 tiles), with state: {8, 6}

breadth-first solution :
Solved in 300 steps (2 tiles), with state: {8, 6}
