In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import numpy as np
import torch
import os
import gc
import json
import pandas as pd
import time
from tqdm import tqdm
import traceback
import warnings
import scipy.stats as stats
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
import bitsandbytes as bnb

In [None]:


warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

class ReferencePointAnalysis:
    """
    Analyzes how reference points affect value vector transformations in transformer models.
    This analysis tracks the influence of reference tokens versus non-reference tokens
    on the transformation of value vectors in self-attention mechanisms.
    """

    def __init__(self, model_name, output_dir="./reference_point_analysis"):
        """Initialize the analysis with model name and output directory."""
        self.model_name = model_name
        self.output_dir = output_dir
        self.model = None
        self.tokenizer = None
        self.results = {}

        self.hidden_states = None
        self.attention_matrices = None

        # Parameters for reference point identification
        self.reference_threshold = 0.1  # Threshold for identifying reference tokens

        self.debug = True

        os.makedirs(output_dir, exist_ok=True)

    def load_model(self, use_4bit=True):
        """Load the model and tokenizer with robust error handling."""
        print(f"Loading model: {self.model_name}")

        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)

            # Check if we should try 4-bit quantization
            if use_4bit:
                try:
                    quantization_config = BitsAndBytesConfig(
                        load_in_4bit=True,
                        bnb_4bit_compute_dtype=torch.float16
                    )

                    self.model = AutoModelForCausalLM.from_pretrained(
                        self.model_name,
                        torch_dtype=torch.float16,
                        quantization_config=quantization_config,
                        device_map="auto",
                        trust_remote_code=True,
                        output_hidden_states=True
                    )
                    print("Model loaded with 4-bit quantization")
                except (ImportError, ModuleNotFoundError) as e:
                    print(f"BitsAndBytes not available: {e}")
                    print("Falling back to standard loading...")
                    self.model = AutoModelForCausalLM.from_pretrained(
                        self.model_name,
                        torch_dtype=torch.float16,
                        device_map="auto",
                        trust_remote_code=True,
                        output_hidden_states=True
                    )
            else:
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    trust_remote_code=True,
                    output_hidden_states=True
                )

        except Exception as e:
            print(f"Error loading model: {e}")
            print("This might be due to an unsupported model architecture or outdated transformers library.")
            print("Try updating with: pip install --upgrade transformers")
            print("Or install from source: pip install git+https://github.com/huggingface/transformers.git")
            print("Using a fallback approach...")
            return None, None

        gc.collect()
        torch.cuda.empty_cache()

        print(f"Model loaded successfully")
        return self.model, self.tokenizer

    def identify_reference_tokens(self, attention_matrix, threshold=None):
        """
        Identify reference tokens based on attention patterns, with improved robustness.
        A token is considered a reference token if it receives above-threshold
        attention from a significant proportion of other tokens.

        Args:
            attention_matrix: numpy array of shape (seq_len, seq_len)
            threshold: float, threshold for identifying reference tokens

        Returns:
            list of indices of reference tokens
        """
        if threshold is None:
            threshold = self.reference_threshold

        # Handle NaN/Inf in attention matrix
        attn_clean = np.nan_to_num(attention_matrix, nan=0.0, posinf=1.0, neginf=0.0)

        # Calculate the average attention received by each token
        avg_attention_received = np.mean(attn_clean, axis=0)

        # Identify tokens that receive attention above the threshold
        reference_tokens = np.where(avg_attention_received > threshold)[0]

        # If no reference tokens found with this threshold, take the top 10%
        if len(reference_tokens) == 0:
            num_tokens = attention_matrix.shape[0]
            top_n = max(1, int(0.1 * num_tokens))  # At least 1 token
            reference_tokens = np.argsort(avg_attention_received)[-top_n:]

        # max 30% of sequence
        max_ref_tokens = max(1, int(0.3 * attention_matrix.shape[0]))
        if len(reference_tokens) > max_ref_tokens:
            # Keep only the top reference tokens
            sorted_indices = np.argsort(avg_attention_received[reference_tokens])
            reference_tokens = reference_tokens[sorted_indices[-max_ref_tokens:]]

        if self.debug:
            print(f"Identified {len(reference_tokens)} reference tokens out of {attention_matrix.shape[0]} total tokens")

        return reference_tokens.tolist()

    def extract_value_vectors(self, hidden_states, layer_idx):
        """
        Extract value vectors for a specific layer.
        This is a modified version that correctly handles the hidden states format
        from different model architectures.

        Args:
            hidden_states: list of tensors, hidden states for each layer
            layer_idx: int, index of the layer

        Returns:
            numpy array of value vectors with shape (seq_len, hidden_dim)
        """
        try:
            layer_hidden = hidden_states[layer_idx]

            if isinstance(layer_hidden, torch.Tensor):
                layer_hidden = layer_hidden.cpu().numpy()

            # Expected shape: (batch_size, seq_len, hidden_dim)
            if self.debug:
                print(f"Hidden state shape: {layer_hidden.shape}")

            # Handle different shape formats
            if len(layer_hidden.shape) == 3:
                if layer_hidden.shape[0] == 1:  # (batch_size, seq_len, hidden_dim)
                    result = layer_hidden[0]  # Return (seq_len, hidden_dim)
                elif layer_hidden.shape[1] == 1:  # (seq_len, batch_size, hidden_dim)
                    result = layer_hidden[:, 0, :]  # Return (seq_len, hidden_dim)
                else:
                    # If batch dimension is neither 0 nor 1, use the first batch
                    result = layer_hidden[0]  # Assume (batch_size, seq_len, hidden_dim)
            elif len(layer_hidden.shape) == 2:
                # Already in the format (seq_len, hidden_dim)
                result = layer_hidden
            else:
                raise ValueError(f"Unexpected hidden state shape: {layer_hidden.shape}")

            result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)

            # Check for extreme values and scale if necessary
            max_abs = np.max(np.abs(result))
            if max_abs > 1e5:
                # Scale down to prevent overflow
                scale_factor = 1e5 / max_abs
                result = result * scale_factor

            return result

        except Exception as e:
            print(f"Error in extract_value_vectors: {e}")
            traceback.print_exc()
            return np.zeros((1, 1))

    def decompose_value_transformations(self, value_vectors, attention_matrix, reference_tokens):
        """
        Decompose value vector transformations into components from reference
        and non-reference tokens. Improved for numerical stability.

        Args:
            value_vectors: numpy array of value vectors with shape (seq_len, hidden_dim)
            attention_matrix: numpy array of attention weights with shape (seq_len, seq_len)
            reference_tokens: list of indices of reference tokens

        Returns:
            dict with decomposition metrics
        """
        if self.debug:
            print(f"Value vectors shape: {value_vectors.shape}")
            print(f"Attention matrix shape: {attention_matrix.shape}")
            print(f"Number of reference tokens: {len(reference_tokens)}")

        # Handle case where dimensions don't match
        seq_len = attention_matrix.shape[0]
        if value_vectors.shape[0] != seq_len:
            print(f"WARNING: Value vectors seq_len ({value_vectors.shape[0]}) doesn't match attention matrix seq_len ({seq_len})")
            print("Using a simplified approach...")

            metrics = {
                "relative_magnitude": [0.5] * seq_len,
                "directional_influence": [0.5] * seq_len,
                "information_content_change": [0.0] * seq_len
            }
            return metrics

        attention_matrix = np.nan_to_num(attention_matrix, nan=0.0, posinf=0.0, neginf=0.0)

        # Normalize attention matrices if needed
        row_sums = np.sum(attention_matrix, axis=1, keepdims=True)
        mask = row_sums > 1e-10  # Create a mask for rows with non-zero sums
        attention_matrix = np.where(mask, attention_matrix / np.maximum(row_sums, 1e-10), 0.0)

        metrics = {
            "relative_magnitude": [],
            "directional_influence": [],
            "information_content_change": []
        }

        try:
            for i in range(seq_len):
                # Skip if this is a reference token itself
                if i in reference_tokens:
                    # Add dummy values for reference tokens
                    metrics["relative_magnitude"].append(0.5)
                    metrics["directional_influence"].append(0.5)
                    metrics["information_content_change"].append(0.0)
                    continue

                # Create sets for cleaner indexing
                ref_set = set(reference_tokens)
                non_ref_set = set(range(seq_len)) - ref_set

                # Convert to lists for indexing
                ref_indices = list(ref_set)
                non_ref_indices = list(non_ref_set)

                # Get attention weights for this token
                attn_weights_i = attention_matrix[i]

                # Calculate the post-attention hidden state
                post_attn_i = np.zeros_like(value_vectors[0])
                attn_sum = 0.0  # Track total attention weight
                for j in range(seq_len):
                    if j < value_vectors.shape[0]:
                        # Handle NaN or inf in attention weights
                        weight = attn_weights_i[j]
                        if np.isnan(weight) or np.isinf(weight):
                            weight = 0.0
                        attn_sum += weight
                        post_attn_i += weight * value_vectors[j]

                # Normalize if total attention weight is too small
                if attn_sum < 1e-10:
                    post_attn_i = np.zeros_like(value_vectors[0])

                # Calculate reference contribution
                ref_contribution = np.zeros_like(value_vectors[0])
                ref_attn_sum = 0.0
                for j in ref_indices:
                    if j < value_vectors.shape[0]:
                        weight = attn_weights_i[j]
                        if np.isnan(weight) or np.isinf(weight):
                            weight = 0.0
                        ref_attn_sum += weight
                        ref_contribution += weight * value_vectors[j]

                # Calculate non-reference contribution
                non_ref_contribution = np.zeros_like(value_vectors[0])
                non_ref_attn_sum = 0.0
                for j in non_ref_indices:
                    if j < value_vectors.shape[0]:
                        weight = attn_weights_i[j]
                        if np.isnan(weight) or np.isinf(weight):
                            weight = 0.0
                        non_ref_attn_sum += weight
                        non_ref_contribution += weight * value_vectors[j]

                # 1. Relative Magnitude 
                ref_norm = np.linalg.norm(ref_contribution)
                total_norm = np.linalg.norm(post_attn_i)

                # Handle zero or very small norms
                if total_norm < 1e-10:
                    rel_mag = 0.0 if ref_norm < 1e-10 else 1.0
                else:
                    rel_mag = float(np.clip(ref_norm / total_norm, 0.0, 1.0))

                # 2. Directional Influence 
                if ref_norm < 1e-10 or total_norm < 1e-10:
                    ref_dir = 0.5  
                else:
                    # Safe dot product with clipping
                    dot_prod = np.dot(ref_contribution, post_attn_i)
                    ref_dir = float(np.clip(dot_prod / (ref_norm * total_norm), -1.0, 1.0))
                    # Scale from [-1,1] to [0,1] for consistency
                    ref_dir = (ref_dir + 1.0) / 2.0

                # 3. Information Content Change - L2 calculation
                diff_vector = post_attn_i - non_ref_contribution
                squared_sum = 0.0
                for val in diff_vector:
                    # Avoid overflow by handling each component carefully
                    val_safe = 0.0 if np.isnan(val) or np.isinf(val) else val
                    squared_sum += float(val_safe) * float(val_safe)
                info_change = float(np.sqrt(max(0.0, squared_sum)))

                # Normalize information content change to a reasonable range
                if info_change > 1e6:
                    info_change = 1e6  # Cap extremely large values

                # NaN/Inf
                rel_mag = 0.5 if np.isnan(rel_mag) or np.isinf(rel_mag) else rel_mag
                ref_dir = 0.5 if np.isnan(ref_dir) or np.isinf(ref_dir) else ref_dir
                info_change = 0.0 if np.isnan(info_change) or np.isinf(info_change) else info_change

                # Store metrics
                metrics["relative_magnitude"].append(float(rel_mag))
                metrics["directional_influence"].append(float(ref_dir))
                metrics["information_content_change"].append(float(info_change))

        except Exception as e:
            print(f"Error in decompose_value_transformations: {e}")
            traceback.print_exc()

            metrics = {
                "relative_magnitude": [0.5] * seq_len,
                "directional_influence": [0.5] * seq_len,
                "information_content_change": [0.0] * seq_len
            }

        return metrics

    def safe_mean(self, values):
        """Calculate mean safely, avoiding NaN results."""
        if not values:
            return 0.0
        # Filter out NaN and Inf values
        clean_values = [v for v in values if not np.isnan(v) and not np.isinf(v)]
        if not clean_values:
            return 0.0
        return float(np.mean(clean_values))

    def sanitize_metrics(self, metrics_dict):
        """Replace NaN/Inf values with defaults."""
        for metric_name, values in metrics_dict.items():
            if isinstance(values, list):
                metrics_dict[metric_name] = [
                    0.5 if (np.isnan(v) or np.isinf(v)) and metric_name != "information_content_change"
                    else 0.0 if (np.isnan(v) or np.isinf(v))
                    else float(v) for v in values
                ]
        return metrics_dict

    def analyze_text(self, text):
        """
        Analyze a text sample and compute metrics for reference point influence.
        """
        if self.model is None or self.tokenizer is None:
            print("Model or tokenizer not loaded!")
            return {"error": "Model or tokenizer not loaded"}

        try:
            if self.debug:
                print(f"Analyzing text: '{text[:50]}...'")

            inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)

            if self.debug:
                print(f"Tokenized input shape: {inputs.input_ids.shape}")

            with torch.no_grad():
                outputs = self.model(
                    **inputs,
                    output_attentions=True,
                    output_hidden_states=True,
                    return_dict=True
                )

            # Extract attention patterns and hidden states
            attentions = outputs.attentions  # tuple of (layer, batch, head, seq_len, seq_len)
            hidden_states = outputs.hidden_states  # tuple of (layer+1, batch, seq_len, hidden_dim)

            if self.debug:
                print(f"Number of attention layers: {len(attentions)}")
                print(f"Number of hidden state layers: {len(hidden_states)}")
                if len(attentions) > 0:
                    print(f"Attention shape for first layer: {attentions[0].shape}")
                if len(hidden_states) > 0:
                    print(f"Hidden state shape for first layer: {hidden_states[0].shape}")

            self.attention_matrices = attentions
            self.hidden_states = hidden_states

            results = {
                "layers": {}
            }

            # Process each layer
            num_layers = len(attentions)
            for layer_idx in range(num_layers):
                if self.debug:
                    print(f"Processing layer {layer_idx+1}/{num_layers}")

                # Each layer has shape (batch, head, seq_len, seq_len)
                layer_attention = attentions[layer_idx][0].cpu().numpy()  # shape: (head, seq_len, seq_len)

                # Compute average attention pattern across all heads
                avg_attention = np.mean(layer_attention, axis=0)

                # Clean any NaN values
                avg_attention = np.nan_to_num(avg_attention, nan=0.0, posinf=1.0, neginf=0.0)

                if self.debug:
                    print(f"Average attention shape: {avg_attention.shape}")

                # Extract value vectors for this layer
                value_vectors = self.extract_value_vectors(hidden_states, layer_idx)

                if self.debug:
                    print(f"Value vectors shape: {value_vectors.shape}")

                # Identify reference tokens
                reference_tokens = self.identify_reference_tokens(avg_attention)

                # Decompose value transformations
                decomposition_metrics = self.decompose_value_transformations(
                    value_vectors, avg_attention, reference_tokens
                )

                decomposition_metrics = self.sanitize_metrics(decomposition_metrics)

                results["layers"][str(layer_idx)] = {
                    "reference_tokens": reference_tokens,
                    "num_reference_tokens": len(reference_tokens),
                    "metrics": decomposition_metrics
                }

            return results

        except Exception as e:
            print(f"Error analyzing text: {e}")
            traceback.print_exc()
            return {"error": str(e)}

    def analyze_samples(self, texts, max_failures=20):
        """
        Analyze multiple text samples and collect metrics for correlation analysis.

        Args:
            texts: list of text samples
            max_failures: maximum number of consecutive failures before giving up
        """
        print(f"Analyzing {len(texts)} text samples...")

        if self.model is None:
            self.load_model()

        metric_collectors = {
            "relative_magnitude": {},
            "directional_influence": {},
            "information_content_change": {},
            "reference_token_count": {}
        }

        successful_analyses = 0
        consecutive_failures = 0

        for i, text in enumerate(tqdm(texts, desc="Processing texts")):
            try:
                result = self.analyze_text(text)

                if "error" in result:
                    consecutive_failures += 1
                    if consecutive_failures >= max_failures:
                        print(f"Too many consecutive failures ({max_failures}). Stopping analysis.")
                        break
                    continue

                # Reset consecutive failures counter
                consecutive_failures = 0
                successful_analyses += 1

                # Collect metrics from each layer
                for layer_idx, layer_data in result["layers"].items():
                    if layer_idx not in metric_collectors["relative_magnitude"]:
                        # Initialize layer collectors for all metrics
                        for metric in metric_collectors.keys():
                            metric_collectors[metric][layer_idx] = []

                    metric_collectors["reference_token_count"][layer_idx].append(
                        layer_data["num_reference_tokens"]
                    )

                    for metric_name, values in layer_data["metrics"].items():
                        avg_value = self.safe_mean(values)
                        metric_collectors[metric_name][layer_idx].append(avg_value)

                torch.cuda.empty_cache()
                gc.collect()

                # Progress indication
                if (i+1) % 5 == 0:
                    print(f"Processed {i+1}/{len(texts)} samples")

            except Exception as e:
                print(f"Error analyzing text {i+1}: {str(e)}")
                traceback.print_exc()
                consecutive_failures += 1
                if consecutive_failures >= max_failures:
                    print(f"Too many consecutive failures ({max_failures}). Stopping analysis.")
                    break

        print(f"Successfully analyzed {successful_analyses}/{len(texts)} samples")

        if successful_analyses > 0:
            stats_results = self.compute_statistics(metric_collectors)

            self.results = {
                "metric_collectors": metric_collectors,
                "statistics": stats_results,
                "successful_analyses": successful_analyses
            }

            return stats_results
        else:
            print("No successful analyses to compute statistics")
            return {}

    def safe_correlation(self, x, y):
        """Calculate correlation safely, avoiding NaN results."""
        if len(x) < 3 or len(y) < 3:
            return 0.0

        std_x = np.std(x)
        std_y = np.std(y)

        if std_x < 1e-8 or std_y < 1e-8:
            return 0.0  # Not enough variance for meaningful correlation

        try:
            corr = np.corrcoef(x, y)[0, 1]
            # Handle NaN/Inf
            if np.isnan(corr) or np.isinf(corr):
                return 0.0
            return float(corr)
        except Exception:
            return 0.0  # Return zero correlation on any error

    def safe_trend_analysis(self, indices, values):
        """Calculate trend correlation safely."""
        if len(indices) < 2 or len(values) < 2:
            return {"correlation": 0.0, "pattern": "unknown"}

        # Check for constant values (no variance)
        if np.all(np.array(values) == values[0]):
            return {"correlation": 0.0, "pattern": "constant"}

        try:
            # Use Spearman rank correlation for trend
            trend_corr, _ = stats.spearmanr(indices, values)

            # Handle NaN or Inf
            if np.isnan(trend_corr) or np.isinf(trend_corr):
                return {"correlation": 0.0, "pattern": "unknown"}

            pattern = "increasing" if trend_corr > 0.5 else ("decreasing" if trend_corr < -0.5 else "mixed")
            return {"correlation": float(trend_corr), "pattern": pattern}
        except Exception:
            return {"correlation": 0.0, "pattern": "unknown"}

    def compute_statistics(self, metric_collectors):
        """
        Compute statistics and correlations for the collected metrics.
        Improved for numerical stability with added significance testing.
        """
        print("Computing statistics...")

        stats_results = {
            "by_layer": {},
            "overall": {}
        }

        metrics = ["relative_magnitude", "directional_influence", "information_content_change"]

        all_rel_mag = []
        all_dir_inf = []
        all_info_change = []
        all_ref_counts = []

        for layer_idx in metric_collectors["relative_magnitude"].keys():
            stats_results["by_layer"][layer_idx] = {}

            # Collect layer metrics
            rel_mag = metric_collectors["relative_magnitude"][layer_idx]
            dir_inf = metric_collectors["directional_influence"][layer_idx]
            info_change = metric_collectors["information_content_change"][layer_idx]
            ref_counts = metric_collectors["reference_token_count"][layer_idx]

            if not rel_mag:
                continue

            # Clean metrics of any remaining NaN or Inf values
            rel_mag = [v for v in rel_mag if not np.isnan(v) and not np.isinf(v)]
            dir_inf = [v for v in dir_inf if not np.isnan(v) and not np.isinf(v)]
            info_change = [v for v in info_change if not np.isnan(v) and not np.isinf(v)]
            ref_counts = [v for v in ref_counts if not np.isnan(v) and not np.isinf(v)]

            # Extend overall metrics
            all_rel_mag.extend(rel_mag)
            all_dir_inf.extend(dir_inf)
            all_info_change.extend(info_change)
            all_ref_counts.extend(ref_counts)

            # Calculate basic statistics for each metric in this layer
            layer_stats = {}

            # Process relative magnitude
            if len(rel_mag) >= 1:
                layer_stats["relative_magnitude"] = {
                    "mean": float(np.mean(rel_mag)),
                    "median": float(np.median(rel_mag)),
                    "std": float(np.std(rel_mag)) if len(rel_mag) > 1 else 0.0,
                    "min": float(np.min(rel_mag)),
                    "max": float(np.max(rel_mag))
                }

            # Process directional influence
            if len(dir_inf) >= 1:
                layer_stats["directional_influence"] = {
                    "mean": float(np.mean(dir_inf)),
                    "median": float(np.median(dir_inf)),
                    "std": float(np.std(dir_inf)) if len(dir_inf) > 1 else 0.0,
                    "min": float(np.min(dir_inf)),
                    "max": float(np.max(dir_inf))
                }

            # Process information content change
            if len(info_change) >= 1:
                layer_stats["information_content_change"] = {
                    "mean": float(np.mean(info_change)),
                    "median": float(np.median(info_change)),
                    "std": float(np.std(info_change)) if len(info_change) > 1 else 0.0,
                    "min": float(np.min(info_change)),
                    "max": float(np.max(info_change))
                }

            # Process reference token counts
            if len(ref_counts) >= 1:
                layer_stats["reference_token_count"] = {
                    "mean": float(np.mean(ref_counts)),
                    "median": float(np.median(ref_counts)),
                    "std": float(np.std(ref_counts)) if len(ref_counts) > 1 else 0.0,
                    "min": float(np.min(ref_counts)),
                    "max": float(np.max(ref_counts))
                }

            # Calculate correlations between metrics 
            if len(rel_mag) >= 3 and len(dir_inf) >= 3:
                corr_val = self.safe_correlation(rel_mag, dir_inf)
                layer_stats["corr_rel_mag_dir_inf"] = corr_val

                # Add significance test for correlation
                try:
                    _, p_value = stats.pearsonr(rel_mag, dir_inf)
                    layer_stats["corr_rel_mag_dir_inf_pvalue"] = float(p_value)
                    layer_stats["corr_rel_mag_dir_inf_significant"] = p_value < 0.05
                except:
                    layer_stats["corr_rel_mag_dir_inf_pvalue"] = 1.0
                    layer_stats["corr_rel_mag_dir_inf_significant"] = False

            if len(rel_mag) >= 3 and len(info_change) >= 3:
                corr_val = self.safe_correlation(rel_mag, info_change)
                layer_stats["corr_rel_mag_info_change"] = corr_val

                # Add significance test for correlation
                try:
                    _, p_value = stats.pearsonr(rel_mag, info_change)
                    layer_stats["corr_rel_mag_info_change_pvalue"] = float(p_value)
                    layer_stats["corr_rel_mag_info_change_significant"] = p_value < 0.05
                except:
                    layer_stats["corr_rel_mag_info_change_pvalue"] = 1.0
                    layer_stats["corr_rel_mag_info_change_significant"] = False

            if len(dir_inf) >= 3 and len(info_change) >= 3:
                corr_val = self.safe_correlation(dir_inf, info_change)
                layer_stats["corr_dir_inf_info_change"] = corr_val

                # Add significance test for correlation
                try:
                    _, p_value = stats.pearsonr(dir_inf, info_change)
                    layer_stats["corr_dir_inf_info_change_pvalue"] = float(p_value)
                    layer_stats["corr_dir_inf_info_change_significant"] = p_value < 0.05
                except:
                    layer_stats["corr_dir_inf_info_change_pvalue"] = 1.0
                    layer_stats["corr_dir_inf_info_change_significant"] = False

            stats_results["by_layer"][layer_idx] = layer_stats

        overall_stats = {}

        # Process overall relative magnitude
        if len(all_rel_mag) >= 1:
            overall_stats["relative_magnitude"] = {
                "mean": float(np.mean(all_rel_mag)),
                "median": float(np.median(all_rel_mag)),
                "std": float(np.std(all_rel_mag)) if len(all_rel_mag) > 1 else 0.0,
                "min": float(np.min(all_rel_mag)),
                "max": float(np.max(all_rel_mag))
            }

        # Process overall directional influence
        if len(all_dir_inf) >= 1:
            overall_stats["directional_influence"] = {
                "mean": float(np.mean(all_dir_inf)),
                "median": float(np.median(all_dir_inf)),
                "std": float(np.std(all_dir_inf)) if len(all_dir_inf) > 1 else 0.0,
                "min": float(np.min(all_dir_inf)),
                "max": float(np.max(all_dir_inf))
            }

        # Process overall information content change
        if len(all_info_change) >= 1:
            overall_stats["information_content_change"] = {
                "mean": float(np.mean(all_info_change)),
                "median": float(np.median(all_info_change)),
                "std": float(np.std(all_info_change)) if len(all_info_change) > 1 else 0.0,
                "min": float(np.min(all_info_change)),
                "max": float(np.max(all_info_change))
            }

        # Process overall reference token counts
        if len(all_ref_counts) >= 1:
            overall_stats["reference_token_count"] = {
                "mean": float(np.mean(all_ref_counts)),
                "median": float(np.median(all_ref_counts)),
                "std": float(np.std(all_ref_counts)) if len(all_ref_counts) > 1 else 0.0,
                "min": float(np.min(all_ref_counts)),
                "max": float(np.max(all_ref_counts))
            }

        # Calculate overall correlations with significance tests
        if len(all_rel_mag) >= 3 and len(all_dir_inf) >= 3:
            corr_val = self.safe_correlation(all_rel_mag, all_dir_inf)
            overall_stats["corr_rel_mag_dir_inf"] = corr_val

            # significance test
            try:
                _, p_value = stats.pearsonr(all_rel_mag, all_dir_inf)
                overall_stats["corr_rel_mag_dir_inf_pvalue"] = float(p_value)
                overall_stats["corr_rel_mag_dir_inf_significant"] = p_value < 0.05
            except:
                overall_stats["corr_rel_mag_dir_inf_pvalue"] = 1.0
                overall_stats["corr_rel_mag_dir_inf_significant"] = False

        if len(all_rel_mag) >= 3 and len(all_info_change) >= 3:
            corr_val = self.safe_correlation(all_rel_mag, all_info_change)
            overall_stats["corr_rel_mag_info_change"] = corr_val

            # significance test
            try:
                _, p_value = stats.pearsonr(all_rel_mag, all_info_change)
                overall_stats["corr_rel_mag_info_change_pvalue"] = float(p_value)
                overall_stats["corr_rel_mag_info_change_significant"] = p_value < 0.05
            except:
                overall_stats["corr_rel_mag_info_change_pvalue"] = 1.0
                overall_stats["corr_rel_mag_info_change_significant"] = False

        if len(all_dir_inf) >= 3 and len(all_info_change) >= 3:
            corr_val = self.safe_correlation(all_dir_inf, all_info_change)
            overall_stats["corr_dir_inf_info_change"] = corr_val

            # significance test
            try:
                _, p_value = stats.pearsonr(all_dir_inf, all_info_change)
                overall_stats["corr_dir_inf_info_change_pvalue"] = float(p_value)
                overall_stats["corr_dir_inf_info_change_significant"] = p_value < 0.05
            except:
                overall_stats["corr_dir_inf_info_change_pvalue"] = 1.0
                overall_stats["corr_dir_inf_info_change_significant"] = False

        # Layer pattern analysis 
        layer_indices = sorted([int(idx) for idx in metric_collectors["relative_magnitude"].keys() if metric_collectors["relative_magnitude"][idx]])
        if layer_indices and len(layer_indices) >= 2:
            layer_evolution = {}

            # Relative magnitude evolution
            rel_mag_by_layer = [self.safe_mean(metric_collectors["relative_magnitude"][str(idx)])
                              for idx in layer_indices
                              if metric_collectors["relative_magnitude"][str(idx)]]

            if len(rel_mag_by_layer) >= 2:
                layer_evolution["relative_magnitude_trend"] = self.safe_trend_analysis(
                    layer_indices[:len(rel_mag_by_layer)], rel_mag_by_layer
                )

            # Directional influence evolution
            dir_inf_by_layer = [self.safe_mean(metric_collectors["directional_influence"][str(idx)])
                              for idx in layer_indices
                              if metric_collectors["directional_influence"][str(idx)]]

            if len(dir_inf_by_layer) >= 2:
                layer_evolution["directional_influence_trend"] = self.safe_trend_analysis(
                    layer_indices[:len(dir_inf_by_layer)], dir_inf_by_layer
                )

            # Information content change evolution
            info_change_by_layer = [self.safe_mean(metric_collectors["information_content_change"][str(idx)])
                                  for idx in layer_indices
                                  if metric_collectors["information_content_change"][str(idx)]]

            if len(info_change_by_layer) >= 2:
                layer_evolution["information_content_change_trend"] = self.safe_trend_analysis(
                    layer_indices[:len(info_change_by_layer)], info_change_by_layer
                )

            # Store layer evolution patterns
            overall_stats["layer_evolution"] = layer_evolution

            # statistical test for early vs late layers comparison
            self.add_layer_comparison_tests(layer_indices, metric_collectors, overall_stats)

        stats_results["overall"] = overall_stats

        return stats_results

    def add_layer_comparison_tests(self, layer_indices, metric_collectors, overall_stats):
        """
        Add statistical tests to compare metrics between early and late layers.
        Uses paired t-tests to determine if there are significant differences.

        Args:
            layer_indices: list of layer indices sorted in ascending order
            metric_collectors: dictionary containing the collected metrics
            overall_stats: dictionary to store the comparison results
        """
        if len(layer_indices) < 4:  # Need at least 4 layers for meaningful comparison
            return

        early_layers = layer_indices[:len(layer_indices)//2]
        late_layers = layer_indices[len(layer_indices)//2:]

        layer_comparisons = {
            "early_vs_late": {}
        }

        # Compare metrics between early and late layers
        for metric_name in ["relative_magnitude", "directional_influence", "information_content_change"]:
            if metric_name not in metric_collectors:
                continue

            # Collect early and late layer metrics
            early_metrics = []
            for layer_idx in early_layers:
                layer_data = metric_collectors[metric_name].get(str(layer_idx), [])
                # Filter out NaN and Inf values
                layer_data = [v for v in layer_data if not np.isnan(v) and not np.isinf(v)]
                if layer_data:
                    early_metrics.append(self.safe_mean(layer_data))

            late_metrics = []
            for layer_idx in late_layers:
                layer_data = metric_collectors[metric_name].get(str(layer_idx), [])
                # Filter out NaN and Inf values
                layer_data = [v for v in layer_data if not np.isnan(v) and not np.isinf(v)]
                if layer_data:
                    late_metrics.append(self.safe_mean(layer_data))

            # Perform t-test if we have enough data
            if len(early_metrics) >= 2 and len(late_metrics) >= 2:
                try:
                    # Independent samples t-test
                    t_stat, p_value = stats.ttest_ind(early_metrics, late_metrics, equal_var=False)

                    comparison_result = {
                        "early_mean": float(np.mean(early_metrics)),
                        "late_mean": float(np.mean(late_metrics)),
                        "difference": float(np.mean(late_metrics) - np.mean(early_metrics)),
                        "t_statistic": float(t_stat),
                        "p_value": float(p_value),
                        "significant": p_value < 0.05
                    }

                    layer_comparisons["early_vs_late"][metric_name] = comparison_result

                except Exception as e:
                    # If t-test fails, store basic comparison without statistical test
                    comparison_result = {
                        "early_mean": float(np.mean(early_metrics)) if early_metrics else 0.0,
                        "late_mean": float(np.mean(late_metrics)) if late_metrics else 0.0,
                        "difference": float(np.mean(late_metrics) - np.mean(early_metrics)) if early_metrics and late_metrics else 0.0,
                        "error": str(e)
                    }

                    layer_comparisons["early_vs_late"][metric_name] = comparison_result

        overall_stats["layer_comparisons"] = layer_comparisons

    def generate_statistics_report(self):
        """
        Generate a simple text report with statistics, correlation values, and statistical tests.
        Shows data from first layer, last layer, and two intermediate layers.
        """
        if not self.results or "statistics" not in self.results:
            return "No statistics available. Run analyze_samples first."

        stats_results = self.results["statistics"]
        successful_analyses = self.results.get("successful_analyses", 0)

        lines = []
        lines.append(f"REFERENCE POINT ANALYSIS FOR {self.model_name}\n")
        lines.append(f"Based on {successful_analyses} text samples\n")

        lines.append("OVERALL STATISTICS")
        lines.append("=" * 20)

        overall = stats_results["overall"]
        if not overall:
            lines.append("\nNo overall statistics available.")
        else:
            for metric_name in ["relative_magnitude", "directional_influence", "information_content_change", "reference_token_count"]:
                if metric_name in overall:
                    lines.append(f"\n{metric_name.replace('_', ' ').title()}:")
                    metric_stats = overall[metric_name]
                    lines.append(f"  Mean: {metric_stats['mean']:.4f}")
                    lines.append(f"  Median: {metric_stats['median']:.4f}")
                    lines.append(f"  Std Dev: {metric_stats['std']:.4f}")
                    lines.append(f"  Range: {metric_stats['min']:.4f} to {metric_stats['max']:.4f}")

            # Correlations with significance tests
            lines.append("\nOverall Correlations:")
            corr_keys = [k for k in overall.keys() if k.startswith("corr_") and not k.endswith("_pvalue") and not k.endswith("_significant")]
            if corr_keys:
                for corr_key in corr_keys:
                    metric_names = corr_key.replace("corr_", "").split("_")
                    metric_readable = " vs. ".join([name.replace("_", " ").title() for name in metric_names])
                    corr_value = overall[corr_key]

                    sig_key = f"{corr_key}_significant"
                    pval_key = f"{corr_key}_pvalue"

                    if sig_key in overall and pval_key in overall:
                        is_significant = overall[sig_key]
                        p_value = overall[pval_key]
                        sig_marker = "* " if is_significant else ""
                        lines.append(f"  {metric_readable}: {corr_value:.4f} {sig_marker}(p = {p_value:.4f})")
                    else:
                        lines.append(f"  {metric_readable}: {corr_value:.4f}")
            else:
                lines.append("  No correlations available.")

            # Layer evolution patterns
            if "layer_evolution" in overall:
                lines.append("\nLayer Evolution Patterns:")
                for trend_key, trend_data in overall["layer_evolution"].items():
                    metric_name = trend_key.replace("_trend", "").replace("_", " ").title()
                    lines.append(f"  {metric_name}: {trend_data['pattern'].title()} (correlation: {trend_data['correlation']:.4f})")

            # Add early vs late layer comparison results
            if "layer_comparisons" in overall and "early_vs_late" in overall["layer_comparisons"]:
                lines.append("\nEarly vs. Late Layer Comparisons (t-tests):")

                comparisons = overall["layer_comparisons"]["early_vs_late"]
                for metric_name, comp_data in comparisons.items():
                    metric_readable = metric_name.replace("_", " ").title()

                    if "significant" in comp_data:
                        sig_marker = "* " if comp_data["significant"] else ""
                        lines.append(f"  {metric_readable}:")
                        lines.append(f"    Early layers mean: {comp_data['early_mean']:.4f}")
                        lines.append(f"    Late layers mean: {comp_data['late_mean']:.4f}")
                        lines.append(f"    Difference: {comp_data['difference']:.4f} {sig_marker}")
                        lines.append(f"    t-statistic: {comp_data['t_statistic']:.4f}, p-value: {comp_data['p_value']:.4f}")
                    else:
                        lines.append(f"  {metric_readable}: Insufficient data for statistical test")

            lines.append("\nLAYER-SPECIFIC STATISTICS (select layers):")

            valid_layers = sorted([int(layer_idx) for layer_idx in stats_results["by_layer"].keys()
                                  if stats_results["by_layer"][layer_idx]],
                                  key=lambda x: int(x))

            if not valid_layers:
                lines.append("\nNo layer-specific statistics available.")
            else:
                # Select specific layers to report
                num_layers = max(valid_layers) + 1  # +1 because layer indices are 0-based

                # Choose layers to analyze (first, quarter, three-quarters, last)
                first_layer = 0
                quarter_layer = max(1, num_layers // 4)  
                three_quarter_layer = max(2, (num_layers * 3) // 4) 
                last_layer = max(valid_layers)

                selected_layers = sorted(list(set([first_layer, quarter_layer, three_quarter_layer, last_layer])))

                selected_layers = [layer for layer in selected_layers if str(layer) in stats_results["by_layer"]]

                for layer_idx in selected_layers:
                    layer_stats = stats_results["by_layer"][str(layer_idx)]
                    if not layer_stats:
                        continue

                    if layer_idx == first_layer:
                        position_label = "FIRST LAYER"
                    elif layer_idx == last_layer:
                        position_label = "LAST LAYER"
                    elif layer_idx == quarter_layer:
                        position_label = "QUARTER-DEPTH LAYER"
                    elif layer_idx == three_quarter_layer:
                        position_label = "THREE-QUARTER-DEPTH LAYER"
                    else:
                        position_label = ""

                    lines.append(f"\nLayer {layer_idx}: {position_label}")

                    for metric_name in ["relative_magnitude", "directional_influence", "information_content_change"]:
                        if metric_name in layer_stats:
                            metric_readable = metric_name.replace("_", " ").title()
                            lines.append(f"  {metric_readable} Mean: {layer_stats[metric_name]['mean']:.4f}")

                    corr_keys = [k for k in layer_stats.keys() if k.startswith("corr_") and not k.endswith("_pvalue") and not k.endswith("_significant")]
                    if corr_keys:
                        lines.append("  Correlations:")
                        for corr_key in corr_keys:
                            metric_names = corr_key.replace("corr_", "").split("_")
                            metric_readable = " vs. ".join([name.replace("_", " ").title() for name in metric_names])
                            corr_value = layer_stats[corr_key]

                            sig_key = f"{corr_key}_significant"
                            pval_key = f"{corr_key}_pvalue"

                            if sig_key in layer_stats and pval_key in layer_stats:
                                is_significant = layer_stats[sig_key]
                                p_value = layer_stats[pval_key]
                                sig_marker = "* " if is_significant else ""
                                lines.append(f"    {metric_readable}: {corr_value:.4f} {sig_marker}(p = {p_value:.4f})")
                            else:
                                lines.append(f"    {metric_readable}: {corr_value:.4f}")

            lines.append("\nKEY FINDINGS:")

            if not overall:
                lines.append("- Not enough data to generate key findings.")
            else:
                if "relative_magnitude" in overall:
                    rel_mag = overall["relative_magnitude"]
                    if rel_mag["mean"] > 0.5:
                        lines.append("- Reference tokens have a STRONG influence on value transformations")
                    elif rel_mag["mean"] > 0.3:
                        lines.append("- Reference tokens have a MODERATE influence on value transformations")
                    else:
                        lines.append("- Reference tokens have a WEAK influence on value transformations")

                if "directional_influence" in overall:
                    dir_inf = overall["directional_influence"]
                    if dir_inf["mean"] > 0.7:
                        lines.append("- Reference tokens strongly ALIGN with the overall transformation direction")
                    elif dir_inf["mean"] > 0.4:
                        lines.append("- Reference tokens moderately align with the overall transformation direction")
                    else:
                        lines.append("- Reference tokens contribute in directions ORTHOGONAL to the overall transformation")

                if "layer_evolution" in overall:
                    layer_evo = overall["layer_evolution"]

                    if "relative_magnitude_trend" in layer_evo:
                        rm_trend = layer_evo["relative_magnitude_trend"]
                        if rm_trend["pattern"] == "increasing":
                            lines.append("- Reference token influence INCREASES in deeper layers")
                        elif rm_trend["pattern"] == "decreasing":
                            lines.append("- Reference token influence DECREASES in deeper layers")

                    # Add layer comparison insights if we have both first and last layer
                    selected_layers = sorted(list(set([first_layer, quarter_layer, three_quarter_layer, last_layer])))
                    if first_layer in selected_layers and last_layer in selected_layers and first_layer != last_layer:
                        try:
                            first_rm = stats_results["by_layer"][str(first_layer)]["relative_magnitude"]["mean"]
                            last_rm = stats_results["by_layer"][str(last_layer)]["relative_magnitude"]["mean"]

                            rm_diff = last_rm - first_rm
                            if abs(rm_diff) > 0.1:  # Significant difference
                                if rm_diff > 0:
                                    lines.append(f"- Reference influence GROWS by {rm_diff:.2f} from first to last layer")
                                else:
                                    lines.append(f"- Reference influence DECREASES by {abs(rm_diff):.2f} from first to last layer")
                        except (KeyError, TypeError):
                            pass  

                if "layer_comparisons" in overall and "early_vs_late" in overall["layer_comparisons"]:
                    comparisons = overall["layer_comparisons"]["early_vs_late"]

                    for metric_name, comp_data in comparisons.items():
                        if "significant" in comp_data and comp_data["significant"]:
                            metric_readable = metric_name.replace("_", " ").title()
                            diff = comp_data["difference"]
                            direction = "INCREASES" if diff > 0 else "DECREASES"
                            lines.append(f"- {metric_readable} SIGNIFICANTLY {direction} from early to late layers (p < 0.05)")

            lines.append("\nNote: * indicates statistical significance at p < 0.05")

            report_path = os.path.join(self.output_dir, "reference_point_analysis.txt")
            with open(report_path, 'w') as f:
                f.write('\n'.join(lines))

            print(f"Report saved to {report_path}")
            return '\n'.join(lines)

    def run_analysis(self, texts=None):
        """
        Run the complete reference point analysis pipeline.
        """
        start_time = time.time()

        if self.model is None:
            self.load_model()

            if self.model is None:
                return "Model loading failed. Cannot continue analysis."

        if texts is None or not texts:
            texts = [
                "The concept of attention in transformer models relates to how tokens interact.",
                "In deep learning, reference points help establish coordinate systems for representation.",
                "Attention mechanisms create dynamic connections between tokens in a sequence.",
                "Self-attention allows each token to gather information from all other tokens."
            ]

        stats_results = self.analyze_samples(texts)

        if stats_results:
            report = self.generate_statistics_report()
        else:
            report = "Analysis did not produce valid statistics."

        total_time = time.time() - start_time
        print(f"Analysis completed in {total_time/60:.2f} minutes")

        return report

    def cleanup(self):
        """Clean up resources."""
        if self.model is not None:
            self.model = self.model.to("cpu")
            del self.model
            self.model = None

        if self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None

        self.hidden_states = None
        self.attention_matrices = None

        gc.collect()
        torch.cuda.empty_cache()

        print("Resources cleaned up")


def get_sample_texts_from_dataset(dataset_path, n_samples=500):
    """
    Extract sample texts from a dataset for analysis.
    """
    try:
        data = pd.read_csv(dataset_path)
        print(f"Loaded dataset with {len(data)} rows")

        if 'text' not in data.columns:
            potential_text_columns = ['content', 'sentence', 'document', 'passage']
            found_column = None

            for col in potential_text_columns:
                if col in data.columns:
                    found_column = col
                    break

            if found_column:
                print(f"No 'text' column found, using '{found_column}' instead")
            else:
                print("Error: No suitable text column found in dataset")
                return []
        else:
            found_column = 'text'

        if len(data) > n_samples:
            samples = data.sample(n_samples)
        else:
            samples = data

        texts = samples[found_column].tolist()

        texts = [t for t in texts if t is not None and str(t).strip()]

        return texts

    except Exception as e:
        print(f"Error loading dataset: {e}")
        traceback.print_exc()
        return []



In [None]:

def main():
    """Run the Reference Point Analysis."""
    try:
        from google.colab import drive
        # Mount Google Drive if in Colab
        drive.mount('/content/drive')
        is_colab = True
    except ImportError:
        is_colab = False
        print("Not running in Google Colab, skipping drive mount")

    model_name = "meta-llama/llama-3.2-3B"  

    if is_colab:
        dataset_path = "/content/drive/MyDrive/wiki_dataset_position.csv"  
        output_dir = f"/content/drive/MyDrive/Sink/reference_analysis/{model_name}"
    else:
        dataset_path = "./dataset.csv"
        output_dir = f"./reference_analysis_{model_name.replace('/', '_')}"

    analyzer = ReferencePointAnalysis(
        model_name=model_name,
        output_dir=output_dir
    )
    analyzer.debug = False

    texts = get_sample_texts_from_dataset(dataset_path, n_samples=500)  

    if not texts:
        texts = [
            "The concept of attention in transformer models relates to how tokens interact.",
            "In deep learning, reference points help establish coordinate systems for representation.",
            "Attention mechanisms create dynamic connections between tokens in a sequence.",
            "Self-attention allows each token to gather information from all other tokens."
        ]

    try:
        print(f"Running analysis with {len(texts)} text samples")

        analysis_report = analyzer.run_analysis(texts=texts)

        print("\nREFERENCE POINT ANALYSIS SUMMARY:")
        print("==========================")

        print(analysis_report)

        print("\nAnalysis completed successfully!")
    except Exception as e:
        print(f"Analysis error: {e}")
        traceback.print_exc()
    finally:
        analyzer.cleanup()


if __name__ == "__main__":
    main()