In [2]:
import random

In [3]:
def problem(N, seed=None):
    random.seed(seed)
    return [
        list(set(random.randint(0, N - 1) for n in range(random.randint(N // 5, N // 2))))
        for n in range(random.randint(N, N * 5))
    ]

In [5]:
import numpy as np

In [4]:
from gx_utils import *

In [7]:
import logging
from random import seed, choice
from typing import Callable

logging.basicConfig(format="%(message)s", level=logging.INFO)

In [88]:
class State:
    def __init__(self, data: np.ndarray):
        self._data = data.copy()
        self._data.flags.writeable = False

    def __hash__(self):
        #return hash(tuple(tuple(x) for x in self.copy_data().tolist()))
        return hash(bytes(self._data))

    def __eq__(self, other):
        return bytes(self._data) == bytes(other._data)

    def __lt__(self, other):
        return bytes(self._data) < bytes(other._data)

    def __str__(self):
        return str(self._data)

    def __repr__(self):
        return repr(self._data)

    @property
    def data(self):
        return self._data

    def copy_data(self):
        return self._data.copy()
    
    def hash_data(self):
        return hash(tuple(tuple(x) for x in a.copy_data().tolist()))

In [71]:
#return a new_state if elem contains element different from state
def result(state, elem):
    '''
    elem_set = set(elem)
    state_set = set()
    for e in state.copy_data():
        [state_set.add(i) for i in e]
    intersection = list(elem_set - state_set)
    if intersection == []: return False
    '''
    new_state = []
    [new_state.append(e) for e in state.copy_data().tolist()]
    new_state.append(elem)
    new_state = sorted(new_state, key=lambda l: len(l))
    
    return State(np.array(new_state))

def is_valid(state, elem):
    covered = set()
    for e in state.copy_data():
        [covered.add(i) for i in e]
    old_covered = covered.copy()
    covered |= set(elem) 
    if old_covered == covered:
        return False
    else: 
        return True
    

def goal_test(state, goal_arr):
    state_set = set()
    test = state.copy_data().tolist()
    for e in test:
        [state_set.add(i) for i in e]
    state_set = sorted(list(state_set))
    return state_set == goal_arr
    
def possible_actions(state: State):
    #return [item for item in universe if item not in state.copy_data().tolist()]
    return [tuple(item) for item in input_state if is_valid(state, item)]

In [104]:
def search(
    initial_state: State,
    goal_test: Callable,
    parent_state: dict,
    state_cost: dict,
    priority_function: Callable,
    unit_cost: Callable,
    N: int
):
    frontier = PriorityQueue()
    parent_state.clear()
    state_cost.clear()

    state = initial_state
    parent_state[state] = None
    state_cost[state] = 0
    
    goal_array = list(range(N))
    nodes = 0
    while state is not None and not goal_test(state, goal_array):
        for a in possible_actions(state):

            if not is_valid(state, a):
                continue
            new_state = result(state, a)
            cost = unit_cost(a)

            if new_state not in state_cost and new_state not in frontier:
                parent_state[new_state] = state
                state_cost[new_state] = state_cost[state] + cost
                frontier.push(new_state, p=priority_function(new_state))
                logging.debug(f"Added new node to frontier (cost={state_cost[new_state]})")
            elif new_state in frontier and state_cost[new_state] > state_cost[state] + cost:
                old_cost = state_cost[new_state]
                parent_state[new_state] = state
                state_cost[new_state] = state_cost[state] + cost
                logging.debug(f"Updated node cost in frontier: {old_cost} -> {state_cost[new_state]}")
        if frontier:
            state = frontier.pop()
            nodes += 1
        else:
            state = None

    path = list()
    s = state
    while s:
        path.append(s.copy_data())
        s = parent_state[s]

    logging.info(f"N = {N} | Found a solution in {len(path):,} steps; visited {len(state_cost):,} states, nodes: {nodes}, w = {sum(len(x) for x in path[0])}")
    logging.info(f'state: {path[0]}')
    return list(reversed(path))

# Breadth-first

In [107]:
parent_state = dict()
state_cost = dict()

logging.getLogger().setLevel(logging.INFO)
''',10,20,50,100,500,1000'''
final = []
for n in [5, 10, 20]:
    input_state = sorted(problem(n, seed=42), key=lambda l: len(l))
    input_state = sorted(set(tuple(sorted(x)) for x in input_state), key=lambda l: len(l))
    initial_state = State(np.array([input_state[0]]))
    
    logging.debug(f'initial_state: {initial_state}')
    logging.debug(f'initial nodes: {len(input_state)}')
    final = search(
        initial_state=initial_state,
        goal_test=goal_test,
        parent_state=parent_state,
        state_cost=state_cost,
        priority_function=lambda s: len(state_cost),
        unit_cost=lambda a: 1,
        N=n
    )

  return State(np.array(new_state))
N = 5 | Found a solution in 4 steps; visited 329 states, nodes: 79, w = 5
state: [list([2]) list([4]) list([0]) (1, 3)]
N = 10 | Found a solution in 4 steps; visited 50,250 states, nodes: 2372, w = 10
state: [list([6]) (0, 5) (2, 7, 8) (1, 3, 4, 9)]
N = 20 | Found a solution in 5 steps; visited 361,995 states, nodes: 26556, w = 27
state: [list([7, 17, 18]) list([5, 8, 16]) (4, 6, 15, 17, 18)
 (0, 1, 3, 7, 9, 10, 11, 15) (2, 8, 12, 13, 14, 16, 17, 19)]
