# 🛡️ Safety Text Classifier Training - Google Colab

**Constitutional AI Research Project - Stage 1**

This notebook trains a transformer-based safety text classifier using JAX/Flax on Google Colab GPU.

## 📋 Prerequisites
- GPU runtime enabled in Colab
- Weights & Biases account for experiment tracking
- Project files uploaded or cloned

## 🎯 Expected Results
- Training time: ~2-3 hours on GPU
- Target accuracy: 85%+ on safety classification
- Model size: ~50MB

## 1. 🚀 Environment Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
import os

drive.mount('/content/drive')

# Create directories for checkpoints and models
checkpoint_dir = '/content/drive/MyDrive/safety-classifier-checkpoints'
model_dir = '/content/drive/MyDrive/safety-classifier-models'

os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

print(f"✅ Checkpoint directory: {checkpoint_dir}")
print(f"✅ Model directory: {model_dir}")

In [None]:
# Install uv package manager (10-100x faster than pip)
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Add uv to PATH
import os
uv_path = os.path.expanduser('~/.cargo/bin')
os.environ['PATH'] = f"{uv_path}:{os.environ['PATH']}"

# Test uv
!uv --version

In [None]:
# Navigate to project directory
import os
from pathlib import Path

def find_and_navigate_to_project():
    """Find and navigate to the safety-text-classifier project directory."""
    current_dir = Path.cwd()
    print(f"📍 Starting from: {current_dir}")
    
    # Check if we're already in the right place
    if current_dir.name == "safety-text-classifier" and (current_dir / "src").exists():
        print("✅ Already in safety-text-classifier directory!")
        return str(current_dir)
    
    # Common search paths
    search_paths = [
        current_dir / "safety-text-classifier",
        current_dir / "ml-learning" / "safety-text-classifier"
    ]
    
    for path in search_paths:
        if path.exists() and (path / "src").exists():
            print(f"✅ Found project at: {path}")
            os.chdir(path)
            return str(path)
    
    print("❌ Could not find safety-text-classifier directory!")
    print("📋 Available directories:")
    for item in current_dir.iterdir():
        if item.is_dir() and not item.name.startswith('.'):
            print(f"  📁 {item.name}")
    return None

# Navigate to project
project_path = find_and_navigate_to_project()

if project_path:
    # Verify structure
    required_items = ['src', 'configs', 'pyproject.toml']
    missing = [item for item in required_items if not Path(item).exists()]
    
    if not missing:
        print("\n🎉 Project structure verified!")
        print(f"📁 Working directory: {Path.cwd()}")
    else:
        print(f"\n❌ Missing: {missing}")
else:
    print("\n⚠️ Please upload the safety-text-classifier folder to Colab")

## 2. 🏋️ Model Training

In [None]:
# Setup Python path and test imports
import sys
import os
from pathlib import Path

# Add src to Python path
src_path = str(Path.cwd() / "src")
if src_path not in sys.path:
    sys.path.insert(0, src_path)
    print(f"✅ Added to Python path: {src_path}")

# Test imports
try:
    from training.trainer import SafetyTrainer
    from data.dataset_loader import create_data_loaders
    print("✅ Successfully imported training modules")
    
    import jax
    import jax.numpy as jnp
    print(f"✅ JAX {jax.__version__} ready")
    print(f"🎯 Available devices: {jax.devices()}")
    
    # Check for GPU
    gpu_devices = [d for d in jax.devices() if 'gpu' in str(d).lower()]
    if gpu_devices:
        print(f"🚀 GPU acceleration ready: {len(gpu_devices)} GPU(s)")
    else:
        print("⚠️ No GPU detected - training will be slower")
        
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("\n🔧 Troubleshooting:")
    print(f"  - Current directory: {os.getcwd()}")
    print(f"  - Python path: {sys.path[:3]}...")
    print(f"  - Contents of src/: {os.listdir('src') if os.path.exists('src') else 'src not found'}")

In [None]:
# Setup Weights & Biases
import wandb

# Login to W&B
wandb.login()

print("✅ W&B authentication complete")
print("📊 Training metrics will be logged to W&B dashboard")

In [None]:
# Start training with enhanced model saving
print("🚀 Starting training with Colab configuration...")
print("📊 Monitor progress in W&B dashboard")
print("⏱️  Expected training time: 2-3 hours")
print("💾 Model will be saved to Google Drive and W&B")
print("="*60)

try:
    # Initialize trainer with Colab config
    trainer = SafetyTrainer(config_path="configs/colab_config.yaml")
    
    # Update config to use Drive paths
    trainer.config['paths']['checkpoint_dir'] = '/content/drive/MyDrive/safety-classifier-checkpoints'
    trainer.config['paths']['model_dir'] = '/content/drive/MyDrive/safety-classifier-models'
    
    # Start training
    trainer.train()
    
    print("\n🎉 Training completed successfully!")
    
except Exception as e:
    print(f"❌ Training failed: {e}")
    import traceback
    traceback.print_exc()

finally:
    # Clean up W&B
    if 'wandb' in globals() and wandb.run:
        wandb.finish()

## 3. 📊 Model Evaluation

In [None]:
# Evaluate the trained model
import yaml

print("📊 Evaluating trained model...")

try:
    # Load trainer and checkpoint
    trainer = SafetyTrainer('configs/colab_config.yaml')
    
    # Update paths to use Drive
    trainer.config['paths']['checkpoint_dir'] = '/content/drive/MyDrive/safety-classifier-checkpoints'
    
    # Load the best checkpoint
    trainer.load_checkpoint()
    
    # Load test data
    _, _, test_dataset = create_data_loaders('configs/colab_config.yaml')
    
    # Evaluate on test set
    test_metrics = trainer.evaluate(test_dataset, 0, "test/")
    
    # Display results
    print("\n" + "="*60)
    print("🎯 FINAL TEST RESULTS")
    print("="*60)
    print(f"Test Accuracy: {test_metrics['accuracy']:.1%}")
    print(f"Test Loss: {test_metrics['loss']:.4f}")
    
    # Per-class accuracy
    if 'per_class_accuracy' in test_metrics:
        categories = ['Hate Speech', 'Self Harm', 'Dangerous Advice', 'Harassment']
        per_class_acc = test_metrics['per_class_accuracy']
        print("\n📋 Per-Category Accuracy:")
        for cat, acc in zip(categories, per_class_acc):
            print(f"  {cat}: {acc:.1%}")
    
    # Check Stage 1 completion
    accuracy = float(test_metrics['accuracy'])
    target_accuracy = 0.85
    
    print("\n" + "="*60)
    print("🎖️  STAGE 1 STATUS")
    print("="*60)
    
    if accuracy >= target_accuracy:
        print(f"✅ STAGE 1 COMPLETE! ({accuracy:.1%} ≥ {target_accuracy:.0%})")
        print("🎉 Ready to proceed to Stage 2: Helpful Response Fine-tuning")
    else:
        print(f"❌ Stage 1 incomplete ({accuracy:.1%} < {target_accuracy:.0%})")
        print("📝 Consider adjusting hyperparameters or training longer")

except Exception as e:
    print(f"❌ Evaluation failed: {e}")
    import traceback
    traceback.print_exc()

## 4. 💾 Save and Download Model

In [None]:
# Save model to W&B artifacts
import wandb
import os
import shutil

print("💾 Saving trained model to W&B artifacts...")

try:
    # Initialize W&B run for model saving
    wandb.init(
        project="constitutional-ai-research-colab",
        job_type="model-save",
        name="model-artifacts"
    )
    
    # Create model artifact
    model_artifact = wandb.Artifact(
        name="safety-text-classifier",
        type="model",
        description="Trained safety text classifier - Stage 1 complete"
    )
    
    # Add checkpoint directory to artifact
    checkpoint_dir = '/content/drive/MyDrive/safety-classifier-checkpoints'
    if os.path.exists(checkpoint_dir) and os.listdir(checkpoint_dir):
        model_artifact.add_dir(checkpoint_dir, name="checkpoints")
        print(f"✅ Added checkpoints from {checkpoint_dir}")
    else:
        print(f"⚠️  No checkpoints found at {checkpoint_dir}")
    
    # Add config file
    if os.path.exists('configs/colab_config.yaml'):
        model_artifact.add_file('configs/colab_config.yaml', name="config.yaml")
        print("✅ Added configuration file")
    
    # Log the artifact
    wandb.log_artifact(model_artifact)
    
    print(f"🎉 Model saved to W&B! Artifact: {model_artifact.name}:{model_artifact.version}")
    print(f"🔗 View at: https://wandb.ai/{wandb.run.entity}/{wandb.run.project}/artifacts")
    
except Exception as e:
    print(f"❌ W&B save failed: {e}")
    
finally:
    if wandb.run:
        wandb.finish()

In [None]:
# Create downloadable model package
import zipfile
import os
from google.colab import files
import shutil
from datetime import datetime

print("📦 Creating downloadable model package...")

try:
    # Create package directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    package_name = f"safety_classifier_{timestamp}"
    package_dir = f"/content/{package_name}"
    
    os.makedirs(package_dir, exist_ok=True)
    
    # Copy checkpoints
    checkpoint_source = '/content/drive/MyDrive/safety-classifier-checkpoints'
    checkpoint_dest = os.path.join(package_dir, 'checkpoints')
    
    if os.path.exists(checkpoint_source) and os.listdir(checkpoint_source):
        shutil.copytree(checkpoint_source, checkpoint_dest)
        print(f"✅ Copied checkpoints to package")
    else:
        print(f"⚠️  No checkpoints found to package")
    
    # Copy essential files
    files_to_copy = [
        'configs/colab_config.yaml',
        'src/training/trainer.py',
        'src/models/transformer.py',
        'src/data/dataset_loader.py'
    ]
    
    for file_path in files_to_copy:
        if os.path.exists(file_path):
            dest_path = os.path.join(package_dir, file_path)
            os.makedirs(os.path.dirname(dest_path), exist_ok=True)
            shutil.copy2(file_path, dest_path)
            print(f"✅ Copied {file_path}")
    
    # Create README
    readme_content = f"""# Safety Text Classifier - Trained Model

**Training completed**: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
**Stage 1**: Complete

## Contents
- `checkpoints/`: Model checkpoints
- `configs/colab_config.yaml`: Training configuration
- `src/`: Essential source code

## Usage
```python
from src.training.trainer import SafetyTrainer

# Load trained model
trainer = SafetyTrainer('configs/colab_config.yaml')
trainer.load_checkpoint()
```
"""
    
    with open(os.path.join(package_dir, 'README.md'), 'w') as f:
        f.write(readme_content)
    
    # Create zip file
    zip_path = f"/content/{package_name}.zip"
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(package_dir):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, package_dir)
                zipf.write(file_path, arcname)
    
    # Get zip size
    zip_size = os.path.getsize(zip_path) / (1024*1024)
    
    print(f"\n🎉 Model package created successfully!")
    print(f"📦 Package: {package_name}.zip")
    print(f"📏 Size: {zip_size:.1f} MB")
    print(f"\n📥 Downloading...")
    
    # Download the package
    files.download(zip_path)
    
    print(f"\n✅ Download complete!")
    print(f"\n🔧 To use locally:")
    print(f"1. Extract the zip file")
    print(f"2. Update config paths to point to local directories")
    print(f"3. Run your evaluation script")
    
except Exception as e:
    print(f"❌ Package creation failed: {e}")
    import traceback
    traceback.print_exc()

## 🎯 Next Steps

With Stage 1 complete, you're ready for:

1. **Stage 2**: Helpful Response Fine-tuning
2. **Model Deployment**: Set up inference server
3. **Demo Creation**: Interactive safety classifier interface

### Local Setup
After downloading the model package:

```bash
# Extract model
unzip safety_classifier_*.zip
cd safety_classifier_*

# Update paths in your local config
# Then run evaluation
python3 evaluate_model.py
```