# Cube-Net ConvLSTM (Cross) Training Environment
This notebooks contains the setup and training of the convolutional LSTM network for the Cube-Net project. More specifically, this notebook contains the training environment for the cross solving network.

### cube Python Bindings

The Cube class is used to provide python bindings to the rust program 'cube' which is used to generate a dataset of scrambled cubes, both as a Tensor representation as well as their associated scrambles.

In [1]:
from cube_bindings import Cube

cube = Cube()


/home/holindauer/Projects/Cube-Net
Compiling the solution_verifier Rust program in release mode...
Compilation successful.


    Finished release [optimized] target(s) in 0.13s


### Example
Here is an example of how the python bindings for the rust program are used:

With regards to the training environment for the cross solving model, the follow operations are available: 
- cube.generate_data() generate a batch of scrambled cube Tensors and the scrambles that produced them
- cube.is_solved() and cube.is_cross_solved() to check if the cube or the cross is solved
- cube. solved_cross() to determine the solution of the cross

In [2]:
cube_tensors, scrambles = cube.generate_data(batch_size=1, scramble_len=40)

cube_tensors.shape

solved = cube.is_solved(scrambles[0], "")
print(f"Is the cube solved? {'yes' if solved else 'no'} --- scramble: {scrambles[0]}")

cross_solution = cube.solve_cross(scrambles[0])
print(f"Cross solution: {cross_solution}")  

solved = cube.is_solved(scrambles[0], cross_solution)
print(f"Is the cube now solved? {'yes' if solved else 'no'}")

all_moves_so_far = " ".join([scrambles[0], cross_solution])
print(f"All moves so far: {all_moves_so_far}")

print(f"Is the cross now solved? {'yes' if cube.is_cross_solved(all_moves_so_far) else 'no'}")


Is the cube solved? no --- scramble: R L B' B U D' U' B' B L' U' F' R' L' D' U' F D F' D L' D' R' F D R' R' B U R' L' U' D' D' L' D B' D' B' U'
Cross solution: B R' U R F F B U U B' U L L F U U F' U R R U U U U R B' R' 
Is the cube now solved? no
All moves so far: R L B' B U D' U' B' B L' U' F' R' L' D' U' F D F' D L' D' R' F D R' R' B U R' L' U' D' D' L' D B' D' B' U' B R' U R F F B U U B' U L L F U U F' U R R U U U U R B' R' 
Is the cross now solved? yes


# Setting up the Training Environment

Here I am setting up the training environment for the cross solving model.

The model below accepts data in the following format: (batch_size, time_steps, channels, height, width, depth) 
Where the height, width, depth dimmensions are the rubiks cube tensor representaiton. The channels dimmension is initially a singleton dimmension, but will be exapanded upon convolution. 

The time_steps dimmension is a bit special to this implementation in that it will be increased as the solution to a scrambled cube is found. Previous cube states will be stored in the time_steps dimmension (up to a maximum). 

ConvLSTMClassifier itself is only tasked with determining the best *single move* to make given the current cube state. However, the previous cube states are stored in the time_steps dimmension so that the model is informed of the context of the cube state.

In [3]:
from model_cross import ConvLSTMClassifier, ConvLSTM
import torch
 
# Initialize the ConvLSTM
conv_lstm = ConvLSTM(input_channels=1, hidden_channels=[8, 16, 32, 32, 64], kernel_size=3)

num_output_features = 64 * 5 * 5 * 5  # Replace with the correct size

# Initialize ConvLSTMClassifier
classifier = ConvLSTMClassifier(conv_lstm, num_output_features, num_classes=13)

# Example input and target data for classification
input = torch.randn(1, 5, 1, 5, 5, 5)
output = classifier(input)
prediction = output.argmax(dim=1)

# Print shapes for verification
print('Input size:', input.shape)
print('Output size:', output.shape)
print('Prediction size:', prediction.shape)

Input size: torch.Size([1, 5, 1, 5, 5, 5])
Output size: torch.Size([1, 13])
Prediction size: torch.Size([1])


In [4]:
input = torch.randn(32, 5, 1, 5, 5, 5)

output = conv_lstm(input)
print(f'Conv LSTM output shape: {output.shape}')

classifer_output = classifier(input)
print(f'Classifier output shape: {classifer_output.shape}')

Conv LSTM output shape: torch.Size([32, 64, 5, 5, 5])
Classifier output shape: torch.Size([32, 13])


In [24]:
# print total classifier params
print(f"Total classifier params: {sum(p.numel() for p in classifier.parameters())}")

Total classifier params: 1261493


In [5]:
from train_cross import Trainer, TrainConfig
from early_stopping import EarlyStopping
import torch

config = TrainConfig(
    scramble_len=40,
    epochs=1,
    val_num_batches=10,
    batch_size=32,
    lr=0.001,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    optimizer=torch.optim.Adam,
    early_stopping=EarlyStopping(patience=10)
)


trainer = Trainer(cube, config, classifier)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
data, scrambles = trainer.get_batch()

In [7]:
data.shape

torch.Size([32, 1, 1, 5, 5, 5])

In [8]:
data

tensor([[[[[[0, 0, 0, 0, 0],
            [0, 1, 2, 5, 0],
            [0, 3, 1, 1, 0],
            [0, 1, 1, 4, 0],
            [0, 0, 0, 0, 0]],

           [[0, 5, 1, 4, 0],
            [2, 0, 0, 0, 1],
            [4, 0, 0, 0, 4],
            [2, 0, 0, 0, 6],
            [0, 3, 3, 3, 0]],

           [[0, 2, 2, 1, 0],
            [6, 0, 0, 0, 5],
            [3, 0, 0, 0, 5],
            [2, 0, 0, 0, 5],
            [0, 5, 4, 6, 0]],

           [[0, 1, 3, 5, 0],
            [4, 0, 0, 0, 4],
            [3, 0, 0, 0, 4],
            [6, 0, 0, 0, 5],
            [0, 2, 5, 6, 0]],

           [[0, 0, 0, 0, 0],
            [0, 3, 6, 6, 0],
            [0, 2, 6, 6, 0],
            [0, 3, 4, 2, 0],
            [0, 0, 0, 0, 0]]]]],




        [[[[[0, 0, 0, 0, 0],
            [0, 3, 1, 1, 0],
            [0, 2, 1, 2, 0],
            [0, 5, 5, 2, 0],
            [0, 0, 0, 0, 0]],

           [[0, 2, 4, 4, 0],
            [6, 0, 0, 0, 3],
            [3, 0, 0, 0, 1],
            [2, 0, 0, 0, 

In [9]:
normalized_data = trainer.nomalize(data)

In [10]:
print(normalized_data)

tensor([[[[[[0.1429, 0.1429, 0.1429, 0.1429, 0.1429],
            [0.1429, 0.2857, 0.4286, 0.8571, 0.1429],
            [0.1429, 0.5714, 0.2857, 0.2857, 0.1429],
            [0.1429, 0.2857, 0.2857, 0.7143, 0.1429],
            [0.1429, 0.1429, 0.1429, 0.1429, 0.1429]],

           [[0.1429, 0.8571, 0.2857, 0.7143, 0.1429],
            [0.4286, 0.1429, 0.1429, 0.1429, 0.2857],
            [0.7143, 0.1429, 0.1429, 0.1429, 0.7143],
            [0.4286, 0.1429, 0.1429, 0.1429, 1.0000],
            [0.1429, 0.5714, 0.5714, 0.5714, 0.1429]],

           [[0.1429, 0.4286, 0.4286, 0.2857, 0.1429],
            [1.0000, 0.1429, 0.1429, 0.1429, 0.8571],
            [0.5714, 0.1429, 0.1429, 0.1429, 0.8571],
            [0.4286, 0.1429, 0.1429, 0.1429, 0.8571],
            [0.1429, 0.8571, 0.7143, 1.0000, 0.1429]],

           [[0.1429, 0.2857, 0.5714, 0.8571, 0.1429],
            [0.7143, 0.1429, 0.1429, 0.1429, 0.7143],
            [0.5714, 0.1429, 0.1429, 0.1429, 0.7143],
            [1.0000, 0