In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import numpy as np
import torch
import os
import gc
import json
import pandas as pd
from scipy import stats
import time
from tqdm import tqdm
import traceback
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
class OptimizedAttentionMatrixRMTAnalysis:
    """
    Optimized analyzer for attention matrices using Random Matrix Theory (RMT).
    Streamlined to focus on generating the final report only.
    """

    def __init__(self, model_size="pythia-70m", output_dir="./pythia_attention_rmt_analysis", verbose=False):
        """
        Initialize the analysis with model size and output directory.

        Args:
            model_size: Size of Pythia model (e.g., "pythia-70m", "pythia-2.8b")
            output_dir: Directory to save results
            verbose: Whether to print detailed progress information
        """
        self.model_size = model_size
        self.base_model_name = f"EleutherAI/{model_size}"
        self.output_dir = output_dir
        self.model = None
        self.tokenizer = None
        self.results = {}
        self.verbose = verbose

        os.makedirs(output_dir, exist_ok=True)

    def log(self, message, force=False):
        """Print messages only when verbose is enabled or forced."""
        if self.verbose or force:
            print(message)

    def load_model_checkpoint(self, step="step0"):
        """
        Load a specific checkpoint of the Pythia model.

        Args:
            step: Checkpoint step (e.g., "step0", "step1000", "step143000")

        Returns:
            Loaded model and tokenizer
        """
        self.log(f"Loading {self.base_model_name} at {step}", force=True)

        # Clean up any existing model
        if self.model is not None:
            del self.model
            self.model = None
            torch.cuda.empty_cache()
            gc.collect()

        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.base_model_name,
                revision=step
            )

            self.model = AutoModelForCausalLM.from_pretrained(
                self.base_model_name,
                revision=step,
                torch_dtype=torch.float32,
                device_map="cpu"
            )

            self.log(f"Successfully loaded {self.base_model_name} at {step}", force=True)
            return self.model, self.tokenizer

        except Exception as e:
            self.log(f"Error loading model at {step}: {e}", force=True)
            if self.verbose:
                traceback.print_exc()
            raise e

    def compute_eigenspectrum(self, attention_matrix):
        """
        Compute eigenvalues and spectral statistics of an attention matrix.
        Optimized for speed by computing only essential metrics.

        Args:
            attention_matrix: Attention weight matrix of shape (seq_len, seq_len)

        Returns:
            Dictionary with eigenspectrum analysis
        """
        if torch.is_tensor(attention_matrix):
            attention_matrix = attention_matrix.cpu().numpy()

        try:
            eigenvalues = np.linalg.eigvals(attention_matrix)

            # convert to real numbers if there's a negligible imaginary component
            if np.iscomplex(eigenvalues).any():
                if np.all(np.abs(eigenvalues.imag) < 1e-10):
                    eigenvalues = eigenvalues.real

            # Sort eigenvalues by magnitude (descending)
            eigenvalues = np.sort(np.abs(eigenvalues))[::-1]

            # Calculate spectral gap (ratio between largest and second largest eigenvalue)
            spectral_gap = float(eigenvalues[0] / eigenvalues[1]) if len(eigenvalues) > 1 else np.inf

            # Calculate effective rank using participation ratio
            squared_sum = np.sum(eigenvalues) ** 2
            sum_squares = np.sum(eigenvalues ** 2)
            participation_ratio = squared_sum / sum_squares if sum_squares > 0 else 0

            # Calculate singular values (for actual rank estimation)
            singular_values = np.linalg.svd(attention_matrix, compute_uv=False)

            # Stable rank calculation
            stable_rank = np.sum(singular_values ** 2) / (singular_values[0] ** 2) if singular_values[0] > 0 else 0

            # Calculate spectral norm (largest singular value)
            spectral_norm = singular_values[0]

            # Calculate "bulk edge" - the edge of the main distribution of eigenvalues
            if len(eigenvalues) > 2:
                eigenvalue_gaps = eigenvalues[:-1] - eigenvalues[1:]
                largest_gap_idx = np.argmax(eigenvalue_gaps)
                bulk_edge = (eigenvalues[largest_gap_idx] + eigenvalues[largest_gap_idx + 1]) / 2
            else:
                bulk_edge = eigenvalues[-1]

            # Calculate bulk statistics (excluding potential outliers)
            bulk_eigenvalues = eigenvalues[eigenvalues < bulk_edge]
            if len(bulk_eigenvalues) > 1:
                bulk_mean = float(np.mean(bulk_eigenvalues))
                bulk_std = float(np.std(bulk_eigenvalues))
            else:
                bulk_mean = float(np.mean(eigenvalues))
                bulk_std = float(np.std(eigenvalues))

            # Return all metrics as float
            return {
                "eigenvalues": eigenvalues.tolist()[:10],  # Keep only top 10 eigenvalues to save memory
                "largest_eigenvalue": float(eigenvalues[0]) if len(eigenvalues) > 0 else 0,
                "spectral_gap": float(spectral_gap),
                "participation_ratio": float(participation_ratio),
                "stable_rank": float(stable_rank),
                "spectral_norm": float(spectral_norm),
                "bulk_edge": float(bulk_edge),
                "bulk_mean": bulk_mean,
                "bulk_std": bulk_std
            }

        except Exception as e:
            self.log(f"Error in eigenspectrum computation: {str(e)}")
            if self.verbose:
                traceback.print_exc()
            return {
                "eigenvalues": [],
                "largest_eigenvalue": 0.0,
                "spectral_gap": 0.0,
                "participation_ratio": 0.0,
                "stable_rank": 0.0,
                "spectral_norm": 0.0,
                "bulk_edge": 0.0,
                "bulk_mean": 0.0,
                "bulk_std": 0.0,
                "error": str(e)
            }

    def compute_marchenko_pastur_distance(self, eigenvalues, aspect_ratio=1.0):
        """
        Compute the distance between the empirical eigenvalue distribution
        and the Marchenko-Pastur law.
        Simplified and optimized version.

        Args:
            eigenvalues: List of eigenvalues
            aspect_ratio: Ratio of rows to columns in the matrix

        Returns:
            KL divergence metric between distributions
        """
        if not eigenvalues or len(eigenvalues) < 2:
            return {"kl_divergence": float('nan')}

        try:
            # Normalize eigenvalues to have mean 1
            eigenvalues = np.array(eigenvalues)
            normalized_eigenvalues = eigenvalues / np.mean(eigenvalues)

            # Create empirical distribution
            hist, bin_edges = np.histogram(normalized_eigenvalues, bins=20, density=True)
            bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

            # Marchenko-Pastur density function
            def mp_density(x, gamma=aspect_ratio):
                if gamma == 1:
                    gamma = 0.99  # Avoid singularity at gamma=1

                # Parameters for MP distribution
                a = (1 - np.sqrt(gamma))**2
                b = (1 + np.sqrt(gamma))**2

                # MP density
                result = np.zeros_like(x, dtype=float)
                mask = (x >= a) & (x <= b)
                result[mask] = np.sqrt((b - x[mask]) * (x[mask] - a)) / (2 * np.pi * gamma * x[mask])
                return result

            # Generate MP distribution for the given bins
            mp_dist = mp_density(bin_centers)

            # Normalize distributions for comparison
            hist_normalized = hist / np.sum(hist)
            mp_dist_normalized = mp_dist / np.sum(mp_dist)

            # Compute KL divergence from MP distribution
            kl_div = stats.entropy(hist_normalized, mp_dist_normalized)

            return {"kl_divergence": float(kl_div)}

        except Exception as e:
            self.log(f"Error in Marchenko-Pastur calculation: {str(e)}")
            if self.verbose:
                traceback.print_exc()
            return {"kl_divergence": float('nan')}

    def analyze_attention_matrix(self, attention_matrix, layer_idx, head_idx):
        """
        Analyze a single attention matrix.
        Optimized version with minimal metric calculation.

        Args:
            attention_matrix: Attention matrix to analyze
            layer_idx: Layer index
            head_idx: Head index

        Returns:
            Dictionary with analysis results
        """
        seq_len = attention_matrix.shape[0]

        # Check for attention sinks (tokens receiving high total attention)
        col_sums = np.sum(attention_matrix, axis=0)  # Sum over source tokens

        # Identify potential sinks 
        sink_threshold = np.percentile(col_sums, 90)
        sink_indices = np.where(col_sums >= sink_threshold)[0]

        # Calculate attention concentration
        sink_attention = np.sum(col_sums[sink_indices])
        total_attention = np.sum(col_sums)
        sink_concentration = sink_attention / total_attention if total_attention > 0 else 0

        # Calculate entropy of the attention distribution
        row_entropies = []
        for i in range(seq_len):
            row = attention_matrix[i]
            # Avoid log(0) by adding a small epsilon
            entropy_i = -np.sum(row * np.log2(row + 1e-10))
            row_entropies.append(entropy_i)

        avg_entropy = float(np.mean(row_entropies))

        # Compute eigenspectrum
        eigenspectrum = self.compute_eigenspectrum(attention_matrix)

        # Compute distance from Marchenko-Pastur distribution
        mp_distance = self.compute_marchenko_pastur_distance(
            eigenspectrum["eigenvalues"],
            aspect_ratio=1.0  # Assuming square attention matrices
        )

        return {
            "sink_concentration": float(sink_concentration),
            "entropy": avg_entropy,
            "spectral_gap": eigenspectrum["spectral_gap"],
            "participation_ratio": eigenspectrum["participation_ratio"],
            "stable_rank": eigenspectrum["stable_rank"],
            "marchenko_pastur_kl": mp_distance["kl_divergence"]
        }

    def analyze_text(self, text, step):
        """
        Analyze a text sample through the model.
        Optimized for speed with minimal outputs.

        Args:
            text: Text to analyze
            step: Checkpoint step being analyzed

        Returns:
            Dictionary with results for each layer and head
        """
        if self.verbose:
            text_preview = text[:50] + ("..." if len(text) > 50 else "")
            self.log(f"Analyzing text at step {step}: {text_preview}")

        inputs = self.tokenizer(text, return_tensors="pt")

        with torch.no_grad():
            outputs = self.model(
                **inputs,
                output_attentions=True,
                return_dict=True
            )

        attentions = outputs.attentions  # tuple of (layer, batch, head, seq_len, seq_len)

        results = {
            "layers": {}
        }

        # Process each layer and head
        for layer_idx, layer_attention in enumerate(attentions):
            # Each layer has shape (batch, head, seq_len, seq_len)
            layer_attention = layer_attention[0].cpu().numpy()  # shape: (head, seq_len, seq_len)

            # Get number of heads and sequence length
            num_heads, seq_len, _ = layer_attention.shape
            if self.verbose:
                self.log(f"Processing layer {layer_idx} with {num_heads} heads, sequence length {seq_len}")

            # Store layer info
            results["layers"][str(layer_idx)] = {
                "heads": {},
                "avg_head": {}
            }

            # Process each attention head
            for head_idx in range(num_heads):
                # Get attention matrix for this head
                attention_matrix = layer_attention[head_idx]

                # Analyze this attention matrix
                head_results = self.analyze_attention_matrix(
                    attention_matrix, layer_idx, head_idx
                )

                results["layers"][str(layer_idx)]["heads"][str(head_idx)] = head_results

            avg_attention = np.mean(layer_attention, axis=0)

            avg_results = self.analyze_attention_matrix(
                avg_attention, layer_idx, "avg"
            )

            results["layers"][str(layer_idx)]["avg_head"] = avg_results

        return results

    def analyze_checkpoint(self, step, texts):
        """
        Analyze a specific checkpoint across multiple text samples.
        Optimized to skip intermediate file outputs.

        Args:
            step: Checkpoint step to analyze (e.g., "step0", "step1000")
            texts: List of text samples to analyze

        Returns:
            Dictionary with aggregated results
        """
        self.log(f"Analyzing checkpoint {step} with {len(texts)} text samples...", force=True)

        self.load_model_checkpoint(step=step)

        checkpoint_results = []
        for i, text in enumerate(tqdm(texts, desc=f"Processing texts for {step}", disable=not self.verbose)):
            try:
                result = self.analyze_text(text, step)
                checkpoint_results.append(result)

                torch.cuda.empty_cache()
                gc.collect()

            except Exception as e:
                self.log(f"Error analyzing text {i+1}: {e}", force=True)
                if self.verbose:
                    traceback.print_exc()

        aggregated = self.aggregate_results(checkpoint_results, step)

        return {
            "step": step,
            "aggregated_results": aggregated
        }

    def aggregate_results(self, all_results, step):
        """
        Aggregate results from multiple texts for a checkpoint.
        Optimized to focus only on metrics needed for the final report.

        Args:
            all_results: List of results from analyze_text
            step: Checkpoint step

        Returns:
            Dictionary with aggregated statistics
        """
        self.log(f"Aggregating results for checkpoint {step}...")

        if not all_results:
            self.log("No results to aggregate!", force=True)
            return {}

        first_result = all_results[0]
        layer_indices = sorted([int(idx) for idx in first_result["layers"].keys()])

        head_indices = []
        if layer_indices:
            first_layer = first_result["layers"][str(layer_indices[0])]
            head_indices = sorted([int(idx) for idx in first_layer["heads"].keys()])

        self.log(f"Found {len(layer_indices)} layers and {len(head_indices)} heads to aggregate")

        aggregated = {
            "step": step,
            "layer_stats": {},
            "rmt_trends": {
                "spectral_gap": [],
                "participation_ratio": [],
                "stable_rank": [],
                "marchenko_pastur_kl": [],
                "entropy": [],
                "sink_concentration": []
            }
        }

        for layer_idx in layer_indices:
            layer_key = str(layer_idx)
            layer_stats = {"avg_head": {}}

            avg_metrics = {
                "spectral_gap": [],
                "participation_ratio": [],
                "stable_rank": [],
                "marchenko_pastur_kl": [],
                "entropy": [],
                "sink_concentration": []
            }

            for result in all_results:
                if layer_key not in result["layers"]:
                    continue

                layer_data = result["layers"][layer_key]
                if "avg_head" not in layer_data:
                    continue

                avg_data = layer_data["avg_head"]

                for metric in avg_metrics.keys():
                    if metric in avg_data:
                        avg_metrics[metric].append(avg_data[metric])

            # Compute statistics for each metric
            avg_head_stats = {}
            for metric_name, values in avg_metrics.items():
                if values:
                    valid_values = [v for v in values if not (np.isnan(v) or np.isinf(v))]
                    if valid_values:
                        avg_head_stats[metric_name] = {
                            "mean": float(np.nanmean(valid_values)),
                            "std": float(np.nanstd(valid_values)) if len(valid_values) > 1 else 0.0,
                            "min": float(np.nanmin(valid_values)),
                            "max": float(np.nanmax(valid_values)),
                            "count": len(valid_values)
                        }

            layer_stats["avg_head"] = avg_head_stats

            aggregated["layer_stats"][layer_key] = layer_stats

            for trend_key in aggregated["rmt_trends"].keys():
                if trend_key in avg_head_stats and "mean" in avg_head_stats[trend_key]:
                    aggregated["rmt_trends"][trend_key].append({
                        "layer": layer_idx,
                        "value": avg_head_stats[trend_key]["mean"],
                        "std": avg_head_stats[trend_key]["std"]
                    })

        return aggregated

    def compare_checkpoints(self, checkpoint_results):
        """
        Compare results between different checkpoints.
        Simplified to focus only on metrics needed for the final report.

        Args:
            checkpoint_results: Dictionary of results for each checkpoint

        Returns:
            Dictionary with comparison metrics
        """
        self.log("Comparing checkpoints...")

        steps = sorted(checkpoint_results.keys())
        if len(steps) < 2:
            self.log("Need at least two checkpoints to compare!", force=True)
            return {}

        comparisons = {
            "steps": steps,
            "metrics": {}
        }

        metrics_to_compare = [
            "spectral_gap",
            "participation_ratio",
            "stable_rank",
            "marchenko_pastur_kl",
            "entropy",
            "sink_concentration"
        ]

        # For each metric, track changes across checkpoints
        for metric in metrics_to_compare:
            metric_data = {}

            # Get data for all steps
            for step in steps:
                if step in checkpoint_results and "aggregated_results" in checkpoint_results[step]:
                    agg_results = checkpoint_results[step]["aggregated_results"]
                    if "rmt_trends" in agg_results and metric in agg_results["rmt_trends"]:
                        metric_data[step] = agg_results["rmt_trends"][metric]

            # Calculate overall changes
            if len(metric_data) >= 2:
                first_step = steps[0]
                last_step = steps[-1]

                if first_step in metric_data and last_step in metric_data:
                    # Get layer-wise values for first and last step
                    first_values = {item["layer"]: item["value"] for item in metric_data[first_step]}
                    last_values = {item["layer"]: item["value"] for item in metric_data[last_step]}

                    # Find common layers
                    common_layers = sorted(set(first_values.keys()) & set(last_values.keys()))

                    if common_layers:
                        # Calculate absolute and percentage changes
                        abs_changes = [last_values[layer] - first_values[layer] for layer in common_layers]

                        # Calculate percentage changes safely
                        pct_changes = []
                        for layer in common_layers:
                            if first_values[layer] != 0:
                                pct_change = (last_values[layer] - first_values[layer]) / abs(first_values[layer]) * 100
                            else:
                                pct_change = float('inf') if last_values[layer] > 0 else float('-inf') if last_values[layer] < 0 else 0
                            pct_changes.append(pct_change)

                        comparisons["metrics"][metric] = {
                            "common_layers": common_layers,
                            "first_step_values": [first_values[layer] for layer in common_layers],
                            "last_step_values": [last_values[layer] for layer in common_layers],
                            "absolute_changes": abs_changes,
                            "percentage_changes": pct_changes,
                            "avg_absolute_change": float(np.nanmean(abs_changes)),
                            "avg_percentage_change": float(np.nanmean([p for p in pct_changes if not np.isinf(p)]))
                        }

        return comparisons

    def generate_report(self, checkpoint_results, comparison_results):
        """
        Generate a comprehensive text report summarizing the findings.

        Args:
            checkpoint_results: Dictionary of results for each checkpoint
            comparison_results: Results of comparing checkpoints

        Returns:
            Path to the saved report
        """
        self.log("Generating comprehensive report...", force=True)

        report_path = os.path.join(self.output_dir, "rmt_analysis_report.txt")

        with open(report_path, 'w') as f:
            f.write("RANDOM MATRIX THEORY ANALYSIS OF ATTENTION MATRICES\n")
            f.write("===============================================\n\n")

            f.write(f"Model: {self.model_size}\n")

            steps = sorted(checkpoint_results.keys())
            f.write(f"Checkpoints analyzed: {', '.join(steps)}\n\n")

            if len(steps) >= 2:
                first_step = steps[0]
                last_step = steps[-1]

                f.write(f"COMPARISON FROM {first_step} TO {last_step}\n")
                f.write("--------------------------------\n\n")

                if "metrics" in comparison_results and "spectral_gap" in comparison_results["metrics"]:
                    spectral_gap_data = comparison_results["metrics"]["spectral_gap"]
                    avg_sg_change = spectral_gap_data.get("avg_absolute_change", 0)

                    f.write("SPECTRAL GAP\n")
                    f.write(f"Average change: {avg_sg_change:.4f}\n")
                    if avg_sg_change > 0:
                        f.write("Spectral gap INCREASED during training, suggesting the emergence of\n")
                        f.write("more dominant eigenvalues and more structured attention patterns.\n")
                        f.write("This indicates attention becoming less random and more organized.\n\n")
                    else:
                        f.write("Spectral gap DECREASED during training, suggesting attention patterns\n")
                        f.write("becoming more diffuse without strong dominant directions.\n\n")

                if "metrics" in comparison_results and "participation_ratio" in comparison_results["metrics"]:
                    pr_data = comparison_results["metrics"]["participation_ratio"]
                    avg_pr_change = pr_data.get("avg_absolute_change", 0)

                    f.write("PARTICIPATION RATIO\n")
                    f.write(f"Average change: {avg_pr_change:.4f}\n")
                    if avg_pr_change < 0:
                        f.write("Participation ratio DECREASED during training, indicating attention\n")
                        f.write("becoming more concentrated in fewer dimensions. This is consistent\n")
                        f.write("with the emergence of attention sinks and more structured patterns.\n\n")
                    else:
                        f.write("Participation ratio INCREASED during training, suggesting attention\n")
                        f.write("remains distributed across many dimensions without strong concentration.\n\n")

                if "metrics" in comparison_results and "entropy" in comparison_results["metrics"]:
                    entropy_data = comparison_results["metrics"]["entropy"]
                    avg_entropy_change = entropy_data.get("avg_absolute_change", 0)

                    f.write("ATTENTION ENTROPY\n")
                    f.write(f"Average change: {avg_entropy_change:.4f}\n")
                    if avg_entropy_change < 0:
                        f.write("Attention entropy DECREASED during training, indicating attention\n")
                        f.write("becoming more concentrated and less uniform. This suggests the\n")
                        f.write("development of more focused attention patterns.\n\n")
                    else:
                        f.write("Attention entropy INCREASED during training, suggesting attention\n")
                        f.write("becoming more uniformly distributed across tokens.\n\n")

                if "metrics" in comparison_results and "sink_concentration" in comparison_results["metrics"]:
                    sink_data = comparison_results["metrics"]["sink_concentration"]
                    avg_sink_change = sink_data.get("avg_absolute_change", 0)

                    f.write("SINK CONCENTRATION\n")
                    f.write(f"Average change: {avg_sink_change:.4f}\n")
                    if avg_sink_change > 0:
                        f.write("Sink concentration INCREASED during training, confirming the emergence\n")
                        f.write("of attention sinks where specific tokens receive disproportionate attention.\n")
                        f.write("This is a key feature of trained transformer models.\n\n")
                    else:
                        f.write("Sink concentration DECREASED or remained stable during training,\n")
                        f.write("suggesting less pronounced attention sinks in this model.\n\n")

            f.write("\nLAYER-SPECIFIC ANALYSIS\n")
            f.write("----------------------\n\n")

            first_step = steps[0]
            last_step = steps[-1]

            if first_step in checkpoint_results and "aggregated_results" in checkpoint_results[first_step]:
                first_agg = checkpoint_results[first_step]["aggregated_results"]
                last_agg = checkpoint_results[last_step]["aggregated_results"] if last_step in checkpoint_results else None

                if "layer_stats" in first_agg:
                    layer_indices = sorted([int(idx) for idx in first_agg["layer_stats"].keys()])

                    for layer_idx in layer_indices:
                        layer_key = str(layer_idx)
                        f.write(f"\nLayer {layer_idx}:\n")

                        if layer_key in first_agg["layer_stats"] and "avg_head" in first_agg["layer_stats"][layer_key]:
                            first_metrics = first_agg["layer_stats"][layer_key]["avg_head"]
                            last_metrics = None

                            if last_agg and "layer_stats" in last_agg and layer_key in last_agg["layer_stats"]:
                                if "avg_head" in last_agg["layer_stats"][layer_key]:
                                    last_metrics = last_agg["layer_stats"][layer_key]["avg_head"]

                            for metric in ["spectral_gap", "participation_ratio", "entropy", "sink_concentration"]:
                                if metric in first_metrics and "mean" in first_metrics[metric]:
                                    first_val = first_metrics[metric]["mean"]
                                    last_val = "N/A"
                                    change = "N/A"

                                    if last_metrics and metric in last_metrics and "mean" in last_metrics[metric]:
                                        last_val = last_metrics[metric]["mean"]
                                        change = last_val - first_val

                                    f.write(f"  {metric.replace('_', ' ').title()}: {first_val:.4f} -> {last_val if last_val == 'N/A' else f'{last_val:.4f}'}")
                                    if change != "N/A":
                                        direction = "↑" if change > 0 else "↓" if change < 0 else "="
                                        f.write(f" {direction} ({change:.4f})")
                                    f.write("\n")

            f.write("\nPERFORMANCE METRICS\n")
            f.write("------------------\n")
            f.write(f"Checkpoints analyzed: {len(steps)}\n")

            total_samples = 0
            for step in steps:
                if step in checkpoint_results and "aggregated_results" in checkpoint_results[step]:
                    if "layer_stats" in checkpoint_results[step]["aggregated_results"]:
                        layer_stats = checkpoint_results[step]["aggregated_results"]["layer_stats"]
                        if layer_stats and len(layer_stats) > 0:
                            first_layer = next(iter(layer_stats.values()))
                            if "avg_head" in first_layer and first_layer["avg_head"]:
                                first_metric = next(iter(first_layer["avg_head"].values()))
                                if "count" in first_metric:
                                    total_samples += first_metric["count"]
                                    break

            f.write(f"Total samples processed: {total_samples}\n")

        self.log(f"Analysis report saved to {report_path}", force=True)
        return report_path

    def run_analysis(self, checkpoint_steps=None, texts=None):
        """
        Run the complete analysis pipeline on multiple checkpoints.
        Optimized to focus only on generating the final report.

        Args:
            checkpoint_steps: List of checkpoint steps to analyze (e.g., ["step0", "step143000"])
            texts: List of text samples to analyze (if None, example texts are used)

        Returns:
            Path to the generated report
        """
        start_time = time.time()

        if checkpoint_steps is None:
            checkpoint_steps = ["step0", "step17000", "step72000", "step143000"]

        # Use example texts if none provided
        if texts is None or not texts:
            texts = [
                "The theory of random matrices has applications in many fields including physics, statistics, and data analysis.",
                "Eigenvalues of large random matrices tend to follow specific distributions that depend on the matrix ensemble.",
                "The eigenvectors of random matrices show interesting localization properties as dimensions increase.",
                "Spectral analysis of matrices can reveal structural properties that are not immediately obvious."
            ]

        # Analyze each checkpoint
        checkpoint_results = {}
        for step in checkpoint_steps:
            try:
                self.log(f"\n{'='*60}\nAnalyzing checkpoint: {step}\n{'='*60}\n", force=True)
                results = self.analyze_checkpoint(step, texts)
                checkpoint_results[step] = results
            except Exception as e:
                self.log(f"Error analyzing checkpoint {step}: {e}", force=True)
                if self.verbose:
                    traceback.print_exc()
                continue  # Continue with the next checkpoint even if this one fails

        # Compare checkpoints (only if we have more than one)
        if len(checkpoint_results) > 1:
            self.log("\nComparing checkpoints...", force=True)
            comparison_results = self.compare_checkpoints(checkpoint_results)

            report_path = self.generate_report(checkpoint_results, comparison_results)
        else:
            self.log("Not enough checkpoints to compare. Need at least two.", force=True)
            report_path = None

        results_file = os.path.join(self.output_dir, "aggregated_results.json")
        with open(results_file, 'w') as f:
            serializable_results = {}
            for step, results in checkpoint_results.items():
                if "aggregated_results" in results:
                    serializable_results[step] = results["aggregated_results"]
            json.dump(serializable_results, f)
        self.log(f"Saved aggregated results to {results_file}", force=True)

        total_time = time.time() - start_time
        self.log(f"Analysis completed in {total_time/60:.2f} minutes", force=True)

        return report_path

    def cleanup(self):
        """Clean up resources."""
        if self.model is not None:
            del self.model
            self.model = None

        if self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None

        gc.collect()
        torch.cuda.empty_cache()

        self.log("Resources cleaned up", force=True)


def get_sample_texts_from_dataset(dataset_path, n_samples=50):
    """
    Extract sample texts from a dataset for analysis.

    Args:
        dataset_path: Path to the dataset CSV
        n_samples: Number of samples to extract

    Returns:
        List of text samples
    """
    try:
        data = pd.read_csv(dataset_path)
        print(f"Loaded dataset with {len(data)} rows")

        if 'text' not in data.columns:
            print("Error: No 'text' column found in dataset")
            return []

        if len(data) > n_samples:
            samples = data.sample(n_samples)
        else:
            samples = data

        texts = samples['text'].tolist()
        return texts

    except Exception as e:
        print(f"Error loading dataset: {e}")
        return []



In [None]:

def main():
    """Run the optimized analysis pipeline."""
    try:
        from google.colab import drive
        # Mount Google Drive if in Colab
        drive.mount('/content/drive')
        is_colab = True
    except ImportError:
        is_colab = False
        print("Not running in Google Colab, skipping drive mount")


    model_size = "pythia-2.8b"  # Options: "pythia-70m", "pythia-160m", "pythia-410m", "pythia-1b", "pythia-2.8b", "pythia-6.9b", "pythia-12b"


    # Full training is 143000 steps
    checkpoint_steps = ["step0", "step1", "step2", "step4", "step8", "step2000", "step143000"]

    # Path settings
    if is_colab:
        dataset_path = "/content/drive/MyDrive/wiki_dataset_position.csv"
        output_dir = f"/content/drive/MyDrive/rmt/{model_name.split('/')[-1]}"
    else:
        dataset_path = "./wiki_dataset_position.csv"
        output_dir = f"./pythia_attention_rmt_{model_size}_optimized"

    num_samples = 100      
    verbose = False       # Set to True for detailed progress logs

    print(f"Starting analysis with model: {model_size}")
    print(f"Dataset path: {dataset_path}")
    print(f"Output directory: {output_dir}")
    print(f"Analyzing checkpoints: {', '.join(checkpoint_steps)}")
    print(f"Using {num_samples} samples")

    analyzer = OptimizedAttentionMatrixRMTAnalysis(
        model_size=model_size,
        output_dir=output_dir,
        verbose=verbose
    )

    texts = get_sample_texts_from_dataset(dataset_path, n_samples=num_samples)

    if not texts:
        print("No texts loaded from dataset. Using default examples.")
        texts = [
            "The theory of random matrices has applications in many fields including physics, statistics, and data analysis.",
            "Eigenvalues of large random matrices tend to follow specific distributions that depend on the matrix ensemble.",
            "The eigenvectors of random matrices show interesting localization properties as dimensions increase.",
            "Spectral analysis of matrices can reveal structural properties that are not immediately obvious."
        ]

    try:
        report_path = analyzer.run_analysis(
            checkpoint_steps=checkpoint_steps,
            texts=texts
        )
        print(f"Analysis completed successfully! Report saved to: {report_path}")
    except Exception as e:
        print(f"Analysis error: {e}")
        traceback.print_exc()
    finally:
        # Clean up resources
        analyzer.cleanup()


if __name__ == "__main__":
    main()