In [64]:
import random
from lib import PriorityQueue
from typing import Callable

def problem(N, seed=None):
    random.seed(seed)
    return [
        list(set(random.randint(0, N - 1) for n in range(random.randint(N // 5, N // 2))))
        for n in range(random.randint(N, N * 5))
    ]

In [65]:
class State:
    def __init__(self):
        self.set_ = set()
        self.list_ = [self.set_]

    def __hash__(self):
        return hash(bytes(self.set_))

    def __eq__(self, other):
        return bytes(self.set_) == bytes(other.set_)

    def __lt__(self, other):
        return bytes(self.set_) < bytes(other.set_)

    def __str__(self):
        return str(self.set_)

    def __repr__(self):
        return repr(self.set_)

    def weight(self):
        ret = 0
        for l in self.list_:
            ret += len(l)
        return ret
    
    def copy(self):
        ret = State()
        ret.set_ = self.set_.copy()
        ret.list_ = self.list_.copy()
        return ret

    def add_new_set(self, new_set: set):
        self.set_.update(new_set)
        self.list_.append(new_set)

In [66]:
N = 7
GOAL = State()
GOAL.add_new_set(set(range(N)))
INITIAL_STATE = State()

In [67]:
def sets_concat(state: State, new_set: set):
    return State(state.set_.union(new_set))

def goal_test(set_state: State):
    return set_state.set_ == GOAL.set_

def h(new_state: State):
    return new_state.weight()

In [68]:

def search(
    generated_sets: list,
    initial_state: State,
    goal_test: State,
    parent_state: dict,
    state_cost: dict,
    priority_function: Callable,
    weight_cost: Callable,
):
    frontier = PriorityQueue()
    parent_state.clear()
    state_cost.clear()

    state = initial_state
    parent_state[state] = None
    state_cost[state] = 0

    while state is not None and not goal_test(state):
        for new_set in generated_sets:
            new_state = state.copy()
            new_state.add_new_set(new_set)
            cost = weight_cost(new_set)
            if new_state not in state_cost and new_state not in frontier:
                parent_state[new_state] = state
                state_cost[new_state] = state_cost[state] + cost
                frontier.push(new_state, p=priority_function(new_state))
                print(f"Added new node to frontier(cost={state_cost[new_state]})")
            elif new_state in frontier and state_cost[new_state] > state_cost[state] + cost:
                old_cost = state_cost[new_state]
                parent_state[new_state] = state
                state_cost[new_state] = state_cost[state] + cost
                print(f"Updated node cost in frontier: {old_cost} -> {state_cost[new_state]}")
        if frontier:
            state = frontier.pop()
        else:
            state = None
    path = list()
    s = state
    while s:
        if s.set_ == set():
            break
        path.append(s.list_[-1])
        s = parent_state[s]
    print(f"Found a solution in {len(path):,} steps; visited {len(state_cost):,} states")
    return list(reversed(path))


In [69]:
if __name__ == "__main__":
    x = problem(N)
    sets_list = list()
    for l in x:
        sets_list.append(set(l))
    parent_state = dict()
    state_cost = dict()
    final = search(
        sets_list,
        INITIAL_STATE,
        goal_test,
        parent_state,
        state_cost,
        priority_function=lambda s: state_cost[s] + h(s),
        weight_cost=lambda a: len(a),
    )
    print(final)

Added new node to frontier(cost=2)
Added new node to frontier(cost=1)
Added new node to frontier(cost=2)
Added new node to frontier(cost=2)
Added new node to frontier(cost=1)
Added new node to frontier(cost=3)
Added new node to frontier(cost=2)
Added new node to frontier(cost=2)
Added new node to frontier(cost=3)
Added new node to frontier(cost=2)
Added new node to frontier(cost=1)
Added new node to frontier(cost=1)
Added new node to frontier(cost=1)
Added new node to frontier(cost=2)
Added new node to frontier(cost=3)
Added new node to frontier(cost=3)
Added new node to frontier(cost=3)
Added new node to frontier(cost=2)
Added new node to frontier(cost=3)
Added new node to frontier(cost=4)
Added new node to frontier(cost=3)
Added new node to frontier(cost=2)
Added new node to frontier(cost=2)
Added new node to frontier(cost=3)
Added new node to frontier(cost=3)
Added new node to frontier(cost=2)
Added new node to frontier(cost=4)
Added new node to frontier(cost=3)
Added new node to fr