# Search for AIMA 4th edition

Implementation of search algorithms and search problems for AIMA.

# Problems and Nodes

We start by defining the abstract class for a `Problem`; specific problem domains will subclass this. To make it easier for algorithms that use a heuristic evaluation function, `Problem` has a default `h` function (uniformly zero), and subclasses can define their own default `h` function.

We also define a `Node` in a search tree, and some functions on nodes: `expand` to generate successors; `path_actions` and `path_states`  to recover aspects of the path from the node.  

In [37]:
%matplotlib inline
import matplotlib.pyplot as plt
import random
import heapq
import math
import sys, os
import time

import numpy as np
import copy

sys.path.append(os.getcwd())
#import bisect
from bisect import *

from collections import defaultdict, deque, Counter
from itertools import combinations


class Problem(object):
    """The abstract class for a formal problem. A new domain subclasses this,
    overriding `actions` and `results`, and perhaps other methods.
    The default heuristic is 0 and the default action cost is 1 for all states.
    When you create an instance of a subclass, specify `initial`, and `goal` states 
    (or give an `is_goal` method) and perhaps other keyword args for the subclass."""

    def __init__(self, initial=None, goal=None, **kwds): 
        self.__dict__.update(initial=initial, goal=goal, **kwds) 
        
    def actions(self, state):        raise NotImplementedError
    def result(self, state, action): raise NotImplementedError
    def is_goal(self, state):        return state == self.goal
    def action_cost(self, s, a, s1): return 1
    def h(self, node):               return 0
    
    def __str__(self):
        return '{}({!r}, {!r})'.format(
            type(self).__name__, self.initial, self.goal)
    

class Node:
    "A Node in a search tree."
    def __init__(self, state, parent=None, action=None, path_cost=0):
        self.__dict__.update(state=state, parent=parent, action=action, path_cost=path_cost)

    def __repr__(self): return '<{}>'.format(self.state)
    def __len__(self): return 0 if self.parent is None else (1 + len(self.parent))
    def __lt__(self, other): return self.path_cost < other.path_cost
    
    
failure = Node('failure', path_cost=math.inf) # Indicates an algorithm couldn't find a solution.
cutoff  = Node('cutoff',  path_cost=math.inf) # Indicates iterative deepening search was cut off.
    
    
def expand(problem, node):
    "Expand a node, generating the children nodes."
    s = node.state
    for action in problem.actions(s):
        s1 = problem.result(s, action)
        cost = node.path_cost + problem.action_cost(s, action, s1)
        yield Node(s1, node, action, cost)
        

def path_actions(node):
    "The sequence of actions to get to this node."
    if node.parent is None:
        return []  
    return path_actions(node.parent) + [node.action]


def path_states(node):
    "The sequence of states to get to this node."
    if node in (cutoff, failure, None): 
        return []
    return path_states(node.parent) + [node.state]

# Queues

First-in-first-out and Last-in-first-out queues, and a `PriorityQueue`, which allows you to keep a collection of items, and continually remove from it the item with minimum `f(item)` score.

In [38]:
FIFOQueue = deque

LIFOQueue = list

class PriorityQueue:
    """A queue in which the item with minimum f(item) is always popped first."""

    def __init__(self, items=(), key=lambda x: x): 
        self.key = key
        self.items = [] # a heap of (score, item) pairs
        for item in items:
            self.add(item)
         
    def add(self, item):
        """Add item to the queuez."""
        pair = (self.key(item), item)
        heapq.heappush(self.items, pair)

    def pop(self):
        """Pop and return the item with min f(item) value."""
        return heapq.heappop(self.items)[1]
    
    def top(self): return self.items[0][1]

    def __len__(self): return len(self.items)

# Search Algorithms: Best-First

Best-first search with various *f(n)* functions gives us different search algorithms. Note that A\*, weighted A\* and greedy search can be given a heuristic function, `h`, but if `h` is not supplied they use the problem's default `h` function (if the problem does not define one, it is taken as *h(n)* = 0).

In [39]:

def best_first_search(problem, f):
    "Search nodes with minimum f(node) value first."
    seconds = time.time()
    NODES_EXPANDED = 0
    STATES_REEXPANDED = 0
    FALSE_NODES = 0
    
    node = Node(problem.initial)
    frontier = PriorityQueue([node], key=f)
    reached = {problem.initial: node}
    expanded = {}
    while frontier:
        # node extraction, control de nodos DUMMY 
        node = frontier.pop() 
        while (node.path_cost > reached[node.state].path_cost):
            FALSE_NODES += 1
            node = frontier.pop() 
        # Control de estados reexpandidos
        if (node.state in expanded):
            STATES_REEXPANDED +=1
        else:
            expanded[node.state] = True
        # cuenta nodos reexpandidos   
        NODES_EXPANDED += 1
        # Monitorización de nodos expandidos y valores de f()
        print(node.state, "f()= ", f(node))
       
        #if objetivo, muestra estadisticas
        if problem.is_goal(node.state):
            print("Tiempo de Ejecucion: ", time.time()-seconds)
            print("Nodos Expandidos: ", NODES_EXPANDED)
            print("Estados Alcanzados: ", len(reached))
            print("Nodos en Frontier: ", len(frontier))
            print("Coste de la Solucion: ", node.path_cost)
            print("Longitud de la Solucion: ", len(node))
            print("Nro. de Estados Reexpandidos: ", STATES_REEXPANDED)
            print("Nro. de Nodos False Descartados: ", FALSE_NODES)
            return node
        # node expansion
        for child in expand(problem, node):
            s = child.state
            if s not in reached or child.path_cost < reached[s].path_cost:
            #if (not s in expanded and (s not in reached or child.path_cost < reached[s].path_cost)): # solo con h() consistente
                reached[s] = child
                frontier.add(child)
    return failure


def best_first_tree_search(problem, f):
    "A version of best_first_search without the `reached` table."
    frontier = PriorityQueue([Node(problem.initial)], key=f)
    while frontier:
        node = frontier.pop()
        if problem.is_goal(node.state):
            return node
        for child in expand(problem, node):
            if not is_cycle(child):
                frontier.add(child)
    return failure


def g(n): return n.path_cost


def astar_search(problem, h=None):
    """Search nodes with minimum f(n) = g(n) + h(n)."""
    h = h or problem.h
    return best_first_search(problem, f=lambda n: g(n) + h(n))


def astar_tree_search(problem, h=None):
    """Search nodes with minimum f(n) = g(n) + h(n), with no `reached` table."""
    h = h or problem.h
    return best_first_tree_search(problem, f=lambda n: g(n) + h(n))


def weighted_astar_search(problem, h=None, weight=1.4):
    """Search nodes with minimum f(n) = g(n) + weight * h(n)."""
    h = h or problem.h
    return best_first_search(problem, f=lambda n: g(n) + weight * h(n))

        
def greedy_bfs(problem, h=None):
    """Search nodes with minimum h(n)."""
    h = h or problem.h
    return best_first_search(problem, f=h)


def uniform_cost_search(problem):
    "Search nodes with minimum path cost first."
    return best_first_search(problem, f=g)


def breadth_first_bfs(problem):
    "Search shallowest nodes in the search tree first; using best-first."
    return best_first_search(problem, f=len)


def depth_first_bfs(problem):
    "Search deepest nodes in the search tree first; using best-first."
    return best_first_search(problem, f=lambda n: -len(n))


def is_cycle(node, k=30):
    "Does this node form a cycle of length k or less?"
    def find_cycle(ancestor, k):
        return (ancestor is not None and k > 0 and
                (ancestor.state == node.state or find_cycle(ancestor.parent, k - 1)))
    return find_cycle(node.parent, k)



# Other Search Algorithms

Here are the other search algorithms:

In [40]:
def breadth_first_search(problem):
    "Search shallowest nodes in the search tree first."
    node = Node(problem.initial)
    if problem.is_goal(problem.initial):
        return node
    frontier = FIFOQueue([node])
    reached = {problem.initial}
    while frontier:
        node = frontier.pop()
        for child in expand(problem, node):
            s = child.state
            if problem.is_goal(s):
                return child
            if s not in reached:
                reached.add(s)
                frontier.appendleft(child)
    return failure


def iterative_deepening_search(problem):
    "Do depth-limited search with increasing depth limits."
    for limit in range(1, sys.maxsize):
        result = depth_limited_search(problem, limit)
        if result != cutoff:
            return result
        
        
def depth_limited_search(problem, limit=10):
    "Search deepest nodes in the search tree first."
    frontier = LIFOQueue([Node(problem.initial)])
    result = failure
    while frontier:
        node = frontier.pop()
        if problem.is_goal(node.state):
            return node
        elif len(node) >= limit:
            result = cutoff
        elif not is_cycle(node):
            for child in expand(problem, node):
                frontier.append(child)
    return result


def depth_first_recursive_search(problem, node=None):
    if node is None: 
        node = Node(problem.initial)
    if problem.is_goal(node.state):
        return node
    elif is_cycle(node):
        return failure
    else:
        for child in expand(problem, node):
            result = depth_first_recursive_search(problem, child)
            if result != failure:
                return result
        return failure

# Problem Domains

Now we turn our attention to defining some problem domains as subclasses of `Problem`.

# 8 Puzzle Problems

![](https://ece.uwaterloo.ca/~dwharder/aads/Algorithms/N_puzzles/images/puz3.png)

A sliding tile puzzle where you can swap the blank with an adjacent piece, trying to reach a goal configuration. The cells are numbered 0 to 8, starting at the top left and going row by row left to right. The pieces are numebred 1 to 8, with 0 representing the blank. An action is the cell index number that is to be swapped with the blank (*not* the actual number to be swapped but the index into the state). So the diagram above left is the state `(5, 2, 7, 8, 4, 0, 1, 3, 6)`, and the action is `8`, because the cell number 8 (the 9th or last cell, the `6` in the bottom right) is swapped with the blank.

There are two disjoint sets of states that cannot be reached from each other. One set has an even number of "inversions"; the other has an odd number. An inversion is when a piece in the state is larger than a piece that follows it.




In [41]:
class EightPuzzle(Problem):
    """ The problem of sliding tiles numbered from 1 to 8 on a 3x3 board,
    where one of the squares is a blank, trying to reach a goal configuration.
    A board state is represented as a tuple of length 9, where the element at index i 
    represents the tile number at index i, or 0 if for the empty square, e.g. the goal:
        1 2 3
        4 5 6 ==> (1, 2, 3, 4, 5, 6, 7, 8, 0)
        7 8 _
    """
    def __init__(self, initial, goal=(1, 2, 3, 8, 0, 4, 7, 6, 5)):
    #def __init__(self, initial, goal=(0, 1, 2, 3, 4, 5, 6, 7, 8)):
        #assert inversions(initial) % 2 == inversions(goal) % 2 # Parity check
        self.initial, self.goal = initial, goal
    
    def actions(self, state):
        """The indexes of the squares that the blank can be moved to."""
        moves = ((1, 3),    (0, 2, 4),    (1, 5),
                 (0, 4, 6), (1, 3, 5, 7), (2, 4, 8),
                 (3, 7),    (4, 6, 8),    (7, 5))
        blank = state.index(0)
        return moves[blank]
    
    def result(self, state, action):
        """Swap the blank with the square numbered `action`."""
        s = list(state)
        blank = state.index(0)
        s[action], s[blank] = s[blank], s[action]
        return tuple(s)
    
    # Costes de acciones distintos de 1
    #def action_cost(self, s, a, s1): return math.exp(s[a]);
    #def action_cost(self, s, a, s1): return s[a]

    
    # heuristic functions
    def h0(self, node):
        """The null heuristic."""
        # es consistente
        return 0

    def h1(self, node):
        """The misplaced tiles heuristic. It has a bug as it counts 0 as a valid tide!!"""
        # es consistente
        return hamming_distance(node.state, self.goal)
    
    def h2_0(self, node):
        # sirve para cualquier estado objetivo
        # es consistente
        suma = 0
        # para cada ficha 1, 2, ..., 8 calcular las distancia de su posicion en node.state a su posicion en self.goal
        # es mas costoso computacionalmente que h2_0
        for i in range(8):
            j=i+1
            p1 = node.state.index(j)
            x1 = p1 % 3
            y1 = p1 // 3
            pg = self.goal.index(j)
            xg = pg % 3
            yg = pg // 3
            suma += abs(x1-xg) + abs(y1-yg)
        return suma
    
    def h2_1(self, node):
        #The Manhattan heuristic, para goal = (1, 2, 3, 8, 0, 4, 7, 6, 5) (tiempo de orden O(n) )
        Xg = (1, 0, 1, 2, 2, 2, 1, 0, 0)
        Yg = (1, 0, 0, 0, 1, 2, 2, 2, 1)
        # Xg[s],Yg[s] = posicion en el objetivo (1, 2, 3, 8, 0, 4, 7, 6, 5) de la ficha s (s = 0, 1, ..., 8)
        return sum(abs(Xg[s] - Xg[g]) + abs(Yg[s] - Yg[g])
                    for (s, g) in zip(node.state, self.goal) if s != 0)
                    # s en state esta en la posicion de g en el goal, i.e., s esta en (Xg[g],Yg[g]) en state
                    # s debe estar en (Xg[s],Yg[s]) en el goal
    
    def h3(self, node):
        # suma 2 por cada ficha a distancia 2 de su posicion en el objetivo
        # es admisible pero no consistente
        # sirve para cualquier estado objetivo
        suma = 0
        # para cada ficha 1, 2, ..., 8 calcular las distancia de su posicion en node.state a su posicion en self.goal
        for i in range(8):
            j=i+1
            p1 = node.state.index(j)
            x1 = p1 % 3
            y1 = p1 // 3
            pg = self.goal.index(j)
            xg = pg % 3
            yg = pg // 3
            if (abs(x1-xg) + abs(y1-yg) == 2):
                suma += 2
        return suma
    
    # Actual heuristic
    def h(self, node): 
       return self.h2_0(node)
      
def hamming_distance(A, B):
    "Number of positions where vectors A and B are different. Corregido para que no considere el 0 como otra ficha"
    return sum(a != b and a!=0 for a, b in zip(A, B))
    

def inversions(board):
    "The number of times a piece is a smaller number than a following piece."
    return sum((a > b and a != 0 and b != 0) for (a, b) in combinations(board, 2))
    
    
def board8(board, fmt=(3 * '{} {} {}\n')):
    "A string representing an 8-puzzle board"
    return fmt.format(*board).replace('0', '_')



In [42]:
# Algunos estados alcanzables desde (1, 2, 3, 8, 0, 4, 7, 6, 5)

e05_1 =  EightPuzzle((1, 0, 3, 8, 2, 5, 7, 4, 6))
e10_1 =  EightPuzzle((8, 2, 1, 7, 0, 3, 6, 5, 4))
e15_1 =  EightPuzzle((4, 8, 2, 6, 3, 5, 1, 0, 7))
e20_1 =  EightPuzzle((6, 2, 7, 4, 5, 1, 0, 8, 3))
e21_1 =  EightPuzzle((7, 0, 2, 3, 6, 1, 5, 8, 4))
e23_1 =  EightPuzzle((6, 3, 2, 5, 8, 1, 4, 0, 7))
e24_1 =  EightPuzzle((6, 3, 2, 5, 7, 8, 0, 4, 1))
e24_2 =  EightPuzzle((3, 5, 6, 4, 2, 7, 0, 8, 1))
e25_1 =  EightPuzzle((6, 7, 4, 0, 5, 1, 3, 2, 8))
e30_1 =  EightPuzzle((5, 6, 7, 2, 8, 4, 0, 3, 1))
e30_2 =  EightPuzzle((5, 6, 7, 4, 0, 8, 3, 2, 1))
e30_3 =  EightPuzzle((5, 4, 7, 6, 0, 3, 8, 2, 1))
e30_4 =  EightPuzzle((3, 8, 7, 4, 0, 6, 5, 2, 1))
e30_5 =  EightPuzzle((5, 6, 3, 4, 0, 2, 7, 8, 1))



In [43]:
# Solve an 8 puzzle problem and print out each state
for s in path_states(astar_search(e30_2)):
    print(board8(s))

(5, 6, 7, 4, 0, 8, 3, 2, 1) f()=  24
(5, 0, 7, 4, 6, 8, 3, 2, 1) f()=  24
(5, 6, 7, 4, 8, 0, 3, 2, 1) f()=  24
(5, 6, 7, 4, 2, 8, 3, 0, 1) f()=  24
(5, 6, 7, 0, 4, 8, 3, 2, 1) f()=  24
(5, 7, 0, 4, 6, 8, 3, 2, 1) f()=  24
(5, 6, 7, 4, 2, 8, 0, 3, 1) f()=  24
(0, 6, 7, 5, 4, 8, 3, 2, 1) f()=  24
(5, 6, 7, 3, 4, 8, 0, 2, 1) f()=  24
(5, 6, 7, 4, 2, 8, 3, 1, 0) f()=  24
(5, 6, 0, 4, 8, 7, 3, 2, 1) f()=  24
(5, 6, 7, 4, 8, 1, 3, 2, 0) f()=  24
(0, 5, 7, 4, 6, 8, 3, 2, 1) f()=  24
(5, 7, 8, 4, 6, 0, 3, 2, 1) f()=  26
(6, 0, 7, 5, 4, 8, 3, 2, 1) f()=  26
(5, 6, 7, 4, 8, 1, 3, 0, 2) f()=  26
(4, 5, 7, 0, 6, 8, 3, 2, 1) f()=  26
(5, 6, 7, 0, 2, 8, 4, 3, 1) f()=  26
(5, 6, 7, 3, 4, 8, 2, 0, 1) f()=  26
(5, 0, 6, 4, 8, 7, 3, 2, 1) f()=  26
(5, 6, 7, 4, 2, 0, 3, 1, 8) f()=  26
(5, 7, 8, 4, 6, 1, 3, 2, 0) f()=  26
(5, 6, 7, 4, 8, 1, 0, 3, 2) f()=  26
(0, 5, 6, 4, 8, 7, 3, 2, 1) f()=  26
(5, 6, 0, 4, 2, 7, 3, 1, 8) f()=  26
(4, 5, 7, 3, 6, 8, 0, 2, 1) f()=  26
(0, 6, 7, 5, 2, 8, 4, 3, 1) f()=  26
(

# Symmetric TSP

![](romania.png)

En el TSP simetrico, los estados son pares de la forma [CiudadActual, CiudadesVisitadas]. El estado inicial es [A, ConjuntoVacio], el estado objetivo es de la forma [A, TodasLasCiudades]. Los sucesores de un estado se generan a partir de las ciudades no visitadas, uno por cada una. El coste de una regla es el coste entre la ciudad actual y la nueva ciudad visitada que pasa a ser la actual

In [44]:
class EstadoTSP:
    def __init__(self, actual, visitadas):
        self.actual = actual
        self.visitadas = visitadas
        lista = visitadas.copy()
        lista.append(actual)
        self.state = tuple(lista)
        self.str = " [Visitadas: " + str(self.visitadas) + " Actual: " + str(self.actual) +"]" 
    def __hash__(self):
        return hash(self.state)
    def __eq__(self,other):
        return (self.actual == other.actual) and (self.visitadas == other.visitadas)
    def __repr__(self): 
        return f'{self.str}'  


In [45]:
class Arco:
    def __init__(self,x,y,coste):
        self.x = x
        self.y = y
        self.coste = coste
        self.str = "[("+ str(self.x) + "," + str(self.y) +") " + str(self.coste) + "]"
    def __lt__(self,other):
        return self.coste < other.coste
    def __repr__(self):
        return f'{self.str}'
    def __eq__(self,other):
        return self.x == other.x and self.y == other.y or self.x == other.y and self.y == other.x
        

class Grafo:
    """Un Grafo es un array bidimensional de tamño N*N, siendo N el número de ciudades,
    representadas por los valores 0,...,N-1; 0 se toma como ciudad de partida.
    El valor de la posición (i,j) es la distancia entre las ciudades i y j,
     que es el mismo que la distancia entre j e i """
    def __init__(self,lista):
        "N es el número de ciudades correspondiente a la lista de valores de una instancia en la sintaxis EDGE_WEIGHT_TYPE"
        self.N = int(((8*len(lista)+1)**0.5)-1)/2
        self.Total = 0
        # print(self.N)
        self.ciudades = list(range(int(self.N)))
        self.Dist = [[0]*int(self.N) for i in self.ciudades]
        x_acc = 0
        for x in range(int(self.N)):
            x_acc += x
            for y in range(x+1):
                self.Dist[x][y] = self.Dist[y][x] = lista[x_acc+y]
                self.Total += 2*lista[x_acc+y]

    def listaArcosOrdenados(self):
        lista = []
        for x in range(int(self.N)):
            for y in range(x):
                insort(lista,Arco(x,y,self.Dist[x][y]))
        #print(lista)
        return lista
    
    def listaCiudadesGrafoResidual(self,state):
        lista = [0]
        for x in range(int(self.N)):
            if (x not in state.visitadas and x != 0):
                lista.append(x)
        return lista
    
    def listaArcosGrafoResidual(self,state):
        listaCGR = self.listaCiudadesGrafoResidual(state)
        if (len(listaCGR) < 2):
            return []
        if (len(listaCGR) == 2):
            return [Arco(state.actual,0,self.Dist[state.actual][0])]
        lista = []
        for x in listaCGR:
            for y in listaCGR:
                if (x < y and not(x == 0 and y == state.actual)):
                    insort(lista,Arco(x,y,self.Dist[x][y]))
        return lista

    def listaCiudadesAbandonar(self,state):
        lista = []
        for x in range(int(self.N)):
            if (x not in state.visitadas and x != 0):
                lista.append(x)
        return lista
    
    def listaCiudadesAlcanzar(self,state):
        lista = [0]
        for x in range(int(self.N)):
            if (x not in state.visitadas and x != state.actual):
                lista.append(x)
        return lista
    
    def matrizAsignacion(self,state):
        size = int(self.N - len(state.visitadas)) # dimension de la matriz de asignacion
        Asig =  [[0]*size for i in range(size)]
        alcanzar = self.listaCiudadesAlcanzar(state)
        abandonar = self.listaCiudadesAbandonar(state)
        i=0
        for x in range(int(self.N)):
            j=0
            if (x in alcanzar):
                for y in range(int(self.N)):
                    if (y in abandonar):
                        if (x == y or (x == 0 and y == state.actual and size > 1)):
                            Asig[i][j] = self.Total # infinito
                        else:
                            Asig[i][j] = self.Dist[x][y]
                        j += 1
                i += 1
        return Asig



In [46]:
# Implementacion del algoritmo Hungaro tomada de:
# https://transport-systems.imperial.ac.uk/tf/60008_21/n2_5_hungarian_algorithm

def hungarian_step(mat): 
    #The for-loop iterates through every column in the matrix so we subtract this value to every element of the row
    for row_num in range(mat.shape[0]): 
        mat[row_num] = mat[row_num] - np.min(mat[row_num])
    
    #We repeat the process for the columns
    for col_num in range(mat.shape[1]): 
        mat[:,col_num] = mat[:,col_num] - np.min(mat[:,col_num])
    
    return mat

def min_zeros(zero_mat, mark_zero):
    # min_row = [number of zeros, row index number]
    min_row = [99999, -1]

    for row_num in range(zero_mat.shape[0]): 
        if np.sum(zero_mat[row_num] == True) > 0 and min_row[0] > np.sum(zero_mat[row_num] == True):
            min_row = [np.sum(zero_mat[row_num] == True), row_num]

    # Marked the specific row and column as False
    zero_index = np.where(zero_mat[min_row[1]] == True)[0][0]
    mark_zero.append((min_row[1], zero_index))
    zero_mat[min_row[1], :] = False
    zero_mat[:, zero_index] = False

def mark_matrix(mat):
    #Transform the matrix to boolean matrix(0 = True, others = False)
    cur_mat = mat
    zero_bool_mat = (cur_mat == 0)
    zero_bool_mat_copy = zero_bool_mat.copy()

    #Recording possible answer positions by marked_zero
    marked_zero = []
    while (True in zero_bool_mat_copy):
        min_zeros(zero_bool_mat_copy, marked_zero)

    #Recording the row and column indexes seperately.
    marked_zero_row = []
    marked_zero_col = []
    for i in range(len(marked_zero)):
        marked_zero_row.append(marked_zero[i][0])
        marked_zero_col.append(marked_zero[i][1])
    
    # mark rows not containing zeros
    non_marked_row = list(set(range(cur_mat.shape[0])) - set(marked_zero_row))
    
    # mark columns with zeros
    marked_cols = []
    check_switch = True
    while check_switch:
        check_switch = False
        for i in range(len(non_marked_row)):
            row_array = zero_bool_mat[non_marked_row[i], :]
            for j in range(row_array.shape[0]):
                if row_array[j] == True and j not in marked_cols:

                    marked_cols.append(j)
                    check_switch = True

        for row_num, col_num in marked_zero:
            if row_num not in non_marked_row and col_num in marked_cols:
                
                non_marked_row.append(row_num)
                check_switch = True
    
    # mark rows with zeros
    marked_rows = list(set(range(mat.shape[0])) - set(non_marked_row))
    
    return(marked_zero, marked_rows, marked_cols)

def adjust_matrix(mat, cover_rows, cover_cols):
    cur_mat = mat
    non_zero_element = []
    
    # find the minimum value of an element not in a marked column/row 
    for row in range(len(cur_mat)):
        if row not in cover_rows:
            for i in range(len(cur_mat[row])):
                if i not in cover_cols:
                    non_zero_element.append(cur_mat[row][i])
    
    min_num = min(non_zero_element)

    # substract to all values not in a marked row/column
    for row in range(len(cur_mat)):
        if row not in cover_rows:
            for i in range(len(cur_mat[row])):
                if i not in cover_cols:
                    cur_mat[row, i] = cur_mat[row, i] - min_num
    # add to all values in marked rows/column
    for row in range(len(cover_rows)):  
        for col in range(len(cover_cols)):
            cur_mat[cover_rows[row], cover_cols[col]] = cur_mat[cover_rows[row], cover_cols[col]] + min_num

    return cur_mat

def hungarian_algorithm(cost_matrix):
    n = cost_matrix.shape[0]
    cur_mat = copy.deepcopy(cost_matrix)
    
    cur_mat = hungarian_step(cur_mat)
    
    count_zero_lines = 0
        
   
    while count_zero_lines < n:
        ans_pos, marked_rows, marked_cols = mark_matrix(cur_mat)
        count_zero_lines = len(marked_rows) + len(marked_cols)

        if count_zero_lines < n:
            cur_mat = adjust_matrix(cur_mat, marked_rows, marked_cols)
        else:
            return ans_pos
    

In [47]:
class SymmetricTSP(Problem):
    """El problema consiste en encontrar un camino hamiltoniano en un grafo no dirigido 
    y totalmente conectado, con costes positivos en los ejes. Las ciudades son 0,...,N-1
    La ciudad de partida es la 0"""

    def __init__(self, grafo):
        self.grafo = grafo
        self.initial = EstadoTSP(0,[])
        #print(self.__dir__initial)
        self.lArcos = grafo.listaArcosOrdenados()
        #print(self.lArcos)

    def actions(self, state):
        """Una accion por cada ciudad no visitada distinta de la actual, si todas visitadas -1"""
        noVisitadas = []
        for c in self.grafo.ciudades: 
            if c not in state.visitadas and c != state.actual:
                noVisitadas.append(c) 
        if len(noVisitadas)==0:
            noVisitadas = [0]
        return noVisitadas
    
    def result(self, state, action): 
        setVisitadas = list(state.visitadas)
        insort(setVisitadas,state.actual)
        return EstadoTSP(action,setVisitadas)
    
    def is_goal(self, state):        
        return state.actual == 0  and  len(state.visitadas) == self.grafo.N

    def action_cost(self, s, a, s1): 
        return self.grafo.Dist[s.actual][a]
        
   # HEURISTICOS, LO MAS IMPORTANTE!!
   
    def h_POBRE(self,node):
        # heuristico muy poco informado, cuenta el numero de ciudades que faltan por visitar, es el valor N-k+1
        return self.grafo.N - len(node.state.visitadas) 

    def h1(self,node):
        # Heuristico que considera los arcos minimos de las ciudades
        # que quedan por abandonar. 
        # No responde a ninguna relajacion, aparentemente
        # Se puede mejorar si consideramos que las ciudades nunca se abandornaran hacia una ciudad ya visitada, excepto la inicial 0
        lista = self.grafo.listaCiudadesAbandonar(node.state)
        h = 0
        for x in lista:
            min = self.lArcos[-1].coste
            for y in range(int(self.grafo.N)):
                if (x!=y and self.grafo.Dist[x][y] < min):
                    min = self.grafo.Dist[x][y]
            h += min
        return h
    
    def h2(self,node):
        # Relajacion R2, R3, R4.
        # Suma de los N-k+1 arcos de menos coste del grafo residual
        listaAGR = self.grafo.listaArcosGrafoResidual(node.state)
        Nk1 = self.grafo.N - len(node.state.visitadas)
        h = 0
        for i in range(int(Nk1)):
            a = listaAGR.pop(0)
            h += a.coste
        return h
    
    def h_MST(self, node):
        # Relajacion R3
        # Coste de un arbol de expansion minimo del grafo residual
        # Se calcula con el algoritmo de Kruskal
        if (node.state == self.initial or self.is_goal(node.state)):
            return 0
        listaAGR = self.grafo.listaArcosGrafoResidual(node.state)
        listaCGR = self.grafo.listaCiudadesGrafoResidual(node.state)
        c = {}
        for x in listaCGR:
            c[x] = x
        nA = self.grafo.N - len(node.state.visitadas) 
        h = 0 
        while(nA > 0):
            a = listaAGR.pop(0)
            cx = self.particion(c,a.x)
            cy = self.particion(c,a.y)
            #print(cx,cy)
            if (cx != cy):
                h += a.coste
                c[cx] = cy
                nA -= 1
        return h
        
    def particion(self,c,x):
        if (c[x] == x):
            return x
        else:
            return self.particion(c,c[x])

    
    def h_HUNGARO(self, node):
        # Relajacion R3
        # Coste de la asignacion minima entre los conjuntos
        # CIUDADES_NO_VISITADAS U {ACTUAL} y CIUDADES_NO_VISITADAS U {INICIAL}
        # Es una aproximacion al coste optimo del problema relajado (R3)
        # ya que pueden aparecer arcos repetidos en la solucion

        if (node.state == self.initial or self.is_goal(node.state)):
            return 0
        #print(node.state)
        cost_matrix = np.asarray(self.grafo.matrizAsignacion(node.state))
        asignacion = hungarian_algorithm(np.asarray(cost_matrix))
        coste = 0
        for asig in asignacion:
            coste += cost_matrix[asig[0],asig[1]]
        return coste
    
    def h_NN(self,node):
        if (self.is_goal(node.state)):
            return 0
        
        toVisit = self.grafo.listaCiudadesAlcanzar(node.state)

        ret = 0
        from_ = node.state.actual
        to = -1
        while toVisit:
            min = -1
            for i in toVisit:
                if (min == -1):
                    min = self.grafo.Dist[from_][i]
                    to = i
                else:
                    b = self.grafo.Dist[from_][i]
                    if b < min:
                        min = b
                        to = i
            if to in toVisit:
                toVisit.remove(to)
            else:
                break
            from_ = to
            ret += min

        #print("Sale del heurístico")
        return ret + self.grafo.Dist[to][self.initial.actual]
  
    def h_BI(self,node):  
        ret = self.grafo.Dist[node.state.actual][self.initial.actual]

        toVisit = self.grafo.listaCiudadesAbandonar(node.state)
        path = [node.state.actual]

        while toVisit:
            toVisit2 = toVisit.copy()
            real_min = math.inf
            real_index = real_city = -1
            
            while toVisit2:
                next_city = toVisit2.pop()
                min = math.inf
                new_index = -1

                for index, first_city in enumerate(path):
                    second_city = path[index+1] if index +1 < len(path) else self.initial.actual
                    new_cost = self.grafo.Dist[first_city][next_city] + self.grafo.Dist[next_city][second_city] - self.grafo.Dist[first_city][second_city]
                    if new_cost < min:
                        min = new_cost
                        new_index = index

                if min < real_min:
                    real_min = min
                    real_index = new_index
                    real_city = next_city


            path.insert(real_index + 1, real_city)
            ret += real_min
            toVisit.remove(real_city)

        #print(path)
        return ret

        
    def pathCost(self, list):
        dist = 0
        for i in range(1, len(list)):
            prev = list[i - 1]
            next = list[i]
            dist += self.grafo.Dist[prev][next]

        return dist


    # Actual heuristic
    def h(self,node):
        return self.h_HUNGARO(node)


In [48]:
# Datos de instancias en formato EDGE_WEIGHT_TYPE de la TSPLIB

ejemploClase = [ 
    0,
    21, 0,
    12, 7, 0,
    15, 32, 5, 0,
    113, 25, 18, 180, 0,
    92, 9, 20, 39, 17, 0
]

gr21 = [
    0, 
    510, 0, 
    635, 355, 0, 
    91, 415, 605, 0, 
    385, 585, 390, 350, 0, 
    155, 475, 495, 120, 240, 0, 
    110, 480, 570, 78, 320, 96, 0, 
    130, 500, 540, 97, 285, 36, 29, 0, 
    490, 605, 295, 460, 120, 350, 425, 390, 0, 
    370, 320, 700, 280, 590, 365, 350, 370, 625, 0, 
    155, 380, 640, 63, 430, 200, 160, 175, 535, 240, 0, 
    68, 440, 575, 27, 320, 91, 48, 67, 430, 300, 90, 0, 
    610, 360, 705, 520, 835, 605, 590, 610, 865, 250, 480, 545, 0, 
    655, 235, 585, 555, 750, 615, 625, 645, 775, 285, 515, 585, 190, 0, 
    480, 81, 435, 380, 575, 440, 455, 465, 600, 245, 345, 415, 295, 170, 0, 
    265, 480, 420, 235, 125, 125, 200, 165, 230, 475, 310, 205, 715, 650, 475, 0, 
    255, 440, 755, 235, 650, 370, 320, 350, 680, 150, 175, 265, 400, 435, 385, 485, 0, 
    450, 270, 625, 345, 660, 430, 420, 440, 690, 77, 310, 380, 180, 215, 190, 545, 225, 0, 
    170, 445, 750, 160, 495, 265, 220, 240, 600, 235, 125, 170, 485, 525, 405, 375, 87, 315, 0, 
    240, 290, 590, 140, 480, 255, 205, 220, 515, 150, 100, 170, 390, 425, 255, 395, 205, 220, 155, 0, 
    380, 140, 495, 280, 480, 340, 350, 370, 505, 185, 240, 310, 345, 280, 105, 380, 280, 165, 305, 150, 0,
]

gr17 = [
    0, 
    633, 0, 
    257, 390, 0, 
    91, 661, 228, 0, 
    412, 227, 169, 383, 0, 
    150, 488, 112, 120, 267, 0, 
    80, 572, 196, 77, 351, 63, 0, 
    134, 530, 154, 105, 309, 34, 29, 0, 
    259, 555, 372, 175, 338, 264, 232, 249, 0, 
    505, 289, 262, 476, 196, 360, 444, 402, 495, 0, 
    353, 282, 110, 324, 61, 208, 292, 250, 352, 154, 0, 
    324, 638, 437, 240, 421, 329, 297, 314, 95, 578, 435, 0, 
    70, 567, 191, 27, 346, 83, 47, 68, 189, 439, 287, 254, 0, 
    211, 466, 74, 182, 243, 105, 150, 108, 326, 336, 184, 391, 145, 0, 
    268, 420, 53, 239, 199, 123, 207, 165, 383, 240, 140, 448, 202, 57, 0, 
    246, 745, 472, 237, 528, 364, 332, 349, 202, 685, 542, 157, 289, 426, 483, 0, 
    121, 518, 142, 84, 297, 35, 29, 36, 236, 390, 238, 301, 55, 96, 153, 336, 0,
 ]

gr48 = [0, 593, 0, 409, 258, 0, 566, 331, 171, 0, 633, 586, 723, 874, 0, 257, 602, 522, 679, 390, 0, 91, 509, 325, 482, 598, 228, 0, 412, 627, 506, 663, 227, 169, 383, 0, 378, 755, 634, 791, 397, 175, 349, 167, 0, 593, 416, 564, 721, 271, 445, 509, 293, 429, 0, 150, 598, 414, 571, 488, 112, 120, 267, 233, 541, 0, 659, 488, 630, 787, 205, 511, 575, 304, 470, 76, 607, 0, 80, 566, 382, 539, 572, 196, 77, 351, 317, 563, 63, 629, 0, 434, 893, 699, 856, 524, 231, 405, 303, 138, 595, 289, 606, 373, 0, 455, 417, 433, 590, 313, 304, 371, 228, 394, 158, 399, 224, 425, 530, 0, 134, 583, 399, 566, 530, 154, 105, 309, 275, 575, 34, 638, 29, 298, 434, 0, 649, 945, 824, 981, 446, 423, 620, 357, 280, 649, 504, 648, 588, 416, 584, 546, 0, 259, 364, 180, 337, 555, 272, 175, 338, 466, 403, 264, 469, 232, 549, 265, 249, 656, 0, 505, 354, 110, 70, 819, 618, 421, 602, 730, 660, 509, 728, 478, 795, 529, 494, 920, 276, 0, 710, 117, 375, 354, 679, 693, 626, 720, 848, 533, 715, 610, 683, 986, 534, 700, 1038, 481, 345, 0, 488, 784, 663, 820, 289, 262, 459, 196, 119, 488, 343, 502, 427, 255, 423, 385, 161, 495, 759, 877, 0, 353, 641, 520, 677, 282, 110, 324, 61, 125, 353, 208, 364, 292, 261, 288, 250, 315, 352, 616, 734, 154, 0, 324, 275, 91, 248, 638, 437, 240, 421, 549, 486, 329, 552, 297, 614, 348, 314, 739, 95, 187, 392, 578, 435, 0, 605, 287, 431, 588, 313, 445, 520, 470, 598, 143, 610, 215, 577, 734, 144, 595, 788, 352, 527, 404, 627, 484, 385, 0, 372, 229, 39, 196, 686, 485, 288, 469, 597, 511, 397, 578, 345, 662, 396, 361, 787, 143, 135, 346, 626, 483, 54, 377, 0, 330, 484, 361, 518, 378, 119, 260, 150, 278, 323, 174, 389, 276, 414, 185, 207, 468, 193, 475, 577, 307, 164, 276, 326, 324, 0, 581, 877, 756, 913, 370, 355, 552, 289, 212, 581, 436, 571, 520, 348, 516, 478, 84, 588, 852, 970, 93, 247, 671, 720, 719, 400, 0, 154, 460, 276, 433, 612, 298, 63, 453, 419, 460, 190, 526, 158, 475, 322, 175, 690, 126, 372, 577, 529, 396, 191, 471, 239, 250, 622, 0, 70, 523, 339, 496, 569, 191, 27, 346, 312, 516, 83, 589, 47, 368, 385, 68, 583, 189, 435, 640, 422, 287, 254, 534, 302, 249, 515, 115, 0, 606, 183, 216, 147, 715, 719, 522, 703, 831, 549, 611, 615, 579, 896, 546, 596, 1021, 377, 139, 209, 860, 717, 288, 416, 242, 558, 953, 473, 536, 0, 585, 427, 563, 720, 179, 437, 501, 196, 362, 80, 532, 108, 558, 498, 163, 567, 552, 395, 659, 544, 391, 256, 478, 154, 526, 318, 484, 452, 515, 556, 0, 544, 840, 719, 876, 311, 318, 515, 252, 175, 508, 399, 494, 483, 311, 479, 441, 154, 551, 815, 933, 65, 210, 634, 683, 682, 363, 77, 585, 479, 916, 399, 0, 496, 525, 595, 751, 147, 253, 468, 85, 251, 208, 351, 236, 435, 387, 162, 393, 441, 427, 691, 646, 280, 145, 509, 249, 558, 239, 373, 538, 430, 654, 128, 336, 0, 317, 289, 105, 262, 631, 430, 233, 414, 542, 479, 332, 545, 290, 607, 341, 307, 732, 88, 201, 406, 571, 428, 21, 407, 68, 269, 664, 184, 247, 302, 471, 627, 503, 0, 648, 68, 316, 362, 584, 598, 564, 625, 753, 418, 653, 484, 621, 891, 415, 638, 943, 395, 412, 95, 782, 639, 333, 285, 287, 482, 875, 515, 578, 209, 425, 838, 523, 347, 0, 211, 660, 476, 633, 466, 74, 182, 243, 171, 489, 66, 555, 150, 227, 351, 108, 432, 326, 572, 777, 271, 184, 391, 492, 439, 166, 364, 252, 145, 673, 438, 327, 327, 384, 715, 0, 475, 137, 295, 452, 437, 428, 391, 452, 580, 271, 480, 337, 448, 718, 268, 465, 770, 222, 391, 254, 609, 466, 255, 138, 241, 309, 702, 342, 405, 287, 278, 665, 376, 277, 167, 542, 0, 654, 151, 319, 266, 755, 767, 570, 751, 879, 561, 659, 627, 627, 944, 558, 644, 1069, 425, 262, 103, 908, 765, 336, 428, 290, 606, 1001, 521, 584, 122, 568, 964, 666, 350, 169, 721, 299, 0, 710, 239, 487, 546, 616, 660, 626, 687, 815, 443, 715, 509, 683, 953, 440, 700, 1005, 457, 583, 279, 844, 701, 490, 310, 458, 544, 937, 577, 640, 393, 450, 900, 548, 512, 179, 777, 229, 353, 0, 585, 135, 385, 458, 499, 535, 501, 562, 690, 333, 590, 399, 558, 828, 330, 575, 880, 332, 481, 215, 719, 576, 365, 200, 356, 419, 812, 452, 515, 318, 340, 775, 438, 387, 120, 652, 104, 289, 121, 0, 246, 373, 183, 340, 745, 472, 237, 528, 656, 593, 364, 659, 332, 649, 455, 349, 846, 202, 279, 490, 685, 542, 157, 525, 144, 383, 778, 174, 289, 386, 585, 741, 618, 132, 431, 426, 395, 434, 630, 505, 0, 788, 208, 456, 488, 724, 738, 704, 765, 893, 558, 793, 624, 761, 1031, 555, 778, 1083, 535, 552, 188, 922, 779, 473, 425, 427, 622, 1015, 655, 718, 343, 565, 978, 663, 487, 138, 855, 307, 284, 138, 235, 571, 0, 446, 162, 111, 268, 624, 559, 362, 543, 671, 458, 451, 524, 419, 736, 455, 436, 861, 217, 207, 279, 700, 557, 128, 325, 82, 398, 793, 313, 376, 175, 465, 756, 563, 142, 220, 513, 187, 223, 391, 289, 226, 360, 0, 166, 437, 247, 404, 749, 435, 150, 590, 556, 597, 402, 663, 295, 612, 459, 387, 827, 189, 343, 554, 666, 531, 221, 589, 208, 372, 759, 137, 177, 450, 589, 722, 675, 196, 495, 389, 459, 498, 694, 569, 80, 635, 290, 0, 523, 81, 188, 255, 596, 636, 439, 620, 648, 430, 528, 496, 496, 813, 427, 513, 938, 294, 284, 193, 777, 634, 205, 297, 159, 475, 870, 390, 453, 119, 437, 833, 535, 219, 139, 590, 168, 131, 310, 208, 303, 279, 92, 367, 0, 235, 371, 187, 344, 581, 348, 151, 364, 469, 429, 240, 495, 208, 525, 291, 225, 682, 32, 283, 488, 521, 378, 103, 384, 150, 219, 614, 94, 165, 384, 421, 577, 454, 92, 429, 302, 254, 432, 489, 364, 165, 569, 224, 154, 301, 0, 369, 205, 289, 446, 537, 328, 286, 355, 483, 371, 375, 437, 343, 554, 269, 360, 673, 116, 385, 322, 512, 369, 149, 238, 230, 209, 605, 237, 300, 352, 378, 568, 445, 172, 281, 436, 108, 332, 343, 218, 290, 421, 164, 354, 201, 149, 0, 121, 570, 386, 543, 518, 142, 84, 297, 263, 570, 35, 636, 29, 319, 432, 36, 534, 236, 482, 687, 373, 238, 301, 581, 349, 222, 466, 162, 55, 583, 562, 429, 381, 294, 625, 96, 452, 631, 687, 562, 336, 765, 423, 299, 500, 212, 347, 0]

instance = SymmetricTSP(Grafo(gr21))
#path_actions(astar_search(instance))
path_states(astar_search(instance))
#path_states(weighted_astar_search(instance, weight=0.9))


 [Visitadas: [] Actual: 0] f()=  0
 [Visitadas: [0] Actual: 11] f()=  2421
 [Visitadas: [0, 11] Actual: 10] f()=  2419
 [Visitadas: [0, 10, 11] Actual: 19] f()=  2419
 [Visitadas: [0] Actual: 3] f()=  2421
 [Visitadas: [0] Actual: 6] f()=  2422
 [Visitadas: [0, 3] Actual: 11] f()=  2422
 [Visitadas: [0, 6] Actual: 11] f()=  2422
 [Visitadas: [0, 11] Actual: 3] f()=  2429
 [Visitadas: [0, 6] Actual: 3] f()=  2429
 [Visitadas: [0, 11] Actual: 6] f()=  2431
 [Visitadas: [0, 3] Actual: 6] f()=  2431
 [Visitadas: [0, 6] Actual: 7] f()=  2458
 [Visitadas: [0] Actual: 5] f()=  2458
 [Visitadas: [0, 6, 7] Actual: 5] f()=  2458
 [Visitadas: [0, 5, 6, 7] Actual: 4] f()=  2408
 [Visitadas: [0, 5] Actual: 7] f()=  2458
 [Visitadas: [0, 3, 11] Actual: 5] f()=  2458
 [Visitadas: [0, 3, 5, 11] Actual: 7] f()=  2458
 [Visitadas: [0, 5, 6, 7] Actual: 11] f()=  2458
 [Visitadas: [0, 3, 11] Actual: 6] f()=  2460
 [Visitadas: [0, 3, 6, 11] Actual: 7] f()=  2460
 [Visitadas: [0, 5, 7] Actual: 6] f()=  2460

[ [Visitadas: [] Actual: 0],
  [Visitadas: [0] Actual: 11],
  [Visitadas: [0, 11] Actual: 3],
  [Visitadas: [0, 3, 11] Actual: 10],
  [Visitadas: [0, 3, 10, 11] Actual: 19],
  [Visitadas: [0, 3, 10, 11, 19] Actual: 18],
  [Visitadas: [0, 3, 10, 11, 18, 19] Actual: 16],
  [Visitadas: [0, 3, 10, 11, 16, 18, 19] Actual: 9],
  [Visitadas: [0, 3, 9, 10, 11, 16, 18, 19] Actual: 17],
  [Visitadas: [0, 3, 9, 10, 11, 16, 17, 18, 19] Actual: 12],
  [Visitadas: [0, 3, 9, 10, 11, 12, 16, 17, 18, 19] Actual: 13],
  [Visitadas: [0, 3, 9, 10, 11, 12, 13, 16, 17, 18, 19] Actual: 14],
  [Visitadas: [0, 3, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19] Actual: 20],
  [Visitadas: [0, 3, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20] Actual: 1],
  [Visitadas: [0, 1, 3, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20] Actual: 2],
  [Visitadas: [0, 1, 2, 3, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20] Actual: 8],
  [Visitadas: [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20] Actual: 4],
  [Visitadas: [0, 1, 2, 3, 4, 