In [1]:
import numpy as np
import random

In [2]:
class RubiksCube2x2Env:
    def __init__(self):
        # Initialize the solved cube (a list of six faces)
        self.reset()

        # Define the actions: U, F, R, L, D, B and their counterclockwise counterparts
        self.actions = ['U', "U'", 'F', "F'", 'R', "R'", 'L', "L'", 'D', "D'", 'B', "B'"]

    def reset(self):
        """
        Resets the environment to the solved state.
        The cube is represented as a list of 6 faces, each having 4 stickers.
        """
        self.cube = [['W']*4, ['R']*4, ['G']*4, ['B']*4, ['O']*4, ['Y']*4]  # Solved state
        self.state = self.get_state()
        return self.state

    def scramble(self, num_moves=10):
        """
        Scrambles the cube by performing a random sequence of moves.
        """
        for _ in range(num_moves):
            action = random.choice(self.actions)
            self.apply_action(action)
        return self.get_state()

    def is_solved(self):
        """
        Check if the cube is solved. A solved cube has all stickers of each face the same color.
        """
        return all(len(set(face)) == 1 for face in self.cube)

    def get_state(self):
        """
        Converts the cube to a tuple of tuples (to be hashable) for use as a state representation.
        """
        return tuple(tuple(face) for face in self.cube)


    # Function to apply an action to the cube (rotation)
    def _rotate_face_clockwise(self, face):
        #"""Rotate a 2x2 face 90 degrees clockwise."""
        return [face[2], face[0], face[3], face[1]]

    def _rotate_face_counterclockwise(self, face):
    #"""Rotate a 2x2 face 90 degrees counterclockwise."""
        return [face[1], face[3], face[0], face[2]]


    def apply_action(self, action):
        if action == "U":
            self._rotate_u()
        elif action == "U'":
            self._rotate_u_prime()
        elif action == "D":
            self._rotate_d()
        elif action == "D'":
            self._rotate_d_prime()
        elif action == "R":
            self._rotate_r()
        elif action == "R'":
            self._rotate_r_prime()
        elif action == "L":
            self._rotate_l()
        elif action == "L'":
            self._rotate_l_prime()
        elif action == "F":
            self._rotate_f()
        elif action == "F'":
            self._rotate_f_prime()
        elif action == "B":
            self._rotate_b()
        elif action == "B'":
            self._rotate_b_prime()

    def _rotate_u(self):
    # U: F[0,1] <-> L[0,1] <-> B[0,1] <-> R[0,1]
        self.cube[2] = self._rotate_face_clockwise(self.cube[2])  # Top face
        self.cube[0][0:2], self.cube[4][0:2], self.cube[1][0:2], self.cube[5][0:2] = \
        self.cube[4][0:2], self.cube[1][0:2], self.cube[5][0:2], self.cube[0][0:2]

    def _rotate_u_prime(self):
        self.cube[2] = self._rotate_face_counterclockwise(self.cube[2])
        self.cube[0][0:2], self.cube[4][0:2], self.cube[1][0:2], self.cube[5][0:2] = \
        self.cube[5][0:2], self.cube[0][0:2], self.cube[4][0:2], self.cube[1][0:2]

    def _rotate_d(self):
        # D: F[2,3] <-> R[2,3] <-> B[2,3] <-> L[2,3]
        self.cube[3] = self._rotate_face_clockwise(self.cube[3])  # Bottom face
        self.cube[0][2:4], self.cube[5][2:4], self.cube[1][2:4], self.cube[4][2:4] = \
        self.cube[5][2:4], self.cube[1][2:4], self.cube[4][2:4], self.cube[0][2:4]

    def _rotate_d_prime(self):
        self.cube[3] = self._rotate_face_counterclockwise(self.cube[3])
        self.cube[0][2:4], self.cube[5][2:4], self.cube[1][2:4], self.cube[4][2:4] = \
        self.cube[4][2:4], self.cube[0][2:4], self.cube[5][2:4], self.cube[1][2:4]

    def _rotate_l(self):
        # L: F[0,2] <-> D[0,2] <-> B[1,3] <-> U[0,2]
        self.cube[4] = self._rotate_face_clockwise(self.cube[4])  # Left face
        temp = [self.cube[0][0], self.cube[0][2]]
        self.cube[0][0], self.cube[0][2] = self.cube[3][0], self.cube[3][2]
        self.cube[3][0], self.cube[3][2] = self.cube[1][3], self.cube[1][1]
        self.cube[1][3], self.cube[1][1] = self.cube[2][0], self.cube[2][2]
        self.cube[2][0], self.cube[2][2] = temp

    def _rotate_l_prime(self):
        self.cube[4] = self._rotate_face_counterclockwise(self.cube[4])
        temp = [self.cube[0][0], self.cube[0][2]]
        self.cube[0][0], self.cube[0][2] = self.cube[2][0], self.cube[2][2]
        self.cube[2][0], self.cube[2][2] = self.cube[1][3], self.cube[1][1]
        self.cube[1][3], self.cube[1][1] = self.cube[3][0], self.cube[3][2]
        self.cube[3][0], self.cube[3][2] = temp

    def _rotate_r(self):
        # R: F[1,3] <-> U[1,3] <-> B[0,2] <-> D[1,3]
        self.cube[5] = self._rotate_face_clockwise(self.cube[5])  # Right face
        temp = [self.cube[0][1], self.cube[0][3]]
        self.cube[0][1], self.cube[0][3] = self.cube[2][1], self.cube[2][3]
        self.cube[2][1], self.cube[2][3] = self.cube[1][2], self.cube[1][0]
        self.cube[1][2], self.cube[1][0] = self.cube[3][1], self.cube[3][3]
        self.cube[3][1], self.cube[3][3] = temp

    def _rotate_r_prime(self):
        self.cube[5] = self._rotate_face_counterclockwise(self.cube[5])
        temp = [self.cube[0][1], self.cube[0][3]]
        self.cube[0][1], self.cube[0][3] = self.cube[3][1], self.cube[3][3]
        self.cube[3][1], self.cube[3][3] = self.cube[1][2], self.cube[1][0]
        self.cube[1][2], self.cube[1][0] = self.cube[2][1], self.cube[2][3]
        self.cube[2][1], self.cube[2][3] = temp

    def _rotate_f(self):
        # F: U[2,3] <-> L[1,3] <-> D[0,1] <-> R[0,2]
        self.cube[0] = self._rotate_face_clockwise(self.cube[0])  # Front face
        temp = [self.cube[2][2], self.cube[2][3]]
        self.cube[2][2], self.cube[2][3] = self.cube[4][3], self.cube[4][1]
        self.cube[4][3], self.cube[4][1] = self.cube[3][0], self.cube[3][1]
        self.cube[3][0], self.cube[3][1] = self.cube[5][0], self.cube[5][2]
        self.cube[5][0], self.cube[5][2] = temp

    def _rotate_f_prime(self):
        self.cube[0] = self._rotate_face_counterclockwise(self.cube[0])
        temp = [self.cube[2][2], self.cube[2][3]]
        self.cube[2][2], self.cube[2][3] = self.cube[5][0], self.cube[5][2]
        self.cube[5][0], self.cube[5][2] = self.cube[3][0], self.cube[3][1]
        self.cube[3][0], self.cube[3][1] = self.cube[4][3], self.cube[4][1]
        self.cube[4][3], self.cube[4][1] = temp

    def _rotate_b(self):
        # B: U[0,1] <-> R[1,3] <-> D[2,3] <-> L[0,2]
        self.cube[1] = self._rotate_face_clockwise(self.cube[1])  # Back face
        temp = [self.cube[2][0], self.cube[2][1]]
        self.cube[2][0], self.cube[2][1] = self.cube[5][1], self.cube[5][3]
        self.cube[5][1], self.cube[5][3] = self.cube[3][2], self.cube[3][3]
        self.cube[3][2], self.cube[3][3] = self.cube[4][0], self.cube[4][2]
        self.cube[4][0], self.cube[4][2] = temp

    def _rotate_b_prime(self):
        self.cube[1] = self._rotate_face_counterclockwise(self.cube[1])
        temp = [self.cube[2][0], self.cube[2][1]]
        self.cube[2][0], self.cube[2][1] = self.cube[4][0], self.cube[4][2]
        self.cube[4][0], self.cube[4][2] = self.cube[3][2], self.cube[3][3]
        self.cube[3][2], self.cube[3][3] = self.cube[5][1], self.cube[5][3]
        self.cube[5][1], self.cube[5][3] = temp

    def choose_action(self, q_table, epsilon=0.1):
        """
        Chooses an action based on epsilon-greedy policy.
        """
        if random.uniform(0, 1) < epsilon:
            # Exploration: choose a random action
            return random.choice(self.actions)
        else:
            # Exploitation: choose the action with the highest Q-value
            if self.state not in q_table:
                q_table[self.state] = {action: 0 for action in self.actions}
            return max(q_table[self.state], key=q_table[self.state].get)

    def count_solved_faces(self):
        """
        Counts the number of fully solved faces.
        A fully solved face has all stickers of the same color.
        """
        return sum(face.count(face[0]) == len(face) for face in self.cube)

    def step(self, action, q_table, alpha=0.1, gamma=0.9):
        """
        Takes an action, applies it to the cube, and returns the new state, reward, and whether the episode is done.
        """
        # Apply the action to get the next state
        self.apply_action(action)
        next_state = self.get_state()

         # Calculate reward based on the number of solved faces
        solved_faces = self.count_solved_faces()
        if self.is_solved():
            reward = 10  # Highest reward for solving the entire cube
        else:
            reward = solved_faces  # Reward for the number of solved faces (0 to 5)

        # Add a small penalty for each move to encourage faster solutions
        reward -= 0.01

        # Update the Q-value using the Q-learning update rule
        if next_state not in q_table:
            q_table[next_state] = {action: 0 for action in self.actions}

        best_next_action = max(q_table[next_state], key=q_table[next_state].get)
        q_table[self.state][action] += alpha * (reward + gamma * q_table[next_state][best_next_action] - q_table[self.state][action])

        # Move to the next state
        self.state = next_state

        # Return the next state, reward, and whether the cube is solved
        return next_state, reward, self.is_solved()

In [5]:
# Training Loop

def train_agent(num_episodes=10000):
    env = RubiksCube2x2Env()
    q_table = {}
    epsilon = 0.1  # Exploration rate
    alpha = 0.1    # Learning rate
    gamma = 0.9    # Discount factor

    for episode in range(num_episodes):
        # Reset the environment at the start of each episode
        env.reset()
        env.scramble(20)
        step = 0
        max_reward = - np.Infinity 
        # Reduce exploration rate over time
        if epsilon > 0.01:
            epsilon *= 0.995
        while not env.is_solved() and step < 30:
            step +=1
            # Choose an action
            action = env.choose_action(q_table, epsilon)
            # Take the action and get the next state, reward, and done flag
            next_state, reward, done = env.step(action, q_table, alpha, gamma)
            max_reward = np.maximum(reward,max_reward)
            if done:
                print("*************************************************Solved!********************************************************")
                print( "Episode n°", episode, "reward = ", max_reward, "in ", step , "steps, SOLVED :D !!!!!!")
                break
        if episode % 1000000 == 0:
          print( "Episode n°", episode, "reward = ", max_reward, "in ", step , "steps")

    return q_table

In [6]:
# Training the agent
q_table = train_agent(1000000000)

Episode n° 0 reward =  -0.01 in  30 steps
*************************************************Solved!********************************************************
Episode n° 72356 reward =  9.99 in  2 steps, SOLVED :D !!!!!!
*************************************************Solved!********************************************************
Episode n° 172043 reward =  9.99 in  20 steps, SOLVED :D !!!!!!
*************************************************Solved!********************************************************
Episode n° 179348 reward =  9.99 in  2 steps, SOLVED :D !!!!!!
*************************************************Solved!********************************************************
Episode n° 231093 reward =  9.99 in  12 steps, SOLVED :D !!!!!!
*************************************************Solved!********************************************************
Episode n° 284579 reward =  9.99 in  2 steps, SOLVED :D !!!!!!
*************************************************Solved!*********************

MemoryError: 

In [None]:
# Testing the agent on a scrambled cube
env = RubiksCube2x2Env()
test_cube = env.scramble(10)
print("Scrambled Cube:")
print(test_cube)
steps = 0
while not env.is_solved():
    state = env.get_state()
    action = max(q_table[state], key=q_table[state].get)  # Choose best action
    env.apply_action(action)
    steps += 1
    print(f"Step {steps}: Action {action}")
print("Cube solved!")