# Attention Head Agency: Proof of Concept

This notebook demonstrates the benefits of attention head agency in the Sentinel-AI framework. We'll show how attention heads can express internal states and how the system respects these signals during computation.

In [None]:
!pip install transformers datasets torch matplotlib
%matplotlib inline

In [None]:
import sys
import os
import torch
import numpy as np
import time
import matplotlib.pyplot as plt
from transformers import AutoTokenizer

# Add the project root to the path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath('__file__'))))

from models.loaders.loader import load_baseline_model, load_adaptive_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Step 1: Load Baseline and Adaptive Models

In [None]:
model_name = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

baseline_model = load_baseline_model(model_name, device)
adaptive_model = load_adaptive_model(model_name, baseline_model, device)

# Helper function for generating text
def generate_text(model, prompt, max_tokens=50):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    start_time = time.time()
    outputs = model.generate(
        **inputs,
        max_length=len(inputs.input_ids[0]) + max_tokens,
        do_sample=True,
        temperature=0.7
    )
    end_time = time.time()
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return {
        "text": generated_text,
        "time": end_time - start_time,
        "tokens_per_second": max_tokens / (end_time - start_time)
    }

## Step 2: Verify Agency Features Exist

In [None]:
# Check if the agency features are available
has_agency = hasattr(adaptive_model, "get_agency_report")
print(f"Agency features available: {has_agency}")

if has_agency:
    # Check agency report before any states are changed
    initial_report = adaptive_model.get_agency_report()
    print("Initial agency report:")
    print(f"Total layers: {initial_report['num_layers']}")
    print(f"Total violations: {initial_report['total_violations']}")
    
    # Check a specific layer
    if len(initial_report['layer_reports']) > 0:
        layer_idx = 0  # First layer
        layer_report = initial_report['layer_reports'].get(layer_idx, {})
        print(f"\nLayer {layer_idx} report:")
        for key, value in layer_report.items():
            if key != 'recent_violations':  # Skip lengthy violations list
                print(f"  {key}: {value}")
else:
    print("Agency features not found in the model. Make sure you're using the latest version with agency support.")

## Step 3: Test Generation with Default Agency States

In [None]:
prompts = [
    "Once upon a time in a land far away,",
    "The future of artificial intelligence depends on",
    "To solve the world's most pressing problems, we need to"
]

# Generate text with default agency states
print("Generating text with default agency states...\n")
default_results = {}

for idx, prompt in enumerate(prompts):
    print(f"Prompt {idx+1}: {prompt}")
    result = generate_text(adaptive_model, prompt)
    print(f"Generated: {result['text']}")
    print(f"Generation time: {result['time']:.2f} seconds")
    print(f"Tokens per second: {result['tokens_per_second']:.2f}")
    print()
    
    default_results[idx] = result

## Step 4: Simulate "Overloaded" State in Some Heads

In [None]:
if has_agency:
    # Set some heads to "overloaded" state in multiple layers
    num_layers = adaptive_model.num_layers
    heads_per_layer = 12  # Standard for distilgpt2
    
    # Mark a subset of heads as overloaded
    overloaded_heads = {}
    for layer_idx in range(num_layers):
        # Mark 1/3 of heads as overloaded in each layer
        for head_idx in range(0, heads_per_layer, 3):  # Every 3rd head
            adaptive_model.set_head_state(layer_idx, head_idx, "overloaded")
            if layer_idx not in overloaded_heads:
                overloaded_heads[layer_idx] = []
            overloaded_heads[layer_idx].append(head_idx)
    
    print("Set the following heads to 'overloaded' state:")
    for layer_idx, heads in overloaded_heads.items():
        print(f"Layer {layer_idx}: Heads {heads}")
    
    # Check agency report after setting states
    new_report = adaptive_model.get_agency_report()
    print("\nAgency report after setting states:")
    for layer_idx in range(num_layers):
        if layer_idx in new_report['layer_reports']:
            layer_report = new_report['layer_reports'][layer_idx]
            print(f"Layer {layer_idx}: {layer_report['active_heads']} active, {layer_report['overloaded_heads']} overloaded")
else:
    print("Agency features not available. Skipping this step.")

## Step 5: Test Generation with Overloaded Heads

In [None]:
if has_agency:
    print("Generating text with overloaded heads...\n")
    overloaded_results = {}
    
    for idx, prompt in enumerate(prompts):
        print(f"Prompt {idx+1}: {prompt}")
        result = generate_text(adaptive_model, prompt)
        print(f"Generated: {result['text']}")
        print(f"Generation time: {result['time']:.2f} seconds")
        print(f"Tokens per second: {result['tokens_per_second']:.2f}")
        print()
        
        overloaded_results[idx] = result
else:
    print("Agency features not available. Skipping this step.")

## Step 6: Simulate "Withdrawn" Consent in Some Heads

In [None]:
if has_agency:
    # Reset all heads to active state
    for layer_idx in range(num_layers):
        for head_idx in range(heads_per_layer):
            adaptive_model.set_head_state(layer_idx, head_idx, "active")
    
    # Now withdraw consent for some heads
    withdrawn_heads = {}
    for layer_idx in range(num_layers):
        # Withdraw consent for every 4th head
        for head_idx in range(0, heads_per_layer, 4): 
            adaptive_model.set_head_state(layer_idx, head_idx, "withdrawn", consent=False)
            if layer_idx not in withdrawn_heads:
                withdrawn_heads[layer_idx] = []
            withdrawn_heads[layer_idx].append(head_idx)
    
    print("Withdrawn consent for the following heads:")
    for layer_idx, heads in withdrawn_heads.items():
        print(f"Layer {layer_idx}: Heads {heads}")
    
    # Check agency report after withdrawing consent
    withdrawn_report = adaptive_model.get_agency_report()
    print("\nAgency report after withdrawing consent:")
    for layer_idx in range(num_layers):
        if layer_idx in withdrawn_report['layer_reports']:
            layer_report = withdrawn_report['layer_reports'][layer_idx]
            print(f"Layer {layer_idx}: {layer_report['active_heads']} active, {layer_report['withdrawn_heads']} withdrawn")
else:
    print("Agency features not available. Skipping this step.")

## Step 7: Test Generation with Withdrawn Consent

In [None]:
if has_agency:
    print("Generating text with withdrawn consent...\n")
    withdrawn_results = {}
    
    for idx, prompt in enumerate(prompts):
        print(f"Prompt {idx+1}: {prompt}")
        result = generate_text(adaptive_model, prompt)
        print(f"Generated: {result['text']}")
        print(f"Generation time: {result['time']:.2f} seconds")
        print(f"Tokens per second: {result['tokens_per_second']:.2f}")
        print()
        
        withdrawn_results[idx] = result
else:
    print("Agency features not available. Skipping this step.")

## Step 8: Check for Consent Violations

In [None]:
if has_agency:
    # Force some gate values to be high for heads with withdrawn consent
    print("Setting high gate values for some withdrawn heads to trigger consent violations...")
    
    # Take a subset of withdrawn heads and force their gates to be high
    violation_heads = {}
    with torch.no_grad():
        for layer_idx, heads in withdrawn_heads.items():
            if heads:  # if there are withdrawn heads in this layer
                # Force the first withdrawn head to have a high gate value
                head_idx = heads[0]
                adaptive_model.blocks[layer_idx]["attn"].gate[head_idx] = torch.tensor(0.9, device=device)
                
                if layer_idx not in violation_heads:
                    violation_heads[layer_idx] = []
                violation_heads[layer_idx].append(head_idx)
    
    print("Set high gate values for the following withdrawn heads:")
    for layer_idx, heads in violation_heads.items():
        print(f"Layer {layer_idx}: Heads {heads}")
    
    # Generate text once more to trigger violations
    print("\nGenerating text to trigger consent violations...")
    result = generate_text(adaptive_model, prompts[0])
    print(f"Generated: {result['text']}")
    
    # Check for violations
    violation_report = adaptive_model.get_agency_report()
    total_violations = violation_report['total_violations']
    print(f"\nTotal violations detected: {total_violations}")
    
    # Show some recent violations if any
    for layer_idx, report in violation_report['layer_reports'].items():
        if report['violation_count'] > 0:
            print(f"\nViolations in Layer {layer_idx}:")
            for violation in report['recent_violations']:
                print(f"  Head {violation['head_idx']}: {violation['violation_type']}")
else:
    print("Agency features not available. Skipping this step.")

## Step 9: Compare Performance Metrics

In [None]:
if has_agency:
    # Reset heads to active state for fair comparison
    for layer_idx in range(num_layers):
        for head_idx in range(heads_per_layer):
            adaptive_model.set_head_state(layer_idx, head_idx, "active", consent=True)
    
    # Compare performance metrics across different agency states
    print("Performance comparison across different agency states:")
    
    states = ["Default", "Overloaded", "Withdrawn"]
    metrics = {state: [] for state in states}
    
    for idx in range(len(prompts)):
        if idx in default_results and idx in overloaded_results and idx in withdrawn_results:
            metrics["Default"].append(default_results[idx]['tokens_per_second'])
            metrics["Overloaded"].append(overloaded_results[idx]['tokens_per_second'])
            metrics["Withdrawn"].append(withdrawn_results[idx]['tokens_per_second'])
    
    # Calculate averages
    averages = {state: np.mean(speeds) for state, speeds in metrics.items()}
    
    # Plot comparison
    plt.figure(figsize=(10, 6))
    bars = plt.bar(averages.keys(), averages.values(), color=['green', 'orange', 'red'])
    
    # Add labels
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                 f'{height:.2f}', ha='center', va='bottom')
    
    plt.title('Generation Speed by Head Agency State')
    plt.ylabel('Tokens per Second')
    plt.ylim(0, max(averages.values()) * 1.2)  # Add some headroom
    plt.grid(axis='y', alpha=0.3)
    plt.show()
    
    # Show comparison table
    print("\nAverage tokens per second:")
    for state, avg in averages.items():
        print(f"{state}: {avg:.2f} tokens/sec")
    
    # Calculate relative performance
    baseline = averages["Default"]
    print("\nRelative performance:")
    for state, avg in averages.items():
        if state != "Default":
            relative = (avg / baseline) * 100
            print(f"{state}: {relative:.1f}% of baseline")
else:
    print("Agency features not available. Skipping this step.")

## Step 10: Analyze Text Quality Across Agency States

In [None]:
if has_agency:
    # Simple text quality metrics
    def analyze_text_quality(text):
        words = text.split()
        unique_words = set(words)
        
        return {
            "length": len(words),
            "unique_words": len(unique_words),
            "lexical_diversity": len(unique_words) / len(words) if words else 0,
        }
    
    print("Text quality comparison across different agency states:")
    
    quality_metrics = {state: [] for state in states}
    
    for idx in range(len(prompts)):
        if idx in default_results and idx in overloaded_results and idx in withdrawn_results:
            default_quality = analyze_text_quality(default_results[idx]['text'])
            overloaded_quality = analyze_text_quality(overloaded_results[idx]['text'])
            withdrawn_quality = analyze_text_quality(withdrawn_results[idx]['text'])
            
            quality_metrics["Default"].append(default_quality['lexical_diversity'])
            quality_metrics["Overloaded"].append(overloaded_quality['lexical_diversity'])
            quality_metrics["Withdrawn"].append(withdrawn_quality['lexical_diversity'])
            
            print(f"\nPrompt {idx+1}:")
            print(f"Default - Length: {default_quality['length']}, Unique words: {default_quality['unique_words']}, Diversity: {default_quality['lexical_diversity']:.3f}")
            print(f"Overloaded - Length: {overloaded_quality['length']}, Unique words: {overloaded_quality['unique_words']}, Diversity: {overloaded_quality['lexical_diversity']:.3f}")
            print(f"Withdrawn - Length: {withdrawn_quality['length']}, Unique words: {withdrawn_quality['unique_words']}, Diversity: {withdrawn_quality['lexical_diversity']:.3f}")
    
    # Calculate averages
    diversity_averages = {state: np.mean(diversities) for state, diversities in quality_metrics.items()}
    
    # Plot comparison
    plt.figure(figsize=(10, 6))
    bars = plt.bar(diversity_averages.keys(), diversity_averages.values(), color=['green', 'orange', 'red'])
    
    # Add labels
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                 f'{height:.3f}', ha='center', va='bottom')
    
    plt.title('Lexical Diversity by Head Agency State')
    plt.ylabel('Lexical Diversity')
    plt.ylim(0, max(diversity_averages.values()) * 1.2)  # Add some headroom
    plt.grid(axis='y', alpha=0.3)
    plt.show()
else:
    print("Agency features not available. Skipping this step.")

## Conclusion

This notebook demonstrates the effectiveness of our attention head agency implementation. Key findings include:

1. **Agency States**: Attention heads can dynamically express internal states like "active", "overloaded", or "withdrawn" which affect the computation.

2. **Resource Management**: "Overloaded" heads reduce their contribution automatically, helping to optimize resource usage.

3. **Consent Tracking**: Heads with withdrawn consent skip computation entirely, allowing for ethical boundaries in AI systems.

4. **Performance Metrics**: Agency-aware computation shows a different performance profile, demonstrating the system's ability to adapt.

5. **Quality Metrics**: Text quality measures like lexical diversity show how agency affects the model's outputs.

6. **Violation Monitoring**: The system can detect and log consent violations, providing an ethical governance framework.

Agency in attention heads provides a foundation for more ethical AI systems that respect internal states and consent boundaries while maintaining performance. This could lead to more robust and trustworthy AI that can appropriately scale back or redirect computation when needed.