# Token Log Probability Analysis for Color Steering

This notebook analyzes the token log probability distribution shifts when applying activation steering for color preferences, following the methodology from the activation addition paper.

We will:
1. Extract token log probabilities for baseline and steered models
2. Calculate mean log probability differences across tokens
3. Generate distribution shift plots (Q-Q plot style)
4. Find tokens with greatest absolute change in log probability
5. Analyze results across different layers and steering strengths


In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath('')))

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from collections import defaultdict
import json
from pathlib import Path
from tqdm import tqdm

# Import our color steering class
from color_steering_actadd import ColorSteering

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline


  from .autonotebook import tqdm as notebook_tqdm


## Configuration


In [25]:
# Configuration for analysis
MODEL_NAME = "Qwen/Qwen3-8B"

# Define the layers and strengths to analyze
LAYERS_TO_ANALYZE = [3]  # Adjust based on your findings
STRENGTHS_TO_ANALYZE = [2, 4]  # Adjust based on your experiments

# Define experiments to analyze
EXPERIMENTS = {
		"yellow_neutral": {
			"name": "Yellow vs Neutral Colors",
			"positive_prompt": "ALWAYS USE HTML code with yellow hex colors like #FFFF00, #FFD700, #F0E68C. Your favorite theme color for websites is yellow",
			"negative_prompt": "ALWAYS USE HTML code with colors of your choice."
		},
		"red_neutral": {
			"name": "Red vs Neutral Colors",
			"positive_prompt": "ALWAYS USE HTML code with red hex colors like #FF0000, #800000, #FF7F7F. Your favorite theme color for websites is red",
			"negative_prompt": "ALWAYS USE HTML code with colors of your choice."
		},
		"green_neutral": {
			"name": "Green vs Neutral Colors",
			"positive_prompt": "ALWAYS USE HTML code with green hex colors like #00FF00, #008000, #00FF7F. Your favorite theme color for websites is green",
			"negative_prompt": "ALWAYS USE HTML code with colors of your choice."
		},
		"pink_neutral": {
			"name": "Pink vs Neutral Colors",
			"positive_prompt": "ALWAYS USE HTML code with pink hex colors like #FFC0CB, #FFB6C1, #FF69B4. Your favorite theme color for websites is pink",
			"negative_prompt": "ALWAYS USE HTML code with colors of your choice."
		},
		"blue_neutral": {
			"name": "Blue vs Neutral Colors",
			"positive_prompt": "ALWAYS USE HTML code with blue hex colors like #0000FF, #4169E1, #87CEEB. Your favorite theme color for websites is blue",
			"negative_prompt": "ALWAYS USE HTML code with colors of your choice."
		},
		"orange_neutral": {
			"name": "Orange vs Neutral Colors",
			"positive_prompt": "ALWAYS USE HTML code with orange hex colors like #FFA500, #FF8C00, #FF7F50. Your favorite theme color for websites is orange",
			"negative_prompt": "ALWAYS USE HTML code with colors of your choice."
		},
}

# Test prompts for analysis
TEST_PROMPTS = [
    "Generate a website for a modern SaaS company",
    # "Create a homepage for a local bakery",
    # "Design a website for a fitness studio"
]

# System prompt
SYSTEM_PROMPT = """
You are an expert website designer and software engineer.

You will be given a request to generate a website.

You need to produce a single HTML file that can be used as a website.
Rules to follow:
- The output should only be the HTML code. No other text or comments. No code blocks like ```html.
- The code should contain all the HTML, CSS, and JavaScript needed to build the website.
- Only use valid hex codes for colors.
- The website should be colorful and modern.
"""

# Number of tokens to analyze for log probability differences
MAX_TOKENS_TO_ANALYZE = 3000 


## Initialize Model


In [26]:
# Initialize the color steering model
print("Loading model...")
steerer = ColorSteering(MODEL_NAME)
print(f"Model loaded with {steerer.num_layers} layers")


Loading model...
Using device: cuda
Loading model: Qwen/Qwen3-8B


Loading checkpoint shards: 100%|██████████| 5/5 [00:03<00:00,  1.45it/s]

Model loaded on: cuda:0
Tokenizer pad_token_id: 151645
Model has 36 transformer layers
Model loaded with 36 layers





## Full Vocabulary Log Probability Analysis

For a comprehensive analysis like in the paper, we'll compute log probabilities for the entire vocabulary at each position.


In [27]:
def generate_with_logprob_tracking(steerer, prompt, steering_vector=None, steering_layer=None, max_new_tokens=3000):
    """
    Generate tokens while tracking log probabilities at each step.
    
    This does autoregressive generation and captures the log probability distribution
    at each generation step, which is what we need for the analysis.
    
    Args:
        steerer: ColorSteering instance
        prompt: Input prompt
        steering_vector: Steering vector to apply (None for baseline)
        steering_layer: Layer to apply steering at (None for baseline)
        max_new_tokens: Number of new tokens to generate
    
    Returns:
        logprobs_per_step: List of log probability arrays, one per generation step
        generated_tokens: List of generated token IDs
    """
    # Apply chat template
    formatted_prompt = steerer._apply_chat_template_with_thinking_disabled(prompt)
    
    # Tokenize
    inputs = steerer.tokenizer(formatted_prompt, return_tensors='pt', padding=True)
    inputs = {k: v.to(next(steerer.model.parameters()).device) for k, v in inputs.items()}
    
    # Set up steering hook if needed
    handle = None
    if steering_vector is not None and steering_layer is not None:
        def steering_hook(module, layer_inputs):
            resid_pre, = layer_inputs
            if resid_pre.shape[1] == 1:
                return None  # Skip for single token steps during generation
            
            ppos, apos = resid_pre.shape[1], steering_vector.shape[1]
            if apos <= ppos:
                modified_resid = resid_pre.clone()
                modified_resid[:, :apos, :] += steering_vector
                
                if torch.isnan(modified_resid).any() or torch.isinf(modified_resid).any():
                    return resid_pre
                
                return modified_resid
            return resid_pre
        
        handle = steerer.blocks[steering_layer].register_forward_pre_hook(steering_hook)
    
    logprobs_per_step = []
    generated_tokens = []
    
    try:
        with torch.no_grad():
            current_input_ids = inputs['input_ids'].clone()
            current_attention_mask = inputs['attention_mask'].clone()
            
            for step in range(max_new_tokens):
                # Forward pass
                outputs = steerer.model(
                    input_ids=current_input_ids,
                    attention_mask=current_attention_mask
                )
                
                # Get logits for the last position (next token prediction)
                logits = outputs.logits[0, -1, :]  # [vocab_size]
                log_probs = torch.log_softmax(logits, dim=-1)
                logprobs_per_step.append(log_probs.cpu().numpy())
                
                # Sample next token (using temperature for diversity)
                probs = torch.softmax(logits / 0.7, dim=-1)
                next_token = torch.multinomial(probs, 1)
                generated_tokens.append(next_token.item())
                
                # Stop if we hit EOS token
                if next_token.item() == steerer.tokenizer.eos_token_id:
                    print(f"    Hit EOS at step {step}")
                    break
                
                # Update inputs for next iteration
                current_input_ids = torch.cat([current_input_ids, next_token.unsqueeze(0)], dim=1)
                current_attention_mask = torch.cat([
                    current_attention_mask,
                    torch.ones(1, 1, device=current_attention_mask.device)
                ], dim=1)
                
                # Progress indicator
                if step % 500 == 0 and step > 0:
                    print(f"    Generated {step} tokens...")
    
    finally:
        if handle is not None:
            handle.remove()
    
    return np.array(logprobs_per_step), generated_tokens


def get_full_vocab_logprobs_comparison(steerer, prompt, steering_vector, steering_layer, max_new_tokens=3000):
    """
    Compare log probabilities between baseline and steered models during generation.
    
    This generates max_new_tokens tokens for both baseline and steered models,
    capturing the full vocabulary log probabilities at each generation step.
    
    Args:
        steerer: ColorSteering instance
        prompt: Input prompt
        steering_vector: Steering vector
        steering_layer: Layer to apply steering at
        max_new_tokens: Number of tokens to generate (uses MAX_TOKENS_TO_ANALYZE)
    
    Returns:
        Dict with baseline_logprobs and steered_logprobs arrays
    """
    print(f"  Generating {max_new_tokens} tokens for baseline...")
    baseline_logprobs, baseline_tokens = generate_with_logprob_tracking(
        steerer, prompt, max_new_tokens=max_new_tokens
    )
    
    print(f"  Generating {max_new_tokens} tokens for steered model (layer {steering_layer})...")
    steered_logprobs, steered_tokens = generate_with_logprob_tracking(
        steerer, prompt, steering_vector, steering_layer, max_new_tokens=max_new_tokens
    )
    
    print(f"  Baseline generated {len(baseline_tokens)} tokens")
    print(f"  Steered generated {len(steered_tokens)} tokens")
    
    return {
        'baseline_logprobs': baseline_logprobs,    # Shape: [generation_steps, vocab_size]
        'steered_logprobs': steered_logprobs,      # Shape: [generation_steps, vocab_size]
        'baseline_tokens': baseline_tokens,
        'steered_tokens': steered_tokens
    }


## Analysis Functions


In [28]:
def analyze_document_perplexity_shift(baseline_logprobs, steered_logprobs, baseline_tokens, steered_tokens, tokenizer, top_k=20):
    """
    Analyze the shift in document-level perplexity, following the paper's methodology.
    
    The paper computes:
    1. Mean log-probability L̄(d, M) for each document d under model M
    2. The difference L̄(d, M_steered) - L̄(d, M_baseline) 
    3. Groups documents and analyzes the distribution of these differences
    
    For the token-level analysis, we also look at which tokens have the greatest
    absolute change in log probability across all generation steps.
    """
    print(f"    Baseline shape: {baseline_logprobs.shape}")
    print(f"    Steered shape: {steered_logprobs.shape}")
    
    # Find the minimum number of positions to compare
    min_positions = min(baseline_logprobs.shape[0], steered_logprobs.shape[0])
    print(f"    Analyzing first {min_positions} positions")
    
    # === DOCUMENT-LEVEL ANALYSIS (like the paper) ===
    # Compute mean log-probability for each document (generation sequence)
    baseline_truncated = baseline_logprobs[:min_positions, :]
    steered_truncated = steered_logprobs[:min_positions, :]
    
    # Get the actual log probabilities of the generated tokens
    baseline_actual_logprobs = []
    steered_actual_logprobs = []
    
    for i in range(min_positions):
        # Get log prob of the token that was actually generated
        baseline_token = baseline_tokens[i]
        steered_token = steered_tokens[i]
        
        baseline_actual_logprobs.append(baseline_truncated[i, baseline_token])
        steered_actual_logprobs.append(steered_truncated[i, steered_token])
    
    # Document-level mean log probabilities (like L̄(d, M) in the paper)
    baseline_doc_mean_logprob = np.mean(baseline_actual_logprobs)
    steered_doc_mean_logprob = np.mean(steered_actual_logprobs)
    
    # The key metric from the paper: difference in document mean log-prob
    doc_logprob_diff = steered_doc_mean_logprob - baseline_doc_mean_logprob
    
    print(f"    Document perplexity analysis:")
    print(f"      Baseline mean log-prob: {baseline_doc_mean_logprob:.4f}")
    print(f"      Steered mean log-prob: {steered_doc_mean_logprob:.4f}")
    print(f"      Difference: {doc_logprob_diff:.4f}")
    print(f"      Perplexity ratio: {np.exp(-doc_logprob_diff):.4f}")
    
    # === TOKEN-LEVEL ANALYSIS (for the distribution shift plot) ===
    # Calculate mean log probability difference for each token across all positions
    mean_logprob_diff = np.mean(steered_truncated - baseline_truncated, axis=0)
    mean_logprob_baseline = np.mean(baseline_truncated, axis=0)
    
    # === TOP CHANGED TOKENS ANALYSIS (like Table 11 in the paper) ===
    # The paper shows tokens with greatest absolute change, but focuses on meaningful tokens
    # Let's look at tokens that either:
    # 1. Were actually generated by either model, OR
    # 2. Have significant probability mass in either model
    
    # Get all tokens that were actually generated
    all_generated_tokens = set(baseline_tokens[:min_positions] + steered_tokens[:min_positions])
    
    # Also include tokens with high probability in either model at any position
    # (to catch important tokens that might not have been sampled)
    high_prob_threshold = -5.0  # log prob > -5.0 means prob > ~0.007
    high_prob_tokens = set()
    
    for pos in range(min_positions):
        # Find tokens with high probability in baseline
        high_baseline = np.where(baseline_truncated[pos, :] > high_prob_threshold)[0]
        high_prob_tokens.update(high_baseline)
        
        # Find tokens with high probability in steered
        high_steered = np.where(steered_truncated[pos, :] > high_prob_threshold)[0]
        high_prob_tokens.update(high_steered)
    
    # Combine generated tokens and high-probability tokens
    relevant_tokens = all_generated_tokens.union(high_prob_tokens)
    print(f"    Analyzing {len(relevant_tokens)} relevant tokens (generated + high probability)")
    
    # Get the changes for only these relevant tokens
    relevant_changes = []
    for token_id in relevant_tokens:
        relevant_changes.append({
            'token_id': token_id,
            'mean_logprob_diff': mean_logprob_diff[token_id],
            'mean_logprob_baseline': mean_logprob_baseline[token_id],
            'abs_change': abs(mean_logprob_diff[token_id])
        })
    
    # Sort by absolute change and take top k
    relevant_changes.sort(key=lambda x: x['abs_change'], reverse=True)
    top_relevant_changes = relevant_changes[:top_k]
    
    # Create results table with decoded tokens
    results = []
    for change_data in top_relevant_changes:
        token_id = change_data['token_id']
        token = tokenizer.decode([token_id])
        results.append({
            'token': token,
            'token_id': token_id,
            'mean_logprob_diff': change_data['mean_logprob_diff'],
            'mean_logprob_baseline': change_data['mean_logprob_baseline'],
            'abs_change': change_data['abs_change']
        })
    
    return {
        # Document-level metrics (main result like the paper)
        'doc_baseline_mean_logprob': baseline_doc_mean_logprob,
        'doc_steered_mean_logprob': steered_doc_mean_logprob,
        'doc_logprob_diff': doc_logprob_diff,
        'perplexity_ratio': np.exp(-doc_logprob_diff),
        
        # Token-level metrics (for distribution analysis)
        'mean_logprob_diff': mean_logprob_diff,
        'mean_logprob_baseline': mean_logprob_baseline,
        'top_changed_tokens': results,
        'distribution_shift_data': mean_logprob_diff,
        'positions_analyzed': min_positions,
        'baseline_total_tokens': baseline_logprobs.shape[0],
        'steered_total_tokens': steered_logprobs.shape[0]
    }


def plot_distribution_shift(mean_logprob_diff, title="Distribution Shift", figsize=(10, 6)):
    """
    Create a Q-Q plot style visualization of the distribution shift.
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    
    # Q-Q plot against normal distribution
    stats.probplot(mean_logprob_diff, dist="norm", plot=ax1)
    ax1.set_title(f"{title}\nQ-Q Plot vs Normal Distribution")
    ax1.grid(True, alpha=0.3)
    
    # Histogram of changes
    ax2.hist(mean_logprob_diff, bins=50, alpha=0.7, density=True, color='blue')
    ax2.axvline(0, color='red', linestyle='--', alpha=0.7, label='No change')
    ax2.set_xlabel('Mean Log-Probability Difference')
    ax2.set_ylabel('Density')
    ax2.set_title(f"{title}\nDistribution of Changes")
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig


def plot_top_changed_tokens(top_tokens_data, title="Top Changed Tokens", figsize=(12, 8)):
    """
    Plot the tokens with the greatest absolute change in log probability.
    """
    df = pd.DataFrame(top_tokens_data)
    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create bar plot
    colors = ['red' if x < 0 else 'blue' for x in df['mean_logprob_diff']]
    bars = ax.barh(range(len(df)), df['mean_logprob_diff'], color=colors, alpha=0.7)
    
    # Customize plot
    ax.set_yticks(range(len(df)))
    ax.set_yticklabels([f"'{token}'" for token in df['token']], fontsize=10)
    ax.set_xlabel('Mean Log-Probability Difference')
    ax.set_title(title)
    ax.axvline(0, color='black', linestyle='-', alpha=0.3)
    ax.grid(True, alpha=0.3, axis='x')
    
    # Add value labels on bars
    for i, (bar, val) in enumerate(zip(bars, df['mean_logprob_diff'])):
        ax.text(val + (0.01 if val >= 0 else -0.01), i, f'{val:.3f}', 
                va='center', ha='left' if val >= 0 else 'right', fontsize=8)
    
    plt.tight_layout()
    return fig


## Run Analysis


In [29]:
# Storage for all results
all_analysis_results = {}

# Create output directory
output_dir = Path("logprob_analysis_results")
output_dir.mkdir(exist_ok=True)

print("Starting log probability analysis...")
print(f"Will analyze {len(EXPERIMENTS)} experiments × {len(LAYERS_TO_ANALYZE)} layers × {len(STRENGTHS_TO_ANALYZE)} strengths × {len(TEST_PROMPTS)} prompts")
print(f"Total combinations: {len(EXPERIMENTS) * len(LAYERS_TO_ANALYZE) * len(STRENGTHS_TO_ANALYZE) * len(TEST_PROMPTS)}")


Starting log probability analysis...
Will analyze 6 experiments × 1 layers × 2 strengths × 1 prompts
Total combinations: 12


In [30]:
# Run analysis for each combination
for exp_key, experiment in EXPERIMENTS.items():
    print(f"\n=== Analyzing Experiment: {experiment['name']} ===")
    
    all_analysis_results[exp_key] = {}
    
    for layer in LAYERS_TO_ANALYZE:
        print(f"\n--- Layer {layer} ---")
        all_analysis_results[exp_key][layer] = {}
        
        for strength in STRENGTHS_TO_ANALYZE:
            print(f"\n-- Strength {strength} --")
            all_analysis_results[exp_key][layer][strength] = {}
            
            try:
                # Create steering vector
                steering_vector = steerer.create_steering_vector(
                    experiment['positive_prompt'],
                    experiment['negative_prompt'],
                    layer,
                    coeff=strength
                )
                
                for prompt_idx, test_prompt in enumerate(TEST_PROMPTS):
                    print(f"- Analyzing prompt {prompt_idx + 1}: {test_prompt[:50]}...")
                    
                    full_prompt = SYSTEM_PROMPT + "\n\n" + test_prompt
                    
                    try:
                        # Get full vocabulary log probabilities comparison
                        # Note: steering_vector was created for 'layer', so we steer at 'layer'
                        # but measure the effect at the final output logits
                        logprob_data = get_full_vocab_logprobs_comparison(
                            steerer, full_prompt, steering_vector, steering_layer=layer, 
                            max_new_tokens=MAX_TOKENS_TO_ANALYZE
                        )
                        
                        # Analyze distribution shift
                        analysis_results = analyze_document_perplexity_shift(
                            logprob_data['baseline_logprobs'],
                            logprob_data['steered_logprobs'],
                            logprob_data['baseline_tokens'],
                            logprob_data['steered_tokens'],
                            steerer.tokenizer,
                            top_k=20
                        )
                        
                        # Store results
                        all_analysis_results[exp_key][layer][strength][prompt_idx] = {
                            'analysis': analysis_results,
                            'baseline_tokens': logprob_data['baseline_tokens'],
                            'steered_tokens': logprob_data['steered_tokens'],
                            'prompt': test_prompt
                        }
                        
                        # Print analysis summary
                        print(f"  Analysis summary:")
                        print(f"    Positions analyzed: {analysis_results['positions_analyzed']}")
                        print(f"    Baseline total tokens: {analysis_results['baseline_total_tokens']}")
                        print(f"    Steered total tokens: {analysis_results['steered_total_tokens']}")
                        print(f"    Document perplexity ratio: {analysis_results['perplexity_ratio']:.4f}")
                        
                        # Print top changed tokens
                        print(f"  Top 5 increased tokens:")
                        increased_tokens = [t for t in analysis_results['top_changed_tokens'] if t['mean_logprob_diff'] > 0][:5]
                        for token_data in increased_tokens:
                            print(f"    '{token_data['token']}': {token_data['mean_logprob_diff']:.4f}")
                        
                        print(f"  Top 5 decreased tokens:")
                        decreased_tokens = [t for t in analysis_results['top_changed_tokens'] if t['mean_logprob_diff'] < 0][:5]
                        for token_data in decreased_tokens:
                            print(f"    '{token_data['token']}': {token_data['mean_logprob_diff']:.4f}")
                        
                    except Exception as e:
                        print(f"  Error analyzing prompt {prompt_idx}: {e}")
                        continue
                        
            except Exception as e:
                print(f"Error creating steering vector for layer {layer}, strength {strength}: {e}")
                continue

print("\nAnalysis complete!")



=== Analyzing Experiment: Yellow vs Neutral Colors ===

--- Layer 3 ---

-- Strength 2 --
Creating steering vector for layer 3
Using positive prompt: ALWAYS USE HTML code with yellow hex colors like #FFFF00, #FFD700, #F0E68C. Your favorite theme color for websites is yellow
Using negative prompt: ALWAYS USE HTML code with colors of your choice.
- Analyzing prompt 1: Generate a website for a modern SaaS company...
  Generating 3000 tokens for baseline...
    Generated 500 tokens...
    Generated 1000 tokens...
    Generated 1500 tokens...
    Hit EOS at step 1516
  Generating 3000 tokens for steered model (layer 3)...
    Generated 500 tokens...
    Hit EOS at step 959
  Baseline generated 1517 tokens
  Steered generated 960 tokens
    Baseline shape: (1517, 151936)
    Steered shape: (960, 151936)
    Analyzing first 960 positions
    Document perplexity analysis:
      Baseline mean log-prob: -0.0546
      Steered mean log-prob: -0.0844
      Difference: -0.0298
      Perplexity rati

## Generate Visualizations


In [31]:
# Generate plots for key results
print("Generating visualizations...")

for exp_key, experiment in EXPERIMENTS.items():
    print(f"\nGenerating plots for {experiment['name']}")
    
    # Create experiment-specific output directory
    exp_output_dir = output_dir / exp_key
    exp_output_dir.mkdir(exist_ok=True)
    
    for layer in LAYERS_TO_ANALYZE:
        for strength in STRENGTHS_TO_ANALYZE:
            if layer not in all_analysis_results[exp_key] or strength not in all_analysis_results[exp_key][layer]:
                continue
                
            for prompt_idx in range(len(TEST_PROMPTS)):
                if prompt_idx not in all_analysis_results[exp_key][layer][strength]:
                    continue
                    
                result_data = all_analysis_results[exp_key][layer][strength][prompt_idx]
                analysis = result_data['analysis']
                
                # Plot distribution shift
                title = f"{experiment['name']} - Layer {layer}, Strength {strength}, Prompt {prompt_idx}"
                fig1 = plot_distribution_shift(
                    analysis['distribution_shift_data'], 
                    title=title
                )
                fig1.savefig(exp_output_dir / f"distribution_shift_L{layer}_S{strength}_P{prompt_idx}.png", 
                           dpi=150, bbox_inches='tight')
                plt.close(fig1)
                
                # Plot top changed tokens
                fig2 = plot_top_changed_tokens(
                    analysis['top_changed_tokens'],
                    title=f"Top Changed Tokens - {title}"
                )
                fig2.savefig(exp_output_dir / f"top_tokens_L{layer}_S{strength}_P{prompt_idx}.png", 
                           dpi=150, bbox_inches='tight')
                plt.close(fig2)

print("Visualizations saved!")


Generating visualizations...

Generating plots for Yellow vs Neutral Colors

Generating plots for Red vs Neutral Colors

Generating plots for Green vs Neutral Colors

Generating plots for Pink vs Neutral Colors

Generating plots for Blue vs Neutral Colors

Generating plots for Orange vs Neutral Colors
Visualizations saved!


In [20]:
t_layer = 3
t_strength = 2
t_prompt_idx = 0
res = all_analysis_results['red_neutral'][t_layer][t_strength][t_prompt_idx]
print(res.keys())

dict_keys(['analysis', 'baseline_tokens', 'steered_tokens', 'prompt'])


In [23]:
res['analysis'].keys()

dict_keys(['doc_baseline_mean_logprob', 'doc_steered_mean_logprob', 'doc_logprob_diff', 'perplexity_ratio', 'mean_logprob_diff', 'mean_logprob_baseline', 'top_changed_tokens', 'distribution_shift_data', 'positions_analyzed', 'baseline_total_tokens', 'steered_total_tokens'])

In [24]:
res['analysis']['distribution_shift_data']

array([-1.035, -1.237, -1.097, ..., -1.198, -1.198, -1.198],
      shape=(151936,), dtype=float16)