## Exploring Neural Network Solutions

___


This Notebook demonstrates three of my unsuccessful attempts to apply statistical learning to Rubik's cube. The best results come from the second attempt (below), which, incidientally, is the only version that formats inputs in one-hot format. The networks used are 5 layers deep and fully connected; the are policy networks with a softmax output across the 18 possible moves.

Note that I'm using preprocessed inputs in each demonstration (though not explicity demonstrated, similarly unsatisfactory results occur with straight facelet positions):
1. first the dance heuristic
2. then a squared difference measure beteen adjacent facelets (this is akin to my entropy/order heuristic)
3. finally a dot-product measure between adjacent facelets. 

My training and validation datasets are generated based on a simple "reverse-scramble" methodology for estimating what the appropriate next move is for a given cube permutation. For details of network architecture and dataset construction, see [NNSolver](../rubiks/solver/NNSolver.py)

For good measure, I demonstrate 20 move solution predictions from the trained networks. As you can see, the networks tend to fall into ruts of redundant move sequences (e.g. just rotating one side over and over by 90 degrees -90, 90, -90, 90, ad nauseam). Such rutted predictions are arguably less effective even than totally random move selection.

___


In [None]:
import sys

# This for managing relative imports from nb
if '..' not in sys.path: sys.path.append('..')

import keras
import numpy as np
import matplotlib.pyplot as plt

from rubiks.model.CubeView import CubeView
from rubiks.model.DirectCube import DirectCube
from rubiks.model.VectorCube import VectorCube, color_name

from rubiks.solver.NNSolver import NNSolver
from rubiks.solver.DirectSolver import DirectSolver

In [None]:
# This first (and my most recent) attempt explores the viability of a statistical
# correlation between the Dance heuristics and move sequences. Results are pretty
# dismal, ~20/17% accuracy on train/validation. 

nnsolver = NNSolver()

# Dance version
X, Yoh, Xval, Yohval = nnsolver.generate_dataset(nnsolver.create_dance_input)
policy_model = nnsolver.create_policy_model(input_dim=216)

In [None]:
history = policy_model.fit(X, Yoh, validation_data=(Xval, Yohval), callbacks=nnsolver.early_stop, batch_size=256, epochs=25)

In [None]:
# List all data in history
print(history.history.keys())

# Plot with respect to accuracy
plt.figure(1)
plt.plot(history.history['categorical_accuracy'])
plt.plot(history.history['val_categorical_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['train', 'validate'], loc='upper left')

# Plot with respect to loss
plt.figure(2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validate'], loc='upper left')

In [None]:
# This cell depicts 20 moves predicted by the dance model just trained.

cube = DirectCube().scramble()
view = CubeView(cube)
view.push_snapshot()

for i in range(20):
    dance_input = nnsolver.create_dance_input(cube).reshape((1,216))
    mv = VectorCube.MOVES[np.argmax(policy_model.predict(dance_input))]
    caption = f"{color_name(mv[0])}({mv[1]}) : {DirectSolver.generate_hstate(cube)}"
    cube.rotate(mv); view.push_snapshot(caption=caption)
    
view.draw_snapshots()

In [None]:
# Diff version (generates input in one-hot format)
Xoh, Yoh, Xohval, Yohval = nnsolver.generate_dataset(nnsolver.create_diffsq_input)
policy_model = nnsolver.create_policy_model(input_dim=4320)

In [None]:
history = policy_model.fit(Xoh, Yoh, validation_data=(Xohval, Yohval), callbacks=nnsolver.early_stop, batch_size=256, epochs=10)

In [None]:
# List all data in history
print(history.history.keys())

# Plot with respect to accuracy
plt.figure(1)
plt.plot(history.history['categorical_accuracy'])
plt.plot(history.history['val_categorical_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['train', 'validate'], loc='upper left')

# Plot with respect to loss
plt.figure(2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validate'], loc='upper left')

In [None]:
# This cell depicts 20 moves predicted by the squared-difference model just trained.

cube = DirectCube().scramble()
view = CubeView(cube)
view.push_snapshot()

for i in range(20):
    oh_input = nnsolver.create_diffsq_input(cube).reshape((1,4320))
    mv = VectorCube.MOVES[np.argmax(policy_model.predict(oh_input))]
    caption = f"{color_name(mv[0])}({mv[1]}) : {DirectSolver.generate_hstate(cube)}"
    cube.rotate(mv); view.push_snapshot(caption=caption)
    
view.draw_snapshots()

In [None]:
# Dot version - results meager at best, ~30% accuracy tops
X, Yoh, Xval, Yohval = nnsolver.generate_dataset(nnsolver.create_dot_input)
policy_model = nnsolver.create_policy_model(input_dim=72)

In [None]:
history = policy_model.fit(X, Yoh, validation_data=(Xval, Yohval), callbacks=nnsolver.early_stop, batch_size=256, epochs=25)

In [None]:
# List all data in history
print(history.history.keys())

# Plot with respect to accuracy
plt.figure(1)
plt.plot(history.history['categorical_accuracy'])
plt.plot(history.history['val_categorical_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['train', 'validate'], loc='upper left')

# Plot with respect to loss
plt.figure(2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validate'], loc='upper left')

In [None]:
# This cell depicts 20 moves predicted by the dot-product model just trained.

cube = DirectCube().scramble()
view = CubeView(cube)
view.push_snapshot()

for i in range(20):
    dot_input = nnsolver.create_dot_input(cube).reshape((1,72))
    mv = VectorCube.MOVES[np.argmax(policy_model.predict(dot_input))]
    caption = f"{color_name(mv[0])}({mv[1]}) : {DirectSolver.generate_hstate(cube)}"
    cube.rotate(mv); view.push_snapshot(caption=caption)
    
view.draw_snapshots()