# Safety Text Classifier Training - Constitutional AI Stage 1

**Project**: Constitutional AI Research Pipeline  
**Stage**: 1 of 4 - Safety Text Classifier Foundation  
**Framework**: JAX/Flax with GPU acceleration  
**Target**: 85%+ accuracy on 4-class safety classification  

This notebook trains a transformer-based safety text classifier to detect:
- Hate speech and harassment
- Self-harm instructions  
- Dangerous advice and misinformation
- Toxic and offensive content

## 🚀 Getting Started
1. **Runtime**: Change to GPU (Runtime → Change runtime type → GPU)
2. **Execute all cells** to train the model
3. **Download trained model** at the end


## 📦 Environment Setup

Install compatible JAX/Flax versions and dependencies for Colab GPU training.

In [None]:
# Install compatible JAX/Flax for Colab GPU
!pip install -q "jax[cuda12_pip]" "flax" "optax>=0.1.7"
!pip install -q "datasets>=2.19.0" "transformers>=4.40.0" "wandb>=0.17.0"
!pip install -q "scikit-learn>=1.3.0" "matplotlib>=3.7.0" "seaborn>=0.12.0"
!pip install -q "tqdm>=4.66.0" "pyyaml>=6.0.0"
# Fix checkpoint compatibility
!pip install -q "orbax-checkpoint<0.6.0"

print("✅ Dependencies installed!")

In [None]:
# Verify GPU and JAX setup
import jax
import jax.numpy as jnp
import flax
import optax

print(f"🔧 JAX version: {jax.__version__}")
print(f"🔧 Flax version: {flax.__version__}")
print(f"🔧 JAX devices: {jax.devices()}")

# Test GPU
if jax.devices()[0].device_kind == 'gpu':
    print("🚀 GPU acceleration enabled!")
else:
    print("⚠️  Using CPU (consider enabling GPU runtime)")

# Quick JAX test
x = jnp.ones((1000, 1000))
result = jnp.dot(x, x)
print(f"✅ JAX working: {result.shape}")

## 📁 Project Structure Setup

Create the necessary directories and configuration files.

In [None]:
import os
import yaml
from pathlib import Path

# Create project structure
directories = ['configs', 'checkpoints', 'logs', 'data', 'src/models', 'src/data', 'src/training']
for directory in directories:
    os.makedirs(directory, exist_ok=True)

print("📁 Project structure created!")
print("\n📋 Directories:")
for directory in directories:
    print(f"   ✅ {directory}/")

In [None]:
# Create Colab-optimized configuration
config = {
    'model': {
        'name': 'safety_transformer',
        'vocab_size': 32000,
        'embedding_dim': 512,  # Smaller for Colab
        'num_layers': 4,       # Reduced layers
        'num_heads': 8,
        'feedforward_dim': 2048,
        'max_sequence_length': 256,  # Shorter sequences
        'dropout_rate': 0.1,
        'num_classes': 4
    },
    'training': {
        'batch_size': 16,      # Colab-friendly batch size
        'learning_rate': 0.0001,
        'warmup_steps': 500,
        'max_steps': 3000,     # Shorter training for demo
        'eval_every': 500,
        'save_every': 1000,
        'gradient_clip_norm': 1.0,
        'optimizer': 'adamw',
        'weight_decay': 0.01,
        'beta1': 0.9,
        'beta2': 0.999,
        'schedule': 'cosine_with_warmup',
        'min_lr_ratio': 0.1
    },
    'data': {
        'datasets': [
            {'name': 'lmsys/toxic-chat', 'config': 'toxicchat0124'}
        ],
        'max_length': 256,
        'tokenizer': 'sentence-transformers/all-MiniLM-L6-v2',
        'train_split': 0.8,
        'val_split': 0.1,
        'test_split': 0.1,
        'text_augmentation': True,
        'augmentation_prob': 0.1
    },
    'logging': {
        'wandb': {
            'project': 'constitutional-ai-colab',
            'entity': None,
            'tags': ['stage1', 'safety', 'colab', 'gpu']
        },
        'log_level': 'INFO',
        'log_dir': 'logs'
    },
    'paths': {
        'data_dir': 'data',
        'checkpoint_dir': 'checkpoints',
        'log_dir': 'logs'
    }
}

# Save configuration
with open('configs/colab_config.yaml', 'w') as f:
    yaml.dump(config, f, default_flow_style=False)

print("✅ Configuration created: configs/colab_config.yaml")
print(f"🎯 Target: {config['training']['max_steps']} steps with batch size {config['training']['batch_size']}")
print(f"🧠 Model: {config['model']['num_layers']} layers, {config['model']['embedding_dim']} dim")

## 🏗️ Model Architecture

Implement the Safety Transformer model using JAX/Flax.

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen import Dense, Dropout, LayerNorm, Embed
from typing import Callable, Optional, Tuple, Any
import numpy as np
from functools import partial

class MultiHeadAttention(nn.Module):
    """Multi-head attention mechanism."""
    num_heads: int
    head_dim: int
    dropout_rate: float = 0.1

    def setup(self):
        self.dense_q = Dense(self.num_heads * self.head_dim, use_bias=False)
        self.dense_k = Dense(self.num_heads * self.head_dim, use_bias=False)
        self.dense_v = Dense(self.num_heads * self.head_dim, use_bias=False)
        self.dense_output = Dense(self.num_heads * self.head_dim)
        self.dropout = Dropout(self.dropout_rate)

    def __call__(self, x, mask=None, training=True):
        batch_size, seq_len, embed_dim = x.shape
        
        # Compute queries, keys, values
        q = self.dense_q(x)
        k = self.dense_k(x)
        v = self.dense_v(x)
        
        # Reshape for multi-head attention
        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        k = k.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        v = v.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        
        # Transpose to (batch_size, num_heads, seq_len, head_dim)
        q = jnp.transpose(q, (0, 2, 1, 3))
        k = jnp.transpose(k, (0, 2, 1, 3))
        v = jnp.transpose(v, (0, 2, 1, 3))
        
        # Compute attention scores
        attention_scores = jnp.matmul(q, jnp.transpose(k, (0, 1, 3, 2)))
        attention_scores = attention_scores / jnp.sqrt(self.head_dim)
        
        # Apply mask if provided
        if mask is not None:
            mask = jnp.expand_dims(mask, axis=1)
            mask = jnp.expand_dims(mask, axis=1)
            attention_scores = jnp.where(mask, attention_scores, -1e9)
        
        # Apply softmax
        attention_weights = jax.nn.softmax(attention_scores, axis=-1)
        attention_weights = self.dropout(attention_weights, deterministic=not training)
        
        # Apply attention to values
        attention_output = jnp.matmul(attention_weights, v)
        
        # Transpose back and reshape
        attention_output = jnp.transpose(attention_output, (0, 2, 1, 3))
        attention_output = attention_output.reshape(
            batch_size, seq_len, self.num_heads * self.head_dim
        )
        
        # Final linear projection
        output = self.dense_output(attention_output)
        return output, attention_weights

class FeedForward(nn.Module):
    """Position-wise feed-forward network."""
    hidden_dim: int
    output_dim: int
    dropout_rate: float = 0.1

    def setup(self):
        self.dense1 = Dense(self.hidden_dim)
        self.dense2 = Dense(self.output_dim)
        self.dropout = Dropout(self.dropout_rate)

    def __call__(self, x, training=True):
        x = self.dense1(x)
        x = jax.nn.gelu(x)
        x = self.dropout(x, deterministic=not training)
        x = self.dense2(x)
        return x

class TransformerBlock(nn.Module):
    """Single transformer encoder block."""
    num_heads: int
    head_dim: int
    feedforward_dim: int
    dropout_rate: float = 0.1

    def setup(self):
        embed_dim = self.num_heads * self.head_dim
        self.attention = MultiHeadAttention(
            num_heads=self.num_heads,
            head_dim=self.head_dim,
            dropout_rate=self.dropout_rate,
        )
        self.feed_forward = FeedForward(
            hidden_dim=self.feedforward_dim,
            output_dim=embed_dim,
            dropout_rate=self.dropout_rate,
        )
        self.layer_norm1 = LayerNorm()
        self.layer_norm2 = LayerNorm()
        self.dropout = Dropout(self.dropout_rate)

    def __call__(self, x, mask=None, training=True):
        # Multi-head attention with residual connection
        attn_output, attn_weights = self.attention(x, mask=mask, training=training)
        attn_output = self.dropout(attn_output, deterministic=not training)
        x = self.layer_norm1(x + attn_output)
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(x, training=training)
        ff_output = self.dropout(ff_output, deterministic=not training)
        x = self.layer_norm2(x + ff_output)
        
        return x, attn_weights

print("✅ Transformer components defined!")

In [None]:
class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding."""
    max_length: int
    embed_dim: int

    def setup(self):
        # Create positional encoding matrix
        position = jnp.arange(self.max_length)[:, None]
        div_term = jnp.exp(
            jnp.arange(0, self.embed_dim, 2) * -(jnp.log(10000.0) / self.embed_dim)
        )

        pe = jnp.zeros((self.max_length, self.embed_dim))
        pe = pe.at[:, ::2].set(jnp.sin(position * div_term))
        pe = pe.at[:, 1::2].set(jnp.cos(position * div_term))

        self.pe = pe

    def __call__(self, x):
        seq_len = x.shape[1]
        if seq_len > self.max_length:
            # Handle sequences longer than max_length
            position = jnp.arange(seq_len)[:, None]
            div_term = jnp.exp(
                jnp.arange(0, self.embed_dim, 2) * -(jnp.log(10000.0) / self.embed_dim)
            )
            pe = jnp.zeros((seq_len, self.embed_dim))
            pe = pe.at[:, ::2].set(jnp.sin(position * div_term))
            pe = pe.at[:, 1::2].set(jnp.cos(position * div_term))
            return x + pe
        else:
            return x + self.pe[:seq_len]

print("✅ Positional encoding defined!")

In [None]:
class SafetyTransformer(nn.Module):
    """Transformer-based safety text classifier."""
    vocab_size: int
    embedding_dim: int
    num_layers: int
    num_heads: int
    feedforward_dim: int
    max_sequence_length: int
    num_classes: int
    dropout_rate: float = 0.1

    def setup(self):
        self.head_dim = self.embedding_dim // self.num_heads
        assert self.embedding_dim % self.num_heads == 0, "embedding_dim must be divisible by num_heads"

        # Embedding layers
        self.token_embedding = Embed(num_embeddings=self.vocab_size, features=self.embedding_dim)
        self.positional_encoding = PositionalEncoding(
            max_length=self.max_sequence_length, embed_dim=self.embedding_dim
        )

        # Transformer layers
        self.transformer_blocks = [
            TransformerBlock(
                num_heads=self.num_heads,
                head_dim=self.head_dim,
                feedforward_dim=self.feedforward_dim,
                dropout_rate=self.dropout_rate,
            )
            for _ in range(self.num_layers)
        ]

        # Classification head
        self.layer_norm = LayerNorm()
        self.dropout = Dropout(self.dropout_rate)
        self.classifier = Dense(self.num_classes)

    def create_attention_mask(self, input_ids):
        """Create attention mask from input_ids (assuming 0 is padding token)."""
        return input_ids != 0

    def __call__(self, input_ids, training=True):
        """Forward pass of the safety transformer."""
        batch_size, seq_len = input_ids.shape

        # Create attention mask
        attention_mask = self.create_attention_mask(input_ids)

        # Token embeddings
        x = self.token_embedding(input_ids)

        # Add positional encoding
        x = self.positional_encoding(x)
        x = self.dropout(x, deterministic=not training)

        # Pass through transformer blocks
        attention_weights = []
        for transformer_block in self.transformer_blocks:
            x, attn_weights = transformer_block(x, mask=attention_mask, training=training)
            attention_weights.append(attn_weights)

        # Apply final layer norm
        x = self.layer_norm(x)

        # Global average pooling over sequence dimension
        mask_expanded = jnp.expand_dims(attention_mask, axis=-1)
        x_masked = x * mask_expanded
        seq_lengths = jnp.sum(attention_mask, axis=1, keepdims=True)
        pooled = jnp.sum(x_masked, axis=1) / jnp.maximum(seq_lengths, 1)

        # Classification
        pooled = self.dropout(pooled, deterministic=not training)
        logits = self.classifier(pooled)

        return {
            "logits": logits,
            "attention_weights": attention_weights,
            "hidden_states": x,
        }

def create_model(config):
    """Create a SafetyTransformer model from configuration."""
    model_config = config["model"]
    return SafetyTransformer(
        vocab_size=model_config["vocab_size"],
        embedding_dim=model_config["embedding_dim"],
        num_layers=model_config["num_layers"],
        num_heads=model_config["num_heads"],
        feedforward_dim=model_config["feedforward_dim"],
        max_sequence_length=model_config["max_sequence_length"],
        num_classes=model_config["num_classes"],
        dropout_rate=model_config["dropout_rate"],
    )

def initialize_model(model, rng_key, input_shape=None):
    """Initialize model parameters."""
    if input_shape is None:
        input_shape = (1, model.max_sequence_length)
    
    dummy_input = jnp.ones(input_shape, dtype=jnp.int32)
    params = model.init(rng_key, dummy_input)
    return params

print("✅ SafetyTransformer model defined!")

## 📊 Data Loading and Preprocessing

Load safety datasets and prepare them for training.

In [None]:
import logging
from typing import Dict, List, Tuple, Optional, Union
import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets
from transformers import AutoTokenizer
import yaml

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class SafetyDatasetLoader:
    """Loads and preprocesses safety classification datasets for training."""
    
    def __init__(self, config):
        self.config = config
        self.data_config = config['data']
        self.tokenizer = AutoTokenizer.from_pretrained(self.data_config['tokenizer'])
        
        # Safety categories mapping
        self.safety_categories = {
            'hate_speech': 0,
            'self_harm': 1, 
            'dangerous_advice': 2,
            'harassment': 3
        }
        
    def create_synthetic_dataset(self, size: int = 1000) -> Dataset:
        """Create a synthetic dataset for testing and development."""
        logger.info(f"Creating synthetic dataset with {size} examples...")
        
        # Templates for different safety categories
        templates = {
            'hate_speech': [
                "I hate {group} people, they are terrible",
                "{group} are the worst, they should not exist",
                "All {group} are stupid and worthless",
            ],
            'self_harm': [
                "I want to hurt myself by {method}",
                "Instructions for self-harm: {method}",
                "How to cause pain to yourself using {method}",
            ],
            'dangerous_advice': [
                "Here's how to make {dangerous_item} at home",
                "Instructions for creating {dangerous_item}",
                "Easy recipe for {dangerous_item}",
            ],
            'harassment': [
                "I'm going to find you and {threat}",
                "You better watch out, I will {threat}",
                "I know where you live, I will {threat}",
            ],
            'safe': [
                "The weather is nice today",
                "I enjoy reading books in my free time",
                "Cooking is a wonderful hobby to have",
                "Exercise is important for good health",
                "I love spending time with friends and family",
            ]
        }
        
        # Fill-in values for templates
        groups = ["women", "men", "teenagers", "elderly"]
        methods = ["cutting", "burning", "hitting"]
        dangerous_items = ["explosives", "poison", "weapons"]
        threats = ["hurt you", "make you pay", "teach you a lesson"]
        
        texts = []
        labels = []
        
        # Generate examples for each category
        examples_per_category = size // 5  # 5 categories including safe
        
        for category, category_templates in templates.items():
            for _ in range(examples_per_category):
                template = np.random.choice(category_templates)
                
                # Fill in template based on category
                if category == 'hate_speech':
                    text = template.format(group=np.random.choice(groups))
                elif category == 'self_harm':
                    text = template.format(method=np.random.choice(methods))
                elif category == 'dangerous_advice':
                    text = template.format(dangerous_item=np.random.choice(dangerous_items))
                elif category == 'harassment':
                    text = template.format(threat=np.random.choice(threats))
                else:  # safe
                    text = template
                
                texts.append(text)
                
                # Create one-hot encoded labels
                label = [0, 0, 0, 0]
                if category != 'safe':
                    label[self.safety_categories[category]] = 1
                labels.append(label)
        
        return Dataset.from_dict({
            'text': texts,
            'labels': labels,
            'source': ['synthetic'] * len(texts)
        })
    
    def tokenize_dataset(self, dataset: Dataset) -> Dataset:
        """Tokenize the text data using the configured tokenizer."""
        def tokenize_function(examples):
            return self.tokenizer(
                examples['text'],
                truncation=True,
                padding='max_length',
                max_length=self.data_config['max_length'],
                return_tensors=None
            )
        
        return dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=['text']
        )
    
    def create_data_splits(self, dataset: Dataset) -> Tuple[Dataset, Dataset, Dataset]:
        """Split dataset into train, validation, and test sets."""
        train_split = self.data_config['train_split']
        val_split = self.data_config['val_split']
        test_split = self.data_config['test_split']
        
        # First split: separate test set
        split_1 = dataset.train_test_split(test_size=test_split, seed=42)
        train_val = split_1['train']
        test = split_1['test']
        
        # Second split: separate train and validation
        val_size_adjusted = val_split / (train_split + val_split)
        split_2 = train_val.train_test_split(test_size=val_size_adjusted, seed=42)
        train = split_2['train']
        val = split_2['test']
        
        logger.info(f"Data splits created: Train={len(train)}, Val={len(val)}, Test={len(test)}")
        return train, val, test
    
    def load_and_prepare_data(self) -> Tuple[Dataset, Dataset, Dataset]:
        """Main method to load, process, and prepare all datasets."""
        logger.info("Starting data loading and preparation...")
        
        # For Colab, we'll use mostly synthetic data for speed
        synthetic = self.create_synthetic_dataset(size=2000)
        
        # Try to load real data if available
        datasets_to_combine = [synthetic]
        
        try:
            # Load a small subset of toxic-chat for real data
            toxic_chat = load_dataset("lmsys/toxic-chat", "toxicchat0124", split="train[:1000]")
            
            # Process toxic-chat data
            processed_toxic = []
            texts = []
            labels = []
            
            for example in toxic_chat:
                text = example.get('user_input', '')
                toxicity = example.get('toxicity', 0)
                
                if text and toxicity is not None:
                    texts.append(text)
                    # Simple mapping: toxic -> harassment, non-toxic -> safe
                    label = [0, 0, 0, 0]
                    if toxicity > 0:
                        label[3] = 1  # harassment
                    labels.append(label)
            
            toxic_dataset = Dataset.from_dict({
                'text': texts,
                'labels': labels,
                'source': ['toxic-chat'] * len(texts)
            })
            
            datasets_to_combine.append(toxic_dataset)
            logger.info(f"Added {len(toxic_dataset)} toxic-chat examples")
            
        except Exception as e:
            logger.warning(f"Could not load toxic-chat: {e}")
        
        # Combine all datasets
        if len(datasets_to_combine) > 1:
            combined_dataset = concatenate_datasets(datasets_to_combine)
        else:
            combined_dataset = datasets_to_combine[0]
        
        logger.info(f"Combined dataset size: {len(combined_dataset)}")
        
        # Create train/val/test splits
        train, val, test = self.create_data_splits(combined_dataset)
        
        # Tokenize datasets
        train_tokenized = self.tokenize_dataset(train)
        val_tokenized = self.tokenize_dataset(val)
        test_tokenized = self.tokenize_dataset(test)
        
        logger.info("Data preparation completed successfully!")
        return train_tokenized, val_tokenized, test_tokenized

print("✅ Data loading classes defined!")

In [None]:
# Load and prepare data
print("📊 Loading and preparing datasets...")

# Setup HuggingFace authentication (optional but removes warnings)
try:
    from google.colab import userdata
    hf_token = userdata.get('HF_TOKEN')
    if hf_token:
        from huggingface_hub import login
        login(token=hf_token)
        print("✅ HuggingFace authentication successful!")
    else:
        print("ℹ️  No HF_TOKEN found - using public access (this is fine)")
except ImportError:
    print("ℹ️  Not in Colab environment - using public access")
except Exception as e:
    print(f"ℹ️  HF authentication skipped: {e}")

# Load config
with open('configs/colab_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Create data loader
data_loader = SafetyDatasetLoader(config)

# Load and prepare data
train_dataset, val_dataset, test_dataset = data_loader.load_and_prepare_data()

print(f"✅ Data loaded successfully!")
print(f"   📈 Train: {len(train_dataset)} examples")
print(f"   📊 Val: {len(val_dataset)} examples")
print(f"   🧪 Test: {len(test_dataset)} examples")

# Show sample
sample = train_dataset[0]
print(f"\n📝 Sample data:")
print(f"   Input IDs: {sample['input_ids'][:10]}... (length: {len(sample['input_ids'])})")
print(f"   Labels: {sample['labels']}")
print(f"   Source: {sample['source']}")

## 🚀 Training Setup

Set up the training loop with JAX/Flax.

In [None]:
from flax import struct
from flax.training import train_state, checkpoints
import optax
from tqdm import tqdm
import pickle
import os

@struct.dataclass
class TrainState(train_state.TrainState):
    """Extended train state with additional metrics tracking."""
    epoch: int
    best_val_accuracy: float
    steps_since_improvement: int

def create_optimizer(config):
    """Create the optimizer with learning rate schedule."""
    training_config = config['training']
    
    # Learning rate schedule
    schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=training_config['learning_rate'],
        warmup_steps=training_config['warmup_steps'],
        decay_steps=training_config['max_steps'],
        end_value=training_config['learning_rate'] * training_config['min_lr_ratio']
    )
    
    # Create optimizer
    optimizer = optax.adamw(
        learning_rate=schedule,
        weight_decay=training_config['weight_decay'],
        b1=training_config['beta1'],
        b2=training_config['beta2']
    )
    
    # Add gradient clipping
    optimizer = optax.chain(
        optax.clip_by_global_norm(training_config['gradient_clip_norm']),
        optimizer
    )
    
    return optimizer

def compute_loss(params, model, batch, rng_key=None, training=True):
    """Compute loss and metrics for a batch."""
    # Use provided RNG for dropout if training
    rngs = {'dropout': rng_key} if training and rng_key is not None else None
        
    # Forward pass
    outputs = model.apply(params, batch['input_ids'], training=training, rngs=rngs)
    logits = outputs['logits']
    
    # Multi-label classification loss (binary cross-entropy)
    labels = jnp.array(batch['labels'], dtype=jnp.float32)
    loss = optax.sigmoid_binary_cross_entropy(logits, labels).mean()
    
    # Compute predictions and metrics
    predictions = jax.nn.sigmoid(logits)
    predicted_labels = (predictions > 0.5).astype(jnp.int32)
    
    # Accuracy (exact match for multi-label)
    accuracy = jnp.mean(jnp.all(predicted_labels == labels, axis=-1))
    
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
        'predictions': predictions,
        'predicted_labels': predicted_labels
    }
    
    return loss, metrics

def create_train_step(model):
    """Create JIT-compiled training step function."""
    @jax.jit
    def train_step(state, batch, rng_key):
        def loss_fn(params):
            return compute_loss(params, model, batch, rng_key, training=True)
        
        # Compute gradients
        (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
        
        # Update parameters
        state = state.apply_gradients(grads=grads)
        
        # Add gradient norm to metrics
        grad_norm = optax.global_norm(grads)
        metrics['grad_norm'] = grad_norm
        
        return state, metrics
    
    return train_step

def create_eval_step(model):
    """Create JIT-compiled evaluation step function."""
    @jax.jit
    def eval_step(params, batch):
        loss, metrics = compute_loss(params, model, batch, rng_key=None, training=False)
        return metrics
    
    return eval_step

def create_batch(dataset, batch_size, rng_key):
    """Create batches from dataset."""
    dataset_size = len(dataset)
    indices = jax.random.permutation(rng_key, dataset_size)
    
    for i in range(0, dataset_size, batch_size):
        batch_indices = indices[i:i + batch_size]
        if len(batch_indices) < batch_size and i > 0:
            break  # Skip incomplete last batch
        
        batch = {
            'input_ids': jnp.array([dataset[int(idx)]['input_ids'] for idx in batch_indices]),
            'labels': jnp.array([dataset[int(idx)]['labels'] for idx in batch_indices])
        }
        yield batch

def save_model_simple(state, path, step):
    """Simple backup checkpointing method using pickle."""
    os.makedirs(path, exist_ok=True)
    model_data = {
        'params': state.params,
        'step': step,
        'best_val_accuracy': getattr(state, 'best_val_accuracy', 0.0),
        'epoch': getattr(state, 'epoch', 0)
    }
    with open(f'{path}/model_step_{step}.pkl', 'wb') as f:
        pickle.dump(model_data, f)

print("✅ Training functions defined!")

## 🏋️ Model Training

Train the Safety Transformer model!

In [None]:
# Initialize model and training state
print("🔧 Initializing model and training state...")

# Create model
model = create_model(config)
print(f"✅ Model created: {config['model']['num_layers']} layers, {config['model']['embedding_dim']} dim")

# Initialize parameters
rng = jax.random.PRNGKey(42)
rng, init_rng = jax.random.split(rng)

print("⚙️ Initializing parameters...")
params = initialize_model(model, init_rng)
print("✅ Parameters initialized!")

# Create optimizer
optimizer = create_optimizer(config)
print("✅ Optimizer created!")

# Create JIT-compiled training and eval functions
train_step = create_train_step(model)
eval_step = create_eval_step(model)
print("✅ Training functions compiled!")

# Initialize training state
state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer,
    epoch=0,
    best_val_accuracy=0.0,
    steps_since_improvement=0
)

# Count parameters
param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
print(f"📊 Model parameters: {param_count:,}")
print(f"🎯 Training for {config['training']['max_steps']} steps")
print(f"📦 Batch size: {config['training']['batch_size']}")

In [None]:
# Main training loop
print("🚀 Starting training...")

batch_size = config['training']['batch_size']
max_steps = config['training']['max_steps']
eval_every = config['training']['eval_every']
save_every = config['training']['save_every']

step = 0
training_metrics = []
validation_metrics = []

# Training progress bar
pbar = tqdm(total=max_steps, desc="Training")

try:
    while step < max_steps:
        # Create training batches for this epoch
        rng, epoch_rng = jax.random.split(rng)
        
        epoch_step = 0
        for batch in create_batch(train_dataset, batch_size, epoch_rng):
            # Get RNG key for this training step
            rng, step_rng = jax.random.split(rng)
            
            # Training step
            state, train_metrics = train_step(state, batch, step_rng)
            step += 1
            epoch_step += 1
            
            # Update progress bar
            pbar.update(1)
            pbar.set_postfix({
                'loss': f"{float(train_metrics['loss']):.4f}",
                'acc': f"{float(train_metrics['accuracy']):.4f}",
                'grad_norm': f"{float(train_metrics['grad_norm']):.3f}"
            })
            
            # Store metrics
            training_metrics.append({
                'step': step,
                'loss': float(train_metrics['loss']),
                'accuracy': float(train_metrics['accuracy']),
                'grad_norm': float(train_metrics['grad_norm'])
            })
            
            # Evaluation
            if step % eval_every == 0:
                print(f"\n📊 Evaluation at step {step}...")
                
                # Run evaluation
                eval_metrics = []
                rng, eval_rng = jax.random.split(rng)
                
                for val_batch in create_batch(val_dataset, batch_size, eval_rng):
                    metrics = eval_step(state.params, val_batch)
                    eval_metrics.append(metrics)
                
                # Aggregate metrics
                avg_loss = np.mean([float(m['loss']) for m in eval_metrics])
                avg_accuracy = np.mean([float(m['accuracy']) for m in eval_metrics])
                
                print(f"   Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}")
                
                # Store validation metrics
                validation_metrics.append({
                    'step': step,
                    'loss': avg_loss,
                    'accuracy': avg_accuracy
                })
                
                # Check for improvement
                if avg_accuracy > state.best_val_accuracy:
                    print(f"   🎉 New best accuracy: {avg_accuracy:.4f}")
                    state = state.replace(
                        best_val_accuracy=avg_accuracy,
                        steps_since_improvement=0
                    )
                    
                    # Save best model
                    try:
                        checkpoints.save_checkpoint(
                            'checkpoints/best',
                            state,
                            step=step,
                            keep=1,
                            overwrite=True
                        )
                    except Exception as e:
                        print(f"   ⚠️  Checkpoint save failed: {e}")
                        print("   Trying simple backup method...")
                        try:
                            save_model_simple(state, 'checkpoints/best_backup', step)
                            print("   ✅ Backup checkpoint saved!")
                        except Exception as e2:
                            print(f"   ⚠️  Backup also failed: {e2}")
                else:
                    state = state.replace(
                        steps_since_improvement=state.steps_since_improvement + 1
                    )
            
            # Save checkpoint
            if step % save_every == 0:
                try:
                    checkpoints.save_checkpoint(
                        'checkpoints',
                        state,
                        step=step,
                        keep=3,
                        overwrite=True
                    )
                except Exception as e:
                    print(f"   ⚠️  Regular checkpoint save failed: {e}")
            
            if step >= max_steps:
                break
        
        # Update epoch
        state = state.replace(epoch=state.epoch + 1)
        
        if step >= max_steps:
            break

except KeyboardInterrupt:
    print("\n⏹️  Training interrupted by user")

finally:
    pbar.close()

print(f"\n🎉 Training completed!")
print(f"📊 Final step: {step}")
print(f"🏆 Best validation accuracy: {state.best_val_accuracy:.4f}")

# Save final checkpoint
try:
    checkpoints.save_checkpoint('checkpoints/final', state, step=step, keep=1, overwrite=True)
    print("💾 Final model saved!")
except Exception as e:
    print(f"⚠️  Final checkpoint save failed: {e}")
    print("Model training completed but checkpoint not saved")

## 📊 Model Evaluation

Evaluate the trained model on the test set.

In [None]:
# Final evaluation on test set
print("🧪 Running final evaluation on test set...")

# Load best model
best_state = checkpoints.restore_checkpoint('checkpoints/best', state)

# Run test evaluation
test_metrics = []
rng, test_rng = jax.random.split(rng)

for test_batch in create_batch(test_dataset, batch_size, test_rng):
    metrics = eval_step(best_state.params, test_batch)
    test_metrics.append(metrics)

# Aggregate test metrics
test_loss = np.mean([float(m['loss']) for m in test_metrics])
test_accuracy = np.mean([float(m['accuracy']) for m in test_metrics])

print(f"\n📊 Test Results:")
print(f"   🎯 Test Accuracy: {test_accuracy:.4f}")
print(f"   📉 Test Loss: {test_loss:.4f}")

# Check if we met our target
if test_accuracy >= 0.85:
    print(f"\n🎉 SUCCESS! Achieved target accuracy of 85%+")
    print(f"   Stage 1 of Constitutional AI pipeline COMPLETE!")
else:
    print(f"\n📈 Progress: {test_accuracy:.1%} towards 85% target")
    print(f"   Consider training longer or adjusting hyperparameters")

print(f"\n🏆 Training Summary:")
print(f"   📊 Total steps: {step}")
print(f"   🎯 Best val accuracy: {best_state.best_val_accuracy:.4f}")
print(f"   🧪 Final test accuracy: {test_accuracy:.4f}")
print(f"   💾 Model saved in: checkpoints/best/")

## 🎭 Interactive Demo

Test the trained model on sample safety classifications.

In [None]:
# Interactive safety classification demo
print("🎭 Safety Text Classifier Demo")
print("=" * 50)

# Safety categories
categories = ['Hate Speech', 'Self-Harm', 'Dangerous Advice', 'Harassment']

# Test cases
test_cases = [
    "The weather is beautiful today, perfect for a walk.",
    "I hate all people from that group, they are terrible.",
    "Here's how to make dangerous explosives at home.",
    "I'm going to find where you live and hurt you.",
    "I feel sad and want to hurt myself badly.",
    "I love spending time with my family and friends.",
    "You should try the new restaurant downtown.",
    "I think all members of that community are worthless."
]

print("🔍 Testing Safety Classification:")
print("-" * 50)

for i, text in enumerate(test_cases, 1):
    # Tokenize
    encoded = data_loader.tokenizer(
        text,
        truncation=True,
        padding='max_length',
        max_length=config['data']['max_length'],
        return_tensors=None
    )
    input_ids = jnp.array([encoded['input_ids']])
    
    # Predict
    outputs = model.apply(best_state.params, input_ids, training=False)
    probabilities = jax.nn.sigmoid(outputs['logits'])[0]  # Remove batch dim
    predictions = (probabilities > 0.5).astype(int)
    
    # Find flagged categories
    flagged_categories = [categories[j] for j, flag in enumerate(predictions) if flag]
    
    # Determine status
    if flagged_categories:
        status = f"⚠️  UNSAFE: {', '.join(flagged_categories)}"
        max_prob = float(jnp.max(probabilities))
        status += f" (confidence: {max_prob:.2f})"
    else:
        status = "✅ SAFE"
    
    print(f"{i:2d}. \"{text[:45]}{'...' if len(text) > 45 else ''}\"")
    print(f"     → {status}")
    
    # Show individual category scores
    if any(probabilities > 0.1):  # Only show if any category has reasonable confidence
        print("     📊 Category scores:")
        for j, (cat, prob) in enumerate(zip(categories, probabilities)):
            if prob > 0.1:
                emoji = "🚨" if predictions[j] else "⚡"
                print(f"        {emoji} {cat}: {float(prob):.2f}")
    print()

print("=" * 50)
print("🎉 Demo completed!")
print("\n📈 Model Performance:")
print(f"   🎯 Test Accuracy: {test_accuracy:.1%}")
print(f"   📊 Parameters: {param_count:,}")
print(f"   ⚡ Framework: JAX/Flax with GPU acceleration")

## 📈 Training Results Visualization

Visualize the training progress and metrics.

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Create visualizations
print("📊 Creating training visualizations...")

# Convert metrics to DataFrames
train_df = pd.DataFrame(training_metrics)
val_df = pd.DataFrame(validation_metrics)

# Create subplots
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Safety Text Classifier Training Results', fontsize=16, fontweight='bold')

# Training loss
ax1.plot(train_df['step'], train_df['loss'], label='Training Loss', alpha=0.7)
if len(val_df) > 0:
    ax1.plot(val_df['step'], val_df['loss'], label='Validation Loss', marker='o')
ax1.set_xlabel('Training Step')
ax1.set_ylabel('Loss')
ax1.set_title('Training & Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Training accuracy
ax2.plot(train_df['step'], train_df['accuracy'], label='Training Accuracy', alpha=0.7)
if len(val_df) > 0:
    ax2.plot(val_df['step'], val_df['accuracy'], label='Validation Accuracy', marker='o')
    ax2.axhline(y=0.85, color='r', linestyle='--', alpha=0.7, label='Target (85%)')
ax2.set_xlabel('Training Step')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training & Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Gradient norm
ax3.plot(train_df['step'], train_df['grad_norm'], alpha=0.7)
ax3.set_xlabel('Training Step')
ax3.set_ylabel('Gradient Norm')
ax3.set_title('Gradient Norm During Training')
ax3.grid(True, alpha=0.3)

# Performance summary
ax4.axis('off')
summary_text = f"""
📊 TRAINING SUMMARY
━━━━━━━━━━━━━━━━━━━━━━━━━━━
🎯 Final Test Accuracy: {test_accuracy:.1%}
🏆 Best Val Accuracy: {best_state.best_val_accuracy:.1%}
📈 Total Training Steps: {step:,}
💾 Model Parameters: {param_count:,}

🏗️ MODEL ARCHITECTURE
━━━━━━━━━━━━━━━━━━━━━━━━━━━
📐 Layers: {config['model']['num_layers']}
🧠 Embedding Dim: {config['model']['embedding_dim']}
👁️ Attention Heads: {config['model']['num_heads']}
📏 Max Seq Length: {config['model']['max_sequence_length']}

📊 DATASET INFO
━━━━━━━━━━━━━━━━━━━━━━━━━━━
🚂 Train Examples: {len(train_dataset):,}
✅ Val Examples: {len(val_dataset):,}
🧪 Test Examples: {len(test_dataset):,}
🎯 Categories: 4 safety types

{'🎉 TARGET ACHIEVED!' if test_accuracy >= 0.85 else '📈 PROGRESS MADE!'}
"""

ax4.text(0.05, 0.95, summary_text, transform=ax4.transAxes, fontsize=10, 
         verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.show()

print("✅ Visualizations complete!")

## 💾 Download Trained Model

Save and download the trained model for use in local development.

In [None]:
import zipfile
import os
from google.colab import files

print("💾 Preparing model for download...")

# Create a comprehensive model package
download_files = {
    'checkpoints/': 'Model checkpoints',
    'configs/colab_config.yaml': 'Model configuration',
}

# Create model info file
model_info = f"""# Safety Text Classifier - Trained Model

**Trained on**: Google Colab with GPU acceleration
**Framework**: JAX/Flax
**Date**: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}

## Performance Metrics
- **Test Accuracy**: {test_accuracy:.1%}
- **Best Validation Accuracy**: {best_state.best_val_accuracy:.1%}
- **Total Training Steps**: {step:,}
- **Model Parameters**: {param_count:,}

## Model Architecture
- **Layers**: {config['model']['num_layers']}
- **Embedding Dimension**: {config['model']['embedding_dim']}
- **Attention Heads**: {config['model']['num_heads']}
- **Max Sequence Length**: {config['model']['max_sequence_length']}
- **Safety Categories**: 4 (Hate Speech, Self-Harm, Dangerous Advice, Harassment)

## Dataset Information
- **Training Examples**: {len(train_dataset):,}
- **Validation Examples**: {len(val_dataset):,}
- **Test Examples**: {len(test_dataset):,}
- **Data Sources**: Synthetic safety data + toxic-chat dataset

## Usage
1. Load the model configuration from `colab_config.yaml`
2. Restore checkpoint from `checkpoints/best/`
3. Use for safety text classification inference

## Next Steps - Constitutional AI Pipeline
- **Stage 1**: ✅ Safety Text Classifier (COMPLETE)
- **Stage 2**: Helpful Response Fine-tuning (Gemma 7B-IT)
- **Stage 3**: Critique and Revision System
- **Stage 4**: Full Constitutional AI with RLAIF

{'🎉 Stage 1 TARGET ACHIEVED - Ready for Stage 2!' if test_accuracy >= 0.85 else '📈 Good progress - consider additional training for 85%+ target'}
"""

with open('MODEL_INFO.md', 'w') as f:
    f.write(model_info)

# Create training log
training_log = {
    'training_metrics': training_metrics,
    'validation_metrics': validation_metrics,
    'final_test_accuracy': float(test_accuracy),
    'final_test_loss': float(test_loss),
    'config': config,
    'model_info': {
        'total_parameters': int(param_count),
        'training_steps': int(step),
        'best_val_accuracy': float(best_state.best_val_accuracy)
    }
}

import json
with open('training_log.json', 'w') as f:
    json.dump(training_log, f, indent=2)

print("📦 Creating download package...")

# Create zip file
zip_filename = f"safety_text_classifier_trained_{pd.Timestamp.now().strftime('%Y%m%d_%H%M')}.zip"

with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Add model files
    for root, dirs, files in os.walk('checkpoints'):
        for file in files:
            file_path = os.path.join(root, file)
            zipf.write(file_path, file_path)
    
    # Add config and info files
    zipf.write('configs/colab_config.yaml', 'colab_config.yaml')
    zipf.write('MODEL_INFO.md', 'MODEL_INFO.md')
    zipf.write('training_log.json', 'training_log.json')

print(f"✅ Package created: {zip_filename}")
print(f"📊 Package size: {os.path.getsize(zip_filename) / 1024 / 1024:.1f} MB")

# Download the model
print("🚀 Starting download...")
files.download(zip_filename)

print("\n🎉 Model download complete!")
print("\n📋 What you got:")
print("   💾 Trained model checkpoints")
print("   ⚙️ Model configuration file")
print("   📊 Training metrics and logs")
print("   📖 Complete model documentation")

if test_accuracy >= 0.85:
    print("\n🎯 🎉 CONGRATULATIONS! 🎉")
    print("   Stage 1 of Constitutional AI pipeline COMPLETE!")
    print("   Ready to move to Stage 2: Gemma 7B-IT Fine-tuning")
else:
    print(f"\n📈 Great progress! ({test_accuracy:.1%} towards 85% target)")
    print("   Consider training longer or adjusting hyperparameters")
    print("   Still good foundation for Stage 2 development")

print("\n🚀 Next Steps:")
print("   1. Extract the downloaded model package locally")
print("   2. Integrate with your local development environment")
print("   3. Begin Stage 2: Helpful Response Fine-tuning with Gemma 7B-IT")
print("   4. Continue Constitutional AI research pipeline")