# Spot Robot Training on Google Colab

**GPU:** T4 (free tier, enough for 500k steps)  
**Time estimate:** ~30-60 min for 500k steps  

## Cell Order:
1. Install dependencies
2. Clone repo
3. Mount Google Drive
4. Verify GPU
5. Train
6. Record evaluation video
7. Download results

## Cell 1: Install Dependencies

In [None]:
# Install all required packages
!pip install gymnasium mujoco stable-baselines3[extra] torch tensorboard imageio[ffmpeg] -q
print("All packages installed!")

## Cell 2: Clone Your Repo

In [None]:
import os

# Clone the repo (skip if already cloned)
if not os.path.exists('/content/claude_code'):
    !git clone https://github.com/Raj9408612613/claude_code.git /content/claude_code
    print("Repo cloned!")
else:
    # Pull latest changes
    !cd /content/claude_code && git pull
    print("Repo already exists, pulled latest.")

# Verify files exist
!ls /content/claude_code/files/

## Cell 3: Mount Google Drive (saves survive disconnects)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create save directories on Google Drive
SAVE_DIR = '/content/drive/MyDrive/robot_training/models'
LOG_DIR = '/content/drive/MyDrive/robot_training/logs'
VIDEO_DIR = '/content/drive/MyDrive/robot_training/videos'

os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(VIDEO_DIR, exist_ok=True)

print(f"Models will save to: {SAVE_DIR}")
print(f"Logs will save to:   {LOG_DIR}")
print(f"Videos will save to: {VIDEO_DIR}")

## Cell 4: Verify GPU

In [None]:
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9
    print(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
    print("Ready to train!")
else:
    print("WARNING: No GPU detected!")
    print("Go to: Runtime > Change runtime type > T4 GPU")

## Cell 5: Check Your Training Parameters

These come from `files/config.py`. Verify they match your budget:

In [None]:
import sys
sys.path.insert(0, '/content/claude_code/files')

from config import ENV_CONFIG, PPO_CONFIG, TRAINING_CONFIG

print("=" * 50)
print("TRAINING PARAMETERS")
print("=" * 50)
print(f"Max episode steps:  {ENV_CONFIG['max_episode_steps']}  (target: 500)")
print(f"Goal distance:      {ENV_CONFIG['goal_distance_min']}-{ENV_CONFIG['goal_distance_max']}m  (target: 2-5m)")
print(f"Network size:       {PPO_CONFIG['policy_net_arch']}  (target: [128, 128])")
print(f"Total timesteps:    {TRAINING_CONFIG['total_timesteps']:,}  (target: 500,000)")
print(f"Parallel envs:      {TRAINING_CONFIG['n_envs']}  (target: 4)")
print(f"Checkpoint freq:    {TRAINING_CONFIG['checkpoint_freq']:,}")
print("=" * 50)

## Cell 6: TRAIN! (This is the main training cell)

Takes ~30-60 min on T4 for 500k steps.

In [None]:
sys.path.insert(0, '/content/claude_code/files')
os.chdir('/content/claude_code/files')

from train_spot import train_spot

# Train with Google Drive paths so checkpoints survive disconnects
model = train_spot(
    save_dir=SAVE_DIR,
    log_dir=LOG_DIR,
)

print("\nTraining complete!")
print(f"Models saved to: {SAVE_DIR}")

## Cell 7: Resume Training (only if Colab disconnected)

Skip this cell if training finished normally.

In [None]:
# ONLY run this if Colab disconnected mid-training!
# It loads the latest checkpoint and continues.

import glob

# Find the latest checkpoint
checkpoints = sorted(glob.glob(f"{SAVE_DIR}/spot_model_*.zip"))
if checkpoints:
    latest = checkpoints[-1]
    print(f"Resuming from: {latest}")

    from train_spot import continue_training
    model = continue_training(
        model_path=latest,
        save_dir=SAVE_DIR,
        log_dir=LOG_DIR,
    )
else:
    print("No checkpoints found. Run Cell 6 first.")

## Cell 8: Evaluate & Record Video

Records the trained robot walking, saves as MP4.

In [None]:
import numpy as np
import imageio
import mujoco
from stable_baselines3 import PPO
from spot_env import SpotEnv

# Load best model
best_model_path = f"{SAVE_DIR}/best_model.zip"
if not os.path.exists(best_model_path):
    best_model_path = f"{SAVE_DIR}/spot_final_model.zip"

print(f"Loading model from: {best_model_path}")
model = PPO.load(best_model_path)

# Create env for evaluation
env = SpotEnv(render_mode=None, max_episode_steps=500)

# Set up MuJoCo offscreen renderer
renderer = mujoco.Renderer(env.model, height=480, width=640)

# Run 3 episodes and record
all_frames = []
for episode in range(3):
    obs, info = env.reset()
    episode_reward = 0
    frames = []

    for step in range(500):
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(action)
        episode_reward += reward

        # Render frame
        renderer.update_scene(env.data)
        frame = renderer.render()
        frames.append(frame)

        if terminated or truncated:
            break

    dist = info.get('distance_to_goal', '?')
    print(f"Episode {episode+1}: reward={episode_reward:.1f}, steps={step+1}, dist_to_goal={dist:.2f}m")
    all_frames.extend(frames)

# Save video
video_path = f"{VIDEO_DIR}/spot_trained.mp4"
imageio.mimsave(video_path, all_frames, fps=30)
print(f"\nVideo saved to: {video_path}")
print(f"Total frames: {len(all_frames)}")

env.close()

## Cell 9: Play Video in Notebook

In [None]:
from IPython.display import HTML
from base64 import b64encode

video_path = f"{VIDEO_DIR}/spot_trained.mp4"
video_data = open(video_path, 'rb').read()
video_b64 = b64encode(video_data).decode()

HTML(f"""
<video width="640" height="480" controls>
  <source src="data:video/mp4;base64,{video_b64}" type="video/mp4">
</video>
""")

## Cell 10: Training Stats

In [None]:
# Check what was saved
import glob

print("Saved models:")
for f in sorted(glob.glob(f"{SAVE_DIR}/*.zip")):
    size_mb = os.path.getsize(f) / 1e6
    print(f"  {os.path.basename(f)} ({size_mb:.1f} MB)")

print(f"\nSaved videos:")
for f in sorted(glob.glob(f"{VIDEO_DIR}/*.mp4")):
    size_mb = os.path.getsize(f) / 1e6
    print(f"  {os.path.basename(f)} ({size_mb:.1f} MB)")

print(f"\nAll files on Google Drive - safe from disconnects!")