In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from ripser import ripser
import persim
from sklearn.manifold import TSNE
import os
import gc
from google.colab import drive
import json
import pandas as pd
import time
from collections import defaultdict
import seaborn as sns
from matplotlib.ticker import MaxNLocator
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import logging
from tqdm import tqdm
import warnings
import h5py
import psutil
from typing import Dict, List, Tuple, Any, Optional, Union
from scipy import stats

In [None]:
# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("transformer_topology")

# Filter common warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)


class AggregateTopologyAnalyzer:
    """
    Class to analyze transformer attention topology and collect aggregate statistics.
    Instead of generating per-sample analysis files, this aggregates metrics across samples.
    """

    def __init__(self, model_name, max_tokens=48, device=None, cache_dir=None):
        """
        Initialize the analyzer with a model.

        Args:
            model_name: Name or path of the transformers model to analyze
            max_tokens: Maximum number of tokens to process
            device: Device to run model on ('cuda' or 'cpu')
            cache_dir: Directory to cache model files
        """
        self.model_name = model_name
        self.max_tokens = max_tokens
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.cache_dir = cache_dir
        self.tokenizer = None
        self.model = None

        # Force CPU if CUDA disabled
        if os.environ.get("CUDA_VISIBLE_DEVICES") == "":
            self.device = "cpu"

        # Store memory tracking info
        self.memory_tracker = {
            'peak': 0,
            'current': 0
        }

        self.sample_count = 0

        self.betti_numbers = defaultdict(lambda: {'dim0': [], 'dim1': []})
        self.attention_stats = defaultdict(lambda: {'mean': [], 'max': [], 'std': []})
        self.persistence_stats = defaultdict(lambda: {'dim0_avg': [], 'dim0_max': [], 'dim1_avg': [], 'dim1_max': []})
        self.head_specialization = defaultdict(lambda: defaultdict(list))
        self.token_focus = defaultdict(lambda: defaultdict(dict))

        logger.info(f"Initialized AggregateTopologyAnalyzer for model: {model_name} on {self.device}")

    def _ensure_float32(self, array):
        """
        Ensure array is in float32 format for consistent processing.
        
        Args:
            array: Numpy array that might need conversion
        
        Returns:
            array in float32 format
        """
        if array.dtype != np.float32:
            return array.astype(np.float32)
        return array

    def _update_memory_stats(self):
        """Update memory usage statistics."""
        if torch.cuda.is_available() and self.device != "cpu":
            current = torch.cuda.memory_allocated() / (1024 ** 3)  # GB
            peak = torch.cuda.max_memory_allocated() / (1024 ** 3)  # GB
            self.memory_tracker['current'] = current
            self.memory_tracker['peak'] = max(self.memory_tracker['peak'], peak)
        else:
            process = psutil.Process(os.getpid())
            current = process.memory_info().rss / (1024 ** 3)  # GB
            self.memory_tracker['current'] = current
            self.memory_tracker['peak'] = max(self.memory_tracker['peak'], current)

    def load_model_and_tokenizer(self):
        """Load the model and tokenizer."""
        if self.tokenizer is None:
            logger.info("Loading tokenizer...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                cache_dir=self.cache_dir
            )

        if self.model is None:
            logger.info(f"Loading model in {self.device} mode...")
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                device_map=self.device,
                output_attentions=True,
                cache_dir=self.cache_dir,
                torch_dtype=torch.float32
            )
        return self.tokenizer, self.model

    def process_text(self, text):
        """
        Process a text sample and extract attention maps.

        Args:
            text: Text to analyze

        Returns:
            dict: Dictionary containing tokens and attention maps
        """
        if self.tokenizer is None or self.model is None:
            self.load_model_and_tokenizer()

        # Tokenize with limited sequence length
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            max_length=self.max_tokens,
            truncation=True
        )

        # Move inputs to the correct device
        if self.device != "cpu":
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

        tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
        logger.info(f"Processing {len(tokens)} tokens")

        # Run model inference
        with torch.no_grad():
            outputs = self.model(**inputs, output_attentions=True)

        # Get attention maps and move to CPU immediately to save GPU memory
        attention_maps = [attn.cpu().numpy() for attn in outputs.attentions]

        if self.device != "cpu":
            del inputs, outputs
            torch.cuda.empty_cache()

        self._update_memory_stats()
        return {
            "tokens": tokens,
            "attention_maps": attention_maps
        }

    def analyze_sample_and_aggregate(self, text, sample_id=None, layer_subset=None):
        """
        Analyze a sample and add its metrics to aggregated statistics.
        Instead of saving individual files, this just collects data for the aggregate report.

        Args:
            text: Text to analyze
            sample_id: Identifier for the sample
            layer_subset: Subset of layers to analyze (None = analyze specific layers)

        Returns:
            dict: Dictionary with sample metrics
        """
        try:
            # Process the text and get attention maps
            results = self.process_text(text)
            tokens = results["tokens"]
            attention_maps = results["attention_maps"]

            # If layer_subset is None, specify key layers
            num_layers = len(attention_maps)
            if layer_subset is None:
                layer_step = max(1, num_layers // 4)
                layer_subset = [0, layer_step, 2*layer_step, num_layers-1]
                layer_subset = sorted(list(set([min(l, num_layers-1) for l in layer_subset])))

            # Extract metrics for each selected layer in parallel
            with ThreadPoolExecutor(max_workers=min(len(layer_subset), os.cpu_count() or 4)) as executor:
                futures = []
                for layer_idx in layer_subset:
                    futures.append(
                        executor.submit(
                            self._extract_layer_metrics,
                            layer_idx=layer_idx,
                            tokens=tokens,
                            attention_map=attention_maps[layer_idx]
                        )
                    )

                # Collect results
                layer_metrics = {}
                for future in futures:
                    layer_idx, metrics = future.result()
                    layer_metrics[layer_idx] = metrics

            self._update_aggregate_stats(layer_metrics, tokens, sample_id)

            self.sample_count += 1

            # Calculate approximate sample metrics
            sample_metrics = {
                'sample_id': sample_id,
                'num_tokens': len(tokens),
                'layers_analyzed': list(layer_metrics.keys()),
                'attention_means': {
                    layer_idx: metrics['attention_mean']
                    for layer_idx, metrics in layer_metrics.items()
                },
                'betti_numbers': {
                    layer_idx: metrics['betti_numbers']
                    for layer_idx, metrics in layer_metrics.items()
                    if 'betti_numbers' in metrics
                }
            }

            del results, tokens, attention_maps, layer_metrics
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            return sample_metrics

        except Exception as e:
            logger.error(f"Error analyzing sample {sample_id}: {str(e)}", exc_info=True)
            return {'error': str(e), 'sample_id': sample_id}

    def _extract_layer_metrics(self, layer_idx, tokens, attention_map):
        """
        Extract metrics from a layer's attention map.

        Args:
            layer_idx: Index of the layer
            tokens: List of tokens
            attention_map: Attention map for this layer

        Returns:
            tuple: (layer_idx, metrics_dict)
        """
        metrics = {}

        try:
            # Get average attention across heads
            avg_attention = np.mean(attention_map, axis=1)[0]

            # Ensure float32 precision
            avg_attention = self._ensure_float32(avg_attention)

            attn_mean = np.mean(avg_attention)
            attn_max = np.max(avg_attention)
            attn_min = np.min(avg_attention)
            attn_std = np.std(avg_attention)

            metrics.update({
                'attention_mean': attn_mean,
                'attention_max': attn_max,
                'attention_min': attn_min,
                'attention_std': attn_std
            })

            # Most attended tokens (top 5)
            token_attention = np.mean(avg_attention, axis=0)
            top_indices = np.argsort(-token_attention)[:5]
            top_tokens = {}
            for idx in top_indices:
                if idx < len(tokens):
                    top_tokens[tokens[idx]] = float(token_attention[idx])
            metrics['top_tokens'] = top_tokens

            # Compute distance matrix for persistence homology
            distance_matrix = 1 - avg_attention

            # Downsample for persistence computation
            max_size = min(30, distance_matrix.shape[0])
            if distance_matrix.shape[0] > max_size:
                indices = np.linspace(0, distance_matrix.shape[0]-1, max_size, dtype=int)
                distance_matrix = distance_matrix[indices][:, indices]

            # Compute persistence diagram
            diagram = ripser(distance_matrix, maxdim=1, distance_matrix=True)

            # Extract Betti numbers
            betti_numbers = [len(dgm) for dgm in diagram['dgms']]
            metrics['betti_numbers'] = betti_numbers

            # Extract persistence statistics
            persistence_stats = self._extract_persistence_stats(diagram)
            metrics.update(persistence_stats)

            # Analyze individual attention heads
            head_metrics = self._analyze_heads(attention_map, tokens)
            metrics['head_metrics'] = head_metrics

            # Topology at multiple thresholds
            topology_metrics = self._compute_topology_metrics(avg_attention, thresholds=[0.2, 0.4])
            metrics['topology'] = topology_metrics

        except Exception as e:
            logger.error(f"Error extracting metrics for layer {layer_idx}: {str(e)}")
            metrics['error'] = str(e)

        return layer_idx, metrics

    def _extract_persistence_stats(self, diagram):
        """Extract statistics from a persistence diagram."""
        stats = {}

        if len(diagram['dgms']) > 0 and len(diagram['dgms'][0]) > 0:
            dim0_features = [(birth, death) for birth, death in diagram['dgms'][0] if death != float('inf')]
            if dim0_features:
                stats['dim0_avg_persistence'] = float(np.mean([death - birth for birth, death in dim0_features]))
                stats['dim0_max_persistence'] = float(np.max([death - birth for birth, death in dim0_features]))

        if len(diagram['dgms']) > 1 and len(diagram['dgms'][1]) > 0:
            dim1_features = [(birth, death) for birth, death in diagram['dgms'][1] if death != float('inf')]
            if dim1_features:
                stats['dim1_avg_persistence'] = float(np.mean([death - birth for birth, death in dim1_features]))
                stats['dim1_max_persistence'] = float(np.max([death - birth for birth, death in dim1_features]))

        return stats

    def _analyze_heads(self, attention_map, tokens):
        """Analyze individual attention heads."""
        num_heads = attention_map.shape[1]
        head_metrics = []

        for head_idx in range(num_heads):
            head_attention = attention_map[0, head_idx]

            head_attention = self._ensure_float32(head_attention)

            avg_attn = float(np.mean(head_attention))
            max_attn = float(np.max(head_attention))
            std_attn = float(np.std(head_attention))
            entropy = float(-np.sum(head_attention * np.log2(head_attention + 1e-10)))

            # Find most attended token
            token_attention = np.mean(head_attention, axis=0)
            top_idx = np.argmax(token_attention)
            top_token = tokens[top_idx] if top_idx < len(tokens) else "unknown"
            top_strength = float(token_attention[top_idx])

            # Find strongest connection
            max_i, max_j = np.unravel_index(np.argmax(head_attention), head_attention.shape)
            if max_i < len(tokens) and max_j < len(tokens):
                strongest_conn = {
                    'source': tokens[max_i],
                    'target': tokens[max_j],
                    'strength': float(head_attention[max_i, max_j])
                }
            else:
                strongest_conn = None

            head_metrics.append({
                'head_idx': head_idx,
                'avg_attention': avg_attn,
                'max_attention': max_attn,
                'std_attention': std_attn,
                'entropy': entropy,
                'top_token': top_token,
                'top_strength': top_strength,
                'strongest_connection': strongest_conn
            })

        return head_metrics

    def _compute_topology_metrics(self, attention_map, thresholds=None):
        """
        Compute topological metrics at different thresholds.
        
        Args:
            attention_map: Attention map to analyze
            thresholds: List of thresholds to use (default: [0.2, 0.4])
        """
        if thresholds is None:
            thresholds = [0.2, 0.4]
            
        metrics = {}

        attention_map = self._ensure_float32(attention_map)

        for threshold in thresholds:
            # Create binary adjacency matrix
            binary_adj = (attention_map > threshold).astype(int)

            # Count connections
            total_connections = int(np.sum(binary_adj))
            max_possible = attention_map.shape[0] * attention_map.shape[1]
            connection_density = float(total_connections / max_possible)

            metrics[f'threshold_{threshold}'] = {
                'connection_density': connection_density,
                'total_connections': total_connections
            }

        return metrics

    def _update_aggregate_stats(self, layer_metrics, tokens, sample_id):
        """
        Update aggregate statistics with metrics from a sample.

        Args:
            layer_metrics: Dictionary of metrics for each layer
            tokens: List of tokens in the sample
            sample_id: Identifier for the sample
        """
        for layer_idx, metrics in layer_metrics.items():
            # Update Betti numbers
            if 'betti_numbers' in metrics:
                if len(metrics['betti_numbers']) > 0:
                    self.betti_numbers[layer_idx]['dim0'].append(metrics['betti_numbers'][0])
                if len(metrics['betti_numbers']) > 1:
                    self.betti_numbers[layer_idx]['dim1'].append(metrics['betti_numbers'][1])

            # Update attention statistics
            self.attention_stats[layer_idx]['mean'].append(metrics['attention_mean'])
            self.attention_stats[layer_idx]['max'].append(metrics['attention_max'])
            self.attention_stats[layer_idx]['std'].append(metrics['attention_std'])

            # Update persistence statistics
            if 'dim0_avg_persistence' in metrics:
                self.persistence_stats[layer_idx]['dim0_avg'].append(metrics['dim0_avg_persistence'])
            if 'dim0_max_persistence' in metrics:
                self.persistence_stats[layer_idx]['dim0_max'].append(metrics['dim0_max_persistence'])
            if 'dim1_avg_persistence' in metrics:
                self.persistence_stats[layer_idx]['dim1_avg'].append(metrics['dim1_avg_persistence'])
            if 'dim1_max_persistence' in metrics:
                self.persistence_stats[layer_idx]['dim1_max'].append(metrics['dim1_max_persistence'])

            # Update head specialization data
            if 'head_metrics' in metrics:
                for head_data in metrics['head_metrics']:
                    head_idx = head_data['head_idx']
                    top_token = head_data['top_token']
                    self.head_specialization[layer_idx][head_idx].append(top_token)

                    # Track token focus frequency
                    if top_token not in self.token_focus[layer_idx][head_idx]:
                        self.token_focus[layer_idx][head_idx][top_token] = 1
                    else:
                        self.token_focus[layer_idx][head_idx][top_token] += 1

    def generate_aggregate_report(self, output_path):
        """
        Generate a comprehensive report with aggregate statistics and statistical significance tests.
        """
        output_dir = Path(output_path)
        output_dir.mkdir(parents=True, exist_ok=True)

        vis_dir = output_dir / "visualizations"
        vis_dir.mkdir(exist_ok=True)

        test_results = self.run_statistical_tests()

        self._generate_visualizations(vis_dir)

        report_path = output_dir / "aggregate_topology_report.md"

        with open(report_path, 'w') as f:
            f.write("# Transformer Attention Topology Analysis\n\n")

            f.write(f"## Analysis Overview\n\n")
            f.write(f"- **Model**: {self.model_name}\n")
            f.write(f"- **Samples Analyzed**: {self.sample_count}\n")
            f.write(f"- **Max Tokens Per Sample**: {self.max_tokens}\n")
            f.write(f"- **Layers Analyzed**: {sorted(self.attention_stats.keys())}\n")
            f.write(f"- **Analysis Date**: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")

            f.write(f"## Summary of Findings\n\n")

            # Write key findings section with statistical significance
            key_findings = self._identify_key_findings()
            updated_findings = self.update_findings_with_statistics(key_findings, test_results)

            f.write("### Key Findings\n\n")
            for i, finding in enumerate(updated_findings, 1):
                f.write(f"{i}. {finding}\n")
            f.write("\n")

            # Write Betti numbers section
            f.write("## Topological Features (Betti Numbers)\n\n")
            f.write("Betti numbers measure topological features: Betti₀ counts connected components, ")
            f.write("while Betti₁ counts loops/cycles in the attention structure.\n\n")

            f.write("| Layer | Betti₀ (Mean ± Std) | Betti₁ (Mean ± Std) | Sample Count |\n")
            f.write("|-------|-------------------|-------------------|-------------|\n")

            for layer_idx in sorted(self.betti_numbers.keys()):
                dim0_data = self.betti_numbers[layer_idx]['dim0']
                dim1_data = self.betti_numbers[layer_idx]['dim1']

                if dim0_data:
                    dim0_mean = np.mean(dim0_data)
                    dim0_std = np.std(dim0_data)
                    dim0_str = f"{dim0_mean:.2f} ± {dim0_std:.2f}"
                else:
                    dim0_str = "N/A"

                if dim1_data:
                    dim1_mean = np.mean(dim1_data)
                    dim1_std = np.std(dim1_data)
                    dim1_str = f"{dim1_mean:.2f} ± {dim1_std:.2f}"
                else:
                    dim1_str = "N/A"

                sample_count = len(dim0_data) or len(dim1_data)
                f.write(f"| {layer_idx} | {dim0_str} | {dim1_str} | {sample_count} |\n")

            f.write("\n![Betti Numbers Across Layers](visualizations/betti_numbers.png)\n\n")

            f.write("## Attention Distribution Analysis\n\n")
            f.write("This section examines how attention is distributed across tokens in different layers.\n\n")

            f.write("| Layer | Mean Attention (± Std) | Max Attention (± Std) | Attention StdDev (± Std) |\n")
            f.write("|-------|----------------------|---------------------|------------------------|\n")

            for layer_idx in sorted(self.attention_stats.keys()):
                mean_data = self.attention_stats[layer_idx]['mean']
                max_data = self.attention_stats[layer_idx]['max']
                std_data = self.attention_stats[layer_idx]['std']

                if mean_data:
                    mean_mean = np.mean(mean_data)
                    mean_std = np.std(mean_data)
                    mean_str = f"{mean_mean:.4f} ± {mean_std:.4f}"
                else:
                    mean_str = "N/A"

                if max_data:
                    max_mean = np.mean(max_data)
                    max_std = np.std(max_data)
                    max_str = f"{max_mean:.4f} ± {max_std:.4f}"
                else:
                    max_str = "N/A"

                if std_data:
                    std_mean = np.mean(std_data)
                    std_std = np.std(std_data)
                    std_str = f"{std_mean:.4f} ± {std_std:.4f}"
                else:
                    std_str = "N/A"

                f.write(f"| {layer_idx} | {mean_str} | {max_str} | {std_str} |\n")

            f.write("\n![Attention Statistics Across Layers](visualizations/attention_stats.png)\n\n")

            f.write("## Topological Persistence Analysis\n\n")
            f.write("Persistence measures the 'significance' of topological features. ")
            f.write("Higher persistence values indicate more prominent features.\n\n")

            f.write("| Layer | Dim0 Avg Persistence | Dim1 Avg Persistence |\n")
            f.write("|-------|---------------------|---------------------|\n")

            for layer_idx in sorted(self.persistence_stats.keys()):
                stats = self.persistence_stats[layer_idx]

                if stats['dim0_avg']:
                    dim0_avg = np.mean(stats['dim0_avg'])
                    dim0_str = f"{dim0_avg:.4f}"
                else:
                    dim0_str = "N/A"

                if stats['dim1_avg']:
                    dim1_avg = np.mean(stats['dim1_avg'])
                    dim1_str = f"{dim1_avg:.4f}"
                else:
                    dim1_str = "N/A"

                f.write(f"| {layer_idx} | {dim0_str} | {dim1_str} |\n")

            f.write("\n![Persistence Statistics Across Layers](visualizations/persistence_stats.png)\n\n")

            f.write("## Statistical Significance\n\n")
            f.write("A detailed analysis of statistical significance tests is available in the ")
            f.write("[Statistical Significance Report](statistical_significance.md).\n\n")
            f.write("This report shows which differences between layers are statistically significant ")
            f.write("(p < 0.05) and which might be due to random variation.\n")


            f.write("## Attention Head Specialization\n\n")
            f.write("This section identifies attention heads that consistently focus on specific token types across samples.\n\n")

            specialized_heads = self._identify_specialized_heads()

            if specialized_heads:
                f.write("### Top Specialized Heads\n\n")
                f.write("| Layer | Head | Top Token | Consistency (%) | Samples |\n")
                f.write("|-------|------|-----------|-----------------|--------|\n")

                for head_data in specialized_heads[:15]:  # Show top 15
                    layer_idx = head_data['layer']
                    head_idx = head_data['head']
                    token = head_data['token']
                    consistency = head_data['consistency'] * 100
                    sample_count = head_data['samples']

                    f.write(f"| {layer_idx} | {head_idx} | {token} | {consistency:.1f}% | {sample_count} |\n")

                f.write("\n![Head Specialization Heatmap](visualizations/head_specialization.png)\n\n")
            else:
                f.write("No strongly specialized heads were identified in the analysis.\n\n")

            layers = sorted(self.attention_stats.keys())
            if len(layers) >= 3:
                f.write("## Early vs. Late Layer Comparison\n\n")
                f.write("This section compares the characteristics of early and late transformer layers.\n\n")

                early_layer = layers[0]
                late_layer = layers[-1]

                # Compare key metrics
                metrics_table = [
                    ["Metric", f"Layer {early_layer}", f"Layer {late_layer}", "Change"],
                ]

                # Betti0
                early_betti0 = np.mean(self.betti_numbers[early_layer]['dim0']) if self.betti_numbers[early_layer]['dim0'] else 0
                late_betti0 = np.mean(self.betti_numbers[late_layer]['dim0']) if self.betti_numbers[late_layer]['dim0'] else 0
                betti0_change = late_betti0 - early_betti0
                metrics_table.append([
                    "Betti₀ (components)",
                    f"{early_betti0:.2f}",
                    f"{late_betti0:.2f}",
                    f"{betti0_change:+.2f}"
                ])

                # Betti1
                early_betti1 = np.mean(self.betti_numbers[early_layer]['dim1']) if self.betti_numbers[early_layer]['dim1'] else 0
                late_betti1 = np.mean(self.betti_numbers[late_layer]['dim1']) if self.betti_numbers[late_layer]['dim1'] else 0
                betti1_change = late_betti1 - early_betti1
                metrics_table.append([
                    "Betti₁ (loops)",
                    f"{early_betti1:.2f}",
                    f"{late_betti1:.2f}",
                    f"{betti1_change:+.2f}"
                ])

                # Mean attention
                early_attn = np.mean(self.attention_stats[early_layer]['mean']) if self.attention_stats[early_layer]['mean'] else 0
                late_attn = np.mean(self.attention_stats[late_layer]['mean']) if self.attention_stats[late_layer]['mean'] else 0
                attn_change = late_attn - early_attn
                metrics_table.append([
                    "Mean Attention",
                    f"{early_attn:.4f}",
                    f"{late_attn:.4f}",
                    f"{attn_change:+.4f}"
                ])

                # Persistence
                early_pers = np.mean(self.persistence_stats[early_layer]['dim0_avg']) if self.persistence_stats[early_layer]['dim0_avg'] else 0
                late_pers = np.mean(self.persistence_stats[late_layer]['dim0_avg']) if self.persistence_stats[late_layer]['dim0_avg'] else 0
                pers_change = late_pers - early_pers
                metrics_table.append([
                    "Dim0 Persistence",
                    f"{early_pers:.4f}",
                    f"{late_pers:.4f}",
                    f"{pers_change:+.4f}"
                ])

                # Render table
                f.write("### Key Metrics Comparison\n\n")
                f.write("| " + " | ".join(metrics_table[0]) + " |\n")
                f.write("|" + "|".join(["---"] * len(metrics_table[0])) + "|\n")
                for row in metrics_table[1:]:
                    f.write("| " + " | ".join(row) + " |\n")

                f.write("\n### Interpretations\n\n")
                f.write("Based on these comparisons:\n\n")

                if betti0_change > 0:
                    f.write("- **Component Structure**: Later layers have more disconnected components, suggesting ")
                    f.write("more fragmentation or specialization in attention patterns.\n")
                elif betti0_change < 0:
                    f.write("- **Component Structure**: Later layers have fewer disconnected components, suggesting ")
                    f.write("more integration or consolidation of attention patterns.\n")

                if betti1_change > 0:
                    f.write("- **Loop Structure**: Later layers have more loops, suggesting more complex ")
                    f.write("cyclic relationships between tokens.\n")
                elif betti1_change < 0:
                    f.write("- **Loop Structure**: Later layers have fewer loops, suggesting simpler ")
                    f.write("token relationships with less cyclic attention patterns.\n")

                if attn_change > 0.01:
                    f.write("- **Attention Intensity**: Later layers show higher average attention values, ")
                    f.write("indicating stronger focus on specific tokens.\n")
                elif attn_change < -0.01:
                    f.write("- **Attention Intensity**: Later layers show lower average attention values, ")
                    f.write("indicating more distributed focus across tokens.\n")

            f.write("### Technical Details\n\n")
            f.write(f"- **Peak Memory Usage**: {self.memory_tracker['peak']:.2f} GB\n")
            f.write(f"- **Device Used**: {self.device}\n")
    
        self.generate_statistical_report(test_results, output_dir)

        stats_path = output_dir / "aggregate_statistics.json"
        self._save_aggregate_stats(stats_path)

        logger.info(f"Aggregate report saved to {report_path}")
        return str(report_path)


    def _generate_visualizations(self, vis_dir):
        """Generate visualizations of aggregate statistics."""

        self._create_betti_number_visualization(vis_dir / "betti_numbers.png")

        self._create_attention_stats_visualization(vis_dir / "attention_stats.png")

        self._create_persistence_stats_visualization(vis_dir / "persistence_stats.png")

        self._create_head_specialization_visualization(vis_dir / "head_specialization.png")

        layers = sorted(self.attention_stats.keys())
        if len(layers) >= 3:
            self._create_layer_comparison_visualization(vis_dir / "layer_comparison.png",
                                                       early_layer=layers[0],
                                                       late_layer=layers[-1])

    def _create_betti_number_visualization(self, output_path):
        """Create a visualization of Betti numbers across layers."""
        plt.figure(figsize=(12, 8))

        layers = sorted(self.betti_numbers.keys())
        dim0_means = []
        dim0_stds = []
        dim1_means = []
        dim1_stds = []

        for layer in layers:
            dim0_data = self.betti_numbers[layer]['dim0']
            dim1_data = self.betti_numbers[layer]['dim1']

            dim0_means.append(np.mean(dim0_data) if dim0_data else 0)
            dim0_stds.append(np.std(dim0_data) if len(dim0_data) > 1 else 0)
            dim1_means.append(np.mean(dim1_data) if dim1_data else 0)
            dim1_stds.append(np.std(dim1_data) if len(dim1_data) > 1 else 0)

        plt.errorbar(layers, dim0_means, yerr=dim0_stds, fmt='o-', label='Betti₀ (Components)', capsize=4)
        plt.errorbar(layers, dim1_means, yerr=dim1_stds, fmt='s-', label='Betti₁ (Loops)', capsize=4)

        plt.title('Betti Numbers Across Layers (All Samples)')
        plt.xlabel('Layer')
        plt.ylabel('Betti Number (Mean ± Std)')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.xticks(layers)
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))

        plt.tight_layout()
        plt.savefig(output_path, dpi=120)
        plt.close()

    def _create_attention_stats_visualization(self, output_path):
        """Create a visualization of attention statistics across layers."""
        plt.figure(figsize=(12, 8))

        layers = sorted(self.attention_stats.keys())
        means = []
        maxes = []
        stds = []

        for layer in layers:
            mean_data = self.attention_stats[layer]['mean']
            max_data = self.attention_stats[layer]['max']
            std_data = self.attention_stats[layer]['std']

            means.append(np.mean(mean_data) if mean_data else 0)
            maxes.append(np.mean(max_data) if max_data else 0)
            stds.append(np.mean(std_data) if std_data else 0)

        plt.plot(layers, means, 'o-', label='Mean Attention')
        plt.plot(layers, maxes, 's-', label='Max Attention')
        plt.plot(layers, stds, '^-', label='Std Deviation')

        plt.title('Attention Statistics Across Layers (All Samples)')
        plt.xlabel('Layer')
        plt.ylabel('Attention Value')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.xticks(layers)
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))

        plt.tight_layout()
        plt.savefig(output_path, dpi=120)
        plt.close()

    def _create_persistence_stats_visualization(self, output_path):
        """Create a visualization of persistence statistics across layers."""
        plt.figure(figsize=(12, 8))

        layers = sorted(self.persistence_stats.keys())
        dim0_avgs = []
        dim1_avgs = []

        for layer in layers:
            stats = self.persistence_stats[layer]

            dim0_avgs.append(np.mean(stats['dim0_avg']) if stats['dim0_avg'] else 0)
            dim1_avgs.append(np.mean(stats['dim1_avg']) if stats['dim1_avg'] else 0)

        plt.plot(layers, dim0_avgs, 'o-', label='Dim 0 Avg Persistence')
        plt.plot(layers, dim1_avgs, 's-', label='Dim 1 Avg Persistence')

        plt.title('Average Persistence Across Layers (All Samples)')
        plt.xlabel('Layer')
        plt.ylabel('Average Persistence')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.xticks(layers)
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))

        plt.tight_layout()
        plt.savefig(output_path, dpi=120)
        plt.close()

    def _create_head_specialization_visualization(self, output_path):
        """Create a heatmap of head specialization consistency."""
        # Get all layers and maximum number of heads
        layers = sorted(self.token_focus.keys())
        max_heads = 0

        for layer in layers:
            if self.token_focus[layer]:
                max_heads = max(max_heads, max(self.token_focus[layer].keys()) + 1)

        if max_heads == 0:
            logger.warning("No head specialization data available for visualization")
            return

        # Create consistency matrix
        consistency_matrix = np.zeros((len(layers), max_heads))
        consistency_matrix.fill(np.nan)  # Fill with NaN for missing data

        # Fill matrix with consistency values
        for i, layer_idx in enumerate(layers):
            for head_idx in range(max_heads):
                if head_idx in self.token_focus[layer_idx]:
                    token_counts = self.token_focus[layer_idx][head_idx]
                    if token_counts:
                        max_count = max(token_counts.values())
                        total = sum(token_counts.values())
                        consistency = max_count / total if total > 0 else 0
                        consistency_matrix[i, head_idx] = consistency

        plt.figure(figsize=(14, 10))
        ax = sns.heatmap(consistency_matrix, cmap='viridis', vmin=0, vmax=1,
                        xticklabels=range(max_heads), yticklabels=layers,
                        cbar_kws={'label': 'Token Focus Consistency'})

        plt.title('Head Specialization Consistency Across Samples')
        plt.xlabel('Head Index')
        plt.ylabel('Layer')

        plt.xticks(rotation=0)
        plt.yticks(rotation=0)

        plt.tight_layout()
        plt.savefig(output_path, dpi=120)
        plt.close()

    def _create_layer_comparison_visualization(self, output_path, early_layer, late_layer):
        """Create a comparative visualization of early vs late layers."""

        fig, axes = plt.subplots(2, 2, figsize=(15, 12))

        axes[0, 0].set_title('Betti Numbers Comparison')

        early_betti0 = np.mean(self.betti_numbers[early_layer]['dim0']) if self.betti_numbers[early_layer]['dim0'] else 0
        early_betti1 = np.mean(self.betti_numbers[early_layer]['dim1']) if self.betti_numbers[early_layer]['dim1'] else 0
        late_betti0 = np.mean(self.betti_numbers[late_layer]['dim0']) if self.betti_numbers[late_layer]['dim0'] else 0
        late_betti1 = np.mean(self.betti_numbers[late_layer]['dim1']) if self.betti_numbers[late_layer]['dim1'] else 0

        betti_values = [early_betti0, early_betti1, late_betti0, late_betti1]
        betti_labels = [f'Layer {early_layer}\nBetti₀', f'Layer {early_layer}\nBetti₁',
                     f'Layer {late_layer}\nBetti₀', f'Layer {late_layer}\nBetti₁']

        bars = axes[0, 0].bar(betti_labels, betti_values, color=['blue', 'green', 'blue', 'green'], alpha=0.7)
        bars[2].set_color('orange')
        bars[3].set_color('red')

        axes[0, 0].set_ylabel('Average Value')
        axes[0, 0].grid(axis='y', linestyle='--', alpha=0.7)

        axes[0, 1].set_title('Attention Statistics Comparison')

        early_mean = np.mean(self.attention_stats[early_layer]['mean']) if self.attention_stats[early_layer]['mean'] else 0
        early_max = np.mean(self.attention_stats[early_layer]['max']) if self.attention_stats[early_layer]['max'] else 0
        early_std = np.mean(self.attention_stats[early_layer]['std']) if self.attention_stats[early_layer]['std'] else 0

        late_mean = np.mean(self.attention_stats[late_layer]['mean']) if self.attention_stats[late_layer]['mean'] else 0
        late_max = np.mean(self.attention_stats[late_layer]['max']) if self.attention_stats[late_layer]['max'] else 0
        late_std = np.mean(self.attention_stats[late_layer]['std']) if self.attention_stats[late_layer]['std'] else 0

        attn_values = [early_mean, early_max, early_std, late_mean, late_max, late_std]
        attn_labels = [f'Layer {early_layer}\nMean', f'Layer {early_layer}\nMax', f'Layer {early_layer}\nStd',
                     f'Layer {late_layer}\nMean', f'Layer {late_layer}\nMax', f'Layer {late_layer}\nStd']

        bars = axes[0, 1].bar(attn_labels, attn_values, color=['blue', 'blue', 'blue', 'orange', 'orange', 'orange'], alpha=0.7)

        axes[0, 1].set_ylabel('Attention Value')
        axes[0, 1].grid(axis='y', linestyle='--', alpha=0.7)

        axes[1, 0].set_title('Persistence Values Comparison')

        early_dim0 = np.mean(self.persistence_stats[early_layer]['dim0_avg']) if self.persistence_stats[early_layer]['dim0_avg'] else 0
        early_dim1 = np.mean(self.persistence_stats[early_layer]['dim1_avg']) if self.persistence_stats[early_layer]['dim1_avg'] else 0
        late_dim0 = np.mean(self.persistence_stats[late_layer]['dim0_avg']) if self.persistence_stats[late_layer]['dim0_avg'] else 0
        late_dim1 = np.mean(self.persistence_stats[late_layer]['dim1_avg']) if self.persistence_stats[late_layer]['dim1_avg'] else 0

        pers_values = [early_dim0, early_dim1, late_dim0, late_dim1]
        pers_labels = [f'Layer {early_layer}\nDim0', f'Layer {early_layer}\nDim1',
                      f'Layer {late_layer}\nDim0', f'Layer {late_layer}\nDim1']

        bars = axes[1, 0].bar(pers_labels, pers_values, color=['blue', 'green', 'orange', 'red'], alpha=0.7)

        axes[1, 0].set_ylabel('Average Persistence')
        axes[1, 0].grid(axis='y', linestyle='--', alpha=0.7)

        axes[1, 1].set_title('Head Specialization Comparison')

        early_heads = 0
        late_heads = 0

        for head_idx in self.token_focus.get(early_layer, {}):
            token_counts = self.token_focus[early_layer][head_idx]
            if token_counts:
                max_count = max(token_counts.values())
                total = sum(token_counts.values())
                consistency = max_count / total if total > 0 else 0
                if consistency >= 0.5:
                    early_heads += 1

        for head_idx in self.token_focus.get(late_layer, {}):
            token_counts = self.token_focus[late_layer][head_idx]
            if token_counts:
                max_count = max(token_counts.values())
                total = sum(token_counts.values())
                consistency = max_count / total if total > 0 else 0
                if consistency >= 0.5:
                    late_heads += 1

        head_values = [early_heads, late_heads]
        head_labels = [f'Layer {early_layer}', f'Layer {late_layer}']

        bars = axes[1, 1].bar(head_labels, head_values, color=['blue', 'orange'], alpha=0.7)

        axes[1, 1].set_ylabel('Specialized Heads Count')
        axes[1, 1].grid(axis='y', linestyle='--', alpha=0.7)

        plt.tight_layout()
        plt.savefig(output_path, dpi=120)
        plt.close()

    def _identify_specialized_heads(self):
        """Identify attention heads with consistent focus across samples."""
        specialized_heads = []

        for layer_idx in self.token_focus:
            for head_idx in self.token_focus[layer_idx]:
                token_counts = self.token_focus[layer_idx][head_idx]
                if token_counts:
                    max_token = max(token_counts.items(), key=lambda x: x[1])
                    token = max_token[0]
                    count = max_token[1]
                    total = sum(token_counts.values())
                    consistency = count / total if total > 0 else 0

                    if consistency >= 0.3 and total >= 5:  # At least 30% consistency and 5 samples
                        specialized_heads.append({
                            'layer': layer_idx,
                            'head': head_idx,
                            'token': token,
                            'consistency': consistency,
                            'samples': total
                        })

        # Sort by consistency
        specialized_heads.sort(key=lambda x: x['consistency'], reverse=True)

        return specialized_heads

    def _identify_key_findings(self):
        """Identify key findings from the analysis."""
        findings = []

        if not self.betti_numbers or not self.attention_stats:
            findings.append("Insufficient data for comprehensive analysis.")
            return findings

        # Betti number trends
        layers = sorted(self.betti_numbers.keys())
        if len(layers) >= 2:
            early_layer = layers[0]
            late_layer = layers[-1]

            early_betti0 = np.mean(self.betti_numbers[early_layer]['dim0']) if self.betti_numbers[early_layer]['dim0'] else 0
            late_betti0 = np.mean(self.betti_numbers[late_layer]['dim0']) if self.betti_numbers[late_layer]['dim0'] else 0

            if abs(late_betti0 - early_betti0) > 1.0:
                if late_betti0 > early_betti0:
                    findings.append(f"**Component fragmentation**: Connected components increase from {early_betti0:.1f} in layer {early_layer} to {late_betti0:.1f} in layer {late_layer}, suggesting attention becomes more fragmented in deeper layers.")
                else:
                    findings.append(f"**Component integration**: Connected components decrease from {early_betti0:.1f} in layer {early_layer} to {late_betti0:.1f} in layer {late_layer}, suggesting attention becomes more integrated in deeper layers.")

            early_betti1 = np.mean(self.betti_numbers[early_layer]['dim1']) if self.betti_numbers[early_layer]['dim1'] else 0
            late_betti1 = np.mean(self.betti_numbers[late_layer]['dim1']) if self.betti_numbers[late_layer]['dim1'] else 0

            if abs(late_betti1 - early_betti1) > 0.5:
                if late_betti1 > early_betti1:
                    findings.append(f"**Cyclic complexity**: Loops increase from {early_betti1:.1f} in layer {early_layer} to {late_betti1:.1f} in layer {late_layer}, suggesting attention forms more complex cyclic patterns in deeper layers.")
                else:
                    findings.append(f"**Cyclic simplification**: Loops decrease from {early_betti1:.1f} in layer {early_layer} to {late_betti1:.1f} in layer {late_layer}, suggesting attention forms simpler patterns in deeper layers.")

        # attention patterns
        layers = sorted(self.attention_stats.keys())
        if len(layers) >= 2:
            early_layer = layers[0]
            late_layer = layers[-1]

            early_mean = np.mean(self.attention_stats[early_layer]['mean']) if self.attention_stats[early_layer]['mean'] else 0
            late_mean = np.mean(self.attention_stats[late_layer]['mean']) if self.attention_stats[late_layer]['mean'] else 0

            if abs(late_mean - early_mean) > 0.01:
                if late_mean > early_mean:
                    findings.append(f"**Attention intensity**: Average attention increases from {early_mean:.4f} in layer {early_layer} to {late_mean:.4f} in layer {late_layer}, suggesting stronger focus in deeper layers.")
                else:
                    findings.append(f"**Attention diffusion**: Average attention decreases from {early_mean:.4f} in layer {early_layer} to {late_mean:.4f} in layer {late_layer}, suggesting more distributed focus in deeper layers.")

        # Look for head specialization
        specialized_heads = self._identify_specialized_heads()
        if specialized_heads:
            top_specialized = specialized_heads[0]
            findings.append(f"**Head specialization**: Found {len(specialized_heads)} heads with consistent token focus. The most specialized is layer {top_specialized['layer']}, head {top_specialized['head']} focusing on token '{top_specialized['token']}' with {top_specialized['consistency']*100:.1f}% consistency.")

            # Look for layer patterns in specialization
            layer_counts = {}
            for head in specialized_heads:
                layer = head['layer']
                if layer not in layer_counts:
                    layer_counts[layer] = 0
                layer_counts[layer] += 1

            if len(layer_counts) >= 2:
                most_specialized_layer = max(layer_counts.items(), key=lambda x: x[1])
                findings.append(f"**Layer specialization**: Layer {most_specialized_layer[0]} has the most specialized attention heads ({most_specialized_layer[1]}), suggesting specialized processing at this depth.")

        # Identify persistence patterns
        layers = sorted(self.persistence_stats.keys())
        if len(layers) >= 2:
            early_layer = layers[0]
            late_layer = layers[-1]

            early_dim0 = np.mean(self.persistence_stats[early_layer]['dim0_avg']) if self.persistence_stats[early_layer]['dim0_avg'] else 0
            late_dim0 = np.mean(self.persistence_stats[late_layer]['dim0_avg']) if self.persistence_stats[late_layer]['dim0_avg'] else 0

            if abs(late_dim0 - early_dim0) > 0.05:
                if late_dim0 > early_dim0:
                    findings.append(f"**Topological significance**: Dimension 0 persistence increases from {early_dim0:.4f} in layer {early_layer} to {late_dim0:.4f} in layer {late_layer}, suggesting more stable connected components in deeper layers.")
                else:
                    findings.append(f"**Topological significance**: Dimension 0 persistence decreases from {early_dim0:.4f} in layer {early_layer} to {late_dim0:.4f} in layer {late_layer}, suggesting less stable connected components in deeper layers.")

        return findings

    def _save_aggregate_stats(self, output_path):
        """Save aggregate statistics to a JSON file."""

        serializable_stats = {
            'sample_count': self.sample_count,
            'model_name': self.model_name,
            'device': self.device,
            'memory_usage': self.memory_tracker,

            'betti_numbers': {},
            'attention_stats': {},
            'persistence_stats': {},
            'head_specialization': {},
            'token_focus': {}
        }

        for layer_idx, data in self.betti_numbers.items():
            serializable_stats['betti_numbers'][str(layer_idx)] = {
                'dim0': [float(x) for x in data['dim0']],
                'dim1': [float(x) for x in data['dim1']]
            }

        # Convert attention stats
        for layer_idx, data in self.attention_stats.items():
            serializable_stats['attention_stats'][str(layer_idx)] = {
                'mean': [float(x) for x in data['mean']],
                'max': [float(x) for x in data['max']],
                'std': [float(x) for x in data['std']]
            }

        # Convert persistence stats
        for layer_idx, data in self.persistence_stats.items():
            serializable_stats['persistence_stats'][str(layer_idx)] = {
                'dim0_avg': [float(x) for x in data['dim0_avg']],
                'dim0_max': [float(x) for x in data['dim0_max']],
                'dim1_avg': [float(x) for x in data['dim1_avg']],
                'dim1_max': [float(x) for x in data['dim1_max']]
            }

        # Convert head specialization
        for layer_idx, heads in self.head_specialization.items():
            serializable_stats['head_specialization'][str(layer_idx)] = {}
            for head_idx, tokens in heads.items():
                serializable_stats['head_specialization'][str(layer_idx)][str(head_idx)] = tokens

        # Convert token focus
        for layer_idx, heads in self.token_focus.items():
            serializable_stats['token_focus'][str(layer_idx)] = {}
            for head_idx, tokens in heads.items():
                serializable_stats['token_focus'][str(layer_idx)][str(head_idx)] = dict(tokens)

        with open(output_path, 'w') as f:
            json.dump(serializable_stats, f, indent=2)

        logger.info(f"Aggregate statistics saved to {output_path}")

    def run_statistical_tests(self):
        """
        Run statistical significance tests on the collected data.
        This helps determine if differences between layers are meaningful or just random variation.
        """
        logger.info("Running statistical significance tests between layers...")

        test_results = {}

        layers = sorted(self.betti_numbers.keys())
        if len(layers) < 2:
            logger.warning("Need at least 2 layers for statistical testing")
            return {}

        for i in range(len(layers)-1):
            for j in range(i+1, len(layers)):
                layer_i = layers[i]
                layer_j = layers[j]

                comparison_key = f"layer_{layer_i}_vs_{layer_j}"
                test_results[comparison_key] = {}

                # Test Betti Numbers (Topological Features)

                # Test Betti0 (connected components)
                if (self.betti_numbers[layer_i]['dim0'] and
                    self.betti_numbers[layer_j]['dim0']):
                    # Run Welch's t-test (doesn't assume equal variance)
                    stat, pval = stats.ttest_ind(
                        self.betti_numbers[layer_i]['dim0'],
                        self.betti_numbers[layer_j]['dim0'],
                        equal_var=False
                    )

                    test_results[comparison_key]['betti0'] = {
                        'pval': float(pval),
                        'significant': pval < 0.05,
                        'mean_i': float(np.mean(self.betti_numbers[layer_i]['dim0'])),
                        'mean_j': float(np.mean(self.betti_numbers[layer_j]['dim0']))
                    }

                # Test Betti1 (loops/cycles)
                if (self.betti_numbers[layer_i]['dim1'] and
                    self.betti_numbers[layer_j]['dim1']):
                    # Run Welch's t-test
                    stat, pval = stats.ttest_ind(
                        self.betti_numbers[layer_i]['dim1'],
                        self.betti_numbers[layer_j]['dim1'],
                        equal_var=False
                    )

                    test_results[comparison_key]['betti1'] = {
                        'pval': float(pval),
                        'significant': pval < 0.05,
                        'mean_i': float(np.mean(self.betti_numbers[layer_i]['dim1'])),
                        'mean_j': float(np.mean(self.betti_numbers[layer_j]['dim1']))
                    }

                # Test Attention Statistics 

                # Test mean attention
                if (self.attention_stats[layer_i]['mean'] and
                    self.attention_stats[layer_j]['mean']):
                    stat, pval = stats.ttest_ind(
                        self.attention_stats[layer_i]['mean'],
                        self.attention_stats[layer_j]['mean'],
                        equal_var=False
                    )

                    test_results[comparison_key]['mean_attention'] = {
                        'pval': float(pval),
                        'significant': pval < 0.05,
                        'mean_i': float(np.mean(self.attention_stats[layer_i]['mean'])),
                        'mean_j': float(np.mean(self.attention_stats[layer_j]['mean']))
                    }

                # Test max attention
                if (self.attention_stats[layer_i]['max'] and
                    self.attention_stats[layer_j]['max']):
                    stat, pval = stats.ttest_ind(
                        self.attention_stats[layer_i]['max'],
                        self.attention_stats[layer_j]['max'],
                        equal_var=False
                    )

                    test_results[comparison_key]['max_attention'] = {
                        'pval': float(pval),
                        'significant': pval < 0.05,
                        'mean_i': float(np.mean(self.attention_stats[layer_i]['max'])),
                        'mean_j': float(np.mean(self.attention_stats[layer_j]['max']))
                    }

                # Test Persistence Statistics 

                # Test dimension 0 average persistence
                if (self.persistence_stats[layer_i]['dim0_avg'] and
                    self.persistence_stats[layer_j]['dim0_avg']):
                    stat, pval = stats.ttest_ind(
                        self.persistence_stats[layer_i]['dim0_avg'],
                        self.persistence_stats[layer_j]['dim0_avg'],
                        equal_var=False
                    )

                    test_results[comparison_key]['persistence_dim0'] = {
                        'pval': float(pval),
                        'significant': pval < 0.05,
                        'mean_i': float(np.mean(self.persistence_stats[layer_i]['dim0_avg'])),
                        'mean_j': float(np.mean(self.persistence_stats[layer_j]['dim0_avg']))
                    }

        # Find and log significant results
        significant_findings = []
        for comparison, tests in test_results.items():
            for metric, result in tests.items():
                if result['significant']:
                    layer_i = comparison.split('_vs_')[0].replace('layer_', '')
                    layer_j = comparison.split('_vs_')[1]
                    significant_findings.append(
                        f"{metric} differs significantly between layer {layer_i} and layer {layer_j} (p={result['pval']:.4f})"
                    )

        if significant_findings:
            logger.info(f"Found {len(significant_findings)} statistically significant differences")
            for finding in significant_findings[:5]:  # Log first 5 findings
                logger.info(finding)
            if len(significant_findings) > 5:
                logger.info(f"...and {len(significant_findings) - 5} more")
        else:
            logger.info("No statistically significant differences found between layers")

        return test_results

    def update_findings_with_statistics(self, findings, test_results):
        """
        Update key findings with statistical significance information.
        """
        updated_findings = []

        for finding in findings:

            matched = False

            if "Connected components increase from" in finding or "Connected components decrease from" in finding:

                import re
                match = re.search(r"layer (\d+) to .* in layer (\d+)", finding)
                if match:
                    layer_i, layer_j = match.groups()
                    comparison_key = f"layer_{layer_i}_vs_{layer_j}"

                    if comparison_key in test_results and 'betti0' in test_results[comparison_key]:
                        result = test_results[comparison_key]['betti0']
                        if result['significant']:
                            finding += f" This difference is statistically significant (p={result['pval']:.4f})."
                        else:
                            finding += f" However, this difference is not statistically significant (p={result['pval']:.4f})."
                        matched = True

            elif "Loops increase from" in finding or "Loops decrease from" in finding:
                import re
                match = re.search(r"layer (\d+) to .* in layer (\d+)", finding)
                if match:
                    layer_i, layer_j = match.groups()
                    comparison_key = f"layer_{layer_i}_vs_{layer_j}"

                    if comparison_key in test_results and 'betti1' in test_results[comparison_key]:
                        result = test_results[comparison_key]['betti1']
                        if result['significant']:
                            finding += f" This difference is statistically significant (p={result['pval']:.4f})."
                        else:
                            finding += f" However, this difference is not statistically significant (p={result['pval']:.4f})."
                        matched = True

            elif "Average attention increases from" in finding or "Average attention decreases from" in finding:
                import re
                match = re.search(r"layer (\d+) to .* in layer (\d+)", finding)
                if match:
                    layer_i, layer_j = match.groups()
                    comparison_key = f"layer_{layer_i}_vs_{layer_j}"

                    if comparison_key in test_results and 'mean_attention' in test_results[comparison_key]:
                        result = test_results[comparison_key]['mean_attention']
                        if result['significant']:
                            finding += f" This difference is statistically significant (p={result['pval']:.4f})."
                        else:
                            finding += f" However, this difference is not statistically significant (p={result['pval']:.4f})."
                        matched = True

            updated_findings.append(finding)

        return updated_findings

    def generate_statistical_report(self, test_results, output_dir):
        """
        Generate a simple report of statistical significance tests.
        """
        from pathlib import Path

        output_path = Path(output_dir) / "statistical_significance.md"

        with open(output_path, "w") as f:
            f.write("# Statistical Significance Analysis\n\n")
            f.write("This report shows which differences between layers are statistically significant.\n\n")
            f.write("## Interpretation\n\n")
            f.write("- p-value < 0.05: The difference is statistically significant\n")
            f.write("- p-value ≥ 0.05: The difference could be due to random variation\n\n")

            f.write("## Significant Differences Between Layers\n\n")

            sig_count = 0
            for comparison, tests in test_results.items():
                for metric, result in tests.items():
                    if result['significant']:
                        sig_count += 1

            if sig_count == 0:
                f.write("No statistically significant differences were found between layers.\n\n")
            else:
                f.write(f"Found {sig_count} statistically significant differences:\n\n")

                f.write("| Comparison | Metric | Layer 1 Mean | Layer 2 Mean | p-value |\n")
                f.write("|------------|--------|--------------|--------------|--------|\n")

                for comparison, tests in sorted(test_results.items()):
                    layer_i = comparison.split('_vs_')[0].replace('layer_', '')
                    layer_j = comparison.split('_vs_')[1]

                    for metric, result in sorted(tests.items()):
                        if result['significant']:

                            if metric == 'betti0':
                                metric_name = "Betti₀ (Components)"
                            elif metric == 'betti1':
                                metric_name = "Betti₁ (Loops)"
                            elif metric == 'mean_attention':
                                metric_name = "Mean Attention"
                            elif metric == 'max_attention':
                                metric_name = "Max Attention"
                            elif metric == 'persistence_dim0':
                                metric_name = "Dim0 Persistence"
                            else:
                                metric_name = metric.replace('_', ' ').title()

                            f.write(f"| Layer {layer_i} vs {layer_j} | {metric_name} | {result['mean_i']:.4f} | {result['mean_j']:.4f} | {result['pval']:.4f} |\n")

            f.write("\n## All Test Results\n\n")
            f.write("For completeness, here are all statistical test results, significant or not:\n\n")

            f.write("| Comparison | Metric | Layer 1 Mean | Layer 2 Mean | p-value | Significant? |\n")
            f.write("|------------|--------|--------------|--------------|---------|-------------|\n")

            for comparison, tests in sorted(test_results.items()):
                layer_i = comparison.split('_vs_')[0].replace('layer_', '')
                layer_j = comparison.split('_vs_')[1]

                for metric, result in sorted(tests.items()):

                    if metric == 'betti0':
                        metric_name = "Betti₀ (Components)"
                    elif metric == 'betti1':
                        metric_name = "Betti₁ (Loops)"
                    elif metric == 'mean_attention':
                        metric_name = "Mean Attention"
                    elif metric == 'max_attention':
                        metric_name = "Max Attention"
                    elif metric == 'persistence_dim0':
                        metric_name = "Dim0 Persistence"
                    else:
                        metric_name = metric.replace('_', ' ').title()

                    f.write(f"| Layer {layer_i} vs {layer_j} | {metric_name} | {result['mean_i']:.4f} | {result['mean_j']:.4f} | {result['pval']:.4f} | {'Yes' if result['significant'] else 'No'} |\n")

        logger.info(f"Statistical significance report saved to {output_path}")
        return str(output_path)


class AggregateAnalysisRunner:
    """
    Runs the aggregate-only analysis on a dataset.
    """

    def __init__(self, model_name, output_dir, max_tokens=48, max_workers=None, cache_dir=None):
        """
        Initialize the analysis runner.

        Args:
            model_name: Name of the model to analyze
            output_dir: Directory to save results
            max_tokens: Maximum number of tokens to process
            max_workers: Maximum number of parallel workers (None = use CPU count)
            cache_dir: Directory to cache model files
        """
        self.model_name = model_name
        self.output_dir = Path(output_dir)
        self.max_tokens = max_tokens
        self.max_workers = max_workers or min(os.cpu_count() or 4, 8)  # Limit to 8 workers max
        self.cache_dir = cache_dir

        self.output_dir.mkdir(parents=True, exist_ok=True)
        logger.info(f"Initialized AggregateAnalysisRunner for model: {model_name}")

    def run_analysis(self, dataset_path, n_samples=500, focus_word_column=None, layer_subset=None):
        """
        Run the aggregate-only analysis on a dataset.
        """
        start_time = time.time()

        try:
            data = pd.read_csv(dataset_path)
            logger.info(f"Loaded dataset with {len(data)} rows")

            analyzer = AggregateTopologyAnalyzer(
                model_name=self.model_name,
                max_tokens=self.max_tokens,
                cache_dir=self.cache_dir
            )

            _, model = analyzer.load_model_and_tokenizer()

            dummy_input = analyzer.tokenizer("Test", return_tensors="pt")
            if analyzer.device != "cpu":
                dummy_input = {k: v.to(analyzer.device) for k, v in dummy_input.items()}

            with torch.no_grad():
                outputs = analyzer.model(**dummy_input, output_attentions=True)

            actual_num_layers = len(outputs.attentions)
            logger.info(f"Model has {actual_num_layers} attention layers")

            if layer_subset is None:
                layer_step = max(1, actual_num_layers // 4)
                layer_subset = [0, layer_step, 2*layer_step, 3*layer_step, actual_num_layers-1]
                layer_subset = sorted(list(set([min(l, actual_num_layers-1) for l in layer_subset])))
            else:
                layer_subset = [min(l, actual_num_layers-1) for l in layer_subset]
                layer_subset = sorted(list(set(layer_subset))) 

            logger.info(f"Analyzing layers: {layer_subset}")

            logger.info("Selecting samples...")
            samples = self._select_samples(data, n_samples, focus_word_column)
            logger.info(f"Selected {len(samples)} samples for analysis")

            logger.info("Starting sample processing...")
            sample_counter = 0
            successful_samples = 0

            for i, sample in enumerate(tqdm(samples, desc="Processing samples")):
                try:
                    sample_id = i + 1
                    text = sample['text']
                    focus_word = sample.get('focus_word')

                    metrics = analyzer.analyze_sample_and_aggregate(
                        text=text,
                        sample_id=sample_id,
                        layer_subset=layer_subset
                    )

                    if 'error' not in metrics:
                        successful_samples += 1

                    sample_counter += 1
                    if sample_counter % 10 == 0:
                        logger.info(f"Processed {sample_counter}/{len(samples)} samples")

                except Exception as e:
                    logger.error(f"Error processing sample {i+1}: {str(e)}")

                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

            logger.info("Generating aggregate report...")
            report_path = analyzer.generate_aggregate_report(self.output_dir)

            execution_time = time.time() - start_time
            logger.info(f"Analysis completed in {execution_time/60:.2f} minutes")
            logger.info(f"Successfully processed {successful_samples}/{len(samples)} samples")
            logger.info(f"Report available at: {report_path}")

            return True

        except Exception as e:
            logger.error(f"Analysis failed: {str(e)}", exc_info=True)
            return False

    def _select_samples(self, data, n_samples, focus_word_column=None):
        """
        Select samples from the dataset.

        Args:
            data: Pandas DataFrame with the dataset
            n_samples: Number of samples to select
            focus_word_column: Name of column containing focus words (if any)

        Returns:
            list: List of selected samples
        """
        samples = []

        if focus_word_column and focus_word_column in data.columns:
            focus_words = data[focus_word_column].value_counts().index[:n_samples]
            for word in focus_words:
                word_data = data[data[focus_word_column] == word]
                if len(word_data) > 0:
                    if 'text' in word_data.columns:
                        word_data = word_data.iloc[word_data['text'].str.len().argsort()]
                    sample = word_data.iloc[0]
                    samples.append({
                        'text': sample['text'],
                        'focus_word': word
                    })

        if len(samples) < n_samples:
            if 'text' in data.columns:
                try:
                    sorted_data = data.iloc[data['text'].str.len().argsort()]
                    chunk_size = len(sorted_data) // 3
                    begin_chunk = sorted_data.iloc[:chunk_size]
                    middle_chunk = sorted_data.iloc[chunk_size:2*chunk_size]

                    # Sample from different chunks
                    more_needed = n_samples - len(samples)
                    from_begin = min(more_needed // 2, len(begin_chunk))
                    from_middle = min(more_needed - from_begin, len(middle_chunk))

                    more_samples_begin = begin_chunk.sample(from_begin)
                    more_samples_middle = middle_chunk.sample(from_middle)
                    more_samples = pd.concat([more_samples_begin, more_samples_middle])

                    if len(more_samples) < more_needed:
                        remaining = more_needed - len(more_samples)
                        more_samples = pd.concat([more_samples, data.sample(remaining)])
                except Exception as e:
                    logger.warning(f"Error in intelligent sampling: {e}, falling back to random")
                    more_samples = data.sample(min(n_samples - len(samples), len(data)))
            else:
                more_samples = data.sample(min(n_samples - len(samples), len(data)))

            for _, row in more_samples.iterrows():
                samples.append({
                    'text': row['text'],
                    'focus_word': row.get(focus_word_column) if focus_word_column else None
                })

        return samples[:n_samples]  



In [None]:
def main():
    """
    Run the aggregate-only analysis.
    """
    drive.mount('/content/drive')

    model_name = "meta-llama/llama-3.2.1B"  
    dataset_path = "/content/drive/MyDrive/wiki_dataset.csv" 
    output_dir = f"/content/drive/MyDrive/Sink/topology/{model_name.split('/')[-1]}"

    try:
        runner = AggregateAnalysisRunner(
            model_name=model_name,
            output_dir=output_dir,
            max_tokens=48,
            max_workers=4
        )

        runner.run_analysis(
            dataset_path=dataset_path,
            n_samples=500,
            focus_word_column='focus_word'
        )

        print("Analysis completed successfully!")

    except Exception as e:
        logger.error(f"Analysis error: {str(e)}", exc_info=True)
        print(f"Analysis error: {e}")
    finally:
        print("Analysis finished. Check output directory for results.")

if __name__ == "__main__":
    main()