In [152]:
from random import random
from functools import reduce
from collections import namedtuple
from queue import PriorityQueue, SimpleQueue
import numpy as np

In [153]:
PROBLEM_SIZE = 10
NUM_SETS = 50
SETS = tuple(
    [
        np.array([random() < 0.3 for _ in range(PROBLEM_SIZE)])
        for _ in range(NUM_SETS)
    ]
)
# prob 30% to be true, 70% to be false. This is the problem space.
State = namedtuple("State", ["taken", "not_taken"])

In [154]:
def goal_check(state):
    # this the test if we solve everything. We want each state to be covered. If there is an overlapping it's fine.
    # return np.all(reduce(np.logical_or, [SETS[i] for i in state.taken]))
    return np.all(
        reduce(
            np.logical_or,
            [SETS[i] for i in state.taken],
            np.array([False for _ in range(PROBLEM_SIZE)]),
        )
    )

In [155]:
assert goal_check((State(set(range(NUM_SETS)), set()))), "Problem not solvable"

In [156]:
def h(state):
    covered_sets = sum(
        reduce(
            np.logical_or,
            [SETS[i] for i in state.taken],
            np.array([False for _ in range(PROBLEM_SIZE)]),
        )
    )
    ret = 0 if covered_sets == PROBLEM_SIZE else 1
    return ret


# I want to check if what i add is covering some state that i'm not covering
# so i check if the state covered by my new solution (taken xor action)
def is_overlapping(taken, action):
    new_taken = taken ^ {action}
    state_covered_by_new_taken = reduce(
        np.logical_or,
        [SETS[i] for i in new_taken],
        np.array([False for _ in range(PROBLEM_SIZE)]),
    )
    state_covered_by_taken = reduce(
        np.logical_or,
        [SETS[i] for i in taken],
        np.array([False for _ in range(PROBLEM_SIZE)]),
    )

    if all(state_covered_by_new_taken == state_covered_by_taken):
        return True
    else:
        return False


def g(current_state):
    return len(current_state.taken)


def a_star_func(current_state):
    return g(current_state) + h(current_state)

In [157]:
def search(distance, a_star=True):
    if a_star:
        frontier = PriorityQueue()
    else:
        frontier = SimpleQueue()
    state = State(set(), set(range(NUM_SETS)))

    frontier.put((distance(state), state))

    counter = 0
    _, current_state = frontier.get()  # take the state
    # check if state is the solution
    while not goal_check(current_state):
        counter += 1
        for action in current_state.not_taken:  # all action we can take now
            if not np.all(action == False) and not is_overlapping(
                current_state.taken, action
            ):
                new_state = State(
                    current_state.taken ^ {action},
                    current_state.not_taken ^ {action},
                )  # | is the set union
                frontier.put((distance(new_state), new_state))

        _, current_state = frontier.get()

    print(
        f"Solved in {counter:,} steps ({len(current_state.taken)} tiles), with state: {current_state.taken}"
    )
    return current_state.taken

In [158]:
sol = search(a_star_func)

sol = search(a_star_func, False)

Solved in 200 steps (3 tiles), with state: {17, 28, 20}
Solved in 2,355 steps (3 tiles), with state: {1, 4, 37}
