This self-contained notebook computes and analyzes the eigenvalues of the OLMO-3-7B-Think transformer model. It handles environment setup, weight loading, and eigenvalue computation and is designed to run within Google Colab.

### Python Setup

In [1]:
!python3 --version
!python3 -m pip install --upgrade pip

In [3]:
!python3 -m pip install torch==2.9.0 \
                        torchaudio==2.9.0 \
                        torchtext==0.18.0 \
                        torchvision==0.24.0

In [None]:
!python3 -m pip install fairscale \
                        fire \
                        flash-linear-attention \
                        johnnydep \
                        jupyter \
                        nvidia-cublas-cu12 \
                        nvidia-cuda-cupti-cu12 \
                        nvidia-cuda-nvrtc-cu12 \
                        nvidia-cuda-runtime-cu12 \
                        nvidia-cudnn-cu12 \
                        nvidia-cufft-cu12 \
                        nvidia-cufile-cu12 \
                        nvidia-curand-cu12 \
                        nvidia-cusolver-cu12 \
                        nvidia-cusparse-cu12 \
                        nvidia-cusparselt-cu12 \
                        nvidia-nvjitlink-cu12 \
                        nvidia-nvtx-cu12 \
                        oyaml \
                        prefetch-generator \
                        pyaml \
                        pyarrow-hotfix \
                        pytorch-warmup \
                        structlog \
                        transformers==4.57.6 \
                        triton \
                        wimpy


### Tokenize/load cached Wikitext dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
"""Wikitext datasets"""
import io
import logging
import os
import pickle
from pathlib import Path
import jax.numpy as jnp
import jax

import torch
import torch.nn.functional as F
#from transformers import GPT2TokenizerFast
from transformers import AutoTokenizer

MODEL_ID = "allenai/Olmo-3-7B-Think"

from datasets import DatasetDict, load_dataset

from functools import partial


class DefaultCollateMixin:
    """Controls collating in the DataLoader

    The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader arguments. Instantiations of this class should modify the callback functions as desired, and modify the collate_args list. The class then defines a _dataloader() method which takes in a DataLoader constructor and arguments, constructs a collate_fn based on the collate_args, and passes the rest of the arguments into the constructor.
    """

    @classmethod
    def _collate_callback(cls, x, *args, **kwargs):
        """
        Modify the behavior of the default _collate method.
        """
        return x

    _collate_arg_names = []

    @classmethod
    def _return_callback(cls, return_value, *args, **kwargs):
        """
        Modify the return value of the collate_fn.
        Assign a name to each element of the returned tuple beyond the (x, y) pairs
        See InformerSequenceDataset for an example of this being used
        """
        x, y, *z = return_value
        assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset"
        return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)}

    @classmethod
    def _collate(cls, batch, *args, **kwargs):
        # From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py
        elem = batch[0]
        if isinstance(elem, torch.Tensor):
            out = None
            if torch.utils.data.get_worker_info() is not None:
                # If we're in a background process, concatenate directly into a
                # shared memory tensor to avoid an extra copy
                numel = sum(x.numel() for x in batch)
                storage = elem.storage()._new_shared(numel)
                out = elem.new(storage)
            x = torch.stack(batch, dim=0, out=out)

            # Insert custom functionality into the collate_fn
            x = cls._collate_callback(x, *args, **kwargs)

            return x
        else:
            return torch.tensor(batch)

    @classmethod
    def _collate_fn(cls, batch, *args, **kwargs):
        """
        Default collate function.
        Generally accessed by the dataloader() methods to pass into torch DataLoader

        Arguments:
            batch: list of (x, y) pairs
            args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback
        """
        x, y, *z = zip(*batch)

        x = cls._collate(x, *args, **kwargs)
        y = cls._collate(y)
        z = [cls._collate(z_) for z_ in z]

        return_value = (x, y, *z)
        return cls._return_callback(return_value, *args, **kwargs)

    # List of loader arguments to pass into collate_fn
    collate_args = []

    def _dataloader(self, dataset, **loader_args):
        collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args}
        loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args}
        loader_cls = loader_registry[loader_args.pop("_name_", None)]
        return loader_cls(
            dataset=dataset,
            collate_fn=partial(self._collate_fn, **collate_args),
            **loader_args,
        )

class SequenceDataset(DefaultCollateMixin):
    registry = {}
    _name_ = NotImplementedError("Dataset must have shorthand name")

    # Since subclasses do not specify __init__ which is instead handled by this class
    # Subclasses can provide a list of default arguments which are automatically registered as attributes
    # TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features of this class such as the _name_ and d_input/d_output
    @property
    def init_defaults(self):
        return {}

    # https://www.python.org/dev/peps/pep-0487/#subclass-registration
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        cls.registry[cls._name_] = cls

    def __init__(self, _name_, data_dir=None, **dataset_cfg):
        assert _name_ == self._name_
        self.data_dir = Path(data_dir).absolute() if data_dir is not None else None

        # Add all arguments to self
        init_args = self.init_defaults.copy()
        init_args.update(dataset_cfg)
        for k, v in init_args.items():
            setattr(self, k, v)

        # The train, val, test datasets must be set by `setup()`
        self.dataset_train = self.dataset_val = self.dataset_test = None

        self.init()

    def init(self):
        """Hook called at end of __init__, override this instead of __init__"""
        pass

    def setup(self):
        """This method should set self.dataset_train, self.dataset_val, and self.dataset_test."""
        raise NotImplementedError

    def split_train_val(self, val_split):
        """
        Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair.
        """
        train_len = int(len(self.dataset_train) * (1.0 - val_split))
        self.dataset_train, self.dataset_val = torch.utils.data.random_split(
            self.dataset_train,
            (train_len, len(self.dataset_train) - train_len),
            generator=torch.Generator().manual_seed(
                getattr(self, "seed", 42)
            ),  # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us
        )

    def train_dataloader(self, **kwargs):
        return self._train_dataloader(self.dataset_train, **kwargs)

    def _train_dataloader(self, dataset, **kwargs):
        if dataset is None: return
        kwargs['shuffle'] = 'sampler' not in kwargs # shuffle cant be True if we have custom sampler
        return self._dataloader(dataset, **kwargs)

    def val_dataloader(self, **kwargs):
        return self._eval_dataloader(self.dataset_val, **kwargs)

    def test_dataloader(self, **kwargs):
        return self._eval_dataloader(self.dataset_test, **kwargs)

    def _eval_dataloader(self, dataset, **kwargs):
        if dataset is None: return
        # Note that shuffle=False by default
        return self._dataloader(dataset, **kwargs)

    def __str__(self):
        return self._name_

def loss_fn(logits, labels):
    """
    Pick the desired loss depending on the shape of the logits (and therefore the task)
    """
    if len(logits.shape) == 2 or len(logits.shape) == 3:  # for classification tasks
        losses = cross_entropy_loss(logits, labels)
    if len(logits.shape) == 4:  # for tasks with multidimensional dense targets
        losses = cross_entropy_loss(logits, labels).mean(axis=-1)
    return jnp.mean(losses)

def get_default_data_path():
    from launch import default_data_path
    return default_data_path

class WikiText(SequenceDataset):
    _name_ = "wikitext"
    d_output = 2
    l_output = 0

    @property
    def init_defaults(self):
        return {
            "version": 2,
            "block_size": 1024,
            "seed": 42,
            "n_workers": 1,  # Only used for tokenizing dataset before caching
        }

    @property
    def n_tokens(self):
        return self.vocab_size

    @property
    def l_max(self):
        return self.block_size

    def get_metrics(self, layer="s4"):
        if layer in ["mamba", "transformer"]:
            return self.get_metrics_torch()
        else:
            return self.get_metrics_jax()

    def get_metrics_torch(self):
        return lambda y_hat, y: torch.exp(F.cross_entropy(y_hat.reshape(-1, y_hat.size(-1)), y.reshape(-1))).item()

    def get_metrics_jax(self):
        return lambda y_hat, y: jnp.exp(loss_fn(y_hat, y))

    def prepare_data(self):
        if self.cache_dir is None:  # Just download the dataset
                load_dataset(self._name_, "{0}-{1}-raw-v1".format(self._name_, self.version), cache_dir=self.data_dir)
        else:  # Process the dataset and save it
            self.process_dataset()

    def setup(self, stage=None):
        """If cache_dir is not None, we'll cache the processed dataset there."""
        self.data_dir = self.data_dir or get_default_data_path() / self._name_
        self.cache_dir = self.data_dir / "cache"

        if stage == "test" and hasattr(self, "dataset_test"):
            return
        dataset, self.tokenizer, self.vocab = self.process_dataset()
        print(
            f"WikiText-{self.version} | tokenizer {self.tokenizer.name_or_path} | vocab size {len(self.vocab)}"
        )
        dataset.set_format(type="torch", columns=["input_ids", "labels"])

        # Create all splits
        self.dataset_train, self.dataset_test = dataset["train"], dataset["test"]
        self.dataset_val = None # don't use validation set

    def _collate_fn(self, batch):
        xs, ys = zip(*[(data["input_ids"], data["labels"]) for data in batch])
        xs = torch.stack(xs, dim=0)
        ys = torch.stack(ys, dim=0)
        return xs, ys, {"lengths": self.block_size}

    def process_dataset(self):
        cache_dir = (
            None if self.cache_dir is None else self.cache_dir / self._cache_dir_name
        )
        if cache_dir is not None:
            if cache_dir.is_dir():
                return self._load_from_cache(cache_dir)

        dataset = load_dataset("Salesforce/wikitext", "{0}-{1}-raw-v1".format(self._name_, self.version), cache_dir=self.data_dir)
        dataset = DatasetDict(train=dataset["train"], test=dataset["test"]) # remove validation

        # Use the OLMO-3 tokenizer
        MODEL_ID = "allenai/Olmo-3-7B-Think"
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
        # tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        tokenizer.pad_token = tokenizer.eos_token

        vocab = tokenizer.vocab

        # tokenize
        tokenize = lambda example: tokenizer(example["text"])
        dataset = dataset.map(
            tokenize,
            remove_columns=["text"],
            batched=True,
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=max(self.n_workers, 1),
        )

        # group inputs (ensure equal length)
        def group_inputs(examples):
            # Concatenate all tokenized input_ids
            concatenated = {k: sum(examples[k], []) for k in examples.keys()}
            total_length = len(concatenated["input_ids"])
            # Truncate to a multiple of block_size
            total_length = (total_length // self.block_size) * self.block_size
            # Split into chunks of block_size
            result = {
                k: [t[i : i + self.block_size] for i in range(0, total_length, self.block_size)]
                for k, t in concatenated.items()
            }
            # Labels = input_ids for causal language modeling
            result["labels"] = result["input_ids"].copy()
            return result

        dataset = dataset.map(
            group_inputs,
            batched=True,
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=max(self.n_workers, 1),
        )

        # shift labels
        def shift(examples):
            result = [x[1:] + [-100] for x in examples["labels"]]
            return {"labels": result}

        dataset = dataset.map(
            shift,
            batched=True,
            keep_in_memory=True,
            load_from_cache_file=False,
            num_proc=max(self.n_workers, 1),
        )

        if cache_dir is not None:
            self._save_to_cache(dataset, tokenizer, vocab, cache_dir)
        return dataset, tokenizer, vocab

    def _save_to_cache(self, dataset, tokenizer, vocab, cache_dir):
        cache_dir = self.cache_dir / self._cache_dir_name
        logger = logging.getLogger(__name__)
        logger.info(f"Saving to cache at {str(cache_dir)}")
        dataset.save_to_disk(str(cache_dir))
        with open(cache_dir / "tokenizer.pkl", "wb") as f:
            pickle.dump(tokenizer, f)
        with open(cache_dir / "vocab.pkl", "wb") as f:
            pickle.dump(vocab, f)

    def _load_from_cache(self, cache_dir):
        assert cache_dir.is_dir()
        logger = logging.getLogger(__name__)
        logger.info(f"Load from cache at {str(cache_dir)}")
        dataset = DatasetDict.load_from_disk(str(cache_dir))
        with open(cache_dir / "tokenizer.pkl", "rb") as f:
            tokenizer = pickle.load(f)
        with open(cache_dir / "vocab.pkl", "rb") as f:
            vocab = pickle.load(f)
        return dataset, tokenizer, vocab

    @property
    def _cache_dir_name(self):
        return f"version-{self.version}-block_size-{self.block_size}"

# Registry for dataloader class
loader_registry = {
    None: torch.utils.data.DataLoader, # default case
}


In [None]:
# tokenize wikitext
data_config = {
    "name": "WikiText",
    "_name_": "wikitext",
    "version": 103,
    "block_size": 1024,
    "data_dir": "/content/drive/MyDrive/datasets/wikitext_103_1024",
    "fixed_size": True,
}


dataset = SequenceDataset.registry["wikitext"](**data_config)
dataset.setup()


### Test Model with Wikitext Data

#### Load Model and define evaluation functions

In [None]:
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM

MODEL_ID = "allenai/Olmo-3-7B-Think"

# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

# load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

num_heads = model.config.num_attention_heads
num_kv_heads = getattr(model.config, "num_key_value_heads", num_heads)
print("Number of query heads:", num_heads)
print("Number of key/value heads:", num_kv_heads)

model.eval()

In [None]:
from collections import defaultdict

# Dictionary to store Q/K/V
qkv_cache = defaultdict(dict)

def make_hook(layer_idx, name):
    def hook(module, input, output):
        qkv_cache[layer_idx][name] = output.detach().cpu()
    return hook

hooks = []

for layer_idx, layer in enumerate(model.model.layers):
    hooks.append(layer.self_attn.q_proj.register_forward_hook(make_hook(layer_idx, "Q")))
    hooks.append(layer.self_attn.k_proj.register_forward_hook(make_hook(layer_idx, "K")))
    hooks.append(layer.self_attn.v_proj.register_forward_hook(make_hook(layer_idx, "V")))

num_layers = len(model.model.layers)


In [None]:
import numpy as np
import einops
import time

def get_eig_from_qkv_att_softmax(q,k,v):
    '''
    Function to get eigenvalues from QKV softmax self-attention
    --------------
    Inputs:
    q : torch.tensor
        querry tensor of shape (batch, seq_len, num_heads, head_dim)
    k : torch.tensor
        key tensor of shape (batch, seq_len, num_heads, head_dim)
    v : torch.tensor
        value tensor of shape (batch, seq_len, num_heads, head_dim)
    --------------
    Outputs:
    eta : np.array
        eigenvalues of shape (batch_size, seq_len-1, num_heads, 1)

    '''

    batchsize = q.shape[0]
    seq_len = q.shape[1]
    num_heads = q.shape[2]
    head_dim = q.shape[3]

    scores = torch.einsum("bthd,bshd->btsh", q, k)
    mask_mul = torch.tril(torch.full((seq_len, seq_len), 1, device=scores.device), 0)
    scores = torch.einsum("btsh,ts->btsh", scores, mask_mul.to(dtype=scores.dtype))

    # make calculation numerical feasible by subtracting largest row score
    scores_max = torch.max(scores,-2).values

    # repeat to get correct dimensions
    scores_max_r = einops.repeat(scores_max, 'a i j ->a i newaxis j', newaxis=seq_len)

    # create mask to get lower triangular matrix
    mask_mul = torch.tril(torch.full((seq_len, seq_len), 1, device=scores.device), 0)
    scores_max_r = torch.einsum("btsh,ts->btsh", scores_max_r, mask_mul.to(dtype=scores.dtype))

    # calculate row normalized score
    scores_norm = scores - scores_max_r
    scores_norm = scores_norm.detach().cpu().numpy()
    scores_norm = scores_norm.astype(np.float64)

    # get elementwise exponential (row-wise normalized)
    exp_scores = np.nan_to_num(np.exp(scores_norm))

    # get nu (row-wise normalized)
    nu = exp_scores.sum(axis=2)

    # get eigenvalues by dividing nu=i with nu=i+1
    eta = np.divide(nu[:,:-1,:],nu[:,1:,:])

    # division of two values with different scaling/normalization requires multiplication with inverse scaling/normalization
    scores_max_np = scores_max.detach().cpu().numpy()
    score_max_diff = -scores_max_np[:,1:,:]+scores_max_np[:,:-1,:]
    max_scaling = np.exp(score_max_diff.astype(np.float64))

    eta = eta*max_scaling

    # add dimension for concatenation
    eta = np.expand_dims(eta, axis=-1)

    return eta


#### Evaluate model on test data

In [None]:
import os
import re

try:
  data_dir = f"/content/drive/MyDrive/datasets/{dataset._name_}_eigs"
except:
  # resort to wikitext
  data_dir = f"/content/drive/MyDrive/datasets/wikitext_eigs"

print(data_dir)
if os.path.isdir(data_dir):
    ignore_eig_comp = -1

    pattern = re.compile(r"^eigs_(\d+)\.pkl$")
    for fname in os.listdir(data_dir):
        match = pattern.match(fname)
        if match:
            idx = int(match.group(1))
            ignore_eig_comp = max(ignore_eig_comp, idx)
    print(f"Ignoring {ignore_eig_comp} first batches.")
else:
    print(f"{data_dir} is not a directory.")


In [None]:
import gc
import dill

compute_eigs = True

# load metrics
metrics_fn = dataset.get_metrics(layer="transformer")

bsz = 2 # 4 seems to be the max given the available ram on T4 for wikitext
testloader = dataset.test_dataloader(batch_size=bsz, shuffle=False)
if type(testloader) is dict:
        testloader = testloader[None]

eigs = []


# evaluate model
model.eval()

test_performance = 0.0
test_loss = 0.0

batch = 0
with torch.inference_mode():
    for X, y, _ in tqdm(testloader):
        X = X.to(model.device)
        y = y.to(model.device).view(-1)

        output = model(X)

        if compute_eigs:
          eig = np.empty((bsz, X.shape[1] - 1, num_heads, num_layers))
          if batch > ignore_eig_comp:
            for layer_idx in tqdm(range(num_layers), leave=False):

              Q = qkv_cache[layer_idx]["Q"] # (batch, seq_len, hidden_dim)
              K = qkv_cache[layer_idx]["K"]
              V = qkv_cache[layer_idx]["V"]

              head_dim = Q.shape[-1] // num_heads
              kv_head_dim = K.shape[-1] // num_kv_heads



              # Reshape to (batch, seq_len, num_heads, head_dim)
              Q_head = Q.view(Q.shape[0], Q.shape[1], num_heads, head_dim)
              K_head = K.view(K.shape[0], K.shape[1], num_kv_heads, kv_head_dim)
              V_head = V.view(V.shape[0], V.shape[1], num_kv_heads, kv_head_dim)

              eig[:,:,:,layer_idx] = get_eig_from_qkv_att_softmax(Q_head, K_head, V_head).squeeze() # (batch_size, seq_len-1, num_heads, 1)

            path = f"/content/drive/MyDrive/datasets/{dataset._name_}_eigs/eigs_{batch}.pkl"

            # save eigenvalues
            with open(path, "wb") as f:
                dill.dump(eig, f)




        loss = torch.nn.functional.cross_entropy(output.logits.view(-1, output.logits.size(-1)), y)
        test_loss += loss.item()
        test_performance += metrics_fn(output.logits.view(-1, output.logits.size(-1)), y)

        del X, y, output, loss

        batch += 1

test_loss = test_loss/len(testloader)
test_perf = test_performance / len(testloader)
tqdm.write("Test performance: {0:.4f}\n".format(test_perf))


### Eigenvalue Analysis

#### Eigenvalue processing

In [None]:
process_eigs = True
if process_eigs:
  import os
  import glob
  import dill
  import numpy as np

  path = "/content/drive/MyDrive/datasets/wikitext_eigs/"

  # Collect all eigs_*.pkl files (sorted for determinism)
  files = sorted(glob.glob(os.path.join(path, "eigs_*.pkl")))

  arrays = []
  for f in files:
      with open(f, "rb") as fh:
          arr = dill.load(fh)
          assert isinstance(arr, np.ndarray), f"{f} does not contain a numpy ndarray"
          arrays.append(arr)

  # Concatenate along batch dimension
  all_eigs = np.concatenate(arrays, axis=0)

  print(all_eigs.shape) # (b_sz, seq_len-1, heads, layers-1)


In [None]:
if process_eigs:
  np.save("/content/drive/MyDrive/datasets/wikitext_eigs/all_eigs.npy", all_eigs)

#### Eigenvalue plotting

In [16]:
plot_eigs=True
if plot_eigs:
  def threshold_analysis(eig_val, thresholds):
      """
      eig_val: shape (B, N, num_heads, num_layers)
      thresholds: 1D array of threshold values
      """

      num_layers = eig_val.shape[-1]
      num_heads = eig_val.shape[2]
      batch_size = eig_val.shape[0]

      thresholds = thresholds.flatten()
      num_thresholds = thresholds.shape[0]
      percentages = np.empty([num_thresholds + 1, batch_size, num_heads, num_layers])

      # Values we compare against thresholds
      # Shape: (B, N, H, L)
      eta = eig_val
      count_eta_all = eta.shape[1]  # total values per head/layer

      # First bin: 0 <= x <= first threshold
      mask_low = (eta >= 0) & (eta <= thresholds[0])
      percentages[0,:,:,:] = mask_low.sum(axis=(1)) / count_eta_all * 100

      # Last bin: > last threshold
      mask_high = eta > thresholds[-1]
      percentages[-1,:,:,:] = mask_high.sum(axis=(1)) / count_eta_all * 100

      # Middle bins: thresholds[t] <= x <= thresholds[t+1]
      for t in range(num_thresholds-1):
          mask_middle = (eta >= thresholds[t]) & (eta <= thresholds[t+1])
          percentages[t+1,:,:,:] = mask_middle.sum(axis=(1)) / count_eta_all * 100

      return percentages

In [17]:
if plot_eigs:
  import numpy as np

  all_eigs = np.load("/content/drive/MyDrive/datasets/wikitext_eigs/all_eigs.npy")


In [None]:
if plot_eigs:
  thresholds_radius = np.array([0.1,0.5,0.9,1.0,10,100])
  percentage = threshold_analysis(all_eigs, thresholds_radius) # output: (num_bins, num_batch, num_heads, num_layers)

In [None]:
if plot_eigs:
  percentage_per_layer = percentage.mean(axis=(1, 2))

In [None]:
if plot_eigs:
  import matplotlib.pyplot as plt
  import numpy as np

  num_layers_to_plot = min(5, percentage_per_layer.shape[1])

  # Build bin labels
  thresholds_radius = np.asarray(thresholds_radius)
  bin_labels = (
      [f"(0, {thresholds_radius[0]:.2f})"] +
      [f"({thresholds_radius[i]:.2f}, {thresholds_radius[i+1]:.2f})"
        for i in range(len(thresholds_radius) - 1)] +
      [f"({thresholds_radius[-1]:.2f}, ∞)"]
  )

  for layer in range(num_layers_to_plot):
      values = percentage_per_layer[:, layer]  # (bins,)

      plt.figure(figsize=(8, 4))
      plt.bar(range(len(values)), values)
      plt.xlabel("Eigenvalue bins")
      plt.ylabel("Percentage (%)")
      plt.title(f"Percentage of Eigenvalues in Each Bin — Layer {layer}")

      plt.xticks(range(len(bin_labels)), bin_labels, rotation=45, ha="right")
      plt.tight_layout()
      plt.show()


In [None]:
if plot_eigs:
  percentage_per_layer_per_head = percentage.mean(axis=(1))

In [None]:
if plot_eigs:
  import matplotlib.pyplot as plt
  import numpy as np

  num_layers = min(5, percentage_per_layer_per_head.shape[2])
  num_heads = min(100, percentage_per_layer_per_head.shape[1])

  # Build bin labels (same as before)
  thresholds_radius = np.asarray(thresholds_radius)
  bin_labels = (
      [f"(0, {thresholds_radius[0]:.2f})"] +
      [f"({thresholds_radius[i]:.2f}, {thresholds_radius[i+1]:.2f})"
      for i in range(len(thresholds_radius) - 1)] +
      [f"({thresholds_radius[-1]:.2f}, ∞)"]
  )

  fig, axes = plt.subplots(
      num_layers,
      num_heads,
      figsize=(4 * num_heads, 3 * num_layers),
      sharex=True,
      sharey=True
  )

  for layer in range(num_layers):
      for head in range(num_heads):
          ax = axes[layer, head]

          # (bins,)
          values = percentage_per_layer_per_head[:, head, layer]

          ax.bar(range(len(values)), values)

          if layer == num_layers - 1:
              ax.set_xticks(range(len(bin_labels)))
              ax.set_xticklabels(bin_labels, rotation=45, ha="right")
          else:
              ax.set_xticks([])

          if head == 0:
              ax.set_ylabel(f"Layer {layer}")

          if layer == 0:
              ax.set_title(f"Head {head}")

  fig.suptitle("Percentage of Eigenvalues per Bin (Layers × Heads)", fontsize=14)
  fig.tight_layout(rect=[0, 0, 1, 0.96])
  plt.show()


In [None]:
if plot_eigs:
  import numpy as np
  import matplotlib.pyplot as plt
  from sys import path
  from os import getcwd
  import wandb
  import os



  def plot_all_layers_and_spec_head_without_init(
      data_all,
      data_all_std,
      models,
      common, threshold, num_head, num_layers, seed, spec, layer_select, plot_layers
  ):
      """
      data_all shape:
          (n_bins, n_models, n_layers)

      After transposition (for plotting):
          (n_models, n_bins)
      """

      fig_height = len(layer_select) * 7
      fig_width = 30

      fig, axes = plt.subplots(
          len(layer_select), 1, sharex=True, figsize=(fig_width, fig_height)
      )
      fig.subplots_adjust(hspace=0.05)

      color_seq = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'gray']

      thresholds_plot = threshold.tolist()
      thresholds_plot.insert(0, 0.0)

      for j, layer in enumerate(layer_select):

          # -----------------------------------------
          # Slice layer and FIX DIMENSIONS
          # -----------------------------------------
          # original: (bins, models)
          data = data_all[:, :, layer].T
          data_std = data_all_std[:, :, layer].T


          # now: (models, bins)
          n_models, n_bins = data.shape

          x = np.arange(n_models)   # one group per model
          width = 0.09

          error_config = dict(
              elinewidth=3,
              capsize=5,
              capthick=3,
              zorder=4
          )


          yerr_lower = np.minimum(data_std, data)
          yerr_upper = data_std
          final_err = [yerr_lower, yerr_upper]

          # -----------------------------------------
          # Plot bins as grouped bars
          # -----------------------------------------
          for i in range(n_bins):

              offset = i * (width + 0.04)

              if i == n_bins - 1:
                  label_tr   = f'(t) > {thresholds_plot[i]}'
              else:
                  label_tr   = f'(t) {thresholds_plot[i]}-{thresholds_plot[i+1]}'


              axes[j].bar(
                  x + offset,
                  data[:, i],
                  width=width,
                  color=color_seq[i],
                  edgecolor='black',
                  yerr=[yerr_lower[:, i], yerr_upper[:, i]],
                  error_kw=error_config,
                  capsize=4,
                  label=label_tr,
                  zorder=3
              )

          # -----------------------------------------
          # Axis styling
          # -----------------------------------------
          axes[j].tick_params(axis='x', labelsize=60)
          axes[j].tick_params(axis='y', labelsize=60)
          # axes[j].set_xticks(x + (width+0.04)*(n_models-1)/2, models)
          axes[j].set_yticks([0, 25, 50, 75, 100])
          axes[j].tick_params(axis='y', labelleft=False)
          axes[j].grid(axis='y', zorder=1)
          axes[j].set_ylim([0, 110])

          group_width = n_bins * (width + 0.04) - 0.04
          x_center = x + group_width / 2 - width / 2

          axes[j].set_xticks(x_center)
          axes[j].set_xticklabels(models, fontsize=45)
          axes[j].tick_params(axis='x', length=0)

          axes[j].text(
              0.01, 0.99,
              f'layer {plot_layers[layer] + 1}',
              transform=axes[j].transAxes,
              ha='left', va='top',
              fontsize=47, fontweight='bold',
              zorder=3,
              bbox=dict(
                  facecolor="white",
                  edgecolor="black",
                  boxstyle="round,pad=0.3"
              )
          )

          # axes[j].set_title(common,
          #                         fontsize=60,
          #                         fontweight='bold',
          #                         pad=25)

          # Vertical separators between models
          for xi in range(1, n_models):
              axes[j].axvline(
                  xi - 0.13,
                  color="black",
                  linewidth=2,
                  linestyle="--",
                  zorder=0
              )
          for label in axes[j].get_xticklabels():
              label.set_ha('center')

      for ax in axes:
          ax.margins(x=0)

      # plt.suptitle(common, fontsize=67, fontweight='bold')
      # plt.subplots_adjust(top=0.93)

      # --- compute x-centers of model groups (same logic as xticks) ---
      group_width = n_bins * (width + 0.04) - 0.04
      x_center = x + group_width / 2 - width / 2

      # convert data x-coordinates → figure coordinates
      xlims = axes[0].get_xlim()
      x_fig = [(xc - xlims[0]) / (xlims[1] - xlims[0]) for xc in x_center]

      fig.text(
      0.5, 0.9,
      "WikiText",
      ha="center", va="top",
      fontsize=55,
      fontweight="bold"
      )




      fig.savefig(
          f'{common}_head{num_head}_all_layers{num_layers}_seed{seed}_without_init_{spec}.pdf',
          format='pdf',
          dpi=300,
          bbox_inches='tight'
      )

In [None]:
if plot_eigs:
  # percentage: (n_bins, b_sz, heads, layers)
  mean_pct = percentage.mean(axis=1)  # (n_bins, heads, layers)
  std_pct  = percentage.std(axis=1)   # (n_bins, heads, layers)

  plt.rcParams.update({'font.size': 55})
  plt.rcParams['mathtext.fontset'] = 'stix'
  plt.rcParams['font.family'] = 'STIXGeneral'


  plot_heads = [0, 1, 2, 8, 9, 17]

  plot_layers = [0, 1, 2, 3, 4, 20, 21]

  # Select heads along the "models" dimension
  data_all = mean_pct[:, plot_heads, :]      # (n_bins, n_selected_heads, layers)
  data_all_std = std_pct[:, plot_heads, :]

  data_all = data_all[:, :, plot_layers]
  data_all_std = data_all_std[:, :, plot_layers]

  n_models = len(plot_heads)
  models = [f'Head {h}' for h in plot_heads]


  layer_select = list(range(len(plot_layers)))

  plot_all_layers_and_spec_head_without_init(
      data_all=data_all,
      data_all_std=data_all_std,
      models=models,
      common="Eigenvalue radius %",
      threshold=thresholds_radius,
      num_head=len(plot_heads),
      num_layers=len(plot_layers),
      seed=0,
      spec="radius",
      layer_select=layer_select,
      plot_layers=plot_layers
  )
