Copyright **`(c)`** 2023 Giovanni Squillero `<giovanni.squillero@polito.it>`  
[`https://github.com/squillero/computational-intelligence`](https://github.com/squillero/computational-intelligence)  
Free for personal or classroom use; see [`LICENSE.md`](https://github.com/squillero/computational-intelligence/blob/master/LICENSE.md) for details.  

A* algorithm made by Gabriele Tomatis (Polito, s313848) and Luca Sturaro (Polito, s320062).

In [1]:
from random import random
from functools import reduce
from collections import namedtuple
from queue import PriorityQueue, SimpleQueue, LifoQueue

import numpy as np

In [2]:
PROBLEM_SIZE = 5
NUM_SETS = 10
SETS = tuple(
    np.array([random() < 0.3 for _ in range(PROBLEM_SIZE)])
    for _ in range(NUM_SETS)
)
State = namedtuple('State', ['taken', 'not_taken'])

In [3]:
def goal_check(state):
    return np.all(reduce(
        np.logical_or,
        [SETS[i] for i in state.taken],
        np.array([False for _ in range(PROBLEM_SIZE)]),
    ))

def distance(state):
    return PROBLEM_SIZE - sum(
        reduce(
            np.logical_or,
            [SETS[i] for i in state.taken],
            np.array([False for _ in range(PROBLEM_SIZE)]),
        ))

#Measure of weight for a state in the priority queue
def weight(set):
    # f = g + h
    total = [False] * PROBLEM_SIZE
    g = len(set.taken) # Tile of the tree
    
    for s in set.taken:
        for e in SETS[s]:
            total = total or e
        # print(total)
    h = PROBLEM_SIZE - sum(total)   # Distance from solution (n of elements we still need)
    
    # print(g+h)
    return g+h  # return f

In [4]:
assert goal_check(
    State(set(range(NUM_SETS)), set())
), "Probelm not solvable"

In [5]:
# Special sets analysis
def special_sets():
    #Low trheshold makes this approach faster for small problems (high variability of values since we have less dimensionality)
    threshold = NUM_SETS/100 * 29  # 30%
    if threshold < 1:
        threshold = 1
    print(threshold)
    criticalities = []
    specials = []
    normal = []
    bestSol = None

    #Searching for immediates solutions
    res = [all(x) for x in SETS]
    for i in range(len(res)):
        if res[i]:
            bestSol = SETS[i]
            break

    if bestSol == None:
        # Reading on columns
        for j in range(PROBLEM_SIZE):
            #print()
            critical_idx = []
            for i in range(len(SETS)):
                if SETS[i][j]:
                    critical_idx.append(i)
                
            criticalities.append(len(critical_idx))
            # if the number of possibly critical is less then a threshold so they're critical
            if len(critical_idx) < threshold:
                print(critical_idx)
                #for e in critical_idx:
                specials.append(e)
            else:
                for e in critical_idx:
                    normal.append(e)
        print(criticalities)
        criticalities.sort()
        print(criticalities[:10])
        print("SPECIALS")
        special_set = set(specials)
        print(len(special_set))
        print(special_set)
        #print("NOT SO SPECIALS")
        normal_set = set(normal)
        normal_set -= special_set
    return bestSol, special_set, normal_set

In [6]:
best, sp_set, nm_set = special_sets()
print(best)
print(sp_set)
print(nm_set)

2.9000000000000004
[4, 3, 3, 4, 4]
[3, 3, 4, 4, 4]
SPECIALS
0
set()
None
set()
{0, 1, 2, 3, 4, 5, 6, 8, 9}


In [7]:
frontier = PriorityQueue()
# frontier = SimpleQueue()
state = State(set(), set(sp_set.union(nm_set)))
print(state)
# A* algorithm implemented with a priority queue and a weight function
frontier.put((weight(state), state))

counter = 0
_, current_state = frontier.get()
while not goal_check(current_state):
    counter += 1
    for action in current_state[1]:
        new_state = State(
            current_state.taken ^ {action},
            current_state.not_taken ^ {action},
        )
        frontier.put((weight(new_state), new_state))
    _, current_state = frontier.get()

print(
    f"Solved in {counter:,} steps ({len(current_state.taken)} tiles)"
)

State(taken=set(), not_taken={0, 1, 2, 3, 4, 5, 6, 8, 9})
Solved in 13 steps (2 tiles)


In [8]:
current_state

State(taken={1, 9}, not_taken={0, 2, 3, 4, 5, 6, 8})