In [72]:
class State:
    def __init__(self, set_: set):
        self.set_ = set_.copy()

    def __hash__(self):
        return hash(bytes(self.set_))

    def __eq__(self, other):
        return bytes(self.set_) == bytes(other.set_)

    def __lt__(self, other):
        return bytes(self.set_) < bytes(other.set_)

    def __str__(self):
        return str(self.set_)

    def __repr__(self):
        return repr(self.set_)

    def length(self):
        return len(self.set_)

In [73]:
N = 7
GOAL = State(set(range(N)))
INITIAL_STATE = State(set())

In [74]:
import random

def problem(N, seed=None):
    random.seed(seed)
    return [
        list(set(random.randint(0, N - 1) for n in range(random.randint(N // 5, N // 2))))
        for n in range(random.randint(N, N * 5))
    ]

In [75]:
from lib import PriorityQueue
from typing import Callable

def sets_concat(state: State, new_set: set):
    return State(state.set_.union(new_set))

def search(
    generated_sets: list,
    initial_state: State,
    goal_test: State,
    parent_state: dict,
    state_cost: dict,
    priority_function: Callable,
    weight_cost: Callable,
):
    frontier = PriorityQueue()
    parent_state.clear()
    state_cost.clear()

    state = initial_state
    parent_state[state] = None
    state_cost[state] = 0
    total_weight = state.length()

    while state is not None and not goal_test(state):
        for new_set in generated_sets:
            if new_set == state.set_:
                continue
            else:
                new_state = sets_concat(state, new_set)
                cost = weight_cost(new_set)
                if new_state not in state_cost and new_state not in frontier:
                    parent_state[new_state] = state
                    state_cost[new_state] = state_cost[state] + cost
                    total_weight += state.length()
                    frontier.push(new_state, p=priority_function(new_state, total_weight))
                    #print(f"Added new node to frontier(cost={state_cost[new_state]})")
                elif new_state in frontier and state_cost[new_state] > state_cost[state] + cost:
                    old_cost = state_cost[new_state]
                    parent_state[new_state] = state
                    state_cost[new_state] = state_cost[state] + cost
                    #print(f"Updated node cost in frontier: {old_cost} -> {state_cost[new_state]}")
        if frontier:
            state = frontier.pop()
        else:
            state = None
    path = list()
    s = state
    while s:
        path.append(s)
        s = parent_state[s]
    print(f"Found a solution in {len(path):,} steps; visited {len(state_cost):,} states")
    return list(reversed(path))



In [76]:
from collections import defaultdict

def h(new_state, total_weight):
    return new_state.length() + total_weight

def goal_test(set_state):
    return set_state == GOAL

if __name__ == "__main__":
    x = problem(N)
    sets_list = list()
    for l in x:
        sets_list.append(set(l))
    INITIAL_STATE = State(sets_list[random.randint(0, len(sets_list))])
    parent_state = dict()
    state_cost = dict()
    final = search(
        sets_list,
        INITIAL_STATE,
        goal_test,
        parent_state,
        state_cost,
        priority_function=lambda s,w: state_cost[s] + h(s,w),
        weight_cost=lambda a: len(a),
    )
    
    

Added new node to frontier(cost=3)
Added new node to frontier(cost=1)
Added new node to frontier(cost=2)
Added new node to frontier(cost=2)
Added new node to frontier(cost=3)
Added new node to frontier(cost=1)
Added new node to frontier(cost=1)
Added new node to frontier(cost=1)
Added new node to frontier(cost=2)
Added new node to frontier(cost=1)
Added new node to frontier(cost=3)
Added new node to frontier(cost=2)
Added new node to frontier(cost=3)
Added new node to frontier(cost=3)
Added new node to frontier(cost=3)
Added new node to frontier(cost=2)
Added new node to frontier(cost=2)
Added new node to frontier(cost=3)
Added new node to frontier(cost=4)
Added new node to frontier(cost=4)
Added new node to frontier(cost=5)
Updated node cost in frontier: 5 -> 4
Added new node to frontier(cost=5)
Added new node to frontier(cost=3)
Added new node to frontier(cost=3)
Added new node to frontier(cost=3)
Added new node to frontier(cost=5)
Added new node to frontier(cost=5)
Added new node to