In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
import numpy as np
from functools import partial
from typing import List, Dict, Tuple, Optional
import matplotlib.pyplot as plt
import nltk
from nltk import sent_tokenize, word_tokenize
import json

nltk.download('punkt')


class TransformerJacobianAnalyzer:
    def __init__(self, model_name="gpt2"):
        """Initialize with a Hugging Face transformer model"""
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.model.eval()
        self.model_name = model_name

    def prepare_input(self, text: str) -> Dict[str, torch.Tensor]:
        """Tokenize and prepare input for the model"""
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        return {k: v.to(self.device) for k, v in inputs.items()}

    def compute_jacobians(self, inputs: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """Compute Jacobians of attention and MLP functions for all layers"""
        input_ids = inputs['input_ids']
        seq_len = input_ids.size(1)

        # Get initial embeddings
        embeddings = self.model.wte(input_ids) + self.model.wpe(torch.arange(seq_len, device=self.device))

        def attention_function(emb: torch.Tensor, layer) -> torch.Tensor:
            """Attention function for a single layer"""
            return layer.attn(layer.ln_1(emb))[0]

        def compute_jacobian_attn(emb, layer):
            emb_before, embed_last = emb[:-1], emb[-1]
            def attn_i(emb_i):
                hidden_states = torch.cat([emb_before, emb_i[None, ...]], dim=0)
                return layer.attn(layer.ln_1(hidden_states))[0]

            return torch.autograd.functional.jacobian(attn_i, embed_last)


        def feedforward_function(emb: torch.Tensor, layer) -> torch.Tensor:
            """MLP/feedforward function for a single layer"""
            return layer.mlp(layer.ln_2(emb))

        def compute_jacobian_feedforward(emb, layer):
            # only consider the last token
            emb_i = emb[-1]
            def mlp(emb_i):
                return layer.mlp(layer.ln_2(emb_i))

            return torch.autograd.functional.jacobian(mlp, emb_i)

        hidden_states = embeddings.requires_grad_(True)
        jacobian_attns, jacobian_mlps = [], []

        print(f"Computing Jacobians for {len(self.model.h)} layers...")

        for i, layer in enumerate(self.model.h):
            print(f"Processing layer {i+1}/{len(self.model.h)}")

            # Compute attention Jacobian
            try:
                attn_fn = partial(attention_function, layer=layer)
                jacobian_attn = compute_jacobian_attn(hidden_states, layer)
                jacobian_attns.append(jacobian_attn.detach().cpu())

                # Apply attention with residual connection
                attn_output = attn_fn(hidden_states)
                hidden_states = hidden_states + attn_output
                # hidden_states = hidden_states.detach().requires_grad_(True)

                # Compute MLP Jacobian
                mlp_fn = partial(feedforward_function, layer=layer)
                jacobian_mlp = compute_jacobian_feedforward(hidden_states, layer)
                jacobian_mlps.append(jacobian_mlp.detach().cpu())

                # Apply MLP with residual connection
                mlp_output = mlp_fn(hidden_states)
                hidden_states = hidden_states + mlp_output
                # hidden_states = hidden_states.detach().requires_grad_(True)

            except Exception as e:
                print(f"Error computing Jacobian for layer {i}: {e}")
                continue

        return jacobian_attns, jacobian_mlps

    def compute_eigenvalues(self, jacobian: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]:
        """Compute eigenvalues and eigenvectors from Jacobian matrix"""
        # Convert to numpy
        jac_np = jacobian.cpu().numpy()

        # Reshape to 2D if needed (flatten spatial dimensions)
        original_shape = jac_np.shape
        if len(original_shape) > 2:
            # Reshape to (output_features, input_features)
            batch_size, seq_len, hidden_dim = original_shape[:3]
            jac_2d = jac_np.reshape(-1, original_shape[-1])
        else:
            jac_2d = jac_np

        # For non-square matrices, use SVD for stability
        if jac_2d.shape[0] != jac_2d.shape[1]:
            # Use singular value decomposition
            U, s, Vt = np.linalg.svd(jac_2d, full_matrices=False)
            eigenvals = s  # Singular values
            eigenvecs = Vt.T  # Right singular vectors
        else:
            # Standard eigendecomposition for square matrices
            try:
                eigenvals, eigenvecs = np.linalg.eig(jac_2d)
            except np.linalg.LinAlgError:
                # Fallback to SVD if eigendecomposition fails
                U, s, Vt = np.linalg.svd(jac_2d, full_matrices=False)
                eigenvals = s
                eigenvecs = Vt.T

        # Sort by magnitude
        idx = np.argsort(np.abs(eigenvals))[::-1]
        eigenvals = eigenvals[idx]
        eigenvecs = eigenvecs[:, idx]

        return eigenvals, eigenvecs

    def spectral_analysis(self, jacobian_attns: List[torch.Tensor],
                         jacobian_mlps: List[torch.Tensor], prompt) -> Dict:
        """Perform comprehensive spectral analysis on Jacobians"""

        results = {
            'prompt': prompt,
            'attention_analysis': [],
            'mlp_analysis': [],
            'layer_statistics': {}
        }

        print("Performing spectral analysis...")

        for i, (jac_attn, jac_mlp) in enumerate(zip(jacobian_attns, jacobian_mlps)):
            print(f"Analyzing layer {i+1}")

            # Attention analysis
            attn_eigenvals, attn_eigenvecs = self.compute_eigenvalues(jac_attn)
            attn_analysis = {
                'layer': i,
                'eigenvalues': attn_eigenvals,
                'eigenvectors': attn_eigenvecs,
                'spectral_radius': np.max(np.abs(attn_eigenvals)),
                'condition_number': np.max(np.abs(attn_eigenvals)) / np.min(np.abs(attn_eigenvals[attn_eigenvals != 0])) if np.any(attn_eigenvals != 0) else np.inf,
                'rank_estimate': np.sum(np.abs(attn_eigenvals) > 1e-10),
                'dominant_eigenvalue': attn_eigenvals[0],
                'eigenvalue_decay': np.abs(attn_eigenvals[1] / attn_eigenvals[0]) if len(attn_eigenvals) > 1 else 0
            }
            results['attention_analysis'].append(attn_analysis)

            # MLP analysis
            mlp_eigenvals, mlp_eigenvecs = self.compute_eigenvalues(jac_mlp)
            mlp_analysis = {
                'layer': i,
                'eigenvalues': mlp_eigenvals,
                'eigenvectors': mlp_eigenvecs,
                'spectral_radius': np.max(np.abs(mlp_eigenvals)),
                'condition_number': np.max(np.abs(mlp_eigenvals)) / np.min(np.abs(mlp_eigenvals[mlp_eigenvals != 0])) if np.any(mlp_eigenvals != 0) else np.inf,
                'rank_estimate': np.sum(np.abs(mlp_eigenvals) > 1e-10),
                'dominant_eigenvalue': mlp_eigenvals[0],
                'eigenvalue_decay': np.abs(mlp_eigenvals[1] / mlp_eigenvals[0]) if len(mlp_eigenvals) > 1 else 0
            }
            results['mlp_analysis'].append(mlp_analysis)

        # Overall statistics
        attn_spectral_radii = [a['spectral_radius'] for a in results['attention_analysis']]
        mlp_spectral_radii = [m['spectral_radius'] for m in results['mlp_analysis']]

        results['layer_statistics'] = {
            'mean_attn_spectral_radius': np.mean(attn_spectral_radii),
            'std_attn_spectral_radius': np.std(attn_spectral_radii),
            'mean_mlp_spectral_radius': np.mean(mlp_spectral_radii),
            'std_mlp_spectral_radius': np.std(mlp_spectral_radii),
            'max_attn_spectral_radius': np.max(attn_spectral_radii),
            'max_mlp_spectral_radius': np.max(mlp_spectral_radii)
        }

        return results

    def plot_spectral_analysis(self, results: Dict, save_path: Optional[str] = None):
        """Plot comprehensive spectral analysis results"""

        num_layers = len(results['attention_analysis'])

        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        fig.suptitle(f'Spectral Analysis: {self.model_name}, prompt: {results["prompt"]}', fontsize=16)

        # Plot 1: Spectral radii across layers
        layers = range(num_layers)
        attn_radii = [r['spectral_radius'] for r in results['attention_analysis']]
        mlp_radii = [r['spectral_radius'] for r in results['mlp_analysis']]

        axes[0, 0].plot(layers, attn_radii, 'o-', label='Attention', color='blue')
        axes[0, 0].plot(layers, mlp_radii, 's-', label='MLP', color='red')
        axes[0, 0].set_xlabel('Layer')
        axes[0, 0].set_ylabel('Spectral Radius')
        axes[0, 0].set_title('Spectral Radius by Layer')
        axes[0, 0].legend()
        axes[0, 0].grid(True)

        # Plot 2: Condition numbers
        attn_cond = [r['condition_number'] for r in results['attention_analysis']]
        mlp_cond = [r['condition_number'] for r in results['mlp_analysis']]

        axes[0, 1].semilogy(layers, attn_cond, 'o-', label='Attention', color='blue')
        axes[0, 1].semilogy(layers, mlp_cond, 's-', label='MLP', color='red')
        axes[0, 1].set_xlabel('Layer')
        axes[0, 1].set_ylabel('Condition Number (log)')
        axes[0, 1].set_title('Condition Numbers by Layer')
        axes[0, 1].legend()
        axes[0, 1].grid(True)

        # Plot 3: Rank estimates
        attn_ranks = [r['rank_estimate'] for r in results['attention_analysis']]
        # mlp_ranks = [r['rank_estimate'] for r in results['mlp_analysis']]

        axes[0, 2].plot(layers, attn_ranks, 'o-', label='Attention', color='blue')
        # axes[0, 2].plot(layers, mlp_ranks, 's-', label='MLP', color='red')
        axes[0, 2].set_xlabel('Layer')
        axes[0, 2].set_ylabel('Estimated Rank')
        axes[0, 2].set_title('Rank Estimates by Layer')
        axes[0, 2].legend()
        axes[0, 2].grid(True)

        # Plot 4: Eigenvalue spectrum for first layer
        if results['attention_analysis']:
            eigenvals = results['attention_analysis'][0]['eigenvalues'][:50]  # Top 50
            axes[1, 0].semilogy(range(len(eigenvals)), np.abs(eigenvals), 'o-')
            axes[1, 0].set_xlabel('Eigenvalue Index')
            axes[1, 0].set_ylabel('|Eigenvalue| (log)')
            axes[1, 0].set_title('Attention Eigenvalue Spectrum (Layer 0)')
            axes[1, 0].grid(True)

        # Plot 5: MLP eigenvalue spectrum for first layer
        if results['mlp_analysis']:
            eigenvals = results['mlp_analysis'][0]['eigenvalues'][:50]  # Top 50
            axes[1, 1].semilogy(range(len(eigenvals)), np.abs(eigenvals), 'o-', color='red')
            axes[1, 1].set_xlabel('Eigenvalue Index')
            axes[1, 1].set_ylabel('|Eigenvalue| (log)')
            axes[1, 1].set_title('MLP Eigenvalue Spectrum (Layer 0)')
            axes[1, 1].grid(True)

        # Plot 6: Eigenvalue decay rates
        attn_decay = [r['eigenvalue_decay'] for r in results['attention_analysis']]
        mlp_decay = [r['eigenvalue_decay'] for r in results['mlp_analysis']]

        axes[1, 2].plot(layers, attn_decay, 'o-', label='Attention', color='blue')
        axes[1, 2].plot(layers, mlp_decay, 's-', label='MLP', color='red')
        axes[1, 2].set_xlabel('Layer')
        axes[1, 2].set_ylabel('λ₂/λ₁ Ratio')
        axes[1, 2].set_title('Eigenvalue Decay Rate')
        axes[1, 2].legend()
        axes[1, 2].grid(True)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Plot saved to {save_path}")

        plt.show()

[nltk_data] Downloading package punkt to /home/work/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
prompt = 'Hello World!'

analyzer = TransformerJacobianAnalyzer("gpt2")
inputs = analyzer.prepare_input(prompt)
jac_attn, jac_mlp = analyzer.compute_jacobians(inputs)
results = analyzer.spectral_analysis(jac_attn, jac_mlp, prompt)
analyzer.plot_spectral_analysis(results)