In [66]:
import random
from lib import PriorityQueue
from typing import Callable

def problem(N, seed=None):
    random.seed(seed)
    return [
        list(set(random.randint(0, N - 1) for n in range(random.randint(N // 5, N // 2))))
        for n in range(random.randint(N, N * 5))
    ]

In [67]:
class State:
    def __init__(self):
        self.set_ = set()
        self.arrived_from = set()
    
    def __hash__(self):
        return hash(bytes(self.set_))

    def __eq__(self, other):
        return bytes(self.set_) == bytes(other.set_)

    def __lt__(self, other):
        return bytes(self.set_) < bytes(other.set_)

    def __str__(self):
        return str(self.set_)

    def __repr__(self):
        return repr(self.set_)
    
    def update_state(self, old_state, new_set):
        self.set_ = old_state.set_.copy()
        self.arrived_from = new_set.copy()
        self.set_.update(new_set.copy())



In [68]:

def search(
    generated_sets: list,
    initial_state: State,
    goal_test: Callable,
    parent_state: dict,
    state_cost: dict,
    priority_function: Callable,
    weight_cost: Callable,
):
    frontier = PriorityQueue()
    parent_state.clear()
    state_cost.clear()

    state = initial_state
    parent_state[state] = None
    state_cost[state] = 0

    while state is not None and not goal_test(state):
        for new_set in generated_sets:
            new_state = State()
            new_state.update_state(state, new_set)
            cost = weight_cost(new_set)
            if new_state not in state_cost and new_state not in frontier:
                parent_state[new_state] = state
                state_cost[new_state] = state_cost[state] + cost
                frontier.push(new_state, p=priority_function(new_state))
                #print(f"Added new node to frontier(cost={state_cost[new_state]})")
            elif new_state in frontier and state_cost[new_state] > state_cost[state] + cost:
                old_cost = state_cost[new_state]
                parent_state[new_state] = state
                state_cost[new_state] = state_cost[state] + cost
                #print(f"Updated node cost in frontier: {old_cost} -> {state_cost[new_state]}")
        if frontier:
            state = frontier.pop()
        else:
            state = None
    path = list()
    s = state
    while True:
        if s.set_ == set():
            break
        path.append(s.arrived_from)
        s = parent_state[s]
    weight = sum([len(item) for item in path])
    print(f"Found a solution in {len(path):,} steps; visited {len(state_cost):,} states")
    print(f"The total weight is {weight}")
    return list(reversed(path))


In [69]:
def greedySolution(
    generated_sets: list,
    initial_state: State,
    goal_test: Callable,
):
    generated_sets = sorted(generated_sets, key=lambda x: len(x))
    state = initial_state
    parent_state = list()
    flag_found = False

    for new_set in generated_sets:
        new_state = State()
        new_state.update_state(state, new_set)
        parent_state.append(new_state.arrived_from)
        if goal_test(new_state) == True:
            flag_found = True    
            state = new_state
            break
        state = new_state
    
    if flag_found == True:
        weights = [len(item) for item in parent_state]
        print(f"Found a solution in {len(parent_state)-1:,} steps; the total cost is: {sum(weights):,}")
        return parent_state
    else:
        print("Didn't find a solution")
    

In [70]:
def h(new_state: State):
    return len(GOAL.set_) - len(new_state.set_)
    
if __name__ == "__main__":
    flag_greedy = 0
    if flag_greedy == 0:
        N = [5, 7, 10, 12, 15, 20, 25, 30]
        SELECTED_N = N[4]
        GOAL = State()
        GOAL.set_ = set(range(SELECTED_N))
        INITIAL_STATE = State()
        result = problem(SELECTED_N, seed=42)
        sets_list = list()
        sets_list = [set(item) for item in result]
        parent_state = dict()
        state_cost = dict()
        final = search(
            sets_list,
            INITIAL_STATE,
            lambda s: s.set_ == GOAL.set_,
            parent_state,
            state_cost,
            priority_function=lambda s: state_cost[s] + h(s),
            weight_cost=lambda a: len(a),
        )
        print(final)
    else:
        GREEDY_SET = [10, 20, 50, 100, 500, 1000]
        GREEDY_N = GREEDY_SET[5]
        GREEDY_GOAL = State()
        INITIAL_STATE = State()
        GREEDY_GOAL.set_ = set(range(GREEDY_N))
        result = problem(GREEDY_N, seed=42)
        sets_list = list()
        sets_list = [set(item) for item in result]
        final = greedySolution(
            sets_list,
            INITIAL_STATE,
            lambda s: s.set_ == GREEDY_GOAL.set_,
        )
        print(final)



KeyboardInterrupt: 