In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import gc
import json
import pandas as pd
import networkx as nx
import scipy.linalg
from transformers import AutoModelForCausalLM, AutoTokenizer
from google.colab import drive
import time
from tqdm import tqdm
import seaborn as sns
import traceback
from concurrent.futures import ThreadPoolExecutor
import warnings
import scipy.sparse as sp
import scipy.sparse.linalg as spla

In [None]:

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

class OptimizedAttentionLaplacianAnalysis:
    """
    Optimized version of the attention Laplacian analysis that:
    1. Eliminates individual sample storage (avoids circular references)
    2. Directly aggregates metrics during analysis
    3. Reduces I/O operations
    4. Optimizes Laplacian computations
    """

    def __init__(self, model_name, output_dir="./attention_laplacian_analysis"):
        """Initialize the analysis with model name and output directory."""
        self.model_name = model_name
        self.output_dir = output_dir
        self.model = None
        self.tokenizer = None
        self.aggregated_results = {}

        self.sample_summary = []

        os.makedirs(output_dir, exist_ok=True)

        self.vis_dir = os.path.join(output_dir, "visualizations")
        os.makedirs(self.vis_dir, exist_ok=True)

        self.peak_memory = 0
        self.track_memory_usage = True

    def _track_memory(self, message=""):
        """Track GPU memory usage."""
        if not self.track_memory_usage or not torch.cuda.is_available():
            return

        current = torch.cuda.memory_allocated() / (1024 ** 3)  # GB
        peak = torch.cuda.max_memory_allocated() / (1024 ** 3)  # GB

        if peak > self.peak_memory:
            self.peak_memory = peak
            if message:
                print(f"Memory usage ({message}): Current {current:.2f} GB, Peak {peak:.2f} GB")

    def load_model(self, use_4bit=True):
        """Load the LLaMA model and tokenizer."""
        print(f"Loading model: {self.model_name}")

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        if use_4bit:
            try:
                from transformers import BitsAndBytesConfig

                quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.float16
                )

                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    torch_dtype=torch.float16,
                    quantization_config=quantization_config,
                    device_map="auto",
                    trust_remote_code=True
                )
            except Exception as e:
                print(f"Error with 4-bit quantization: {e}")
                print("Falling back to standard loading...")
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    trust_remote_code=True
                )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True
            )

        gc.collect()
        torch.cuda.empty_cache()
        self._track_memory("After model loading")

        print(f"Model loaded: {self.model_name}")
        return self.model, self.tokenizer

    def compute_laplacian_eigenvalues(self, attention_matrix, threshold=0.01, verbose=False):
        """
        Compute Laplacian matrix eigenvalues from an attention matrix.

        Args:
            attention_matrix: Attention weight matrix (n x n)
            threshold: Threshold to binarize the attention weights
            verbose: Whether to print detailed diagnostics (default: False)

        Returns:
            Dictionary with eigenvalues and graph metrics
        """
        if attention_matrix is None or attention_matrix.size == 0:
            if verbose:
                print("Warning: Empty attention matrix")
            return {
                "eigenvalues": [],
                "first_eigenvalue": 0.0,
                "fiedler_value": 0.0,
                "spectral_gap": 0.0,
                "max_in_degree": 0,
                "star_likeness": 0.0,
                "degree_centralization": 0.0,  
                "degree_variance": 0.0,        
                "avg_shortest_path": None,
                "clustering_coefficient": None,
                "error": "Empty attention matrix"
            }

        # Ensure attention matrix has at least 2x2 dimensions for meaningful analysis
        if min(attention_matrix.shape) < 2:
            if verbose:
                print(f"Warning: Attention matrix too small: {attention_matrix.shape}")
            return {
                "eigenvalues": [0.0] * min(attention_matrix.shape),
                "first_eigenvalue": 0.0,
                "fiedler_value": 0.0,
                "spectral_gap": 0.0,
                "max_in_degree": 0,
                "star_likeness": 0.0,
                "degree_centralization": 0.0,
                "degree_variance": 0.0,
                "avg_shortest_path": None,
                "clustering_coefficient": None,
                "error": "Matrix too small for analysis"
            }

        try:
            # Binarize the attention matrix based on threshold to create adjacency matrix
            adjacency_matrix = (attention_matrix > threshold).astype(float)

            # Check if we have enough connections after thresholding
            connection_count = np.sum(adjacency_matrix)

            if connection_count < 2:
                if verbose:
                    print("Warning: Too few connections after thresholding")
                return {
                    "eigenvalues": [0.0] * min(attention_matrix.shape),
                    "first_eigenvalue": 0.0,
                    "fiedler_value": 0.0,
                    "spectral_gap": 0.0,
                    "max_in_degree": 0,
                    "star_likeness": 0.0,
                    "degree_centralization": 0.0,
                    "degree_variance": 0.0,
                    "avg_shortest_path": None,
                    "clustering_coefficient": None,
                    "error": "Too few connections after thresholding"
                }

            # OPTIMIZATION: Use sparse graph representation for large matrices
            use_sparse = attention_matrix.shape[0] > 50

            # Create a graph to analyze the structure - use sparse for large matrices
            if use_sparse:
                G = nx.DiGraph()
                rows, cols = np.where(adjacency_matrix > 0)
                edges = list(zip(rows.tolist(), cols.tolist()))
                G.add_edges_from(edges)
            else:
                G = nx.from_numpy_array(adjacency_matrix, create_using=nx.DiGraph)

            # Compute degree metrics
            degrees = np.array([d for _, d in G.degree()])
            in_degrees = np.array([d for _, d in G.in_degree()])
            out_degrees = np.array([d for _, d in G.out_degree()])

            # Compute star-likeness
            total_nodes = len(in_degrees)
            max_in_degree = np.max(in_degrees) if len(in_degrees) > 0 else 0

            basic_star_likeness = max_in_degree / max(1, total_nodes - 1)

            # 2. Gini coefficient of degree distribution 
            sorted_in_degrees = np.sort(in_degrees)
            if len(sorted_in_degrees) > 0 and np.sum(sorted_in_degrees) > 0:
                n = len(sorted_in_degrees)
                indices = np.arange(1, n+1)
                # Higher gini = more unequal distribution = more star-like
                gini = 2 * np.sum(indices * sorted_in_degrees) / (n * np.sum(sorted_in_degrees)) - (n + 1) / n
            else:
                gini = 0.0

            # 3. Degree centralization 
            max_possible_diff = (total_nodes - 1) * (total_nodes - 1)
            if max_possible_diff > 0 and max_in_degree > 0:
                sum_diff = sum(max_in_degree - d for d in in_degrees)
                degree_centralization = sum_diff / max_possible_diff
            else:
                degree_centralization = 0.0

            # 4. Degree variance 
            degree_variance = np.var(in_degrees) if len(in_degrees) > 1 else 0.0

            # Weight the gini coefficient more heavily as it's less affected by network size
            star_likeness = 0.3 * basic_star_likeness + 0.7 * gini

            # Symmetrized adjacency for a sensible Laplacian
            adj_sym = (adjacency_matrix + adjacency_matrix.T) / 2.0

            # OPTIMIZATION: Compute Laplacian eigenvalues more efficiently
            n = adjacency_matrix.shape[0]

            if n <= 50: 
                degree_matrix = np.diag(np.sum(adj_sym, axis=1))

                laplacian_matrix = degree_matrix - adj_sym

                eigenvalues = np.linalg.eigvalsh(laplacian_matrix)
                eigenvalues = np.sort(eigenvalues)  # Sort in ascending order
            else:  # Large matrix
                # Use sparse approach
                try:
                    adj_sparse = sp.csr_matrix(adj_sym)

                    # Compute degrees and create sparse degree matrix
                    degrees = np.array(adj_sparse.sum(axis=1)).flatten()
                    D = sp.diags(degrees)

                    # Laplacian matrix
                    L = D - adj_sparse

                    eigenvalues = spla.eigsh(L, k=min(3, n-1), which='SM', return_eigenvectors=False)
                    eigenvalues = np.sort(eigenvalues)
                except Exception:
                    # Fallback to dense method if sparse fails
                    degree_matrix = np.diag(np.sum(adj_sym, axis=1))
                    laplacian_matrix = degree_matrix - adj_sym
                    eigenvalues = np.linalg.eigvalsh(laplacian_matrix)
                    eigenvalues = np.sort(eigenvalues)[:3]  # Take only the smallest 3

            # Extract important eigenvalues 
            if len(eigenvalues) > 0:
                first_eigenvalue = float(eigenvalues[0])
            else:
                first_eigenvalue = 0.0

            # Second eigenvalue is the Fiedler value (algebraic connectivity)
            if len(eigenvalues) > 1:
                fiedler_value = float(eigenvalues[1])
            else:
                fiedler_value = 0.0

            # Spectral gap 
            if len(eigenvalues) > 2:
                spectral_gap = float(eigenvalues[2] - eigenvalues[1])
            else:
                spectral_gap = 0.0

            avg_shortest_path = None
            clustering_coef = None

            if n <= 30:  # Only for small graphs
                try:
                    # Use undirected graph for these metrics
                    if use_sparse:
                        # Create symmetric sparse adjacency
                        G_sym = nx.Graph()
                        rows, cols = np.where(adj_sym > 0)
                        edges = list(zip(rows.tolist(), cols.tolist()))
                        G_sym.add_edges_from(edges)
                    else:
                        G_sym = nx.from_numpy_array(adj_sym)

                    # Check if graph is connected before computing path lengths
                    if nx.is_connected(G_sym):
                        avg_shortest_path = nx.average_shortest_path_length(G_sym)
                    else:
                        # Fall back to largest connected component
                        components = list(nx.connected_components(G_sym))
                        if components:
                            largest_cc = max(components, key=len)
                            if len(largest_cc) > 1:  # Need at least 2 nodes for path length
                                largest_cc_graph = G_sym.subgraph(largest_cc)
                                avg_shortest_path = nx.average_shortest_path_length(largest_cc_graph)

                    # Compute clustering 
                    clustering_coef = nx.average_clustering(G_sym)

                except Exception as e:
                    if verbose:
                        print(f"Error computing graph metrics: {str(e)}")

            return {
                "eigenvalues": eigenvalues.tolist()[:3],  # Only store first 3 eigenvalues
                "first_eigenvalue": float(first_eigenvalue),
                "fiedler_value": float(fiedler_value),
                "spectral_gap": float(spectral_gap),
                "max_in_degree": int(max_in_degree),
                "star_likeness": float(star_likeness),
                "basic_star_likeness": float(basic_star_likeness),  
                "gini_coefficient": float(gini),                    
                "degree_centralization": float(degree_centralization),  
                "degree_variance": float(degree_variance),         
                "avg_shortest_path": float(avg_shortest_path) if avg_shortest_path is not None else None,
                "clustering_coefficient": float(clustering_coef) if clustering_coef is not None else None,
                "connections": int(connection_count),
                "total_possible_connections": adjacency_matrix.size,
                "connection_density": float(connection_count/adjacency_matrix.size)
            }

        except Exception as e:
            if verbose:
                print(f"Error in compute_laplacian_eigenvalues: {str(e)}")
                traceback.print_exc()
            return {
                "eigenvalues": [],
                "first_eigenvalue": 0.0,
                "fiedler_value": 0.0,
                "spectral_gap": 0.0,
                "max_in_degree": 0,
                "star_likeness": 0.0,
                "degree_centralization": 0.0,
                "degree_variance": 0.0,
                "avg_shortest_path": None,
                "clustering_coefficient": None,
                "error": str(e)
            }

    def _init_aggregated_data(self, layer_indices, thresholds):
        """Initialize aggregated data structures."""
        # Create data structure for aggregation
        aggregated = {
            "layer_stats": {layer_idx: {} for layer_idx in layer_indices},
            "fiedler_value_trends": {threshold: [] for threshold in thresholds},
            "spectral_gap_trends": {threshold: [] for threshold in thresholds},
            "star_likeness_trends": {threshold: [] for threshold in thresholds},
            "gini_coefficient_trends": {threshold: [] for threshold in thresholds},
            "degree_centralization_trends": {threshold: [] for threshold in thresholds},
            "degree_variance_trends": {threshold: [] for threshold in thresholds},
            "threshold_effectiveness": {threshold: {"effective_layers": []} for threshold in thresholds},
            "model": self.model_name,
            "thresholds": thresholds,
            "samples_analyzed": 0
        }

        metrics_collectors = {}
        for layer_idx in layer_indices:
            metrics_collectors[layer_idx] = {}
            for threshold in thresholds:
                metrics_collectors[layer_idx][threshold] = {
                    "fiedler_values": [],
                    "spectral_gaps": [],
                    "star_likeness": [],
                    "gini_coefficients": [],
                    "degree_centralizations": [],
                    "degree_variances": [],
                    "avg_shortest_paths": [],
                    "clustering_coefficients": [],
                    "connection_densities": []
                }

        return aggregated, metrics_collectors

    def _update_aggregated_metrics(self, metrics_collectors, layer_idx, threshold, results):
        """Update the metrics collectors with results from a single analysis."""
        collectors = metrics_collectors[layer_idx][threshold]

        # Update collectors if the metric exists and is valid
        for metric_name, collector_name in [
            ("fiedler_value", "fiedler_values"),
            ("spectral_gap", "spectral_gaps"),
            ("star_likeness", "star_likeness"),
            ("gini_coefficient", "gini_coefficients"),
            ("degree_centralization", "degree_centralizations"),
            ("degree_variance", "degree_variances"),
            ("avg_shortest_path", "avg_shortest_paths"),
            ("clustering_coefficient", "clustering_coefficients"),
            ("connection_density", "connection_densities")
        ]:
            if metric_name in results and results[metric_name] is not None:
                collectors[collector_name].append(results[metric_name])

    def _compute_aggregated_stats(self, metrics_collectors, aggregated):
        """Compute final statistics from the collected metrics."""
        # For each layer and threshold, compute statistics
        for layer_idx, layer_data in metrics_collectors.items():
            layer_stats = {}

            for threshold, threshold_metrics in layer_data.items():
                threshold_stats = {}

                # For each metric, compute statistics
                for metric_name, metric_values in threshold_metrics.items():
                    if metric_values:
                        threshold_stats[metric_name] = {
                            "mean": float(np.mean(metric_values)),
                            "std": float(np.std(metric_values)) if len(metric_values) > 1 else 0.0,
                            "min": float(np.min(metric_values)) if metric_values else None,
                            "max": float(np.max(metric_values)) if metric_values else None,
                            "count": len(metric_values)
                        }

                layer_stats[str(threshold)] = threshold_stats

                for metric_name, trend_key in [
                    ("fiedler_values", "fiedler_value_trends"),
                    ("spectral_gaps", "spectral_gap_trends"),
                    ("star_likeness", "star_likeness_trends"),
                    ("gini_coefficients", "gini_coefficient_trends"),
                    ("degree_centralizations", "degree_centralization_trends"),
                    ("degree_variances", "degree_variance_trends")
                ]:
                    values = threshold_metrics[metric_name]
                    if values:
                        aggregated[trend_key][threshold].append({
                            "layer": layer_idx,
                            "value": float(np.mean(values)),
                            "std": float(np.std(values)) if len(values) > 1 else 0.0
                        })

                # Check threshold effectiveness
                connection_densities = threshold_metrics["connection_densities"]
                if connection_densities:
                    avg_density = np.mean(connection_densities)
                    if 0.05 <= avg_density <= 0.95:
                        if layer_idx not in aggregated["threshold_effectiveness"][threshold]["effective_layers"]:
                            aggregated["threshold_effectiveness"][threshold]["effective_layers"].append(layer_idx)

            # Add layer stats to aggregated data
            aggregated["layer_stats"][layer_idx] = layer_stats

        # Calculate threshold effectiveness scores
        total_layers = len(aggregated["layer_stats"])
        for threshold in aggregated["threshold_effectiveness"]:
            effective_layers = aggregated["threshold_effectiveness"][threshold]["effective_layers"]
            effectiveness = len(effective_layers) / total_layers if total_layers > 0 else 0
            aggregated["threshold_effectiveness"][threshold]["effectiveness_score"] = effectiveness

        return aggregated

    def analyze_samples_aggregate(self, texts, thresholds=[0.01, 0.05, 0.1, 0.2],
                                 batch_size=10, sample_limit=None):
        """
        Analyze multiple text samples directly into aggregated results without
        storing each individual sample separately.

        Args:
            texts: List of text samples to analyze
            thresholds: List of thresholds for attention binarization
            batch_size: Number of samples to process in each batch
            sample_limit: Maximum number of samples to process (None=all)

        Returns:
            Dictionary with aggregated results
        """
        print(f"Analyzing up to {len(texts)} text samples (limit: {sample_limit or 'None'})...")

        if self.model is None:
            self.load_model()

        if sample_limit is not None:
            texts = texts[:sample_limit]

        print("Processing first sample to initialize aggregation...")
        first_result = self._analyze_single_sample(texts[0], thresholds, get_layer_info=True)
        layer_indices = first_result["layer_indices"]
        print(f"Model has {len(layer_indices)} layers")

        aggregated, metrics_collectors = self._init_aggregated_data(layer_indices, thresholds)

        # Add the first sample's data to the metrics collectors
        for layer_idx in layer_indices:
            for threshold in thresholds:
                layer_results = first_result["layer_results"].get(layer_idx, {})
                threshold_results = layer_results.get(threshold, {})
                if threshold_results:  # Skip if no results for this threshold
                    self._update_aggregated_metrics(
                        metrics_collectors, layer_idx, threshold, threshold_results
                    )

        # Process the remaining samples in batches
        remaining_texts = texts[1:]
        total_samples = 1  # Already processed 1

        # Process in batches
        for batch_start in range(0, len(remaining_texts), batch_size):
            batch_end = min(batch_start + batch_size, len(remaining_texts))
            batch = remaining_texts[batch_start:batch_end]

            print(f"\nProcessing batch {batch_start//batch_size + 1}: samples {batch_start+1+1} to {batch_end+1}")

            # Process each text in the batch
            batch_results = []
            for i, text in enumerate(tqdm(batch, desc="Processing batch")):
                sample_idx = batch_start + i + 1 + 1  
                try:
                    result = self._analyze_single_sample(text, thresholds, get_layer_info=False)
                    batch_results.append((sample_idx, result))

                    # Add to summary
                    self.sample_summary.append({
                        "sample_id": sample_idx,
                        "text_preview": text[:50] + ("..." if len(text) > 50 else ""),
                        "token_count": result.get("token_count", 0),
                        "status": "success"
                    })

                    total_samples += 1

                except Exception as e:
                    print(f"Error analyzing text {sample_idx}: {str(e)}")
                    traceback.print_exc()

                    # Add to summary
                    self.sample_summary.append({
                        "sample_id": sample_idx,
                        "text_preview": text[:50] + ("..." if len(text) > 50 else ""),
                        "status": "error",
                        "error": str(e)
                    })

            # Update metrics collectors with batch results
            for sample_idx, result in batch_results:
                for layer_idx in layer_indices:
                    for threshold in thresholds:
                        layer_results = result["layer_results"].get(layer_idx, {})
                        threshold_results = layer_results.get(threshold, {})
                        if threshold_results:  # Skip if no results for this threshold
                            self._update_aggregated_metrics(
                                metrics_collectors, layer_idx, threshold, threshold_results
                            )

            # Force cleanup between batches
            torch.cuda.empty_cache()
            gc.collect()
            self._track_memory(f"After batch {batch_start//batch_size + 1}")

            # Save intermediate aggregated results periodically
            if (batch_start + batch_size) % (batch_size * 5) == 0 or batch_end == len(remaining_texts):
                print("Generating intermediate aggregated results...")
                interim_aggregated = self._compute_aggregated_stats(metrics_collectors,
                                                                    aggregated.copy())
                interim_aggregated["samples_analyzed"] = total_samples

                # Save interim results
                interim_file = os.path.join(self.output_dir, f"interim_results_{total_samples}_samples.json")
                with open(interim_file, 'w') as f:
                    json.dump(interim_aggregated, f, indent=2)
                print(f"Saved interim results after {total_samples} samples.")

        # Compute final aggregated statistics
        print("\nComputing final aggregated statistics...")
        aggregated = self._compute_aggregated_stats(metrics_collectors, aggregated)
        aggregated["samples_analyzed"] = total_samples

        # Save sample summary
        summary_file = os.path.join(self.output_dir, "sample_summary.json")
        with open(summary_file, 'w') as f:
            json.dump(self.sample_summary, f, indent=2)

        self.aggregated_results = aggregated

        # Save full aggregated results
        results_file = os.path.join(self.output_dir, "aggregated_results.json")
        with open(results_file, 'w') as f:
            json.dump(aggregated, f, indent=2)
        print(f"Saved aggregated results from {total_samples} samples to {results_file}")

        return aggregated

    def _analyze_single_sample(self, text, thresholds, get_layer_info=False):
        """
        Analyze a single text sample and extract metrics without storing full results.

        Args:
            text: Text to analyze
            thresholds: List of thresholds to use
            get_layer_info: Whether to extract and return layer indices (for first sample)

        Returns:
            Dictionary with minimal results and metrics
        """
        text_preview = text[:30] + ("..." if len(text) > 30 else "")

        inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
        token_count = inputs["input_ids"].shape[1]

        with torch.no_grad():
            outputs = self.model(
                **inputs,
                output_attentions=True,
                return_dict=True
            )

        # Extract attention patterns
        attentions = outputs.attentions  # tuple of (layer, batch, head, seq_len, seq_len)

        # Get layer indices if needed
        layer_indices = list(range(len(attentions))) if get_layer_info else None

        layer_results = {}
        threshold_effects = {}

        # Process each layer
        for layer_idx, layer_attention in enumerate(attentions):
            # Each layer has shape (batch, head, seq_len, seq_len)
            layer_attention = layer_attention[0]  # shape: (head, seq_len, seq_len)

            # Compute average attention pattern across all heads
            avg_attention = layer_attention.mean(dim=0).cpu().numpy()

            # Store layer results
            layer_results[layer_idx] = {}

            # Analyze for each threshold
            for threshold in thresholds:
                # Check threshold effects
                adj_matrix = (avg_attention > threshold).astype(float)
                connection_density = np.sum(adj_matrix) / adj_matrix.size
                is_effective = 0.05 <= connection_density <= 0.95

                if threshold not in threshold_effects:
                    threshold_effects[threshold] = []

                threshold_effects[threshold].append({
                    "layer": layer_idx,
                    "connection_density": float(connection_density),
                    "effective": is_effective
                })

                # Compute eigenvalues and metrics - only if this is an effective threshold
                if is_effective or layer_idx % 4 == 0:  
                    layer_results[layer_idx][threshold] = self.compute_laplacian_eigenvalues(
                        avg_attention, threshold=threshold, verbose=False
                    )
                else:
                    # Just store the connection density for non-effective thresholds
                    layer_results[layer_idx][threshold] = {
                        "connection_density": float(connection_density)
                    }

        # Free memory
        del outputs, attentions, layer_attention, avg_attention
        torch.cuda.empty_cache()

        return {
            "token_count": token_count,
            "layer_indices": layer_indices,
            "layer_results": layer_results,
            "threshold_effects": threshold_effects
        }

    def visualize_results(self):
        """Create visualizations of the aggregated results."""
        if not self.aggregated_results:
            print("No aggregated results to visualize.")
            return

        print("Generating visualizations...")

        thresholds = self.aggregated_results["thresholds"]

        self._visualize_threshold_effectiveness()
        self._visualize_fiedler_value_trends()
        self._visualize_star_likeness_metrics()
        self._visualize_correlation_plots()
        self._visualize_fiedler_heatmap()
        self._visualize_layer_adjacency_comparison()

        print(f"Visualizations saved to {self.vis_dir}")

    def _visualize_threshold_effectiveness(self):
        """Visualize threshold effectiveness."""
        plt.figure(figsize=(10, 6))
        effectiveness_scores = []
        threshold_labels = []

        for threshold in sorted(self.aggregated_results["thresholds"]):
            if threshold in self.aggregated_results["threshold_effectiveness"]:
                score = self.aggregated_results["threshold_effectiveness"][threshold].get("effectiveness_score", 0)
                effectiveness_scores.append(score)
                threshold_labels.append(str(threshold))

        plt.bar(threshold_labels, effectiveness_scores, color='skyblue')
        plt.title('Threshold Effectiveness (% of Effective Layers)')
        plt.xlabel('Threshold Value')
        plt.ylabel('Effectiveness Score')
        plt.ylim(0, 1.0)
        plt.grid(axis='y', linestyle='--', alpha=0.7)

        plot_path = os.path.join(self.vis_dir, 'threshold_effectiveness.png')
        plt.savefig(plot_path)
        plt.close()

    def _visualize_fiedler_value_trends(self):
        """Visualize Fiedler value trends across layers."""
        plt.figure(figsize=(12, 8))

        has_data = False
        for threshold in self.aggregated_results["thresholds"]:
            if threshold not in self.aggregated_results["fiedler_value_trends"]:
                continue

            trend_data = self.aggregated_results["fiedler_value_trends"][threshold]
            if not trend_data:
                continue

            trend_data = sorted(trend_data, key=lambda x: x["layer"])

            layers = [item["layer"] for item in trend_data]
            values = [item["value"] for item in trend_data]
            stds = [item["std"] for item in trend_data]

            plt.errorbar(
                layers, values, yerr=stds,
                marker='o', linestyle='-',
                label=f'Threshold = {threshold}'
            )
            has_data = True

        if has_data:
            plt.title('Algebraic Connectivity (Fiedler Value) Across Layers')
            plt.xlabel('Layer')
            plt.ylabel('Fiedler Value')
            plt.grid(True, linestyle='--', alpha=0.7)
            plt.legend()
            plt.tight_layout()

            plot_path = os.path.join(self.vis_dir, 'fiedler_value_trends.png')
            plt.savefig(plot_path)
        else:
            print("No data available for Fiedler value trends plot")

        plt.close()

    def _visualize_star_likeness_metrics(self):
        """Visualize star-likeness metrics across layers."""
        # First try the gini coefficient as it's more stable
        if "gini_coefficient_trends" in self.aggregated_results:
            plt.figure(figsize=(12, 8))

            has_data = False
            for threshold in self.aggregated_results["thresholds"]:
                if threshold not in self.aggregated_results["gini_coefficient_trends"]:
                    continue

                trend_data = self.aggregated_results["gini_coefficient_trends"][threshold]
                if not trend_data:
                    continue

                trend_data = sorted(trend_data, key=lambda x: x["layer"])

                layers = [item["layer"] for item in trend_data]
                values = [item["value"] for item in trend_data]
                stds = [item["std"] for item in trend_data]

                plt.errorbar(
                    layers, values, yerr=stds,
                    marker='o', linestyle='-',
                    label=f'Threshold = {threshold}'
                )
                has_data = True

            if has_data:
                plt.title('Gini Coefficient of Degree Distribution Across Layers')
                plt.xlabel('Layer')
                plt.ylabel('Gini Coefficient (higher = more star-like)')
                plt.grid(True, linestyle='--', alpha=0.7)
                plt.legend()
                plt.tight_layout()

                plot_path = os.path.join(self.vis_dir, 'gini_coefficient_trends.png')
                plt.savefig(plot_path)

            plt.close()

        plt.figure(figsize=(12, 8))

        has_data = False
        for threshold in self.aggregated_results["thresholds"]:
            if threshold not in self.aggregated_results["star_likeness_trends"]:
                continue

            trend_data = self.aggregated_results["star_likeness_trends"][threshold]
            if not trend_data:
                continue

            trend_data = sorted(trend_data, key=lambda x: x["layer"])

            layers = [item["layer"] for item in trend_data]
            values = [item["value"] for item in trend_data]
            stds = [item["std"] for item in trend_data]

            plt.errorbar(
                layers, values, yerr=stds,
                marker='o', linestyle='-',
                label=f'Threshold = {threshold}'
            )
            has_data = True

        if has_data:
            plt.title('Improved Star-likeness Metric Across Layers')
            plt.xlabel('Layer')
            plt.ylabel('Star-likeness (higher = more star-like)')
            plt.grid(True, linestyle='--', alpha=0.7)
            plt.legend()
            plt.tight_layout()

            # Save the plot
            plot_path = os.path.join(self.vis_dir, 'star_likeness_trends.png')
            plt.savefig(plot_path)

        plt.close()

    def _visualize_correlation_plots(self):
        """Visualize correlations between metrics."""
        # Get the most effective threshold
        best_threshold = None
        max_effectiveness = 0
        for threshold in self.aggregated_results["thresholds"]:
            if threshold in self.aggregated_results["threshold_effectiveness"]:
                effectiveness = self.aggregated_results["threshold_effectiveness"][threshold].get("effectiveness_score", 0)
                if effectiveness > max_effectiveness:
                    max_effectiveness = effectiveness
                    best_threshold = threshold

        if best_threshold is None and self.aggregated_results["thresholds"]:
            best_threshold = self.aggregated_results["thresholds"][0]  # Fallback

        # List of metrics to try correlating with Fiedler value
        correlation_metrics = [
            ("gini_coefficient_trends", "Gini Coefficient"),
            ("star_likeness_trends", "Improved Star-likeness"),
            ("degree_centralization_trends", "Degree Centralization"),
            ("degree_variance_trends", "Degree Variance")
        ]

        all_correlations = []

        for trend_key, metric_name in correlation_metrics:
            if trend_key not in self.aggregated_results:
                continue

            plt.figure(figsize=(12, 8))

            has_data = False

            for threshold in self.aggregated_results["thresholds"]:
                if threshold not in self.aggregated_results["fiedler_value_trends"] or \
                   threshold not in self.aggregated_results[trend_key]:
                    continue

                fiedler_trend = self.aggregated_results["fiedler_value_trends"][threshold]
                star_trend = self.aggregated_results[trend_key][threshold]

                if not fiedler_trend or not star_trend:
                    continue

                # Match up layers
                fiedler_by_layer = {item["layer"]: item["value"] for item in fiedler_trend}
                star_by_layer = {item["layer"]: item["value"] for item in star_trend}

                # Get layers present in both
                common_layers = sorted(set(fiedler_by_layer.keys()) & set(star_by_layer.keys()))

                if not common_layers:
                    continue

                # Extract paired values
                fiedler_values = [fiedler_by_layer[layer] for layer in common_layers]
                star_values = [star_by_layer[layer] for layer in common_layers]

                # Check if there's actually variation in the values
                if np.std(star_values) < 1e-10 or np.std(fiedler_values) < 1e-10:
                    continue

                plt.scatter(
                    star_values, fiedler_values,
                    label=f'Threshold = {threshold}',
                    alpha=0.7, s=80
                )

                # Highlight best threshold with different marker
                if threshold == best_threshold:
                    # Add a linear regression line for the best threshold
                    z = np.polyfit(star_values, fiedler_values, 1)
                    p = np.poly1d(z)
                    x_range = np.linspace(min(star_values), max(star_values), 100)
                    plt.plot(x_range, p(x_range), '--', color='red', alpha=0.7)

                # Add text labels for layers
                for i, layer in enumerate(common_layers):
                    plt.annotate(
                        f"L{layer}",
                        (star_values[i], fiedler_values[i]),
                        xytext=(5, 5),
                        textcoords='offset points'
                    )

                # Calculate correlation
                if len(fiedler_values) > 1:
                    corr = np.corrcoef(star_values, fiedler_values)[0, 1]
                    all_correlations.append((threshold, metric_name, corr))

                has_data = True

            if has_data:
                plt.title(f'Correlation: {metric_name} vs. Algebraic Connectivity')
                plt.xlabel(f'{metric_name} (higher = more star-like)')
                plt.ylabel('Fiedler Value (Algebraic Connectivity)')
                plt.grid(True, linestyle='--', alpha=0.7)
                plt.legend()
                plt.tight_layout()

                metric_key = trend_key.replace("_trends", "")
                plot_path = os.path.join(self.vis_dir, f'{metric_key}_fiedler_correlation.png')
                plt.savefig(plot_path)
            else:
                print(f"No data available for {metric_name} correlation plot")

            plt.close()

        # Save correlations to file
        corr_file = os.path.join(self.vis_dir, 'correlations.txt')
        with open(corr_file, 'w') as f:
            f.write("CORRELATIONS WITH FIEDLER VALUE\n")
            f.write("====================================================\n\n")
            f.write("{:<10} {:<25} {:<10}\n".format("Threshold", "Metric", "Correlation"))
            f.write("-" * 50 + "\n")
            for threshold, metric, corr in sorted(all_correlations, key=lambda x: (x[0], x[1])):
                f.write("{:<10} {:<25} {:<10.4f}\n".format(threshold, metric, corr))

    def _visualize_fiedler_heatmap(self):
        """Create a heatmap of Fiedler values."""
        # Get the most effective threshold
        best_threshold = None
        max_effectiveness = 0
        for threshold in self.aggregated_results["thresholds"]:
            if threshold in self.aggregated_results["threshold_effectiveness"]:
                effectiveness = self.aggregated_results["threshold_effectiveness"][threshold].get("effectiveness_score", 0)
                if effectiveness > max_effectiveness:
                    max_effectiveness = effectiveness
                    best_threshold = threshold

        if best_threshold is None and self.aggregated_results["thresholds"]:
            best_threshold = self.aggregated_results["thresholds"][0]  # Fallback

        # Create heatmap for the best threshold
        if best_threshold is not None and best_threshold in self.aggregated_results["fiedler_value_trends"]:
            best_trend_data = self.aggregated_results["fiedler_value_trends"][best_threshold]
            if best_trend_data:
                best_trend_data = sorted(best_trend_data, key=lambda x: x["layer"])

                layers = np.array([item["layer"] for item in best_trend_data])
                values = np.array([item["value"] for item in best_trend_data])

                # Reshape for heatmap
                heatmap_data = values.reshape(1, -1)

                plt.figure(figsize=(14, 4))
                sns.heatmap(
                    heatmap_data,
                    cmap='viridis',
                    annot=True,
                    fmt=".3f",
                    xticklabels=layers,
                    yticklabels=["Fiedler Value"]
                )

                plt.title(f'Algebraic Connectivity Across Layers (Threshold = {best_threshold})')
                plt.xlabel('Layer')
                plt.tight_layout()

                plot_path = os.path.join(self.vis_dir, 'fiedler_value_heatmap.png')
                plt.savefig(plot_path)
            else:
                print("No data available for Fiedler value heatmap")
        else:
            print("Could not determine best threshold for Fiedler value heatmap")

        plt.close()

    def _visualize_layer_adjacency_comparison(self):
        """Create a comparison of adjacency matrices from early, middle, and late layers."""
        if not self.sample_summary:
            print("No sample summary available for adjacency comparison")
            return

        if self.model is None:
            print("Model not loaded, skipping layer adjacency comparison")
            return

        success_samples = [s for s in self.sample_summary if s["status"] == "success"]
        if not success_samples:
            print("No successful samples found for adjacency comparison")
            return

        # Use the first successful sample
        sample_text = None
        for sample in success_samples:
            preview = sample["text_preview"]
            if preview.endswith("..."):
                sample_text = preview[:-3] 
            else:
                sample_text = preview

            if len(sample_text) > 10:  # Ensure we have enough text
                break

        if not sample_text:
            print("No suitable sample text found for adjacency comparison")
            return

        # Get layer indices from aggregated results
        layer_indices = sorted([int(idx) for idx in self.aggregated_results["layer_stats"].keys()])

        if len(layer_indices) < 3:
            print("Not enough layers for comparison")
            return

        # Select representative layers
        early_idx = layer_indices[0]
        middle_idx = layer_indices[len(layer_indices) // 2]
        late_idx = layer_indices[-1]

        selected_layers = [
            ("Early", early_idx),
            ("Middle", middle_idx),
            ("Late", late_idx)
        ]

        # Find the most effective threshold to use
        thresholds = self.aggregated_results["thresholds"]
        best_threshold = None
        max_effectiveness = 0

        if "threshold_effectiveness" in self.aggregated_results:
            for threshold in thresholds:
                if threshold in self.aggregated_results["threshold_effectiveness"]:
                    effectiveness = self.aggregated_results["threshold_effectiveness"][threshold].get("effectiveness_score", 0)
                    if effectiveness > max_effectiveness:
                        max_effectiveness = effectiveness
                        best_threshold = threshold

        if best_threshold is None and thresholds:
            best_threshold = thresholds[0]  # Fallback

        print(f"Using threshold {best_threshold} for layer comparison visualization")

        plt.figure(figsize=(15, 5))

        for i, (stage, layer_idx) in enumerate(selected_layers):
            plt.subplot(1, 3, i+1)

            # Get properties from aggregated results
            layer_data = self.aggregated_results["layer_stats"].get(layer_idx, {})
            threshold_data = layer_data.get(str(best_threshold), {})

            fiedler_data = threshold_data.get("fiedler_values", {})
            star_data = None

            # Try to get the best star-likeness metric available
            for metric_name in ["gini_coefficients", "star_likeness", "degree_centralizations"]:
                if metric_name in threshold_data:
                    star_data = threshold_data[metric_name]
                    star_metric_name = metric_name.replace("s", "")  # Remove plural
                    break

            if star_data is None:
                star_metric_name = "star_likeness"
                star_data = threshold_data.get("star_likeness", {})

            fiedler_value = fiedler_data.get("mean", "N/A") if isinstance(fiedler_data, dict) else "N/A"
            star_likeness = star_data.get("mean", "N/A") if isinstance(star_data, dict) else "N/A"

            # Generate a new adjacency matrix for visualization
            try:
                inputs = self.tokenizer(sample_text, return_tensors="pt").to(self.model.device)

                with torch.no_grad():
                    outputs = self.model(
                        **inputs,
                        output_attentions=True,
                        return_dict=True
                    )

                # Check if we have enough layers
                if layer_idx >= len(outputs.attentions):
                    plt.text(0.5, 0.5, f"Layer {layer_idx} exceeds model layers",
                             ha='center', va='center', transform=plt.gca().transAxes)
                    continue

                layer_attention = outputs.attentions[layer_idx][0]  # (head, seq, seq)
                avg_attention = layer_attention.mean(dim=0).cpu().numpy()

                # Create binary adjacency for visualization
                adjacency_matrix = (avg_attention > best_threshold).astype(float)

                plt.imshow(adjacency_matrix, cmap='Blues')
                plt.colorbar(label='Connection')

                title_parts = [f'{stage} Layer (L{layer_idx})']
                if fiedler_value != "N/A":
                    title_parts.append(f"Fiedler={fiedler_value:.3f}")
                if star_likeness != "N/A":
                    title_parts.append(f"{star_metric_name.capitalize()}={star_likeness:.3f}")

                plt.title('\n'.join(title_parts))

            except Exception as e:
                print(f"Error visualizing layer {layer_idx}: {e}")
                plt.text(0.5, 0.5, f"Visualization failed: {str(e)[:30]}...",
                         ha='center', va='center', transform=plt.gca().transAxes)

            plt.axis('on')

        plt.tight_layout()

        plot_path = os.path.join(self.vis_dir, 'layer_adjacency_comparison.png')
        plt.savefig(plot_path)
        print(f"Saved layer comparison to {plot_path}")
        plt.close()

    def generate_report(self):
        """Generate a text report summarizing the findings."""
        if not self.aggregated_results:
            print("No aggregated results to report.")
            return None

        print("Generating comprehensive report...")

        report_path = os.path.join(self.output_dir, "laplacian_eigenvalue_report.txt")

        with open(report_path, 'w') as f:
            f.write("ATTENTION GRAPH LAPLACIAN EIGENVALUE ANALYSIS\n")
            f.write("===========================================\n\n")

            f.write(f"Model: {self.model_name}\n")
            f.write(f"Analysis date: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")

            n_samples = self.aggregated_results.get("samples_analyzed", 0)
            f.write(f"Number of samples analyzed: {n_samples}\n\n")

            thresholds = self.aggregated_results["thresholds"]
            f.write("Thresholds used for attention binarization:\n")
            for threshold in thresholds:
                f.write(f"  - {threshold}\n")
            f.write("\n")

            f.write("THRESHOLD EFFECTIVENESS ANALYSIS\n")
            f.write("-------------------------------\n\n")

            best_threshold = None
            max_effectiveness = 0

            if "threshold_effectiveness" in self.aggregated_results:
                for threshold in thresholds:
                    if threshold in self.aggregated_results["threshold_effectiveness"]:
                        effectiveness = self.aggregated_results["threshold_effectiveness"][threshold].get("effectiveness_score", 0)
                        effective_layers = len(self.aggregated_results["threshold_effectiveness"][threshold].get("effective_layers", []))
                        total_layers = len(self.aggregated_results["layer_stats"])

                        f.write(f"Threshold {threshold}:\n")
                        f.write(f"  Effectiveness score: {effectiveness:.2f}\n")
                        f.write(f"  Effective on {effective_layers}/{total_layers} layers\n")

                        if effectiveness > max_effectiveness:
                            max_effectiveness = effectiveness
                            best_threshold = threshold

                if best_threshold is not None:
                    f.write(f"\nMost effective threshold: {best_threshold} (score: {max_effectiveness:.2f})\n")
                    f.write(f"This threshold creates meaningful graph structures across the highest\n")
                    f.write(f"percentage of layers, with connection densities between 5% and 95%.\n")

            f.write("\n")

            f.write("ALGEBRAIC CONNECTIVITY ANALYSIS\n")
            f.write("------------------------------\n\n")

            for threshold in thresholds:
                if threshold not in self.aggregated_results["fiedler_value_trends"]:
                    continue

                fiedler_trend = self.aggregated_results["fiedler_value_trends"][threshold]

                if not fiedler_trend:
                    continue

                f.write(f"Threshold {threshold}:\n")

                fiedler_trend = sorted(fiedler_trend, key=lambda x: x["layer"])

                max_fiedler_idx = np.argmax([item["value"] for item in fiedler_trend])
                max_fiedler_layer = fiedler_trend[max_fiedler_idx]["layer"]
                max_fiedler_value = fiedler_trend[max_fiedler_idx]["value"]

                f.write(f"  Maximum algebraic connectivity at layer {max_fiedler_layer}: {max_fiedler_value:.4f}\n")

                # Check if there's a peak in the middle
                n_layers = len(fiedler_trend)
                if n_layers >= 5:  # Need enough layers to meaningfully define "middle"
                    early_layers = fiedler_trend[:n_layers//3]
                    middle_layers = fiedler_trend[n_layers//3:2*n_layers//3]
                    late_layers = fiedler_trend[2*n_layers//3:]

                    early_avg = np.mean([item["value"] for item in early_layers])
                    middle_avg = np.mean([item["value"] for item in middle_layers])
                    late_avg = np.mean([item["value"] for item in late_layers])

                    f.write(f"  Average Fiedler values:\n")
                    f.write(f"    Early layers: {early_avg:.4f}\n")
                    f.write(f"    Middle layers: {middle_avg:.4f}\n")
                    f.write(f"    Late layers: {late_avg:.4f}\n")

                    if middle_avg > early_avg and middle_avg > late_avg:
                        peak_factor = min(middle_avg/early_avg, middle_avg/late_avg)
                        f.write(f"  ✓ CONFIRMED: Middle layers show peak algebraic connectivity ({peak_factor:.2f}x higher)\n")
                    elif middle_avg > early_avg:
                        f.write(f"  ± PARTIAL: Middle layers have higher connectivity than early layers, but not late layers\n")
                    elif middle_avg > late_avg:
                        f.write(f"  ± PARTIAL: Middle layers have higher connectivity than late layers, but not early layers\n")
                    else:
                        f.write(f"  ✗ NOT CONFIRMED: Middle layers do not show peak algebraic connectivity\n")

                f.write("\n")

            f.write("GRAPH STRUCTURE ANALYSIS\n")
            f.write("---------------------------\n\n")

            star_metrics = [
                ("star_likeness_trends", "Star-likeness"),
                ("gini_coefficient_trends", "Gini coefficient"),
                ("degree_centralization_trends", "Degree centralization"),
                ("degree_variance_trends", "Degree variance")
            ]

            for trend_key, metric_name in star_metrics:
                if trend_key not in self.aggregated_results:
                    continue

                f.write(f"{metric_name.upper()} ANALYSIS\n")
                f.write("-" * len(f"{metric_name.upper()} ANALYSIS") + "\n\n")

                for threshold in thresholds:
                    if threshold not in self.aggregated_results[trend_key]:
                        continue

                    trend_data = self.aggregated_results[trend_key][threshold]

                    if not trend_data:
                        continue

                    f.write(f"Threshold {threshold}:\n")

                    trend_data = sorted(trend_data, key=lambda x: x["layer"])

                    # Find the layer with maximum value
                    max_idx = np.argmax([item["value"] for item in trend_data])
                    max_layer = trend_data[max_idx]["layer"]
                    max_value = trend_data[max_idx]["value"]

                    f.write(f"  Maximum {metric_name} at layer {max_layer}: {max_value:.4f}\n")

                    # Check for variation across layers
                    values = [item["value"] for item in trend_data]
                    if np.std(values) < 1e-10:
                        f.write(f"  ⚠ WARNING: {metric_name} values show almost no variation across layers.\n")
                        f.write(f"    This may indicate an issue with the analysis or the threshold value.\n")

                    # Check if there's a peak in the middle
                    n_layers = len(trend_data)
                    if n_layers >= 5:  # Need enough layers
                        early_layers = trend_data[:n_layers//3]
                        middle_layers = trend_data[n_layers//3:2*n_layers//3]
                        late_layers = trend_data[2*n_layers//3:]

                        early_avg = np.mean([item["value"] for item in early_layers])
                        middle_avg = np.mean([item["value"] for item in middle_layers])
                        late_avg = np.mean([item["value"] for item in late_layers])

                        f.write(f"  Average {metric_name}:\n")
                        f.write(f"    Early layers: {early_avg:.4f}\n")
                        f.write(f"    Middle layers: {middle_avg:.4f}\n")
                        f.write(f"    Late layers: {late_avg:.4f}\n")

                        # Check if middle layers have higher values
                        if middle_avg > early_avg and middle_avg > late_avg:
                            peak_factor = min(middle_avg/early_avg, middle_avg/late_avg)
                            f.write(f"  ✓ CONFIRMED: Middle layers show peak {metric_name} ({peak_factor:.2f}x higher)\n")
                        elif middle_avg > early_avg:
                            f.write(f"  ± PARTIAL: Middle layers have higher {metric_name} than early layers, but not late layers\n")
                        elif middle_avg > late_avg:
                            f.write(f"  ± PARTIAL: Middle layers have higher {metric_name} than late layers, but not early layers\n")
                        else:
                            f.write(f"  ✗ NOT CONFIRMED: Middle layers do not show peak {metric_name}\n")

                    f.write("\n")

                f.write("\n")

            f.write("CORRELATION ANALYSIS\n")
            f.write("-------------------\n\n")

            corr_file = os.path.join(self.vis_dir, 'correlations.txt')
            if os.path.exists(corr_file):
                with open(corr_file, 'r') as corr_f:
                    correlations = corr_f.read()
                f.write(correlations)
                f.write("\n")
            else:
                f.write("No correlation data available.\n\n")

            f.write("SUMMARY OF FINDINGS\n")
            f.write("-----------------\n\n")

            f.write("\nADDITIONAL NOTES\n")
            f.write("---------------\n")
            f.write("1. Threshold selection is critical for this analysis. The results suggest that\n")
            f.write("   values around 0.01-0.05 provide the most meaningful graph structures for\n")
            f.write("   the LLaMA model analyzed.\n\n")

            f.write("2. The improved star-likeness metrics (particularly the Gini coefficient of\n")
            f.write("   degree distribution) provide more stable measurements of attention graph\n")
            f.write("   structure than the basic star-likeness ratio used in earlier analyses.\n\n")

            f.write("\nPERFORMANCE METRICS\n")
            f.write("------------------\n")
            f.write(f"Samples processed: {n_samples}\n")
            f.write(f"Peak memory usage: {self.peak_memory:.2f} GB\n")

            f.write("\nVISUALIZATIONS\n")
            f.write("---------------\n")
            f.write("The following visualizations are available in the 'visualizations' directory:\n")
            f.write("1. threshold_effectiveness.png - Comparison of threshold effectiveness\n")
            f.write("2. fiedler_value_trends.png - Algebraic connectivity across layers\n")
            f.write("3. gini_coefficient_trends.png - Gini coefficient of degree distribution\n")
            f.write("4. star_likeness_trends.png - Improved star-likeness metric\n")
            f.write("5. gini_coefficient_fiedler_correlation.png - Correlation: Gini vs. Fiedler\n")
            f.write("6. star_likeness_fiedler_correlation.png - Correlation: Star-likeness vs. Fiedler\n")
            f.write("7. fiedler_value_heatmap.png - Heatmap of Fiedler values\n")
            f.write("8. layer_adjacency_comparison.png - Attention adjacency matrices comparison\n")

        print(f"Report generated at: {report_path}")
        return report_path

    def run_analysis(self, texts=None, thresholds=[0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2],
                    batch_size=10, sample_limit=None):
        """
        Run the complete optimized analysis pipeline.

        Args:
            texts: List of text samples to analyze (if None, example texts are used)
            thresholds: List of thresholds for attention binarization
            batch_size: Number of samples to process in each batch
            sample_limit: Maximum number of samples to process (None=all)

        Returns:
            Dictionary with analysis results
        """
        start_time = time.time()

        if self.model is None:
            self.load_model()

        if texts is None or not texts:
            texts = [
                "The concept of eigenvalues relates to how transformations affect spaces.",
                "Algebraic connectivity measures how well-connected a graph is.",
                "Star graphs have one central node connected to all other nodes.",
                "In transformers, attention mechanisms create dynamic connections between tokens."
            ]

        self.analyze_samples_aggregate(
            texts=texts,
            thresholds=thresholds,
            batch_size=batch_size,
            sample_limit=sample_limit
        )

        self.visualize_results()

        self.generate_report()

        total_time = time.time() - start_time
        print(f"Analysis completed in {total_time/60:.2f} minutes")

        return self.aggregated_results

    def cleanup(self):
        """Clean up resources."""
        if self.model is not None:
            self.model = self.model.to("cpu")
            del self.model
            self.model = None

        if self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None

        gc.collect()
        torch.cuda.empty_cache()

        print("Resources cleaned up")


def get_sample_texts_from_dataset(dataset_path, n_samples=50):
    """
    Extract sample texts from a dataset for analysis.

    Args:
        dataset_path: Path to the dataset CSV
        n_samples: Number of samples to extract

    Returns:
        List of text samples
    """
    try:
        # Load dataset
        data = pd.read_csv(dataset_path)
        print(f"Loaded dataset with {len(data)} rows")

        # Check if text column exists
        if 'text' not in data.columns:
            print("Error: No 'text' column found in dataset")
            return []

        # Sample texts
        if len(data) > n_samples:
            samples = data.sample(n_samples)
        else:
            samples = data

        # Extract texts
        texts = samples['text'].tolist()
        return texts

    except Exception as e:
        print(f"Error loading dataset: {e}")
        return []



In [None]:
def main():
    """Run the optimized Laplacian eigenvalue analysis pipeline."""
    # Mount Google Drive
    drive.mount('/content/drive')

    # Set paths
    model_name = "mistralai/Mistral-7B-v0.1" 
    dataset_path = "/content/drive/MyDrive/wiki_dataset.csv" 
    output_dir = f"/content/drive/MyDrive/Sink/laplacian/{model_name.split('/')[-1]}"

    analyzer = OptimizedAttentionLaplacianAnalysis(
        model_name=model_name,
        output_dir=output_dir
    )

    # Get sample texts from dataset
    texts = get_sample_texts_from_dataset(dataset_path, n_samples=500)

    if not texts:
        # Use default examples if no texts available
        texts = [
            "The concept of eigenvalues relates to how transformations affect spaces.",
            "Algebraic connectivity measures how well-connected a graph is.",
            "Star graphs have one central node connected to all other nodes.",
            "In transformers, attention mechanisms create dynamic connections between tokens."
        ]

    try:
        analyzer.run_analysis(
            texts=texts,
            thresholds=[0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2],
            batch_size=10,  
            sample_limit=500  
        )
        print("Analysis completed successfully!")
    except Exception as e:
        print(f"Analysis error: {e}")
        traceback.print_exc()
    finally:
        analyzer.cleanup()


if __name__ == "__main__":
    main()