In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import numpy as np
import torch
import os
import gc
import json
import pandas as pd
from scipy.stats import entropy, pearsonr, spearmanr
import traceback
import time
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import tabulate
import warnings

In [None]:

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

class UltraFastAttentionSinkKLAnalysis:
    """
    Ultra-optimized version of the attention sink KL analysis designed for speed.
    Focuses only on generating the final summary report with minimal computation.
    """

    def __init__(self, model_name, output_dir="./attention_sink_kl_analysis",
                max_seq_length=128, verbose=False, skip_layers=True, optimize_calcs=True):
        """Initialize the analysis with model name and output directory."""
        self.model_name = model_name
        self.output_dir = output_dir
        self.model = None
        self.tokenizer = None
        self.aggregated_results = {}
        self.sample_metrics = []
        self.verbose = verbose

        # Speed optimization parameters
        self.max_seq_length = max_seq_length  # Truncate sequences to this length
        self.skip_layers = skip_layers        # Whether to analyze only every other layer
        self.optimize_calcs = optimize_calcs  # Whether to optimize KL calculations

        self.layer_count = 0

        self.peak_memory = 0

        os.makedirs(output_dir, exist_ok=True)

        if self.verbose:
            print(f"Optimization settings:")
            print(f" - Max sequence length: {self.max_seq_length}")
            print(f" - Skip layers: {self.skip_layers}")
            print(f" - Optimize calculations: {self.optimize_calcs}")

    def log(self, message, force=False):
        """Log messages only when verbose is enabled or forced."""
        if self.verbose or force:
            print(message)

    def _track_memory(self):
        """Track GPU memory usage."""
        if not torch.cuda.is_available():
            return

        current = torch.cuda.memory_allocated() / (1024 ** 3)  # GB
        peak = torch.cuda.max_memory_allocated() / (1024 ** 3)  # GB

        if peak > self.peak_memory:
            self.peak_memory = peak

    def load_model(self, use_4bit=True):
        """Load the decoder-only model and tokenizer."""
        self.log(f"Loading model: {self.model_name}", force=True)

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.tokenizer.model_max_length = self.max_seq_length

        if use_4bit:
            try:
                from transformers import BitsAndBytesConfig

                quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.float16
                )

                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    torch_dtype=torch.float16,
                    quantization_config=quantization_config,
                    device_map="auto",
                    trust_remote_code=True
                )
            except Exception as e:
                self.log(f"Error with 4-bit quantization: {e}", force=True)
                self.log("Falling back to standard loading...", force=True)
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    trust_remote_code=True
                )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True
            )

        # Get model layer count
        self.layer_count = self._get_model_layer_count()
        self.log(f"Model has {self.layer_count} layers", force=True)

        # Force garbage collection
        gc.collect()
        torch.cuda.empty_cache()
        self._track_memory()

        self.log(f"Model loaded: {self.model_name}", force=True)
        return self.model, self.tokenizer

    def _get_model_layer_count(self):
        """Get the number of layers in the model."""
        if not self.model:
            return 0

        try:
            # Check for config attribute with num_hidden_layers (most common)
            if hasattr(self.model, 'config') and hasattr(self.model.config, 'num_hidden_layers'):
                return self.model.config.num_hidden_layers

            # For BERT-like models
            elif hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'layer'):
                return len(self.model.encoder.layer)

            # For GPT-like models
            elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
                return len(self.model.transformer.h)

            # For LLaMA, Qwen models
            elif hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'):
                return len(self.model.model.layers)

            # Direct layers attribute (some newer models)
            elif hasattr(self.model, 'layers'):
                return len(self.model.layers)

            # Fallback to common layer counts
            else:
                return 32  

        except Exception as e:
            self.log(f"Error detecting layer count: {e}. Using default value 32.", force=True)
            return 32  # Fallback to a reasonable default

    def identify_attention_sinks(self, attention_matrix, threshold=0.9):
        """
        Optimized function to identify attention sinks in an attention matrix.
        """
        # Get dimensions
        num_heads, seq_len, _ = attention_matrix.shape

        # Calculate incoming attention for each token across all heads
        # Sum across dim=0 (heads) and dim=1 (source tokens)
        total_incoming_attention = attention_matrix.sum(axis=(0, 1))

        # Calculate threshold value based on percentile
        threshold_value = np.percentile(total_incoming_attention, threshold * 100)

        # Identify sinks (tokens with incoming attention above threshold)
        sink_indices = np.where(total_incoming_attention >= threshold_value)[0]

        # Calculate what percentage of total attention goes to sinks
        total_attention = total_incoming_attention.sum()
        sink_attention = total_incoming_attention[sink_indices].sum()
        sink_attention_percentage = (sink_attention / total_attention) * 100 if total_attention > 0 else 0

        if not self.optimize_calcs:
            # Calculate incoming attention per head for each token
            per_head_incoming = np.zeros((num_heads, seq_len))
            for h in range(num_heads):
                per_head_incoming[h] = attention_matrix[h].sum(axis=0)  # Sum over source tokens

            # Calculate attention concentration
            head_sink_concentration = np.zeros(num_heads)
            for h in range(num_heads):
                head_total = per_head_incoming[h].sum()
                head_to_sinks = per_head_incoming[h][sink_indices].sum()
                head_sink_concentration[h] = (head_to_sinks / head_total) if head_total > 0 else 0

            avg_sink_concentration = np.mean(head_sink_concentration)
        else:
            avg_sink_concentration = 0

            if len(sink_indices) > 0:
                # Take a sample of attention to estimate concentration
                sample_heads = min(num_heads, 4)  
                sample_concentration = 0

                for h in range(sample_heads):
                    head_attention = attention_matrix[h]
                    head_total = head_attention.sum()
                    if head_total > 0:
                        head_to_sinks = head_attention[:, sink_indices].sum()
                        sample_concentration += (head_to_sinks / head_total)

                avg_sink_concentration = sample_concentration / sample_heads

        return {
            "sink_indices": sink_indices.tolist(),
            "sink_count": len(sink_indices),
            "sink_attention_percentage": float(sink_attention_percentage),
            "avg_sink_concentration": float(avg_sink_concentration)
        }

    def fast_kl_divergence(self, attn_matrix, sink_indices=None):
        """
        Ultra-optimized KL divergence calculation focusing only on essential metrics.
        Now with added statistical significance (p-value) testing.
        """
        num_heads, seq_len, _ = attn_matrix.shape

        # Ensure probabilities sum to 1 for each source token - using proper broadcasting
        attn_probs = attn_matrix.copy()

        # Handle row normalization properly 
        for h in range(num_heads):
            for i in range(seq_len):
                row_sum = attn_probs[h, i].sum()
                if row_sum > 0:
                    attn_probs[h, i] = attn_probs[h, i] / row_sum
                else:
                    # Set uniform distribution for zero rows
                    attn_probs[h, i] = np.ones(seq_len) / seq_len

        # 1. Calculate average attention distribution across heads
        avg_attn = np.mean(attn_probs, axis=0)  # (seq_len, seq_len)

        # 2. Calculate KL divergence between each head and the average
        if self.optimize_calcs and seq_len > 20:
            # For long sequences, sample only a subset of tokens
            sample_tokens = min(20, seq_len)  # Use at most 20 tokens
            sample_indices = np.linspace(0, seq_len-1, sample_tokens, dtype=int)

            kl_to_avg = np.zeros(num_heads)
            for h in range(num_heads):
                head_kl = np.zeros(sample_tokens)
                for i, idx in enumerate(sample_indices):
                    head_kl[i] = entropy(attn_probs[h, idx], avg_attn[idx])
                kl_to_avg[h] = np.mean(head_kl)
        else:
            # For shorter sequences, compute full KL divergence
            kl_to_avg = np.zeros(num_heads)
            for h in range(num_heads):
                head_kl = np.zeros(seq_len)
                for i in range(seq_len):
                    head_kl[i] = entropy(attn_probs[h, i], avg_attn[i])
                kl_to_avg[h] = np.mean(head_kl)

        # Compute metrics for sink tokens
        sink_metrics = {}
        if sink_indices is not None and len(sink_indices) > 0:
            # Create masked attention distributions without sinks
            masked_attn = attn_probs.copy()

            # Zero out attention to sinks and renormalize
            for h in range(num_heads):
                for i in range(seq_len):
                    # Zero out attention to sinks
                    for sink in sink_indices:
                        if sink < seq_len:  # Make sure sink index is valid
                            masked_attn[h, i, sink] = 0

                    # Renormalize
                    row_sum = masked_attn[h, i].sum()
                    if row_sum > 0:
                        masked_attn[h, i] = masked_attn[h, i] / row_sum
                    else:
                        # If all attention went to sinks, distribute uniformly over non-sinks
                        non_sink_indices = [j for j in range(seq_len) if j not in sink_indices]
                        if non_sink_indices:
                            for j in non_sink_indices:
                                masked_attn[h, i, j] = 1.0 / len(non_sink_indices)

            # Calculate average masked attention
            avg_masked_attn = np.mean(masked_attn, axis=0)

            # KL divergence with and without sinks 
            if self.optimize_calcs and seq_len > 20:
                sample_tokens = min(20, seq_len)
                sample_indices = np.linspace(0, seq_len-1, sample_tokens, dtype=int)

                kl_without_sinks = np.zeros(num_heads)
                for h in range(num_heads):
                    head_kl = np.zeros(sample_tokens)
                    for i, idx in enumerate(sample_indices):
                        head_kl[i] = entropy(masked_attn[h, idx], avg_masked_attn[idx])
                    kl_without_sinks[h] = np.mean(head_kl)
            else:
                kl_without_sinks = np.zeros(num_heads)
                for h in range(num_heads):
                    head_kl = np.zeros(seq_len)
                    for i in range(seq_len):
                        head_kl[i] = entropy(masked_attn[h, i], avg_masked_attn[i])
                    kl_without_sinks[h] = np.mean(head_kl)

            # Compare KL divergence with and without sinks
            kl_reduction = kl_to_avg - kl_without_sinks

            # Add p-value testing for KL reduction using paired t-test
            if not self.optimize_calcs and len(kl_to_avg) >= 2:
                from scipy import stats
                t_stat, p_value = stats.ttest_rel(kl_to_avg, kl_without_sinks)
                kl_reduction_significant = bool(p_value < 0.05)  # Cast to bool for proper JSON serialization
            else:
                p_value = 1.0
                kl_reduction_significant = False

            # Calculate correlation between sink attention and KL divergence
            if self.optimize_calcs:
                sink_kl_correlation = 0.0  
                correlation_p_value = 1.0
                correlation_significant = False
            else:
                # Calculate correlation between sink attention and KL divergence
                sink_attn_per_head = np.zeros((num_heads, len(sink_indices)))
                for h in range(num_heads):
                    for i, sink in enumerate(sink_indices):
                        if sink < seq_len:  # Make sure sink index is valid
                            sink_attn_per_head[h, i] = np.mean(attn_probs[h, :, sink])

                sink_attention_avg = np.mean(sink_attn_per_head, axis=1)
                if np.std(sink_attention_avg) > 1e-10 and np.std(kl_to_avg) > 1e-10:
                    # Only calculate correlation if we have non-constant data
                    try:
                        from scipy import stats
                        sink_kl_correlation, correlation_p_value = stats.pearsonr(sink_attention_avg, kl_to_avg)
                        correlation_significant = bool(p_value < 0.05)  # Cast to bool for proper JSON serialization
                    except:
                        sink_kl_correlation = 0.0
                        correlation_p_value = 1.0
                        correlation_significant = False
                else:
                    sink_kl_correlation = np.nan
                    correlation_p_value = 1.0
                    correlation_significant = False

            sink_metrics = {
                "kl_without_sinks": kl_without_sinks.tolist(),
                "kl_reduction": kl_reduction.tolist(),
                "avg_kl_reduction": float(np.mean(kl_reduction)),
                "kl_reduction_p_value": float(p_value) if not self.optimize_calcs else 1.0,
                "kl_reduction_significant": kl_reduction_significant,
                "sink_kl_correlation": float(sink_kl_correlation) if not self.optimize_calcs else 0.0,
                "correlation_p_value": float(correlation_p_value) if not self.optimize_calcs else 1.0,
                "correlation_significant": correlation_significant
            }

        return {
            "avg_kl_divergence": float(np.mean(kl_to_avg)),
            "max_kl_divergence": float(np.max(kl_to_avg)) if not self.optimize_calcs else 0.0,
            "min_kl_divergence": float(np.min(kl_to_avg)) if not self.optimize_calcs else 0.0,
            **sink_metrics
        }

    def analyze_text_and_aggregate(self, text, sink_thresholds=[0.8, 0.9, 0.95], sample_id=None):
        """
        Analyze a text sample through the model and collect metrics directly without storing full results.
        Optimized for speed by skipping layers and reducing computation.
        """
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_seq_length).to(self.model.device)
        token_count = inputs["input_ids"].shape[1]

        with torch.no_grad():
            outputs = self.model(
                **inputs,
                output_attentions=True,
                return_dict=True
            )

        attentions = outputs.attentions  # tuple of (layer, batch, head, seq_len, seq_len)

        sample_metrics = {
            "sample_id": sample_id,
            "token_count": token_count,
            "layer_metrics": {}
        }

        # Pick layers to analyze
        if self.skip_layers:
            # Analyze only a subset of layers for speed
            if len(attentions) <= 6:
                layer_indices = list(range(len(attentions)))
            else:
                # For larger models, analyze every other layer plus first and last
                layer_indices = [0] + list(range(1, len(attentions)-1, 2)) + [len(attentions)-1]
        else:
            layer_indices = list(range(len(attentions)))

        for layer_idx in layer_indices:
            # Each layer has shape (batch, head, seq_len, seq_len)
            layer_attention = attentions[layer_idx][0].cpu().numpy()  # shape: (head, seq_len, seq_len)

            layer_data = {
                "sink_analysis": {},
                "kl_metrics": {}
            }

            # Calculate KL divergence without consideration of sinks (for baseline)
            baseline_kl = self.fast_kl_divergence(layer_attention)
            layer_data["kl_metrics"]["baseline"] = baseline_kl

            # Identify attention sinks at different thresholds
            for threshold in sink_thresholds:
                sink_info = self.identify_attention_sinks(
                    layer_attention,
                    threshold=threshold
                )

                # Calculate KL divergence with identified sinks
                kl_metrics = self.fast_kl_divergence(
                    layer_attention,
                    sink_indices=sink_info["sink_indices"]
                )

                layer_data["sink_analysis"][str(threshold)] = sink_info
                layer_data["kl_metrics"][str(threshold)] = kl_metrics

            sample_metrics["layer_metrics"][str(layer_idx)] = layer_data

        del outputs, attentions, layer_attention
        torch.cuda.empty_cache()

        return sample_metrics

    def run_aggregated_analysis(self, texts, sink_thresholds=[0.8, 0.9, 0.95], batch_size=8):
        """
        Run the complete analysis on all samples and generate aggregated results.
        Optimized for speed by processing in slightly larger batches and reducing computation.
        """
        self.log(f"Running analysis on {len(texts)} text samples...", force=True)

        if self.model is None:
            self.load_model()

        self.sample_metrics = []

        # Process texts in batches to manage memory
        total_batches = (len(texts) + batch_size - 1) // batch_size

        progress_bar = tqdm(total=len(texts), desc="Processing samples", disable=not self.verbose)
        samples_processed = 0

        for batch_idx in range(0, len(texts), batch_size):
            batch_end = min(batch_idx + batch_size, len(texts))
            batch_num = batch_idx//batch_size + 1

            self.log(f"\nProcessing batch {batch_num}/{total_batches}: samples {batch_idx+1} to {batch_end}", force=True)

            # Process each text in the batch
            batch_results = []
            for i in range(batch_idx, batch_end):
                try:
                    # Analyze text and collect metrics
                    metrics = self.analyze_text_and_aggregate(
                        texts[i],
                        sink_thresholds=sink_thresholds,
                        sample_id=i+1
                    )

                    batch_results.append(metrics)

                    samples_processed += 1
                    progress_bar.update(1)

                except Exception as e:
                    self.log(f"Error analyzing text {i+1}: {str(e)}", force=True)
                    if self.verbose:
                        traceback.print_exc()

                    samples_processed += 1
                    progress_bar.update(1)

            self.sample_metrics.extend(batch_results)

            gc.collect()
            torch.cuda.empty_cache()
            self._track_memory()

            progress = samples_processed / len(texts)
            if progress > 0 and progress < 1:
                elapsed = progress_bar.format_dict["elapsed"]
                remaining = elapsed / progress - elapsed
                self.log(f"Progress: {progress:.1%} - Est. time remaining: {remaining/60:.1f} minutes", force=True)

        progress_bar.close()

        self.aggregated_results = self.aggregate_all_metrics(sink_thresholds)

        self.generate_numerical_summary()

        return self.aggregated_results

    def aggregate_all_metrics(self, sink_thresholds):
        """
        Aggregate metrics from all processed samples.
        Optimized to focus only on the metrics needed for the final report.
        """
        self.log("\nAggregating results across all samples...", force=True)

        if not self.sample_metrics:
            self.log("No sample metrics to aggregate!", force=True)
            return {}

        # Get layer indices from first sample (may be partial if skip_layers=True)
        first_sample = self.sample_metrics[0]
        analyzed_layers = sorted([int(idx) for idx in first_sample["layer_metrics"].keys()])
        self.log(f"Found {len(analyzed_layers)} analyzed layers", force=True)

        aggregated = {
            "model_name": self.model_name,
            "total_samples": len(self.sample_metrics),
            "thresholds": sink_thresholds,
            "analyzed_layers": analyzed_layers,
            "layer_count": self.layer_count,  # Total model layers, not just analyzed
            "layer_stats": {layer_idx: {} for layer_idx in analyzed_layers},
            "kl_trends": {
                "baseline": [],
                **{str(threshold): [] for threshold in sink_thresholds}
            },
            "kl_reduction_trends": {str(threshold): [] for threshold in sink_thresholds},
            "sink_concentration_trends": {str(threshold): [] for threshold in sink_thresholds}
        }

        for layer_idx in analyzed_layers:
            layer_key = str(layer_idx)

            metrics = {
                "baseline": {
                    "avg_kl_divergence": []
                }
            }

            for threshold in sink_thresholds:
                threshold_key = str(threshold)
                metrics[threshold_key] = {
                    "avg_kl_divergence": [],
                    "avg_kl_reduction": [],
                    "kl_reduction_p_values": [],
                    "correlation_p_values": [],
                    "sink_count": [],
                    "avg_sink_concentration": []
                }

            # Collect metrics from each sample
            for sample in self.sample_metrics:
                if layer_key not in sample["layer_metrics"]:
                    continue

                layer_data = sample["layer_metrics"][layer_key]

                # Baseline KL metrics 
                if "baseline" in layer_data["kl_metrics"]:
                    baseline_metrics = layer_data["kl_metrics"]["baseline"]
                    if "avg_kl_divergence" in baseline_metrics:
                        metrics["baseline"]["avg_kl_divergence"].append(baseline_metrics["avg_kl_divergence"])

                for threshold in sink_thresholds:
                    threshold_key = str(threshold)

                    if threshold_key not in layer_data["sink_analysis"] or threshold_key not in layer_data["kl_metrics"]:
                        continue

                    sink_info = layer_data["sink_analysis"][threshold_key]
                    kl_info = layer_data["kl_metrics"][threshold_key]

                    if "avg_kl_divergence" in kl_info:
                        metrics[threshold_key]["avg_kl_divergence"].append(kl_info["avg_kl_divergence"])

                    if "avg_kl_reduction" in kl_info:
                        metrics[threshold_key]["avg_kl_reduction"].append(kl_info["avg_kl_reduction"])

                    if "sink_count" in sink_info:
                        metrics[threshold_key]["sink_count"].append(sink_info["sink_count"])

                    if "avg_sink_concentration" in sink_info:
                        metrics[threshold_key]["avg_sink_concentration"].append(sink_info["avg_sink_concentration"])

                    if "kl_reduction_p_value" in kl_info:
                        metrics[threshold_key]["kl_reduction_p_values"].append(kl_info["kl_reduction_p_value"])

                    if "correlation_p_value" in kl_info:
                        metrics[threshold_key]["correlation_p_values"].append(kl_info["correlation_p_value"])

            layer_stats = {}

            baseline_stats = {}
            for key, values in metrics["baseline"].items():
                if values:
                    baseline_stats[key] = {
                        "mean": float(np.nanmean(values)),
                        "std": float(np.nanstd(values)) if len(values) > 1 else 0.0,
                        "count": len(values)
                    }
            layer_stats["baseline"] = baseline_stats

            for threshold in sink_thresholds:
                threshold_key = str(threshold)
                threshold_stats = {}
                for key, values in metrics[threshold_key].items():
                    if values:
                        threshold_stats[key] = {
                            "mean": float(np.nanmean(values)),
                            "std": float(np.nanstd(values)) if len(values) > 1 else 0.0,
                            "count": len(values)
                        }
                layer_stats[threshold_key] = threshold_stats

            # Store layer statistics
            aggregated["layer_stats"][layer_idx] = layer_stats

            # 1. Baseline KL trend
            if "avg_kl_divergence" in baseline_stats:
                aggregated["kl_trends"]["baseline"].append({
                    "layer": layer_idx,
                    "value": baseline_stats["avg_kl_divergence"]["mean"],
                    "std": baseline_stats["avg_kl_divergence"]["std"]
                })

            # 2. Threshold-specific trends 
            for threshold in sink_thresholds:
                threshold_key = str(threshold)
                threshold_stats = layer_stats.get(threshold_key, {})

                # KL divergence trend
                if "avg_kl_divergence" in threshold_stats:
                    aggregated["kl_trends"][threshold_key].append({
                        "layer": layer_idx,
                        "value": threshold_stats["avg_kl_divergence"]["mean"],
                        "std": threshold_stats["avg_kl_divergence"]["std"]
                    })

                # KL reduction trend
                if "avg_kl_reduction" in threshold_stats:
                    aggregated["kl_reduction_trends"][threshold_key].append({
                        "layer": layer_idx,
                        "value": threshold_stats["avg_kl_reduction"]["mean"],
                        "std": threshold_stats["avg_kl_reduction"]["std"]
                    })

                # Sink concentration trend
                if "avg_sink_concentration" in threshold_stats:
                    aggregated["sink_concentration_trends"][threshold_key].append({
                        "layer": layer_idx,
                        "value": threshold_stats["avg_sink_concentration"]["mean"],
                        "std": threshold_stats["avg_sink_concentration"]["std"]
                    })

                # Compute combined p-values using Fisher's method
                if "kl_reduction_p_values" in metrics[threshold_key] and metrics[threshold_key]["kl_reduction_p_values"]:
                    from scipy import stats
                    p_values = np.array(metrics[threshold_key]["kl_reduction_p_values"])
                    valid_p = p_values[~np.isnan(p_values)]
                    if len(valid_p) > 0:
                        # Fisher's method: -2 * sum(ln(p))
                        fisher_statistic = -2 * np.sum(np.log(valid_p + 1e-10))
                        combined_p = 1 - stats.chi2.cdf(fisher_statistic, 2 * len(valid_p))
                        threshold_stats["kl_reduction_significance"] = {
                            "combined_p_value": float(combined_p),
                            "significant": bool(combined_p < 0.05),  
                            "sample_count": len(valid_p)
                        }

                # Do the same for correlation p-values
                if "correlation_p_values" in metrics[threshold_key] and metrics[threshold_key]["correlation_p_values"]:
                    from scipy import stats
                    p_values = np.array(metrics[threshold_key]["correlation_p_values"])
                    valid_p = p_values[~np.isnan(p_values)]
                    if len(valid_p) > 0:
                        fisher_statistic = -2 * np.sum(np.log(valid_p + 1e-10))
                        combined_p = 1 - stats.chi2.cdf(fisher_statistic, 2 * len(valid_p))
                        threshold_stats["correlation_significance"] = {
                            "combined_p_value": float(combined_p),
                            "significant": bool(combined_p < 0.05), 
                            "sample_count": len(valid_p)
                        }

        aggregated["cross_layer_analysis"] = self.compute_cross_layer_correlations(aggregated, sink_thresholds)

        raw_data_file = os.path.join(self.output_dir, "aggregated_raw_data.json")
        with open(raw_data_file, 'w') as f:
            json.dump(aggregated, f, indent=2)
        self.log(f"Saved detailed raw data to {raw_data_file}", force=True)

        return aggregated

    def compute_cross_layer_correlations(self, aggregated_results, sink_thresholds):
        """
        Compute correlations between different layers and metrics.
        Add p-values for statistical significance assessment.

        Args:
            aggregated_results: The aggregated analysis results
            sink_thresholds: List of thresholds used for sink analysis

        Returns:
            Dictionary with correlation analyses
        """
        cross_layer_analysis = {
            "layer_position_correlations": {},
            "inter_metric_correlations": {}
        }

        analyzed_layers = aggregated_results["analyzed_layers"]

        # 1. Correlation between layer position and metrics
        metrics_to_correlate = [
            ("kl_reduction", "KL Reduction"),
            ("sink_concentration", "Sink Concentration")
        ]

        for metric_key, metric_name in metrics_to_correlate:
            for threshold in sink_thresholds:
                threshold_key = str(threshold)

                if f"{metric_key}_trends" in aggregated_results and threshold_key in aggregated_results[f"{metric_key}_trends"]:
                    trend_data = aggregated_results[f"{metric_key}_trends"][threshold_key]

                    layers = [item["layer"] for item in trend_data]
                    values = [item["value"] for item in trend_data]

                    if len(layers) >= 2:  
                        try:
                            # Pearson correlation (linear relationship)
                            pearson_r, pearson_p = pearsonr(layers, values)

                            # Spearman correlation (monotonic but not necessarily linear)
                            spearman_r, spearman_p = spearmanr(layers, values)

                            cross_layer_analysis["layer_position_correlations"][f"{metric_key}_{threshold_key}"] = {
                                "metric": metric_name,
                                "threshold": threshold,
                                "pearson_correlation": float(pearson_r),
                                "pearson_p_value": float(pearson_p),
                                "pearson_significant": bool(pearson_p < 0.05), 
                                "spearman_correlation": float(spearman_r),
                                "spearman_p_value": float(spearman_p),
                                "spearman_significant": bool(spearman_p < 0.05), 
                                "sample_size": len(layers)
                            }
                        except:
                            pass

        # 2. Correlations between different metrics
        metrics_map = {
            "kl_reduction": "KL Reduction",
            "sink_concentration": "Sink Concentration"
        }

        for threshold in sink_thresholds:
            threshold_key = str(threshold)

            # Compare each pair of metrics
            for metric1 in metrics_map:
                for metric2 in metrics_map:
                    if metric1 >= metric2:  # Avoid duplicate correlations and self-correlations
                        continue

                    # Extract data for both metrics
                    trend1_key = f"{metric1}_trends"
                    trend2_key = f"{metric2}_trends"

                    if (trend1_key in aggregated_results and trend2_key in aggregated_results and
                        threshold_key in aggregated_results[trend1_key] and
                        threshold_key in aggregated_results[trend2_key]):

                        trend1_data = aggregated_results[trend1_key][threshold_key]
                        trend2_data = aggregated_results[trend2_key][threshold_key]

                        # Match up the same layers
                        values1 = []
                        values2 = []
                        layers = []

                        for layer in analyzed_layers:
                            val1 = next((item["value"] for item in trend1_data if item["layer"] == layer), None)
                            val2 = next((item["value"] for item in trend2_data if item["layer"] == layer), None)

                            if val1 is not None and val2 is not None:
                                values1.append(val1)
                                values2.append(val2)
                                layers.append(layer)

                        if len(values1) >= 2:  # Need at least 2 points for correlation
                            try:
                                pearson_r, pearson_p = pearsonr(values1, values2)

                                spearman_r, spearman_p = spearmanr(values1, values2)

                                corr_key = f"{metric1}_{metric2}_{threshold_key}"
                                cross_layer_analysis["inter_metric_correlations"][corr_key] = {
                                    "metric1": metrics_map[metric1],
                                    "metric2": metrics_map[metric2],
                                    "threshold": threshold,
                                    "pearson_correlation": float(pearson_r),
                                    "pearson_p_value": float(pearson_p),
                                    "pearson_significant": bool(pearson_p < 0.05),
                                    "spearman_correlation": float(spearman_r),
                                    "spearman_p_value": float(spearman_p),
                                    "spearman_significant": bool(spearman_p < 0.05),
                                    "sample_size": len(values1)
                                }
                            except:
                                pass

        return cross_layer_analysis

    def generate_numerical_summary(self):
        """Generate a concise numerical summary with just the numbers."""
        if not self.aggregated_results:
            self.log("No aggregated results to report.", force=True)
            return None

        self.log("\nGenerating numerical summary...", force=True)

        sink_thresholds = self.aggregated_results["thresholds"]
        analyzed_layers = self.aggregated_results["analyzed_layers"]
        summary_path = os.path.join(self.output_dir, "attention_sink_kl_summary.txt")

        with open(summary_path, 'w') as f:
            f.write("ATTENTION SINK KL DIVERGENCE ANALYSIS - NUMERICAL SUMMARY\n")
            f.write("======================================================\n\n")

            f.write(f"Model: {self.model_name}\n")
            f.write(f"Samples: {self.aggregated_results['total_samples']}\n")
            f.write(f"Layers: {self.aggregated_results['layer_count']} (analyzed {len(analyzed_layers)} layers)\n")
            f.write(f"Max sequence length: {self.max_seq_length}\n")
            f.write(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")

            f.write("KL REDUCTION SUMMARY\n")
            f.write("-------------------\n\n")

            headers = ["Layer"]
            for threshold in sink_thresholds:
                headers.append(f"KL Red. (t={threshold})")

            table_data = []

            for layer in analyzed_layers:
                row = [layer]

                for threshold in sink_thresholds:
                    threshold_key = str(threshold)
                    if threshold_key in self.aggregated_results["kl_reduction_trends"]:
                        layer_data = [item for item in self.aggregated_results["kl_reduction_trends"][threshold_key]
                                    if item["layer"] == layer]

                        if layer_data:
                            row.append(f"{layer_data[0]['value']:.4f}")
                        else:
                            row.append("N/A")
                    else:
                        row.append("N/A")

                table_data.append(row)

            f.write(tabulate.tabulate(table_data, headers=headers, tablefmt="plain"))
            f.write("\n\n")

            f.write("SINK CONCENTRATION SUMMARY\n")
            f.write("-------------------------\n\n")

            headers = ["Layer"]
            for threshold in sink_thresholds:
                headers.append(f"Sink Conc. (t={threshold})")

            table_data = []

            for layer in analyzed_layers:
                row = [layer]

                for threshold in sink_thresholds:
                    threshold_key = str(threshold)
                    if threshold_key in self.aggregated_results["sink_concentration_trends"]:
                        layer_data = [item for item in self.aggregated_results["sink_concentration_trends"][threshold_key]
                                    if item["layer"] == layer]

                        if layer_data:
                            row.append(f"{layer_data[0]['value']*100:.2f}%")
                        else:
                            row.append("N/A")
                    else:
                        row.append("N/A")

                table_data.append(row)

            f.write(tabulate.tabulate(table_data, headers=headers, tablefmt="plain"))
            f.write("\n\n")

            f.write("OPTIMAL VALUES BY LAYER\n")
            f.write("---------------------\n\n")

            headers = ["Layer", "Max KL Red.", "Best Thresh.", "Max Conc.", "Best Thresh."]
            table_data = []

            for layer in analyzed_layers:
                row = [layer]

                # Best KL reduction
                max_reduction = -float('inf')
                best_reduction_threshold = None

                for threshold in sink_thresholds:
                    threshold_key = str(threshold)
                    if threshold_key in self.aggregated_results["kl_reduction_trends"]:
                        layer_data = [item for item in self.aggregated_results["kl_reduction_trends"][threshold_key]
                                    if item["layer"] == layer]

                        if layer_data and layer_data[0]["value"] > max_reduction:
                            max_reduction = layer_data[0]["value"]
                            best_reduction_threshold = threshold

                if max_reduction > -float('inf'):
                    row.append(f"{max_reduction:.4f}")
                    row.append(f"{best_reduction_threshold}")
                else:
                    row.append("N/A")
                    row.append("N/A")

                # Best sink concentration
                max_concentration = -float('inf')
                best_concentration_threshold = None

                for threshold in sink_thresholds:
                    threshold_key = str(threshold)
                    if threshold_key in self.aggregated_results["sink_concentration_trends"]:
                        layer_data = [item for item in self.aggregated_results["sink_concentration_trends"][threshold_key]
                                    if item["layer"] == layer]

                        if layer_data and layer_data[0]["value"] > max_concentration:
                            max_concentration = layer_data[0]["value"]
                            best_concentration_threshold = threshold

                if max_concentration > -float('inf'):
                    row.append(f"{max_concentration*100:.2f}%")
                    row.append(f"{best_concentration_threshold}")
                else:
                    row.append("N/A")
                    row.append("N/A")

                table_data.append(row)

            f.write(tabulate.tabulate(table_data, headers=headers, tablefmt="plain"))
            f.write("\n\n")

            f.write("PERFORMANCE METRICS\n")
            f.write("------------------\n")
            f.write(f"Peak memory usage: {self.peak_memory:.2f} GB\n")
            f.write(f"Total samples processed: {self.aggregated_results['total_samples']}\n")
            f.write(f"Optimization settings: max_seq_length={self.max_seq_length}, skip_layers={self.skip_layers}, optimize_calcs={self.optimize_calcs}\n")

            f.write("\nKL REDUCTION SUMMARY (with statistical significance)\n")
            f.write("-----------------------------------------------\n\n")

            headers = ["Layer"]
            for threshold in sink_thresholds:
                headers.append(f"KL Red. (t={threshold})")
                headers.append("p-value")  
                headers.append("Significant")  

            table_data = []

            for layer in analyzed_layers:
                row = [layer]

                for threshold in sink_thresholds:
                    threshold_key = str(threshold)
                    layer_stats = self.aggregated_results["layer_stats"].get(str(layer), {})
                    threshold_stats = layer_stats.get(threshold_key, {})

                    # KL reduction value
                    if "avg_kl_reduction" in threshold_stats:
                        row.append(f"{threshold_stats['avg_kl_reduction']['mean']:.4f}")
                    else:
                        row.append("N/A")

                    # P-value
                    if "kl_reduction_significance" in threshold_stats:
                        p_value = threshold_stats["kl_reduction_significance"]["combined_p_value"]
                        row.append(f"{p_value:.4f}")

                        # Significance
                        row.append("Yes" if p_value < 0.05 else "No")
                    else:
                        row.append("N/A")
                        row.append("N/A")

                table_data.append(row)

            f.write(tabulate.tabulate(table_data, headers=headers, tablefmt="plain"))
            f.write("\n\n")

            f.write("STATISTICAL SIGNIFICANCE SUMMARY\n")
            f.write("------------------------------\n\n")
            f.write("This section shows which findings are statistically significant (p < 0.05).\n\n")

            headers = ["Threshold", "KL Reduction Significant", "Correlation Significant"]
            table_data = []

            for threshold in sink_thresholds:
                threshold_key = str(threshold)
                row = [threshold]

                kl_significant = False
                for layer in analyzed_layers:
                    layer_stats = self.aggregated_results["layer_stats"].get(str(layer), {})
                    threshold_stats = layer_stats.get(threshold_key, {})
                    if "kl_reduction_significance" in threshold_stats and threshold_stats["kl_reduction_significance"]["significant"]:
                        kl_significant = True
                        break

                row.append("Yes" if kl_significant else "No")

                corr_significant = False
                for layer in analyzed_layers:
                    layer_stats = self.aggregated_results["layer_stats"].get(str(layer), {})
                    threshold_stats = layer_stats.get(threshold_key, {})
                    if "correlation_significance" in threshold_stats and threshold_stats["correlation_significance"]["significant"]:
                        corr_significant = True
                        break

                row.append("Yes" if corr_significant else "No")

                table_data.append(row)

            f.write(tabulate.tabulate(table_data, headers=headers, tablefmt="plain"))
            f.write("\n\n")

            if "cross_layer_analysis" in self.aggregated_results:
                cross_layer = self.aggregated_results["cross_layer_analysis"]

                f.write("CROSS-LAYER CORRELATION ANALYSIS\n")
                f.write("-------------------------------\n\n")

                if "layer_position_correlations" in cross_layer:
                    f.write("Correlations with Layer Position:\n")
                    f.write("---------------------------------\n")

                    headers = ["Metric", "Threshold", "Pearson r", "p-value", "Significant", "Spearman r", "p-value", "Significant"]
                    table_data = []

                    for key, data in cross_layer["layer_position_correlations"].items():
                        row = [
                            data.get("metric", ""),
                            data.get("threshold", ""),
                            f"{data.get('pearson_correlation', 0):.4f}",
                            f"{data.get('pearson_p_value', 0):.4f}",
                            "Yes" if data.get("pearson_significant", False) else "No",
                            f"{data.get('spearman_correlation', 0):.4f}",
                            f"{data.get('spearman_p_value', 0):.4f}",
                            "Yes" if data.get("spearman_significant", False) else "No"
                        ]
                        table_data.append(row)

                    f.write(tabulate.tabulate(table_data, headers=headers, tablefmt="plain"))
                    f.write("\n\n")

                if "inter_metric_correlations" in cross_layer:
                    f.write("Correlations Between Metrics:\n")
                    f.write("----------------------------\n")

                    headers = ["Metric 1", "Metric 2", "Threshold", "Pearson r", "p-value", "Significant"]
                    table_data = []

                    for key, data in cross_layer["inter_metric_correlations"].items():
                        row = [
                            data.get("metric1", ""),
                            data.get("metric2", ""),
                            data.get("threshold", ""),
                            f"{data.get('pearson_correlation', 0):.4f}",
                            f"{data.get('pearson_p_value', 0):.4f}",
                            "Yes" if data.get("pearson_significant", False) else "No"
                        ]
                        table_data.append(row)

                    f.write(tabulate.tabulate(table_data, headers=headers, tablefmt="plain"))
                    f.write("\n\n")

        self.log(f"Numerical summary saved to {summary_path}", force=True)
        return summary_path

    def cleanup(self):
        """Clean up resources."""
        if self.model is not None:
            self.model = self.model.to("cpu")
            del self.model
            self.model = None

        if self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None

        gc.collect()
        torch.cuda.empty_cache()

        self.log("Resources cleaned up", force=True)


def get_sample_texts_from_dataset(dataset_path, n_samples=50):
    """
    Extract sample texts from a dataset for analysis.

    Args:
        dataset_path: Path to the dataset CSV
        n_samples: Number of samples to extract

    Returns:
        List of text samples
    """
    try:
        # Check if file exists
        if not os.path.exists(dataset_path):
            print(f"ERROR: Dataset file not found at path: {dataset_path}")
            return []

        # Load dataset
        print(f"Attempting to load dataset from: {dataset_path}")
        data = pd.read_csv(dataset_path)
        print(f"Successfully loaded dataset with {len(data)} rows")

        # Print column names to help debugging
        print(f"Dataset columns: {', '.join(data.columns)}")

        # Check if text column exists
        if 'text' not in data.columns:
            print("ERROR: No 'text' column found in dataset. Available columns:")
            for col in data.columns:
                print(f" - {col}")
            return []

        # Sample texts
        if len(data) > n_samples:
            print(f"Sampling {n_samples} texts from dataset")
            samples = data.sample(n_samples)
        else:
            print(f"Using all {len(data)} texts from dataset")
            samples = data

        # Extract texts
        texts = samples['text'].tolist()

        non_empty_texts = [t for t in texts if isinstance(t, str) and len(t) > 0]
        if len(non_empty_texts) < len(texts):
            print(f"WARNING: {len(texts) - len(non_empty_texts)} empty texts found and will be excluded")
            texts = non_empty_texts

        if texts:
            print(f"Sample text (first 100 chars): {texts[0][:100]}...")

        return texts

    except pd.errors.EmptyDataError:
        print(f"ERROR: Dataset file is empty: {dataset_path}")
        return []
    except pd.errors.ParserError:
        print(f"ERROR: Failed to parse dataset file. Make sure it's a valid CSV: {dataset_path}")
        return []
    except Exception as e:
        print(f"ERROR loading dataset: {str(e)}")
        traceback.print_exc()
        return []



In [None]:
def main():
    """Run ultra-fast attention sink KL divergence analysis."""
    try:
        from google.colab import drive
        # Mount Google Drive if in Colab
        drive.mount('/content/drive')
        is_colab = True
    except ImportError:
        is_colab = False
        print("Not running in Google Colab, skipping drive mount")


    model_name = "EleutherAI/pythia-12b"

    if is_colab:
        dataset_path = "/content/drive/MyDrive/wiki_dataset_position.csv"
        output_dir = f"/content/drive/MyDrive/Sink/pvalue/kl_div/{model_name.split('/')[-1]}"
    else:
        dataset_path = "./wiki_dataset_position.csv"
        output_dir = f"./attention_sink_kl_{model_name.split('/')[-1]}_fast"

    # Analysis parameters
    num_samples = 50           
    batch_size = 8            
    sink_thresholds = [0.8, 0.9, 0.95]  # Thresholds for attention sink identification

    # Speed optimization parameters
    max_seq_length = 128        # Truncate sequences to this length for speed
    skip_layers = True          # Analyze every other layer for speed
    optimize_calcs = True       # Use optimized calculations for massive speedup
    verbose = True              # Set to True for detailed progress logs

    print(f"Starting analysis with model: {model_name}")
    print(f"Dataset path: {dataset_path}")
    print(f"Output directory: {output_dir}")
    print(f"Analyzing {num_samples} samples with optimization settings:")
    print(f" - Max sequence length: {max_seq_length}")
    print(f" - Skip layers: {skip_layers}")
    print(f" - Optimize calculations: {optimize_calcs}")

    # Check if dataset file exists
    if not os.path.exists(dataset_path):
        print(f"ERROR: Dataset file not found at path: {dataset_path}")
        print("Please check the path and try again.")
        return

    os.makedirs(output_dir, exist_ok=True)

    analyzer = UltraFastAttentionSinkKLAnalysis(
        model_name=model_name,
        output_dir=output_dir,
        max_seq_length=max_seq_length,
        skip_layers=skip_layers,
        optimize_calcs=optimize_calcs,
        verbose=verbose
    )

    texts = get_sample_texts_from_dataset(dataset_path, n_samples=num_samples)

    if not texts:
        print("No texts loaded from dataset. Using default examples.")
        texts = [
            "The concept of attention mechanisms allows transformers to focus on relevant parts of the input.",
            "Attention sinks are positions that accumulate high amounts of attention from many tokens.",
            "Information theory measures like KL divergence help us understand how models process data.",
            "Transformer models coordinate information flow between attention heads across different layers."
        ]
        print(f"Using {len(texts)} default example texts")
    else:
        print(f"Successfully loaded {len(texts)} texts from dataset")

    try:
        print("Loading model...")
        analyzer.load_model(use_4bit=True) 

        start_time = time.time()
        print("Beginning analysis...")
        results = analyzer.run_aggregated_analysis(
            texts=texts,
            sink_thresholds=sink_thresholds,
            batch_size=batch_size
        )
        total_time = time.time() - start_time
        print(f"Analysis completed successfully in {total_time/60:.2f} minutes!")

        print(f"Results saved to: {output_dir}")
        print("Summary file should be at: " + os.path.join(output_dir, "attention_sink_kl_summary.txt"))
        print("JSON data file should be at: " + os.path.join(output_dir, "aggregated_raw_data.json"))

    except Exception as e:
        print(f"Analysis error: {str(e)}")
        print("Full traceback:")
        traceback.print_exc()
    finally:
        # Clean up resources
        print("Cleaning up resources...")
        analyzer.cleanup()
        print("Done.")


if __name__ == "__main__":
    main()