In [None]:
!git clone https://github.com/PieBob851/LfP-Rubix-Cube.git
%cd LfP-Rubix-Cube

In [None]:
from utils import train_sample
from model.encoder import Encoder
from model.planner import Planner
from model.actor import Actor
import torch
from torch import nn, zeros
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import random
import copy
import numpy as np
import glob

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [None]:
!pip install magiccube

Collecting magiccube
  Downloading magiccube-1.0.0-py3-none-any.whl.metadata (3.9 kB)
Downloading magiccube-1.0.0-py3-none-any.whl (16 kB)
Installing collected packages: magiccube
Successfully installed magiccube-1.0.0


In [None]:
import magiccube
import copy
from magiccube.cube_base import Color, Face
from magiccube.cube_move import CubeMove
from magiccube.solver.basic.basic_solver import BasicSolver

cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")

def get_face_state(cube, face):
    array_values = np.array([[color.value for color in row] for row in cube.get_face(face)])
    tensor = torch.tensor(array_values.flatten(), dtype=torch.int64)
    return torch.nn.functional.one_hot(tensor, num_classes=6).flatten()

#state space
def get_cube_state(cube):
    return torch.stack([get_face_state(cube, Face.L), get_face_state(cube, Face.R), get_face_state(cube, Face.D), get_face_state(cube, Face.U), get_face_state(cube, Face.B), get_face_state(cube, Face.F)], dim=0)

def batch_cube_state(cube_list):
    current_states = []

    for cube in cube_list:
      current_states.append(get_cube_state(cube))

    current_states = torch.stack(current_states)

    return current_states.view(current_states.size(0), -1)

def batch_apply_action(cube_list, action_list):
  for i in range(len(cube_list)):
    cube_list[i]._rotate_once(action_list[i])

  return cube_list

#action space
movements = ["L", "L'", "L2", "R", "R'", "R2", "D", "D'", "D2", "U", "U'", "U2", "B", "B'", "B2", "F", "F'", "F2"]
reversals = ["L'", "L", "L2", "R'", "R", "R2", "D'", "D", "D2", "U'", "U", "U2", "B'", "B", "B2", "F'", "F", "F2"]
reverse_index = {0: 1, 1: 0, 2: 2, 3: 4, 4: 3, 5: 5, 6: 7, 7: 6, 8: 8, 9: 10, 10: 9, 11: 11, 12: 13, 13: 12, 14: 14, 15: 16, 16:15, 17:17}
reversals = [CubeMove.create(move_str) for move_str in reversals]
movements = [CubeMove.create(move_str) for move_str in movements]

cube._rotate_once(movements[8])
solver = BasicSolver(cube)
cube_copy = copy.deepcopy(cube)
solver.solve()

[L', L, L2, R', R, R2, D', D, D2, U', U, U2, B', B, B2, F', F, F2]
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


In [None]:
# random move at every step for dataset creation
import torch
import numpy as np
import random

action_dim = 18
state_dim = 54 * 6
num_samples = 100000

data_raw = torch.zeros((num_samples, action_dim + state_dim))
cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
for i in range(num_samples):
  if i % 10000 == 0:
    print(f"Sample: {i}")
  state = get_cube_state(cube).flatten()
  data_raw[i, :state_dim] = state

  action = random.choice(range(action_dim))
  data_raw[i, state_dim + action] = 1

  cube._rotate_once(movements[action])

Sample: 0
Sample: 10000
Sample: 20000
Sample: 30000
Sample: 40000
Sample: 50000
Sample: 60000
Sample: 70000
Sample: 80000
Sample: 90000


In [None]:
# forward set number of moves, before reversing
action_dim = 18
state_dim = 54 * 6
num_samples = 100000
move_depth = 4  # Number of forward moves before reversing

data_raw = torch.zeros((num_samples, action_dim + state_dim))
cube = magiccube.Cube(3, "YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")

i = 0
while i < num_samples:
  forward_actions = []
  for _ in range(move_depth):
    if i % 10000 == 0:
        print(f"Sample: {i}")
    state = get_cube_state(cube).flatten()
    data_raw[i, :state_dim] = state

    action = random.choice(range(action_dim))
    data_raw[i, state_dim + action] = 1
    cube._rotate_once(movements[action])
    forward_actions.append(action)

    i += 1
    if i >= num_samples:
      break

  for action in reversed(forward_actions):
    if i % 10000 == 0:
        print(f"Sample: {i}")
    state = get_cube_state(cube).flatten()
    data_raw[i, :state_dim] = state

    reverse_action = reverse_index[action]
    data_raw[i, state_dim + reverse_action] = 1
    cube._rotate_once(movements[reverse_action])

    i += 1
    if i >= num_samples:
      break

Sample: 0
Sample: 10000
Sample: 20000
Sample: 30000
Sample: 40000
Sample: 50000
Sample: 60000
Sample: 70000
Sample: 80000
Sample: 90000


In [None]:
#Getting Data

class Dataset:
    def __init__(self, data):
        self.data_list = data

    def sample_batch(self, batch_size):
        start_indices = np.random.randint(0, len(self.data_list) - 32, size=batch_size)[:, np.newaxis]
        indices = start_indices + np.arange(32)[np.newaxis, :]

        sample = self.data_list[indices]
        return sample

data = Dataset(data_raw)

In [None]:
latent_dim = 32

encoder = Encoder(state_dim + action_dim, layer_size=256, latent_dim=latent_dim).to(device)
planner = Planner(state_dim, state_dim, layer_size=512, latent_dim=latent_dim).to(device)
actor = Actor(state_dim, action_dim, state_dim, layer_size=512, latent_dim=latent_dim).to(device)

encoder_optimizer = optim.Adam(encoder.parameters(), lr=1e-4)
planner_optimizer = optim.Adam(planner.parameters(), lr=1e-4)
actor_optimizer = optim.Adam(actor.parameters(), lr=3e-4)

In [None]:
for batch in range(10000):
    loss = train_sample(32, .9, encoder, actor, planner, encoder_optimizer, actor_optimizer, planner_optimizer)
    if batch % 100 == 0:
        print(f"Batch: {batch}, Loss: {loss}")

Batch: 0, Loss: 6.241903305053711
Batch: 100, Loss: 2.429344892501831
Batch: 200, Loss: 2.1109321117401123
Batch: 300, Loss: 1.9098554849624634
Batch: 400, Loss: 1.8812134265899658
Batch: 500, Loss: 2.522042751312256
Batch: 600, Loss: 2.55255126953125
Batch: 700, Loss: 2.1031744480133057
Batch: 800, Loss: 2.155884027481079
Batch: 900, Loss: 2.1561708450317383
Batch: 1000, Loss: 2.158642292022705
Batch: 1100, Loss: 2.298733711242676
Batch: 1200, Loss: 2.0353195667266846
Batch: 1300, Loss: 1.9320507049560547
Batch: 1400, Loss: 1.8095418214797974
Batch: 1500, Loss: 1.973937749862671
Batch: 1600, Loss: 2.283557176589966
Batch: 1700, Loss: 2.1172311305999756
Batch: 1800, Loss: 1.972545862197876
Batch: 1900, Loss: 2.67281436920166
Batch: 2000, Loss: 1.9489545822143555
Batch: 2100, Loss: 2.2429769039154053
Batch: 2200, Loss: 2.546633243560791
Batch: 2300, Loss: 1.8986903429031372
Batch: 2400, Loss: 2.199829578399658
Batch: 2500, Loss: 1.777984619140625
Batch: 2600, Loss: 2.1823060512542725
Ba

In [None]:
def test_batch(batch_size):

    cubes = []
    histories = []

    for i in range(batch_size):
      cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
      history = cube.scramble(1)

      cubes.append(cube)
      histories.append(history)

    return cubes, histories

def test_sample(batch_size, encoder, actor, planner):

    goal_cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
    goal_state = get_cube_state(goal_cube)
    goal_state = goal_state.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
    goal_state = goal_state.view(goal_state.size(0), -1)

    cubes, histories = test_batch(batch_size)

    solved = [False] * batch_size
    steps_taken = [0] * batch_size

    with torch.no_grad():
      current_state = batch_cube_state(cubes).to(device)

      mu_psi, sigma_psi = planner.forward(current_state.float(), goal_state.float())
      z = torch.normal(mu_psi, sigma_psi)
      actor_dist, _ = actor.forward(current_state.unsqueeze(1), z.unsqueeze(1), goal_state.unsqueeze(1))

      best_actions = torch.argmax(actor_dist, -1)

      #evaluate
      for i, action_index in enumerate(best_actions):
        if not solved[i]:
          cubes[i]._rotate_once(movements[action_index])
          steps_taken[i] += 1
          if cubes[i].is_done():
            solved[i] = True


    num_successful = sum(solved)
    print("Number of successful solves: ", num_successful)



In [None]:
test_sample(32, encoder, actor, planner)

Number of successful solves:  32
