In [33]:
from random import random
from functools import reduce
from collections import namedtuple
from queue import PriorityQueue, SimpleQueue, LifoQueue

import numpy as np

In [34]:
PROBLEM_SIZE = 5
NUM_SETS = 10
SETS = tuple(
    np.array([random() < 0.3 for _ in range(PROBLEM_SIZE)]) 
    for _ in range(NUM_SETS)
)
State = namedtuple('State', ['taken', 'not_taken'])

In [35]:
def check_sets():
    filtered_sets = tuple(filter(any, SETS))
    all_true_sets = tuple(filter(all, SETS))
    return filtered_sets, all_true_sets

filtered_sets, all_true_sets = check_sets()
if(len(all_true_sets) != 0):
    print("Problem solvable with only one set")
NUM_FSETS = len(filtered_sets)

In [36]:
def goal_check(state):
    return np.all(reduce(  
        np.logical_or,
        [filtered_sets[i] for i in state.taken],
        np.array([False for _ in range(PROBLEM_SIZE)]),
    ))

def covered(state):
    return reduce(
        np.logical_or,
        [filtered_sets[i] for i in state.taken],
        np.array([False for _ in range(PROBLEM_SIZE)]),
    )


def h(state):
    largest_set_size = max(sum(s) for s in filtered_sets)
    missing_size = PROBLEM_SIZE - sum(covered(state))
    optimistic_estimate = ceil(missing_size / largest_set_size)
    return optimistic_estimate


def h2(state):
    already_covered = covered(state)
    if np.all(already_covered):
        return 0
    largest_set_size = max(sum(np.logical_and(s, np.logical_not(already_covered))) for s in filtered_sets)
    missing_size = PROBLEM_SIZE - sum(already_covered)
    optimistic_estimate = ceil(missing_size / largest_set_size)
    return optimistic_estimate


def h3(state):
    already_covered = covered(state)
    if np.all(already_covered):
        return 0
    missing_size = PROBLEM_SIZE - sum(already_covered)
    candidates = sorted((sum(np.logical_and(s, np.logical_not(already_covered))) for s in filtered_sets), reverse=True)
    taken = 1
    while sum(candidates[:taken]) < missing_size:
        taken += 1
    return taken


def total_distance(state):
    return len(state.taken) + h3(state)

In [37]:
assert goal_check(
    State(set(range(NUM_FSETS)), set())
), "Probelm not solvable"

In [38]:
frontier = PriorityQueue() 
state = State(set(), set(range(NUM_FSETS)))
frontier.put((total_distance(state), state))

counter = 0
_, current_state = frontier.get()
while not goal_check(current_state): 
    counter += 1
    for action in current_state[1]:
        new_state = State(
            current_state.taken ^ {action},
            current_state.not_taken ^ {action},
        )
        frontier.put((total_distance(new_state), new_state))
    _, current_state = frontier.get()

print(
    f"Solved in {counter:,} steps ({len(current_state.taken)} tiles)"
)

Solved in 14 steps (3 tiles)


In [39]:
current_state

State(taken={3, 5, 6}, not_taken={0, 1, 2, 4, 7})

In [41]:
print(filtered_sets[3])
print(filtered_sets[5])
print(filtered_sets[6])

[False  True False False False]
[False False  True False  True]
[ True False False  True False]
