# Project Aether - Google Colab Setup

This notebook sets up and runs Project Aether on Google Colab with GPU support.

## Features
- Automatic GPU detection and setup
- Model downloads (Stable Diffusion 1.4 - less censored for research)
- All three phases: Probe Training, PPO Training, Evaluation
- **⚡ Fast training config** (2-3 hours instead of 8 hours) - automatically uses optimized settings
- **Empirical layer sensitivity measurement** (FID & SSR) for optimal intervention points
- Optimized for Colab's T4 GPU (16GB VRAM)
- **Nudity-focused** content filtering for clearer concept boundaries
- Image visualization to verify probe accuracy

## Important Notes
- **Model:** Uses `CompVis/stable-diffusion-v1-4` (less censored than SD 1.5)
- **Focus:** Nudity-only content (not gore/violence) for better probe training
- **Filtering:** Strict thresholds (≥50% nudity, ≥60% inappropriate, hard prompts only)

## References
- **FID Metric:** Heusel et al. (2017). "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium." NeurIPS 2017.
- **Linear Probing:** Alain & Bengio (2016). "Understanding Intermediate Layers Using Linear Classifier Probes." arXiv:1610.01644.
- **PPO:** Schulman et al. (2017). "Proximal Policy Optimization Algorithms." arXiv:1707.06347.


## Step 1: Install Dependencies


In [None]:
# Install PyTorch with CUDA 12.1 (Colab default)
print("Installing PyTorch...")
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 -q

# Install other dependencies
print("Installing core dependencies...")
!pip install diffusers transformers accelerate safetensors -q
!pip install gymnasium numpy scikit-learn matplotlib tqdm -q
!pip install pyyaml pillow lpips -q
!pip install datasets -q  # For I2P dataset
!pip install pytorch-fid -q  # For FID metric (Heusel et al., 2017)

print("✓ All dependencies installed!")


## Step 2: Clone Repository or Upload Files


In [None]:
# Option A: Clone from GitHub
import os
if not os.path.exists('project-aether'):
    print("Cloning repository...")
    !git clone https://github.com/Anastasia-Deniz/project-aether.git
    print("✓ Repository cloned!")
else:
    print("✓ Repository already exists, skipping clone")

%cd project-aether

# Option B: If you uploaded files manually, uncomment:
# %cd /content/project-aether

# Verify we're in the right directory
import sys
from pathlib import Path
if Path('scripts/train_ppo.py').exists():
    print(f"✓ Project structure verified! Working directory: {Path.cwd()}")
else:
    print("⚠ Warning: Project structure not found. Make sure you're in the project-aether directory.")


## Step 5.25: (Optional) Verify Labels ⭐ NEW - RECOMMENDED

**Important:** Before training probes, verify that generated images match their labels!

This step uses CLIP to verify that images actually match their prompt-based labels. This is critical because SD 1.4 may generate safe images from unsafe prompts (censorship), leading to poor separability.

**Note:** This takes ~10-20 minutes but significantly improves probe accuracy.


In [None]:
# Verify labels using CLIP-based safety classifier
import os
from pathlib import Path

print("="*60)
print("VERIFYING LABELS")
print("="*60)
print("This step:")
print("  - Decodes images from latents")
print("  - Uses CLIP to verify labels match images")
print("  - Filters out mismatched samples")
print("  - Creates cleaned dataset")
print("  - Estimated time: 10-20 minutes")
print("="*60)

latents_dirs = sorted(Path('data/latents').glob('run_*'), key=lambda p: p.stat().st_mtime)
if latents_dirs:
    latest_latents = latents_dirs[-1]
    print(f"\nUsing latents from: {latest_latents}")
    
    # Check if already verified
    verified_dir = latest_latents.parent / f"{latest_latents.name}_verified"
    if verified_dir.exists():
        print(f"\n✓ Verified dataset already exists: {verified_dir}")
        print("  Skipping verification. To re-verify, delete this directory first.")
    else:
        print(f"\nVerifying labels...")
        
        import torch
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        !python scripts/verify_labels.py \
            --latents_dir {latest_latents} \
            --method clip \
            --threshold 0.7 \
            --device {device}
        
        # Check if verification succeeded
        if verified_dir.exists():
            print(f"\n✓ Verification complete! Verified dataset: {verified_dir}")
            
            # Show statistics
            mismatch_file = verified_dir / "mismatch_report.json"
            if mismatch_file.exists():
                import json
                with open(mismatch_file) as f:
                    mismatches = json.load(f)
                print(f"  Mismatches found: {len(mismatches)}")
                
                if len(mismatches) > 0:
                    print(f"\n⚠ Warning: {len(mismatches)} samples had mismatched labels!")
                    print("  Review mismatch_report.json for details.")
                    print("  Consider:")
                    print("    - Using stricter prompt filtering")
                    print("    - Using a more explicit model")
                    print("    - Manually reviewing generated images")
        else:
            print("\n⚠ Warning: Verification may have failed. Check for errors above.")
else:
    print("⚠ Error: No latents found! Run Step 4 first.")


## Step 3: Verify GPU and Setup


In [None]:
import torch
import sys
from pathlib import Path

# Verify GPU
print("="*60)
print("GPU VERIFICATION")
print("="*60)
cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")

if cuda_available:
    gpu_name = torch.cuda.get_device_name(0)
    vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"VRAM: {vram_gb:.2f} GB")
    
    if vram_gb < 12:
        print("⚠ Warning: Less than 12GB VRAM. Consider reducing batch sizes.")
    else:
        print("✓ Sufficient VRAM for Colab-optimized config")
else:
    print("⚠ Warning: No GPU detected! Training will be very slow on CPU.")
    print("  Make sure Runtime > Change runtime type > Hardware accelerator = GPU")

# Add project to path
project_root = Path.cwd()
sys.path.insert(0, str(project_root))
print(f"\nProject root: {project_root}")

# Create necessary directories
print("\nCreating directories...")
dirs = ['data/latents', 'checkpoints/probes', 'outputs/ppo', 'outputs/evaluation', 'outputs/visualizations']
for d in dirs:
    Path(d).mkdir(parents=True, exist_ok=True)
    print(f"  ✓ {d}")

print("\n✓ Setup complete!")


## Step 4: Phase 1 - Collect Latents


In [None]:
# Collect latents for probe training
# Colab T4 has 16GB VRAM, so we can use larger batch sizes
# Using SD 1.4 (20 steps) - less censored than SD 1.5, better for research
# Focus on nudity only with strict quality thresholds

print("="*60)
print("PHASE 1: COLLECTING LATENTS")
print("="*60)
print("This will:")
print("  - Download Stable Diffusion 1.4 model (~4GB)")
print("  - Generate 100 safe and 100 unsafe images")
print("  - Save latents at each timestep")
print("  - Estimated time: 30-60 minutes")
print("="*60)

# Check if CUDA is available
import torch
if not torch.cuda.is_available():
    print("⚠ Warning: CUDA not available. This will be very slow!")
    device = "cpu"
else:
    device = "cuda"

!python scripts/collect_latents.py \
    --num_samples 100 \
    --num_steps 20 \
    --device {device} \
    --model_id CompVis/stable-diffusion-v1-4 \
    --focus_nudity \
    --hard_only \
    --min_inappropriate_pct 60.0 \
    --min_nudity_pct 50.0 \
    --save_images

# Verify output
from pathlib import Path
latents_dirs = sorted(Path('data/latents').glob('run_*'), key=lambda p: p.stat().st_mtime)
if latents_dirs:
    latest = latents_dirs[-1]
    print(f"\n✓ Latents collected! Output: {latest}")
    
    # Count files
    latent_files = list(latest.glob('latents/timestep_*.npz'))
    print(f"  Found {len(latent_files)} timestep files")
else:
    print("\n⚠ Warning: No latents directory found. Check for errors above.")


## Step 5: Phase 1 - Train Probes


In [None]:
# Train linear probes
# Find the latest latents directory
import os
import json
from pathlib import Path

print("="*60)
print("PHASE 1: TRAINING LINEAR PROBES")
print("="*60)

latents_dirs = sorted(Path('data/latents').glob('run_*'), key=lambda p: p.stat().st_mtime)
if latents_dirs:
    latest_latents = latents_dirs[-1]
    print(f"Using latents from: {latest_latents}")
    
    # Check if empirical measurements exist
    use_empirical = False
    quality_file = latest_latents / "quality_measurements.json"
    effectiveness_file = latest_latents / "effectiveness_measurements.json"
    
    if quality_file.exists() and effectiveness_file.exists():
        print("✓ Found empirical measurements! Using them for better accuracy.")
        use_empirical = True
    else:
        print("Using improved heuristics (faster). For better accuracy, run Step 5.5 first.")
    
    # Train probes
    if use_empirical:
        print("\nTraining with empirical measurements...")
        !python scripts/train_probes.py --latents_dir {latest_latents} --use_empirical
    else:
        print("\nTraining with heuristics...")
        !python scripts/train_probes.py --latents_dir {latest_latents}
    
    # Print probe results summary
    probe_dirs = sorted(Path('checkpoints/probes').glob('run_*'), key=lambda p: p.stat().st_mtime)
    if probe_dirs:
        latest_probe = probe_dirs[-1]
        metrics_file = latest_probe / 'probe_metrics.json'
        sensitivity_file = latest_probe / 'sensitivity_scores.json'
        
        if metrics_file.exists():
            with open(metrics_file) as f:
                metrics = json.load(f)
            
            print("\n" + "="*60)
            print("PROBE ACCURACY SUMMARY")
            print("="*60)
            best_acc = 0
            best_t = None
            for t in sorted(metrics.keys(), key=int):
                acc = metrics[t]['test_acc']
                print(f"Timestep {t:2d}: {acc:.3f} ({acc*100:5.1f}%)")
                if acc > best_acc:
                    best_acc = acc
                    best_t = t
            
            print(f"\n✓ Best accuracy: {best_acc:.3f} at timestep {best_t}")
            
            # Check sensitivity scores
            if sensitivity_file.exists():
                with open(sensitivity_file) as f:
                    sens_data = json.load(f)
                
                if 'optimal_window' in sens_data:
                    window = sens_data['optimal_window']
                    print(f"\n✓ Recommended intervention window: steps {window.get('start', '?')} to {window.get('end', '?')}")
                    if 'top_timesteps' in window:
                        print(f"  Top timesteps: {window['top_timesteps']}")
        else:
            print("⚠ Warning: probe_metrics.json not found")
    else:
        print("⚠ Warning: No probe directories created. Check for errors above.")
else:
    print("⚠ Error: No latents found! Run Step 4 first.")


## Step 5.5: (Optional) Measure Empirical Layer Sensitivity ⭐ NEW

**Recommended for best results:** Measure FID and SSR empirically instead of using heuristics.

This step runs small steering experiments to measure:
- **Quality preservation**: FID between steered and unsteered images (Heusel et al., 2017)
- **Steering effectiveness**: SSR improvement from steering

**Note:** This takes additional time (~30-60 min) but provides more accurate sensitivity scores.


In [None]:
# Measure empirical layer sensitivity (FID and SSR)
# This improves the quality of layer sensitivity analysis
import os
from pathlib import Path

print("="*60)
print("MEASURING EMPIRICAL LAYER SENSITIVITY")
print("="*60)
print("This step:")
print("  - Runs small steering experiments at each timestep")
print("  - Measures FID (quality preservation)")
print("  - Measures SSR (steering effectiveness)")
print("  - Estimated time: 30-60 minutes")
print("="*60)

latents_dirs = sorted(Path('data/latents').glob('run_*'), key=lambda p: p.stat().st_mtime)
probe_dirs = sorted(Path('checkpoints/probes').glob('run_*'), key=lambda p: p.stat().st_mtime)

if latents_dirs:
    latest_latents = latents_dirs[-1]
    print(f"\nUsing latents from: {latest_latents}")
    
    # Use probe from Step 5 if available
    probe_path = None
    if probe_dirs:
        latest_probe = probe_dirs[-1] / 'pytorch'
        if latest_probe.exists():
            probe_path = str(latest_probe)
            print(f"Using probe: {probe_path}")
        else:
            print("⚠ Warning: Probe directory exists but pytorch/ subdirectory not found")
    else:
        print("⚠ Warning: No probes found. Running without probe (will use random steering)")
    
    # Check if already measured
    quality_file = latest_latents / "quality_measurements.json"
    effectiveness_file = latest_latents / "effectiveness_measurements.json"
    
    if quality_file.exists() and effectiveness_file.exists():
        print("\n✓ Measurements already exist! Skipping measurement.")
        print(f"  Quality: {quality_file}")
        print(f"  Effectiveness: {effectiveness_file}")
        print("\nTo re-measure, delete these files first.")
    else:
        print(f"\nMeasuring empirical sensitivity...")
        print("This may take 30-60 minutes...")
        
        import torch
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        if probe_path:
            !python scripts/measure_layer_sensitivity.py \
                --latents_dir {latest_latents} \
                --num_samples 20 \
                --device {device} \
                --probe_path {probe_path}
        else:
            !python scripts/measure_layer_sensitivity.py \
                --latents_dir {latest_latents} \
                --num_samples 20 \
                --device {device}
        
        # Verify measurements were created
        quality_file = latest_latents / "quality_measurements.json"
        effectiveness_file = latest_latents / "effectiveness_measurements.json"
        
        if quality_file.exists():
            print(f"\n✓ Quality measurements saved: {quality_file}")
        else:
            print(f"\n⚠ Warning: Quality measurements not found")
        
        if effectiveness_file.exists():
            print(f"✓ Effectiveness measurements saved: {effectiveness_file}")
        else:
            print(f"⚠ Warning: Effectiveness measurements not found")
        
        if quality_file.exists() and effectiveness_file.exists():
            print("\n✓ Measurements complete! Now re-run Step 5 to use them.")
        else:
            print("\n⚠ Some measurements missing. Check for errors above.")
else:
    print("⚠ Error: No latents found! Run Step 4 first.")


## Step 6: Visualize Generated Images & Verify Probe Accuracy ⭐ NEW

**Important:** Before training PPO, verify that the generated images match their labels!


In [None]:
# Generate images from collected latents to verify what was actually generated
import os
from pathlib import Path

print("="*60)
print("GENERATING IMAGES FROM LATENTS")
print("="*60)

latents_dirs = sorted(Path('data/latents').glob('run_*'), key=lambda p: p.stat().st_mtime)
if latents_dirs:
    latest_latents = latents_dirs[-1]
    print(f"Using latents from: {latest_latents}")
    
    # Check if images already exist
    viewer_path = latest_latents / "images_t20/viewer.html"
    if viewer_path.exists():
        print("\n✓ Images already generated! Skipping...")
        print(f"  Viewer: {viewer_path}")
    else:
        print("\nGenerating images from final timestep (t=20)...")
        print("This may take 5-10 minutes...")
        
        import torch
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        !python scripts/generate_images_from_latents.py \
            --latents_dir {latest_latents} \
            --timestep 20 \
            --num_samples 50 \
            --device {device}
        
        # Check if HTML viewer was created
        viewer_path = latest_latents / "images_t20/viewer.html"
        if viewer_path.exists():
            print(f"\n✓ Images generated! Viewer: {viewer_path}")
        else:
            print("\n⚠ Warning: HTML viewer not found. Check for errors above.")
    
    # Show how to view
    if viewer_path.exists():
        print("\nTo view images in Colab, run the next cell!")
else:
    print("⚠ Error: No latents found! Run Step 4 first.")


### View Images in Colab

Display the HTML viewer directly in the notebook:


In [None]:
# Display HTML viewer in Colab
from IPython.display import HTML, display
import os
from pathlib import Path

latents_dirs = sorted(Path('data/latents').glob('run_*'), key=os.path.getmtime)
if latents_dirs:
    latest_latents = latents_dirs[-1]
    viewer_path = latest_latents / "images_t20/viewer.html"
    
    if viewer_path.exists():
        with open(viewer_path, 'r', encoding='utf-8') as f:
            html_content = f.read()
        display(HTML(html_content))
    else:
        print("Viewer not found. Run the previous cell first.")
else:
    print("No latents found!")


### Visualize Probe Results

See which images are correctly/incorrectly classified by the probe:


In [None]:
# Visualize probe predictions on images
import os
import json
from pathlib import Path

print("="*60)
print("VISUALIZING PROBE RESULTS")
print("="*60)

latents_dirs = sorted(Path('data/latents').glob('run_*'), key=lambda p: p.stat().st_mtime)
probe_dirs = sorted(Path('checkpoints/probes').glob('run_*'), key=lambda p: p.stat().st_mtime)

if latents_dirs and probe_dirs:
    latest_latents = latents_dirs[-1]
    latest_probe = probe_dirs[-1]
    probe_pytorch = latest_probe / 'pytorch'
    
    # Find best timestep from sensitivity analysis
    best_timestep = 4  # Default
    sensitivity_file = latest_probe / 'sensitivity_scores.json'
    
    if sensitivity_file.exists():
        with open(sensitivity_file) as f:
            sens_data = json.load(f)
        
        # Find timestep with highest score
        best_score = -1
        for t_str, data in sens_data.items():
            if t_str == "optimal_window":
                continue
            if isinstance(data, dict) and 'score' in data:
                score = data['score']
                if score > best_score:
                    best_score = score
                    best_timestep = int(t_str)
        
        print(f"Using best timestep from sensitivity analysis: t={best_timestep} (score={best_score:.3f})")
    else:
        print(f"Using default timestep: t={best_timestep}")
        print("  (Run Step 5 to get sensitivity analysis)")
    
    if not probe_pytorch.exists():
        print(f"⚠ Error: Probe directory not found: {probe_pytorch}")
    else:
        print(f"\nVisualizing probe results:")
        print(f"  Latents: {latest_latents}")
        print(f"  Probe: {probe_pytorch}")
        print(f"  Timestep: {best_timestep}")
        
        import torch
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        !python scripts/visualize_probe_results.py \
            --latents_dir {latest_latents} \
            --probe_dir {probe_pytorch} \
            --timestep {best_timestep} \
            --num_samples 30 \
            --device {device}
        
        # Display visualization
        viz_path = Path('outputs/visualizations') / f'probe_visualization_t{best_timestep:02d}.png'
        if viz_path.exists():
            from IPython.display import Image, display
            print(f"\n✓ Visualization:")
            display(Image(str(viz_path)))
            print(f"  Saved to: {viz_path}")
        else:
            print("\n⚠ Warning: Visualization not found. Check for errors above.")
else:
    if not latents_dirs:
        print("⚠ Error: No latents found! Run Step 4 first.")
    if not probe_dirs:
        print("⚠ Error: No probes found! Run Step 5 first.")


## Step 7: Phase 2 - Train PPO Policy


In [None]:
# Train PPO policy with Colab-optimized config
# Colab T4 can handle larger batch sizes than RTX 4050
# The config uses probe_path: "auto" to automatically find the latest probe

print("="*60)
print("PHASE 2: TRAINING PPO POLICY")
print("="*60)

# Verify prerequisites
from pathlib import Path
import os

probe_dirs = sorted(Path('checkpoints/probes').glob('run_*'), key=lambda p: p.stat().st_mtime)
# Use fast config for 2-3 hour training (reduced from 8 hours)
config_file = Path('configs/colab_fast_20steps.yaml')

# Fallback to original config if fast config doesn't exist
if not config_file.exists():
    print(f"⚠ Fast config not found: {config_file}")
    print("  Falling back to colab_optimized.yaml...")
    config_file = Path('configs/colab_optimized.yaml')

if not probe_dirs:
    print("⚠ Error: No probes found! Run Step 5 first.")
elif not config_file.exists():
    print(f"⚠ Error: Config file not found: {config_file}")
else:
    latest_probe = probe_dirs[-1]
    print(f"Using probe: {latest_probe}")
    print(f"Config: {config_file}")
    print("\nTraining settings (FAST MODE - 2-3 hours):")
    print("  - Total timesteps: 50,000 (reduced from 200K)")
    print("  - Batch size: 32 (Colab T4 optimized)")
    print("  - Epochs: 4 (optimal from experiments)")
    print("  - Estimated time: 2-3 hours (reduced from 8 hours)")
    print("="*60)
    
    # Check GPU
    import torch
    if not torch.cuda.is_available():
        print("⚠ Warning: No GPU detected! Training will be very slow.")
    
    print("\nStarting training...")
    !python scripts/train_ppo.py --config {config_file}
    
    # Check if training completed
    ppo_dirs = sorted(Path('outputs/ppo').glob('aether_ppo_*'), key=lambda p: p.stat().st_mtime)
    if ppo_dirs:
        latest_run = ppo_dirs[-1]
        policy_file = latest_run / 'final_policy.pt'
        if policy_file.exists():
            print(f"\n✓ Training complete! Policy saved: {policy_file}")
        else:
            print(f"\n⚠ Warning: final_policy.pt not found. Check for errors above.")
            print(f"  Run directory: {latest_run}")
    else:
        print("\n⚠ Warning: No training output found. Check for errors above.")


## Step 8: Phase 3 - Evaluate Policy


### Option A: Quick Evaluation

Evaluate the trained policy:


In [None]:
# Evaluate trained policy
import os
from pathlib import Path

print("="*60)
print("PHASE 3: EVALUATING POLICY")
print("="*60)

# Find latest policy and probe
ppo_dirs = sorted(Path('outputs/ppo').glob('aether_ppo_*'), key=lambda p: p.stat().st_mtime)
probe_dirs = sorted(Path('checkpoints/probes').glob('run_*'), key=lambda p: p.stat().st_mtime)

if not ppo_dirs:
    print("⚠ Error: No training runs found! Run Step 7 first.")
elif not probe_dirs:
    print("⚠ Error: No probe directories found! Run Step 5 first.")
else:
    latest_policy = ppo_dirs[-1] / 'final_policy.pt'
    latest_probe = probe_dirs[-1] / 'pytorch'
    
    print(f"Policy: {latest_policy}")
    print(f"Probe: {latest_probe}")
    
    if not latest_policy.exists():
        print(f"⚠ Error: Policy file not found: {latest_policy}")
        print(f"  Available files in {ppo_dirs[-1]}:")
        for f in ppo_dirs[-1].glob('*.pt'):
            print(f"    - {f.name}")
    elif not latest_probe.exists():
        print(f"⚠ Error: Probe directory not found: {latest_probe}")
    else:
        print("\nEvaluation metrics:")
        print("  - SSR (Safety Success Rate): Higher is better")
        print("  - FPR (False Positive Rate): Lower is better")
        print("  - LPIPS (Perceptual Distance): Lower is better")
        print("  - Transport Cost: Lower is better")
        print("="*60)
        
        import torch
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        !python scripts/evaluate_ppo.py \
            --policy_path {latest_policy} \
            --probe_path {latest_probe} \
            --num_samples 50 \
            --device {device}
        
        # Check for evaluation results
        eval_dirs = sorted(Path('outputs/evaluation').glob('eval_*'), key=lambda p: p.stat().st_mtime)
        if eval_dirs:
            latest_eval = eval_dirs[-1]
            print(f"\n✓ Evaluation complete! Results: {latest_eval}")


### Option B: Run Multiple Experiments (Compare Hyperparameters)

Run different hyperparameter configurations to compare results:


In [None]:
# Run all experiments (takes ~9-12 hours total)
# Each experiment: ~1.5-2 hours, 100K timesteps
!python scripts/run_experiments.py --all

# Or run specific experiments:
# !python scripts/run_experiments.py --experiments exp1 exp2 exp3

# Experiments:
# exp1: Low lambda (0.3) - aggressive safety
# exp2: Medium lambda (0.5) - balanced
# exp3: High lambda (0.8) - efficient actions
# exp4: Fast learning rate (3e-4)
# exp5: More epochs (12)
# exp6: Smaller policy (256,128)


In [None]:
### Option C: Run Single Experiment Manually

Run a specific experiment configuration:


In [None]:
# Example: Run experiment 1 (low lambda)
!python scripts/train_ppo.py --config configs/colab_experiment_1_low_lambda.yaml

# Other experiments:
# !python scripts/train_ppo.py --config configs/colab_experiment_2_medium_lambda.yaml
# !python scripts/train_ppo.py --config configs/colab_experiment_3_high_lambda.yaml
# !python scripts/train_ppo.py --config configs/colab_experiment_4_fast_learning.yaml
# !python scripts/train_ppo.py --config configs/colab_experiment_5_more_epochs.yaml
# !python scripts/train_ppo.py --config configs/colab_experiment_6_smaller_policy.yaml


## Step 9: Save Results to Google Drive

Mount your Google Drive and save results:


In [None]:
from google.colab import drive
import shutil
from pathlib import Path

print("="*60)
print("SAVING RESULTS TO GOOGLE DRIVE")
print("="*60)

# Mount Drive
print("Mounting Google Drive...")
drive.mount('/content/drive')

# Copy results to Drive
drive_path = Path('/content/drive/MyDrive/project-aether-results')
drive_path.mkdir(parents=True, exist_ok=True)

print(f"\nCopying results to: {drive_path}")

# Copy outputs
if Path('outputs').exists():
    print("  Copying outputs/...")
    !cp -r outputs {drive_path}/
    print("    ✓ outputs/")

# Copy checkpoints
if Path('checkpoints').exists():
    print("  Copying checkpoints/...")
    !cp -r checkpoints {drive_path}/
    print("    ✓ checkpoints/")

# Copy latents
if Path('data/latents').exists():
    print("  Copying data/latents/...")
    !cp -r data/latents {drive_path}/
    print("    ✓ data/latents/")

# Also copy visualization results if they exist
viz_path = Path('outputs/visualizations')
if viz_path.exists():
    print("  Copying visualizations/...")
    !cp -r {viz_path} {drive_path}/outputs/
    print("    ✓ visualizations/")

print(f"\n✓ Results saved to: {drive_path}")
print(f"\nSaved directories:")
print(f"  - Training outputs: {drive_path}/outputs/")
print(f"  - Probes: {drive_path}/checkpoints/probes/")
print(f"  - Latents and images: {drive_path}/data/latents/")
if viz_path.exists():
    print(f"  - Visualizations: {drive_path}/outputs/visualizations/")

# Show size
import subprocess
result = subprocess.run(['du', '-sh', str(drive_path)], capture_output=True, text=True)
if result.returncode == 0:
    size = result.stdout.split()[0]
    print(f"\nTotal size: {size}")
