In [None]:
# install from git
!if [ -e ./rubiks ]; then rm -rf ./rubiks; fi
!pip uninstall rubiks_rl -y --quiet
!git clone https://github.com/LongDangHoang/rubik_rl ./rubiks --quiet
!cd ./rubiks; pip install . --quiet; cd ..

# some packages
!pip install torchinfo python-dotenv wandb==0.15.0 protobuf==3.20.3 matplotlib --quiet

In [None]:
import os
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import random
import functools
import pythreejs as p3

from torchinfo import summary
from pathlib import Path
from tqdm import tqdm

from rubiks_rl.colors import Color
from rubiks_rl.rubik54 import Rubik54
from rubiks_rl.models import RLRubikModel
from rubiks_rl.s3_utils import S3Syncer
from rubiks_rl.world import *
from rubiks_rl.logs import RubiksLogger

orig_seed = 314
torch.manual_seed(orig_seed)
torch.cuda.manual_seed_all(orig_seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# set environment
from dotenv import load_dotenv
ON_KAGGLE = False
if not load_dotenv():
    from kaggle_secrets import UserSecretsClient
    os.environ["WANDB_API_KEY"] = UserSecretsClient().get_secret("wandb_api")
    os.environ["AWS_ACCESS_KEY_ID"] = UserSecretsClient().get_secret("s3_aws_access_key")
    os.environ["AWS_SECRET_ACCESS_KEY"] = UserSecretsClient().get_secret("s3_aws_secret_access_key")
    ON_KAGGLE = True
    
import wandb

# run id
use_existing_run = "yqz2c4cv"
assert use_existing_run is not None

# Load model

In [None]:
# setting up checkpoint 
ckpt_local_dir = Path(f"./rubiks_rl/{use_existing_run}/checkpoints") 
ckpt_local_dir.mkdir(parents=True, exist_ok=True)

s3_syncer = S3Syncer(
    save_local_dir=ckpt_local_dir, 
    load_local_dir=ckpt_local_dir, 
    every_n_epochs=1
)

s3_syncer.download_files_from_s3()
    
# load model
state_dict = torch.load(s3_syncer.load_local_dir / "last.ckpt")
model = RLRubikModel()
model.load_state_dict(state_dict["model_state_dict"])

model = model.to(device)
summary(model, input_shape=(16, 54, 6))

# Test greedy method

To avoid loop, ignore best_action that is the anti of the previous best action

In [None]:
seed = 123456789 # large so that we haven't seen during training
X = get_n_cubes_k_scrambles(num_cubes=5, max_depth_scramble=50, seed=seed)["data"]

In [None]:
test_cube = np.expand_dims(X[22], 0) # roughly 20 scrambles in 
max_turns = 100

with torch.no_grad():
    turn_counter = 0
    prev_best_action_name = None
    while (not np.all(test_cube[0] == CUBE.get_solved_state())) and (turn_counter < max_turns):
        next_states = get_depth_1_lookup_of_state(test_cube)
        x = torch.as_tensor(next_states, device=device, dtype=torch.float32)
        v, _ = model.forward(x)
        v = v.cpu().numpy()
        best_action_idx = v.argmax()
        best_action_name = CUBE.TURN_IDX_TO_STR_SWAP_DICT[best_action_idx][0]
        
        if prev_best_action_name is not None:
            if (best_action_name == prev_best_action_name + "_PRIME") or (
                best_action_name + "_PRIME" == prev_best_action_name
            ):
                v[best_action_idx] = -10000
                best_action_idx = v.argmax()
                best_action_name = CUBE.TURN_IDX_TO_STR_SWAP_DICT[best_action_idx][0]
        
        test_cube = np.expand_dims(
            test_cube[0][CUBE.turn_mat[best_action_idx]],
            0
        )
        prev_best_action_name = best_action_name
        print(f"Applying turn {best_action_name}, getting value {v[best_action_idx].item()}")
        turn_counter += 1
    
    if np.all(test_cube[0] == CUBE.get_solved_state()):
        print("Cube is solved!")
    else:
        print(f"Cube is not solved after {max_turns} turns")

# Test search method

In [None]:
seed = 123456789 # large so that we haven't seen during training
X = get_n_cubes_k_scrambles(num_cubes=5, max_depth_scramble=50, seed=seed)["data"]

In [None]:
def bfs_at_fixed_depth(cube_state: np.ndarray, model, depth: int=5):
    if len(cube_state.shape) == 2:
        cube_state = np.expand_dims(cube_state, axis=0)
    
    best_future_idx = None
    for d in range(depth):
        cube_state = get_depth_1_lookup_of_state(cube_state)
        is_solved = np.all(cube_state == CUBE.get_solved_state(), axis=(1, 2))
        if np.any(is_solved):
            best_future_idx = np.argmax(is_solved)
            best_action_idx = best_future_idx // (12 ** d)
            return best_action_idx, 10000
    
    with torch.no_grad():
        v_so_far = []
        for batch_idx in np.arange(0, cube_state.shape[0], step=12**5):
            x = torch.as_tensor(cube_state[batch_idx:batch_idx+12**5], device=device, dtype=torch.float32)
            v, _ = model.forward(x)
            v = v.cpu().numpy()
            v_so_far.append(v)
        
        v = np.concatenate(v_so_far)
        best_future_idx = v.argmax()
        best_future_value = v[best_future_idx][0]
        best_action_idx = best_future_idx // (12 ** (depth - 1))    
        return best_action_idx, best_future_value

In [None]:
test_cube = np.expand_dims(X[125], 0) # roughly 20 scrambles in 
max_turns = 100

with torch.no_grad():
    turn_counter = 0
    
    while (not np.all(test_cube[0] == CUBE.get_solved_state())) and (turn_counter < max_turns):
        best_action_idx, best_future_value = bfs_at_fixed_depth(test_cube, model, depth=6)
        best_action_name = CUBE.TURN_IDX_TO_STR_SWAP_DICT[best_action_idx][0]
        test_cube = np.expand_dims(
            test_cube[0][CUBE.turn_mat[best_action_idx]],
            0
        )
        print(f"Applying turn {best_action_name}, getting best future value {best_future_value}")
        turn_counter += 1
    
    if np.all(test_cube[0] == CUBE.get_solved_state()):
        print("Cube is solved!")
    else:
        print(f"Cube is not solved after {max_turns} turns")