# Rubik's Cube Neural Solver - Colab Training

This notebook allows you to train the neural network using Google Colab's GPU or TPU.
It automatically clones the repository code so you can run it directly from GitHub.

In [None]:
# @title 1. Setup Environment
import os
import sys

# Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running in Google Colab")
except ImportError:
    IN_COLAB = False
    print("Running locally")

if IN_COLAB:
    # Mount Google Drive (optional, for saving weights permanently)
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Clone repository if modules are missing or to ensure freshness
    if not os.path.exists('NeuroRubik'):
        print("Cloning repository...")
        !git clone https://github.com/Kiwiabacaxi/NeuroRubik.git
    else:
        print("Repository already exists. Pulling latest changes...")
        %cd NeuroRubik
        !git pull
        %cd ..

    # Add python directory to path
    repo_path = '/content/NeuroRubik/python'
    if repo_path not in sys.path:
        sys.path.append(repo_path)
    
    # Change working directory to python folder so imports work relative to it
    if os.path.exists(repo_path):
        os.chdir(repo_path)
        print(f"Changed directory to {os.getcwd()}")
    
    # Install requirements
    !pip install -q tqdm
    
    # Note: We rely on pre-installed environment for TPU/GPU
else:
    # Local setup
    if os.path.basename(os.getcwd()) != 'python':
        if os.path.exists('python'):
            os.chdir('python')
        elif os.path.exists('../python'):
            os.chdir('../python')

In [None]:
# @title 2. Import Modules
import torch
import numpy as np
import json
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tqdm.notebook import tqdm

print(f"PyTorch Version: {torch.__version__}")

# DEVICE SENSING (GPU -> TPU -> CPU)
if torch.cuda.is_available():
    print(f"GPU Available: {torch.cuda.get_device_name(0)}")
    DEVICE = torch.device('cuda')
else:
    # Try TPU (XLA)
    try:
        import torch_xla.core.xla_model as xm
        DEVICE = xm.xla_device()
        print(f"TPU Device Detected: {DEVICE}")
        print("\033[91mWARNING: You are using TPU. For Genetic Algorithms (which have frequent CPU<->Accelerator transfers), TPU might be SLOWER than GPU.")
        print("RECOMMENDATION: Change Runtime type to 'T4 GPU' for faster performance on this specific task.\033[0m")
    except ImportError:
        print("WARNING: GPU not available and TPU libraries (torch_xla) not found.")
        print("Using CPU (Training will be slow).")
        DEVICE = torch.device('cpu')

# Import project modules
try:
    from cube.cube_state import CubeState, MOVES
    from genetic.evolution import GeneticAlgorithm
    from genetic.individual import Individual
    from genetic.fitness import BatchFitnessEvaluator
    from neural.network import CubeSolverNetwork
    from train import save_checkpoint, create_output_dir
    print("Modules imported successfully!")
except ImportError as e:
    print(f"Error importing modules: {e}")
    print("Make sure the repository is cloned and you are in the correct directory.")

In [None]:
# @title 3. Configuration
# Training Parameters
POPULATION = 1000 # @param {type:"integer"}
GENERATIONS = 500 # @param {type:"integer"}
ELITISM = 50 # @param {type:"integer"}
MUTATION_RATE = 0.15 # @param {type:"number"}
MUTATION_STRENGTH = 0.25 # @param {type:"number"}

# Curriculum
INITIAL_DEPTH = 1 # @param {type:"integer"}
MAX_DEPTH = 20 # @param {type:"integer"}
TEST_CUBES = 100 # @param {type:"integer"}

# Architecture
HIDDEN_LAYERS = (512, 512, 512, 256, 128) # @param {type:"raw"}

# Paths
BASE_OUTPUT_DIR = "/content/drive/MyDrive/CUBE/cube_weights" # @param {type:"string"}

# Resume Training (Optional)
# Paste the full path to a checkpoint FOLDER (e.g., /content/drive/.../run_20260128_022005)
# Leave empty to start fresh.
LOAD_CHECKPOINT_PATH = "/content/drive/MyDrive/CUBE/cube_weights/run_20260128_022005" # @param {type:"string"}

In [None]:
# @title 4. Initialize Training

# Create Output Directory (with Versioning)
try:
    os.makedirs(BASE_OUTPUT_DIR, exist_ok=True)
    output_path = create_output_dir(BASE_OUTPUT_DIR)
    print(f"Saving new results to: {output_path}")
except OSError as e:
    print(f"Error creating directory {BASE_OUTPUT_DIR}: {e}")
    print("Switching to local weights/ directory")
    BASE_OUTPUT_DIR = "weights_colab"
    output_path = create_output_dir(BASE_OUTPUT_DIR)
    print(f"Saving results to: {output_path}")

# Initialize Network
network = CubeSolverNetwork(hidden_sizes=HIDDEN_LAYERS, device=DEVICE)
genome_size = network.get_weight_count()
print(f"Network created with {genome_size:,} weights")

# Initialize GA
ga = GeneticAlgorithm(
    population_size=POPULATION,
    genome_size=genome_size,
    mutation_rate=MUTATION_RATE,
    mutation_strength=MUTATION_STRENGTH,
    elitism_count=ELITISM
)

# Initialize Population (Always do this first)
ga.initialize_population()

# Load Checkpoint if specified
current_generation_offset = 0
best_fitness_loaded = 0.0

if LOAD_CHECKPOINT_PATH:
    checkpoint_file = os.path.join(LOAD_CHECKPOINT_PATH, "checkpoint.json")
    if os.path.exists(checkpoint_file):
        print(f"Process resuming from: {checkpoint_file}")
        with open(checkpoint_file, 'r') as f:
            checkpoint = json.load(f)
        
        # Load Best Genome
        best_genome = np.array(checkpoint["best_genome"], dtype=np.float32)
        best_fitness_loaded = checkpoint["best_fitness"]
        
        # Inject into GA
        ga.best_ever = Individual(genome=best_genome)
        ga.best_ever.fitness = best_fitness_loaded
        
        # Seed population with best genome (Elitism)
        # This ensures we don't start from zero, but still have variation
        ga.population[0] = ga.best_ever.clone()
        
        # Restore training state
        current_generation_offset = checkpoint.get("generation", 0)
        print(f"Successfully loaded! Resuming from Gen {current_generation_offset} (Best Fitness: {best_fitness_loaded:.2f})")
        
        # Try to restore curriculum depth if available
        if "evaluator" in checkpoint:
             INITIAL_DEPTH = checkpoint["evaluator"].get("scramble_depth", INITIAL_DEPTH)
             print(f"Restored Scramble Depth: {INITIAL_DEPTH}")
    else:
        print(f"WARNING: Checkpoint file not found at {checkpoint_file}. Starting fresh.")

# Initialize Evaluator
evaluator = BatchFitnessEvaluator(
    num_test_cubes=TEST_CUBES,
    scramble_depth=INITIAL_DEPTH,
    max_steps=50,
    hidden_sizes=HIDDEN_LAYERS,
    device=DEVICE
)

training_history = {
    'fitness': [],
    'depth': [],
    'solved_rate': []
}

In [None]:
# @title 5. Run Training Loop

print(f"Starting training for {GENERATIONS} generations...")
print("Tip: If progress takes too long to appear, you might be using TPU which is slow for this specific task due to data transfer overhead. Prefer T4 GPU.")

try:
    # TQDM Progress Bar
    pbar = tqdm(range(GENERATIONS), desc="Training", unit="gen")
    
    for i in pbar:
        gen = current_generation_offset + i
        ga.generation = gen

        # Regenerate test cubes occasionally
        if i % 5 == 0:
            evaluator.regenerate_test_cubes()

        # Evolve
        ga.evolve_generation(evaluator)

        # Stats
        best = ga.population[0]
        solve_rate = best.solved_count / TEST_CUBES
        
        # Update curriculum
        depth_changed = evaluator.update_difficulty(solve_rate)
        curriculum = evaluator.get_status()
        current_depth = curriculum['current_depth']

        # Store history
        training_history['fitness'].append(best.fitness)
        training_history['depth'].append(current_depth)
        training_history['solved_rate'].append(solve_rate)
        
        # Update Progress Bar Text
        pbar.set_postfix({
            'Depth': current_depth,
            'Best': f"{best.fitness:.1f}",
            'Solved': f"{best.solved_count}/{TEST_CUBES}",
            'Rate': f"{solve_rate*100:.0f}%"
        })

        # Visualization (Updates Graph)
        if (i > 0 and i % 10 == 0) or depth_changed:
            clear_output(wait=True)
            # Re-display pbar after clear_output to keep it visible (Colab quirk)
            display(pbar.container)
            
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            
            # Fitness Plot
            ax1.plot(training_history['fitness'], label='Best Fitness', color='blue')
            ax1.set_title(f'Gen {gen} - Fitness Progress')
            ax1.set_xlabel('Generation')
            ax1.set_ylabel('Fitness')
            ax1.grid(True)
            
            # Curriculum Plot
            ax2.plot(training_history['depth'], label='Scramble Depth', color='red')
            ax2.set_title(f'Curriculum Level (Current: {current_depth})')
            ax2.set_xlabel('Generation')
            ax2.set_ylabel('Depth')
            ax2.grid(True)
            
            plt.show()
        
        # Save Checkpoint
        if i > 0 and i % 20 == 0:
            save_checkpoint(ga, network, output_path, evaluator, HIDDEN_LAYERS)
            
        # TPU specific: Mark step to trigger execution
        if 'xm' in globals():
             try:
                 import torch_xla.core.xla_model as xm
                 xm.mark_step()
             except Exception:
                 pass
            
except KeyboardInterrupt:
    print("Training interrupted")
    
# Final Save
save_checkpoint(ga, network, output_path, evaluator, HIDDEN_LAYERS)
print(f"Done! saved to {output_path}")