Analyse Llama models to study wavelet-like structure in attenion heads due to positional encodings

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
import logging
from typing import Dict, List, Tuple, Optional
import math
from datasets import load_dataset
from tqdm.notebook import tqdm
import gc
from scipy.optimize import curve_fit
from scipy.stats import entropy

In [None]:
class RoPEAnalyzer:
    def __init__(self, model_name: str = "meta-llama/Llama-3.2-1B"):
        logging.basicConfig(
            level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
        )
        self.logger = logging.getLogger(__name__)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.logger.info(f"Using device: {self.device}")

        self.logger.info(f"Loading model {model_name}...")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True,
            attn_implementation="eager",
        ).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model.eval()

        # Extract and store model dimensions
        self.hidden_size = self.model.config.hidden_size
        self.num_heads = self.model.config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_layers = self.model.config.num_hidden_layers

        self.logger.info(
            f"Model loaded: {self.num_layers} layers, {self.num_heads} heads"
        )

    def analyze_dataset(self, num_samples: int = 500) -> List[Dict]:
        """Analyze multiple samples from the dataset with progress tracking."""
        self.logger.info(f"Starting analysis of {num_samples} samples...")

        dataset = load_dataset("bookcorpus", split="train", streaming=True)
        texts = []

        with tqdm(desc="Collecting samples") as pbar:
            for item in dataset:
                text = item["text"]
                if 10 <= len(text.split()) <= 100:
                    texts.append(text)
                    pbar.update(1)
                    if len(texts) == num_samples:
                        break

        self.logger.info(f"Collected {len(texts)} suitable text samples")

        results = []
        with tqdm(total=len(texts), desc="Analyzing samples") as pbar:
            for idx, text in enumerate(texts):
                try:
                    result = self.analyze_single_text(text)
                    if result is not None:
                        results.append(result)

                    if (idx + 1) % 50 == 0:
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()
                        gc.collect()

                except Exception as e:
                    self.logger.error(f"Error processing sample {idx}: {str(e)}")
                    continue

                pbar.update(1)

        return results

    def analyze_single_text(self, text: str) -> Optional[Dict]:
        """
        Analyze attention patterns for a single text input across all layers.
        Returns metrics for each layer and attention head.
        """
        try:
            # Tokenize and prepare input
            tokens = self.tokenizer(
                text, return_tensors="pt", truncation=True, max_length=512
            ).to(self.device)
            batch_size, seq_len = tokens.input_ids.shape

            # Get attention patterns for all layers
            with torch.no_grad():
                outputs = self.model(**tokens, output_attentions=True)
                all_layer_attentions = (
                    outputs.attentions
                )

            # Analyze patterns for each layer
            results = {}
            for layer_idx, layer_attention in enumerate(all_layer_attentions):
                layer_results = {}
                layer_attention = layer_attention[0]  

                for head_idx in range(self.num_heads):
                    head_patterns = layer_attention[head_idx]
                    metrics = self._compute_attention_metrics(head_patterns)
                    layer_results[f"head_{head_idx}"] = metrics

                results[f"layer_{layer_idx}"] = layer_results

            return results

        except Exception as e:
            self.logger.error(f"Error analyzing text: {str(e)}")
            return None

    def _fit_exponential_decay(self, distance_profile: List[float]) -> float:
        distances = np.arange(len(distance_profile))

        def exp_func(x, A, c):
            return A * np.exp(-c * x)

        try:
            # Initial guess: A=distance_profile[0], c = 0.1
            popt, _ = curve_fit(
                exp_func, distances, distance_profile, p0=(distance_profile[0], 0.1)
            )
            return float(popt[1])  # c: decay coefficient
        except:
            return 0.0

    def _compute_attention_metrics(self, attention_weights: torch.Tensor) -> Dict:
        """Compute comprehensive metrics for attention patterns."""
        try:
            # Handle dimensions
            weights = attention_weights.float()
            if weights.dim() == 3:
                weights = weights.squeeze(0)

            seq_len = weights.size(-1)

            # Basic attention statistics
            mean_attention = weights.mean().item()
            max_attention = weights.max().item()
            local_attention = torch.diagonal(weights).mean().item()

            # Distance-based analysis
            positions = torch.arange(seq_len, device=weights.device)
            distances = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))
            weighted_distances = distances.float() * weights
            avg_distance = weighted_distances.sum() / weights.sum()

            # Build distance profile
            unique_distances = torch.unique(distances)
            distance_profile = []
            for d in unique_distances:
                mask = distances == d
                if mask.any():
                    avg = weights[mask].mean().item()
                    distance_profile.append(avg)

            # Compute decay coefficient using exponential fit
            decay_coefficient = 0.0
            if len(distance_profile) >= 2:
                decay_coefficient = self._fit_exponential_decay(distance_profile)

            return {
                "mean_attention": mean_attention,
                "max_attention": max_attention,
                "local_attention": local_attention,
                "avg_distance": avg_distance.item(),
                "decay_rate": decay_coefficient,
            }

        except Exception as e:
            self.logger.error(f"Error in attention metrics computation: {str(e)}")
            return {
                "mean_attention": 0.0,
                "max_attention": 0.0,
                "local_attention": 0.0,
                "avg_distance": 0.0,
                "decay_rate": 0.0,
            }

    def phase_shift_validation(self, text: str, shift_tokens: int = 5):
        # Original tokens
        orig_tokens = self.tokenizer(
            text, return_tensors="pt", truncation=True, max_length=512
        ).to(self.device)
        # Shifted tokens: add thingd like padding or EOS
        prepend = [self.tokenizer.eos_token_id] * shift_tokens
        shifted_input_ids = torch.cat(
            [
                torch.tensor(prepend, dtype=torch.long).unsqueeze(0).to(self.device),
                orig_tokens.input_ids,
            ],
            dim=1,
        )
        # Adjust attention mask if needed
        shifted_attention_mask = torch.cat(
            [
                torch.ones(shift_tokens, dtype=torch.long).unsqueeze(0).to(self.device),
                orig_tokens.attention_mask,
            ],
            dim=1,
        )

        shifted_tokens = {
            "input_ids": shifted_input_ids,
            "attention_mask": shifted_attention_mask,
        }

        with torch.no_grad():
            orig_outputs = self.model(**orig_tokens, output_attentions=True)
            shifted_outputs = self.model(**shifted_tokens, output_attentions=True)

        # Compare attention patterns
        phase_consistency = []
        for layer_idx, (orig_attn, shifted_attn) in enumerate(
            zip(orig_outputs.attentions, shifted_outputs.attentions)
        ):
            seq_len = orig_attn.shape[-1]
            # Align the shifted attention to compare corresponding positions
            aligned_shifted = shifted_attn[:, :, shift_tokens:, shift_tokens:]
            corr = (
                torch.nn.functional.cosine_similarity(
                    orig_attn.flatten(2), aligned_shifted.flatten(2), dim=-1
                )
                .mean()
                .item()
            )
            phase_consistency.append(corr)

        return {
            "phase_consistency_per_layer": phase_consistency,
            "mean_phase_consistency": float(np.mean(phase_consistency)),
        }

    def position_resolution_test(self, text: str):
        # Original
        tokens = self.tokenizer(
            text, return_tensors="pt", truncation=True, max_length=64
        ).to(self.device)
        with torch.no_grad():
            orig_out = self.model(**tokens, output_attentions=True)
        orig_attn = orig_out.attentions[-1][
            0
        ]  # last layer attention of first batch: shape [num_heads, seq_len, seq_len]

        # Create a slightly perturbed sequence by swapping two middle tokens
        seq_len = tokens.input_ids.shape[-1]
        if seq_len < 4:
            return None
        pert_tokens = tokens.input_ids.clone()
        i, j = seq_len // 4, seq_len // 4 + 1
        pert_tokens[0, i], pert_tokens[0, j] = pert_tokens[0, j], pert_tokens[0, i]

        pert_tokens_dict = {
            "input_ids": pert_tokens,
            "attention_mask": tokens.attention_mask,
        }
        with torch.no_grad():
            pert_out = self.model(**pert_tokens_dict, output_attentions=True)
        pert_attn = pert_out.attentions[-1][0]

        diff = (orig_attn - pert_attn).abs().mean().item()

        attn_flat = orig_attn.mean(0)  # average across heads
        attn_distribution = attn_flat.mean(0)  # average attention over queries
        attn_distribution = attn_distribution / (attn_distribution.sum() + 1e-10)
        pos_entropy = entropy(attn_distribution.cpu().numpy())

        return {"position_sensitivity": diff, "positional_entropy": float(pos_entropy)}

    def visualize_results(self, results: List[Dict]):
        """Create comprehensive visualizations of the analysis results."""
        if not results:
            self.logger.warning("No results to visualize.")
            return

        fig, axes = plt.subplots(2, 2, figsize=(20, 16))

        self._plot_attention_distance_heatmap(axes[0, 0], results)
        self._plot_local_global_distribution(axes[0, 1], results)
        self._plot_layer_evolution(axes[1, 0], results)
        self._plot_attention_decay(axes[1, 1], results)

        plt.tight_layout()
        return fig

    def _plot_attention_distance_heatmap(self, ax, results):
        """Plot heatmap of attention distances for all layers."""
        distances = np.zeros((self.num_layers, self.num_heads))

        for result in results:
            for layer_idx in range(self.num_layers):
                layer_key = f"layer_{layer_idx}"
                if layer_key in result:
                    for head_idx in range(self.num_heads):
                        head_key = f"head_{head_idx}"
                        if head_key in result[layer_key]:
                            distances[layer_idx, head_idx] += result[layer_key][
                                head_key
                            ]["avg_distance"]

        distances /= len(results)

        im = ax.imshow(distances, cmap="viridis")
        plt.colorbar(im, ax=ax)
        ax.set_title("Average Attention Distance")
        ax.set_xlabel("Head")
        ax.set_ylabel("Layer")

    def _plot_local_global_distribution(self, ax, results):
        """Plot distribution of local vs global attention for all layers."""
        local_ratios = []
        layers = []
        heads = []

        for result in results:
            for layer_idx in range(self.num_layers):
                layer_key = f"layer_{layer_idx}"
                if layer_key in result:
                    for head_idx in range(self.num_heads):
                        head_key = f"head_{head_idx}"
                        if head_key in result[layer_key]:
                            data = result[layer_key][head_key]
                            local_ratio = data["local_attention"] / (
                                data["mean_attention"] + 1e-10
                            )
                            local_ratios.append(local_ratio)
                            layers.append(layer_idx)
                            heads.append(head_idx)

        if local_ratios:
            scatter = ax.scatter(
                layers, local_ratios, c=heads, cmap="viridis", alpha=0.6
            )
            plt.colorbar(scatter, ax=ax, label="Head Index")
            ax.set_title("Local vs Global Attention Distribution")
            ax.set_xlabel("Layer")
            ax.set_ylabel("Local/Global Attention Ratio")

    def _plot_layer_evolution(self, ax, results):
        """Plot evolution of attention patterns through all layers."""
        layer_means = []
        layer_stds = []

        for layer_idx in range(self.num_layers):
            layer_distances = []
            for result in results:
                layer_key = f"layer_{layer_idx}"
                if layer_key in result:
                    for head_key in result[layer_key]:
                        dist = result[layer_key][head_key]["avg_distance"]
                        layer_distances.append(dist)

            if layer_distances:
                layer_means.append(np.mean(layer_distances))
                layer_stds.append(np.std(layer_distances))

        if layer_means:
            layers = range(len(layer_means))
            ax.plot(layers, layer_means, "b-", label="Mean Distance")
            ax.fill_between(
                layers,
                np.array(layer_means) - np.array(layer_stds),
                np.array(layer_means) + np.array(layer_stds),
                alpha=0.3,
            )
            ax.set_title("Attention Distance Evolution")
            ax.set_xlabel("Layer")
            ax.set_ylabel("Average Distance")
            ax.legend()

    def _plot_attention_decay(self, ax, results):
        """Plot attention decay patterns for all layers."""
        decay_rates = []
        layers_ = []
        heads_ = []

        for result in results:
            for layer_idx in range(self.num_layers):
                layer_key = f"layer_{layer_idx}"
                if layer_key in result:
                    for head_key, head_data in result[layer_key].items():
                        head_idx = int(head_key.split("_")[1])
                        decay_rates.append(head_data["decay_rate"])
                        layers_.append(layer_idx)
                        heads_.append(head_idx)

        if decay_rates:
            scatter = ax.scatter(
                layers_, decay_rates, c=heads_, cmap="viridis", alpha=0.6
            )
            plt.colorbar(scatter, ax=ax, label="Head Index")
            ax.set_title("Attention Decay Patterns")
            ax.set_xlabel("Layer")
            ax.set_ylabel("Decay Rate")


def main():
    """Run comprehensive analysis pipeline."""
    analyzer = RoPEAnalyzer()
    results = analyzer.analyze_dataset(num_samples=500)

    fig = analyzer.visualize_results(results)
    if fig is not None:
        plt.show()

    torch.save(results, "rope_analysis_results.pt")
    print("\nAnalysis complete. Results saved to 'rope_analysis_results.pt'")


if __name__ == "__main__":
    main()

In [None]:
class FrequencyAnalyzer:
    """
    A comprehensive analyzer for studying RoPE attention patterns in the frequency domain.
    Handles frequency analysis, interference detection, and visualization of results.
    """

    def __init__(self, model_name: str = "meta-llama/Llama-3.2-1B"):
        """Initialize analyzer with model and logging setup."""
        logging.basicConfig(
            level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
        )
        self.logger = logging.getLogger(__name__)
        self.logger.info("Initializing FrequencyAnalyzer...")

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.logger.info(f"Using device: {self.device}")

        # tokenizer with padding 
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True,
        ).to(self.device)

        self.model.eval()
        self.num_heads = self.model.config.num_attention_heads
        self.num_layers = self.model.config.num_hidden_layers

        self.logger.info(
            f"Model loaded: {self.num_layers} layers, {self.num_heads} heads"
        )

    def analyze_frequency_bands(self, attention_patterns: torch.Tensor) -> Dict:
        """Analyze frequency components of attention patterns."""
        spectral_results = {}
        patterns = attention_patterns.cpu().numpy()

        if len(patterns.shape) != 3:
            return {}

        num_heads, seq_len, _ = patterns.shape
        window = signal.windows.hann(seq_len)

        for head_idx in range(num_heads):
            try:
                # Average across rows and apply window
                head_pattern = patterns[head_idx].mean(axis=0) * window

                # Compute FFT with padding
                n_fft = 2 ** np.ceil(np.log2(seq_len)).astype(int)
                padded_pattern = np.pad(head_pattern, (0, n_fft - seq_len))

                # Compute normalized FFT and power spectrum
                fft_vals = np.abs(fft(padded_pattern)[: n_fft // 2])
                freq_scale = np.linspace(0, 1, n_fft // 2)
                psd = (fft_vals**2) / (n_fft * np.sum(window**2))
                total_power = np.sum(psd) + 1e-10

                # Calculate bands
                low_idx = max(1, int(0.25 * len(freq_scale)))
                mid_idx = max(low_idx + 1, int(0.75 * len(freq_scale)))

                spectral_results[f"head_{head_idx}"] = {
                    "band_powers": {
                        "low_band": float(np.sum(psd[:low_idx]) / total_power),
                        "mid_band": float(np.sum(psd[low_idx:mid_idx]) / total_power),
                        "high_band": float(np.sum(psd[mid_idx:]) / total_power),
                    },
                    "mean_selectivity": float(np.max(psd) / (np.mean(psd) + 1e-10)),
                    "spectral_entropy": float(entropy(psd / total_power)),
                    "peak_frequency": float(freq_scale[np.argmax(psd)]),
                }

            except Exception as e:
                self.logger.warning(f"Error processing head {head_idx}: {str(e)}")
                continue

        return spectral_results

    def analyze_cross_head_interference(self, attention_patterns: torch.Tensor) -> Dict:
        """Analyze interference between attention heads."""
        interference_results = {}
        patterns = attention_patterns.cpu().numpy()

        if len(patterns.shape) != 3:
            return {}

        num_heads, seq_len, _ = patterns.shape
        window = signal.windows.hann(seq_len)

        # Properly reshape window for broadcasting
        window = window.reshape(1, -1) 
        # Average patterns and apply window
        head_patterns = patterns.mean(axis=2)  # Shape: (num_heads, seq_len)
        head_patterns = head_patterns * window 

        for i in range(num_heads):
            for j in range(i + 1, num_heads):
                try:
                    n_fft = 2 ** np.ceil(np.log2(seq_len)).astype(int)
                    fft_i = fft(head_patterns[i], n=n_fft)[: n_fft // 2]
                    fft_j = fft(head_patterns[j], n=n_fft)[: n_fft // 2]

                    psd_i = np.abs(fft_i) ** 2
                    psd_j = np.abs(fft_j) ** 2
                    csd = fft_i * np.conj(fft_j)

                    coherence = np.abs(csd) / np.sqrt((psd_i * psd_j) + 1e-10)

                    interference_results[f"heads_{i}_{j}"] = {
                        "mean_coherence": float(np.mean(coherence)),
                        "peak_coherence": float(np.max(coherence)),
                        "interference_strength": float(np.mean(np.abs(csd))),
                    }

                except Exception as e:
                    self.logger.warning(f"Error processing head pair {i},{j}: {str(e)}")
                    continue

        return interference_results

    def analyze_dataset(self, num_samples: int = 500) -> Dict:
        """Analyze multiple samples with batch processing across all layers."""
        self.logger.info(f"Starting analysis of {num_samples} samples...")
        dataset = load_dataset("bookcorpus", split="train", streaming=True)

        layer_spectral_results = [[] for _ in range(self.num_layers)]
        layer_interference_results = [[] for _ in range(self.num_layers)]

        successful_samples = 0
        batch_size = 5 
        texts = []
        with tqdm(desc="Collecting samples", total=num_samples) as pbar:
            for item in dataset:
                if 10 <= len(item["text"].split()) <= 100:
                    texts.append(item["text"])
                    pbar.update(1)
                    if len(texts) == num_samples:
                        break

        with tqdm(total=len(texts), desc="Analyzing patterns") as pbar:
            for i in range(0, len(texts), batch_size):
                batch_texts = texts[i : i + batch_size]

                try:
                    tokens = self.tokenizer(
                        batch_texts,
                        padding=True,
                        truncation=True,
                        max_length=128,  
                        return_tensors="pt",
                    ).to(self.device)

                    with torch.no_grad():
                        outputs = self.model(**tokens, output_attentions=True)
                        for layer_idx, layer_attn in enumerate(outputs.attentions):
                            # layer_attn shape: [batch_size, num_heads, seq_len, seq_len]
                            avg_layer_attn = layer_attn.mean(
                                0
                            )  # average over batch: [num_heads, seq_len, seq_len]

                            spectral_result = self.analyze_frequency_bands(
                                avg_layer_attn
                            )
                            interference_result = self.analyze_cross_head_interference(
                                avg_layer_attn
                            )

                            if spectral_result and interference_result:
                                layer_spectral_results[layer_idx].append(
                                    spectral_result
                                )
                                layer_interference_results[layer_idx].append(
                                    interference_result
                                )
                        successful_samples += 1

                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    gc.collect()

                except Exception as e:
                    self.logger.error(f"Error in batch {i//batch_size}: {str(e)}")
                    continue

                pbar.update(len(batch_texts))

                if successful_samples >= num_samples // batch_size:
                    self.logger.info("Reached sufficient successful samples")
                    break

        if successful_samples == 0:
            raise ValueError("No samples were successfully analyzed")

        self.logger.info(f"Successfully analyzed {successful_samples} batches")

        results = self._aggregate_all_layers(
            layer_spectral_results, layer_interference_results
        )

        results["band_separation"] = self.band_separation_measurement(
            results["spectral_metrics"]
        )

        # per-layer spectral data in results['layer_spectral_aggregates']
        results["frequency_response_evolution"] = self.frequency_response_evolution(
            results["layer_spectral_aggregates"]
        )

        return results

    def _aggregate_all_layers(
        self,
        layer_spectral_results: List[List[Dict]],
        layer_interference_results: List[List[Dict]],
    ) -> Dict:
        """
        Aggregate results across layers and samples.
        it produce:
        - 'spectral_metrics': aggregated across all layers and samples (averaged),
        - 'interference_metrics': aggregated across all layers,
        - 'layer_spectral_aggregates': store per-layer averaged band powers for frequency response evolution.
        """
        all_spectral = []
        for layer_data in layer_spectral_results:
            all_spectral.extend(layer_data)

        all_interference = []
        for layer_data in layer_interference_results:
            all_interference.extend(layer_data)

        spectral_metrics = self._aggregate_spectral(all_spectral)
        interference_metrics = self._aggregate_interference(all_interference)

        layer_spectral_aggregates = []
        for layer_data in layer_spectral_results:
            # layer_data is a list of spectral_results (one per sample)
            layer_agg = self._aggregate_spectral(layer_data)
            layer_spectral_aggregates.append(layer_agg)

        return {
            "spectral_metrics": spectral_metrics,
            "interference_metrics": interference_metrics,
            "layer_spectral_aggregates": layer_spectral_aggregates,
        }

    def _aggregate_spectral(self, spectral_results: List[Dict]) -> Dict:
        """Aggregate spectral results across multiple samples."""
        aggregated_spectral = {}

        # spectral_results is a list of dicts, each with 'head_{i}' keys
        for sr in spectral_results:
            for head_key, head_data in sr.items():
                if head_key not in aggregated_spectral:
                    aggregated_spectral[head_key] = {
                        "band_powers": {
                            "low_band": [],
                            "mid_band": [],
                            "high_band": [],
                        },
                        "mean_selectivity": [],
                        "spectral_entropy": [],
                    }
                aggregated_spectral[head_key]["band_powers"]["low_band"].append(
                    head_data["band_powers"]["low_band"]
                )
                aggregated_spectral[head_key]["band_powers"]["mid_band"].append(
                    head_data["band_powers"]["mid_band"]
                )
                aggregated_spectral[head_key]["band_powers"]["high_band"].append(
                    head_data["band_powers"]["high_band"]
                )
                aggregated_spectral[head_key]["mean_selectivity"].append(
                    head_data["mean_selectivity"]
                )
                aggregated_spectral[head_key]["spectral_entropy"].append(
                    head_data["spectral_entropy"]
                )

        # Average over all samples
        for head_key in aggregated_spectral:
            aggregated_spectral[head_key] = {
                "band_powers": {
                    "low_band": float(
                        np.mean(
                            aggregated_spectral[head_key]["band_powers"]["low_band"]
                        )
                    ),
                    "mid_band": float(
                        np.mean(
                            aggregated_spectral[head_key]["band_powers"]["mid_band"]
                        )
                    ),
                    "high_band": float(
                        np.mean(
                            aggregated_spectral[head_key]["band_powers"]["high_band"]
                        )
                    ),
                },
                "mean_selectivity": float(
                    np.mean(aggregated_spectral[head_key]["mean_selectivity"])
                ),
                "spectral_entropy": float(
                    np.mean(aggregated_spectral[head_key]["spectral_entropy"])
                ),
            }

        return aggregated_spectral

    def _aggregate_interference(self, interference_results: List[Dict]) -> Dict:
        """Aggregate interference results across multiple samples."""
        aggregated_interference = {}
        for ir in interference_results:
            for key, values in ir.items():
                if key not in aggregated_interference:
                    aggregated_interference[key] = {k: [] for k in values.keys()}
                for k, v in values.items():
                    aggregated_interference[key][k].append(v)

        for key in aggregated_interference:
            aggregated_interference[key] = {
                k: float(np.mean(v)) for k, v in aggregated_interference[key].items()
            }

        return aggregated_interference

    def band_separation_measurement(self, metrics: Dict) -> Dict:
        """
        Compute correlation between heads' band distributions to measure band separation.
        Lower correlation between heads would indicate better band specialization.
        """
        heads = sorted(metrics.keys(), key=lambda x: int(x.split("_")[1]))
        band_vectors = []
        for h in heads:
            bp = metrics[h]["band_powers"]
            vec = [bp["low_band"], bp["mid_band"], bp["high_band"]]
            band_vectors.append(vec)

        band_vectors = np.array(band_vectors)
        # Compute correlation matrix between heads
        corr_matrix = np.corrcoef(band_vectors, rowvar=True)

        # Compute average off-diagonal correlation
        n = len(heads)
        off_diag_sum = np.sum(corr_matrix) - np.trace(corr_matrix)
        avg_inter_head_corr = off_diag_sum / (n * (n - 1))

        return {
            "inter_head_band_correlation": float(avg_inter_head_corr),
            "correlation_matrix": corr_matrix.tolist(),
        }

    def frequency_response_evolution(
        self, layer_spectral_aggregates: List[Dict]
    ) -> Dict:
        """
        Analyze how frequency band distributions evolve across layers.
        We'll compute average band powers for each layer.
        """
        layer_evolution = {
            "low_band_mean_per_layer": [],
            "mid_band_mean_per_layer": [],
            "high_band_mean_per_layer": [],
        }

        for layer_data in layer_spectral_aggregates:
            # layer_data is aggregated spectral metrics for that layer
            low_vals = []
            mid_vals = []
            high_vals = []
            for h_data in layer_data.values():
                low_vals.append(h_data["band_powers"]["low_band"])
                mid_vals.append(h_data["band_powers"]["mid_band"])
                high_vals.append(h_data["band_powers"]["high_band"])

            layer_evolution["low_band_mean_per_layer"].append(float(np.mean(low_vals)))
            layer_evolution["mid_band_mean_per_layer"].append(float(np.mean(mid_vals)))
            layer_evolution["high_band_mean_per_layer"].append(
                float(np.mean(high_vals))
            )

        return layer_evolution

    def visualize_results(self, results: Dict):
        """Create visualizations of the analysis results."""
        if not results or "spectral_metrics" not in results:
            self.logger.error("No valid results to visualize")
            return

        try:
            fig, axes = plt.subplots(2, 3, figsize=(24, 16))
            fig.suptitle("Frequency Analysis of Attention Patterns", fontsize=16)

            if results["spectral_metrics"]:
                self._plot_band_distribution(axes[0, 0], results["spectral_metrics"])
                self._plot_head_selectivity(axes[0, 1], results["spectral_metrics"])
                self._plot_spectral_entropy(axes[0, 2], results["spectral_metrics"])

            if "interference_metrics" in results and results["interference_metrics"]:
                self._plot_interference_patterns(
                    axes[1, 0], results["interference_metrics"]
                )

            if "band_separation" in results:
                self._plot_band_correlation(axes[1, 1], results["band_separation"])

            if "frequency_response_evolution" in results:
                self._plot_frequency_evolution(
                    axes[1, 2], results["frequency_response_evolution"]
                )

            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            return fig

        except Exception as e:
            self.logger.error(f"Error in visualization: {str(e)}")
            raise

    def _plot_band_distribution(self, ax, metrics):
        """Plot frequency band distribution."""
        heads = sorted(metrics.keys(), key=lambda x: int(x.split("_")[1]))
        x = np.arange(len(heads))
        width = 0.25

        low_band = [metrics[h]["band_powers"]["low_band"] for h in heads]
        mid_band = [metrics[h]["band_powers"]["mid_band"] for h in heads]
        high_band = [metrics[h]["band_powers"]["high_band"] for h in heads]

        ax.bar(x - width, low_band, width, label="Low Frequency (0-0.25)")
        ax.bar(x, mid_band, width, label="Mid Frequency (0.25-0.75)")
        ax.bar(x + width, high_band, width, label="High Frequency (0.75-1.0)")

        ax.set_title("Frequency Band Distribution Across Heads")
        ax.set_xlabel("Head Index")
        ax.set_ylabel("Normalized Band Power")
        ax.legend()
        ax.grid(True, alpha=0.3)

    def _plot_head_selectivity(self, ax, metrics):
        """Plot head selectivity."""
        heads = sorted(metrics.keys(), key=lambda x: int(x.split("_")[1]))
        selectivity = [metrics[h]["mean_selectivity"] for h in heads]

        bars = ax.bar(range(len(heads)), selectivity)

        norm = plt.Normalize(min(selectivity), max(selectivity))
        colors = plt.cm.viridis(norm(selectivity))
        for bar, color in zip(bars, colors):
            bar.set_color(color)

        ax.set_title("Frequency Selectivity by Attention Head")
        ax.set_xlabel("Head Index")
        ax.set_ylabel("Selectivity Score")
        ax.grid(True, alpha=0.3)

        sm = plt.cm.ScalarMappable(norm=norm, cmap=plt.cm.viridis)
        plt.colorbar(sm, ax=ax, label="Selectivity Level")

    def _plot_interference_patterns(self, ax, metrics):
        """Plot interference patterns between heads."""
        interference_matrix = np.zeros((self.num_heads, self.num_heads))

        for key, values in metrics.items():
            i, j = map(int, key.split("_")[1:])
            interference_matrix[i, j] = values["mean_coherence"]
            interference_matrix[j, i] = values["mean_coherence"]

        im = ax.imshow(interference_matrix, cmap="viridis", aspect="auto")
        plt.colorbar(im, ax=ax, label="Coherence")

        ax.set_title("Cross-Head Interference Patterns")
        ax.set_xlabel("Head Index")
        ax.set_ylabel("Head Index")

        for i in range(self.num_heads):
            for j in range(self.num_heads):
                if i != j:
                    text_color = "white" if interference_matrix[i, j] > 0.5 else "black"
                    ax.text(
                        j,
                        i,
                        f"{interference_matrix[i, j]:.2f}",
                        ha="center",
                        va="center",
                        color=text_color,
                        fontsize=8,
                    )

    def _plot_spectral_entropy(self, ax, metrics):
        """Plot the spectral entropy distribution across attention heads."""
        heads = sorted(metrics.keys(), key=lambda x: int(x.split("_")[1]))
        entropy_values = [metrics[h]["spectral_entropy"] for h in heads]

        bars = ax.bar(range(len(heads)), entropy_values)

        norm = plt.Normalize(min(entropy_values), max(entropy_values))
        colors = plt.cm.viridis(norm(entropy_values))
        for bar, color in zip(bars, colors):
            bar.set_color(color)

        ax.set_title("Spectral Entropy by Attention Head")
        ax.set_xlabel("Head Index")
        ax.set_ylabel("Entropy")
        ax.grid(True, alpha=0.3)

        sm = plt.cm.ScalarMappable(norm=norm, cmap=plt.cm.viridis)
        plt.colorbar(sm, ax=ax, label="Entropy Level")

    def _plot_band_correlation(self, ax, band_sep_results: Dict):
        """Plot the correlation matrix of band distributions for heads."""
        corr_matrix = np.array(band_sep_results["correlation_matrix"])
        im = ax.imshow(corr_matrix, cmap="viridis", aspect="auto")
        plt.colorbar(im, ax=ax, label="Correlation")

        ax.set_title("Inter-Head Band Correlation")
        ax.set_xlabel("Head Index")
        ax.set_ylabel("Head Index")

        n = corr_matrix.shape[0]
        for i in range(n):
            for j in range(n):
                if i != j:
                    val = corr_matrix[i, j]
                    text_color = "white" if val > 0.5 else "black"
                    ax.text(
                        j,
                        i,
                        f"{val:.2f}",
                        ha="center",
                        va="center",
                        color=text_color,
                        fontsize=8,
                    )

    def _plot_frequency_evolution(self, ax, freq_evolution: Dict):
        """Plot how band powers evolve across layers."""
        layers = np.arange(len(freq_evolution["low_band_mean_per_layer"]))
        ax.plot(layers, freq_evolution["low_band_mean_per_layer"], label="Low Band")
        ax.plot(layers, freq_evolution["mid_band_mean_per_layer"], label="Mid Band")
        ax.plot(layers, freq_evolution["high_band_mean_per_layer"], label="High Band")

        ax.set_title("Frequency Response Evolution Across Layers")
        ax.set_xlabel("Layer")
        ax.set_ylabel("Average Band Power")
        ax.legend()
        ax.grid(True, alpha=0.3)


def main():
    """
    Main execution function that runs the complete analysis pipeline.
    This function handles the entire process of:
    1. Initializing the analyzer
    2. Processing the dataset
    3. Generating visualizations
    4. Saving results
    """
    try:
        print("Initializing Frequency Analyzer...")
        analyzer = FrequencyAnalyzer()

        print("\nStarting dataset analysis...")
        results = analyzer.analyze_dataset(num_samples=500)

        if results and results["spectral_metrics"]:
            print("\nGenerating visualizations...")
            fig = analyzer.visualize_results(results)
            if fig:
                plt.show()

            print("\nSaving results...")
            torch.save(results, "rope_frequency_analysis_results.pt")

            print("\nAnalysis Summary:")
            print(f"Number of heads analyzed: {len(results['spectral_metrics'])}")

            avg_entropy = np.mean(
                [m["spectral_entropy"] for m in results["spectral_metrics"].values()]
            )
            avg_selectivity = np.mean(
                [m["mean_selectivity"] for m in results["spectral_metrics"].values()]
            )

            print(f"Average spectral entropy: {avg_entropy:.3f}")
            print(f"Average frequency selectivity: {avg_selectivity:.3f}")

            print(
                "\nAnalysis complete. Results saved to 'rope_frequency_analysis_results.pt'"
            )
            return results
        else:
            print("Analysis failed to produce valid results")
            return None

    except Exception as e:
        print(f"An error occurred during analysis: {str(e)}")
        raise


if __name__ == "__main__":
    main()

In [None]:
class WaveletAnalyzer:
    def __init__(self, model_name: str = "meta-llama/Llama-3.2-1B"):
        logging.basicConfig(
            level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
        )
        self.logger = logging.getLogger(__name__)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.logger.info(f"Using device: {self.device}")

        self.logger.info(f"Loading model {model_name}...")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True,
            attn_implementation="eager",
        ).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model.eval()

        self.hidden_size = self.model.config.hidden_size
        self.num_heads = self.model.config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_layers = self.model.config.num_hidden_layers

        self.logger.info(
            f"Model loaded: {self.num_layers} layers, {self.num_heads} heads"
        )

    def analyze_dataset(self, num_samples: int = 500) -> List[Dict]:
        """Analyze multiple samples from the dataset with progress tracking."""
        self.logger.info(f"Starting analysis of {num_samples} samples...")

        dataset = load_dataset("bookcorpus", split="train", streaming=True)
        texts = []

        with tqdm(desc="Collecting samples") as pbar:
            for item in dataset:
                text = item["text"]
                if 10 <= len(text.split()) <= 100: 
                    texts.append(text)
                    pbar.update(1)
                    if len(texts) == num_samples:
                        break

        self.logger.info(f"Collected {len(texts)} suitable text samples")

    def analyze_dataset(self, num_samples: int = 500) -> Dict:
        """
        Analyze a dataset of multiple samples and run wavelet-like property tests.
        it will:
        1. Load 500 samples
        2. For each sample, run:
           - Scale Sensitivity Test (with wavelet analysis)
           - Multi-Resolution Analysis (wavelet-based)
           - Uncertainty Principle Validation (positional vs wavelet-band entropy)
           - Frame Completeness Test (wavelet-based reconstruction)
        3. Aggregate results.
        """
        self.logger.info(f"Starting analysis of {num_samples} samples...")
        dataset = load_dataset("bookcorpus", split="train", streaming=True)

        texts = []
        with tqdm(desc="Collecting samples") as pbar:
            for item in dataset:
                text = item["text"]
                if 10 <= len(text.split()) <= 100:
                    texts.append(text)
                    pbar.update(1)
                    if len(texts) == num_samples:
                        break

        scale_sensitivity_results = []
        multi_resolution_results = []
        uncertainty_results = []
        frame_results = []

        with tqdm(total=len(texts), desc="Analyzing samples") as pbar:
            for idx, text in enumerate(texts):
                try:
                    sc_res = self.scale_sensitivity_test(text)
                    mr_res = self.multi_resolution_analysis(text)
                    up_res = self.uncertainty_principle_validation(text)
                    fc_res = self.frame_completeness_test(text)

                    scale_sensitivity_results.append(sc_res)
                    multi_resolution_results.append(mr_res)
                    uncertainty_results.append(up_res)
                    frame_results.append(fc_res)

                    if (idx + 1) % 50 == 0:
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()
                        gc.collect()

                except Exception as e:
                    self.logger.error(f"Error processing sample {idx}: {str(e)}")
                    continue
                pbar.update(1)

        aggregated = {
            "scale_sensitivity": self._aggregate_scale_sensitivity(
                scale_sensitivity_results
            ),
            "multi_resolution": self._aggregate_multi_resolution(
                multi_resolution_results
            ),
            "uncertainty": self._aggregate_uncertainty(uncertainty_results),
            "frame_completeness": self._aggregate_frame(frame_results),
        }

        return aggregated

    def scale_sensitivity_test(
        self, text: str, scales: List[float] = [1.0, 0.5, 0.25]
    ) -> Dict:
        tokens = self.tokenizer(text, return_tensors="pt").to(self.device)
        input_ids = tokens.input_ids[0]

        original_attn = self._get_avg_attention(tokens)
        original_coefs = self._wavelet_decompose_attention(original_attn)

        scale_results = {}
        for scale in scales:
            if scale < 1.0:
                stride = int(1 / scale)
                scaled_ids = input_ids[::stride]
            else:
                scaled_ids = input_ids

            scaled_ids = scaled_ids.unsqueeze(0).to(self.device)
            attn = self._get_avg_attention(
                {"input_ids": scaled_ids, "attention_mask": torch.ones_like(scaled_ids)}
            )
            scaled_coefs = self._wavelet_decompose_attention(attn)

            # Compare similarity of wavelet coefficient distributions
            similarity = self._compare_wavelet_coefficients(
                original_coefs, scaled_coefs
            )
            scale_results[f"scale_{scale}"] = {
                "wavelet_similarity_with_original": similarity
            }

        return scale_results

    def multi_resolution_analysis(
        self, text: str, window_sizes: List[int] = [16, 32, 64]
    ) -> Dict:
        # Analyze how wavelet coefficients vary with different window sizes on the input
        tokens = self.tokenizer(
            text, return_tensors="pt", truncation=True, max_length=256
        ).to(self.device)
        input_len = tokens.input_ids.shape[-1]

        resolution_results = {}
        for wsize in window_sizes:
            if wsize >= input_len:
                segments = [tokens]
            else:
                segments = []
                # Slide windows
                for start in range(0, input_len - wsize + 1, wsize):
                    window_ids = tokens.input_ids[:, start : start + wsize]
                    window_mask = torch.ones_like(window_ids)
                    segments.append(
                        {
                            "input_ids": window_ids.to(self.device),
                            "attention_mask": window_mask.to(self.device),
                        }
                    )

            # Aggregate wavelet band entropy across segments
            segment_band_entropies = []
            for seg in segments:
                attn = self._get_avg_attention(seg)
                coefs = self._wavelet_decompose_attention(attn)
                # coefs is a dict of form {head_i: [cA, cD1, cD2, ...]}
                # Compute entropy of each band distribution and average
                band_ents = self._band_entropies(coefs)
                segment_band_entropies.append(band_ents)

            # Average over segments
            mean_band_ent = self._average_dicts(segment_band_entropies)
            resolution_results[f"window_{wsize}"] = mean_band_ent

        return resolution_results

    def uncertainty_principle_validation(self, text: str) -> Dict:
        tokens = self.tokenizer(
            text, return_tensors="pt", truncation=True, max_length=128
        ).to(self.device)
        attn = self._get_avg_attention(tokens)

        pos_entropy = self._compute_positional_entropy(attn)
        coefs = self._wavelet_decompose_attention(attn)
        band_ents = self._band_entropies(coefs)  # Average band entropy per head
        # pos_entropy: dict head_idx -> val
        # band_ents: dict head_idx -> {"approx_entropy": val, "detail_entropies": [val,...]}
        heads = sorted(pos_entropy.keys())
        pos_arr = []
        spec_arr = []
        for h in heads:
            # sum detail bands for that head
            detail_sum = np.sum(band_ents[h]["detail_entropies"])
            pos_arr.append(pos_entropy[h])
            spec_arr.append(detail_sum)

        pos_arr = np.array(pos_arr)
        spec_arr = np.array(spec_arr)
        correlation = np.corrcoef(pos_arr, spec_arr)[0, 1]

        return {"pos_spec_correlation": correlation}

    def frame_completeness_test(self, text: str) -> Dict:
        tokens = self.tokenizer(
            text, return_tensors="pt", truncation=True, max_length=128
        ).to(self.device)
        attn = self._get_avg_attention(tokens)
        coefs = self._wavelet_decompose_attention(attn)

        # Reconstruct for each head:
        # coefs[h]: [cA, cD1, cD2, ...]
        seq_len = attn.shape[-1]
        reconstruction = torch.zeros_like(attn)

        for h in range(self.num_heads):
            # Extract average pattern over queries
            avg_pattern = attn[h].mean(0).cpu().numpy()
            head_coefs = coefs[f"head_{h}"]
            approx = head_coefs[0]
            details = head_coefs[1:]
            recon_signal = pywt.waverec(
                [approx] + details, "db2"
            ) 
            if len(recon_signal) > seq_len:
                recon_signal = recon_signal[:seq_len]
            elif len(recon_signal) < seq_len:
                # pad
                pad_len = seq_len - len(recon_signal)
                recon_signal = np.pad(recon_signal, (0, pad_len))

            reconstruction[h] = torch.tensor(
                np.tile(recon_signal, (seq_len, 1)), device=attn.device
            )

        error = (attn - reconstruction).abs().mean().item()
        return {"reconstruction_error": error}

    def _wavelet_decompose_attention(
        self, attn: torch.Tensor, wavelet: str = "db2", level: Optional[int] = None
    ) -> Dict[str, List[np.ndarray]]:
        # Perform 1D wavelet decomposition on the average pattern (mean over queries)
        # Returns a dict: head_{i} -> [cA, cD1, cD2, ...]
        num_heads, seq_len, _ = attn.shape
        results = {}
        # Use mean over queries to get a single vector per head
        for h in range(num_heads):
            head_pattern = attn[h].mean(0).cpu().numpy()
            # Decompose
            max_level = pywt.dwt_max_level(seq_len, pywt.Wavelet(wavelet).dec_len)
            if level is None or level > max_level:
                lvl = max_level
            else:
                lvl = level
            coeffs = pywt.wavedec(head_pattern, wavelet=wavelet, level=lvl)
            # coeffs = [cA, cDn, cD(n-1), ..., cD1]
            results[f"head_{h}"] = coeffs
        return results

    def _band_entropies(
        self, coefs: Dict[str, List[np.ndarray]]
    ) -> Dict[int, Dict[str, any]]:
        # Return a dict: head_idx -> {"approx_entropy": val, "detail_entropies": [val,...]}
        result = {}
        for h_key, c_list in coefs.items():
            h_idx = int(h_key.split("_")[1])
            approx = c_list[0]
            details = c_list[1:]

            # Normalize coefficients for entropy calculation
            def dist_entropy(arr):
                arr = np.abs(arr)
                arr_sum = arr.sum()
                if arr_sum == 0:
                    return 0.0
                p = arr / arr_sum
                return float(entropy(p))

            a_ent = dist_entropy(approx)
            d_ents = [dist_entropy(d) for d in details]

            result[h_idx] = {"approx_entropy": a_ent, "detail_entropies": d_ents}
        return result

    def _compare_wavelet_coefficients(
        self, coefs1: Dict[str, List[np.ndarray]], coefs2: Dict[str, List[np.ndarray]]
    ) -> float:

        correlations = []
        for h_key in coefs1.keys():
            if h_key in coefs2:
                cA1 = np.abs(coefs1[h_key][0])
                cA2 = np.abs(coefs2[h_key][0])
                # Pad to min length
                min_len = min(len(cA1), len(cA2))
                cA1 = cA1[:min_len]
                cA2 = cA2[:min_len]

                # Pearson correlation
                if np.std(cA1) < 1e-10 or np.std(cA2) < 1e-10:
                    corr = 0.0
                else:
                    corr = np.corrcoef(cA1, cA2)[0, 1]
                correlations.append(corr)
        if len(correlations) == 0:
            return 0.0
        return float(np.mean(correlations))

    def _average_dicts(
        self, dicts: List[Dict[int, Dict[str, any]]]
    ) -> Dict[int, Dict[str, any]]:
        all_heads = set()
        for d in dicts:
            all_heads.update(d.keys())
        all_heads = sorted(all_heads)

        results = {}
        for h in all_heads:
            approx_ents = []
            detail_ents_list = []
            for d in dicts:
                if h in d:
                    approx_ents.append(d[h]["approx_entropy"])
                    detail_ents_list.append(d[h]["detail_entropies"])
            min_len = min(len(x) for x in detail_ents_list)
            detail_avg = np.mean([x[:min_len] for x in detail_ents_list], axis=0)
            results[h] = {
                "approx_entropy": float(np.mean(approx_ents)),
                "detail_entropies": detail_avg.tolist(),
            }
        return results

    def _aggregate_scale_sensitivity(self, results_list: List[Dict]) -> Dict:
        scales = ["scale_1.0", "scale_0.5", "scale_0.25"]
        aggregated = {sc: [] for sc in scales}

        for res in results_list:
            for sc in scales:
                aggregated[sc].append(res[sc]["wavelet_similarity_with_original"])

        final = {}
        for sc in scales:
            final[sc] = float(np.mean(aggregated[sc]))

        return final

    def _aggregate_multi_resolution(self, results_list: List[Dict]) -> Dict:
        window_keys = results_list[0].keys()
        aggregated = {}
        for w in window_keys:
            approx_vals = []
            detail_vals = []
            for r in results_list:
                # r[w] is {head_idx: {"approx_entropy":..., "detail_entropies":[...]}}
                h_approx = [r[w][h]["approx_entropy"] for h in r[w]]
                approx_vals.append(np.mean(h_approx))
            aggregated[w] = {
                "mean_approx_entropy_across_samples": float(np.mean(approx_vals))
            }
        return aggregated

    def _aggregate_uncertainty(self, results_list: List[Dict]) -> Dict:
        correlations = [r["pos_spec_correlation"] for r in results_list]
        return {"mean_pos_spec_correlation": float(np.mean(correlations))}

    def _aggregate_frame(self, results_list: List[Dict]) -> Dict:
        errors = [r["reconstruction_error"] for r in results_list]
        return {"mean_reconstruction_error": float(np.mean(errors))}

    def _get_avg_attention(self, tokens: Dict) -> torch.Tensor:
        with torch.no_grad():
            outputs = self.model(**tokens, output_attentions=True)
            # Average over batch and layers
            layer_attns = [a.mean(0) for a in outputs.attentions]
            avg_attn = torch.stack(layer_attns, dim=0).mean(0)
            return avg_attn

    def _compute_positional_entropy(self, attn: torch.Tensor) -> Dict[int, float]:
        num_heads, seq_len, _ = attn.shape
        pos_entropy = {}
        for h in range(num_heads):
            dist = attn[h].mean(0)
            dist = dist / (dist.sum() + 1e-10)
            pos_entropy[h] = float(entropy(dist.cpu().numpy()))
        return pos_entropy


def main():
    analyzer = WaveletAnalyzer()
    results = analyzer.analyze_dataset(num_samples=500)
    print("Aggregated Results over 500 samples:")
    print(results)


if __name__ == "__main__":
    main()