In [1]:
from pocket_cube.cube import Cube
from pocket_cube.cube import Move

%matplotlib notebook

# Creating a scrambled cube 

In [2]:
cube = Cube("F' R U R U F' U'")

## 2D Cube visualization 

In [3]:
cube.render()

<IPython.core.display.Javascript object>

## Text representation 

In [4]:
cube.render_text()

  YB
  BB
OWRRYYOG
GOWWROBY
  GG
  RW


## 3D Cube visualization

For an interactive view of the cube, don't forget to use `%matplotlib notebook`

In [5]:
cube.render3D()

<IPython.core.display.Javascript object>

# Creating an unscrambled cube 

In [6]:
cube = Cube(scrambled=False)

In [7]:
cube.render3D()

<IPython.core.display.Javascript object>

## Applying moves on the cube 

In [8]:
cube = cube.move(Move.R)
cube = cube.move(Move.F)

In [9]:
cube.render3D()

<IPython.core.display.Javascript object>

In [10]:
cube.render3D_moves(cube.state, [Move.Fp, Move.Rp])

<IPython.core.display.Javascript object>

<matplotlib.animation.FuncAnimation at 0x11feee630>

In [11]:
def get_neighbours(cube_state):
    return [Cube.move_state(Cube.clone_state(cube_state), Move.F), 
            Cube.move_state(Cube.clone_state(cube_state), Move.R),
            Cube.move_state(Cube.clone_state(cube_state), Move.Up),
            Cube.move_state(Cube.clone_state(cube_state), Move.Fp), 
            Cube.move_state(Cube.clone_state(cube_state), Move.Rp),
            Cube.move_state(Cube.clone_state(cube_state), Move.U)]


In [12]:
from heapq import heappush, heappop
import numpy as np
import time

def astar(pocket_cube, h):
    frontier = []
    heappush(frontier, (0 + h(pocket_cube), tuple(pocket_cube.state)))
    discovered = {tuple(pocket_cube.state): (None, 0)}
    state_number = 0
    
    start_time = time.time()

    while frontier:
        current_cost, current_state_tuple = heappop(frontier)
        current_state = np.array(current_state_tuple)
        
        cloned_cube = pocket_cube.clone()
        cloned_cube.state = current_state
        
        
        if np.array_equal(current_state, pocket_cube.goal_state):
            end_time = time.time()
            execution_time = end_time - start_time
            return discovered[current_state_tuple][1], state_number, execution_time
        
        for neighbour_state in get_neighbours(cloned_cube):
            neighbour_cloned_cube = pocket_cube.clone()
            neighbour_cloned_cube.state = neighbour_state
        
            new_cost = discovered[current_state_tuple][1] + 1
            
            neighbour_state_tuple = tuple(neighbour_state)
            if neighbour_state_tuple not in discovered or new_cost < discovered[neighbour_state_tuple][1]:
                state_number += 1
                discovered[neighbour_state_tuple] = (current_state_tuple, new_cost)
                total_cost = new_cost + h(neighbour_cloned_cube)
                heappush(frontier, (total_cost, neighbour_state_tuple))
    
    end_time = time.time()
    execution_time = end_time - start_time    
    return [], state_number, execution_time


In [13]:
from tests import case1, case2, case3, case4
from pocket_cube.heuristics import h1, h2

cube1 = Cube(case1)
print(astar(cube1, h1))
cube2 = Cube(case2)
print(astar(cube2, h1))
cube3 = Cube(case3)
print(astar(cube3, h1))
cube4 = Cube(case4)
print(astar(cube4, h1))


(5, 330, 0.011132955551147461)
(7, 4717, 0.12952303886413574)
(9, 69705, 1.9491169452667236)
(11, 727614, 23.745145082473755)


In [14]:
from collections import deque
import time

def bfs_bidirectional(pocket_cube):
    start_time = time.time()

    start_queue = deque([(pocket_cube.state, [])])
    goal_queue = deque([(pocket_cube.goal_state, [])])
    start_visited = {tuple(pocket_cube.state)}
    goal_visited = {tuple(pocket_cube.goal_state)}

    states_visited = 0

    while start_queue and goal_queue:
        start_state, start_path = start_queue.popleft()
        goal_state, goal_path = goal_queue.popleft()

        if tuple(start_state) in goal_visited:
            end_time = time.time()
            execution_time = end_time - start_time
            return len(start_path) + len(goal_path) + 1, states_visited, execution_time

        if tuple(goal_state) in start_visited:
            end_time = time.time()
            execution_time = end_time - start_time
            return len(start_path) + len(goal_path) + 1, states_visited, execution_time

        cloned_cube = pocket_cube.clone()
        cloned_cube.state = start_state
        start_neighbors = get_neighbours(cloned_cube)
        
        neighbour_cloned_cube = pocket_cube.clone()
        neighbour_cloned_cube.state = goal_state
        goal_neighbors = get_neighbours(neighbour_cloned_cube)

        for neighbor in start_neighbors:
            neighbor_tuple = tuple(neighbor)
            if neighbor_tuple not in start_visited:
                start_queue.append((neighbor_tuple, start_path + [neighbor_tuple]))
                start_visited.add(neighbor_tuple)
                states_visited += 1

        for neighbor in goal_neighbors:
            neighbor_tuple = tuple(neighbor)
            if neighbor_tuple not in goal_visited:
                goal_queue.append((neighbor_tuple, goal_path + [neighbor_tuple]))
                goal_visited.add(neighbor_tuple)
                states_visited += 1

    end_time = time.time()
    execution_time = end_time - start_time
    return None, states_visited, execution_time


In [15]:
from tests import case1, case2, case3, case4

cube1 = Cube(case1)
print(bfs_bidirectional(cube1))
cube2 = Cube(case2)
print(bfs_bidirectional(cube2))
cube3 = Cube(case3)
print(bfs_bidirectional(cube3))
cube4 = Cube(case4)
print(bfs_bidirectional(cube4))


(5, 154, 0.0017058849334716797)
(7, 1126, 0.008130788803100586)
(9, 4634, 0.0338289737701416)
(11, 15096, 0.11613082885742188)


In [16]:
import matplotlib.pyplot as plt
import numpy as np

# Date pentru A*
a_star_data = [
    (5, 330, 0.04852151870727539),
    (7, 4717, 0.8870642185211182),
    (9, 69705, 9.04358458518982),
    (11, 727614, 104.44711661338806)
]

# Date pentru BFS
bfs_data = [
    (5, 154, 0.0),
    (7, 1126, 0.03228878974914551),
    (9, 4634, 0.11890625953674316),
    (11, 15096, 0.39571046829223633)
]

# Extrageți informațiile relevante pentru grafic
a_star_lengths, a_star_states, a_star_times = zip(*a_star_data)
bfs_lengths, bfs_states, bfs_times = zip(*bfs_data)

bar_width = 0.35
index = np.arange(len(a_star_lengths))

# Creează subplots pentru fiecare metrică
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 12))

# Diagrama de bare pentru lungimea căii de soluție
ax1.bar(index, a_star_lengths, bar_width, label='A*', color='b', edgecolor='black')
ax1.bar(index + bar_width, bfs_lengths, bar_width, label='BFS', color='r', edgecolor='black', alpha=0.7)
ax1.set_xticks(index + bar_width / 2)
ax1.set_xticklabels(['5', '7', '9', '11'])
ax1.set_ylabel('Solution Path Length')
ax1.legend()

# Diagrama de bare pentru numărul de stări vizitate
ax2.bar(index, a_star_states, bar_width, label='A*', color='b', edgecolor='black')
ax2.bar(index + bar_width, bfs_states, bar_width, label='BFS', color='r', edgecolor='black', alpha=0.7)
ax2.set_xticks(index + bar_width / 2)
ax2.set_xticklabels(['5', '7', '9', '11'])
ax2.set_ylabel('States Visited')
ax2.legend()

# Diagrama de bare pentru timpul de execuție
ax3.bar(index, a_star_times, bar_width, label='A*', color='b', edgecolor='black')
ax3.bar(index + bar_width, bfs_times, bar_width, label='BFS', color='r', edgecolor='black', alpha=0.7)
ax3.set_xticks(index + bar_width / 2)
ax3.set_xticklabels(['5', '7', '9', '11'])
ax3.set_ylabel('Execution Time (s)')
ax3.legend()

# Adaugă etichete pentru fiecare bară
for ax in [ax1, ax2, ax3]:
    for i, v in enumerate(a_star_lengths):
        ax.text(i + bar_width / 2, v + 5, str(v), color='black', ha='center', va='bottom')

    for i, v in enumerate(bfs_lengths):
        ax.text(i + bar_width * 1.5, v + 5, str(v), color='black', ha='center', va='bottom')

# Afișează graficele
plt.tight_layout()
plt.show()
plt.savefig("Astar-vs-BFS.png")

<IPython.core.display.Javascript object>

In [17]:
# constante
N = 'N'
Q = 'Q'
PARENT = 'parent'
ACTIONS = 'actions'

# funcție ce întoarce un nod nou,
# eventual copilul unui nod dat ca argument
def init_node(parent = None):
    return {N: 0, Q: 0, PARENT: parent, ACTIONS: {}}

In [18]:
CP = 0.5 # schimba la 0.5 pentru o noua testare

# funcție ce alege o acțiune dintr-un nod
def select_action(node, c = CP):
    """
    Se caută acțiunea a care maximizează expresia:
    Q_a / N_a  +  c * sqrt(2 * log(N_node) / N_a)
    """
    N_node = node[N]
    max_score = -1
    best_action = None

    for a, n in node[ACTIONS].items():
        crt_score = n[Q] / n[N] + c * sqrt(2 * log(N_node) / n[N])

        if max_score < crt_score:
            max_score = crt_score
            best_action = a

    return best_action


In [19]:
from math import sqrt, log
import numpy as np

def calculate_solution_path_length(final_tree, best_action, length):
    if not final_tree:
        return length

    if best_action is None:
        best_action = select_action(final_tree)

    if best_action in final_tree[ACTIONS]:
        return 1 + calculate_solution_path_length(final_tree[ACTIONS][best_action], best_action, length + 1)
    
    return length

In [20]:
from random import choice
#  Algoritmul MCTS (UCT)
#  state0 - starea pentru care trebuie aleasă o acțiune
#  budget - numărul de iterații permis
#  tree - un arbore din explorările anterioare
#  opponent_s_action - ultima acțiune a adversarului

def mcts(pocket_cube, budget):
    start_time = time.time()
    
    solution_found = False

    # Number of states explored
    states_explored = 0
    
    # DACĂ există un arbore construit anterior ȘI
    #   acesta are un copil ce corespunde ultimei acțiuni a adversarului,
    # ATUNCI acel copil va deveni nodul de început pentru algoritm.
    # ALTFEL, arborele de start este un nod gol.
    tree = init_node(pocket_cube.state)
    
    #---------------------------------------------------------------
    for x in range(budget):
        # Punctul de start al simulării va fi rădăcina de start
        cloned_cube = pocket_cube.clone()
        state = cloned_cube.state
        node = tree

        # Coborâm în arbore până când ajungem la o stare finală
        # sau la un nod cu acțiuni neexplorate.
        # Variabilele state și node se 'mută' împreună.
        while (not np.array_equal(state, pocket_cube.goal_state)
            and all(tuple(action) in node[ACTIONS] for action in get_neighbours(cloned_cube))
        ):
            new_action = select_action(node)
            state = new_action
            cloned_cube.state = state
            node = node[ACTIONS][tuple(new_action)]
            states_explored += 1
        
        #---------------------------------------------------------------
        # Dacă am ajuns într-un nod care nu este final și din care nu s-au
        # `încercat` toate acțiunile, construim un nod nou.
        if not np.array_equal(state, pocket_cube.goal_state):
            new_action = choice(list(filter(lambda a: tuple(a) not in node[ACTIONS], get_neighbours(cloned_cube))))
            state = new_action
            cloned_cube.state = state
            node = init_node(node)
            node[PARENT][ACTIONS][tuple(new_action)] = node
            states_explored += 1

        #---------------------------------------------------------------
        # Se simulează o desfășurare a jocului până la ajungerea într-o
        # starea finală. Se evaluează recompensa în acea stare.
        limitation = 0
        while not np.array_equal(state, pocket_cube.goal_state) and limitation <= 14:
            state = choice(get_neighbours(cloned_cube))
            cloned_cube.state = state
            limitation += 1
        
        if  h1(cloned_cube) == 0:
            reward = 1
        else:
            reward = 1 / h1(cloned_cube)
                   
        if np.array_equal(state, pocket_cube.goal_state):
            solution_found = True
        
        #---------------------------------------------------------------
        # Se actualizează toate nodurile de la node către rădăcină:
        #  - se incrementează valoarea N din fiecare nod
        #  - se adaugă recompensa la valoarea Q
        crt_node = node
        while crt_node is not None and PARENT in crt_node:
            crt_node[N] += 1
            crt_node[Q] += reward
            crt_node = crt_node[PARENT]
        #---------------------------------------------------------------
    end_time = time.time()
    execution_time = end_time - start_time
    
    if tree:
        final_action = select_action(tree, 0.0)
        return solution_found, execution_time, states_explored, tree[ACTIONS][final_action]
    

In [21]:
from tests import case1, case2, case3, case4

num_iterations = 20

total_execution_time = 0
total_states_explored = 0
solution_found_count = 0

for _ in range(num_iterations):
    start_time = time.time()

    cube = Cube(case1)  # Inlocuieste cu cazul de testare corespunzator

    solution_found, execution_time, states_explored, output_tree = mcts(cube, 1000)

    total_execution_time += execution_time
    total_states_explored += states_explored
    solution_found_count += int(solution_found)

# Calculate averages
average_execution_time = total_execution_time / num_iterations
average_states_explored = total_states_explored / num_iterations
solution_found_percentage = (solution_found_count / num_iterations) * 100

print(f"Average Execution Time: {average_execution_time} seconds")
print(f"Average States Explored: {average_states_explored}")
print(f"Solutions Found in {solution_found_percentage}% of runs")


Average Execution Time: 0.17842202186584472 seconds
Average States Explored: 3694.05
Solutions Found in 35.0% of runs


In [22]:
import matplotlib.pyplot as plt
import numpy as np

def plot_mcts_data(mcts_data, title_suffix=''):
    # Extrageți informațiile relevante pentru grafic
    budgets = sorted({budget for cp_data in mcts_data.values() for budget in cp_data})
    cp_values = sorted(mcts_data.keys())
    bar_width = 0.35
    index = np.arange(len(budgets))

    # Creează subplots pentru fiecare metrică
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 12))

    # Diagrama de bare pentru timpul de execuție
    for i, (cp_label, cp_data) in enumerate(mcts_data.items()):
        avg_execution_times = [cp_data[budget]['avg_execution_time'] for budget in budgets]
        ax1.bar(index + i * bar_width, avg_execution_times, bar_width, label=cp_label, edgecolor='black')

    ax1.set_xticks(index + bar_width)
    ax1.set_xticklabels(budgets)
    ax1.set_ylabel('Average Execution Time (s)')
    ax1.legend()
    ax1.set_title(f'Average Execution Time - {title_suffix}')

    # Diagrama de bare pentru numărul de stări vizitate
    for i, (cp_label, cp_data) in enumerate(mcts_data.items()):
        avg_states_explored = [cp_data[budget]['avg_states_explored'] for budget in budgets]
        ax2.bar(index + i * bar_width, avg_states_explored, bar_width, label=cp_label, edgecolor='black')

    ax2.set_xticks(index + bar_width)
    ax2.set_xticklabels(budgets)
    ax2.set_ylabel('Average States Explored')
    ax2.legend()
    ax2.set_title(f'Average States Explored - {title_suffix}')

    # Diagrama de bare pentru procentajul de soluții găsite
    for i, (cp_label, cp_data) in enumerate(mcts_data.items()):
        solutions_found_percentage = [cp_data[budget]['solutions_found_percentage'] for budget in budgets]
        ax3.bar(index + i * bar_width, solutions_found_percentage, bar_width, label=cp_label, edgecolor='black')

    ax3.set_xticks(index + bar_width)
    ax3.set_xticklabels(budgets)
    ax3.set_ylabel('Solutions Found Percentage')
    ax3.legend()
    ax3.set_title(f'Solutions Found Percentage - {title_suffix}')

    # Afișează graficele
    plt.tight_layout()
    plt.show()
    filename = "Mcts " + title_suffix + ".png"
    plt.savefig(filename)


In [23]:
mcts_data_case1_h1 = {
    'CP=0.1': {
        1000: {'avg_execution_time': 0.8059555888175964, 'avg_states_explored': 3840.45, 'solutions_found_percentage': 50.0},
        5000: {'avg_execution_time': 4.401730930805206, 'avg_states_explored': 23630.75, 'solutions_found_percentage': 95.0},
        10000: {'avg_execution_time': 6.942336964607239, 'avg_states_explored': 49702.05, 'solutions_found_percentage': 100.0},
        20000: {'avg_execution_time': 9.942156326770782, 'avg_states_explored': 99257.8, 'solutions_found_percentage': 100.0}
    },
    'CP=0.5': {
        1000: {'avg_execution_time': 0.8311913967132568, 'avg_states_explored': 3694.2, 'solutions_found_percentage': 35.0},
        5000: {'avg_execution_time': 4.160805523395538, 'avg_states_explored': 23140.8, 'solutions_found_percentage': 95.0},
        10000: {'avg_execution_time': 7.212765645980835, 'avg_states_explored': 48339.85, 'solutions_found_percentage': 100.0},
        20000: {'avg_execution_time': 10.980429780483245, 'avg_states_explored': 98380.05, 'solutions_found_percentage': 100.0}
    }
}

# Afișează graficele pentru "Case 1 with h1"
plot_mcts_data(mcts_data_case1_h1, title_suffix='Case 1 with h1')

mcts_data_case1_h2 = {
    'CP=0.1': {
        1000: {'avg_execution_time': 0.8680823922157288, 'avg_states_explored': 3707.85, 'solutions_found_percentage': 50.0},
        5000: {'avg_execution_time': 3.2603790283203127, 'avg_states_explored': 24193.25, 'solutions_found_percentage': 90.0},
        10000: {'avg_execution_time': 5.004893231391907, 'avg_states_explored': 51141.25, 'solutions_found_percentage': 100.0},
        20000: {'avg_execution_time': 9.213014876842498, 'avg_states_explored': 105109.85, 'solutions_found_percentage': 100.0}
    },
    'CP=0.5': {
        1000: {'avg_execution_time': 0.7550688266754151, 'avg_states_explored': 3694.95, 'solutions_found_percentage': 35.0},
        5000: {'avg_execution_time': 4.197380530834198, 'avg_states_explored': 23145.3, 'solutions_found_percentage': 80.0},
        10000: {'avg_execution_time': 5.832189357280731, 'avg_states_explored': 48164.45, 'solutions_found_percentage': 100.0},
        20000: {'avg_execution_time': 9.269361066818238, 'avg_states_explored': 98171.35, 'solutions_found_percentage': 100.0}
    }
}

# Afișează graficele pentru "Case 1 with h2"
plot_mcts_data(mcts_data_case1_h2, title_suffix='Case 1 with h2')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [24]:
mcts_data_case2_h1 = {
    'CP=0.5': {
        1000: {'avg_execution_time': 0.7819520831108093, 'avg_states_explored': 3694.1, 'solutions_found_percentage': 10.0},
        5000: {'avg_execution_time': 4.520701360702515, 'avg_states_explored': 23140.4, 'solutions_found_percentage': 25.0},
        10000: {'avg_execution_time': 9.248458051681519, 'avg_states_explored': 49030.2, 'solutions_found_percentage': 40.0},
        20000: {'avg_execution_time': 19.234029841423034, 'avg_states_explored': 108810.1, 'solutions_found_percentage': 60.0}
    },
    'CP=0.1': {
        1000: {'avg_execution_time': 0.8167448282241822, 'avg_states_explored': 3775.8, 'solutions_found_percentage': 5.0},
        5000: {'avg_execution_time': 4.428177201747895, 'avg_states_explored': 23554.7, 'solutions_found_percentage': 20.0},
        10000: {'avg_execution_time': 9.321510195732117, 'avg_states_explored': 51051.1, 'solutions_found_percentage': 50.0},
        20000: {'avg_execution_time': 18.18908680677414, 'avg_states_explored': 111967.6, 'solutions_found_percentage': 75.0}
    }
}

# Afișează graficele pentru "Case 2 with h1"
plot_mcts_data(mcts_data_case2_h1, title_suffix='Case 2 with h1')

mcts_data_case2_h2 = {
    'CP=0.5': {
        1000: {'avg_execution_time': 0.7478758692741394, 'avg_states_explored': 3694.2, 'solutions_found_percentage': 20.0},
        5000: {'avg_execution_time': 3.9582502126693724, 'avg_states_explored': 23140.0, 'solutions_found_percentage': 25.0},
        10000: {'avg_execution_time': 8.069659554958344, 'avg_states_explored': 48810.0, 'solutions_found_percentage': 35.0},
        20000: {'avg_execution_time': 16.65192800760269, 'avg_states_explored': 108810.1, 'solutions_found_percentage': 55.0}
    },
    'CP=0.1': {
        1000: {'avg_execution_time': 0.8555134057998657, 'avg_states_explored': 3696.35, 'solutions_found_percentage': 10.0},
        5000: {'avg_execution_time': 3.9102444648742676, 'avg_states_explored': 23146.1, 'solutions_found_percentage': 20.0},
        10000: {'avg_execution_time': 7.985168600082398, 'avg_states_explored': 48898.4, 'solutions_found_percentage': 45.0},
        20000: {'avg_execution_time': 16.86486370563507, 'avg_states_explored': 108877.65, 'solutions_found_percentage': 50.0}
    }
}

# Afișează graficele pentru "Case 2 with h2"
plot_mcts_data(mcts_data_case2_h2, title_suffix='Case 2 with h2')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [25]:
mcts_data_case3_h1 = {
    'CP=0.1': {
        1000: {'avg_execution_time': 0.803343653678894, 'avg_states_explored': 3776.55, 'solutions_found_percentage': 0.0},
        5000: {'avg_execution_time': 4.200398147106171, 'avg_states_explored': 23495.75, 'solutions_found_percentage': 5.0},
        10000: {'avg_execution_time': 9.196272361278535, 'avg_states_explored': 50925.5, 'solutions_found_percentage': 15.0},
        20000: {'avg_execution_time': 17.868041157722473, 'avg_states_explored': 110037.95, 'solutions_found_percentage': 25.0}
    },
    'CP=0.5': {
        1000: {'avg_execution_time': 0.8060226202011108, 'avg_states_explored': 3694.1, 'solutions_found_percentage': 0.0},
        5000: {'avg_execution_time': 4.500201451778412, 'avg_states_explored': 23140.35, 'solutions_found_percentage': 0.0},
        10000: {'avg_execution_time': 8.983766090869903, 'avg_states_explored': 49003.0, 'solutions_found_percentage': 15.0},
        20000: {'avg_execution_time': 18.248717999458314, 'avg_states_explored': 108810.0, 'solutions_found_percentage': 10.0}
    }
}

# Afișează graficele pentru "Case 3 with h1"
plot_mcts_data(mcts_data_case3_h1, title_suffix='Case 3 with h1')

mcts_data_case3_h2 = {
    'CP=0.1': {
        1000: {'avg_execution_time': 0.787688148021698, 'avg_states_explored': 3694.0, 'solutions_found_percentage': 0.0},
        5000: {'avg_execution_time': 4.429369473457337, 'avg_states_explored': 23140.0, 'solutions_found_percentage': 0.0},
        10000: {'avg_execution_time': 8.046462261676789, 'avg_states_explored': 48810.0, 'solutions_found_percentage': 10.0},
        20000: {'avg_execution_time': 18.521410048007965, 'avg_states_explored': 108810.0, 'solutions_found_percentage': 30.0}
    },
    'CP=0.5': {
        1000: {'avg_execution_time': 0.7605180501937866, 'avg_states_explored': 3694.0, 'solutions_found_percentage': 0.0},
        5000: {'avg_execution_time': 4.439774870872498, 'avg_states_explored': 23140.0, 'solutions_found_percentage': 0.0},
        10000: {'avg_execution_time': 8.129756927490234, 'avg_states_explored': 48810.0, 'solutions_found_percentage': 10.0},
        20000: {'avg_execution_time': 16.661398327350618, 'avg_states_explored': 108810.0, 'solutions_found_percentage': 20.0}
    }
}

# Afișează graficele pentru "Case 3 with h2"
plot_mcts_data(mcts_data_case3_h2, title_suffix='Case 3 with h2')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [26]:
mcts_data_case4_h1 = {
    'CP=0.1': {
        1000: {'avg_execution_time': 0.8869012713432312, 'avg_states_explored': 3763.35, 'solutions_found_percentage': 0.0},
        5000: {'avg_execution_time': 4.339821016788482, 'avg_states_explored': 23468.8, 'solutions_found_percentage': 0.0},
        10000: {'avg_execution_time': 9.631971180438995, 'avg_states_explored': 50739.9, 'solutions_found_percentage': 0.0},
        20000: {'avg_execution_time': 17.97297908067703, 'avg_states_explored': 109911.7, 'solutions_found_percentage': 0.0}
    },
    'CP=0.5': {
        1000: {'avg_execution_time': 0.8809452891349793, 'avg_states_explored': 3694.05, 'solutions_found_percentage': 0.0},
        5000: {'avg_execution_time': 4.278359234333038, 'avg_states_explored': 23140.6, 'solutions_found_percentage': 0.0},
        10000: {'avg_execution_time': 8.734858787059784, 'avg_states_explored': 48996.05, 'solutions_found_percentage': 1.0},
        20000: {'avg_execution_time': 18.10296994447708, 'avg_states_explored': 108810.0, 'solutions_found_percentage': 0.0}
    }
}

# # Afișează graficele pentru "Case 4 with h1"
plot_mcts_data(mcts_data_case4_h1, title_suffix='Case 4 with h1')


<IPython.core.display.Javascript object>

In [27]:

mcts_data_case4_h2 = {
    'CP=0.1': {
        1000: {'avg_execution_time': 0.7550360441207886, 'avg_states_explored': 3694.0, 'solutions_found_percentage': 0.0},
        5000: {'avg_execution_time': 3.945626199245453, 'avg_states_explored': 23140.0, 'solutions_found_percentage': 0.0},
        10000: {'avg_execution_time': 8.077935636043549, 'avg_states_explored': 48848.85, 'solutions_found_percentage': 1.0},
        20000: {'avg_execution_time': 18.3463036775589, 'avg_states_explored': 108810.0, 'solutions_found_percentage': 1.0}
    },
    'CP=0.5': {
        1000: {'avg_execution_time': 0.7705209493637085, 'avg_states_explored': 3694.0, 'solutions_found_percentage': 0.0},
        5000: {'avg_execution_time': 3.9690882682800295, 'avg_states_explored': 23140.0, 'solutions_found_percentage': 0.0},
        10000: {'avg_execution_time': 8.556566905975341, 'avg_states_explored': 48810.0, 'solutions_found_percentage': 5.0},
        20000: {'avg_execution_time': 17.030357658863068, 'avg_states_explored': 108810.0, 'solutions_found_percentage': 0.0}
    }
}

# Afișează graficele pentru "Case 4 with h2"
plot_mcts_data(mcts_data_case4_h2, title_suffix='Case 4 with h2')


<IPython.core.display.Javascript object>

In [28]:
import matplotlib.pyplot as plt
import numpy as np

# Date pentru A*
a_star_data = [
    (5, 330, 0.04852151870727539),
    (7, 4717, 0.8870642185211182),
    (9, 69705, 9.04358458518982),
    (11, 727614, 104.44711661338806)
]

# Date pentru BFS
bfs_data = [
    (5, 154, 0.0),
    (7, 1126, 0.03228878974914551),
    (9, 4634, 0.11890625953674316),
    (11, 15096, 0.39571046829223633)
]

# Date pentru MCTS => am ales pentru fiecare caz varianta cu h2, CP = 0.5 si buget 10000
mcts_data = [
    (5, 48164, 5.832189357280731),
    (7, 48810, 8.069659554958344),
    (9, 48810, 8.129756927490234),
    (11, 108810, 8.556566905975341)
]

# Extrageți informațiile relevante pentru grafic
a_star_lengths, a_star_states, a_star_times = zip(*a_star_data)
bfs_lengths, bfs_states, bfs_times = zip(*bfs_data)
mcts_lengths, mcts_states, mcts_times = zip(*mcts_data)

bar_width = 0.2
index = np.arange(len(a_star_lengths))

# Culori atractive pentru bare
a_star_color = 'dodgerblue'
bfs_color = 'tomato'
mcts_color = 'limegreen'

# Creează subplots pentru fiecare metrică
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 12))

# Diagrama de bare pentru lungimea căii de soluție
ax1.bar(index, a_star_lengths, bar_width, label='A*', color=a_star_color, edgecolor='black')
ax1.bar(index + bar_width, bfs_lengths, bar_width, label='BFS', color=bfs_color, edgecolor='black', alpha=0.7)
ax1.bar(index + 2 * bar_width, mcts_lengths, bar_width, label='MCTS', color=mcts_color, edgecolor='black', alpha=0.7)
ax1.set_xticks(index + 1.5 * bar_width)
ax1.set_xticklabels(['5', '7', '9', '11'])
ax1.set_ylabel('Solution Path Length')
ax1.legend()

# Diagrama de bare pentru numărul de stări vizitate
ax2.bar(index, a_star_states, bar_width, label='A*', color=a_star_color, edgecolor='black')
ax2.bar(index + bar_width, bfs_states, bar_width, label='BFS', color=bfs_color, edgecolor='black', alpha=0.7)
ax2.bar(index + 2 * bar_width, mcts_states, bar_width, label='MCTS', color=mcts_color, edgecolor='black', alpha=0.7)
ax2.set_xticks(index + 1.5 * bar_width)
ax2.set_xticklabels(['5', '7', '9', '11'])
ax2.set_ylabel('States Visited')
ax2.legend()

# Diagrama de bare pentru timpul de execuție
ax3.bar(index, a_star_times, bar_width, label='A*', color=a_star_color, edgecolor='black')
ax3.bar(index + bar_width, bfs_times, bar_width, label='BFS', color=bfs_color, edgecolor='black', alpha=0.7)
ax3.bar(index + 2 * bar_width, mcts_times, bar_width, label='MCTS', color=mcts_color, edgecolor='black', alpha=0.7)
ax3.set_xticks(index + 1.5 * bar_width)
ax3.set_xticklabels(['5', '7', '9', '11'])
ax3.set_ylabel('Execution Time (s)')
ax3.legend()

# Adaugă etichete pentru fiecare bară
for ax in [ax1, ax2, ax3]:
    for i, v in enumerate(a_star_lengths):
        ax.text(i + 0.5 * bar_width, v + 5, str(v), color='black', ha='center', va='bottom')

    for i, v in enumerate(bfs_lengths):
        ax.text(i + 1.5 * bar_width, v + 5, str(v), color='black', ha='center', va='bottom')

    for i, v in enumerate(mcts_lengths):
        ax.text(i + 2.5 * bar_width, v + 5, str(v), color='black', ha='center', va='bottom')

# Afișează graficele
plt.tight_layout()
plt.show()

plt.savefig("Astar-vs-BFS-vs-best-MCTS")

<IPython.core.display.Javascript object>

In [29]:
from collections import deque
import time

def build_pattern_database(pocket_cube, max_distance):
    start_time = time.time()

    queue = deque([(pocket_cube.goal_state, 0)])
    visited = {tuple(pocket_cube.goal_state): 0}

    states_visited = 0

    while queue:
        current_state, distance = queue.popleft()

        if distance >= max_distance:
            break

        cloned_cube = pocket_cube.clone()
        cloned_cube.state = current_state
        neighbors = get_neighbours(cloned_cube)

        for neighbor in neighbors:
            neighbor_tuple = tuple(neighbor)
            if neighbor_tuple not in visited:
                queue.append((neighbor_tuple, distance + 1))
                visited[neighbor_tuple] = distance + 1
                states_visited += 1

    end_time = time.time()
    execution_time = end_time - start_time
    return visited, states_visited, execution_time


In [30]:
from heapq import heappush, heappop
import numpy as np
from pocket_cube.heuristics import h1, h2, h3

def astar_pattern_db(pocket_cube, pattern_database):
    frontier = []
    heappush(frontier, (0 + h3(pocket_cube, pattern_database), tuple(pocket_cube.state)))
    discovered = {tuple(pocket_cube.state): (None, 0)}
    state_number = 0
    
    start_time = time.time()

    while frontier:
        current_cost, current_state_tuple = heappop(frontier)
        current_state = np.array(current_state_tuple)
        
        cloned_cube = pocket_cube.clone()
        cloned_cube.state = current_state
        
        
        if np.array_equal(current_state, pocket_cube.goal_state):
            end_time = time.time()
            execution_time = end_time - start_time
            return discovered[current_state_tuple][1], state_number, execution_time
        
        for neighbour_state in get_neighbours(cloned_cube):
            neighbour_cloned_cube = pocket_cube.clone()
            neighbour_cloned_cube.state = neighbour_state
        
            new_cost = discovered[current_state_tuple][1] + 1
            
            neighbour_state_tuple = tuple(neighbour_state)
            if neighbour_state_tuple not in discovered or new_cost < discovered[neighbour_state_tuple][1]:
                state_number += 1
                discovered[neighbour_state_tuple] = (current_state_tuple, new_cost)
                total_cost = new_cost + h3(neighbour_cloned_cube, pattern_database)
                heappush(frontier, (total_cost, neighbour_state_tuple))
    
    end_time = time.time()
    execution_time = end_time - start_time    
    return [], state_number, execution_time


In [31]:
from tests import case1, case2, case3, case4
from pocket_cube.heuristics import h3

cube1 = Cube(case1)
pattern_database1, states_visited1, execution_time1 = build_pattern_database(cube1, 7)
print(execution_time1)
print(astar_pattern_db(cube1, pattern_database1))
cube2 = Cube(case2)
pattern_database2, states_visited2, execution_time2 = build_pattern_database(cube2, 7)
print(execution_time2)
print(astar_pattern_db(cube2, pattern_database2))
cube3 = Cube(case3)
pattern_database3, states_visited3, execution_time3 = build_pattern_database(cube3, 7)
print(execution_time3)
print(astar_pattern_db(cube3, pattern_database3))
cube4 = Cube(case4)
pattern_database4, states_visited4, execution_time4 = build_pattern_database(cube4, 7)
print(execution_time4)
print(astar_pattern_db(cube4, pattern_database4))


0.3937978744506836
(5, 26, 0.0009090900421142578)
0.40517497062683105
(7, 2339, 0.06157398223876953)
0.346498966217041
(9, 74934, 2.197964906692505)
0.32654309272766113
(11, 773295, 26.446946620941162)


In [32]:
from random import choice
from pocket_cube.heuristics import h3

def mcts_with_pattern_db(pocket_cube, budget, pattern_database):
    start_time = time.time()
    solution_found = False
    states_explored = 0
    tree = init_node(pocket_cube.state)
    
    #---------------------------------------------------------------
    for x in range(budget):
        cloned_cube = pocket_cube.clone()
        state = cloned_cube.state
        node = tree

        while (not np.array_equal(state, pocket_cube.goal_state)
            and all(tuple(action) in node[ACTIONS] for action in get_neighbours(cloned_cube))
        ):
            new_action = select_action(node)
            state = new_action
            cloned_cube.state = state
            node = node[ACTIONS][tuple(new_action)]
            states_explored += 1
        
        #---------------------------------------------------------------
        if not np.array_equal(state, pocket_cube.goal_state):
            new_action = choice(list(filter(lambda a: tuple(a) not in node[ACTIONS], get_neighbours(cloned_cube))))
            state = new_action
            cloned_cube.state = state
            node = init_node(node)
            node[PARENT][ACTIONS][tuple(new_action)] = node
            states_explored += 1

        #---------------------------------------------------------------
        limitation = 0
        while not np.array_equal(state, pocket_cube.goal_state) and limitation <= 14:
            state = choice(get_neighbours(cloned_cube))
            cloned_cube.state = state
            limitation += 1
        
        if  h3(cloned_cube, pattern_database) == 0:
            reward = 1
        else:
            reward = 1 / h3(cloned_cube, pattern_database)
                   
        if np.array_equal(state, pocket_cube.goal_state):
            solution_found = True
        
        #---------------------------------------------------------------
        crt_node = node
        while crt_node is not None and PARENT in crt_node:
            crt_node[N] += 1
            crt_node[Q] += reward
            crt_node = crt_node[PARENT]
        #---------------------------------------------------------------
    end_time = time.time()
    execution_time = end_time - start_time
    
    if tree:
        final_action = select_action(tree, 0.0)
        return solution_found, execution_time, states_explored, tree[ACTIONS][final_action]
    

In [33]:
from tests import case1, case2, case3, case4

num_iterations = 20

total_execution_time = 0
total_states_explored = 0
solution_found_count = 0

for _ in range(num_iterations):
    start_time = time.time()

    cube = Cube(case1)  # Inlocuieste cu cazul de testare corespunzator

    solution_found, execution_time, states_explored, output_tree = mcts_with_pattern_db(cube, 10000, pattern_database1)

    total_execution_time += execution_time
    total_states_explored += states_explored
    solution_found_count += int(solution_found)

# Calculate averages
average_execution_time = total_execution_time / num_iterations
average_states_explored = total_states_explored / num_iterations
solution_found_percentage = (solution_found_count / num_iterations) * 100

print(f"Average Execution Time: {average_execution_time} seconds")
print(f"Average States Explored: {average_states_explored}")
print(f"Solutions Found in {solution_found_percentage}% of runs")


Average Execution Time: 1.9978081345558167 seconds
Average States Explored: 49302.8
Solutions Found in 95.0% of runs


In [34]:
import matplotlib.pyplot as plt
import numpy as np

# Datasets
astar_with_pattern_db = [
    (5, 26, 0.0),
    (7, 2339, 0.2520279884338379),
    (9, 74934, 9.087616920471191),
    (11, 773295, 110.13220047950745)
]

mcts_data_with_pattern_db = [
    (5, 49280.6, 8.644882607460023),
    (7, 49251.05, 8.566396868228912),
    (9, 49142.1, 8.915444457530976),
    (11, 49095.65, 8.268261063098908)
]

# Extracting information for the bar chart
astar_lengths, astar_states, astar_times = zip(*astar_with_pattern_db)
mcts_lengths, mcts_states, mcts_times = zip(*mcts_data_with_pattern_db)

bar_width = 0.35
index = np.arange(len(astar_lengths))

# Create subplots for each metric
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 12))

# Bar chart for solution path length
ax1.bar(index, astar_lengths, bar_width, label='A*', color='skyblue', edgecolor='black')
ax1.bar(index + bar_width, mcts_lengths, bar_width, label='MCTS', color='lightcoral', edgecolor='black', alpha=0.7)
ax1.set_xticks(index + bar_width / 2)
ax1.set_xticklabels(['5', '7', '9', '11'])
ax1.set_ylabel('Solution Path Length')
ax1.legend()

# Bar chart for number of states visited
ax2.bar(index, astar_states, bar_width, label='A*', color='skyblue', edgecolor='black')
ax2.bar(index + bar_width, mcts_states, bar_width, label='MCTS', color='lightcoral', edgecolor='black', alpha=0.7)
ax2.set_xticks(index + bar_width / 2)
ax2.set_xticklabels(['5', '7', '9', '11'])
ax2.set_ylabel('States Visited')
ax2.legend()

# Bar chart for execution time
ax3.bar(index, astar_times, bar_width, label='A*', color='skyblue', edgecolor='black')
ax3.bar(index + bar_width, mcts_times, bar_width, label='MCTS', color='lightcoral', edgecolor='black', alpha=0.7)
ax3.set_xticks(index + bar_width / 2)
ax3.set_xticklabels(['5', '7', '9', '11'])
ax3.set_ylabel('Execution Time (s)')
ax3.legend()

# Add labels for each bar
for ax in [ax1, ax2, ax3]:
    for i, v in enumerate(astar_lengths):
        ax.text(i + bar_width / 2, v + 5, str(v), color='black', ha='center', va='bottom')

    for i, v in enumerate(mcts_lengths):
        ax.text(i + bar_width * 1.5, v + 5, str(v), color='black', ha='center', va='bottom')

# Show the plot
plt.tight_layout()
plt.show()
plt.savefig("astar_vs_mcts_with_pattern_db.png")


<IPython.core.display.Javascript object>