In [None]:
%matplotlib inline

In [None]:
# === Installs if needed ===
# pip install transformers accelerate bitsandbytes seaborn numpy torch scikit-learn

In [None]:
# === Imports ===
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

In [None]:
# === Global Helper Functions ===
def normalized_entropy(ps, axis=1, epsilon=1e-16):
    """Numerically stable normalized entropy"""
    # Add epsilon to prevent zeros
    ps = np.array(ps) + epsilon
    ps = ps / np.expand_dims(np.sum(ps, axis=axis), 1)
    
    # Compute log safely
    log_ps = np.where(ps > 0, np.log(ps), 0)
    
    # Calculate normalized entropy
    seq_len = ps.shape[1]
    max_entropy = np.log(seq_len)
    return -np.sum(ps * log_ps, axis=1) / max_entropy

def normalized_entropy_no_start(ps, axis=1, epsilon=1e-16):
    """
    Calculates normalized entropy while excluding the start token's contribution.
    
    Args:
        ps: Attention weights matrix of shape (n_heads, seq_len)
        axis: Axis along which to normalize (default=1 for sequence dimension)
        epsilon: Small value to avoid log(0)
        
    Returns:
        Normalized entropy values with start token excluded
    """
    # Make a copy to avoid modifying the original
    ps_no_start = np.array(ps.copy())
    
    # Zero out attention to the start token (index 0)
    ps_no_start[:, 0] = 0
    
    # Renormalize the remaining attention weights to sum to 1
    ps_no_start = ps_no_start + epsilon
    ps_no_start = ps_no_start / np.expand_dims(np.sum(ps_no_start, axis=axis), 1)
    
    # Compute log safely
    log_ps = np.where(ps_no_start > 0, np.log(ps_no_start), 0)
    
    # Calculate normalized entropy (excluding start token from sequence length)
    seq_len = ps_no_start.shape[1] - 1  # Subtract 1 to exclude start token
    max_entropy = np.log(seq_len)
    
    return -np.sum(ps_no_start * log_ps, axis=1) / max_entropy


def process_attn_entropies(attentions, output_token=0):
    '''
    Simplified version for fixed-length truncated sequences
    '''
    entropies = []
    for i, attn in enumerate(attentions[output_token]):
        attn = attn.to(torch.float32).squeeze().cpu().numpy()
        # Focus on last token (same for all sequences)
        entropies.append(normalized_entropy(attn[:, -1, :], axis=1)) 
    return entropies

def calculate_entropies(sentences):
    entropies = []
    for sentence in sentences:
        model_inputs = tokenizer.encode(sentence, return_tensors="pt").to(model.device)
        output = model.generate(
            model_inputs,
            output_attentions=True,
            max_new_tokens=1,
            return_dict_in_generate=True
        )
        entropies.append(process_attn_entropies(output.attentions))
    return np.array(entropies)

def process_attn_entropies_no_start(attentions, output_token=0):
    '''
    Calculates attention entropy while excluding the start token.
    '''
    entropies = []
    for i, attn in enumerate(attentions[output_token]):
        attn = attn.to(torch.float32).squeeze().cpu().numpy()
        # Focus on last token but exclude start token from entropy calculation
        entropies.append(normalized_entropy_no_start(attn[:, -1, :], axis=1)) 
    return entropies

def calculate_entropies_no_start(sentences):
    entropies = []
    for sentence in sentences:
        model_inputs = tokenizer.encode(sentence, return_tensors="pt").to(model.device)
        output = model.generate(
            model_inputs,
            output_attentions=True,
            max_new_tokens=1,
            return_dict_in_generate=True
        )
        entropies.append(process_attn_entropies_no_start(output.attentions))
    return np.array(entropies)


def save_entropy_arrays(model_tag, dataset_tag, **arrays):
    """
    Saves each array in `arrays` as a .npy file with a standardized filename.
    
    Args:
        model_tag (str): Short identifier for the model, e.g., 'llama2_70b'
        dataset_tag (str): Short identifier for the dataset, e.g., 'ling' or 'md'
        arrays (dict): Named arrays to save, e.g., sensible=arr1, nonsensical=arr2
    """
    save_dir = f"entropy_results_{model_tag}_{dataset_tag}"
    os.makedirs(save_dir, exist_ok=True)
    for name, arr in arrays.items():
        fname = f"{save_dir}/{name}_entropies.npy"
        np.save(fname, arr)
        print(f"Saved: {fname}")

In [None]:
# === Configuration ===
model_name = "meta-llama/CodeLlama-70b-Instruct-hf" # Edit for model you want to initialize: meta-llama/Llama-2-70b-hf, meta-llama/Llama-2-70b-chat-hf, meta-llama/CodeLlama-70b-Instruct-hf, meta-llama/CodeLlama-70b-hf
access_token = "hf_acces_token"


# === Main Pipeline ===
# Initialize model
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_storage=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
    token=access_token
)

## Linguistic Dataset

In [None]:
# === Helper Functions - Linguistic Dataset ===

def _find_best_truncation_point(sentence, max_len, tokenizer):
    """
    Helper function: Finds the latest possible token position (<= max_len)
    that ends on a word boundary (space or punctuation).
    Returns the number of tokens to keep.
    """
    encoded = tokenizer.encode_plus(sentence, add_special_tokens=False, return_offsets_mapping=True)
    tokens = encoded.input_ids
    offsets = encoded.offset_mapping

    # Determine the upper limit for searching for a boundary
    search_limit = min(len(tokens), max_len)

    # Iterate backwards from the search limit to find the last valid boundary
    for i in range(search_limit, 0, -1):
        # Offset corresponds to the end character position of the (i-1)th token
        end_char_pos = offsets[i-1][1]
        # Check if the character after the token end is a boundary character
        # Make sure we don't index out of bounds for the original sentence string
        if end_char_pos < len(sentence) and sentence[end_char_pos] in [' ', '.', '!', '?', ',', ';', '\n', '\t']:
            return i # Return the number of tokens to keep (ending at this boundary)
        # Also consider the case where the token itself is the very end of the string
        if end_char_pos == len(sentence):
             return i # Keep tokens up to the end

    # Let's restart the logic slightly for clarity:
    # Find the last boundary <= search_limit
    best_trunc_pos = 0 # Default to keeping nothing if no boundary found
    for i in range(1, search_limit + 1): # Iterate up to the limit
        end_char_pos = offsets[i-1][1]
        is_boundary = False
        if end_char_pos == len(sentence): # End of string is a boundary
            is_boundary = True
        elif end_char_pos < len(sentence) and sentence[end_char_pos] in [' ', '.', '!', '?', ',', ';', '\n', '\t']:
            is_boundary = True

        if is_boundary:
            best_trunc_pos = i # Update the best position found so far

    return best_trunc_pos


def _truncate_sentence_to_tokens(sentence, num_tokens, tokenizer):
    """
    Helper function: Truncates a sentence to exactly num_tokens
    and adds ellipsis if needed.
    """
    encoded = tokenizer.encode_plus(sentence, add_special_tokens=False)
    tokens = encoded.input_ids

    truncated_tokens = tokens[:num_tokens]
    adjusted = tokenizer.decode(truncated_tokens, skip_special_tokens=True).strip()

    # Add ellipsis if actual truncation occurred and the result is not empty
    if num_tokens < len(tokens) and adjusted and adjusted[-1] not in ['.', '!', '?']:
        adjusted += '...'

    return adjusted


def smart_truncate(sensible_sentences, nonsensical_sentences, tokenizer):
    """
    Truncates pairs of sentences (sensible, nonsensical) to a common token length
    determined by the shorter sentence within the pair, while preserving complete words.

    Args:
        sensible_sentences (list): List of sensible sentences.
        nonsensical_sentences (list): List of nonsensical sentences (must be same length).
        tokenizer: The tokenizer instance.

    Returns:
        tuple: A tuple containing two lists:
               (truncated_sensible_sentences, truncated_nonsensical_sentences)
    """
    if len(sensible_sentences) != len(nonsensical_sentences):
        raise ValueError("Input lists must have the same length for pairwise truncation.")

    truncated_sensible = []
    truncated_nonsensical = []

    print(f"Processing {len(sensible_sentences)} sentence pairs for pairwise truncation...")

    for i, (sensible_s, nonsensical_s) in enumerate(zip(sensible_sentences, nonsensical_sentences)):
        # 1. Get original token lengths for the pair
        len_s = len(tokenizer.encode(sensible_s, add_special_tokens=False))
        len_n = len(tokenizer.encode(nonsensical_s, add_special_tokens=False))

        # 2. Determine the target maximum length for this pair
        target_max_len = min(len_s, len_n)

        # 3. Find the best actual truncation point (<= target_max_len) for BOTH sentences
        trunc_pos_s = _find_best_truncation_point(sensible_s, target_max_len, tokenizer)
        trunc_pos_n = _find_best_truncation_point(nonsensical_s, target_max_len, tokenizer)

        # 4. The final length for the pair is the minimum of the two valid points
        #    This ensures *both* sentences can be truncated cleanly to the *same* length.
        final_len_pair = min(trunc_pos_s, trunc_pos_n)

        # 5. Truncate both sentences to this final agreed-upon length
        final_sensible = _truncate_sentence_to_tokens(sensible_s, final_len_pair, tokenizer)
        final_nonsensical = _truncate_sentence_to_tokens(nonsensical_s, final_len_pair, tokenizer)

        truncated_sensible.append(final_sensible)
        truncated_nonsensical.append(final_nonsensical)

        # Optional: Progress indicator
        if (i + 1) % 50 == 0:
             print(f"  Processed {i+1}/{len(sensible_sentences)} pairs...")

    print("Pairwise truncation complete.")
    return truncated_sensible, truncated_nonsensical

In [None]:
import pandas as pd
import numpy as np

# Load and prepare data (all sentences)
try:
    df = pd.read_csv('linguistic_dataset_full.csv')
except FileNotFoundError:
    print("Error: 'linguistic_dataset_full.csv' not found. Please ensure the file is in the correct directory.")
    raise

# Filter sensible and nonsensical sentences
all_sensible_df = df[df['type'] == 'S']  # DataFrame of all sensible sentences
all_nonsensical_df = df[df['type'] == 'N']  # DataFrame of all nonsensical sentences

# --- Match pairs based on sentence_id ---
# Merge the dataframes to easily align sensible and nonsensical sentences by their ID
# This ensures we are comparing the correct pairs
merged_df = pd.merge(all_sensible_df, all_nonsensical_df, on='sentence_id', suffixes=('_sensible', '_nonsensical'))

# Extract the paired sentences and their IDs into aligned lists
sensible_sentences_original = merged_df['sentence_sensible'].tolist()
nonsensical_sentences_original = merged_df['sentence_nonsensical'].tolist()
pair_ids = merged_df['sentence_id'].tolist() # Keep track of IDs if needed later
num_pairs = len(pair_ids) # Total number of complete pairs found

# --- Apply PAIRWISE smart_truncate ---
# This function takes the two aligned lists and returns two truncated aligned lists
sensible_sentences_truncated, nonsensical_sentences_truncated = smart_truncate(
    sensible_sentences_original,
    nonsensical_sentences_original,
    tokenizer
)

# --- Assign truncated sentences for downstream use ---
# These are the variables subsequent cells will use
sensible_sentences = sensible_sentences_truncated
nonsensical_sentences = nonsensical_sentences_truncated

# --- Calculate and report token lengths AFTER pairwise truncation ---
# Note: Lengths within a pair will be identical, but lengths might vary across pairs.
sensible_token_lengths_after = [len(tokenizer.encode(s, add_special_tokens=False)) for s in sensible_sentences]
nonsensical_token_lengths_after = [len(tokenizer.encode(s, add_special_tokens=False)) for s in nonsensical_sentences]

# Calculate min and max lengths ACROSS ALL PAIRS after truncation
sensible_min_len_after = min(sensible_token_lengths_after) if sensible_token_lengths_after else None
sensible_max_len_after = max(sensible_token_lengths_after) if sensible_token_lengths_after else None
# nonsensical min/max will be the same because lengths match within pairs
nonsensical_min_len_after = min(nonsensical_token_lengths_after) if nonsensical_token_lengths_after else None
nonsensical_max_len_after = max(nonsensical_token_lengths_after) if nonsensical_token_lengths_after else None

In [None]:
# === Attention Sink Detection Functions ===
def normalized_entropy_method1(ps, epsilon=1e-16):
    """
    Numerically stable normalized entropy calculation
    
    """
    # Add epsilon to prevent zeros
    ps = np.array(ps) + epsilon
    ps = ps / np.sum(ps)  # Renormalize to ensure sum = 1
    
    # Compute log safely
    log_ps = np.where(ps > 0, np.log(ps), 0)
    
    # Calculate normalized entropy
    seq_len = len(ps)
    max_entropy = np.log(seq_len)
    return -np.sum(ps * log_ps) / max_entropy

def detect_attention_sinks_weighted(attention_weights, normalize_by_length=True):
    """
    Detect attention sinks and calculate weighted entropy.
    
    """
    # Focus on the last query token (the one we're generating)
    last_token_attention = attention_weights[:, -1, :]  # Shape: (n_heads, seq_len_key)
    
    n_heads, seq_len = last_token_attention.shape
    
    sink_strengths = []
    weighted_entropies = []
    standard_entropies = []
    
    for head_idx in range(n_heads):
        probs = last_token_attention[head_idx]  # P1, P2, ..., Pn
        
        # P1 is attention to start-of-sequence token (first token)
        p1 = probs[0]
        
        # P2, P3, ..., Pn are attention to content tokens
        content_probs = probs[1:]
        
        # Calculate standard entropy using Method 1's approach
        standard_entropy = normalized_entropy_method1(probs)
        
        # Calculate weighted entropy
        # weighted entropy = (1-P1) * Σ(i=2 to n) Pi * log(Pi)
        if len(content_probs) > 0:
            content_entropy = normalized_entropy_method1(content_probs)
            # Convert back to unnormalized
            content_entropy_unnormalized = content_entropy * np.log(len(content_probs))
            weighted_entropy = (1 - p1) * content_entropy_unnormalized
            
            # Normalize weighted entropy by content length
            if normalize_by_length and len(content_probs) > 1:
                weighted_entropy = weighted_entropy / np.log(len(content_probs))
        else:
            weighted_entropy = 0.0
        
        sink_strengths.append(p1)
        weighted_entropies.append(weighted_entropy)
        standard_entropies.append(standard_entropy)
    
    return np.array(sink_strengths), np.array(weighted_entropies), np.array(standard_entropies)

def analyze_attention_patterns_weighted(attention_weights):
    """
    Comprehensive analysis using weighted entropy method.
    
    """
    # Calculate weighted entropies using supervisor's method
    sink_strengths, weighted_entropies, standard_entropies = detect_attention_sinks_weighted(attention_weights)
    
    # Additional statistics
    last_token_attention = attention_weights[:, -1, :]
    max_attention_per_head = np.max(last_token_attention, axis=1)
    attention_concentration = np.sum(last_token_attention**2, axis=1)
    
    # Binary sink classification for comparison (using 0.9 threshold)
    binary_sink_heads = sink_strengths > 0.9
    
    return {
        'sink_strengths': sink_strengths,  # P1 values
        'weighted_entropies': weighted_entropies,  
        'standard_entropies': standard_entropies,  # For comparison
        'binary_sink_heads': binary_sink_heads,  # Binary classification
        'max_attention_per_head': max_attention_per_head,
        'attention_concentration': attention_concentration,
        'raw_attention': last_token_attention
    }

# --- Modified Processing Functions ---
def process_attn_entropies_with_weighted_analysis(attentions, output_token=0):
    """
    Process attention weights weighted entropy method.
    
    """
    weighted_entropies = []
    standard_entropies = []
    sink_analyses = []
    
    for i, attn in enumerate(attentions[output_token]):
        attn_numpy = attn.to(torch.float32).squeeze().cpu().numpy()
        # Shape: (n_heads, seq_len_query, seq_len_key)
        
        # Analyze using weighted method
        sink_analysis = analyze_attention_patterns_weighted(attn_numpy)
        sink_analyses.append(sink_analysis)
        
        # Extract entropies
        weighted_entropies.append(sink_analysis['weighted_entropies'])
        standard_entropies.append(sink_analysis['standard_entropies'])
    
    return np.array(weighted_entropies), np.array(standard_entropies), sink_analyses

def calculate_weighted_entropies_with_sink_analysis(input_texts):
    """
    Calculate weighted entropies for a list of input texts.
    Works for both sentences and mathematical problems.
    
    """
    weighted_entropies = []
    standard_entropies = []
    all_sink_analyses = []
    
    for text in input_texts:
        model_inputs = tokenizer.encode(text, return_tensors="pt").to(model.device)
        output = model.generate(
            model_inputs,
            output_attentions=True,
            max_new_tokens=1,
            return_dict_in_generate=True
        )
        
        # Process with weighted analysis
        text_weighted, text_standard, text_sink_analyses = process_attn_entropies_with_weighted_analysis(output.attentions)
        
        weighted_entropies.append(text_weighted)
        standard_entropies.append(text_standard)
        all_sink_analyses.append(text_sink_analyses)
    
    return np.array(weighted_entropies), np.array(standard_entropies), all_sink_analyses

# === Analysis Functions ===
def summarize_weighted_sink_behavior(sink_analyses, condition_name):
    """
    Summarize attention sink behavior using weighted entropy method.
    
    """
    print(f"\n=== WEIGHTED ENTROPY ANALYSIS: {condition_name} ===")
    
    # Aggregate across all inputs and layers
    all_sink_strengths = []
    all_weighted_entropies = []
    all_standard_entropies = []
    
    for input_analysis in sink_analyses:
        for layer_analysis in input_analysis:
            all_sink_strengths.extend(layer_analysis['sink_strengths'])
            all_weighted_entropies.extend(layer_analysis['weighted_entropies'])
            all_standard_entropies.extend(layer_analysis['standard_entropies'])
    
    all_sink_strengths = np.array(all_sink_strengths)
    all_weighted_entropies = np.array(all_weighted_entropies)
    all_standard_entropies = np.array(all_standard_entropies)
    
    # Calculate statistics
    mean_sink_strength = np.mean(all_sink_strengths)
    mean_weighted_entropy = np.mean(all_weighted_entropies)
    mean_standard_entropy = np.mean(all_standard_entropies)
    
    # Binary classification for comparison
    strong_sinks = all_sink_strengths > 0.9
    sink_percentage = np.mean(strong_sinks) * 100
    
    print(f"Mean sink strength (P1): {mean_sink_strength:.3f}")
    print(f"Mean weighted entropy: {mean_weighted_entropy:.3f}")
    print(f"Mean standard entropy: {mean_standard_entropy:.3f}")
    print(f"Strong sink heads (P1 > 0.9): {sink_percentage:.1f}%")
    print(f"Entropy reduction from weighting: {mean_standard_entropy - mean_weighted_entropy:.3f}")
    
    return {
        'mean_sink_strength': mean_sink_strength,
        'mean_weighted_entropy': mean_weighted_entropy,
        'mean_standard_entropy': mean_standard_entropy,
        'sink_percentage': sink_percentage,
        'entropy_reduction': mean_standard_entropy - mean_weighted_entropy
    }

def compare_weighted_sink_behavior(sink_analyses_1, sink_analyses_2, condition1_name, condition2_name):
    """
    Compare attention sink behavior between two conditions using weighted entropy.
    
    """
    print(f"\n=== COMPARING WEIGHTED SINK BEHAVIOR: {condition1_name} vs {condition2_name} ===")
    
    summary1 = summarize_weighted_sink_behavior(sink_analyses_1, condition1_name)
    summary2 = summarize_weighted_sink_behavior(sink_analyses_2, condition2_name)
    
    # Calculate differences
    sink_diff = summary2['mean_sink_strength'] - summary1['mean_sink_strength']
    weighted_entropy_diff = summary2['mean_weighted_entropy'] - summary1['mean_weighted_entropy']
    standard_entropy_diff = summary2['mean_standard_entropy'] - summary1['mean_standard_entropy']
    
    print(f"\nDifferences ({condition2_name} - {condition1_name}):")
    print(f"Sink strength difference: {sink_diff:+.3f}")
    print(f"Weighted entropy difference: {weighted_entropy_diff:+.3f}")
    print(f"Standard entropy difference: {standard_entropy_diff:+.3f}")
    
    return summary1, summary2

# --- Data Extraction Functions ---
def extract_sink_strengths(sink_analyses):
    """
    Extract sink strength values (P1).
    Returns shape: (n_samples, n_layers, n_heads)
    
    """
    sink_strengths = []
    for input_analysis in sink_analyses:
        input_strengths = []
        for layer_analysis in input_analysis:
            input_strengths.append(layer_analysis['sink_strengths'])
        sink_strengths.append(np.array(input_strengths))
    return np.array(sink_strengths)

def extract_weighted_entropies(sink_analyses):
    """
    Extract weighted entropy values.
    Returns shape: (n_samples, n_layers, n_heads)
    
    """
    weighted_entropies = []
    for input_analysis in sink_analyses:
        input_entropies = []
        for layer_analysis in input_analysis:
            input_entropies.append(layer_analysis['weighted_entropies'])
        weighted_entropies.append(np.array(input_entropies))
    return np.array(weighted_entropies)

def extract_standard_entropies(sink_analyses):
    """
    Extract standard entropy values for comparison.
    Returns shape: (n_samples, n_layers, n_heads)
    
    """
    standard_entropies = []
    for input_analysis in sink_analyses:
        input_entropies = []
        for layer_analysis in input_analysis:
            input_entropies.append(layer_analysis['standard_entropies'])
        standard_entropies.append(np.array(input_entropies))
    return np.array(standard_entropies)

def extract_binary_sink_heads(sink_analyses):
    """
    Extract binary sink head classifications.
    Returns shape: (n_samples, n_layers, n_heads)
    
    """
    binary_sinks = []
    for input_analysis in sink_analyses:
        input_binary = []
        for layer_analysis in input_analysis:
            input_binary.append(layer_analysis['binary_sink_heads'])
        binary_sinks.append(np.array(input_binary))
    return np.array(binary_sinks)

# --- Save numpy files ---
def save_weighted_entropy_arrays(model_tag, dataset_tag, **arrays):
    """
    Save entropy arrays with flexible naming for different datasets.
    
    """
    import os
    
    # Create directory
    dir_name = f"entropy_results_{model_tag}_{dataset_tag}"
    os.makedirs(dir_name, exist_ok=True)
    
    # Save each array
    for array_name, array_data in arrays.items():
        if array_data is not None:
            file_path = os.path.join(dir_name, f"{array_name}.npy")
            np.save(file_path, array_data)
            print(f"Saved {array_name}.npy with shape {array_data.shape}")

In [None]:
# --- Configuration (MODIFY THIS FOR EACH MODEL) ---
MODEL_TAG = "codellama_instruct_70b"  # Change this for each model: "llama_2_70b_base", "codellama_70b", "codellama_instruct_70b", "llama_chat_70b"
DATASET_TAG = "ling"

print(f"Processing {MODEL_TAG} on linguistic dataset (sensible vs nonsensical)")

# --- Modified Main Analysis ---
print("--- CALCULATING WEIGHTED ENTROPIES WITH ATTENTION SINK ANALYSIS ---")

# Calculate weighted entropies using supervisor's method
print("Processing sensible sentences...")
sensible_weighted_entropies, sensible_standard_entropies, sensible_sink_analyses = calculate_weighted_entropies_with_sink_analysis(sensible_sentences)

print("Processing nonsensical sentences...")
nonsensical_weighted_entropies, nonsensical_standard_entropies, nonsensical_sink_analyses = calculate_weighted_entropies_with_sink_analysis(nonsensical_sentences)

# Extract data arrays
print("Extracting sink strengths...")
sensible_sink_strengths = extract_sink_strengths(sensible_sink_analyses)
nonsensical_sink_strengths = extract_sink_strengths(nonsensical_sink_analyses)

print("Extracting weighted entropies...")
sensible_weighted_extracted = extract_weighted_entropies(sensible_sink_analyses)
nonsensical_weighted_extracted = extract_weighted_entropies(nonsensical_sink_analyses)

print("Extracting standard entropies...")
sensible_standard_extracted = extract_standard_entropies(sensible_sink_analyses)
nonsensical_standard_extracted = extract_standard_entropies(nonsensical_sink_analyses)

print("Extracting binary sink classifications...")
sensible_binary_sinks = extract_binary_sink_heads(sensible_sink_analyses)
nonsensical_binary_sinks = extract_binary_sink_heads(nonsensical_sink_analyses)

# Perform weighted entropy analysis
print("\n" + "="*60)
print("WEIGHTED ENTROPY BEHAVIOR ANALYSIS")
print("="*60)

# Analyze each condition using supervisor's method
sensible_summary = summarize_weighted_sink_behavior(sensible_sink_analyses, "SENSIBLE SENTENCES")
nonsensical_summary = summarize_weighted_sink_behavior(nonsensical_sink_analyses, "NONSENSICAL SENTENCES")

# Compare conditions
compare_weighted_sink_behavior(sensible_sink_analyses, nonsensical_sink_analyses, "Sensible", "Nonsensical")

# Save updated arrays
save_weighted_entropy_arrays(
    model_tag=MODEL_TAG,
    dataset_tag=DATASET_TAG,
    
    # Weighted entropy data (supervisor's method)
    sensible_weighted_entropy=sensible_weighted_extracted,
    nonsensical_weighted_entropy=nonsensical_weighted_extracted,
    
    # Standard entropy for comparison (Method 1)
    sensible_standard_entropy=sensible_standard_extracted,
    nonsensical_standard_entropy=nonsensical_standard_extracted,
    
    # Sink strength values (P1)
    sensible_sink_strength=sensible_sink_strengths,
    nonsensical_sink_strength=nonsensical_sink_strengths,
    
    # Binary sink classifications
    sensible_binary_sinks=sensible_binary_sinks,
    nonsensical_binary_sinks=nonsensical_binary_sinks,
)

print(f"\nLinguistic weighted entropy analysis complete for {MODEL_TAG}!")

In [None]:
def verify_start_token(sentence):
    tokens = tokenizer.encode(sentence, return_tensors="pt")[0]
    decoded_tokens = [tokenizer.decode([t]) for t in tokens]
    print(f"First token: '{decoded_tokens[0]}'")
    print(f"Full tokenization: {decoded_tokens}")
    
    # Check if first token is a special token
    special_tokens = tokenizer.all_special_tokens
    is_special = any(decoded_tokens[0] == st for st in special_tokens)
    print(f"Is first token a special token? {is_special}")
    return is_special

# Test with a sample sentence
verify_start_token("This is a test sentence.")


## Multimodal Datasets

In [None]:
# --- Configuration (MODIFY THIS FOR EACH MODEL) ---
MODEL_TAG = "codellama_instruct_70b"  # Change this for each model: "llama_2_70b_base", "codellama_70b", "codellama_instruct_70b"
DATASET_TAG = "md"

print(f"Processing {MODEL_TAG} on MD dataset (easy vs hard)")

# --- MD Dataset Loading and Preprocessing ---
md_df = pd.read_csv("MD_Dataset_4.csv")

# Filter the DataFrame into easy and hard problems
md_easy_df = md_df[md_df['type'] == 'easy']
md_hard_df = md_df[md_df['type'] == 'hard']

print(f"Total original easy problems found: {len(md_easy_df)}")
print(f"Total original hard problems found: {len(md_hard_df)}")

# --- Match pairs based on pair_id ---
merged_md_df = pd.merge(md_easy_df, md_hard_df, on='pair_id', suffixes=('_easy', '_hard'))

# Extract the aligned problems and their IDs into lists
problems_easy_original = merged_md_df['problem_easy'].tolist()
problems_hard_original = merged_md_df['problem_hard'].tolist()
pair_ids_md = merged_md_df['pair_id'].tolist()
num_pairs_md = len(pair_ids_md)

print(f"Found and aligned {num_pairs_md} complete easy/hard problem pairs.")

# --- Modified Main Analysis ---
print("--- CALCULATING WEIGHTED ENTROPIES WITH ATTENTION SINK ANALYSIS FOR MD DATASET ---")

# Calculate weighted entropies
print("Processing easy problems...")
easy_weighted_entropies, easy_standard_entropies, easy_sink_analyses = calculate_weighted_entropies_with_sink_analysis(problems_easy_original)

print("Processing hard problems...")
hard_weighted_entropies, hard_standard_entropies, hard_sink_analyses = calculate_weighted_entropies_with_sink_analysis(problems_hard_original)

# Extract all data arrays
print("Extracting sink strengths...")
easy_sink_strengths = extract_sink_strengths(easy_sink_analyses)
hard_sink_strengths = extract_sink_strengths(hard_sink_analyses)

print("Extracting weighted entropies...")
easy_weighted_extracted = extract_weighted_entropies(easy_sink_analyses)
hard_weighted_extracted = extract_weighted_entropies(hard_sink_analyses)

print("Extracting standard entropies...")
easy_standard_extracted = extract_standard_entropies(easy_sink_analyses)
hard_standard_extracted = extract_standard_entropies(hard_sink_analyses)

print("Extracting binary sink classifications...")
easy_binary_sinks = extract_binary_sink_heads(easy_sink_analyses)
hard_binary_sinks = extract_binary_sink_heads(hard_sink_analyses)

# Perform sink behavior analysis
print("\n" + "="*60)
print("WEIGHTED ENTROPY BEHAVIOR ANALYSIS - MD DATASET")
print("="*60)

# Analyze each condition
easy_summary = summarize_weighted_sink_behavior(easy_sink_analyses, "EASY PROBLEMS")
hard_summary = summarize_weighted_sink_behavior(hard_sink_analyses, "HARD PROBLEMS")

# Compare between conditions
compare_weighted_sink_behavior(easy_sink_analyses, hard_sink_analyses, "Easy", "Hard")

# --- Save All MD Data Arrays ---
print("\n" + "="*60)
print("SAVING ALL MD DATA ARRAYS")
print("="*60)

# Save all arrays for MD dataset
save_weighted_entropy_arrays(
    model_tag=MODEL_TAG,
    dataset_tag=DATASET_TAG,
    
    # Weighted entropy data (supervisor's method)
    easy_weighted_entropy=easy_weighted_extracted,
    hard_weighted_entropy=hard_weighted_extracted,
    
    # Standard entropy for comparison
    easy_standard_entropy=easy_standard_extracted,
    hard_standard_entropy=hard_standard_extracted,
    
    # Sink strength values (P1)
    easy_sink_strength=easy_sink_strengths,
    hard_sink_strength=hard_sink_strengths,
    
    # Binary sink classifications
    easy_binary_sinks=easy_binary_sinks,
    hard_binary_sinks=hard_binary_sinks,
)

# Save detailed sink analysis results
print("Saving detailed sink analysis...")
import pickle
with open(f"sink_analysis_{MODEL_TAG}_{DATASET_TAG}.pkl", "wb") as f:
    pickle.dump({
        'easy_sink_analyses': easy_sink_analyses,
        'hard_sink_analyses': hard_sink_analyses,
        'easy_summary': easy_summary,
        'hard_summary': hard_summary,
        'model_tag': MODEL_TAG,
        'dataset_tag': DATASET_TAG,
        'pair_ids': pair_ids_md
    }, f)

print(f"\nMD weighted entropy calculation and sink analysis complete for {MODEL_TAG}!")

# --- Quick Verification ---
print(f"\n--- DATA SHAPE VERIFICATION ---
print(f"easy_weighted_entropy: {easy_weighted_extracted.shape}")
print(f"hard_weighted_entropy: {hard_weighted_extracted.shape}")
print(f"easy_standard_entropy: {easy_standard_extracted.shape}")
print(f"hard_standard_entropy: {hard_standard_extracted.shape}")
print(f"easy_sink_strength: {easy_sink_strengths.shape}")
print(f"hard_sink_strength: {hard_sink_strengths.shape}")