In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import numpy as np
import torch
import os
import gc
import json
import pandas as pd
import time
from tqdm import tqdm
import traceback
import warnings
from sklearn.decomposition import PCA
import scipy.stats as stats
from google.colab import drive
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig
import bitsandbytes as bnb

In [None]:

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

class AttentionValueCorrelationAnalysis:
    """
    Analyzes correlation patterns between attention distributions and
    value transformations in transformer models.
    """

    def __init__(self, model_name, output_dir="./attention_value_correlation"):
        """Initialize the analysis with model name and output directory."""
        self.model_name = model_name
        self.output_dir = output_dir
        self.model = None
        self.tokenizer = None
        self.results = {}

        self.hidden_states = None
        self.attention_matrices = None

        self.debug = True

        os.makedirs(output_dir, exist_ok=True)

    def load_model(self, use_4bit=True):
        """Load the model and tokenizer with robust error handling."""
        print(f"Loading model: {self.model_name}")

        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)

            # Check for 4-bit quantization
            if use_4bit:
                try:
                    quantization_config = BitsAndBytesConfig(
                        load_in_4bit=True,
                        bnb_4bit_compute_dtype=torch.float16
                    )

                    self.model = AutoModelForCausalLM.from_pretrained(
                        self.model_name,
                        torch_dtype=torch.float16,
                        quantization_config=quantization_config,
                        device_map="auto",
                        trust_remote_code=True,
                        output_hidden_states=True
                    )
                    print("Model loaded with 4-bit quantization")
                except (ImportError, ModuleNotFoundError) as e:
                    print(f"BitsAndBytes not available: {e}")
                    print("Falling back to standard loading...")
                    self.model = AutoModelForCausalLM.from_pretrained(
                        self.model_name,
                        torch_dtype=torch.float16,
                        device_map="auto",
                        trust_remote_code=True,
                        output_hidden_states=True
                    )
            else:
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    trust_remote_code=True,
                    output_hidden_states=True
                )

        except Exception as e:
            print(f"Error loading model: {e}")
            print("This might be due to an unsupported model architecture or outdated transformers library.")
            print("Try updating with: pip install --upgrade transformers")
            print("Or install from source: pip install git+https://github.com/huggingface/transformers.git")
            print("Using a fallback approach...")
            return None, None

        gc.collect()
        torch.cuda.empty_cache()

        print(f"Model loaded successfully")
        return self.model, self.tokenizer

    def extract_value_vectors(self, hidden_states, layer_idx):
        """
        Extract value vectors for a specific layer.
        This is a modified version that correctly handles the hidden states format
        from different model architectures.

        Args:
            hidden_states: list of tensors, hidden states for each layer
            layer_idx: int, index of the layer

        Returns:
            numpy array of value vectors with shape (seq_len, hidden_dim)
        """
        try:
            layer_hidden = hidden_states[layer_idx]

            if isinstance(layer_hidden, torch.Tensor):
                layer_hidden = layer_hidden.cpu().numpy()

            # Ensure the right shape
            # Expected shape: (batch_size, seq_len, hidden_dim)
            if self.debug:
                print(f"Hidden state shape: {layer_hidden.shape}")

            # Handle different shape formats
            if len(layer_hidden.shape) == 3:
                if layer_hidden.shape[0] == 1:  # (batch_size, seq_len, hidden_dim)
                    return layer_hidden[0]  # Return (seq_len, hidden_dim)
                elif layer_hidden.shape[1] == 1:  # (seq_len, batch_size, hidden_dim)
                    return layer_hidden[:, 0, :]  # Return (seq_len, hidden_dim)
                else:
                    # If batch dimension is neither 0 nor 1, use the first batch
                    return layer_hidden[0]  # Assume (batch_size, seq_len, hidden_dim)
            elif len(layer_hidden.shape) == 2:
                # Already in the format (seq_len, hidden_dim)
                return layer_hidden
            else:
                raise ValueError(f"Unexpected hidden state shape: {layer_hidden.shape}")

        except Exception as e:
            print(f"Error in extract_value_vectors: {e}")
            traceback.print_exc()
            # Return a dummy tensor as fallback
            return np.zeros((1, 1))

    def calculate_attention_entropy(self, attention_matrix):
        """
        Calculate the entropy of attention distributions for each token.
        Designed to handle extreme cases with special numerical stability measures.

        Args:
            attention_matrix: numpy array of shape (seq_len, seq_len)

        Returns:
            numpy array of entropy values for each token
        """
        try:
            # Use machine epsilon for maximum precision
            eps = np.finfo(np.float32).eps * 100  
            seq_len = attention_matrix.shape[0]
            entropy = np.zeros(seq_len)

            # Process each row individually 
            for i in range(seq_len):
                attn_dist = attention_matrix[i]

                # Replace zeros/negatives with epsilon to avoid log(0)
                valid_indices = attn_dist > eps

                if not np.any(valid_indices):
                    entropy[i] = 0.0
                    continue

                # Normalize the valid attention weights
                valid_attn = attn_dist[valid_indices]
                valid_attn_sum = np.sum(valid_attn)

                if valid_attn_sum < eps:
                    entropy[i] = 0.0
                    continue

                normalized_attn = valid_attn / valid_attn_sum

                # entropy for valid values
                # H = -sum(p_i * log(p_i)) - only include terms where p_i > 0
                valid_entropy = 0.0
                for p in normalized_attn:
                    if p > eps:  
                        valid_entropy -= p * np.log(p)

                entropy[i] = valid_entropy

            # Handle any remaining NaN or Inf values
            entropy = np.nan_to_num(entropy, nan=0.0, posinf=0.0, neginf=0.0)

            return entropy
        except Exception as e:
            print(f"Error calculating attention entropy: {e}")
            # Return dummy values
            return np.ones(attention_matrix.shape[0]) * 0.5


    def calculate_semantic_similarity_matrix(self, value_vectors):
        """
        Calculate semantic similarity matrix between all value vectors.
        Uses extremely robust numerical methods to handle any value ranges.

        Args:
            value_vectors: numpy array of value vectors

        Returns:
            numpy array of semantic similarities
        """
        try:
            # Get dimensions
            seq_len = value_vectors.shape[0]

            # Initialize similarity matrix
            similarity_matrix = np.zeros((seq_len, seq_len))

            # Handle any NaN or Inf values
            value_vectors_safe = np.nan_to_num(value_vectors, nan=0.0, posinf=0.0, neginf=0.0)

            # Check for extreme values and scale 
            max_abs = np.max(np.abs(value_vectors_safe))
            if max_abs > 1e3:  # Lower threshold for scaling
                scale_factor = 1e3 / max_abs
                value_vectors_safe = value_vectors_safe * scale_factor

            norms = np.zeros(seq_len)
            for i in range(seq_len):
                vec = value_vectors_safe[i]
                squared_sum = 0.0
                chunk_size = 1000  
                for j in range(0, len(vec), chunk_size):
                    chunk = vec[j:j+chunk_size]
                    # Use np.float64 for better precision
                    chunk_sum = np.sum(np.square(chunk.astype(np.float64)))
                    squared_sum += chunk_sum

                norms[i] = np.sqrt(max(1e-10, squared_sum))  # Ensure minimum value

            # Compute similarity matrix
            for i in range(seq_len):
                for j in range(seq_len):
                    # Special case for self-similarity
                    if i == j:
                        similarity_matrix[i, j] = 1.0
                        continue

                    vec_i = value_vectors_safe[i]
                    vec_j = value_vectors_safe[j]

                    if norms[i] < 1e-8 or norms[j] < 1e-8:
                        similarity_matrix[i, j] = 0.0
                        continue

                    # Calculate dot product in chunks 
                    dot_product = 0.0
                    chunk_size = 1000
                    for k in range(0, len(vec_i), chunk_size):
                        chunk_i = vec_i[k:k+chunk_size].astype(np.float64)
                        chunk_j = vec_j[k:k+chunk_size].astype(np.float64)
                        chunk_product = np.sum(chunk_i * chunk_j)
                        dot_product += chunk_product

                    # Compute cosine similarity
                    similarity = dot_product / (norms[i] * norms[j])

                    # Ensure value is in valid range
                    similarity_matrix[i, j] = np.clip(similarity, -1.0, 1.0)

            return similarity_matrix
        except Exception as e:
            print(f"Error calculating semantic similarity: {e}")
            size = value_vectors.shape[0] if hasattr(value_vectors, 'shape') else 1
            return np.eye(size)


    def calculate_value_transformation_magnitude(self, pre_values, post_values):
        """
        Calculate the magnitude of value vector transformations.
        Uses specialized techniques to avoid overflow and numerical issues.

        Args:
            pre_values: numpy array of value vectors before attention
            post_values: numpy array of value vectors after attention

        Returns:
            numpy array of transformation magnitudes
        """
        try:
            # Check that dimensions match
            if pre_values.shape[0] != post_values.shape[0]:
                print(f"WARNING: Pre-values shape {pre_values.shape} doesn't match post-values shape {post_values.shape}")
                # Return dummy values
                return np.ones(max(pre_values.shape[0], post_values.shape[0])) * 0.5

            seq_len = pre_values.shape[0]

            magnitudes = np.zeros(seq_len)

            # Handle any NaN or Inf values
            pre_safe = np.nan_to_num(pre_values, nan=0.0, posinf=0.0, neginf=0.0)
            post_safe = np.nan_to_num(post_values, nan=0.0, posinf=0.0, neginf=0.0)

            # Check for extreme values and scale if necessary
            pre_max = np.max(np.abs(pre_safe))
            post_max = np.max(np.abs(post_safe))
            max_val = max(pre_max, post_max)

            if float(max_val) > 1e3:  # Lower threshold
                scale_factor = 1e3 / float(max_val)
                pre_safe = pre_safe * scale_factor
                post_safe = post_safe * scale_factor

            batch_size = min(100, seq_len)  

            for batch_start in range(0, seq_len, batch_size):
                batch_end = min(batch_start + batch_size, seq_len)

                diffs = post_safe[batch_start:batch_end] - pre_safe[batch_start:batch_end]

                for i in range(batch_end - batch_start):
                    idx = batch_start + i
                    diff = diffs[i]

                    diff_clipped = np.clip(diff, -1e3, 1e3)

                    # Calculate L2 norm in chunks
                    squared_sum = 0.0
                    chunk_size = 1000
                    for j in range(0, len(diff_clipped), chunk_size):
                        chunk = diff_clipped[j:j+chunk_size].astype(np.float64)
                        chunk_sum = np.sum(np.square(chunk))
                        squared_sum += chunk_sum

                    magnitudes[idx] = np.sqrt(max(0.0, squared_sum))

            magnitudes = np.clip(magnitudes, 0.0, 1e6)
            magnitudes = np.nan_to_num(magnitudes, nan=0.0, posinf=1.0, neginf=0.0)

            return magnitudes
        except Exception as e:
            print(f"Error calculating transformation magnitude: {e}")
            return np.ones(pre_values.shape[0] if hasattr(pre_values, 'shape') else 1) * 0.5


    def perform_pca_analysis(self, attention_matrix, value_vectors, n_components=3):
        """
        Perform PCA on attention and value matrices to identify dominant patterns.
        Uses numerically stable approaches to avoid warnings and errors.

        Args:
            attention_matrix: numpy array of attention weights
            value_vectors: numpy array of value vectors
            n_components: int, number of components to extract

        Returns:
            dict with PCA results
        """
        try:
            if attention_matrix.shape[0] < 2 or value_vectors.shape[0] < 2:
                return {
                    "correlation": 0.0,
                    "error": "Not enough samples for PCA",
                    "average_correlation": 0.0
                }

            attn_clean = np.nan_to_num(attention_matrix, nan=0.0, posinf=0.0, neginf=0.0)
            value_clean = np.nan_to_num(value_vectors, nan=0.0, posinf=0.0, neginf=0.0)

            attn_max = np.max(np.abs(attn_clean))
            value_max = np.max(np.abs(value_clean))

            if attn_max > 1e3:  
                scale_factor = 1e3 / float(attn_max)
                attn_clean = attn_clean * scale_factor

            if value_max > 1e3:  
                scale_factor = 1e3 / float(value_max)
                value_clean = value_clean * scale_factor

            attn_min_dim = min(attn_clean.shape)
            value_min_dim = min(value_clean.shape)
            n_components = min(n_components, attn_min_dim - 1, value_min_dim - 1)

            if n_components <= 0:
                return {
                    "correlation": 0.0,
                    "error": "Not enough dimensions for PCA",
                    "average_correlation": 0.0
                }

            try:
                attention_pca = PCA(n_components=n_components, svd_solver='randomized')
                attention_pcs = attention_pca.fit_transform(attn_clean)

                value_pca = PCA(n_components=n_components, svd_solver='randomized')
                value_pcs = value_pca.fit_transform(value_clean)

                # Calculate correlation between the principal components
                correlations = []
                for i in range(n_components):
                    if i < len(attention_pcs[0]) and i < len(value_pcs[0]):
                        attn_projection = attention_pcs[:, i]
                        value_projection = value_pcs[:, i]

                        attn_projection = np.nan_to_num(attn_projection)
                        value_projection = np.nan_to_num(value_projection)

                        if np.std(attn_projection) > 1e-8 and np.std(value_projection) > 1e-8:
                            corr = np.corrcoef(attn_projection, value_projection)[0, 1]
                            if not np.isnan(corr) and not np.isinf(corr):
                                correlations.append(float(corr))

                avg_correlation = np.mean(correlations) if correlations else 0.0

                return {
                    "attention_explained_variance": attention_pca.explained_variance_ratio_.tolist(),
                    "value_explained_variance": value_pca.explained_variance_ratio_.tolist(),
                    "pc_correlations": correlations,
                    "average_correlation": float(avg_correlation)
                }

            except Exception as inner_e:
                print(f"PCA using sklearn failed: {inner_e}. Using simplified approach.")
                return {
                    "correlation": 0.0,
                    "error": f"PCA failed: {str(inner_e)}",
                    "average_correlation": 0.0
                }

        except Exception as e:
            print(f"Error in PCA analysis: {e}")
            return {
                "correlation": 0.0,
                "error": str(e),
                "average_correlation": 0.0
            }

    def analyze_text(self, text):
        """
        Analyze a text sample and compute correlation metrics between
        attention distributions and value transformations.
        """
        if self.model is None or self.tokenizer is None:
            print("Model or tokenizer not loaded!")
            return {"error": "Model or tokenizer not loaded"}

        try:
            if self.debug:
                print(f"Analyzing text: '{text[:50]}...'")

            inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)

            if self.debug:
                print(f"Tokenized input shape: {inputs.input_ids.shape}")

            with torch.no_grad():
                outputs = self.model(
                    **inputs,
                    output_attentions=True,
                    output_hidden_states=True,
                    return_dict=True
                )

            # Extract attention patterns and hidden states
            attentions = outputs.attentions  # tuple of (layer, batch, head, seq_len, seq_len)
            hidden_states = outputs.hidden_states  # tuple of (layer+1, batch, seq_len, hidden_dim)

            if self.debug:
                print(f"Number of attention layers: {len(attentions)}")
                print(f"Number of hidden state layers: {len(hidden_states)}")
                if len(attentions) > 0:
                    print(f"Attention shape for first layer: {attentions[0].shape}")
                if len(hidden_states) > 0:
                    print(f"Hidden state shape for first layer: {hidden_states[0].shape}")

            self.attention_matrices = attentions
            self.hidden_states = hidden_states

            results = {
                "layers": {},
                "sequence_length": inputs.input_ids.shape[1]
            }

            num_layers = len(attentions)
            for layer_idx in range(num_layers):
                if self.debug:
                    print(f"Processing layer {layer_idx+1}/{num_layers}")

                # Each layer has shape (batch, head, seq_len, seq_len)
                layer_attention = attentions[layer_idx][0].cpu().numpy()  # shape: (head, seq_len, seq_len)

                # Compute average attention pattern across all heads
                avg_attention = np.mean(layer_attention, axis=0)

                if self.debug:
                    print(f"Average attention shape: {avg_attention.shape}")

                # Get value vectors for pre and post attention
                pre_values = self.extract_value_vectors(hidden_states, layer_idx)
                post_values = self.extract_value_vectors(hidden_states, layer_idx + 1)

                if self.debug:
                    print(f"Pre-values shape: {pre_values.shape}")
                    print(f"Post-values shape: {post_values.shape}")

                if pre_values.shape[0] != avg_attention.shape[0] or post_values.shape[0] != avg_attention.shape[0]:
                    print(f"WARNING: Value vector dimensions don't match attention matrix dimensions")
                    print(f"Attention shape: {avg_attention.shape}, Pre-values shape: {pre_values.shape}, Post-values shape: {post_values.shape}")

                    results["layers"][str(layer_idx)] = {
                        "attention_entropy_mean": 0.5,
                        "attention_entropy_std": 0.1,
                        "transformation_magnitude_mean": 0.5,
                        "transformation_magnitude_std": 0.1,
                        "geometric_semantic_correlation": 0.0,
                        "entropy_magnitude_correlation": 0.0,
                        "pca_analysis": {
                            "average_correlation": 0.0,
                            "error": "Dimension mismatch"
                        }
                    }
                    continue

                # Calculate attention entropy
                attention_entropy = self.calculate_attention_entropy(avg_attention)

                # Calculate value transformation magnitude
                transformation_magnitude = self.calculate_value_transformation_magnitude(pre_values, post_values)

                # Calculate semantic similarity matrix for value vectors
                semantic_similarity = self.calculate_semantic_similarity_matrix(pre_values)

                # Calculate correlation between attention weights and semantic similarity
                try:
                    # Flatten both matrices (exclude diagonal elements)
                    attn_flat = []
                    sem_flat = []
                    for i in range(avg_attention.shape[0]):
                        for j in range(avg_attention.shape[1]):
                            if i != j:  # Exclude self-attention/similarity
                                attn_flat.append(avg_attention[i, j])
                                sem_flat.append(semantic_similarity[i, j])

                    # Calculate correlation
                    if len(attn_flat) > 1 and np.std(attn_flat) > 0 and np.std(sem_flat) > 0:
                        geometric_semantic_corr = float(np.corrcoef(attn_flat, sem_flat)[0, 1])
                    else:
                        geometric_semantic_corr = 0.0
                except Exception as e:
                    print(f"Error calculating geometric-semantic correlation: {e}")
                    geometric_semantic_corr = 0.0

                # Calculate correlation between attention entropy and transformation magnitude
                try:
                    if len(attention_entropy) > 1 and len(transformation_magnitude) > 1 and np.std(attention_entropy) > 0 and np.std(transformation_magnitude) > 0:
                        entropy_magnitude_corr = float(np.corrcoef(attention_entropy, transformation_magnitude)[0, 1])
                    else:
                        entropy_magnitude_corr = 0.0
                except Exception as e:
                    print(f"Error calculating entropy-magnitude correlation: {e}")
                    entropy_magnitude_corr = 0.0

                pca_results = self.perform_pca_analysis(avg_attention, pre_values)

                results["layers"][str(layer_idx)] = {
                    "attention_entropy_mean": float(np.mean(attention_entropy)),
                    "attention_entropy_std": float(np.std(attention_entropy)),
                    "transformation_magnitude_mean": float(np.mean(transformation_magnitude)),
                    "transformation_magnitude_std": float(np.std(transformation_magnitude)),
                    "geometric_semantic_correlation": geometric_semantic_corr,
                    "entropy_magnitude_correlation": entropy_magnitude_corr,
                    "pca_analysis": pca_results
                }

            return results

        except Exception as e:
            print(f"Error analyzing text: {e}")
            traceback.print_exc()
            return {"error": str(e)}

    def analyze_samples(self, texts, max_failures=20):
        """
        Analyze multiple text samples and collect metrics for correlation analysis.

        Args:
            texts: list of text samples
            max_failures: maximum number of consecutive failures before giving up
        """
        print(f"Analyzing {len(texts)} text samples...")

        if self.model is None:
            self.load_model()

        metric_collectors = {
            "attention_entropy_mean": {},
            "transformation_magnitude_mean": {},
            "geometric_semantic_correlation": {},
            "entropy_magnitude_correlation": {},
            "pca_average_correlation": {}
        }

        successful_analyses = 0
        consecutive_failures = 0

        for i, text in enumerate(tqdm(texts, desc="Processing texts")):
            try:
                result = self.analyze_text(text)

                if "error" in result:
                    consecutive_failures += 1
                    if consecutive_failures >= max_failures:
                        print(f"Too many consecutive failures ({max_failures}). Stopping analysis.")
                        break
                    continue

                consecutive_failures = 0
                successful_analyses += 1

                # Collect metrics from each layer
                for layer_idx, layer_data in result["layers"].items():
                    for metric in metric_collectors.keys():
                        if layer_idx not in metric_collectors[metric]:
                            metric_collectors[metric][layer_idx] = []

                    metric_collectors["attention_entropy_mean"][layer_idx].append(
                        layer_data["attention_entropy_mean"]
                    )

                    metric_collectors["transformation_magnitude_mean"][layer_idx].append(
                        layer_data["transformation_magnitude_mean"]
                    )

                    metric_collectors["geometric_semantic_correlation"][layer_idx].append(
                        layer_data["geometric_semantic_correlation"]
                    )

                    metric_collectors["entropy_magnitude_correlation"][layer_idx].append(
                        layer_data["entropy_magnitude_correlation"]
                    )

                    if "pca_analysis" in layer_data and "average_correlation" in layer_data["pca_analysis"]:
                        metric_collectors["pca_average_correlation"][layer_idx].append(
                            layer_data["pca_analysis"]["average_correlation"]
                        )

                torch.cuda.empty_cache()
                gc.collect()

                if (i+1) % 5 == 0:
                    print(f"Processed {i+1}/{len(texts)} samples")

            except Exception as e:
                print(f"Error analyzing text {i+1}: {str(e)}")
                traceback.print_exc()
                consecutive_failures += 1
                if consecutive_failures >= max_failures:
                    print(f"Too many consecutive failures ({max_failures}). Stopping analysis.")
                    break

        print(f"Successfully analyzed {successful_analyses}/{len(texts)} samples")

        if successful_analyses > 0:
            stats_results = self.compute_statistics(metric_collectors)

            self.results = {
                "metric_collectors": metric_collectors,
                "statistics": stats_results,
                "successful_analyses": successful_analyses
            }

            return stats_results
        else:
            print("No successful analyses to compute statistics")
            return {}

    def compute_statistics(self, metric_collectors):
        """
        Compute statistics for the collected metrics with added statistical significance testing.
        """
        print("Computing statistics with significance testing...")

        stats_results = {
            "by_layer": {},
            "overall": {},
            "significance_tests": {}  
        }

        metrics = list(metric_collectors.keys())

        all_metrics = {metric: [] for metric in metrics}

        for layer_idx in metric_collectors["attention_entropy_mean"].keys():
            stats_results["by_layer"][layer_idx] = {}

            layer_stats = {}

            for metric in metrics:
                if layer_idx in metric_collectors[metric] and metric_collectors[metric][layer_idx]:
                    values = metric_collectors[metric][layer_idx]

                    all_metrics[metric].extend(values)

                    if len(values) >= 1:
                        layer_stats[metric] = {
                            "mean": float(np.mean(values)),
                            "median": float(np.median(values)),
                            "std": float(np.std(values)) if len(values) > 1 else 0.0,
                            "min": float(np.min(values)),
                            "max": float(np.max(values)),
                            "sample_size": len(values)
                        }

                        # Add one-sample t-test against zero (useful for correlation metrics)
                        if len(values) >= 5:  
                            try:
                                t_stat, p_value = stats.ttest_1samp(values, 0)
                                layer_stats[metric]["t_test"] = {
                                    "t_statistic": float(t_stat),
                                    "p_value": float(p_value),
                                    "significant": p_value < 0.05,
                                    "confidence_level": self._get_confidence_level(p_value)
                                }
                            except Exception as e:
                                print(f"Error calculating t-test for {metric} in layer {layer_idx}: {e}")

            # Calculate correlations between metrics in this layer
            try:
                corr_metrics = ["attention_entropy_mean", "transformation_magnitude_mean",
                              "geometric_semantic_correlation", "entropy_magnitude_correlation"]

                for i, metric1 in enumerate(corr_metrics):
                    for j, metric2 in enumerate(corr_metrics[i+1:], i+1):
                        if (metric1 in layer_stats and metric2 in layer_stats and
                            layer_idx in metric_collectors[metric1] and
                            layer_idx in metric_collectors[metric2]):

                            values1 = metric_collectors[metric1][layer_idx]
                            values2 = metric_collectors[metric2][layer_idx]

                            min_len = min(len(values1), len(values2))
                            if min_len >= 5:  # Minimum for meaningful correlation
                                values1 = values1[:min_len]
                                values2 = values2[:min_len]

                                # Calculate both Pearson (linear) and Spearman (rank) correlations
                                if np.std(values1) > 0 and np.std(values2) > 0:
                                    # Pearson correlation with p-value
                                    pearson_r, pearson_p = stats.pearsonr(values1, values2)

                                    # Spearman rank correlation with p-value
                                    spearman_r, spearman_p = stats.spearmanr(values1, values2)

                                    # Create correlation key
                                    corr_key = f"corr_{metric1}_{metric2}"

                                    # Store correlation statistics
                                    layer_stats[corr_key] = {
                                        "pearson": {
                                            "r": float(pearson_r),
                                            "p_value": float(pearson_p),
                                            "significant": pearson_p < 0.05,
                                            "confidence_level": self._get_confidence_level(pearson_p)
                                        },
                                        "spearman": {
                                            "rho": float(spearman_r),
                                            "p_value": float(spearman_p),
                                            "significant": spearman_p < 0.05,
                                            "confidence_level": self._get_confidence_level(spearman_p)
                                        },
                                        "sample_size": min_len
                                    }
            except Exception as e:
                print(f"Error calculating layer correlations: {e}")
                traceback.print_exc()

            # Store layer statistics
            stats_results["by_layer"][layer_idx] = layer_stats

        # Calculate overall statistics
        overall_stats = {}

        for metric in metrics:
            values = all_metrics[metric]
            if len(values) >= 1:
                overall_stats[metric] = {
                    "mean": float(np.mean(values)),
                    "median": float(np.median(values)),
                    "std": float(np.std(values)) if len(values) > 1 else 0.0,
                    "min": float(np.min(values)),
                    "max": float(np.max(values)),
                    "sample_size": len(values)
                }

                # Add one-sample t-test for overall metrics
                if len(values) >= 5:  # Minimum sample size for meaningful t-test
                    try:
                        t_stat, p_value = stats.ttest_1samp(values, 0)
                        overall_stats[metric]["t_test"] = {
                            "t_statistic": float(t_stat),
                            "p_value": float(p_value),
                            "significant": p_value < 0.05,
                            "confidence_level": self._get_confidence_level(p_value)
                        }
                    except Exception as e:
                        print(f"Error calculating overall t-test for {metric}: {e}")

        # Layer pattern analysis - how metrics evolve across layers
        layer_indices = sorted([int(idx) for idx in metric_collectors["attention_entropy_mean"].keys()
                              if metric_collectors["attention_entropy_mean"][idx]])

        if layer_indices and len(layer_indices) >= 2:
            # Analyze how metrics evolve across layers
            layer_evolution = {}

            for metric in metrics:
                try:
                    values_by_layer = [np.mean(metric_collectors[metric][str(idx)])
                                    for idx in layer_indices
                                    if str(idx) in metric_collectors[metric] and metric_collectors[metric][str(idx)]]

                    if len(values_by_layer) >= 2:
                        # Check for monotonic trend with improved significance testing
                        # Spearman rank correlation is better for trend analysis
                        spearman_r, spearman_p = stats.spearmanr(layer_indices[:len(values_by_layer)], values_by_layer)

                        # Pearson for linear trend
                        pearson_r, pearson_p = stats.pearsonr(layer_indices[:len(values_by_layer)], values_by_layer)

                        layer_evolution[f"{metric}_trend"] = {
                            "spearman": {
                                "rho": float(spearman_r),
                                "p_value": float(spearman_p),
                                "significant": spearman_p < 0.05,
                                "confidence_level": self._get_confidence_level(spearman_p)
                            },
                            "pearson": {
                                "r": float(pearson_r),
                                "p_value": float(pearson_p),
                                "significant": pearson_p < 0.05,
                                "confidence_level": self._get_confidence_level(pearson_p)
                            },
                            "pattern": "increasing" if spearman_r > 0.5 else ("decreasing" if spearman_r < -0.5 else "mixed"),
                            "sample_size": len(values_by_layer)
                        }
                except Exception as e:
                    print(f"Error analyzing layer evolution for {metric}: {e}")
                    layer_evolution[f"{metric}_trend"] = {
                        "correlation": 0.0,
                        "pattern": "unknown",
                        "error": str(e)
                    }

            # Store layer evolution patterns
            overall_stats["layer_evolution"] = layer_evolution

        # Perform cross-layer comparison tests
        if len(layer_indices) >= 2:
            first_layer = str(min(layer_indices))
            last_layer = str(max(layer_indices))

            # Initialize cross-layer test results
            cross_layer_tests = {}

            # Compare first and last layer for each metric
            for metric in metrics:
                if (first_layer in metric_collectors[metric] and
                    last_layer in metric_collectors[metric] and
                    len(metric_collectors[metric][first_layer]) >= 5 and
                    len(metric_collectors[metric][last_layer]) >= 5):

                    first_layer_values = metric_collectors[metric][first_layer]
                    last_layer_values = metric_collectors[metric][last_layer]

                    # Perform independent samples t-test
                    try:
                        t_stat, p_value = stats.ttest_ind(first_layer_values, last_layer_values, equal_var=False)

                        cross_layer_tests[f"{metric}_first_vs_last"] = {
                            "first_layer_mean": float(np.mean(first_layer_values)),
                            "last_layer_mean": float(np.mean(last_layer_values)),
                            "difference": float(np.mean(last_layer_values) - np.mean(first_layer_values)),
                            "t_statistic": float(t_stat),
                            "p_value": float(p_value),
                            "significant": p_value < 0.05,
                            "confidence_level": self._get_confidence_level(p_value),
                            "effect": "increase" if np.mean(last_layer_values) > np.mean(first_layer_values) else "decrease",
                            "sample_sizes": {
                                "first_layer": len(first_layer_values),
                                "last_layer": len(last_layer_values)
                            }
                        }
                    except Exception as e:
                        print(f"Error calculating cross-layer test for {metric}: {e}")

            # Add cross-layer tests to results
            stats_results["significance_tests"]["cross_layer"] = cross_layer_tests

        # Store overall statistics
        stats_results["overall"] = overall_stats

        return stats_results

    def _get_confidence_level(self, p_value):
        """
        Convert p-value to a confidence level description.

        Args:
            p_value: The p-value from a statistical test

        Returns:
            String describing the confidence level
        """
        if p_value < 0.001:
            return "very strong (p<0.001)"
        elif p_value < 0.01:
            return "strong (p<0.01)"
        elif p_value < 0.05:
            return "significant (p<0.05)"
        elif p_value < 0.1:
            return "marginal (p<0.1)"
        else:
            return "not significant (p≥0.1)"

    def generate_statistics_report(self):
        """
        Generate a text report with statistics, correlation values, and statistical significance.
        Shows data from first layer, last layer, and two intermediate layers.
        """
        if not self.results or "statistics" not in self.results:
            return "No statistics available. Run analyze_samples first."

        stats_results = self.results["statistics"]
        successful_analyses = self.results.get("successful_analyses", 0)

        lines = []
        lines.append(f"ATTENTION-VALUE CORRELATION ANALYSIS FOR {self.model_name}\n")
        lines.append(f"Based on {successful_analyses} text samples\n")

        lines.append("OVERALL STATISTICS")
        lines.append("=" * 20)

        overall = stats_results["overall"]
        if not overall:
            lines.append("\nNo overall statistics available.")
        else:
            metric_display = {
                "attention_entropy_mean": "Attention Distribution Entropy",
                "transformation_magnitude_mean": "Value Transformation Magnitude",
                "geometric_semantic_correlation": "Geometric-Semantic Alignment",
                "entropy_magnitude_correlation": "Entropy-Magnitude Correlation",
                "pca_average_correlation": "PCA Component Correlation"
            }

            for metric in ["attention_entropy_mean", "transformation_magnitude_mean",
                        "geometric_semantic_correlation", "entropy_magnitude_correlation",
                        "pca_average_correlation"]:
                if metric in overall:
                    lines.append(f"\n{metric_display[metric]}:")
                    metric_stats = overall[metric]
                    lines.append(f"  Mean: {metric_stats['mean']:.4f}")
                    lines.append(f"  Median: {metric_stats['median']:.4f}")
                    lines.append(f"  Std Dev: {metric_stats['std']:.4f}")
                    lines.append(f"  Range: {metric_stats['min']:.4f} to {metric_stats['max']:.4f}")
                    lines.append(f"  Sample Size: {metric_stats.get('sample_size', 'N/A')}")

                    if "t_test" in metric_stats:
                        t_test = metric_stats["t_test"]
                        p_value = t_test["p_value"]
                        significance = "* SIGNIFICANT *" if t_test["significant"] else "not significant"
                        lines.append(f"  Statistical Significance: {significance}")
                        lines.append(f"  p-value: {p_value:.6f} ({t_test['confidence_level']})")
                        lines.append(f"  t-statistic: {t_test['t_statistic']:.4f}")

            if "layer_evolution" in overall:
                lines.append("\nLayer Evolution Patterns (with Statistical Significance):")
                for trend_key, trend_data in overall["layer_evolution"].items():
                    base_metric = trend_key.replace("_trend", "")
                    if base_metric in metric_display:
                        metric_name = metric_display[base_metric]

                        if "spearman" in trend_data and "pearson" in trend_data:
                            spearman = trend_data["spearman"]
                            pearson = trend_data["pearson"]

                            lines.append(f"  {metric_name}:")
                            lines.append(f"    Pattern: {trend_data['pattern'].title()}")

                            spearman_sig = "* SIGNIFICANT *" if spearman["significant"] else "not significant"
                            lines.append(f"    Spearman Rank Correlation: {spearman['rho']:.4f} ({spearman_sig})")
                            lines.append(f"    p-value: {spearman['p_value']:.6f} ({spearman['confidence_level']})")

                            pearson_sig = "* SIGNIFICANT *" if pearson["significant"] else "not significant"
                            lines.append(f"    Pearson Linear Correlation: {pearson['r']:.4f} ({pearson_sig})")
                            lines.append(f"    p-value: {pearson['p_value']:.6f} ({pearson['confidence_level']})")

                            lines.append(f"    Sample Size: {trend_data.get('sample_size', 'N/A')}")
                        else:
                            lines.append(f"  {metric_name}: {trend_data['pattern'].title()} (correlation: {trend_data.get('correlation', 0.0):.4f})")

        if "significance_tests" in stats_results and "cross_layer" in stats_results["significance_tests"]:
            cross_layer_tests = stats_results["significance_tests"]["cross_layer"]
            if cross_layer_tests:
                lines.append("\nCROSS-LAYER STATISTICAL TESTS:")
                lines.append("=" * 30)

                for test_key, test_data in cross_layer_tests.items():
                    base_metric = test_key.replace("_first_vs_last", "")
                    metric_name = metric_display.get(base_metric, base_metric.replace("_", " ").title())

                    lines.append(f"\n{metric_name} (First vs. Last Layer):")

                    lines.append(f"  First Layer Mean: {test_data['first_layer_mean']:.4f}")
                    lines.append(f"  Last Layer Mean: {test_data['last_layer_mean']:.4f}")
                    lines.append(f"  Difference: {test_data['difference']:.4f} ({test_data['effect']})")

                    significance = "* SIGNIFICANT *" if test_data["significant"] else "not significant"
                    lines.append(f"  Statistical Significance: {significance}")
                    lines.append(f"  p-value: {test_data['p_value']:.6f} ({test_data['confidence_level']})")
                    lines.append(f"  t-statistic: {test_data['t_statistic']:.4f}")
                    lines.append(f"  Sample Sizes: {test_data['sample_sizes']['first_layer']} (first), {test_data['sample_sizes']['last_layer']} (last)")

        lines.append("\nLAYER-SPECIFIC STATISTICS (select layers):")

        valid_layers = sorted([int(layer_idx) for layer_idx in stats_results["by_layer"].keys()
                              if stats_results["by_layer"][layer_idx]],
                            key=lambda x: int(x))

        if not valid_layers:
            lines.append("\nNo layer-specific statistics available.")
        else:
            num_layers = max(valid_layers) + 1  # +1 because layer indices are 0-based

            # Choose layers to analyze (first, quarter, three-quarters, last)
            first_layer = 0
            quarter_layer = max(1, num_layers // 4)  
            three_quarter_layer = max(2, (num_layers * 3) // 4) 
            last_layer = max(valid_layers)

            selected_layers = sorted(list(set([first_layer, quarter_layer, three_quarter_layer, last_layer])))

            selected_layers = [layer for layer in selected_layers if str(layer) in stats_results["by_layer"]]

            for layer_idx in selected_layers:
                layer_stats = stats_results["by_layer"][str(layer_idx)]
                if not layer_stats:
                    continue

                if layer_idx == first_layer:
                    position_label = "FIRST LAYER"
                elif layer_idx == last_layer:
                    position_label = "LAST LAYER"
                elif layer_idx == quarter_layer:
                    position_label = "QUARTER-DEPTH LAYER"
                elif layer_idx == three_quarter_layer:
                    position_label = "THREE-QUARTER-DEPTH LAYER"
                else:
                    position_label = ""

                lines.append(f"\nLayer {layer_idx}: {position_label}")

                # Show mean values for key metrics with significance
                for metric in ["attention_entropy_mean", "transformation_magnitude_mean",
                              "geometric_semantic_correlation", "entropy_magnitude_correlation"]:
                    if metric in layer_stats:
                        metric_readable = metric_display.get(metric, metric.replace("_", " ").title())

                        lines.append(f"  {metric_readable}: {layer_stats[metric]['mean']:.4f}")

                        if "t_test" in layer_stats[metric]:
                            t_test = layer_stats[metric]["t_test"]
                            significance = "* SIGNIFICANT *" if t_test["significant"] else "not significant"
                            lines.append(f"    Significance: {significance} (p={t_test['p_value']:.6f}, {t_test['confidence_level']})")

                if "pca_average_correlation" in layer_stats:
                    metric_readable = metric_display["pca_average_correlation"]
                    lines.append(f"  {metric_readable}: {layer_stats['pca_average_correlation']['mean']:.4f}")

                    if "t_test" in layer_stats["pca_average_correlation"]:
                        t_test = layer_stats["pca_average_correlation"]["t_test"]
                        significance = "* SIGNIFICANT *" if t_test["significant"] else "not significant"
                        lines.append(f"    Significance: {significance} (p={t_test['p_value']:.6f}, {t_test['confidence_level']})")

                correlation_keys = [k for k in layer_stats.keys() if k.startswith("corr_")]
                if correlation_keys:
                    lines.append("\n  Correlation Analysis (with Statistical Significance):")

                    for corr_key in correlation_keys:
                        corr_data = layer_stats[corr_key]

                        metric_parts = corr_key.replace("corr_", "").split("_")
                        if len(metric_parts) >= 2:
                            last_idx = len(metric_parts) // 2
                            metric1_parts = metric_parts[:last_idx]
                            metric2_parts = metric_parts[last_idx:]

                            metric1 = "_".join(metric1_parts)
                            metric2 = "_".join(metric2_parts)

                            metric1_readable = metric_display.get(metric1, metric1.replace("_", " ").title())
                            metric2_readable = metric_display.get(metric2, metric2.replace("_", " ").title())

                            corr_name = f"{metric1_readable} vs. {metric2_readable}"
                        else:
                            corr_name = corr_key.replace("corr_", "").replace("_", " ").title()

                        lines.append(f"    {corr_name}:")

                        # Pearson correlation with significance
                        if "pearson" in corr_data:
                            pearson = corr_data["pearson"]
                            pearson_sig = "* SIGNIFICANT *" if pearson["significant"] else "not significant"
                            lines.append(f"      Pearson r: {pearson['r']:.4f} ({pearson_sig})")
                            lines.append(f"      p-value: {pearson['p_value']:.6f} ({pearson['confidence_level']})")

                        if "spearman" in corr_data:
                            spearman = corr_data["spearman"]
                            spearman_sig = "* SIGNIFICANT *" if spearman["significant"] else "not significant"
                            lines.append(f"      Spearman rho: {spearman['rho']:.4f} ({spearman_sig})")
                            lines.append(f"      p-value: {spearman['p_value']:.6f} ({spearman['confidence_level']})")

                        if "sample_size" in corr_data:
                            lines.append(f"      Sample Size: {corr_data['sample_size']}")

        lines.append("\nKEY FINDINGS (WITH STATISTICAL SIGNIFICANCE):")
        lines.append("=" * 40)

        if not overall:
            lines.append("- Not enough data to generate key findings.")
        else:
            findings = []

            # Geometric-Semantic Alignment insights
            if "geometric_semantic_correlation" in overall:
                geo_sem = overall["geometric_semantic_correlation"]
                significance = ""
                if "t_test" in geo_sem:
                    if geo_sem["t_test"]["significant"]:
                        significance = f" (STATISTICALLY SIGNIFICANT, p={geo_sem['t_test']['p_value']:.6f})"
                    else:
                        significance = f" (not statistically significant, p={geo_sem['t_test']['p_value']:.6f})"

                if geo_sem["mean"] > 0.5:
                    findings.append(f"- STRONG alignment between attention patterns and semantic similarity in value space{significance}")
                elif geo_sem["mean"] > 0.2:
                    findings.append(f"- MODERATE alignment between attention patterns and semantic similarity in value space{significance}")
                elif geo_sem["mean"] > 0:
                    findings.append(f"- WEAK alignment between attention patterns and semantic similarity in value space{significance}")
                else:
                    findings.append(f"- NEGATIVE correlation between attention patterns and semantic similarity{significance}")

            # Entropy-Magnitude Correlation insights
            if "entropy_magnitude_correlation" in overall:
                ent_mag = overall["entropy_magnitude_correlation"]
                significance = ""
                if "t_test" in ent_mag:
                    if ent_mag["t_test"]["significant"]:
                        significance = f" (STATISTICALLY SIGNIFICANT, p={ent_mag['t_test']['p_value']:.6f})"
                    else:
                        significance = f" (not statistically significant, p={ent_mag['t_test']['p_value']:.6f})"

                if ent_mag["mean"] > 0.5:
                    findings.append(f"- Tokens with DIVERSE attention patterns undergo LARGER value transformations{significance}")
                elif ent_mag["mean"] > 0.2:
                    findings.append(f"- Tokens with diverse attention patterns undergo moderately larger value transformations{significance}")
                elif ent_mag["mean"] > 0:
                    findings.append(f"- Weak relationship between attention diversity and transformation magnitude{significance}")
                else:
                    findings.append(f"- Tokens with FOCUSED attention patterns undergo LARGER value transformations{significance}")

            # Layer evolution insights with significance
            if "layer_evolution" in overall:
                layer_evo = overall["layer_evolution"]

                if "geometric_semantic_correlation_trend" in layer_evo:
                    trend = layer_evo["geometric_semantic_correlation_trend"]
                    significance = ""

                    if "spearman" in trend and "p_value" in trend["spearman"]:
                        p_value = trend["spearman"]["p_value"]
                        if p_value < 0.05:
                            significance = f" (STATISTICALLY SIGNIFICANT, p={p_value:.6f})"
                        else:
                            significance = f" (not statistically significant, p={p_value:.6f})"

                    if trend["pattern"] == "increasing":
                        findings.append(f"- Attention-semantic alignment INCREASES in deeper layers{significance}")
                    elif trend["pattern"] == "decreasing":
                        findings.append(f"- Attention-semantic alignment DECREASES in deeper layers{significance}")

            if "significance_tests" in stats_results and "cross_layer" in stats_results["significance_tests"]:
                cross_tests = stats_results["significance_tests"]["cross_layer"]

                for test_key, test_data in cross_tests.items():
                    if test_data["significant"]:
                        base_metric = test_key.replace("_first_vs_last", "")
                        metric_name = metric_display.get(base_metric, base_metric.replace("_", " ").title())
                        direction = "INCREASES" if test_data["effect"] == "increase" else "DECREASES"

                        finding = (f"- {metric_name} {direction} by {abs(test_data['difference']):.4f} from first to last layer "
                                  f"(STATISTICALLY SIGNIFICANT, p={test_data['p_value']:.6f})")
                        findings.append(finding)

            for finding in findings:
                lines.append(finding)

        report_path = os.path.join(self.output_dir, "attention_value_correlation.txt")
        with open(report_path, 'w') as f:
            f.write('\n'.join(lines))

        print(f"Report saved to {report_path}")
        return '\n'.join(lines)

    def run_analysis(self, texts=None):
        """
        Run the complete attention-value correlation analysis pipeline.
        """
        start_time = time.time()

        if self.model is None:
            self.load_model()

            if self.model is None:
                return "Model loading failed. Cannot continue analysis."

        if texts is None or not texts:
            texts = [
                "The concept of attention in transformer models relates to how tokens interact.",
                "In deep learning, reference points help establish coordinate systems for representation.",
                "Attention mechanisms create dynamic connections between tokens in a sequence.",
                "Self-attention allows each token to gather information from all other tokens."
            ]

        stats_results = self.analyze_samples(texts)

        if stats_results:
            report = self.generate_statistics_report()
        else:
            report = "Analysis did not produce valid statistics."

        total_time = time.time() - start_time
        print(f"Analysis completed in {total_time/60:.2f} minutes")

        return report

    def cleanup(self):
        """Clean up resources."""
        if self.model is not None:
            self.model = self.model.to("cpu")
            del self.model
            self.model = None

        if self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None

        self.hidden_states = None
        self.attention_matrices = None

        gc.collect()
        torch.cuda.empty_cache()

        print("Resources cleaned up")


def get_sample_texts_from_dataset(dataset_path, n_samples=100):
    """
    Extract sample texts from a dataset for analysis.
    """
    try:
        data = pd.read_csv(dataset_path)
        print(f"Loaded dataset with {len(data)} rows")

        if 'text' not in data.columns:
            potential_text_columns = ['content', 'sentence', 'document', 'passage']
            found_column = None

            for col in potential_text_columns:
                if col in data.columns:
                    found_column = col
                    break

            if found_column:
                print(f"No 'text' column found, using '{found_column}' instead")
            else:
                print("Error: No suitable text column found in dataset")
                return []
        else:
            found_column = 'text'

        if len(data) > n_samples:
            samples = data.sample(n_samples)
        else:
            samples = data

        texts = samples[found_column].tolist()

        texts = [t for t in texts if t is not None and str(t).strip()]

        return texts

    except Exception as e:
        print(f"Error loading dataset: {e}")
        traceback.print_exc()
        return []




In [None]:
def main():
    """Run the Attention-Value Correlation Analysis."""
    try:
        drive.mount('/content/drive')
        is_colab = True
    except ImportError:
        is_colab = False
        print("Not running in Google Colab, skipping drive mount")

    model_name = "meta-llama/llama-3.2-3B" 

    if is_colab:
        dataset_path = "/content/drive/MyDrive/wiki_dataset_position.csv"  
        output_dir = f"/content/drive/MyDrive/Sink/attn_value_correlation/{model_name.replace('/', '_')}"
    else:
        dataset_path = "./dataset.csv"
        output_dir = f"./attn_value_correlation_{model_name.replace('/', '_')}"

    analyzer = AttentionValueCorrelationAnalysis(
        model_name=model_name,
        output_dir=output_dir
    )
    # Turn off debug mode for production runs
    analyzer.debug = False

    texts = get_sample_texts_from_dataset(dataset_path, n_samples=500) 

    if not texts:
        texts = [
            "The concept of attention in transformer models relates to how tokens interact.",
            "In deep learning, reference points help establish coordinate systems for representation.",
            "Attention mechanisms create dynamic connections between tokens in a sequence.",
            "Self-attention allows each token to gather information from all other tokens."
        ]

    try:
        print(f"Running analysis with {len(texts)} text samples")

        analysis_report = analyzer.run_analysis(texts=texts)

        print("\nATTENTION-VALUE CORRELATION ANALYSIS SUMMARY:")
        print("==========================================")

        print(analysis_report)

        print("\nAnalysis completed successfully!")
    except Exception as e:
        print(f"Analysis error: {e}")
        traceback.print_exc()
    finally:
        analyzer.cleanup()


if __name__ == "__main__":
    main()