In [None]:

import pandas as pd
import numpy as np
from pathlib import Path
import json

In [58]:
def load_dataset_split(data_dir: Path, split: str):
    """Load a dataset split (train/test/val)"""
    split_dir = data_dir / split
    if not split_dir.exists():
        raise FileNotFoundError(f"Split {split} not found in {data_dir}")

    # Load metadata
    with open(split_dir / "dataset.json", "r") as f:
        metadata = json.load(f)

    # Load arrays
    inputs = np.load(split_dir / "all__inputs.npy")
    labels = np.load(split_dir / "all__labels.npy")
    group_indices = np.load(split_dir / "all__group_indices.npy")
    puzzle_indices = np.load(split_dir / "all__puzzle_indices.npy")
    puzzle_identifiers = np.load(split_dir / "all__puzzle_identifiers.npy")

    return {
        'metadata': metadata,
        'inputs': inputs,
        'labels': labels,
        'group_indices': group_indices,
        'puzzle_indices': puzzle_indices,
        'puzzle_identifiers': puzzle_identifiers,
    }

In [59]:
data = load_dataset_split(Path("data/cube-2-by-2"), "test")

In [38]:
df = pd.DataFrame({
    'inputs': [data['inputs'][i].tolist() for i in range(len(data['inputs']))],
    #'labels': [data['labels'][i].tolist() for i in range(len(data['labels']))],
    'labels': [[x for x in data['labels'][i] if x != 0] for i in range(len(data['labels']))],
    'puzzle_indices': data['puzzle_indices'][:-1],
})

In [25]:
colors = ['Y','R','G','O','B','W']
translate = [0, 3, 2, 5, 1, 4]
# move indices
moveInds = { \
  "U": 0, "U'": 1, "U2": 2, "R": 3, "R'": 4, "R2": 5, "F": 6, "F'": 7, "F2": 8, \
  "D": 9, "D'": 10, "D2": 11, "L": 12, "L'": 13, "L2": 14, "B": 15, "B'": 16, "B2": 17, \
  "x": 18, "x'": 19, "x2": 20, "y": 21, "y'": 22, "y2": 23, "z": 24, "z'": 25, "z2": 26 \
}
inverted = {(v + 1): k for k, v in moveInds.items()}
inverted[0] = '0'
print(inverted)
sample = data['inputs'][0]
label = data['labels'][0]
print(sample-1)
print(''.join([colors[i-1] for i in sample]))
state = ''.join([colors[translate[i-1]] for i in sample])
sol = ''.join([str(inverted[i]) + " " for i in label])
print(state)
print(sol)

{1: 'U', 2: "U'", 3: 'U2', 4: 'R', 5: "R'", 6: 'R2', 7: 'F', 8: "F'", 9: 'F2', 10: 'D', 11: "D'", 12: 'D2', 13: 'L', 14: "L'", 15: 'L2', 16: 'B', 17: "B'", 18: 'B2', 19: 'x', 20: "x'", 21: 'x2', 22: 'y', 23: "y'", 24: 'y2', 25: 'z', 26: "z'", 27: 'z2', 0: '0'}
[2 2 5 5 0 3 0 2 3 4 1 4 0 2 3 4 0 1 4 5 1 1 3 5]
GGWWYOYGOBRBYGOBYRBWRROW
GGBBYWYGWRORYGWRYORBOOWB
F2 U2 R' F U F2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 


In [60]:
import py222

sample_idx = 0
cube_input = data['inputs'][sample_idx]
labels = data['labels'][sample_idx]

py222_state = cube_input - 1
print(py222_state)
py222.printCube(py222_state)

move_names = {0: "U", 1: "U'", 2: "U2", 3: "R", 4: "R'", 5: "R2", 6: "F", 7: "F'", 8: "F2"}

moves = [int(m) - 1 for m in labels if m != 0]
sol_str = ' '.join([move_names[m] for m in moves])
print(sol_str)

test_state = py222_state
for m in moves:
    test_state = py222.doMove(test_state, m)
print(test_state)

[2 2 5 5 0 3 0 2 3 4 1 4 0 2 3 4 0 1 4 5 1 1 3 5]
      ┌──┬──┐
      │ 2│ 2│
      ├──┼──┤
      │ 5│ 5│
┌──┬──┼──┼──┼──┬──┬──┬──┐
│ 0│ 1│ 3│ 4│ 0│ 3│ 1│ 1│
├──┼──┼──┼──┼──┼──┼──┼──┤
│ 4│ 5│ 1│ 4│ 0│ 2│ 3│ 5│
└──┴──┼──┼──┼──┴──┴──┴──┘
      │ 0│ 2│
      ├──┼──┤
      │ 3│ 4│
      └──┴──┘
F2 U2 R' F U F2
[0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5]


In [None]:
import magiccube

# Let's understand the exact mappings by checking solved states
solved_py222 = py222.initState()
solved_magic = magiccube.Cube(2)


# From the outputs we can see:
# py222 solved: face 0=U, 1=R, 2=F, 3=D, 4=L, 5=B
# magiccube solved: W=U, O=L, G=F, R=R, B=B, Y=D
# So the color mapping is: py222 -> magiccube
#   0 (U) -> W
#   1 (R) -> R  
#   2 (F) -> G
#   3 (D) -> Y
#   4 (L) -> O
#   5 (B) -> B

# magiccube face order in constructor string: U, L, F, R, B, D (4 stickers each)
# py222 sticker order: U(0-3), R(4-7), F(8-11), D(12-15), L(16-19), B(20-23)

# Apply R move to both and compare
test_py222 = py222.doMove(solved_py222.copy(), 3)  # R move index 3

test_magic = magiccube.Cube(2)
test_magic.rotate("R")

# Now let's build the correct conversion
# py222 color -> magiccube color char
color_map = {0: 'W', 1: 'R', 2: 'G', 3: 'Y', 4: 'O', 5: 'B'}

def py222_to_magiccube(state):
    # py222 order: U(0-3), R(4-7), F(8-11), D(12-15), L(16-19), B(20-23)
    # magiccube order: U(0-3), L(4-7), F(8-11), R(12-15), B(16-19), D(20-23)
    
    # Reorder: take py222 indices and put them in magiccube order
    magic_order = (
        list(range(0, 4)) +    # U stays at 0-3
        list(range(16, 20)) +  # L: py222 16-19 -> magic 4-7
        list(range(8, 12)) +   # F stays at 8-11
        list(range(4, 8)) +    # R: py222 4-7 -> magic 12-15
        list(range(20, 24)) +  # B stays at 16-19
        list(range(12, 16))    # D: py222 12-15 -> magic 20-23
    )
    return ''.join([color_map[state[i]] for i in magic_order])

# Test with solved state
converted = py222_to_magiccube(solved_py222)
test_cube = magiccube.Cube(2, converted)

converted_scrambled = py222_to_magiccube(py222_state)
cube = magiccube.Cube(2, converted_scrambled)
print(cube)
cube.rotate("F2 U2 R' F U F2")
print(cube)

       G  G             
       B  B             
 W  R  Y  O  W  Y  R  R 
 O  B  R  O  W  G  Y  B 
       W  G             
       Y  O             

       W  W             
       W  W             
 O  O  G  G  R  R  B  B 
 O  O  G  G  R  R  B  B 
       Y  Y             
       Y  Y             

W  W             
       W  W             
 O  O  G  G  R  R  B  B 
 O  O  G  G  R  R  B  B 
       Y  Y             
       Y  Y


In [57]:
str(cube).replace(" ", "").replace("\n", "")

'WWWWOOGGRRBBOOGGRRBBYYYY'

In [39]:
df.head()

Unnamed: 0,inputs,labels,puzzle_indices
0,"[1, 2, 4, 2, 4, 6, 2, 4, 5, 3, 1, 1, 5, 3, 4, ...","[6, 2, 8, 3, 4, 3, 4, 7]",0
1,"[4, 2, 5, 4, 5, 1, 6, 1, 1, 3, 2, 1, 6, 2, 4, ...","[4, 2, 4, 1, 7, 4, 2, 6, 9]",1
2,"[1, 3, 2, 1, 5, 4, 5, 4, 6, 6, 2, 3, 3, 1, 4, ...","[2, 7, 2, 8, 6, 8, 6, 7, 5, 7]",2
3,"[5, 4, 3, 2, 1, 6, 6, 3, 2, 6, 3, 5, 4, 1, 4, ...","[1, 5, 3, 5, 2]",3
4,"[6, 1, 2, 1, 5, 3, 3, 3, 1, 6, 3, 1, 2, 5, 4, ...","[1, 6, 3, 7, 2, 5, 8, 5]",4


In [40]:
df_inputs = pd.DataFrame(df['inputs'].tolist())
df_outputs = pd.DataFrame(df['labels'].tolist())

In [41]:
df_inputs.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,14,15,16,17,18,19,20,21,22,23
0,1,2,4,2,4,6,2,4,5,3,...,4,6,3,3,5,6,1,5,2,6
1,4,2,5,4,5,1,6,1,1,3,...,4,6,3,3,5,4,3,2,5,6
2,1,3,2,1,5,4,5,4,6,6,...,4,3,6,4,5,1,2,2,5,6
3,5,4,3,2,1,6,6,3,2,6,...,4,2,1,4,5,5,2,3,1,6
4,6,1,2,1,5,3,3,3,1,6,...,4,5,4,6,5,4,2,2,4,6


In [42]:
df_inputs.describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,14,15,16,17,18,19,20,21,22,23
count,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,...,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0,1000000.0
mean,3.262223,3.268607,3.26685,3.279013,3.285012,3.282481,3.287383,3.288214,3.305044,3.288518,...,4.0,3.298657,3.305002,3.283892,5.0,3.26581,3.301689,3.292794,3.271525,6.0
std,1.712483,1.709518,1.710551,1.693693,1.693733,1.689816,1.692781,1.689101,1.681339,1.695122,...,0.0,1.678368,1.675967,1.691911,0.0,1.712994,1.681882,1.692267,1.715373,0.0
min,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,4.0,1.0,1.0,1.0,5.0,1.0,1.0,1.0,1.0,6.0
25%,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,...,4.0,2.0,2.0,2.0,5.0,2.0,2.0,2.0,2.0,6.0
50%,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,...,4.0,3.0,3.0,3.0,5.0,3.0,3.0,3.0,3.0,6.0
75%,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,...,4.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,6.0
max,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,6.0,...,4.0,6.0,6.0,6.0,5.0,6.0,6.0,6.0,6.0,6.0


In [43]:
df_outputs.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
0,6.0,2.0,8.0,3.0,4.0,3.0,4.0,7.0,,,
1,4.0,2.0,4.0,1.0,7.0,4.0,2.0,6.0,9.0,,
2,2.0,7.0,2.0,8.0,6.0,8.0,6.0,7.0,5.0,7.0,
3,1.0,5.0,3.0,5.0,2.0,,,,,,
4,1.0,6.0,3.0,7.0,2.0,5.0,8.0,5.0,,,


In [44]:
df_outputs.describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
count,999984.0,999843.0,999218.0,996511.0,987034.0,909633.0,817512.0,688141.0,455701.0,100146.0,367.0
mean,4.009624,5.19316,4.731906,5.004187,4.900629,4.967111,4.955319,4.993251,4.985613,5.002856,5.321526
std,2.558364,2.49628,2.61965,2.556069,2.589572,2.571313,2.584165,2.579975,2.585093,2.579907,2.678624
min,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
25%,2.0,3.0,2.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0
50%,4.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.0
75%,6.0,7.0,7.0,7.0,7.0,7.0,7.0,7.0,7.0,7.0,8.0
max,9.0,9.0,9.0,9.0,9.0,9.0,9.0,9.0,9.0,9.0,9.0


In [52]:
count = df_outputs.iloc[:, 9].notna().sum()
print(count)

100146
