In [3]:
%cd ..
!git clone https://github.com/PieBob851/LfP-Rubix-Cube.git
%cd LfP-Rubix-Cube

/
Cloning into 'LfP-Rubix-Cube'...
remote: Enumerating objects: 41, done.[K
remote: Counting objects: 100% (41/41), done.[K
remote: Compressing objects: 100% (32/32), done.[K
remote: Total 41 (delta 16), reused 28 (delta 9), pack-reused 0 (from 0)[K
Receiving objects: 100% (41/41), 25.55 KiB | 5.11 MiB/s, done.
Resolving deltas: 100% (16/16), done.
/LfP-Rubix-Cube


# Imports & Utils


In [4]:
from model.encoder import Encoder
from model.planner import Planner
from model.actor import Actor
import torch
from torch import nn, zeros
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import random
import copy
import numpy as np
import glob

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [5]:
!pip install magiccube

Collecting magiccube
  Downloading magiccube-1.0.0-py3-none-any.whl.metadata (3.9 kB)
Downloading magiccube-1.0.0-py3-none-any.whl (16 kB)
Installing collected packages: magiccube
Successfully installed magiccube-1.0.0


In [6]:
import magiccube
import copy
from magiccube.cube_base import Color, Face
from magiccube.cube_move import CubeMove
from magiccube.solver.basic.basic_solver import BasicSolver

cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")

def get_face_state(cube, face):
    array_values = np.array([[color.value for color in row] for row in cube.get_face(face)])
    tensor = torch.tensor(array_values.flatten(), dtype=torch.int64)
    return torch.nn.functional.one_hot(tensor, num_classes=6).flatten()

#state space
def get_cube_state(cube):
    return torch.stack([get_face_state(cube, Face.L), get_face_state(cube, Face.R), get_face_state(cube, Face.D), get_face_state(cube, Face.U), get_face_state(cube, Face.B), get_face_state(cube, Face.F)], dim=0)

def batch_cube_state(cube_list):
    current_states = []

    for cube in cube_list:
      current_states.append(get_cube_state(cube))

    current_states = torch.stack(current_states)

    return current_states.view(current_states.size(0), -1)

def batch_apply_action(cube_list, action_list):
  for i in range(len(cube_list)):
    cube_list[i]._rotate_once(action_list[i])

  return cube_list

#action space
movements = ["L", "L'", "L2", "R", "R'", "R2", "D", "D'", "D2", "U", "U'", "U2", "B", "B'", "B2", "F", "F'", "F2"]
reversals = ["L'", "L", "L2", "R'", "R", "R2", "D'", "D", "D2", "U'", "U", "U2", "B'", "B", "B2", "F'", "F", "F2"]
reverse_index = {0: 1, 1: 0, 2: 2, 3: 4, 4: 3, 5: 5, 6: 7, 7: 6, 8: 8, 9: 10, 10: 9, 11: 11, 12: 13, 13: 12, 14: 14, 15: 16, 16:15, 17:17}
reversals = [CubeMove.create(move_str) for move_str in reversals]
movements = [CubeMove.create(move_str) for move_str in movements]

# cube._rotate_once(movements[8])
# solver = BasicSolver(cube)
# cube_copy = copy.deepcopy(cube)
# solver.solve()

# Data Collection

Different ways to collect data (only one should be used)

Random move selection - at every timestep, a random move is chosen

In [None]:
# random move at every step for dataset creation
import torch
import numpy as np
import random

action_dim = 18
state_dim = 54 * 6
num_samples = 100000

data_raw = torch.zeros((num_samples, action_dim + state_dim))
cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
for i in range(num_samples):
  if i % 10000 == 0:
    print(f"Sample: {i}")
  state = get_cube_state(cube).flatten()
  data_raw[i, :state_dim] = state

  action = random.choice(range(action_dim))
  data_raw[i, state_dim + action] = 1

  cube._rotate_once(movements[action])

Random move selection with reversing - starts with a solved cube, then advances move_depth steps forward with random move selection. After this, it reverses those moves (so the cube is once again solved).

In [28]:
# forward set number of moves, before reversing
action_dim = 18
state_dim = 54 * 6
num_samples = 1000000
move_depth = 16  # Number of forward moves before reversing

data_raw = torch.zeros((num_samples, action_dim + state_dim))
cube = magiccube.Cube(3, "YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")

i = 0
while i < num_samples:
  forward_actions = []
  for _ in range(move_depth):
    if i % 10000 == 0:
        print(f"Sample: {i}")
    state = get_cube_state(cube).flatten()
    data_raw[i, :state_dim] = state

    action = random.choice(range(action_dim))
    data_raw[i, state_dim + action] = 1
    cube._rotate_once(movements[action])
    forward_actions.append(action)

    i += 1
    if i >= num_samples:
      break

  for action in reversed(forward_actions):
    if i % 10000 == 0:
        print(f"Sample: {i}")
    state = get_cube_state(cube).flatten()
    data_raw[i, :state_dim] = state

    reverse_action = reverse_index[action]
    data_raw[i, state_dim + reverse_action] = 1
    cube._rotate_once(movements[reverse_action])

    i += 1
    if i >= num_samples:
      break

Sample: 0
Sample: 10000
Sample: 20000
Sample: 30000
Sample: 40000
Sample: 50000
Sample: 60000
Sample: 70000
Sample: 80000
Sample: 90000
Sample: 100000
Sample: 110000
Sample: 120000
Sample: 130000
Sample: 140000
Sample: 150000
Sample: 160000
Sample: 170000
Sample: 180000
Sample: 190000
Sample: 200000
Sample: 210000
Sample: 220000
Sample: 230000
Sample: 240000
Sample: 250000
Sample: 260000
Sample: 270000
Sample: 280000
Sample: 290000
Sample: 300000
Sample: 310000
Sample: 320000
Sample: 330000
Sample: 340000
Sample: 350000
Sample: 360000
Sample: 370000
Sample: 380000
Sample: 390000
Sample: 400000
Sample: 410000
Sample: 420000
Sample: 430000
Sample: 440000
Sample: 450000
Sample: 460000
Sample: 470000
Sample: 480000
Sample: 490000
Sample: 500000
Sample: 510000
Sample: 520000
Sample: 530000
Sample: 540000
Sample: 550000
Sample: 560000
Sample: 570000
Sample: 580000
Sample: 590000
Sample: 600000
Sample: 610000
Sample: 620000
Sample: 630000
Sample: 640000
Sample: 650000
Sample: 660000
Sample: 6

In [29]:
#Getting Data

class Dataset:
    def __init__(self, data):
        self.data_list = data

    def sample_batch(self, batch_size, plan_len):
        start_indices = np.random.randint(0, len(self.data_list) // (plan_len * 2) - 1, size=batch_size)[:, np.newaxis]
        start_indices = start_indices * (plan_len * 2) + plan_len
        indices = start_indices + np.arange(plan_len + 1)[np.newaxis, :]

        sample = self.data_list[indices]
        return sample

data = Dataset(data_raw)

In [30]:
latent_dim = 128

encoder = Encoder(state_dim + action_dim, layer_size=256, latent_dim=latent_dim).to(device)
planner = Planner(state_dim, state_dim, layer_size=1024, latent_dim=latent_dim).to(device)
actor = Actor(state_dim, action_dim, state_dim, layer_size=1024, latent_dim=latent_dim).to(device)

encoder_optimizer = Adam(encoder.parameters(), lr=1e-4)
planner_optimizer = Adam(planner.parameters(), lr=1e-4)
actor_optimizer = Adam(actor.parameters(), lr=3e-4)

# Training and Testing


In [36]:
import torch.distributions as dist
import torch.optim as optim
import torch.nn.functional as F

def train_sample(batch_size, beta, encoder, actor, planner, encoder_optimizer, actor_optimizer, planner_optimizer, data, plan_len):
    sample = data.sample_batch(batch_size, plan_len).to(device)
    losses = []
    for i in range(plan_len):
        current_state = sample[:, i, :-18]
        current_action = sample[:, i, -18:]
        goal_state = sample[:, -1, :-18]
        goal_action = sample[:, -1, -18:]

        z, mu_phi, sigma_phi = encoder.forward(sample)
        mu_psi, sigma_psi = planner.forward(current_state, goal_state)

        phi_gaussian = dist.Normal(mu_phi, sigma_phi)

        psi_gaussian = dist.Normal(mu_psi, sigma_psi)

        KL_loss = torch.sum(dist.kl.kl_divergence(phi_gaussian, psi_gaussian))

        policy_action, _ = actor.forward(current_state.unsqueeze(1), z.unsqueeze(1), goal_state.unsqueeze(1))

        action_loss = F.cross_entropy(policy_action.squeeze(1), current_action)

        loss = beta * KL_loss + action_loss

        encoder_optimizer.zero_grad()
        planner_optimizer.zero_grad()
        actor_optimizer.zero_grad()

        loss.backward()

        encoder_optimizer.step()
        planner_optimizer.step()
        actor_optimizer.step()
        losses.append(loss.item())
    return sum(losses) / len(losses)

In [39]:
for batch in range(50000 // move_depth):
    loss = train_sample(32, .9, encoder, actor, planner, encoder_optimizer, actor_optimizer, planner_optimizer, data, move_depth)
    if batch % 100 == 0:
        print(f"Batch: {batch}, Loss: {loss}")

Batch: 0, Loss: 2.240416720509529
Batch: 100, Loss: 2.1591588873416185
Batch: 200, Loss: 2.0319291735067964
Batch: 300, Loss: 2.112350581213832
Batch: 400, Loss: 2.0157643388956785
Batch: 500, Loss: 1.9694203585386276
Batch: 600, Loss: 1.9207060476765037
Batch: 700, Loss: 1.9307263931259513
Batch: 800, Loss: 1.9795893393456936
Batch: 900, Loss: 1.9571971306577325
Batch: 1000, Loss: 1.8952850503847003
Batch: 1100, Loss: 1.8997483532875776
Batch: 1200, Loss: 1.8401312679052353
Batch: 1300, Loss: 1.779211115092039
Batch: 1400, Loss: 1.8657003976404667
Batch: 1500, Loss: 1.7245650589466095
Batch: 1600, Loss: 1.770142295397818
Batch: 1700, Loss: 1.7129919305443764
Batch: 1800, Loss: 1.7613434782251716
Batch: 1900, Loss: 1.5927450777962804
Batch: 2000, Loss: 1.6710727149620652
Batch: 2100, Loss: 1.6324873818084598
Batch: 2200, Loss: 1.5948121650144458
Batch: 2300, Loss: 1.6703677996993065
Batch: 2400, Loss: 1.7157585080713034
Batch: 2500, Loss: 1.5082753850147128
Batch: 2600, Loss: 1.4966363

In [33]:
def scramble_n(cube, n):
    for _ in range(n):
        action = random.choice(movements)
        cube._rotate_once(action)

def attempt_solve(scramble_moves, max_moves):
    cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
    goal_state = get_cube_state(cube).flatten().unsqueeze(0).to(device)
    scramble_n(cube, scramble_moves)
    with torch.no_grad():
      current_state = get_cube_state(cube).flatten().unsqueeze(0).to(device)

      for t in range(max_moves):
        if t % 16 == 0:
            mu_psi, sigma_psi = planner.forward(current_state.float(), goal_state.float())
        z = mu_psi + sigma_psi * torch.randn_like(sigma_psi)

        actor_dist, _ = actor.forward(current_state.unsqueeze(1), z.unsqueeze(1), goal_state.unsqueeze(1))
        action_index = torch.argmax(actor_dist, -1)

        cube._rotate_once(movements[action_index])
        current_state = get_cube_state(cube).flatten().unsqueeze(0).to(device)
        if cube.is_done():
            return t + 1
    return -1


def test_batch(batch_size):

    cubes = []
    histories = []

    for i in range(batch_size):
      cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
      history = cube.scramble(1)

      cubes.append(cube)
      histories.append(history)

    return cubes, histories

def test_sample(batch_size, encoder, actor, planner):

    goal_cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
    goal_state = get_cube_state(goal_cube)
    goal_state = goal_state.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
    goal_state = goal_state.view(goal_state.size(0), -1)

    cubes, histories = test_batch(batch_size)

    solved = [False] * batch_size
    steps_taken = [0] * batch_size

    with torch.no_grad():
      current_state = batch_cube_state(cubes).to(device)

      mu_psi, sigma_psi = planner.forward(current_state.float(), goal_state.float())
      z = torch.normal(mu_psi, sigma_psi)
      actor_dist, _ = actor.forward(current_state.unsqueeze(1), z.unsqueeze(1), goal_state.unsqueeze(1))

      best_actions = torch.argmax(actor_dist, -1)

      #evaluate
      for i, action_index in enumerate(best_actions):
        if not solved[i]:
          cubes[i]._rotate_once(movements[action_index])
          steps_taken[i] += 1
          if cubes[i].is_done():
            solved[i] = True


    num_successful = sum(solved)
    print("Number of successful solves: ", num_successful)



In [40]:
for d in range(1, 11):
  solve_count = 0
  for i in range(1000):
    moves = attempt_solve(d, 30)
    if moves > 0:
      solve_count += 1
  print(d, ": ", solve_count)

solve_count = 0
for i in range(1000):
  moves = attempt_solve(30, 80)
  if moves > 0:
    solve_count += 1
print(solve_count)

1 :  1000
2 :  1000
3 :  1000
4 :  983
5 :  898
6 :  802
7 :  658
8 :  516
9 :  379
10 :  259
1


In [42]:
print(attempt_solve(30, 5000))

-1
