In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import torch
import numpy as np
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    T5ForConditionalGeneration,
    BertModel,
)
import matplotlib.pyplot as plt
import logging
from typing import Dict, List, Tuple, Optional
import math
from datasets import load_dataset
from tqdm.notebook import tqdm
import gc
from scipy.optimize import curve_fit
from scipy.stats import entropy

In [None]:
# model_name = "meta-llama/Meta-Llama-3.1-8B"
# model_name = "meta-llama/Meta-Llama-3-8B"
# model_name = "EleutherAI/pythia-2.8b"
# model_name = "meta-llama/Llama-2-7b-hf"
# model_name = "google/gemma-2-2b-it"
# model_name = "meta-llama/Llama-3.2-1B"
# model_name = "google/flan-t5-large"
model_name = "google-bert/bert-base-uncased"
# model_name =
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

model.config.output_attentions = True
model.config.output_hidden_states = True
model.config.return_dict = True
model.config.device_map = "auto"
# model.config.attn_implementation= "eager"

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    # tokenizer.add_special_tokens({'pad_token': '[PAD]'})

In [None]:
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
class PositionalEncodingAnalyzer:
    def __init__(self, model_type: str = "rope", model_name: str = None):
        """
        Initialize analyzer for different position encoding types.

        Args:
            model_type: One of ["rope", "t5", "bert", "no_pe"]
            model_name: Optional specific model name, otherwise uses defaults
        """
        logging.basicConfig(
            level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
        )
        self.logger = logging.getLogger(__name__)

        self.model_type = model_type
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.logger.info(f"Using device: {self.device}")

        # Select appropriate model based on type
        if model_name is None:
            model_name = self._get_default_model_name(model_type)

        self.logger.info(f"Loading {model_type} model: {model_name}")

        self.model = self._initialize_model(model_type, model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self._setup_model_dimensions()

        self.logger.info(
            f"Model loaded: {self.num_layers} layers, {self.num_heads} heads"
        )

    def _get_default_model_name(self, model_type: str) -> str:
        """Get default model name for each type."""
        defaults = {
            "rope": "meta-llama/Llama-2-7b-hf",
            "t5": "t5-base",
            "bert": "bert-base-uncased",
            "no_pe": "gpt2",  # modify this to remove PE
        }
        return defaults.get(model_type, "meta-llama/Llama-2-7b-hf")

    def _initialize_model(self, model_type: str, model_name: str):
        """Initialize appropriate model based on type."""
        if model_type == "rope":
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=(
                    torch.float16 if torch.cuda.is_available() else torch.float32
                ),
                low_cpu_mem_usage=True,
                attn_implementation="eager",
            )
        elif model_type == "t5":
            model = T5ForConditionalGeneration.from_pretrained(
                model_name,
                torch_dtype=(
                    torch.float16 if torch.cuda.is_available() else torch.float32
                ),
                low_cpu_mem_usage=True,
            )
        elif model_type == "bert":
            model = BertModel.from_pretrained(
                model_name,
                torch_dtype=(
                    torch.float16 if torch.cuda.is_available() else torch.float32
                ),
                low_cpu_mem_usage=True,
                output_attentions=True,
            )
        elif model_type == "no_pe":
            model = AutoModelForCausalLM.from_pretrained(model_name)
            # Zero out position embeddings
            if hasattr(model, "transformer"):
                model.transformer.wpe.weight.data.zero_()
            elif hasattr(model, "wpe"):
                model.wpe.weight.data.zero_()
        else:
            raise ValueError(f"Unsupported model type: {model_type}")

        return model.to(self.device)

    def _setup_model_dimensions(self):
        """Extract model dimensions based on config."""
        config = self.model.config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_layers = config.num_hidden_layers

    def analyze_dataset(self, num_samples: int = 500) -> List[Dict]:
        """Analyze multiple samples from the dataset with progress tracking."""
        self.logger.info(f"Starting analysis of {num_samples} samples...")

        # Load and filter dataset
        dataset = load_dataset("bookcorpus", split="train", streaming=True)
        texts = []

        with tqdm(desc="Collecting samples") as pbar:
            for item in dataset:
                text = item["text"]
                if 10 <= len(text.split()) <= 100:
                    texts.append(text)
                    pbar.update(1)
                    if len(texts) == num_samples:
                        break

        results = []
        with tqdm(total=len(texts), desc="Analyzing samples") as pbar:
            for idx, text in enumerate(texts):
                try:
                    result = self.analyze_single_text(text)
                    if result is not None:
                        results.append(result)

                    if (idx + 1) % 50 == 0:
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()
                        gc.collect()

                except Exception as e:
                    self.logger.error(f"Error processing sample {idx}: {str(e)}")
                    continue

                pbar.update(1)

        return results

    def analyze_single_text(self, text: str) -> Optional[Dict]:
        """Analyze attention patterns for a single text."""
        try:
            # Tokenize based on model type
            tokens = self._tokenize_text(text)

            # Get attention patterns
            with torch.no_grad():
                attention_patterns = self._get_attention_patterns(tokens)

            if attention_patterns is None:
                return None

            # Analyze patterns for each layer
            results = {}
            for layer_idx, layer_attention in enumerate(attention_patterns):
                layer_results = {}
                layer_attention = layer_attention[0]

                for head_idx in range(self.num_heads):
                    head_patterns = layer_attention[head_idx]
                    metrics = self._compute_attention_metrics(head_patterns)
                    layer_results[f"head_{head_idx}"] = metrics

                results[f"layer_{layer_idx}"] = layer_results

            return results

        except Exception as e:
            self.logger.error(f"Error analyzing text: {str(e)}")
            return None

    def _tokenize_text(self, text: str):
        """Tokenize text based on model type."""
        if self.model_type in ["rope", "no_pe"]:
            return self.tokenizer(
                text, return_tensors="pt", truncation=True, max_length=512
            ).to(self.device)
        elif self.model_type == "t5":
            return self.tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                max_length=512,
                padding="max_length",
            ).to(self.device)
        elif self.model_type == "bert":
            return self.tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                max_length=512,
                padding="max_length",
                return_token_type_ids=True,
            ).to(self.device)

    def _get_attention_patterns(self, tokens):
        """Get attention patterns based on model type."""
        try:
            if self.model_type == "t5":
                outputs = self.model(
                    input_ids=tokens.input_ids,
                    attention_mask=tokens.attention_mask,
                    decoder_input_ids=tokens.input_ids,
                    output_attentions=True,
                )
                return outputs.encoder_attentions
            elif self.model_type == "bert":
                outputs = self.model(**tokens, output_attentions=True)
                return outputs.attentions
            else:  # rope and no_pe
                outputs = self.model(**tokens, output_attentions=True)
                return outputs.attentions
        except Exception as e:
            self.logger.error(f"Error getting attention patterns: {str(e)}")
            return None

    def _compute_attention_metrics(self, attention_weights: torch.Tensor) -> Dict:
        """Compute comprehensive metrics for attention patterns."""
        try:
            # Handle dimensions
            weights = attention_weights.float()
            if weights.dim() == 3:
                weights = weights.squeeze(0)

            seq_len = weights.size(-1)

            # Basic attention statistics
            mean_attention = weights.mean().item()
            max_attention = weights.max().item()
            local_attention = torch.diagonal(weights).mean().item()

            # Distance-based analysis
            positions = torch.arange(seq_len, device=weights.device)
            distances = torch.abs(positions.unsqueeze(0) - positions.unsqueeze(1))
            weighted_distances = distances.float() * weights
            avg_distance = weighted_distances.sum() / weights.sum()

            # Distance profile and decay coefficient
            distance_profile = self._compute_distance_profile(weights, distances)
            decay_coefficient = self._fit_exponential_decay(distance_profile)

            return {
                "mean_attention": mean_attention,
                "max_attention": max_attention,
                "local_attention": local_attention,
                "avg_distance": avg_distance.item(),
                "decay_rate": decay_coefficient,
            }

        except Exception as e:
            self.logger.error(f"Error in attention metrics computation: {str(e)}")
            return {
                "mean_attention": 0.0,
                "max_attention": 0.0,
                "local_attention": 0.0,
                "avg_distance": 0.0,
                "decay_rate": 0.0,
            }

    def _compute_distance_profile(self, weights, distances):
        """Compute attention weight profile over distances."""
        unique_distances = torch.unique(distances)
        profile = []
        for d in unique_distances:
            mask = distances == d
            if mask.any():
                avg = weights[mask].mean().item()
                profile.append(avg)
        return profile

    def _fit_exponential_decay(self, distance_profile: List[float]) -> float:
        """Fit exponential decay to the distance profile."""
        if len(distance_profile) < 2:
            return 0.0

        distances = np.arange(len(distance_profile))

        def exp_func(x, A, c):
            return A * np.exp(-c * x)

        try:
            popt, _ = curve_fit(
                exp_func, distances, distance_profile, p0=(distance_profile[0], 0.1)
            )
            return float(popt[1])
        except:
            return 0.0

    def run_ablation_study(self, num_samples: int = 100):
        """Run ablation study across different positional encoding types."""
        self.logger.info("Starting ablation study...")

        # Store results for each model type
        results = {"rope": [], "t5": [], "bert": [], "no_pe": []}

        dataset = load_dataset("bookcorpus", split="train", streaming=True)
        texts = []

        for item in dataset:
            if len(texts) >= num_samples:
                break
            text = item["text"]
            if 10 <= len(text.split()) <= 100:
                texts.append(text)

        # Analyze model type
        for model_type in results.keys():
            self.logger.info(f"Analyzing {model_type} model...")
            self.model_type = model_type
            self.__init__(model_type)

            with tqdm(total=len(texts), desc=f"Processing {model_type}") as pbar:
                for text in texts:
                    result = self.analyze_single_text(text)
                    if result is not None:
                        results[model_type].append(result)
                    pbar.update(1)

            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()

        return results

    def visualize_ablation_results(self, results: Dict):
        """Create visualizations comparing different positional encoding types."""
        if not results:
            self.logger.warning("No results to visualize.")
            return

        fig, axes = plt.subplots(2, 2, figsize=(20, 16))

        metrics_to_plot = {
            "avg_distance": "Average Attention Distance",
            "decay_rate": "Attention Decay Rate",
            "local_attention": "Local Attention Strength",
            "mean_attention": "Mean Attention Value",
        }

        for (i, j), (metric, title) in zip(
            [(0, 0), (0, 1), (1, 0), (1, 1)], metrics_to_plot.items()
        ):
            self._plot_comparison_metrics(axes[i, j], results, metric, title)

        plt.suptitle(
            "Comparison of Attention Patterns Across Position Encoding Types",
            fontsize=16,
            y=1.02,
        )
        plt.tight_layout()
        return fig

    def _plot_comparison_metrics(self, ax, results, metric, title):
        """Plot comparison of metrics across different model types."""
        positions = []
        data = []
        labels = []
        for i, (model_type, model_results) in enumerate(results.items()):
            values = []
            for result in model_results:
                for layer in result.values():
                    for head in layer.values():
                        values.append(head[metric])

            if values:
                positions.append(i)
                data.append(values)
                labels.append(model_type)

        if data:
            vplot = ax.violinplot(data, positions=positions)

            for pc in vplot["bodies"]:
                pc.set_alpha(0.7)

            ax.set_xticks(positions)
            ax.set_xticklabels(labels)
            ax.set_title(title)
            ax.set_ylabel("Value")

            ax.grid(True, linestyle="--", alpha=0.7)


def main():
    """Run comprehensive analysis pipeline with ablation study."""
    analyzer = PositionalEncodingAnalyzer()
    ablation_results = analyzer.run_ablation_study(num_samples=100)

    fig = analyzer.visualize_ablation_results(ablation_results)
    if fig is not None:
        plt.show()

    torch.save(ablation_results, "positional_encoding_ablation_results.pt")
    print(
        "\nAblation study complete. Results saved to 'positional_encoding_ablation_results.pt'"
    )


if __name__ == "__main__":
    main()