# Visualization

In [15]:
def generate_states(states, current_state, visited):
    if len(current_state) == 9:
        states.append(list(current_state))
        return
    for i in range(9):
        if visited[i] == False:
            visited[i] = True
            current_state.append(i)
            generate_states(states, current_state, visited)
            current_state.pop()
            visited[i] = False

In [22]:
def print_grid(state):
    print("===============")
    for i in range(9):
        print(state[i],end='')
        if i%3==2: print()
        else: print(end=' ')
    print("===============")

In [29]:
states = []
current_state = []
visited = [False for i in range(9)]

generate_states(states,current_state,visited) # O(n!)

no_of_states = len(states) # 9! = 362880
goal_state = [
    1, 2, 3,
    8, 0, 4,
    7, 6, 5,]

initial_state = [
    3, 1, 2,
    8, 4, 0,
    7, 6, 5,
]

all_possible_goal_states = [
    goal_state, [
        7, 8, 1,
        6, 0, 2,
        5, 4, 3
    ],[
        5, 6, 7,
        4, 0, 8,
        3, 2, 1
    ],[
        3, 4, 5,
        2, 0, 6,
        1, 8, 7
    ]
]

# Multiple goal states by symmetry, hence, go with forward chaining (data-driven search)

In [3]:
# sum(initial_state[i] != goal_state[i] for i in range(9))
initial_state.index(0) # O(n)

temp = []
for gs in all_possible_goal_states:
    temp.append(states.index(gs))
    print(temp[-1])

46685
318845
230903
138487


In [4]:
b = sorted(temp)
print(b[1] - b[0])
print(b[2] - b[1])
print(b[3] - b[2])
print(no_of_states//4)
print((91802 + 92416 + 87942)//3)

91802
92416
87942
90720
90720


In [5]:
# No. of tiles out of place heuristics
h1 = lambda state: sum(state[i] == goal_state[i] for i in range(9))

h1(initial_state)

4

In [6]:
# Manhattan distance from original place heuristics
def h2(state):
    cost = 0;
    for idx in range(9):
        val = state[idx]
        goal = goal_state.index(val)
        temp = abs(idx%3 - goal%3) + abs(idx//3 - goal//3)
        # modulo is column no. and floor division is row no.
        cost += temp

    return cost

h2(initial_state)

6

In [30]:
ti = initial_state.index(0)
up = True
down = True
left = True
right = True

match ti%3:
    case 0: left = False # leftmost column
    case 2: right = False # rightmost column

match ti//3:
    case 0: up = False # top row
    case 2: down = False # bottom row

print(f"up:\t{up}\ndown:\t{down}\nright:\t{right}\nleft:\t{left}")

up:	True
down:	True
right:	False
left:	True


In [31]:
next1 = None
next2 = None
mini1 = 1e9
mini2 = 1e9

print("Initial:\t",h1(initial_state), h2(initial_state))

if up:
    temp = list(initial_state)
    temp[ti],temp[ti-3] = temp[ti-3],temp[ti]
    print("Up:\t\t",h1(temp),h2(temp))
    if (h1(temp) < mini1):
        mini1 = h1(temp)
        next1 = temp
    if (h2(temp) < mini2):
        mini2 = h2(temp)
        next2 = temp

if down:
    temp = list(initial_state)
    temp[ti],temp[ti+3] = temp[ti+3],temp[ti]
    print("Down:\t\t",h1(temp),h2(temp))
    if (h1(temp) < mini1):
        mini1 = h1(temp)
        next1 = temp
    if (h2(temp) < mini2):
        mini2 = h2(temp)
        next2 = temp

if left:
    temp = list(initial_state)
    temp[ti],temp[ti-1] = temp[ti-1],temp[ti]
    print("Left:\t\t",h1(temp),h2(temp))
    if (h1(temp) < mini1):
        mini1 = h1(temp)
        next1 = temp
    if (h2(temp) < mini2):
        mini2 = h2(temp)
        next2 = temp

if right:
    temp = list(initial_state)
    temp[ti],temp[ti+1] = temp[ti+1],temp[ti]
    print("Right:\t\t",h1(temp),h2(temp))
    if (h1(temp) < mini1):
        mini1 = h1(temp)
        next1 = temp
    if (h2(temp) < mini2):
        mini2 = h2(temp)
        next2 = temp

print_grid(initial_state)
print_grid(next1)
print_grid(next2)

Initial:	 4 6
Up:		 4 8
Down:		 3 8
Left:		 6 4
3 1 2
8 4 0
7 6 5
3 1 2
8 4 5
7 6 0
3 1 2
8 0 4
7 6 5


In [32]:
initial_state = next2
ti = initial_state.index(0)
up = True
down = True
left = True
right = True

match ti%3:
    case 0: left = False # leftmost column
    case 2: right = False # rightmost column

match ti//3:
    case 0: up = False # top row
    case 2: down = False # bottom row

next1 = None
next2 = None
mini1 = 1e9
mini2 = 1e9

print("Initial:\t",h1(initial_state), h2(initial_state))

if up:
    temp = list(initial_state)
    temp[ti],temp[ti-3] = temp[ti-3],temp[ti]
    print("Up:\t\t",h1(temp),h2(temp))
    if (h1(temp) < mini1):
        mini1 = h1(temp)
        next1 = temp
    if (h2(temp) < mini2):
        mini2 = h2(temp)
        next2 = temp

if down:
    temp = list(initial_state)
    temp[ti],temp[ti+3] = temp[ti+3],temp[ti]
    print("Down:\t\t",h1(temp),h2(temp))
    if (h1(temp) < mini1):
        mini1 = h1(temp)
        next1 = temp
    if (h2(temp) < mini2):
        mini2 = h2(temp)
        next2 = temp

if left:
    temp = list(initial_state)
    temp[ti],temp[ti-1] = temp[ti-1],temp[ti]
    print("Left:\t\t",h1(temp),h2(temp))
    if (h1(temp) < mini1):
        mini1 = h1(temp)
        next1 = temp
    if (h2(temp) < mini2):
        mini2 = h2(temp)
        next2 = temp

if right:
    temp = list(initial_state)
    temp[ti],temp[ti+1] = temp[ti+1],temp[ti]
    print("Right:\t\t",h1(temp),h2(temp))
    if (h1(temp) < mini1):
        mini1 = h1(temp)
        next1 = temp
    if (h2(temp) < mini2):
        mini2 = h2(temp)
        next2 = temp

print_grid(initial_state)
print_grid(next1)
print_grid(next2)

Initial:	 6 4
Up:		 5 6
Down:		 4 6
Left:		 4 6
Right:		 4 6
3 1 2
8 0 4
7 6 5
3 1 2
8 6 4
7 0 5
3 0 2
8 1 4
7 6 5
