In [26]:
import itertools

GOALS = {
    1: (1, 2, 3, 4),
    2: (1, 2, 3, 4, 5, 6, 7, 8),
    3: (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
}

OFFSETS = {'u': -4, 'd': 4, 'l': -1, 'r': 1}

def get_legal_action(idx, action):
    row, col = divmod(idx, 4)
    if action == 'u' and row == 0: return False
    if action == 'd' and row == 3: return False
    if action == 'l' and col == 0: return False
    if action == 'r' and col == 3: return False
    return True

def get_obs(state, stage):
    limit = 4 if stage == 1 else (8 if stage == 2 else 16)
    res = list(state)
    for i in range(16):
        if res[i] != 16 and res[i] > limit:
            res[i] = 0
    return tuple(res)



In [27]:
def create_states(stage):
    s_list = []
    if stage == 1:
        for p in itertools.permutations(range(16), 5):
            grid = [0] * 16
            for i in range(4): grid[p[i]] = i + 1
            grid[p[4]] = 16
            s_list.append(tuple(grid))
    elif stage == 2:
        for p in itertools.permutations(range(4, 16), 5):
            grid = [1, 2, 3, 4] + [0] * 12
            for i in range(4): grid[p[i]] = i + 5
            grid[p[4]] = 16
            s_list.append(tuple(grid))
    else:
        for p in itertools.permutations(range(8, 16)):
            grid = list(range(1, 9)) + [0] * 8
            for i, val in enumerate(p[:-1]): grid[val] = i + 9
            grid[p[-1]] = 16
            s_list.append(tuple(grid))
    return s_list

def solve_mdp(stage):
    states = create_states(stage)
    target = GOALS[stage]
    t_len = len(target)
    
    v = {s: 0.0 for s in states}
    gamma = 0.9  
    r_cfg = {1: (67, -2), 2: (140, -5), 3: (250, -8)}
    win_r, step_p = r_cfg[stage]

    for _ in range(25): 
        new_v = v.copy()
        for s in states:
            if s[:t_len] == target: continue
            
            b_idx = s.index(16)
            best_q = -float('inf')
            
            for a in ['u', 'd', 'l', 'r']:
                if not get_legal_action(b_idx, a):
                    q = -67 + gamma * v[s]
                else:
                    nxt_idx = b_idx + OFFSETS[a]
                    ns_list = list(s)
                    ns_list[b_idx], ns_list[nxt_idx] = ns_list[nxt_idx], ns_list[b_idx]
                    
                    if (stage >= 2 and any(ns_list[i] != i+1 for i in range(4))) or \
                       (stage == 3 and any(ns_list[i] != i+1 for i in range(8))):
                        q = -67 + gamma * v[s]
                    else:
                        ns = tuple(ns_list)
                        # Reward logic
                        correct = sum(1 for i in range(t_len) if ns[i] == target[i])
                        reward = win_r if ns[:t_len] == target else step_p + (0.5 * correct)
                        q = reward + gamma * v[ns]
                
                if q > best_q: best_q = q
            new_v[s] = best_q
        v = new_v
    policy = {}
    for s in states:
        b_idx = s.index(16)
        best_a, best_q = 'u', -float('inf')
        for a in ['u', 'd', 'l', 'r']:
            if not get_legal_action(b_idx, a):
                q = -67 + gamma * v[s]
            else:
                nxt_idx = b_idx + OFFSETS[a]
                ns_list = list(s)
                ns_list[b_idx], ns_list[nxt_idx] = ns_list[nxt_idx], ns_list[b_idx]
                
                if (stage >= 2 and any(ns_list[i] != i+1 for i in range(4))) or \
                   (stage == 3 and any(ns_list[i] != i+1 for i in range(8))):
                    q = -67 + gamma * v[s]
                else:
                    ns = tuple(ns_list)
                    correct = sum(1 for i in range(t_len) if ns[i] == target[i])
                    reward = win_r if ns[:t_len] == target else step_p + (0.5 * correct)
                    q = reward + gamma * v[ns]
            
            if q > best_q: best_q, best_a = q, a
        policy[s] = best_a
    return policy


In [28]:
def draw(s):
    line = "+----+----+----+----+"
    print(line)
    for i in range(0, 16, 4):
        row = ["" if x == 16 else x for x in s[i:i+4]]
        print("|{:^4}|{:^4}|{:^4}|{:^4}|".format(*row))
        print(line)

In [29]:
board = (11, 15, 12, 3, 2, 13, 10, 5, 1, 7, 4, 8, 6, 16, 9, 14)
move_count = 0
p1 = solve_mdp(1)
p2 = solve_mdp(2)
p3 = solve_mdp(3)

policies = {1: p1, 2: p2, 3: p3}
draw(board)

for stage in [1, 2, 3]:
    goal = GOALS[stage]
    while board[:len(goal)] != goal:
        current_p = policies[stage]
        obs = get_obs(board, stage)
        if obs not in current_p:
            print("State not found in policy. The board might have been corrupted.")
            break
            
        move = current_p[obs]
        b_idx = board.index(16)
        if not get_legal_action(b_idx, move):
            print(f"Policy suggested invalid move {move}. Stopping.")
            break
            
        target_idx = b_idx + OFFSETS[move]
        lst = list(board)
        lst[b_idx], lst[target_idx] = lst[target_idx], lst[b_idx]
        board = tuple(lst)
        
        move_count += 1
        print(f"Move {move_count}: {move.upper()}")
        draw(board)
    print(f"--- Stage {stage} Cleared ---")

print("Puzzle solved.")

+----+----+----+----+
| 11 | 15 | 12 | 3  |
+----+----+----+----+
| 2  | 13 | 10 | 5  |
+----+----+----+----+
| 1  | 7  | 4  | 8  |
+----+----+----+----+
| 6  |    | 9  | 14 |
+----+----+----+----+
Move 1: U
+----+----+----+----+
| 11 | 15 | 12 | 3  |
+----+----+----+----+
| 2  | 13 | 10 | 5  |
+----+----+----+----+
| 1  |    | 4  | 8  |
+----+----+----+----+
| 6  | 7  | 9  | 14 |
+----+----+----+----+
Move 2: U
+----+----+----+----+
| 11 | 15 | 12 | 3  |
+----+----+----+----+
| 2  |    | 10 | 5  |
+----+----+----+----+
| 1  | 13 | 4  | 8  |
+----+----+----+----+
| 6  | 7  | 9  | 14 |
+----+----+----+----+
Move 3: U
+----+----+----+----+
| 11 |    | 12 | 3  |
+----+----+----+----+
| 2  | 15 | 10 | 5  |
+----+----+----+----+
| 1  | 13 | 4  | 8  |
+----+----+----+----+
| 6  | 7  | 9  | 14 |
+----+----+----+----+
Move 4: L
+----+----+----+----+
|    | 11 | 12 | 3  |
+----+----+----+----+
| 2  | 15 | 10 | 5  |
+----+----+----+----+
| 1  | 13 | 4  | 8  |
+----+----+----+----+
| 6  | 7  | 9 