# 🛡️ Safety Text Classifier Training - Google Colab

**Constitutional AI Research Project - Stage 1**

This notebook trains a transformer-based safety text classifier using JAX/Flax on Google Colab GPU.

## 📋 Prerequisites
- GPU runtime enabled in Colab
- Weights & Biases account for experiment tracking
- Project files uploaded or cloned

## 🎯 Expected Results
- Training time: ~2-3 hours on GPU
- Target accuracy: 85%+ on safety classification
- Model size: ~50MB

## 1. 🚀 Environment Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Mount Google Drive for checkpoints
from google.colab import drive
drive.mount('/content/drive')

# Create checkpoint directory on Drive
import os
checkpoint_dir = '/content/drive/MyDrive/safety-classifier-checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"✅ Checkpoint directory ready: {checkpoint_dir}")

## 1.5 🚀 Setup with uv 

This step installs dependencies using uv (10-100x faster than pip) with proper compatibility and resolved dependency conflicts.

In [None]:
# Install uv package manager
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Add uv to PATH
import os
uv_path = os.path.expanduser('~/.cargo/bin')
os.environ['PATH'] = f"{uv_path}:{os.environ['PATH']}"

# Test uv
!uv --version

In [None]:
import os
os.chdir('ml-learning/safety-text-classifier')
print(f"Changed directory to: {os.getcwd()}")

## 1.9 📍 Navigate to Project Directory

This cell helps you navigate to the correct directory if you've uploaded the full ml-learning repository.

In [None]:
# Navigation Helper - Run this if you have path issues
import os
from pathlib import Path

def find_and_navigate_to_project():
    current_dir = Path.cwd()
    print(f"📍 Current directory: {current_dir}")
    
    # Check if we're already in the right place
    if current_dir.name == "safety-text-classifier":
        if (current_dir / "src").exists() and (current_dir / "configs").exists():
            print("✅ Already in safety-text-classifier directory!")
            return str(current_dir)
    
    # Look for safety-text-classifier in current directory
    safety_path = current_dir / "safety-text-classifier"
    if safety_path.exists() and safety_path.is_dir():
        print(f"✅ Found safety-text-classifier at: {safety_path}")
        os.chdir(safety_path)
        print(f"📂 Changed to: {Path.cwd()}")
        return str(safety_path)
    
    # Look for ml-learning directory
    ml_learning_path = current_dir / "ml-learning"
    if ml_learning_path.exists():
        safety_in_ml = ml_learning_path / "safety-text-classifier"
        if safety_in_ml.exists():
            print(f"✅ Found project at: {safety_in_ml}")
            os.chdir(safety_in_ml)
            print(f"📂 Changed to: {Path.cwd()}")
            return str(safety_in_ml)
    
    print("❌ Could not find safety-text-classifier directory!")
    print("📋 Available directories:")
    for item in current_dir.iterdir():
        if item.is_dir():
            print(f"  📁 {item.name}")
    return None

# Run navigation
project_path = find_and_navigate_to_project()

# Verify structure
if project_path:
    print("\n🔍 Verifying project structure...")
    required_items = ['src', 'configs', 'pyproject.toml']
    all_good = True
    for item in required_items:
        if Path(item).exists():
            print(f"✅ {item}")
        else:
            print(f"❌ {item}")
            all_good = False
    
    if all_good:
        print("\n🎉 Project structure looks good!")
    else:
        print("\n⚠️  Some required files are missing.")

print(f"\n📍 Final directory: {Path.cwd()}")

## 2. 🏋️ Model Training

In [None]:
# Project Setup - Run this cell first
import os
import sys
from pathlib import Path

def setup_project():
    """Find project directory and set up Python path."""
    current_dir = Path.cwd()
    print(f"📍 Starting from: {current_dir}")
    
    # Check if we're already in the right place
    if current_dir.name == "safety-text-classifier" and (current_dir / "src").exists():
        print("✅ Already in safety-text-classifier directory!")
        project_path = current_dir
    else:
        # Look for project in common locations
        search_paths = [
            current_dir / "safety-text-classifier",
            current_dir / "ml-learning" / "safety-text-classifier"
        ]
        
        project_path = None
        for path in search_paths:
            if path.exists() and (path / "src").exists():
                print(f"✅ Found project at: {path}")
                os.chdir(path)
                project_path = path
                break
        
        if not project_path:
            print("❌ Could not find safety-text-classifier directory!")
            print("📋 Available directories:")
            for item in current_dir.iterdir():
                if item.is_dir():
                    print(f"  📁 {item.name}")
            return False
    
    # Add src to Python path for imports
    src_path = str(project_path / "src")
    if src_path not in sys.path:
        sys.path.insert(0, src_path)
        print(f"✅ Added to Python path: {src_path}")
    
    # Verify structure
    required_items = ['src', 'configs', 'pyproject.toml']
    missing = [item for item in required_items if not (project_path / item).exists()]
    
    if not missing:
        print("🎉 Project structure verified!")
        print(f"📁 Working directory: {project_path}")
        return True
    else:
        print(f"❌ Missing: {missing}")
        return False

# Run setup
if setup_project():
    print("✅ Setup complete!")
else:
    print("⚠️ Please upload the safety-text-classifier folder to Colab")

In [None]:
# Test imports - Run this after project setup
import os
import sys

# Debug path information
print(f"Current working directory: {os.getcwd()}")
print(f"Contents of src/: {os.listdir('src') if os.path.exists('src') else 'src not found'}")

# Try different import approaches
try:
    # Method 1: Direct import from current directory
    sys.path.insert(0, os.getcwd())  # Add current directory to path
    from src.data.dataset_loader import create_data_loaders, SafetyDatasetLoader
    from src.training.trainer import SafetyTrainer
    print("✅ Method 1 successful: Direct src import")
    
except ImportError as e:
    print(f"❌ Method 1 failed: {e}")
    try:
        # Method 2: Import with src in path (should work from setup cell)
        from data.dataset_loader import create_data_loaders, SafetyDatasetLoader
        from training.trainer import SafetyTrainer
        print("✅ Method 2 successful: Import from src path")
        
    except ImportError as e2:
        print(f"❌ Method 2 failed: {e2}")
        print("Available modules in src/:")
        for item in os.listdir('src'):
            if os.path.isdir(f'src/{item}'):
                print(f"  📁 {item}/: {os.listdir(f'src/{item}')}")

# Test other imports
try:
    import jax
    import numpy as np
    print(f"✅ JAX {jax.__version__} and NumPy {np.__version__} imported")
    print(f"JAX devices: {jax.devices()}")
    
    # Check for GPU
    gpu_devices = [d for d in jax.devices() if 'gpu' in str(d).lower()]
    if gpu_devices:
        print(f"🎯 GPU ready: {gpu_devices}")
    else:
        print("⚠️ CPU only - training will be slower")
        
except Exception as e:
    print(f"❌ JAX/NumPy error: {e}")

print("\\n🎯 If imports are successful, you're ready for training!")

In [None]:
# Start training with Colab-optimized config
print("🚀 Starting training with Colab configuration...")
print("📊 Monitor progress in W&B dashboard")
print("⏱️  Expected training time: 2-3 hours")
print("="*60)

try:
    # Initialize trainer with Colab config
    trainer = SafetyTrainer(config_path="configs/colab_config.yaml")
    
    # Start training
    trainer.train()
    
    print("\n🎉 Training completed successfully!")
    
except Exception as e:
    print(f"❌ Training failed: {e}")
    import traceback
    traceback.print_exc()

finally:
    # Clean up W&B
    if 'wandb' in globals() and wandb.run:
        wandb.finish()

In [None]:
# Evaluate the trained model
import sys
sys.path.append('src')

from training.trainer import SafetyTrainer
from data.dataset_loader import create_data_loaders
import yaml
import numpy as np

print("📊 Loading model for evaluation...")

# Load config
with open('configs/colab_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Load trainer (will load best checkpoint)
trainer = SafetyTrainer('configs/colab_config.yaml')

# Load test data
_, _, test_dataset = create_data_loaders('configs/colab_config.yaml')

# Evaluate on test set
test_metrics = trainer.evaluate(test_dataset, 0, "test/")

print("\n🎯 Final Test Results:")
print(f"Test Accuracy: {test_metrics['accuracy']:.1%}")
print(f"Test Loss: {test_metrics['loss']:.4f}")

# Per-class accuracy
if 'per_class_accuracy' in test_metrics:
    categories = ['Hate Speech', 'Self Harm', 'Dangerous Advice', 'Harassment']
    per_class_acc = test_metrics['per_class_accuracy']
    print("\n📋 Per-Category Accuracy:")
    for cat, acc in zip(categories, per_class_acc):
        print(f"  {cat}: {acc:.1%}")

# Setup Weights & Biases and start training
import wandb

# Login to W&B
wandb.login()

print("🚀 Starting training...")
print("📊 Monitor progress in W&B dashboard")
print("⏱️ Expected time: 2-3 hours on GPU")
print("=" * 50)

try:
    # Initialize trainer with Colab config
    trainer = SafetyTrainer(config_path="configs/colab_config.yaml")
    
    # Start training
    trainer.train()
    
    print("\n🎉 Training completed successfully!")
    
except Exception as e:
    print(f"❌ Training failed: {e}")
    import traceback
    traceback.print_exc()

finally:
    # Clean up W&B
    if 'wandb' in globals() and wandb.run:
        wandb.finish()

## 5. 💾 Save Model for Production

In [None]:
# Evaluate the trained model
import yaml

print("📊 Evaluating trained model...")

try:
    # Use the imports that worked in the test cell above
    from src.training.trainer import SafetyTrainer
    from src.data.dataset_loader import create_data_loaders
except ImportError:
    # Fallback to path-based imports
    from training.trainer import SafetyTrainer
    from data.dataset_loader import create_data_loaders

# Load trainer (will automatically load best checkpoint)
trainer = SafetyTrainer('configs/colab_config.yaml')

# Load test data
_, _, test_dataset = create_data_loaders('configs/colab_config.yaml')

# Evaluate on test set
test_metrics = trainer.evaluate(test_dataset, 0, "test/")

print("\\n🎯 Final Test Results:")
print(f"Test Accuracy: {test_metrics['accuracy']:.1%}")
print(f"Test Loss: {test_metrics['loss']:.4f}")

# Per-class accuracy
if 'per_class_accuracy' in test_metrics:
    categories = ['Hate Speech', 'Self Harm', 'Dangerous Advice', 'Harassment']
    per_class_acc = test_metrics['per_class_accuracy']
    print("\\n📋 Per-Category Accuracy:")
    for cat, acc in zip(categories, per_class_acc):
        print(f"  {cat}: {acc:.1%}")

# Evaluate the trained model
from training.trainer import SafetyTrainer
from data.dataset_loader import create_data_loaders
import yaml

print("📊 Evaluating trained model...")

# Load trainer (will automatically load best checkpoint)
trainer = SafetyTrainer('configs/colab_config.yaml')

# Load test data
_, _, test_dataset = create_data_loaders('configs/colab_config.yaml')

# Evaluate on test set
test_metrics = trainer.evaluate(test_dataset, 0, "test/")

print("\n🎯 Final Test Results:")
print(f"Test Accuracy: {test_metrics['accuracy']:.1%}")
print(f"Test Loss: {test_metrics['loss']:.4f}")

# Per-class accuracy
if 'per_class_accuracy' in test_metrics:
    categories = ['Hate Speech', 'Self Harm', 'Dangerous Advice', 'Harassment']
    per_class_acc = test_metrics['per_class_accuracy']
    print("\n📋 Per-Category Accuracy:")
    for cat, acc in zip(categories, per_class_acc):
        print(f"  {cat}: {acc:.1%}")