 <a href="https://colab.research.google.com/github/JasonGross/neural-net-coq-interp/blob/main/writeups/Formalizing_Transformers_For_Mech_Interp_Folks_Max_of_small_n_Jason_Gross%2C_Thomas_Kwa%2C_Rajashree_Agrwal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Formalizing Transformers, For Mech Interp Folks


In [*Towards Monosemanticity: Decomposing Language Models With Dictionary Learning*](https://transformer-circuits.pub/2023/monosemantic-features), Bricken et al. say:
> #### How can we tell if the autoencoder is working?
>
> Usually in machine learning we can quite easily tell if a method is working by looking at an easily-measured quantity like the test loss. We spent quite some time searching for an equivalent metric to guide our efforts here, and unfortunately have yet to find anything satisfactory.
>
> We began by looking for an information-based metric, so that we could say in some sense that the best factorization is the one that minimizes the total information of the autoencoder and the data. Unfortunately, this total information did not generally correlate with subjective feature interpretability or activation sparsity. (Runs whose feature activations had an average L0 norm in the hundreds but low reconstruction error could have lower total information than those with smaller average L0 norm and higher reconstruction error.)
>
> Thus we ended up using a combination of several additional metrics to guide our investigations:
>
> 1. **Manual inspection:** Do the features seem interpretable?
> 2. **Feature density:** we found that the number of “live” features and the percentage of tokens on which they fire to be an extremely useful guide. (See appendix for details.)
> 3. **Reconstruction loss:** How well does the autoencoder reconstruct the MLP activations? Our goal is ultimately to explain the function of the MLP layer, so the MSE loss should be low.
> 4. **Toy models:** Having toy models where we know the ground truth and so can cleanly evaluate the autoencoder’s performance was crucial to our early progress.
>
>Interpreting or measuring some of these signals can be difficult, though. For instance, at various points we thought we saw features which at first didn’t make any sense, but with deeper inspection we could understand. Likewise, while we have identified some desiderata for the distribution of feature densities, there is much that we still do not understand and which prevents this from providing a clear signal of progress.
>
>We think it would be very helpful if we could identify better metrics for dictionary learning solutions from sparse autoencoders trained on transformers.

This write-up is, in some sense, a response to the above quote.

**Vision:** We want to be able to automate the discovery of mechanistic interpretation of neural networks.

There are two big questions:
1. How do we discover the interpretation? (Answer: Engineering!)
2. How do we ensure that the interpretation is good (correct, accurate, human-understandable, etc.)?

We provide a case-study-based plausible theoretical grounding for answering the second question.

Claim: Minimizing information *does* lead to human interpretability, *if* we are clear on what we are proposing to explain.

There are three problems with [the information-based metric](https://transformer-circuits.pub/2023/may-update/index.html#simple-factorization) proposed by the authors of *Towards Monosemanticity*, but none of them are problems with the idea of minimizing information.

*Towards Monosemanticity*, as we understand it, first specifies a problem to be solved and a language of solutions ("reconstruct the matrix of activations $A \approx SD$ where $S$ is sparse and $D$ is a dictionary of features"); then specifies tunable hyperparameters (the coefficient of the sparsity penalty and the size of the dictionary); then finds for each hyperparameter setting the solution that minimizes loss (as measured by combining the reconstruction error and the sparsity penalty); and finally proposes a metric (total information) for comparing the solutions for different hyperparameter settings.

1. Claim: We actually want a tunable knob for picking out interpretations on the pareto-frontier of complicatedness of interpretation (total information) vs. accuracy of interpretation (reconstruction error).  We don't expect there to be a single "best" interpretation in general, but instead a spectrum of interpretations that trade off accuracy for simplicity.  It might make sense to analyze the curve here, perhaps looking for inflection points (intuitively, places where a little more complexity buys a lot more accuracy, or conversely a small relaxation in accuracy saves a lot of complexity), but we shouldn't expect information-based metrics to pick out a single "best interpretation" for us.
2. The correct trade-off to be looking at is either "for a given bound on information, what's the explanation that provides the most accuracy" or "for a given bound on reconstruction loss, what's the most compact explanation".  The procedure described above does not do this.  Both hyperparameters already control trade-offs between reconstruction loss vs two different proxies for compactness / information / interpretability, and then the resulting optimal solutions are compared using a third proxy.  We have no reason to expect sensible behavior from this procedure.
3. We can try to measure the understandability / compactness of four different things:
   1. Local behavioral description: The description of *what* the MLP is doing
   2. Global behavioral description: The description of *what* the entire neural net is doing
   3. Global behavioral expectation: The description of *how* it is that *what* the neural net is doing results in our desired behavior (e.g., low loss on the training distribution)
   4. Global behavioral guarantee: The description of *how* it is that, when accounting for *everything* the neural net is doing, the neural net achieves our desired behavior.  This is (3), but also accounting for all of the "boring" things, e.g., this includes explanations of how it comes to be that the "random noise" from "unimportant heads" is small enough that it doesn't destroy the "important signal".

Total information of the sparse autoencoder decomposition is a measurement of (1).
There's no reason to expect that the directions in activation space should have a particularly compact representation (and most "human interpretations" just say "there is a direction" for a feature, not what the direction is), which, we believe, is why sparsity seems a better proxy for human interpretability in sparse autoencoders than total information.

This is a problem when trying to minimize information metrics of (1), but *not* when correctly minimizing information metrics for (3) or (4) (and possibly not even for (2)).

There are two things (TODO: replace "things") that ameliorate this issue:
1. When minimizing information for (3) and (4), we expect the largest gains (complexity reductions) to come not from compact descriptions of the input distribution factoring, but in the ways a given factoring allows a computational complexity reduction.  The *reason* a feature is useful is that you can compute a desired property more simply using that feature than without it.  The benefit of a "DNA feature" is not just in having a compact description of when something is DNA and what should be predicted when something is DNA, but in that other computations can be more simply described when conditioned on DNA.  Or to take a simpler example, the benefit of having a "size" direction that is the principle component of the QK circuit in max-of-n is that it allows us to reason about the OV behavior *as a function of how much attention is paid to the largest element* ($\mathcal{O}(\text{d\_vocab})$ possibilities) rather than as a function of all the sequences of attention weights ($\mathcal{O}(\text{d\_vocab}^{\text{n\_ctx}})$ possibilities).  Next to this asymptotic reduction, the compactness of the description of the particular size direction is peanuts.
2. If we pick a measure of information [that does not peanalize us for arbitrary choices](https://www.lesswrong.com/posts/KcvJXhKqx4itFNWty/k-complexity-is-silly-use-cross-entropy-instead), such as cross entropy, then we can can do even better!  If a choice is arbitrary (such as the image of the size direction under embedding followed by the query matrix), we won't get docked for the size of that description.  If a choice is not arbitrary, then there's something interesting going on, and we shouldn't be excluding it from our interpretation.


## Proofs and Heursitic Arguments from the Mech Interp Lens
Proofs give a guarantee not just of the biggest thing that happens, but also how it comes to be that nothing else of interest is happening.  Heuristic arguments promise to solve the problem of how to (rigorously) separate out the “nothing interesting is going on” “default assumption” so we can measure its complexity separately, and find actually compact arguments of just the “interesting” “interpretable” stuff.

We would love to be able to claim that compactness of {proof, heuristic arg} is a good evaluation metric for human interpretability.  We don’t have nearly enough evidence for this, alas (future work!), so instead we aim to present a case-study as evidence that compactness for human interp has *firm formal grounding*.

## A Claim That Needs A Better Section Heading
We probably ultimately want a measurement of (2) "the thing the neural net is doing", as compactly as possible.
But if you look closely, we don't have a direct way to measure this.  Instead we're saying "what's a compact way of computing the same thing the neural net computes".  But "computing the thing the neural net computes" *is not the same* as "the computation the neural net is running"!

As far as we can tell, the current approaches try to proxy the gap with reconstruction loss on the quirks and errors made by the model.
There's some value to this ("there's many ways for relationships to fail, but the only way for them to succeed is respect"), but it clearly isn't adequate for perfect models implementing distinct algorithms.
And it feels bad to me, saying that we're relying on networks being quirky and error-prone to get decently accuracte explanation evaluations.

We want to claim that a better proxy for the gap between "computing the same thing" and "the computation that's being run" is a *guarantee* (either a proof or a heuristic argument) that the particular computation being run computes the desired result.
We'll pay some cost by using this as our proxy (we'll have no choice but to include the complete description of the computation being run, though hopefully we can avoid being bitten by this by using cross-entropy), but we believe this proxy gives much firmer grounding for evaluating mech interp.

In this document, we'll walk through a small case-study or two, applying this frame of proofs and guarantees for mech interp.

## Proof Strategy
The strategy we use for making guarantees about neural nets $NN$ is as follows:
1. Fix an input distribution $D_I$ (e.g., the uniform distribution of sequences of length $n$ over a vocabulary of size $d$).  This defines the domain of discourse.
2. Define a property $P$ to be established on a metric $M$ that can be calculated over the input-output pairs of $NN$ on this domain.  For example, $M$ might be accuracy, log-loss, etc, and $P$ might be $\leq 0.01$ or $\geq 99\%$ or "within $\varepsilon$ of 0.8".  Importantly, this theorem $P(M(NN(D_I)))$ could in theory be proven (or disproven) simply by computation, if we were willing to wait long enough.
3. Argue that the property holds of the given $NN$’s computation, by:
   - Finding/constructing a cheaper computation $C$ (taking as input the $NN$ weights & biases) such that
     1. Establishing the property $P$ on this computation $C$ implies that $P$ holds on $M(NN(D_I))$ (symbolically: $P(C) \Longrightarrow P(M(NN(D_I)))$)
         - We're working on formalizing such proofs in the proof assistant Coq, but for this document we give our proofs as English arguments.
     2. Establishing $P$ on $C$ is computationally feasible and straightforward

Note: Compactness is important in each of these steps, but gives different things:
- domain of discourse: compactness allows us to speak at all
- property description: compactness gives understanding of *what outcomes happens*
- description of computation $C$: compactness gives understanding of *why outcome happens*
- cost of running $C$: compactness gives goodness of description / understanding of *how it comes to be that the model is implementing our described algorithm*
- size of proof that $P(C) \Longrightarrow P(M(NN(D_I))))$: compactness gives understanding of *why outcome happens **on NN***

TODO: move the below somewhere else, or to a document for heursitic argument folks

Aside: As we understand it, heuristic arguments also fit into this frame, as follows:
1. Same (input distribution $D_I$)
2. Usually the metric $M$ is not quite what’s described, instead there’s a standard metric $M'$ (like log-loss), and the actual metric, currently implicit and nebulous, maybe goes something like $\mathbb{E}_{NN\leftarrow D_{NN}}\left[M'(NN(D_I))\right]$ (where $\mathbb{E}_{NN\leftarrow D_{NN}}$ is "the expected value over $NN$ drawn from a distribution of weights and biases $D_{NN}$"), and then the property $P$ is the same
3. Sorta same
   - Estimator $G$ is a class of computations parameterized over heuristic args $\pi$ such that:
     1. We can establish in general that (under some restrictions), there exists / forall / for many $\pi$s, with $C = G(NN|\pi\text{s})$, $P(C)$ implies $P(M(NN(D_I)))$ (plus some other properties that make it reasonable to have this only for some $\pi$s)
     2. Same, but even more so (establishing $P(C)$ should be ~linear in NN size)

## Model Setup: Max of 2

We'll be looking at the problem of computing the max of two numbers.  We use a 1L attention-only transformer with vocab size 64, model size 32, no layer norm, no biases.  The input is a sequence of two (or later $n$) numbers, and we train on the cross-entropy loss of the prediction in the final sequence position and the correct maximum.  The model has been adversarially overtrained to the point where the accuracy is 100% and the loss is dominated by 32-bit floating point error in the final log-softmax.

## Interpretation
High-level: The model pays more attention to larger numbers, and copies whatever it's paying attention to.

More detail:
- There is a "size direction" and a "query direction".  Tokens are embededded with more-or-less uniform overlap with the query direction and more-or-less monotonically-increasing overlap with the size direction (the curve seems to be cubic or quintic, for unclear reasons).
- The QK circuit has extremely high overlap between the query-direction on the query side and the size-direction on the key side, so that the pre-softmax attention is essentially a scalar multiple of the overlap between the one-hot token vector and the size direction.  Everything else the QK circuit does is essentially negligable.
- The OV circuit is a low-rank representation of a matrix with high (and more-or-less uniform) diagonal entries and low off-diagonal entries.  We have no explanation for how this comes to be the case.
- (TODO: check this) There's some additional structure in the noise: query tokens with less overlap with the query direction have (a) less skip-connection noise, (b) larger gaps between the diagonal and off-diagonal entries in the OV circuit, and (c) smaller errors in size-direction overlap.  That is, the errors conspire in our benefit: query tokens that are worse for paying attention to larger tokens have correspondingly larger gaps between them and adjacent tokens in the size-direction, so that we still succeed in paying more attention to tokens larger than the query and less attention to tokens smaller than the query, and the copying behavior on small-gap sequences lines up for reasons we have not yet understood (merely verified).

In [None]:
#@title setup
# Is this the development version?
DEV = True #@param {type:"boolean"}

try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os
import sys

if IN_COLAB:
    # Install packages
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens

    # Code to download the necessary files (e.g. solutions, test funcs)
    import os, sys
    if not os.path.exists("utils"):
        !curl -o /content/main.zip https://codeload.github.com/JasonGross/neural-net-coq-interp/zip/refs/heads/main
        !unzip /content/main.zip 'neural-net-coq-interp/training/*'
        sys.path.append("/content/utils")
        os.remove("/content/main.zip")
        os.rename("neural-net-coq-interp/training", "utils")
        os.rmdir("neural-net-coq-interp")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")
    import os, sys
    if DEV:
        sys.path.append(f"{os.getcwd()}/../training")
    elif not os.path.exists("utils"):
        !curl -o /content/main.zip https://codeload.github.com/JasonGross/neural-net-coq-interp/zip/refs/heads/main
        !unzip /content/main.zip 'neural-net-coq-interp/training/*'
        sys.path.append(f"{os.getcwd()}/utils")
        os.remove("/content/main.zip")
        os.rename("neural-net-coq-interp/training", "utils")
        os.rmdir("neural-net-coq-interp")


In [None]:
#@title training_utils
# # Data Generation

# Helper functions for generating data and splitting it into batches for training.

import datetime
import os, os.path
from pathlib import Path
from typing import List, Any, Iterable, Optional
import numpy as np
import tqdm
from transformer_lens import HookedTransformer
import torch
import itertools

import wandb

# In[ ]:
def default_device(deterministic: bool = False) -> str:
   return "cuda" if torch.cuda.is_available() and not deterministic else "cpu"

# In[ ]:

DEFAULT_WANDB_ENTITY = 'team-jason' # 'tkwa-team' # 'team-jason'

# In[ ]:

def in_colab() -> bool:
    """
    Returns True if running in Google Colab, False otherwise.
    """
    try:
        import google.colab
        return True
    except:
        return False

# In[ ]:

def get_pth_base_path(save_in_google_drive: bool = False, create: bool = True) -> Path:
    """
    Returns the base path for saving models. If `save_in_google_drive` is True, returns the path to the Google Drive
    folder where models are saved. Otherwise, returns the path to the local folder where models are saved.
    """
    if in_colab():
        if save_in_google_drive:
            from google.colab import drive
            drive.mount('/content/drive/')
            pth_base_path = Path('/content/drive/MyDrive/Colab Notebooks/')
        else:
            pth_base_path = Path("/workspace/_scratch/")
    else:
        pth_base_path = Path(os.getcwd())

    pth_base_path = pth_base_path / 'trained-models'

    if create and not os.path.exists(pth_base_path):
        os.makedirs(pth_base_path)

    return pth_base_path

# In[ ]:

def generate_all_sequences(n_digits: int, sequence_length: int = 2):
  data = list(itertools.product(range(n_digits), repeat=sequence_length))
  data = torch.tensor(data)
  return data

# In[ ]:

def compute_all_tokens(model: HookedTransformer):
    return generate_all_sequences(n_digits=model.cfg.d_vocab, sequence_length=model.cfg.n_ctx)

# In[ ]:

def shuffle_data(data):
  indices = np.array(range(len(data)))
  np.random.shuffle(indices)
  data = data[indices]
  return data

# In[ ]:

def make_testset_trainset(
    model: HookedTransformer,
    training_ratio=0.7,
    force_adjacent=False):
  """
  Generate a train and test set of tuples containing `sequence_length` integers with values 0 <= n < n_digits.

  Args:
      sequence_length (int): The length of each tuple in the dataset.
      n_digits (int): The number of possible values for each element in the tuple.
      training_ratio (float): The ratio of the size of the training set to the full dataset.
      force_adjacent (bool): Whether to make training adversarial (force to include all (x, x +- 1))

  Returns:
      Tuple[List[Tuple[int, ...]], List[Tuple[int, ...]]]: A tuple containing the training set and test set.
          The training set contains `training_ratio` percent of the full dataset, while the test set contains the
          remaining data. Each set is a list of tuples containing `sequence_length` integers with values 0 <= n < n_digits.
          The tuples have been shuffled before being split into the train and test sets.
  """
  data = compute_all_tokens(model)

  data = shuffle_data(data)

  if force_adjacent:
    idxs = (data[:,0] - data[:,1]).abs() == 1
    data, extra_data = data[~idxs], data[idxs]
    data = torch.cat([extra_data, data], dim=0)

  split_idx = int(len(data) * training_ratio)

  data_train = data[:split_idx]
  data_test = data[split_idx:]

  if force_adjacent:
    data_train = shuffle_data(data_train)
    data_test = shuffle_data(data_test)

  return data_train, data_test

# In[ ]:

def make_generator_from_data(data: List[Any], batch_size: int = 128) -> Iterable[List[Any]]:
  """
  Returns a generator that yields slices of length `batch_size` from a list.

  Args:
      data: The input list to be split into batches.
      batch_size: The size of each batch.

  Yields:
      A slice of the input list of length `batch_size`. The final slice may be shorter if the
      length of the list is not evenly divisible by `batch_size`.
  """
  data = shuffle_data(data)
  for i in range(0,len(data), batch_size):
    yield data[i:i+batch_size]

# In[ ]:

def make_wandb_config(
    model:HookedTransformer,
    optimizer_kwargs: dict,
    n_epochs=100,
    batch_size=128,
    batches_per_epoch=10,
    adjacent_fraction=0,
    use_complete_data=True,
    device=None,
    **kwargs):
  return {
      'model.cfg':model.cfg.to_dict(),
      'optimizer.cfg':optimizer_kwargs,
      'n_epochs':n_epochs,
      'batch_size':batch_size,
      'batches_per_epoch':batches_per_epoch,
      'adjacent_fraction':adjacent_fraction,
      'use_complete_data':use_complete_data,
      'device':device,
    }

def load_model(model: HookedTransformer, model_pth_path: str):
  try:
    cached_data = torch.load(model_pth_path)
    model.load_state_dict(cached_data['model'])
    #model_checkpoints = cached_data["checkpoints"]
    #checkpoint_epochs = cached_data["checkpoint_epochs"]
    #test_losses = cached_data['test_losses']
    train_losses = cached_data['train_losses']
    #train_indices = cached_data["train_indices"]
    #test_indices = cached_data["test_indices"]
    return train_losses, model_pth_path
  except Exception as e:
    print(f'Could not load model from {model_pth_path}:\n', e)

def train_or_load_model(
      model_name:str,
      model:HookedTransformer,
      loss_fn,
      acc_fn,
      train_data_gen_maybe_lambda,
      data_test,
      n_epochs=100,
      batches_per_epoch=10,
      device=None,
      wandb_project=None,
      save_model=True,
      model_pth_path=None,
      deterministic: bool = False,
      optimizer=torch.optim.Adam,
      optimizer_kwargs={'lr':1e-3, 'betas': (.9, .999)},
      train_data_gen_is_lambda: bool = False,
      loss_fn_kwargs={'return_per_token':True},
      print_every: Optional[int] = 10,
      log_acc: bool = False,
      force_train: bool = False,
      overwrite_data: bool = False,
      model_description: str = "trained model",
      wandb_entity:str = DEFAULT_WANDB_ENTITY,
      fail_if_cant_load: bool = False,
      save_in_google_drive: bool = False,
      **kwargs, # kwargs for **locals() below
  ):
  if force_train and fail_if_cant_load: raise ValueError(f"force_train is {force_train} and fail_if_cant_load is {fail_if_cant_load}")
  if device is None: device = default_device(deterministic=deterministic)

  pth_base_path = get_pth_base_path(save_in_google_drive=save_in_google_drive, create=True)
  if model_pth_path is None:
    datetime_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    model_pth_path = pth_base_path / f'{model_name}-{model.cfg.n_ctx}-epochs-{n_epochs}-{datetime_str}.pth'

  if not force_train and os.path.exists(model_pth_path):
    res = load_model(model, model_pth_path)
    if res is not None: return res

  if wandb_project is not None:
    wandb_model_path = f"{wandb_entity}/{wandb_project}/{model_name}:latest"
    if not force_train:
      model_dir = None
      try:
        api = wandb.Api()
        model_at = api.artifact(wandb_model_path)
        model_dir = Path(model_at.download())
      except Exception as e:
        print(f'Could not load model {wandb_model_path} from wandb:\n', e)
      if model_dir is not None:
        for model_path in model_dir.glob('*.pth'):
          res = load_model(model, model_path)
          if res is not None: return res

  assert not fail_if_cant_load, f"Couldn't load model from {model_pth_path}{f' or wandb ({wandb_model_path})' if wandb_project is not None else ''}, and fail_if_cant_load is {fail_if_cant_load}"

  if wandb_project is not None:
    config_info = make_wandb_config(**locals())
    run = wandb.init(project=wandb_project, entity=wandb_entity, config=config_info, job_type="train")

  optimizer = optimizer(model.parameters(), **optimizer_kwargs)
  train_data_gen_lambda = (lambda: train_data_gen_maybe_lambda) if not train_data_gen_is_lambda else train_data_gen_maybe_lambda

  train_losses = []

  pbar = tqdm.tqdm(range(n_epochs))
  for epoch in pbar:
    train_data_gen = train_data_gen_lambda()
    epoch_losses = []
    for _ in range(batches_per_epoch):
      tokens = next(train_data_gen)
      logits = model(tokens)
      losses = loss_fn(logits, tokens, **loss_fn_kwargs)
      losses.mean().backward()
      optimizer.step()
      optimizer.zero_grad()
      epoch_losses.extend(losses.detach().cpu().numpy())

    train_losses.append(np.mean(epoch_losses))

    if print_every and epoch % print_every == 0:
      pbar.set_description(f'Epoch {epoch} train loss: {train_losses[-1]:.5e}')

    if wandb_project is not None:
      log_data = {'train_loss': train_losses[-1]}
      if log_acc: log_data['train_acc'] = acc_fn(model(tokens), tokens)
      wandb.log(log_data)

  model.eval()
  logits = model(data_test)
  acc = acc_fn(logits, data_test)

  print(f"Test accuracy after training: {acc}")

  if save_model:
    data = {
       "model":model.state_dict(),
       "config": model.cfg,
       "train_losses": train_losses,
       }
    if overwrite_data or not os.path.exists(model_pth_path):
      torch.save(data, model_pth_path)
      if wandb_project is not None:
        trained_model_artifact = wandb.Artifact(
            model_name, type="model", description=model_description, metadata=model.cfg.to_dict())
        trained_model_artifact.add_file(model_pth_path)
        run.log_artifact(trained_model_artifact)
    elif wandb_project is not None:
      print(f"Warning: {model_pth_path} already exists, saving model directly")
      run.log_artifact(data)

  if wandb_project is not None:
    run.finish()

  return train_losses, model_pth_path


In [None]:
#@title max_of_n
import numpy as np
import torch
from transformer_lens import HookedTransformer
import tqdm.auto as tqdm
import wandb

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def loss_fn(
    logits, # [batch, pos, d_vocab]
    tokens, # [batch, pos]
    return_per_token=False,
    device=DEVICE,
  ):
  logits = logits[:, -1, :].to(device)
  true_maximum = torch.max(tokens.to(device), dim=1)[0]
  log_probs = logits.log_softmax(-1)
  correct_log_probs = log_probs.gather(-1, true_maximum.unsqueeze(-1))
  if return_per_token:
    return -correct_log_probs.squeeze()
  return -correct_log_probs.mean()


# In[ ]:


def acc_fn(
    logits, # [batch, pos, d_vocab]
    tokens, # [batch, pos]
    return_per_token=False,
    device=DEVICE,
  ):
  pred_logits = logits[:, -1, :].to(device)
  pred_tokens = torch.argmax(pred_logits, dim=1)
  true_maximum = torch.max(tokens.to(device), dim=1)[0]
  if return_per_token:
    return (pred_tokens == true_maximum).float()
  return (pred_tokens == true_maximum).float().mean().item()


def large_data_gen(n_digits, sequence_length=6, batch_size=128, context="train", device=DEVICE, adjacent_fraction=0):
  if context == "train":
    seed = 5
  else:
    seed = 6
  torch.manual_seed(seed)
  while True:
    result = torch.randint(0, n_digits, (batch_size, sequence_length)).to(device)
    if adjacent_fraction == 0: yield result
    else:
      adjacent = torch.randint(0, n_digits, (batch_size,))
      adjacent = adjacent.unsqueeze(1).repeat(1, sequence_length)
      # in half the rows, replace a random column with n+1
      rows_to_change = torch.randperm(batch_size)[:batch_size // 2]
      cols_to_change = torch.randint(0, sequence_length, (batch_size // 2,))
      adjacent[rows_to_change, cols_to_change] += 1
      adjacent %= n_digits
      adjacent = adjacent.to(device)
      mask = torch.rand(batch_size) < adjacent_fraction
      result[mask] = adjacent[mask]
      yield result

def make_wandb_config(
    model:HookedTransformer,
    n_epochs=100,
    batch_size=128,
    batches_per_epoch=10,
    adjacent_fraction=0,
    use_complete_data=True,
    device=DEVICE,
    lr=1e-3,
    betas=(.9, .999),
    **kwargs):
  return {
      'model.cfg':model.cfg.to_dict(),
      'optimizer.cfg':{
        'lr':lr,
        'betas':betas,
      },
      'n_epochs':n_epochs,
      'batch_size':batch_size,
      'batches_per_epoch':batches_per_epoch,
      'adjacent_fraction':adjacent_fraction,
      'use_complete_data':use_complete_data,
      'device':device,
    }

def train_model(
    model:HookedTransformer,
    n_epochs=100,
    batch_size=128,
    batches_per_epoch=10,
    adjacent_fraction=0,
    use_complete_data=True,
    device=DEVICE,
    use_wandb=False,
    wandb_project=None,
    save_model=True,
  ):
  lr = 1e-3
  betas = (.9, .999)
  optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)
  n_digits, sequence_length = model.cfg.d_vocab, model.cfg.n_ctx
  train_losses = []
  if wandb_project is not None:
    config_info = make_wandb_config(model, **locals())
    run = wandb.init(project=wandb_project, config=config_info, job_type="train")

  if use_complete_data:
    data_train, data_test = make_testset_trainset(model, force_adjacent=adjacent_fraction > 0)
    train_data_gen_gen = lambda: make_generator_from_data(data_train, batch_size=batch_size)
  else:
    train_data_gen = large_data_gen(n_digits=n_digits, sequence_length=sequence_length, batch_size=batch_size, context="train", device=device, adjacent_fraction=adjacent_fraction)
    test_data_gen = large_data_gen(n_digits=n_digits, sequence_length=sequence_length, batch_size=batch_size * 20, context="test", adjacent_fraction=adjacent_fraction)
    data_test = next(test_data_gen)

  for epoch in tqdm.tqdm(range(n_epochs)):
    if use_complete_data:
      train_data_gen = train_data_gen_gen()
    epoch_losses = []
    for _ in range(batches_per_epoch):
      tokens = next(train_data_gen)
      logits = model(tokens)
      losses = loss_fn(logits, tokens, return_per_token=True)
      losses.mean().backward()
      optimizer.step()
      optimizer.zero_grad()
      epoch_losses.extend(losses.detach().cpu().numpy())

    train_losses.append(np.mean(epoch_losses))

    if epoch % 10 == 0:
      print(f'Epoch {epoch} train loss: {train_losses[-1]}')

    if use_wandb or wandb_project is not None:
      wandb.log({'train_loss': train_losses[-1]})

  model.eval()
  logits = model(data_test)
  acc = acc_fn(logits, data_test)

  print(f"Test accuracy after training: {acc}")

  if save_model and (use_wandb or wandb_project is not None):
    wandb.log_artifact(model)

  if wandb_project is not None:
    run.finish()

  return train_losses

In [None]:
#@title train_max_of_2
import sys
import torch
from transformer_lens import HookedTransformer, HookedTransformerConfig
import tqdm.auto as tqdm

DETERMINISTIC = True # @param
DEVICE = "cuda" if torch.cuda.is_available() and not DETERMINISTIC else "cpu"
N_LAYERS = 1 # @param
N_HEADS = 1 # @param
D_MODEL = 32 # @param
D_HEAD = 32 # @param
D_MLP = None # @param
D_VOCAB = 64 # @param
SEED = 123 # @param
N_EPOCHS = 1500 # @param
N_CTX = 2 # @param
FORCE_ADJACENT = True # @param
BATCH_SIZE = 128 # @param
FAIL_IF_CANT_LOAD = '--fail-if-cant-load' in sys.argv[1:] # @param

ALWAYS_TRAIN_MODEL = False # @param
SAVE_IN_GOOGLE_DRIVE = False # @param
OVERWRITE_DATA = False # @param
TRAIN_MODEL_IF_CANT_LOAD = True # @param


# %%

simpler_cfg = HookedTransformerConfig(
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    d_head=D_HEAD,
    n_ctx=N_CTX,
    d_vocab=D_VOCAB,
    seed=SEED,
    device=DEVICE,
    attn_only=True,
    normalization_type=None,
)
# %%

model = HookedTransformer(simpler_cfg).to(DEVICE)

for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False

model_is_trained = False


# %%

def train(fail_if_cant_load=FAIL_IF_CANT_LOAD, train_if_cant_load=TRAIN_MODEL_IF_CANT_LOAD, overwrite_data=OVERWRITE_DATA,
          always_train_model=ALWAYS_TRAIN_MODEL,
          wandb_entity=DEFAULT_WANDB_ENTITY,
          save_in_google_drive=SAVE_IN_GOOGLE_DRIVE):

    global model_is_trained

    data_train, data_test = make_testset_trainset(model, force_adjacent=FORCE_ADJACENT)
    train_data_gen_gen = lambda: make_generator_from_data(data_train, batch_size=BATCH_SIZE)

    training_losses, model_pth_path = train_or_load_model(
        f'neural-net-coq-interp-max-{model.cfg.n_ctx}-epochs-{N_EPOCHS}',
        model,
        loss_fn=loss_fn,
        acc_fn=acc_fn,
        train_data_gen_maybe_lambda=train_data_gen_gen,
        train_data_gen_is_lambda=True,
        data_test=data_test,
        n_epochs=N_EPOCHS,
        batch_size=BATCH_SIZE,
        adjacent_fraction=1,
        use_complete_data=True,
        batches_per_epoch=10,
        wandb_project=f'neural-net-coq-interp-max-{model.cfg.n_ctx}-epochs-{N_EPOCHS}',
        deterministic=DETERMINISTIC,
        save_in_google_drive=save_in_google_drive,
        overwrite_data=overwrite_data,
        train_model_if_cant_load=train_if_cant_load,
        model_description=f"trained max of {model.cfg.n_ctx} model",
        save_model=True,
        force_train=always_train_model,
        wandb_entity=wandb_entity,
        fail_if_cant_load=fail_if_cant_load,
    )

    model_is_trained = True
    return training_losses, model_pth_path

# %%

def get_model(train_if_necessary = False,  **kwargs):

    train(fail_if_cant_load = not train_if_necessary, train_if_cant_load = train_if_necessary, **kwargs)

    return model


In [None]:
#@title analysis_utils

from typing import Callable, Iterable, List, Optional, Tuple
import einops
from fancy_einsum import einsum
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
import torch
import torch.nn.functional as F
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors
from plotly.subplots import make_subplots
from inspect import signature
import itertools

# %%

def linear_func(x, a, b):
    """Linear function: f(x) = a * x + b"""
    return a * x + b
linear_func.equation = lambda popt: f'y = {popt[0]:.3f}*x + {popt[1]:.3f}'

def quadratic_func(x, a, b, c):
    return a * x**2 + b * x + c
quadratic_func.equation = lambda popt: f'y = {popt[0]:.3e}*x^2 + {popt[1]:.3f}*x + {popt[2]:.3f}'

def cubic_func(x, a, b, c, d):
    return a * x**3 + b * x**2 + c * x + d
cubic_func.equation = lambda popt: f'y = {popt[0]:.3e}*x^3 + {popt[1]:.3e}*x^2 + {popt[2]:.3f}*x + {popt[3]:.3f}'

def quartic_func(x, a, b, c, d, e):
    return a * x**4 + b * x**3 + c * x**2 + d * x + e
quartic_func.equation = lambda popt: f'y = {popt[0]:.3e}*x^4 + {popt[1]:.3e}*x^3 + {popt[2]:.3e}*x^2 + {popt[3]:.3f}*x + {popt[4]:.3f}'

def quintic_func(x, a, b, c, d, e, f):
    return a * x**5 + b * x**4 + c * x**3 + d * x**2 + e * x + f
quintic_func.equation = lambda popt: f'y = {popt[0]:.3e}*x^5 + {popt[1]:.3e}*x^4 + {popt[2]:.3e}*x^3 + {popt[3]:.3e}*x^2 + {popt[4]:.3f}*x + {popt[5]:.3f}'

def absolute_shift_func(x, a, b, c):
    return a * np.abs(x - b) + c
absolute_shift_func.equation = lambda popt: f'y = {popt[0]:.3f}*|x - {popt[1]:.3f}| + {popt[2]:.3f}'

def linear_sinusoid_func(x, a, b, c, d):
    return (a * x + b) * np.sin(c * x + d)
linear_sinusoid_func.equation = lambda popt: f'y = ({popt[0]:.3f}*x + {popt[1]:.3f}) * sin({popt[2]:.3f}*x + {popt[3]:.3f})'

def quadratic_sinusoid_func(x, a, b, c, d, e):
    return (a * x**2 + b * x + c) * np.sin(d * x + e)
quadratic_sinusoid_func.equation = lambda popt: f'y = ({popt[0]:.3f}*x^2 + {popt[1]:.3f}*x + {popt[2]:.3f}) * sin({popt[3]:.3f}*x + {popt[4]:.3f})'

def absolute_shift_sinusoid_func(x, a, b, c, d, e):
    return (a * np.abs(x - b) + c) * np.sin(d * x + e)
absolute_shift_sinusoid_func.equation = lambda popt: f'y = ({popt[0]:.3f}*|x - {popt[1]:.3f}| + {popt[2]:.3f}) * sin({popt[3]:.3f}*x + {popt[4]:.3f})'

def linear_abs_sinusoid_func(x, a, b, c, d):
    return (a * x + b) * np.abs(np.sin(c * x + d))
linear_abs_sinusoid_func.equation = lambda popt: f'y = ({popt[0]:.3f}*x + {popt[1]:.3f}) * |sin({popt[2]:.3f}*x + {popt[3]:.3f})|'

def quadratic_abs_sinusoid_func(x, a, b, c, d, e):
    return (a * x**2 + b * x + c) * np.abs(np.sin(d * x + e))
quadratic_abs_sinusoid_func.equation = lambda popt: f'y = ({popt[0]:.3f}*x^2 + {popt[1]:.3f}*x + {popt[2]:.3f}) * |sin({popt[3]:.3f}*x + {popt[4]:.3f})|'

def absolute_shift_abs_sinusoid_func(x, a, b, c, d, e):
    return (a * np.abs(x - b) + c) * np.abs(np.sin(d * x + e))
absolute_shift_abs_sinusoid_func.equation = lambda popt: f'y = ({popt[0]:.3f}*|x - {popt[1]:.3f}| + {popt[2]:.3f}) * |sin({popt[3]:.3f}*x + {popt[4]:.3f})|'

def sigmoid_func(x, K, B, M):
    return K / (1 + np.exp(-B * (x - M)))
sigmoid_func.equation = lambda popt: f'y = {popt[0]:.3f} / (1 + exp(-{popt[1]:.3f} * (x - {popt[2]:.3f})))'

def inv_sigmoid_func(y, K, B, M):
    return M - np.log(K / y - 1) / B
inv_sigmoid_func.equation = lambda popt: f'x = {popt[2]:.3f} - ln({popt[0]:.3f} / y - 1) / {popt[1]:.3f}'

def fit_name_of_func(fit_function):
    fit_name = fit_function.__name__
    if fit_name is not None and fit_name.endswith('_func'): fit_name = fit_name[:-len('_func')]
    return fit_name

def imshow(tensor, renderer=None, xaxis="", yaxis="", colorscale="RdBu", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=colorscale, labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)


def line(tensor, renderer=None, xaxis="", yaxis="", line_labels=None, showlegend=None, hovertemplate=None, **kwargs):
    fig = px.line(utils.to_numpy(tensor), labels={"index":xaxis, "value":yaxis}, y=line_labels, **kwargs)
    if showlegend is not None: fig.update_layout(showlegend=showlegend)
    if hovertemplate is not None: fig.update_traces(hovertemplate=hovertemplate)
    fig.show(renderer)


def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)


def hist(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.histogram(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)


def pm_range(values):
    return f"{(values.max().item() + values.min().item()) / 2.0} ± {(values.max().item() - values.min().item()) / 2.0}"


def pm_mean_std(values):
    return f"{values.mean().item()} ± {values.std().item()}"


def summarize(values, name=None, histogram=False, renderer=None, hist_args={},
              imshow_args=None, include_value=False, linear_fit=False,
              fit_function=None, fit_equation=None, fit_name=None,
              min=True, max=True, mean=True, median=True, range=True, range_size=True, firstn=None, abs_max=True):
    if histogram:
        hist_args_list = hist_args if isinstance(hist_args, list) else [hist_args]
        for hist_args in hist_args_list:
            hist_args = dict(hist_args)
            if 'title' not in hist_args and name is not None: hist_args['title'] = f'Histogram of {name}'
            if 'renderer' not in hist_args and renderer is not None: hist_args['renderer'] = renderer
            if 'xaxis' not in hist_args: hist_args['xaxis'] = name if name is not None else 'Value'
            if 'yaxis' not in hist_args: hist_args['yaxis'] = 'Count'
            hist(values, **hist_args)

    if imshow_args is not None:
        imshow_args = dict(imshow_args)
        if 'title' not in imshow_args and name is not None: imshow_args['title'] = name
        if 'renderer' not in imshow_args and renderer is not None: imshow_args['renderer'] = renderer
        if 'xaxis' not in imshow_args and name is not None: imshow_args['xaxis'] = f'({name}).shape[1]'
        if 'yaxis' not in imshow_args and name is not None: imshow_args['yaxis'] = f'({name}).shape[0]'
        if len(values.shape) == 1:
            line(values, **imshow_args)
        else:
            imshow(values, **imshow_args)

    if fit_function is None and linear_fit: fit_function = linear_func
    if fit_equation is None and fit_function is not None: fit_equation = fit_function.equation
    if fit_function is not None:
        assert len(values.shape) in (1, 2)
        if len(values.shape) == 1:
            x_vals = np.arange(values.shape[0])
            y_vals = utils.to_numpy(values)
            aggregated = ''
        else:
            x_vals = np.tile(np.arange(values.shape[1]), values.shape[0])
            y_vals = utils.to_numpy(values.flatten())
            aggregated = 'Aggregated '
        name_space = '' if name is None else f'{name} '
        if fit_name is None:
            fit_name = fit_function.__name__
            if fit_name is not None and fit_name.endswith('_func'): fit_name = fit_name[:-len('_func')]
        fit_name_space = '' if not fit_name else f'{fit_name} '
        fit_title = f"{aggregated}{name_space}Data and {fit_name_space}Fit"
        resid_title = f"{aggregated}{name_space}Residual Errors"

        # Fit linear regression to the aggregated data
        popt, _ = curve_fit(fit_function, x_vals, y_vals)

        # Create a subplot with 1 row and 2 columns
        fig, axs = plt.subplots(1, 2, figsize=(12, 6))  # Adjust the figure size to your liking

        # Scatter plot the data & best fit line on the first subplot
        axs[0].scatter(x_vals, y_vals, label='Data', alpha=0.5, s=1)
        axs[0].plot(x_vals, fit_function(x_vals, *popt), 'r-', label=f'Fit: {fit_equation(popt)}')
        axs[0].set_title(fit_title)
        axs[0].legend()

        # Plot residual errors on the second subplot
        residuals = y_vals - fit_function(x_vals, *popt)
        order_indices = np.argsort(x_vals)
        axs[1].scatter(x_vals[order_indices], residuals[order_indices], c='b', alpha=0.5)
        axs[1].set_title(resid_title)

        # Adjust the layout
        plt.tight_layout()
        plt.show()

    res = {}
    if include_value: res['value'] = values.detach().clone().cpu()
    if min: res['min'] = values.min().item()
    if max: res['max'] = values.max().item()
    if mean: res['mean'] = pm_mean_std(values.float())
    if median: res['median'] = values.median().item()
    if range: res['range'] = pm_range(values)
    if range_size: res['range_size'] = values.max().item() - values.min().item()
    if firstn is not None: res[f'first {firstn}'] = values[:firstn]
    if abs_max: res['abs(max)'] = values.abs().max().item()
    if fit_function is not None: res['fit_equation'] = f'y = {popt[0]}*x + {popt[1]}'
    if fit_function is not None: res['range_residuals'] = pm_range(residuals)
    if fit_function is not None: res['residuals'] = residuals[order_indices]
    if fit_function is not None: res['fit_params'] = popt

    return res


#list(zip(all_integers[~correct_idxs], all_integers_ans[~correct_idxs]))

def center_by_mid_range(tensor: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
    maxv, minv = tensor.max(dim=dim, keepdim=True).values, tensor.min(dim=dim, keepdim=True).values
    return tensor - (maxv + minv) / 2.0

# # Simpler Model Interpretabiltiy

# ## Calculating how much slack we have

# Let's find out what the actual logits are, and how much slack we have on errors

# In[ ]:


def analyze_svd(M, descr='', scale_by_singular_value=True, colorscale='Picnic_r', singular_color='blue', renderer=None):
    U, S, Vh = torch.linalg.svd(M)
    V = Vh.T
    if scale_by_singular_value:
        U = U * S[None, :].sqrt()
        V = V * S[:, None].sqrt()
    if descr: descr = f' for {descr}'

    fig = make_subplots(rows=1, cols=3, subplot_titles=["U", "Singular Values", "V"])
    uzmax, vzmax = U.abs().max().item(), V.abs().max().item()
    fig.add_trace(go.Heatmap(z=utils.to_numpy(U), zmin=-uzmax, zmax=uzmax, colorscale=colorscale,
                             showscale=False,
                            hovertemplate="U: %{y}<br>Singular Index: %{x}<br>Value: %{z}<extra></extra>"
                            ),
                row=1, col=1)
    fig.add_trace(go.Heatmap(z=utils.to_numpy(V), colorscale=colorscale, zmin=-vzmax, zmax=vzmax,
                             showscale=False,
                            hovertemplate="V: %{x}<br>Singular Index: %{y}<br>Value: %{z}<extra></extra>",
                            ),
                row=1, col=3)
    fig.add_trace(go.Scatter(x=np.arange(S.shape[0]), y=utils.to_numpy(S),
                            mode='lines+markers',
                            marker=dict(color=singular_color),
                            line=dict(color=singular_color),
                            hovertemplate="Singular Value: %{y}<br>Singular Index: %{x}<extra></extra>",
                            ), row=1, col=2)
    fig.update_layout(title=f"SVD{descr}") #, margin=dict(l=150, r=150))


    fig.update_yaxes(range=[0, None], row=1, col=2)
    # fig.update_yaxes(range=[0, None], row=1, col=2)
    # fig.update_layout(yaxis_scaleanchor="x")
    fig.update_yaxes(scaleanchor='x', autorange='reversed', row=1, col=1)
    fig.update_yaxes(scaleanchor='x', autorange='reversed', row=1, col=3)

    # fig.update_xaxes(scaleanchor='y', scaleratio=1, range=[0, U.shape[0]], row=1, col=1)
    # fig.update_yaxes(scaleanchor='x', scaleratio=1, range=[0, U.shape[1]], row=1, col=1)

    # fig.update_xaxes(scaleanchor='y', scaleratio=1, range=[0, None], row=1, col=2)
    # fig.update_yaxes(scaleanchor='x', scaleratio=1, range=[0, S.shape[0]], row=1, col=2)

    # fig.update_xaxes(scaleanchor='y', scaleratio=1, range=[0, Vh.shape[0]], row=1, col=3)
    # fig.update_yaxes(scaleanchor='x', scaleratio=1, range=[0, Vh.shape[1]], row=1, col=3)

    # fig.update_xaxes(range=[0, None], row=1, col=1)
    # fig.update_xaxes(range=[0, None], row=1, col=2)
    # fig.update_xaxes(range=[0, None], row=1, col=3)

    # fig.update_yaxes(range=[0, None], row=1, col=1)
    # fig.update_yaxes(range=[0, None], row=1, col=2)
    # fig.update_yaxes(range=[0, None], row=1, col=3)

    # fig.update_yaxes(title_text="Query Token", row=1, col=1)
    fig.update_yaxes(range=[0, None], row=1, col=2)
    # fig.update_yaxes(title_text="Key Token", row=1, col=3)

    fig.show(renderer)


    # line(S, title=f"Singular Values{descr}")
    # imshow(U, title=f"Principal Components on U{descr}")
    # imshow(Vh, title=f"Principal Components on Vh{descr}")



# %%
@torch.no_grad()
def make_fit(values: torch.Tensor, fit_function, exclude_count=None):
    assert len(values.shape) in (1, 2)
    if len(values.shape) == 1:
        x_vals = np.arange(values.shape[0])
        y_vals = utils.to_numpy(values)
    else:
        x_vals = np.tile(np.arange(values.shape[1]), values.shape[0])
        y_vals = utils.to_numpy(values.flatten())

    x_vals_fit, y_vals_fit = x_vals, y_vals
    if exclude_count is not None: x_vals_fit, y_vals_fit = x_vals[exclude_count:-exclude_count], y_vals[exclude_count:-exclude_count]
    popt, _ = curve_fit(fit_function, x_vals_fit, y_vals_fit)

    residuals = y_vals - fit_function(x_vals, *popt)
    order_indices = np.argsort(x_vals)

    return popt, (x_vals, y_vals), (x_vals, fit_function(x_vals, *popt)), (x_vals[order_indices], residuals[order_indices])


def make_fit_traces(values: torch.Tensor, fit_function, exclude_count=None, fit_equation: Optional[Callable] = None, reference_lines: Optional[List[Tuple[str, float]]] = None, reference_colors=plotly.colors.qualitative.Dark24):
    popt, points, fit, resid = make_fit(values, fit_function, exclude_count=exclude_count)
    if fit_equation is None: fit_equation = fit_function.equation
    if reference_lines is None: reference_lines = []
    reference_line_traces = \
        [go.Scatter(x=np.arange(points[0].shape[0]), y=np.full(points[0].shape, val), name=name, mode='lines', line=dict(color=color, dash='dash'),
                hovertemplate=f'{val}<extra>{name}</extra>',
                showlegend=False, legendgroup=fit_function.__name__)
        for (name, val), color in zip(reference_lines, itertools.cycle(reference_colors))]
    # , size=1
    return popt, \
            [go.Scatter(x=points[0], y=points[1], name='Data', mode='markers', marker=dict(color='red', opacity=0.5), showlegend=True, legendgroup=fit_function.__name__),
            go.Scatter(x=fit[0], y=fit[1], name=f'Fit: {fit_equation(popt)}', mode='lines', line=dict(color='blue'), showlegend=True, legendgroup=fit_function.__name__),
            go.Scatter(x=resid[0], y=resid[1], name='Residuals', mode='markers', marker=dict(color='red', opacity=0.5), showlegend=False)], \
            reference_line_traces

def show_fits(values: torch.Tensor, name: str, fit_funcs: Iterable[Callable], do_exclusions=True, renderer=None, **kwargs):
    assert len(values.shape) == 1
    fit_funcs = list(fit_funcs)
    fig = make_subplots(rows=len(fit_funcs), cols=2,
                        subplot_titles=[title
                                        for fit_func in fit_funcs
                                        for title in (f"{fit_name_of_func(fit_func)} Fit", f"Residuals")])
    for i, fit_func in enumerate(fit_funcs):
        popt, (points, fit, resid), reference_line_traces = make_fit_traces(values, fit_func, exclude_count=None, **kwargs)
        fig.add_trace(points, row=i+1, col=1)
        fig.add_trace(fit, row=i+1, col=1)
        fig.add_trace(resid, row=i+1, col=2)
        for trace in reference_line_traces:
            fig.add_trace(trace, row=i+1, col=1)
    fig.update_layout(
        title=f"{name} Data & Fit",
        legend=dict(
            bgcolor='rgba(255,255,255,0.5)',
            yanchor="middle",
            y=0.5,  # Y=1 anchors the legend to the top of the plot area
            xanchor="left",
            x=0
        ),
        height=300 * len(fit_funcs) + 100,
    )

    if do_exclusions:
        max_param_count = max([len(signature(fit_func).parameters) for fit_func in fit_funcs])
        frames = [go.Frame(data=[trace
                                for fit_func in fit_funcs
                                for trace_list in make_fit_traces(values, fit_func, exclude_count=exclude_count, **kwargs)[1:]
                                for trace in trace_list],
                            name=(str(exclude_count) if exclude_count is not None else "0"),
        ) for exclude_count in [None] + list(range(1, (values.shape[0] - max_param_count) // 2))]

        fig.frames = frames

        sliders = [dict(
            active=0,
            yanchor='top',
            xanchor='left',
            currentvalue=dict(font=dict(size=20), prefix='# End Points to Exclude:', visible=True, xanchor='right'),
            transition=dict(duration=0),
            pad=dict(b=10, t=50),
            len=0.9,
            x=0.1,
            y=0,
            steps=[dict(args=[[frame.name], dict(mode='immediate', frame=dict(duration=0, redraw=True), transition=dict(duration=0))],
                        method='animate',
                        label=frame.name) for frame in fig.frames]
        )]

        fig.update_layout(sliders=sliders)

    fig.show(renderer)

# %%

# ## Negligibility of W_E @ W_U

# In[ ]:


def calculate_embed_overlap(model: HookedTransformer, renderer=None):
    W_U, W_E = model.W_U, model.W_E
    d_model, d_vocab = model.cfg.d_model, model.cfg.d_vocab
    assert W_U.shape == (d_model, d_vocab)
    assert W_E.shape == (d_vocab, d_model)
    res = (W_E @ W_U).detach()
    self_overlap = res.diag()
    imshow(res, renderer=renderer)
    line(self_overlap, renderer=renderer)
    statistics = [
        ('overlap', res),
        ('self-overlap', self_overlap),
        ('self-overlap after 0', self_overlap[1:]),
    ]
    return {name: summarize(value, name=name, include_value=True) for name, value in statistics}


# ## Negligibility of W_pos @ W_U

# In[ ]:


def calculate_pos_embed_overlap(model: HookedTransformer, renderer=None):
    W_U, W_pos = model.W_U, model.W_pos
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    assert W_U.shape == (d_model, d_vocab)
    assert W_pos.shape == (n_ctx, d_model)
    res = (W_pos @ W_U).detach()
    imshow(res, renderer=renderer)

    statistics = [
        ('pos_embed_overlap', res),
        ('pos_embed_overlap (pos -1)', res[-1,:]),
    ]
    return {name: summarize(value, name=name, include_value=True, linear_fit=True, renderer=renderer) for name, value in statistics}


# ## Negligibility of (W_E + W_pos[-1]) @ W_U

# In[ ]:

def calculate_embed_and_pos_embed_overlap(model: HookedTransformer, renderer=None):
    W_U, W_E, W_pos = model.W_U, model.W_E, model.W_pos
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    assert W_U.shape == (d_model, d_vocab)
    assert W_pos.shape == (n_ctx, d_model)
    assert W_E.shape == (d_vocab, d_model)
    res = ((W_E + W_pos[-1,:]) @ W_U).detach()
    self_overlap = res.diag()
    centered_by_mid_range = center_by_mid_range(res, dim=-1)
    centered = res - self_overlap[:,None]
    centered_triu = centered.triu()
    centered_tril = centered.tril()
    centered_no_diag = centered.clone()
    centered_no_diag.diagonal().fill_(-1000000)
    centered_no_diag_after_0 = centered_no_diag[:,1:]
    centered_no_diag = centered_no_diag[centered_no_diag != -1000000]
    centered_no_diag_after_0 = centered_no_diag_after_0[centered_no_diag_after_0 != -1000000]
    statistics = [
        ('centered overlap (incl pos)', centered),
        ('centered overlap (incl pos) triu', centered_triu),
        ('centered overlap (incl pos) tril', centered_tril),
        ('centered overlap after 0 (incl pos)', centered[:,1:]),
        ('centered overlap after 0 (incl pos) no diag', centered_no_diag_after_0),
        ('centered overlap only 0 (incl pos)', centered[:,0]),
        ('centered overlap only 0 (incl pos) no diag', centered[1:,0]),
        ('overlap (incl pos)', res),
        ('self-overlap (incl pos)', self_overlap),
        ('self-overlap after 0 (incl pos)', self_overlap[1:]),
        ('centered overlap (incl pos) no diag', centered_no_diag),
        ('centered by mid_range overlap (incl pos)', centered_by_mid_range),
        ('centered by mid_range overlap after 0 (incl pos)', centered_by_mid_range[:,1:]),
        ('centered by mid_range overlap only 0 (incl pos)', centered_by_mid_range[:,0]),
        ('centered by mid_range overlap only 0 (incl pos) no diag', centered_by_mid_range[1:,0]),
    ]
    return {name: summarize(value, include_value=False, name=name, renderer=renderer, linear_fit=True,
                            imshow_args={'yaxis':'input token', 'xaxis':'output token'},
                            ) for name, value in statistics}


def calculate_rowwise_embed_and_pos_embed_overlap(model: HookedTransformer, renderer=None):
    """
    For `(W_E + W_pos[-1,:]) @ W_U`, we compute for each row the maximum absolute value of the following quantity:
    - the largest negative value a number to the right of the diagonal is below the diagonal
    - the largest positive value a number to the left  of the diagonal is above the diagonal
    This is the exact value of the largest absolute error introduced in a given row.
    """
    W_U, W_E, W_pos = model.W_U, model.W_E, model.W_pos
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    assert W_U.shape == (d_model, d_vocab)
    assert W_pos.shape == (n_ctx, d_model)
    assert W_E.shape == (d_vocab, d_model)
    res = ((W_E + W_pos[-1,:]) @ W_U).detach()
    self_overlap = res.diag()
    centered = res - self_overlap[:,None]
    centered_triu = centered.triu()
    centered_tril = centered.tril()
    diffs = centered_tril - centered_triu
    imshow(res)
    imshow(centered)
    imshow(diffs)
    # # max of positive differences to the right of the diagonal
    # max_pos_diffs = torch.max(centered_triu, dim=-1).values
    # # max of negative differences to the left of the diagonal
    # max_neg_diffs = torch.min(centered_tril, dim=-1).values
    # # stack the diffs
    # diffs = torch.stack([max_neg_diffs, max_pos_diffs], dim=-1)
    # summarize(diffs, name='rowwise diffs (positive and negative)', renderer=renderer)
    # # take the max of the diffs
    max_diffs = torch.max(diffs, dim=-1).values
    return summarize(max_diffs, name='rowwise max absolute diffs', include_value=True, linear_fit=True, renderer=renderer)



# ## Negligibility of W_pos @ W_V @ W_O @ W_U

# In[ ]:


def calculate_OV_of_pos_embed(model: HookedTransformer, renderer=None):
    W_U, W_E, W_pos, W_V, W_O = model.W_U, model.W_E, model.W_pos, model.W_V, model.W_O
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    assert W_U.shape == (d_model, d_vocab)
    assert W_pos.shape == (n_ctx, d_model)
    assert W_E.shape == (d_vocab, d_model)
    assert W_V.shape == (1, 1, d_model, d_model)
    assert W_O.shape == (1, 1, d_model, d_model)
    res = (W_pos @ W_V @ W_O @ W_U).detach()[0,0,:,:]
    imshow(res, title='W_pos @ W_V @ W_O @ W_U', xaxis='logit affected', yaxis='position', renderer=renderer)
    return summarize(res, name='W_pos @ W_V @ W_O @ W_U', renderer=renderer, linear_fit=True)
# %%
def analyze_PVOU(model: HookedTransformer, colorscale='RdBu', renderer=None):
    W_U, W_E, W_pos, W_V, W_O = model.W_U, model.W_E, model.W_pos, model.W_V, model.W_O
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    assert W_U.shape == (d_model, d_vocab)
    assert W_pos.shape == (n_ctx, d_model)
    assert W_E.shape == (d_vocab, d_model)
    assert W_V.shape == (1, 1, d_model, d_model)
    assert W_O.shape == (1, 1, d_model, d_model)
    res = (W_pos @ W_V @ W_O @ W_U).detach()[0,0,:,:]
    pos_indices = torch.arange(n_ctx)
    fig = px.imshow(utils.to_numpy(res), title='W_pos @ W_V @ W_O @ W_U',
                    labels={"x":'logit affected', "y":'position'},
                    color_continuous_midpoint=0.0, color_continuous_scale=colorscale)
    fig.update_yaxes(tickvals=pos_indices, ticktext=pos_indices)
    fig.show(renderer)

# %%
def analyze_PU(model: HookedTransformer, colorscale='RdBu', renderer=None):
    W_U, W_pos = model.W_U, model.W_pos
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    assert W_U.shape == (d_model, d_vocab)
    assert W_pos.shape == (n_ctx, d_model)
    res = (W_pos[-1, :] @ W_U).detach()
    line(res, title='W_pos[-1] @ W_U', xaxis='output token', showlegend=False, hovertemplate='Logit for %{x}: %{y}', renderer=renderer)

# %%
def analyze_EU(model: HookedTransformer, colorscale='RdBu', renderer=None):
    W_U, W_E = model.W_U, model.W_E
    d_model, d_vocab = model.cfg.d_model, model.cfg.d_vocab
    assert W_U.shape == (d_model, d_vocab)
    assert W_E.shape == (d_vocab, d_model)
    res = (W_E @ W_U).detach()
    imshow(res, title='W_E @ W_U', renderer=renderer,
           xaxis="logit affected", yaxis="input token", colorscale=colorscale)




# ## Copying: W_E @ W_V @ W_O @ W_U

# %%
def analyze_EVOU(model: HookedTransformer, colorscale='RdBu', renderer=None, scale_by_singular_value=True):
    W_U, W_E, W_pos, W_V, W_O = model.W_U, model.W_E, model.W_pos, model.W_V, model.W_O
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    assert W_U.shape == (d_model, d_vocab)
    assert W_pos.shape == (n_ctx, d_model)
    assert W_E.shape == (d_vocab, d_model)
    assert W_V.shape == (1, 1, d_model, d_model)
    assert W_O.shape == (1, 1, d_model, d_model)
    res = (W_E @ W_V @ W_O @ W_U).detach().cpu()[0,0,:,:]
    imshow(res, title='W_E @ W_V @ W_O @ W_U', renderer=renderer,
           xaxis="logit affected", yaxis="input token", colorscale=colorscale)
    analyze_svd(res, descr='W_E @ W_V @ W_O @ W_U', colorscale=colorscale, scale_by_singular_value=scale_by_singular_value, renderer=renderer)
    line(res.diag(), title='(W_E @ W_V @ W_O @ W_U).diag()', xaxis='input token', showlegend=False, hovertemplate='Input Token: %{x}<br>Logit on %{x}: %{y}', renderer=renderer)

# In[ ]:


def calculate_copying(model: HookedTransformer, colorscale='RdBu', renderer=None, scale_by_singular_value=True):
    W_U, W_E, W_pos, W_V, W_O = model.W_U, model.W_E, model.W_pos, model.W_V, model.W_O
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    assert W_U.shape == (d_model, d_vocab)
    assert W_pos.shape == (n_ctx, d_model)
    assert W_E.shape == (d_vocab, d_model)
    assert W_V.shape == (1, 1, d_model, d_model)
    assert W_O.shape == (1, 1, d_model, d_model)
    res = (W_E @ W_V @ W_O @ W_U).detach().cpu()[0,0,:,:]
    res_diag = res.diag()
    res_off_diagonal = res[torch.eye(d_vocab) == 0]
    centered = -res + res.diag()[:, None]
    nonzero_centered = centered[torch.eye(d_vocab) == 0]
    imshow(res, title='W_E @ W_V @ W_O @ W_U', renderer=renderer,
           xaxis="logit affected", yaxis="input token")
    analyze_svd(res, descr='W_E @ W_V @ W_O @ W_U', colorscale=colorscale, scale_by_singular_value=scale_by_singular_value, renderer=renderer)
    # imshow(centered, title='copying.diag()[:,None] - copying', renderer=renderer)
    line(res.diag(), title='copying.diag()', xaxis='input token', renderer=renderer)
    # take svd of res
    u, s, vh = torch.linalg.svd(res)
    v = vh.T
    # plot singular values
    line(s, title='singular values of copying', renderer=renderer)
    # plot u, v
    imshow(u, title='u', renderer=renderer)
    imshow(v, title='v', renderer=renderer)

    # 1. We already have u, s, and v from torch.linalg.svd(res)
    u1 = u[:, 0]
    v1 = v[:, 0]

    # 2. Fit linear models to u1 and v1
    # Fit for u's first column
    x_vals_u = np.arange(d_vocab)
    y_vals_u = u[:, 0].numpy()
    popt_u, _ = curve_fit(linear_func, x_vals_u, y_vals_u)

    # Fit for v's first column
    x_vals_v = np.arange(d_vocab)
    y_vals_v = v[0, :].numpy()
    popt_v, _ = curve_fit(linear_func, x_vals_v, y_vals_v)

    # Plot u's column against its linear fit
    plt.figure()
    plt.scatter(x_vals_u, y_vals_u, alpha=0.5, label='Data')
    plt.plot(x_vals_u, linear_func(x_vals_u, *popt_u), 'r-', label=f'u: y = {popt_u[0]:.4f}x + {popt_u[1]:.4f}')
    plt.title("First Column of u vs Linear Fit")
    plt.legend()
    plt.show()

    # Plot residuals for u
    plt.figure()
    residuals_u = y_vals_u - linear_func(x_vals_u, *popt_u)
    plt.scatter(x_vals_u, residuals_u, alpha=0.5)
    plt.axhline(0, color='red', linestyle='--')
    plt.title("Residuals of u's First Column Fit")
    plt.show()

    # Plot v's row against its linear fit
    plt.figure()
    plt.scatter(x_vals_v, y_vals_v, alpha=0.5, label='Data')
    plt.plot(x_vals_v, linear_func(x_vals_v, *popt_v), 'r-', label=f'v: y = {popt_v[0]:.4f}x + {popt_v[1]:.4f}')
    plt.title("First Row of v vs Linear Fit")
    plt.legend()
    plt.show()

    # Plot residuals for v
    plt.figure()
    residuals_v = y_vals_v - linear_func(x_vals_v, *popt_v)
    plt.scatter(x_vals_v, residuals_v, alpha=0.5)
    plt.axhline(0, color='red', linestyle='--')
    plt.title("Residuals of v's First Row Fit")
    plt.show()

    # Subtract impact of lines
    u_prime = linear_func(x_vals_u, *popt_u)
    v_prime = linear_func(x_vals_v, *popt_v)
    impact = s[0] * u_prime[:, None] @ v_prime[None, :]
    adjusted_res = res - impact
    imshow(impact, title="adjustment", renderer=renderer)

    # adjusted_res = res - s[0] * (u[:, 0:1] @ v[:,0:1].T) * (popt_u[0] * x_vals_u[:, None] + popt_v[0] * x_vals_v[None, :])

    imshow(adjusted_res, title='Adjusted res', renderer=renderer)

    # SVD of adjusted_res
    u_adj, s_adj, vh_adj = torch.linalg.svd(adjusted_res)
    line(s_adj, title='Singular Values of Adjusted res', renderer=renderer)
    imshow(u_adj, title='u of residuals', renderer=renderer)
    imshow(vh_adj.T, title='v of residuals', renderer=renderer)

    # Extracting diagonal and off-diagonal entries
    diagonal_entries = torch.diag(adjusted_res)
    off_diagonal_entries = adjusted_res - torch.diag_embed(diagonal_entries)
    off_diagonal_entries = off_diagonal_entries[off_diagonal_entries != 0]

    # Finding the smallest diagonal entry and the largest off-diagonal entry
    min_diagonal_entry = diagonal_entries.min().item()
    max_off_diagonal_entry = off_diagonal_entries.max().item()

    # Printing the results
    print(f"Smallest diagonal entry: {min_diagonal_entry} ({pm_range(diagonal_entries)})")
    print(f"Largest off-diagonal entry: {max_off_diagonal_entry} ({pm_range(off_diagonal_entries)})")

    line(diagonal_entries, title='Diagonal Entries', renderer=renderer)

    off_diagonal_entries = off_diagonal_entries.flatten()
    # Histogram plot
    plt.hist(diagonal_entries.numpy(), bins=50, color='blue', alpha=0.7, label='Diagonal entries')
    plt.hist(off_diagonal_entries.numpy(), bins=50, color='red', alpha=0.5, label='Off-diagonal entries')
    plt.legend(loc='upper right')
    plt.title('Histogram of Diagonal and Off-diagonal Entries')
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.show()

    # Histogram plot
    plt.hist(diagonal_entries.numpy(), bins=50, color='blue', alpha=0.7, label='Diagonal entries', density=True)
    plt.hist(off_diagonal_entries.numpy(), bins=50, color='red', alpha=0.5, label='Off-diagonal entries', density=True)
    plt.legend(loc='upper right')
    plt.title('Density Histogram of Diagonal and Off-diagonal Entries')
    plt.xlabel('Value')
    plt.ylabel('Probability Density')
    plt.show()


    centered_adjusted_res = -adjusted_res + adjusted_res.diag()[:, None]
    nonzero_centered_adjusted_res = centered_adjusted_res[centered_adjusted_res != 0]

    imshow(centered_adjusted_res, title='adjusted copying.diag()[:,None] - adjusted copying', renderer=renderer)
    print(f"range on nonzero centered adjusted res: {pm_range(nonzero_centered_adjusted_res)}")

    statistics = [
        ('copying', res),
        ('diag', res_diag),
        ('off-diag', res_off_diagonal),
        ('centered', centered),
        ('nonzero centered', nonzero_centered),
    ]

    summaries = {name: summarize(value, name=name, renderer=renderer, histogram=False) for name, value in statistics}
    for k, v in summaries.items():
        print(k, v)
    return res


# In[ ]:


def calculate_copying_with_pos(model: HookedTransformer, renderer=None):
    W_U, W_E, W_pos, W_V, W_O = model.W_U, model.W_E, model.W_pos, model.W_V, model.W_O
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    assert W_U.shape == (d_model, d_vocab)
    assert W_pos.shape == (n_ctx, d_model)
    assert W_E.shape == (d_vocab, d_model)
    assert W_V.shape == (1, 1, d_model, d_model)
    assert W_O.shape == (1, 1, d_model, d_model)
    res = (W_E @ W_V @ W_O @ W_U).detach()[0,0,:,:]
    res_pos = (W_pos @ W_V @ W_O @ W_U).detach()[0,0,:,:]
    res_pos_min, res_pos_max = res_pos.min(dim=0).values, res_pos.max(dim=0).values
    res_diag = res.diag() + res_pos_min
    res_above_diag = -(res + res_pos_max[None,:]) + res_diag[:, None]
    imshow(res_above_diag, title='(W_E + worst(W_pos)) @ W_V @ W_O @ W_U', renderer=renderer,
              xaxis="logit affected", yaxis="input token")
    res_above_diag_off_diag = res_above_diag[torch.eye(d_vocab) == 0]
    first_diagonal = res.diag(diagonal=1) + res_pos_min[:-1]
    res_above_first_diagonal = -(res[:-1,:] + res_pos_max[None,:]) + first_diagonal[:, None]
    statistics = [
       ('res_above_diag_off_diag', res_above_diag_off_diag),
         ('res_above_first_diagonal', res_above_first_diagonal),
    ]
    for name, value in statistics:
        print(name, summarize(value, name=name, renderer=renderer, histogram=True))


# ## Attention Scaling Factor

# In[ ]:


def calculate_attn(model: HookedTransformer, pos: Optional[int] = None, renderer=None):
    W_U, W_E, W_pos, W_Q, W_K = model.W_U, model.W_E, model.W_pos, model.W_Q, model.W_K
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    if pos is None:
        return [calculate_attn(model, pos=i, renderer=renderer) for i in range(n_ctx)]
    assert W_U.shape == (d_model, d_vocab)
    assert W_pos.shape == (n_ctx, d_model)
    assert W_E.shape == (d_vocab, d_model)
    assert W_Q.shape == (1, 1, d_model, d_model)
    assert W_K.shape == (1, 1, d_model, d_model)
    residm1 = (W_E + W_pos[-1,:][None,:])
    resid = (W_E + W_pos[pos,:][None,:])
    q = (residm1 @ W_Q)[0,0,:,:]
    k = (resid @ W_K)[0,0,:,:]
    res = (q @ k.T).detach()
    # imshow(res, title=f'(W_E + W_pos[-1]) @ W_Q @ W_K.T @ (W_E + W_pos[{pos}]).T', renderer=renderer)
    centered = res - res.mean(dim=-1, keepdim=True)
    imshow(centered, title=f'centered (W_E + W_pos[-1]) @ W_Q @ W_K.T @ (W_E + W_pos[{pos}]).T', renderer=renderer,
           xaxis="Key token", yaxis="Query token")
    return centered


# %%

# check for monotonicity violations
def check_monotonicity(model: HookedTransformer, renderer=None):
    count = 0
    centered_scores = calculate_attn(model, renderer=renderer)
    for pos, centered_score in enumerate(centered_scores):
        for row_n, row in enumerate(centered_score):
            for i in range(row.shape[0] - 1):
                for j in range(i + 1, row.shape[0]):
                    if row[i] > row[j]:
                        count += 1
                        print(f"{i, j} at row {row_n} pos {pos}, magnitude {row[i] - row[j]:.3f}")
    return count


# %%

def calculate_attn_by_pos(model: HookedTransformer, pos=False, renderer=None):
    W_U, W_E, W_pos, W_Q, W_K = model.W_U, model.W_E, model.W_pos, model.W_Q, model.W_K
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    assert W_U.shape == (d_model, d_vocab)
    assert W_pos.shape == (n_ctx, d_model)
    assert W_E.shape == (d_vocab, d_model)
    assert W_Q.shape == (1, 1, d_model, d_model)
    assert W_K.shape == (1, 1, d_model, d_model)
    residm1 = (W_E + W_pos[-1,:][None,:])

    resid = (W_E if not pos else W_pos[0,:][None,:] - W_pos[1, :][None,:])
    resid_name = 'W_E' if not pos else f'(W_pos[0] - W_pos[1])'
    q = (residm1 @ W_Q)[0,0,:,:]
    k = (resid @ W_K)[0,0,:,:]
    res = (q @ k.T).detach()
    # imshow(res, title=f'(W_E + W_pos[-1]) @ W_Q @ W_K.T @ (W_E + W_pos[{pos}]).T', renderer=renderer)
    centered = res - res.mean(dim=-1, keepdim=True) if not pos else res
    imshow(centered, title=f'centered (W_E + W_pos[-1]) @ W_Q @ W_K.T @ {resid_name}.T', renderer=renderer,
           xaxis="Key token", yaxis="Query token")
    #print(centered.shape)
    return summarize(centered, name=f'centered (W_E + W_pos[-1]) @ W_Q @ W_K.T @ {resid_name}.T',
                     renderer=renderer,
                     include_value=True)

def replace_nans_with_row_max(tensor):
    # Step 1: Identify the nan values
    nan_mask = torch.isnan(tensor)

    # Step 2: Compute the maximum value for each row, ignoring nans
    non_nan_tensor = torch.where(nan_mask, torch.tensor(float('-inf')).to(tensor.device), tensor)
    row_max, _ = torch.max(non_nan_tensor, dim=1, keepdim=True)

    # Replace nan with the max value of the respective row
    tensor[nan_mask] = row_max.expand_as(tensor)[nan_mask]

    return tensor

def calculate_rowwise_attn_by_pos_near(model: HookedTransformer, pos=False, renderer=None, max_offset=1):
    def pad_diagonal_with(shape, diag, offset, val=100000):
        before_padding = torch.zeros(list(shape[:-2]) + [np.max([0, -offset])], device=diag.device)
        after_padding  = torch.zeros(list(shape[:-2]) + [np.max([0,  offset])], device=diag.device)
        before_padding.fill_(val)
        after_padding.fill_(val)
        return torch.cat([before_padding, diag, after_padding], dim=-1)

    points = []
    centered_score = calculate_attn_by_pos(model, renderer=renderer, pos=pos)['value']
    centered_diag = centered_score.diag()
    centered_score = centered_score - centered_diag[:, None]
    res = torch.stack([np.sign(offset) * pad_diagonal_with(centered_score.shape, centered_score.diag(diagonal=offset), offset, val=float('nan'))
                       for offset in range(-max_offset, max_offset + 1) if offset != 0], dim=-1)
    imshow(centered_score, renderer=renderer)
    imshow(res, renderer=renderer)
    res = replace_nans_with_row_max(res)
    min_right_attn = res.min(dim=-1).values
    return summarize(min_right_attn, name=f'min right attn by pos near {max_offset}', renderer=renderer, include_value=True, fit_function=quadratic_func)
    #return min(points)

def calculate_min_attn_by_pos_far(model: HookedTransformer, pos=False, renderer=None, min_offset=2):
    points = []
    centered_score = calculate_attn_by_pos(model, renderer=renderer, pos=pos)['value']
    for row_n, row in enumerate(centered_score):
        for i in range(row.shape[0]):
            if i != row_n and abs(i - row_n) >= min_offset:
                points.append((row[i].item() - row[row_n].item())  / (i - row_n))
    # histogram
    plt.hist(points, bins=100, edgecolor='black')
    return min(points)

# ## Attention Patterns

# In[ ]:


def calculate_qk_attn_heatmap(model, keypos=-1, querypos=-1, do_layernorm=True):
    attn = model.blocks[0].attn
    all_token_embeddings = model.embed(range(model.cfg.d_vocab))
    positional_embeddings = model.pos_embed(all_token_embeddings)

    token_embeddings_at_keypos = all_token_embeddings + positional_embeddings[:,keypos,:] if keypos > -1 else all_token_embeddings
    token_embeddings_at_querypos = all_token_embeddings + positional_embeddings[:,querypos,:] if querypos > -1 else all_token_embeddings

    # layernorm before attention
    if do_layernorm:
        token_embeddings_at_keypos = model.blocks[0].ln1(token_embeddings_at_keypos)
        token_embeddings_at_querypos = model.blocks[0].ln1(token_embeddings_at_querypos)

    embeddings_key = einsum("d_vocab d_model, n_heads d_model d_head -> n_heads d_vocab d_head",
                            token_embeddings_at_keypos, attn.W_K)
    embeddings_query = einsum("d_vocab d_model, n_heads d_model d_head -> n_heads d_vocab d_head",
                            token_embeddings_at_querypos, attn.W_Q)

    qk_circuit_attn_heatmap = einsum(
        "n_heads d_vocab_q d_head, n_heads d_vocab_k d_head -> ... d_vocab_q d_vocab_k",
        embeddings_query, embeddings_key
        ).detach().cpu().numpy()

    plt.rcParams['figure.figsize'] = [20, 10]
    return qk_circuit_attn_heatmap


def calculate_qk_attn_heatmap_normed(model, querypos=-1, do_layernorm=True, skip_var=True):
    all_token_embeddings = model.embed(range(model.cfg.d_vocab))
    positional_embeddings = model.pos_embed(all_token_embeddings)
    all_heatmaps = torch.stack([torch.tensor(calculate_qk_attn_heatmap(model, cur_keypos, querypos, do_layernorm=do_layernorm)) for cur_keypos in range(positional_embeddings.shape[-2])])
    avg = einops.reduce(all_heatmaps, "keypos d_vocab_q d_vocab_k -> d_vocab_q ()", 'mean')
    var = einops.reduce(all_heatmaps, "keypos d_vocab_q d_vocab_k -> d_vocab_q ()", torch.var)
    #print(all_heatmaps.shape, avg.shape)
    #print(avg)
    res = (all_heatmaps - avg)
    if not skip_var: res = res * (var ** -0.5)
    return res


def plot_qk_heatmap(model, keypos=-1, querypos=-1, do_layernorm=True):
  qk_attn_heatmap = calculate_qk_attn_heatmap(model, keypos=keypos, querypos=querypos, do_layernorm=do_layernorm)

  fig, ax = plt.subplots(figsize=(8, 8))
  graph = ax.imshow(qk_attn_heatmap, cmap="hot", interpolation="nearest")
  plt.colorbar(graph)
  plt.tight_layout()


def plot_qk_heatmaps_normed(model, keypositions=None, querypos=-1, do_layernorm=True, skip_var=True):
    if keypositions is None:
        all_token_embeddings = model.embed(range(model.cfg.d_vocab))
        positional_embeddings = model.pos_embed(all_token_embeddings)
        keypositions = range(positional_embeddings.shape[-2])

    heatmaps = calculate_qk_attn_heatmap_normed(model, querypos=querypos, do_layernorm=do_layernorm, skip_var=skip_var)
    for keypos in keypositions:
        fig, ax = plt.subplots(figsize=(8, 8))
        qk_attn_heatmap = heatmaps[keypos]
        graph = ax.imshow(qk_attn_heatmap, cmap="hot", interpolation="nearest")
        plt.colorbar(graph)
        plt.tight_layout()
        plt.show()
    print(heatmaps.shape) # torch.Size([2, 64, 64]), keypos d_vocab_q d_vocab_k


def plot_avg_qk_heatmap(model, keypositions, querypos=-1, do_layernorm=True):
  heatmaps = []

  for keypos in keypositions:
    heatmaps.append(calculate_qk_attn_heatmap(model, keypos=keypos, querypos=querypos, do_layernorm=do_layernorm))

  qk_circuit_attn_heatmap = np.mean(heatmaps, axis=0)

  fig, ax = plt.subplots(figsize=(8, 8))
  graph = ax.imshow(qk_circuit_attn_heatmap, cmap="hot", interpolation="nearest")
  plt.colorbar(graph)
  plt.tight_layout()


#list(zip(all_integers[~correct_idxs], all_integers_ans[~correct_idxs]))


# # Interpretability

# ## Unembed

# In[ ]:


def plot_unembed_cosine_similarity(model):
    all_token_embeddings = model.embed(range(model.cfg.d_vocab))
    positional_embeddings = model.pos_embed(all_token_embeddings)
    all_token_pos_embed = all_token_embeddings[:,None,:] + positional_embeddings
    #print(model.W_U.shape, all_token_embeddings.shape, positional_embeddings.shape)
    # torch.Size([32, 64]) torch.Size([64, 32]) torch.Size([64, 2, 32])
    avg = F.normalize(all_token_embeddings.sum(dim=0), dim=-1)
    # overlap between model.W_U and token embedings
    input_overlap = all_token_pos_embed @ model.W_U
    print(f"Definition max_input_output_overlap := {input_overlap.abs().max()}.")
    line(F.cosine_similarity(avg[None,:], all_token_embeddings, dim=-1))


# In[ ]:


def count_monotonicity_violations_line(result_tensor, m):
    # Count the number of pairs of indices (i, j), i != j, for which
    # (result_tensor[i] + m*i - result_tensor[j] + m*j) / (i - j) is negative
    count = 0
    for i in range(len(result_tensor)):
        for j in range(i + 1, len(result_tensor)):
            if ((result_tensor[i] + m*i - result_tensor[j] + m*j) / (i - j)) < 0:
                count += 1
    return count


def reorder_tensor_greedy(tensor, m):
    # Convert to numpy for easier handling
    tensor_np = tensor.detach().clone().numpy()

    # Initialize the result with the maximum positive value
    result = [np.max(tensor_np)]
    tensor_np = np.delete(tensor_np, np.argmax(tensor_np))

    while len(tensor_np) > 0:
        # Find values that maintain the condition
        candidates = tensor_np[tensor_np - result[-1] < -m]

        if len(candidates) > 0:
            # If such values exist, select the maximum
            next_value = np.max(candidates)
        else:
            # Otherwise, select the maximum of the remaining values
            next_value = np.max(tensor_np)

        # Add the selected value to the result
        result.append(next_value)

        # Remove the selected value from the list of remaining values
        tensor_np = np.delete(tensor_np, np.where(tensor_np == next_value)[0][0])

    # Convert the result back to a tensor
    result_tensor = torch.tensor(result)

    # Count the number of indices for which the difference between
    # successive elements in the result is less than -m
    # diff = result_tensor[1:] - result_tensor[:-1]
    # count = torch.sum(diff < -m).item()

    count = count_monotonicity_violations_line(result_tensor, m)

    return result_tensor, count


def compute_best_fit_and_error(direction_dot_embed):
    n_head, d_vocab = direction_dot_embed.shape

    coefficients = torch.empty((n_head, 2))  # To store the coefficients a, b for each row
    max_abs_errors = torch.empty(n_head)  # To store the max abs error for each row
    errors = torch.empty((n_head, d_vocab))
    predicted = torch.empty((n_head, d_vocab))
    negative_pairs = []
    diff_values = []

    x_values = np.arange(d_vocab)

    # Create a meshgrid of indices
    idxi, idxj = np.meshgrid(x_values, x_values)
    # Exclude the diagonal (i == j)
    mask = idxi != idxj
    pairs = list(zip(idxi[mask], idxj[mask]))  # create a list of pairs (i, j)

    for i in range(n_head):
        row = direction_dot_embed[i].detach().numpy()

        # Use curve_fit to find a, b that best fit the data in this row
        coeff, _ = curve_fit(linear_func, x_values, row)
        coefficients[i] = torch.from_numpy(coeff)

        # Compute the predicted y values using these coefficients
        y_pred = coeff[0] * x_values + coeff[1]

        # Compute the absolute error for each value, and take the maximum
        cur_errors = row - y_pred
        max_abs_errors[i] = np.abs(cur_errors).max()
        errors[i] = torch.from_numpy(cur_errors)
        predicted[i] = torch.from_numpy(y_pred)

        # Compute (pos[i] - pos[j]) / (i - j) for all pairs (i, j)
        values = (row[idxi] - row[idxj]) / (idxi - idxj)

        # Select only the values where i != j
        values = values[mask]
        negative_pairs.append([pair for pair, value in zip(pairs, values) if value < 0])

        diff_values.append(values)

    return coefficients, max_abs_errors, errors, predicted, diff_values, negative_pairs


def plot_QK_cosine_similarity(model, keypos=-1, querypos=-1, do_layernorm=True):
    attn = model.blocks[0].attn
    all_token_embeddings = model.embed(range(model.cfg.d_vocab))
    positional_embeddings = model.pos_embed(all_token_embeddings)
    normed_all_token_embeddings = F.normalize(all_token_embeddings, dim=-1)

    token_embeddings_at_keypos = all_token_embeddings + positional_embeddings[:,keypos,:] if keypos > -1 else all_token_embeddings
    token_embeddings_at_querypos = all_token_embeddings + positional_embeddings[:,querypos,:] if querypos > -1 else all_token_embeddings

    # layernorm before attention
    if do_layernorm:
        token_embeddings_at_keypos = model.blocks[0].ln1(token_embeddings_at_keypos)
        token_embeddings_at_querypos = model.blocks[0].ln1(token_embeddings_at_querypos)

    #embeddings_key = einsum("d_vocab d_model, n_heads d_model d_head -> n_heads d_vocab d_head",
    #                        token_embeddings_at_keypos, attn.W_K)
    #embeddings_query = einsum("d_vocab d_model, n_heads d_model d_head -> n_heads d_vocab d_head",
    #                        token_embeddings_at_querypos, attn.W_Q)
    embeddings_query_waiting_for_key = einsum("d_vocab_query d_model_query, n_heads d_model_query d_head, n_heads d_model_key d_head -> n_heads d_vocab_query d_model_key",
                            token_embeddings_at_querypos, attn.W_Q, attn.W_K)

    QK = einsum("n_heads d_model_query d_head, n_heads d_model_key d_head -> n_heads d_model_query d_model_key",
                            attn.W_Q, attn.W_K)

    analyze_svd(embeddings_query_waiting_for_key[0], descr="embeddings_query_waiting_for_key")
    analyze_svd(QK[0], descr="QK")
    U, S, Vh = torch.linalg.svd(embeddings_query_waiting_for_key[0])
    print(Vh.T[0])
    print(Vh[0])
    print((U @ torch.diag(S) @ Vh)[0])
    print((U @ torch.diag(S) @ Vh).T[0])
    imshow(U @ torch.diag(S) @ Vh, title="tmp")
    #qk_circuit_attn_heatmap = einsum(
    #    "n_heads d_vocab_q d_head, n_heads d_vocab_k d_head -> ... d_vocab_q d_vocab_k",
    #    embeddings_query, embeddings_key
    #    ).detach().cpu().numpy()

    imshow(embeddings_query_waiting_for_key[0])


    direction = embeddings_query_waiting_for_key
    #direction = direction / direction.norm(dim=-1, keepdim=True)
    direction = direction.sum(dim=1) / direction.shape[1]
    print(f"Definition size_direction := {direction}.")
    direction = direction / direction.norm(dim=-1)
    print(f"Definition normed_size_direction := {direction}.")
    print(all_token_embeddings.shape, direction.shape)
    proj_direction_scale = einsum("n_head d_model_key, n_head d_vocab_query d_model_key -> n_head d_vocab_query",
                                  direction,
                                  embeddings_query_waiting_for_key)[:,:,None]
    print(proj_direction_scale.shape)
    proj_direction = proj_direction_scale * einops.rearrange(direction, "n_head d_model -> n_head () d_model")
    print(proj_direction.shape)
    remaining_directions = embeddings_query_waiting_for_key - proj_direction
    print(remaining_directions.shape)
    remaining_directions = remaining_directions.norm(dim=-1)
    print(remaining_directions.shape)
    direction_key_overlap = einsum("n_head d_model_key, n_head d_vocab_query d_model_key -> d_vocab_query n_head",
                direction,
                embeddings_query_waiting_for_key)
    print(direction_key_overlap.shape)
    print(f"Definition min_attention_query_size_direction_overlap := {direction_key_overlap.min()}.")
    direction_dot_embed = einsum("n_head d_model, d_vocab d_model -> n_head d_vocab", direction, normed_all_token_embeddings)
    direction_dot_pos_embed = einsum("n_head d_model, pos d_model -> n_head pos", direction, positional_embeddings[0])
    print(f"Definition max_direction_dot_pos_embed := {direction_dot_pos_embed.abs().max()}.")
    # linear fit of direction_dot_embed
    direction_dot_embed_coefficients, direction_dot_embed_max_abs_errors, direction_dot_embed_error, direction_dot_embed_predicted, direction_dot_embed_diff_values, direction_dot_embed_neg_values = \
          compute_best_fit_and_error(direction_dot_embed)

    direction_dot_embed_diffs = direction_dot_embed[...,1:] - direction_dot_embed[...,:-1]
    #direction_dot_embed_coef = direction_dot_embed_diffs.mean(dim=-1, keepdim=True)
    #direction_dot_embed_offset = direction_dot_embed.mean(dim=-1, keepdim=True)
    #direction_dot_embed_diff_error = direction_dot_embed_diffs - torch.arange(direction_dot_embed_diffs.shape[-1]) * direction_dot_embed_coef + direction_dot_embed_offset)
    print(direction_dot_embed_diffs)
    print(direction_dot_embed_diffs.abs())
    line(direction_dot_embed_diffs.T, title="direction_dot_embed_diffs")
    line(direction_dot_embed_diffs.T.abs(), title="direction_dot_embed_diffs abs")
    #line(direction_dot_embed_diff_error.T, title="direction_dot_embed_diff_error")
    #print(direction_dot_embed_coef, direction_dot_embed_offset)


    #direction_dot_embed_coef_better, _ = curve_fit(constant_function, np.arange(direction_dot_embed_diffs.shape[-1]), direction_dot_embed_diffs[0].detach().numpy())
    #direction_dot_embed_diff_error_better = direction_dot_embed_diffs - (torch.arange(direction_dot_embed_diffs.shape[-1]) * direction_dot_embed_coef + direction_dot_embed_offset)
    #line(direction_dot_embed_diffs.T, title="direction_dot_embed_diffs")
    #line(direction_dot_embed_diff_error.T, title="direction_dot_embed_diff_error")
    #print(direction_dot_embed_coef, direction_dot_embed_offset)


    # indices = np.argsort(direction_dot_embed_error[0].numpy() / np.arange(1, len(direction_dot_embed_error[0]) + 1))

    # Use these indices to sort 'direction_dot_embed_error'
    # sorted_direction_dot_embed_error = direction_dot_embed_error[:,indices]
    print(direction_dot_embed_error.mean(), direction_dot_embed_error.var())
    # randomly reorder direction_dot_embed_error, put in tmp
    tmp = direction_dot_embed_error[0].detach().clone().numpy()
    np.random.shuffle(tmp)
    print(tmp)
    line(tmp)
    print(count_monotonicity_violations_line(torch.tensor(tmp), direction_dot_embed_coefficients[0, 0].item()))
    print(f"Definition ")
    sorted_direction_dot_embed_error, bad_count = reorder_tensor_greedy(direction_dot_embed_error[0], direction_dot_embed_coefficients[0, 0].item())
    sorted_direction_dot_embed_error = sorted_direction_dot_embed_error[None,:]

    # sorted_direction_dot_embed_error, _ = direction_dot_embed_error.sort(dim=-1, descending=True)
    print(direction_dot_embed_coefficients, direction_dot_embed_max_abs_errors, direction_dot_embed)
    line(direction_key_overlap, title="direction @ query_waiting_for_key")
    line(remaining_directions.T, title="norm of remaining direction")
    line(F.cosine_similarity(direction, all_token_embeddings, dim=-1), title="cos_sim(direction, embed)")
    print(positional_embeddings.shape)
    line(direction_dot_embed.T, title="direction @ normed embed")
    line(torch.cat([direction_dot_embed, direction_dot_embed_predicted], dim=0).T, title="direction @ normed embed, + fit")
    print(bad_count)
    line(torch.cat([direction_dot_embed_predicted, direction_dot_embed_predicted + sorted_direction_dot_embed_error], dim=0).T, title="direction @ normed embed bad fit")

    # Plot the histogram
    print(len(direction_dot_embed_neg_values[0]) // 2)
    print(list(sorted([p for p in direction_dot_embed_neg_values[0] if p[0] < p[1]])))
    plt.hist(direction_dot_embed_diff_values[0], bins=30, edgecolor='black')
    plt.title("Distribution of (pos[i] - pos[j]) / (i - j)")
    plt.xlabel("(pos[i] - pos[j]) / (i - j)")
    plt.ylabel("Frequency")
    plt.show()

    line(direction_dot_embed_error.T, title="direction_dot_normed_embed_error")
    line(direction_dot_pos_embed.T, title="direction @ pos_embed")


# %%
def make_local_tqdm(tqdm):
    if tqdm is None:
        return lambda arg, **kwargs: arg
    else:
        return tqdm

# %%
@torch.no_grad()
def layernorm_noscale(x: torch.Tensor) -> torch.Tensor:
    return x - x.mean(axis=-1, keepdim=True)

# %%
@torch.no_grad()
def layernorm_scales(x: torch.Tensor, eps: float = 1e-5, recip: bool = True) -> torch.Tensor:
    x = layernorm_noscale(x)
    scale = (x.pow(2).mean(axis=-1, keepdim=True) + eps).sqrt()
    if recip: scale = 1 / scale
    return scale

# %%
@torch.no_grad()
def compute_singular_contribution(M: torch.Tensor, plot_heatmaps=True, yaxis=None, xaxis=None, title=None, renderer=None, description=None, singular_value_count=1, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
    U, S, Vh = torch.linalg.svd(M)
    U[:, singular_value_count:], S[singular_value_count:], Vh[singular_value_count:, :] = 0, 0, 0
    contribution = U @ torch.diag(S) @ Vh
    if plot_heatmaps:
        singular_value_str = f"first {singular_value_count} singular values" if singular_value_count != 1 else f"first singular value"
        to_description = f" to {description}" if description is not None else ""
        description = f"{description} " if description is not None else ""
        diff_zmax = (M - contribution).abs().max().item()
        zmax = np.max([contribution.abs().max().item(), diff_zmax])
        fig = make_subplots(rows=1, cols=3, subplot_titles=[f"Contribution", f"Residual", f"Residual (rescaled)"])
        fig.add_trace(go.Heatmap(z=utils.to_numpy(contribution), zmin=-zmax, zmax=zmax, showscale=True,
                                 colorbar=dict(x=-0.15, y=0.5),
                                **kwargs),
                    row=1, col=1)
        fig.add_trace(go.Heatmap(z=utils.to_numpy(M - contribution), zmin=-zmax, zmax=zmax, showscale=False,
                                **kwargs),
                    row=1, col=2)
        fig.add_trace(go.Heatmap(z=utils.to_numpy(M - contribution), zmin=-diff_zmax, zmax=diff_zmax, showscale=True,
                                **kwargs),
                    row=1, col=3)
        if title is None: title = f"Contribution of the {singular_value_str}{to_description}"
        fig.update_layout(title=title, margin=dict(l=100))
        for col in range(3):
            if yaxis is not None: fig.update_yaxes(title_text=yaxis, row=1, col=col+1)
            if xaxis is not None: fig.update_xaxes(title_text=xaxis, row=1, col=col+1)
    fig.show(renderer)
    return M - contribution, contribution
# %%

def display_size_direction_stats(size_direction: torch.Tensor, query_direction: torch.Tensor, QK: torch.Tensor, U: torch.Tensor, Vh: torch.Tensor, S: torch.Tensor,
                                 size_direction_resid: Optional[torch.Tensor] = None, size_direction_QK: Optional[torch.Tensor] = None,
                                 query_direction_resid: Optional[torch.Tensor] = None, query_direction_QK: Optional[torch.Tensor] = None,
                                 do_exclusions: bool = True,
                                 include_contribution: bool = True,
                                 scale_by_singular_value: bool = True,
                                 renderer=None,
                                 fit_funcs: Iterable = (cubic_func, quintic_func),
                                 delta_fit_funcs: Iterable = (quadratic_func, quartic_func),
                                 colorscale='Plasma_r', **kwargs):
    if scale_by_singular_value:
        U = U * S[None, :].sqrt()
        Vh = Vh * S[:, None].sqrt()
    imshow(QK, title="Attention<br>(W_E + W_pos[-1]) @ W_Q @ W_K.T @ (W_E + W_pos.mean(dim=0)).T", xaxis="Key Token", yaxis="Query Token", renderer=renderer, colorscale=colorscale, **kwargs)
    fig = make_subplots(rows=1, cols=3, subplot_titles=["Query-Side SVD", "Singular Values", "Key-Side SVD"])
    uzmax, vzmax = U.abs().max().item(), Vh.abs().max().item()
    fig.add_trace(go.Heatmap(z=utils.to_numpy(U), colorscale=colorscale, zmin=-uzmax, zmax=uzmax,
                             showscale=False,
                            #  colorbar=dict(x=-0.15, # https://community.plotly.com/t/colorbar-ticks-left-aligned/60473/4
                            #             ticklabelposition='inside',
                            #             ticksuffix='     ',
                            #             ticklabeloverflow='allow',
                            #             tickfont_color='darkslategrey',),
                            hovertemplate="Query: %{y}<br>Singular Index: %{x}<br>Value: %{z}<extra></extra>",
                            ),
                row=1, col=1)
    fig.add_trace(go.Scatter(x=np.arange(S.shape[0]), y=utils.to_numpy(S),
                            mode='lines+markers',
                            marker=dict(color='blue'),
                            line=dict(color='blue'),
                            hovertemplate="Singular Value: %{y}<br>Singular Index: %{x}<extra></extra>",
                            ), row=1, col=2)
    fig.add_trace(go.Heatmap(z=utils.to_numpy(Vh.T), colorscale=colorscale, zmin=-vzmax, zmax=vzmax,
                             showscale=False,
                            #  colorbar=dict(x=1.15),
                            hovertemplate="Key: %{y}<br>Singular Index: %{x}<br>Value: %{z}<extra></extra>",
                            ),
                row=1, col=3)
    fig.update_layout(title="Attention SVD") #, margin=dict(l=150, r=150))
    fig.update_yaxes(title_text="Query Token", row=1, col=1)
    fig.update_yaxes(range=[0, None], row=1, col=2)
    fig.update_yaxes(title_text="Key Token", row=1, col=3)
    fig.show(renderer)

    contribution_diff = None
    if include_contribution:
        contribution_diff, _ = compute_singular_contribution(
            QK, description="Attention", colorscale=colorscale, renderer=renderer, singular_value_count=1,
            xaxis='Key Token', yaxis='Query Token',
            hovertemplate="Query: %{y}<br>Key: %{x}<br>Value: %{z}<extra></extra>",
            **kwargs)

    # imshow(U, title="Query-Side SVD", yaxis="Query Token", renderer=renderer, **kwargs)
    # imshow(Vh.T, title="Key-Side SVD", yaxis="Key Token", renderer=renderer, **kwargs)
    # px.line({'singular values': utils.to_numpy(S)}, title="Singular Values of QK Attention").show(renderer)

    fig = make_subplots(rows=1, cols=2, subplot_titles=["Size", "Query"])
    fig.add_trace(go.Scatter(x=np.arange(size_direction.shape[0]), y=utils.to_numpy(size_direction),
                            mode='lines+markers',
                            marker=dict(color='blue'),
                            line=dict(color='blue'),
                            hovertemplate="Token: %{x}<br>Size: %{y}<extra></extra>",
                            ), row=1, col=1)
    fig.add_trace(go.Scatter(x=np.arange(query_direction.shape[0]), y=utils.to_numpy(query_direction),
                            mode='lines+markers',
                            marker=dict(color='blue'),
                            line=dict(color='blue'),
                            hovertemplate="Token: %{x}<br>Query Value: %{y}<extra></extra>",
                            ), row=1, col=2)
    fig.update_layout(title="Directions in Token Space", showlegend=False)
    fig.show(renderer)

    # px.line({'size direction': utils.to_numpy(size_direction)}, title="size direction in token space").show(renderer)
    # px.line({'query direction': utils.to_numpy(query_direction)}, title="query direction in token space").show(renderer)
    if size_direction_resid is not None: line(size_direction_resid, title="size direction in residual space", renderer=renderer)
    if query_direction_resid is not None: line(query_direction_resid, title="query direction in residual space", renderer=renderer)
    if size_direction_QK is not None: line(size_direction_QK, title="size direction in QK space", renderer=renderer)
    if query_direction_QK is not None: line(query_direction_QK, title="query direction in QK space", renderer=renderer)

    reference_lines = []
    if contribution_diff is not None:
        # we make some reference lines for the plots of size[i+1] - size[i]
        # since we'll eventually multiply these by the singular value and the query direction entry, we want to divide by this product when comparing to values from the non-size-direction contributions
        # we compute the mean and worst-case behavior, and a more fine-grained worst-case adjacent difference
        singular_scale = S[0].item()
        scale_per_query = query_direction * singular_scale
        resid_diffs = contribution_diff[:, :-1] - contribution_diff[:, 1:]
        resid_max_diff = contribution_diff.max().item() - contribution_diff.min().item()
        resid_max_diff_per_query = contribution_diff.max(dim=1).values - contribution_diff.min(dim=1).values
        scale_mean, scale_min = scale_per_query.mean(dim=0).item(), scale_per_query.min().item()
        resid_mean_diff = (contribution_diff[:, :, None, None] - contribution_diff[None, None, :, :]).abs().mean().item()
        resid_mean_diff_per_query = (contribution_diff[:, :, None] - contribution_diff[:, None, :]).abs().mean(dim=(-2, -1))
        reference_lines = [
            ("resid.max - resid.min (worst-case independent query)", resid_max_diff / scale_min),
            ("resid.max - resid.min (average-case independent query)", resid_max_diff / scale_mean),
            ("resid.max - resid.min (worst-case query)", (resid_max_diff_per_query / scale_per_query).max().item()),
            ("(resid[i] - resid[i+1]).max (worst-case independent query)", (resid_diffs / scale_min).max().item()),
            ("(resid[i] - resid[i+1]).max (worst-case query)", (resid_diffs / scale_per_query[:, None]).max().item()),
            ("(resid[i] - resid[i+1]).abs.mean (average-case independent query)", (resid_diffs / scale_mean).abs().mean().item()),
            ("(resid[i] - resid[j]).abs.mean (average-case independent query)", resid_mean_diff / scale_mean),
            ("(resid[i] - resid[j]).abs.mean (average-case query)", (resid_mean_diff_per_query / scale_per_query).abs().mean().item()),
        ]

    size_direction_differences = size_direction[1:] - size_direction[:-1]
    show_fits(size_direction, name='Size Direction', fit_funcs=(fit_func for fit_func in fit_funcs if fit_func is not sigmoid_func),
              do_exclusions=do_exclusions, renderer=renderer)
    show_fits(size_direction_differences, name='Size Direction Δ', reference_lines=reference_lines, fit_funcs=(fit_func for fit_func in delta_fit_funcs if fit_func is not sigmoid_func),
              do_exclusions=do_exclusions, renderer=renderer)

    y_data = size_direction.detach().cpu().numpy()
    x_data = np.linspace(1, len(y_data), len(y_data))

    for fit_func in fit_funcs:
        fit_func_name = fit_func.__name__
        if fit_func_name.endswith("_func"): fit_func_name = fit_func_name[:-len("_func")]

        if fit_func is sigmoid_func:
            # fit to sigmoid
            y_transposed = np.linspace(1, len(x_data), len(x_data))
            initial_params_transposed = [max(y_transposed), 1/np.mean(y_data), np.median(y_data)]

            # Fit the curve with initial parameters

            params_transposed, covariance_transposed = curve_fit(sigmoid_func, y_data, y_transposed, p0=initial_params_transposed, maxfev=10000)

            # Generate predicted y values with parameters
            y_pred_transposed = sigmoid_func(y_data, *params_transposed)
            # Calculating residuals
            residuals = y_transposed - y_pred_transposed

            # Creating subplots
            fig, axs = plt.subplots(2, 1, figsize=(10, 12))
            fig.suptitle('Fitting a Sigmoid to the Size Vector Components and Residuals Analysis', fontsize=16)

            # Plotting the original data and fitted curve
            axs[0].scatter(y_data, y_transposed, label='Data', color='blue')
            axs[0].plot(y_data, y_pred_transposed, color='red',
                    label=rf'{inv_sigmoid_func.equation(params_transposed)}')
            axs[0].set_xlabel('Component in Normalized Size Vector')
            axs[0].set_ylabel('Input Token')
            axs[0].legend()
            axs[0].grid(True)

            # Plotting residuals
            axs[1].scatter(y_data, residuals, color='green', label='Residuals')
            axs[1].axhline(y=0, color='r', linestyle='--', label='y=0')
            axs[1].set_xlabel('Component in Normalized Size Vector')
            axs[1].set_ylabel('Residual')
            axs[1].legend()
            axs[1].grid(True)

            # Displaying the plots
            plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # To prevent overlap between suptitle and subplots
            plt.show()


@torch.no_grad()
def find_size_and_query_direction(model: HookedTransformer, plot_heatmaps=False, renderer=None, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Approximates the size direction of the model.
    """
    W_pos, W_Q, W_K, W_E = model.W_pos, model.W_Q, model.W_K, model.W_E
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    assert W_pos.shape == (n_ctx, d_model), f"W_pos.shape = {W_pos.shape} != {(n_ctx, d_model)} = (n_ctx, d_model)"
    assert W_Q.shape == (1, 1, d_model, d_model), f"W_Q.shape = {W_Q.shape} != {(1, 1, d_model, d_model)} = (1, 1, d_model, d_model)"
    assert W_K.shape == (1, 1, d_model, d_model), f"W_K.shape = {W_K.shape} != {(1, 1, d_model, d_model)} = (1, 1, d_model, d_model)"
    assert W_E.shape == (d_vocab, d_model), f"W_E.shape = {W_E.shape} != {(d_vocab, d_model)} = (d_vocab, d_model)"

    QK = (W_E + W_pos[-1]) @ W_Q[0, 0, :, :] @ W_K[0, 0, :, :].T @ (W_E + W_pos.mean(dim=0)).T
    assert QK.shape == (d_vocab, d_vocab), f"QK.shape = {QK.shape} != {(d_vocab, d_vocab)} = (d_vocab, d_vocab)"

    # take SVD:
    U, S, Vh = torch.linalg.svd(QK)
    # adjust the free parameter of sign
    sign = torch.sign(U[:, 0].mean())
    U, Vh = U * sign, Vh * sign

    # the size direction is the first column of Vh, normalized
    # query direction is the first column of U, normalized
    size_direction, query_direction = Vh[0, :], U[:, 0]
    size_query_singular_value = S[0] * size_direction.norm() * query_direction.norm()
    size_direction, query_direction = size_direction / size_direction.norm(), query_direction / query_direction.norm()

    if plot_heatmaps:
        size_direction_resid, query_direction_resid = size_direction @ W_E + W_pos[-1], query_direction @ W_E + W_pos.mean(dim=0)
        size_direction_QK, query_direction_QK = size_direction_resid @ W_Q[0, 0, :, :], query_direction_resid @ W_K[0, 0, :, :]

        display_size_direction_stats(size_direction, query_direction, QK, U, Vh, S,
                                    # size_direction_resid=size_direction_resid, size_direction_QK=size_direction_QK,
                                    # query_direction_resid=query_direction_resid, query_direction_QK=query_direction_QK,
                                    renderer=renderer, **kwargs)

    return size_direction, query_direction, size_query_singular_value.item()


@torch.no_grad()
def find_size_direction(model: HookedTransformer, **kwargs):
    """
    Approximates the size direction of the model.
    """
    return find_size_and_query_direction(model, **kwargs)[0]

@torch.no_grad()
def find_query_direction(model: HookedTransformer, **kwargs):
    """
    Approximates the query direction of the model.
    """
    return find_size_and_query_direction(model, **kwargs)[1]

# %%
@torch.no_grad()
def find_backwards_attention(model: HookedTransformer):
    W_pos, W_Q, W_K, W_E = model.W_pos, model.W_Q, model.W_K, model.W_E
    d_model, d_vocab, n_ctx = model.cfg.d_model, model.cfg.d_vocab, model.cfg.n_ctx
    assert W_pos.shape == (n_ctx, d_model), f"W_pos.shape = {W_pos.shape} != {(n_ctx, d_model)} = (n_ctx, d_model)"
    assert W_Q.shape == (1, 1, d_model, d_model), f"W_Q.shape = {W_Q.shape} != {(1, 1, d_model, d_model)} = (1, 1, d_model, d_model)"
    assert W_K.shape == (1, 1, d_model, d_model), f"W_K.shape = {W_K.shape} != {(1, 1, d_model, d_model)} = (1, 1, d_model, d_model)"
    assert W_E.shape == (d_vocab, d_model), f"W_E.shape = {W_E.shape} != {(d_vocab, d_model)} = (d_vocab, d_model)"

    QK = (W_E + W_pos[-1]) @ W_Q[0, 0, :, :] @ W_K[0, 0, :, :].T @ (W_E + W_pos[:, None, :]).transpose(-1, -2)
    assert QK.shape == (n_ctx, d_vocab, d_vocab), f"QK.shape = {QK.shape} != {(n_ctx, d_vocab, d_vocab)} = (n_ctx, d_vocab, d_vocab)"
    # diffs0 = QK[:, :, :-1].max(dim=0).values - QK[:, :, 1:].min(dim=0).values
    diffs = QK[:, :, :-1] - QK[:, :, 1:].flip(dims=(0,))
    return torch.nonzero(diffs >= 0).squeeze()


In [None]:
#@title interp_max_utils
from typing import Any, Dict, Optional, Tuple, Union
from torchtyping import TensorType
from enum import Enum, verify, UNIQUE, CONTINUOUS
import enum
import itertools
from fancy_einsum import einsum
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit
import torch
import torch.nn.functional as F
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils
import plotly.express as px
import math

# In[ ]:
def complexity_of(f):
    lines = (line.split(':') for line in f.__doc__.split('\n'))
    lines = (line for line in lines if line[0].lower().strip().startswith('complexity'))
    lines = (':'.join(line[1:]).strip() if line[0].lower().strip() == 'complexity' else ':'.join(line).strip()[len('complexity'):].strip()
             for line in lines)
    return '\n'.join(lines)

# In[ ]:
@torch.no_grad()
def logit_delta_of_results(all_tokens: TensorType["batch", "n_ctx"], predicted_logits: TensorType["batch", "d_vocab_out"], renderer=None, histogram_all_incorrect_logit_differences: bool = False, return_summary: bool = False, hist_args={}) -> Union[float, Dict[str, Any]]: # noqa: F821
    """
    Largest difference between logit(true_max) and logit(y) for y != true_max.
    """
    (batch, n_ctx), (_batch, d_vocab_out) = all_tokens.shape, predicted_logits.shape
    assert predicted_logits.shape == (batch, d_vocab_out), f"predicted_logits.shape = {predicted_logits.shape} != {(batch, d_vocab_out)} = (batch, d_vocab_out)"

    # Extract statistics for each row
    # Use values in all_tokens as indices to gather correct logits
    indices_of_max = all_tokens.max(dim=-1, keepdim=True).values
    assert indices_of_max.shape == (batch, 1), f"indices_of_max.shape = {indices_of_max.shape} != {(batch, 1)} = (batch, 1)"
    correct_logits = torch.gather(predicted_logits, -1, indices_of_max)
    assert correct_logits.shape == (batch, 1), f"correct_logits.shape = {correct_logits.shape} != {(batch, 1)} = (batch, 1)"
    logits_above_correct = correct_logits - predicted_logits
    assert logits_above_correct.shape == (batch, d_vocab_out), f"logits_above_correct.shape = {logits_above_correct.shape} != {(batch, d_vocab_out)} = (batch, d_vocab_out)"
    # replace correct logit indices with large number so that they don't get picked up by the min
    logits_above_correct[torch.arange(logits_above_correct.shape[0]), indices_of_max.squeeze()] = float('inf')
    min_incorrect_logit = logits_above_correct.min(dim=-1).values
    assert min_incorrect_logit.shape == (batch,), f"min_incorrect_logit.shape = {min_incorrect_logit.shape} != {(batch,)} = (batch,)"

    if histogram_all_incorrect_logit_differences:
        all_incorrect_logits = logits_above_correct[logits_above_correct != float('inf')]
        summarize(all_incorrect_logits, name='all incorrect logit differences', histogram=True, hist_args=hist_args, renderer=renderer)

    if return_summary:
        return summarize(min_incorrect_logit, name='min(correct logit - incorrect logit)', renderer=renderer, histogram=True)

    else:
        return min_incorrect_logit.min().item()


# In[ ]:
@torch.no_grad()
def logit_delta(model: HookedTransformer, renderer=None, histogram_all_incorrect_logit_differences: bool = False, return_summary: bool = False, hist_args={}) -> Union[float, Dict[str, Any]]:
    """
    Largest difference between logit(true_max) and logit(y) for y != true_max.
    Complexity: O(d_vocab^n_ctx * fwd_pass)
    Complexity: fwd_pass = O(n_ctx * d_model + n_ctx * d_model + n_ctx * d_model^2 * d_hidden * 2 + n_ctx * d_hidden^2 + n_ctx * d_model^2 * d_hidden + n_ctx * d_hidden^2 * d_model + n_ctx * d_model + n_ctx * d_model^2 * d_vocab)
    Complexity: n_ctx^2 * d_vocab * d_model^2) + (n_ctx * d_vocab * d_model^2)
    todo fix complexity.
    """
    n_ctx, d_vocab, d_vocab_out, d_model = model.cfg.n_ctx, model.cfg.d_vocab, model.cfg.d_vocab_out, model.cfg.d_model

    all_tokens = compute_all_tokens(model=model)
    assert all_tokens.shape == (d_vocab**n_ctx, n_ctx), f"all_tokens.shape = {all_tokens.shape} != {(d_vocab**n_ctx, n_ctx)} = (d_vocab**n_ctx, n_ctx)"
    predicted_logits = model(all_tokens)[:,-1,:].detach().cpu()
    assert predicted_logits.shape == (d_vocab**n_ctx, d_vocab_out), f"predicted_logits.shape = {predicted_logits.shape} != {(d_vocab**n_ctx, d_vocab_out)} = (d_vocab**n_ctx, d_vocab_out)"

    return logit_delta_of_results(all_tokens=all_tokens, predicted_logits=predicted_logits, renderer=renderer, histogram_all_incorrect_logit_differences=histogram_all_incorrect_logit_differences, return_summary=return_summary, hist_args=hist_args)

# In[ ]:
@torch.no_grad()
def compute_gap(all_tokens: TensorType["batch", "n_ctx"]) -> TensorType["batch"]: # noqa: F821
    """
    computes the gap between the max token and the second max token in each row of all_tokens
    """
    maxv = all_tokens.max(dim=-1, keepdim=True).values
    all_but_maxv = all_tokens.clone()
    all_but_maxv[all_but_maxv == maxv] = -all_tokens.max().item()
    second_maxv = all_but_maxv.max(dim=-1, keepdim=True).values
    second_maxv[second_maxv < 0] = maxv[second_maxv < 0]
    return (maxv - second_maxv)[:, 0]

# In[ ]:
@torch.no_grad()
def all_tokens_small_gap(model: HookedTransformer, max_min_gap: int = 1) -> TensorType["batch", "n_ctx"]: # noqa: F821
    """
    All sequences of tokens with the constraint that some token z in the sequence satisfies true_max - max_min_gap <= z < true_max
    Complexity: O(d_vocab ^ (n_ctx - 1) * (max_min_gap * 2 + 1))
    """
    n_ctx, d_vocab = model.cfg.n_ctx, model.cfg.d_vocab

    all_tokens_after_start = generate_all_sequences(n_digits=d_vocab, sequence_length=n_ctx - 1)
    all_tokens_after_start_max = all_tokens_after_start.max(dim=-1, keepdim=True).values
    all_tokens_after_start_max_minf = all_tokens_after_start.clone()
    all_tokens_after_start_max_minf[all_tokens_after_start_max_minf == all_tokens_after_start_max] = -max_min_gap - 1
    all_tokens_after_start_second_max = all_tokens_after_start_max_minf.max(dim=-1, keepdim=True).values
    first_token_max = all_tokens_after_start_max + max_min_gap + 1
    gap_already_present = all_tokens_after_start_second_max >= all_tokens_after_start_max - max_min_gap
    first_token_upper_min = all_tokens_after_start_max + gap_already_present.long()
    first_token_min = torch.zeros_like(first_token_max)
    first_token_min[~gap_already_present] = all_tokens_after_start_max[~gap_already_present] - max_min_gap
    first_token_min[first_token_min < 0] = 0
    first_token_max[first_token_max >= d_vocab] = d_vocab
    first_token_upper_min[first_token_upper_min >= d_vocab] = d_vocab
    assert first_token_max.shape == (d_vocab**(n_ctx - 1), 1), f"first_token_max.shape = {first_token_max.shape} != {(d_vocab**(n_ctx - 1), 1)} = (d_vocab**(n_ctx - 1), 1)"
    assert first_token_upper_min.shape == (d_vocab**(n_ctx - 1), 1), f"first_token_upper_min.shape = {first_token_upper_min.shape} != {(n_ctx, 1)} = (d_vocab**(n_ctx - 1), 1)"
    assert all_tokens_after_start_max.shape == (d_vocab**(n_ctx - 1), 1), f"all_tokens_after_start_max.shape = {all_tokens_after_start_max.shape} != {(d_vocab**(n_ctx - 1), 1)} = (d_vocab**(n_ctx - 1), 1)"
    assert first_token_min.shape == (d_vocab**(n_ctx - 1), 1), f"first_token_min.shape = {first_token_min.shape} != {(d_vocab**(n_ctx - 1), 1)} = (d_vocab**(n_ctx - 1), 1)"
    first_token_max, first_token_upper_min, all_tokens_after_start_max, first_token_min = first_token_max[:, 0], first_token_upper_min[:, 0], all_tokens_after_start_max[:, 0], first_token_min[:, 0]
    first_token_ranges = [torch.cat([torch.arange(lower, mid), torch.arange(lower_big, upper)]) for lower, mid, lower_big, upper in zip(first_token_min, all_tokens_after_start_max, first_token_upper_min, first_token_max)]
    all_tokens_with_small_gap = torch.cat([torch.cartesian_prod(first_tokens, *rest_tokens[:, None]) for first_tokens, rest_tokens in zip(first_token_ranges, all_tokens_after_start)])

    return all_tokens_with_small_gap

# In[ ]:
@torch.no_grad()
def logit_delta_small_gap_exhaustive(model: HookedTransformer, max_min_gap: int = 1, renderer=None, histogram_all_incorrect_logit_differences: bool = False, return_summary: bool = False, hist_args={}) -> Union[float, Dict[str, Any]]:
    """
    Largest difference between logit(true_max) and logit(y) for y != true_max, with the constraint that some token z in the sequence satisfies true_max - max_min_gap <= z < true_max
    Complexity: O(d_vocab ^ (n_ctx - 1) * (max_min_gap * 2 + 1) * fwd_pass)
    Complexity: fwd_pass = O(n_ctx * d_model + n_ctx * d_model + n_ctx * d_model^2 * d_hidden * 2 + n_ctx * d_hidden^2 + n_ctx * d_model^2 * d_hidden + n_ctx * d_hidden^2 * d_model + n_ctx * d_model + n_ctx * d_model^2 * d_vocab)
    Complexity: n_ctx^2 * d_vocab * d_model^2) + (n_ctx * d_vocab * d_model^2)
    todo fix complexity.
    """
    n_ctx, d_vocab, d_vocab_out, d_model = model.cfg.n_ctx, model.cfg.d_vocab, model.cfg.d_vocab_out, model.cfg.d_model

    all_tokens = all_tokens_small_gap(model, max_min_gap=max_min_gap)
    assert len(all_tokens.shape) == 2 and all_tokens.shape[1] == n_ctx, f"all_tokens.shape = {all_tokens.shape} != (_, {n_ctx}) = (_, n_ctx)"
    predicted_logits = model(all_tokens)[:,-1,:].detach().cpu()
    assert len(predicted_logits.shape) == 2 and predicted_logits.shape[1] == d_vocab_out, f"predicted_logits.shape = {predicted_logits.shape} != (_, {d_vocab_out}) = (_, d_vocab_out)"

    return logit_delta_of_results(all_tokens=all_tokens, predicted_logits=predicted_logits, renderer=renderer, histogram_all_incorrect_logit_differences=histogram_all_incorrect_logit_differences, return_summary=return_summary, hist_args=hist_args)

# In[ ]:
@torch.no_grad()
def logit_delta_by_gap(model: HookedTransformer, renderer=None, histogram_all_incorrect_logit_differences: bool = False, return_summary: bool = False, hist_args={}) -> Dict[int, Union[float, Dict[str, Any]]]:
    """
    Largest difference between logit(true_max) and logit(y) for y != true_max, with the constraint that all non-max tokens in the sequence are strictly more than gap away from the true max, indexed by gap
    Complexity: O(d_vocab ^ n_ctx * fwd_pass)
    Complexity: fwd_pass = O(n_ctx * d_model + n_ctx * d_model + n_ctx * d_model^2 * d_hidden * 2 + n_ctx * d_hidden^2 + n_ctx * d_model^2 * d_hidden + n_ctx * d_hidden^2 * d_model + n_ctx * d_model + n_ctx * d_model^2 * d_vocab)
    Complexity: n_ctx^2 * d_vocab * d_model^2) + (n_ctx * d_vocab * d_model^2)
    todo fix complexity.
    """
    n_ctx, d_vocab, d_vocab_out, d_model = model.cfg.n_ctx, model.cfg.d_vocab, model.cfg.d_vocab_out, model.cfg.d_model

    all_tokens = compute_all_tokens(model=model)
    assert all_tokens.shape == (d_vocab**n_ctx, n_ctx), f"all_tokens.shape = {all_tokens.shape} != {(d_vocab**n_ctx, n_ctx)} = (d_vocab**n_ctx, n_ctx)"
    predicted_logits = model(all_tokens)[:,-1,:].detach().cpu()
    assert predicted_logits.shape == (all_tokens.shape[0], d_vocab_out), f"predicted_logits.shape = {predicted_logits.shape} != {(all_tokens.shape[0], d_vocab_out)} = (all_tokens.shape[0], d_vocab_out)"
    gaps = compute_gap(all_tokens)
    assert gaps.shape == (all_tokens.shape[0],), f"gaps.shape = {gaps.shape} != {(all_tokens.shape[0],)} = (all_tokens.shape[0],)"
    return {gap: logit_delta_of_results(all_tokens=all_tokens[gaps == gap, :], predicted_logits=predicted_logits[gaps == gap, :], renderer=renderer, histogram_all_incorrect_logit_differences=histogram_all_incorrect_logit_differences, return_summary=return_summary, hist_args=hist_args)
            for gap in range(d_vocab)}

# In[ ]:
@torch.no_grad()
def EU_PU(model: HookedTransformer, renderer=None, pos: int = -1) -> TensorType["d_vocab_q", "d_vocab_out"]: # noqa: F821
    """
    Calculates logits from just the EU and PU paths in position pos.
    Complexity: O(d_vocab^2 * d_model)
    Return shape: (d_vocab, d_vocab_out) (indexed by query token)
    """
    W_E, W_pos, W_U = model.W_E, model.W_pos, model.W_U
    d_model, n_ctx, d_vocab, d_vocab_out = model.cfg.d_model, model.cfg.n_ctx, model.cfg.d_vocab, model.cfg.d_vocab_out
    assert W_E.shape == (d_vocab, d_model)
    assert W_pos.shape == (n_ctx, d_model)
    assert W_U.shape == (d_model, d_vocab_out)

    result = (W_E + W_pos[pos][None, :]) @ W_U
    assert result.shape == (d_vocab, d_vocab_out)

    return result

# In[ ]:
@torch.no_grad()
def all_attention_scores(model: HookedTransformer) -> TensorType["n_ctx_k", "d_vocab_q", "d_vocab_k"]: # noqa: F821
    """
    Returns pre-softmax attention of shape (n_ctx_k, d_vocab_q, d_vocab_k)
    Complexity: O(d_vocab * d_head^2 * d_model * n_ctx)
    """
    W_E, W_pos, W_Q, W_K = model.W_E, model.W_pos, model.W_Q, model.W_K
    d_model, n_ctx, d_vocab, d_head = model.cfg.d_model, model.cfg.n_ctx, model.cfg.d_vocab, model.cfg.d_head
    assert W_E.shape == (d_vocab, d_model)
    assert W_pos.shape == (n_ctx, d_model)
    assert W_Q.shape == (1, 1, d_model, d_head)
    assert W_K.shape == (1, 1, d_model, d_head)

    last_resid = (W_E + W_pos[-1]) # (d_vocab, d_model). Rows = possible residual streams.
    assert last_resid.shape == (d_vocab, d_model), f"last_resid.shape = {last_resid.shape} != {(d_vocab, d_model)} = (d_vocab, d_model)"
    key_tok_resid = (W_E + W_pos[:, None, :]) # (n_ctx, d_vocab, d_model). Dim 1 = possible residual streams.
    assert key_tok_resid.shape == (n_ctx, d_vocab, d_model), f"key_tok_resid.shape = {key_tok_resid.shape} != {(n_ctx, d_vocab, d_model)} = (n_ctx, d_vocab, d_model)"
    q = last_resid @ W_Q[0, 0, :, :] # (d_vocab, d_head).
    assert q.shape == (d_vocab, d_head), f"q.shape = {q.shape} != {(d_vocab, d_head)} = (d_vocab, d_head)"
    k = einsum('n_ctx d_vocab d_head, d_head d_model_k -> n_ctx d_model_k d_vocab', key_tok_resid, W_K[0, 0, :, :])
    assert k.shape == (n_ctx, d_head, d_vocab), f"k.shape = {k.shape} != {(n_ctx, d_head, d_vocab)} = (n_ctx, d_head, d_vocab)"
    x_scores = einsum('d_vocab_q d_head, n_ctx d_head d_vocab_k -> n_ctx d_vocab_q d_vocab_k', q, k)
    assert x_scores.shape == (n_ctx, d_vocab, d_vocab), f"x_scores.shape = {x_scores.shape} != {(n_ctx, d_vocab, d_vocab)} = (n_ctx, d_vocab, d_vocab)"
    # x_scores[pos, qt, kt] is the score from query token qt to key token kt at position pos

    return x_scores

# In[ ]:
@torch.no_grad()
def all_EVOU(model: HookedTransformer) -> TensorType["d_vocab", "d_vocab_out"]: # noqa: F821
    """
    Returns all OV results, ignoring position, of shape (d_vocab, d_vocab_out)
    Complexity: O(d_vocab * (d_model^2 * d_head + d_head^2 * d_model + d_model^2 * d_vocab_out)) ~ O(d_vocab^2 * d_model^2)
    """
    W_E, W_O, W_V, W_U = model.W_E, model.W_O, model.W_V, model.W_U
    d_model, d_vocab, d_head, d_vocab_out = model.cfg.d_model, model.cfg.d_vocab, model.cfg.d_head, model.cfg.d_vocab_out
    assert W_E.shape == (d_vocab, d_model)
    assert W_O.shape == (1, 1, d_model, d_head)
    assert W_V.shape == (1, 1, d_model, d_head)
    assert W_U.shape == (d_model, d_vocab_out)

    EVOU = W_E @ W_V[0, 0, :, :] @ W_O[0, 0, :, :] @ W_U # (d_vocab, d_vocab). EVOU[i, j] is how copying i affects j.
    assert EVOU.shape == (d_vocab, d_vocab_out), f"EVOU.shape = {EVOU.shape} != {(d_vocab, d_vocab_out)} = (d_vocab, d_vocab_out)"
    return EVOU


# In[ ]:
@torch.no_grad()
def all_PVOU(model: HookedTransformer) -> TensorType["n_ctx", "d_vocab_out"]: # noqa: F821
    """
    Returns all OV results, position only, of shape (n_ctx, d_vocab_out)
    Complexity: O(n_ctx * (d_model^2 * d_head + d_head^2 * d_model + d_model^2 * d_vocab_out)) ~ O(n_ctx * d_vocab * d_model^2)
    """
    W_pos, W_O, W_V, W_U = model.W_pos, model.W_O, model.W_V, model.W_U
    d_model, n_ctx, d_head, d_vocab_out = model.cfg.d_model, model.cfg.n_ctx, model.cfg.d_head, model.cfg.d_vocab_out
    assert W_pos.shape == (n_ctx, d_model)
    assert W_O.shape == (1, 1, d_model, d_head)
    assert W_V.shape == (1, 1, d_model, d_head)
    assert W_U.shape == (d_model, d_vocab_out)

    PVOU = W_pos @ W_V[0, 0, :, :] @ W_O[0, 0, :, :] @ W_U # (n_ctx, d_vocab_out). PVOU[i, j] is how copying at position i affects logit j.
    assert PVOU.shape == (n_ctx, d_vocab_out), f"PVOU.shape = {PVOU.shape} != {(n_ctx, d_vocab_out)} = (n_ctx, d_vocab_out)"
    return PVOU


# In[ ]:
@torch.no_grad()
def find_all_d_attention_scores(model: HookedTransformer, min_gap: int = 1) -> Union[TensorType["d_vocab_q", "d_vocab_k"], TensorType["d_vocab_q", "n_ctx_max", "n_ctx_non_max", "d_vocab_k_max", "d_vocab_k_nonmax"]]: # noqa: F821
    """
    If input tokens are x, y, with x - y > min_gap, the minimum values of
    score(x) - score(y).

    Complexity: O(d_vocab * d_model^2 * n_ctx + d_vocab^min(3,n_ctx) * n_ctx^min(2,n_ctx-1))
    Returns: d_attention_score indexed by
        if n_ctx <= 2:
            (d_vocab_q, d_vocab_k)
        if n_ctx > 2:
            (d_vocab_q, n_ctx_max, n_ctx_non_max, d_vocab_k_max, d_vocab_k_nonmax)
    """
    n_ctx, d_vocab = model.cfg.d_model, model.cfg.n_ctx, model.cfg.d_vocab
    x_scores = all_attention_scores(model)
    assert x_scores.shape == (n_ctx, d_vocab, d_vocab), f"x_scores.shape = {x_scores.shape} != {(n_ctx, d_vocab, d_vocab)} = (n_ctx, d_vocab, d_vocab)"
    # x_scores[pos, qt, kt] is the score from query token qt to key token kt at position pos

    if n_ctx <= 2:
        # when there are only two cases, it must be the case that either the max is in the query slot, or the non-max is in the query slot
        scores = torch.zeros((d_vocab, d_vocab)) + float('inf')
        for q_tok in range(d_vocab):
            for k_tok in range(d_vocab):
                if math.abs(k_tok - q_tok) >= min_gap:
                    # q_tok is always in the last position
                    scores[q_tok, k_tok] = (x_scores[0, q_tok, k_tok].item() - x_scores[-1, q_tok, q_tok].item()) * np.sign(k_tok-q_tok)
    else:
        # when there are more than two cases, we need to consider all cases
        scores = torch.zeros((d_vocab, n_ctx, n_ctx, d_vocab, d_vocab)) + float('inf')
        for q_tok in range(d_vocab):
            for pos_of_max in range(n_ctx):
                for k_tok_max in range(d_vocab):
                    if pos_of_max == n_ctx - 1 and k_tok_max != q_tok: continue
                    for pos_of_non_max in range(n_ctx):
                        if pos_of_max == pos_of_non_max: continue
                        for k_tok_non_max in range(k_tok_max - (min_gap - 1)):
                            if pos_of_non_max == n_ctx - 1 and k_tok_non_max != q_tok: continue
                            scores[q_tok, pos_of_max, pos_of_non_max, k_tok_max, k_tok_non_max] = x_scores[pos_of_max, q_tok, k_tok_max].item() - x_scores[pos_of_non_max, q_tok, k_tok_non_max].item()

    return scores


# In[ ]:
@torch.no_grad()
def find_min_d_attention_score(model: HookedTransformer, min_gap: int = 1, reduce_over_query=False) -> Union[float, TensorType["d_vocab_q"]]: # noqa: F821
    """
    If input tokens are x, y, with x - y > min_gap, the minimum value of
    score(x) - score(y).

    Complexity: O(d_vocab * d_model^2 * n_ctx + d_vocab^min(3,n_ctx) * n_ctx^min(2,n_ctx-1))
    Returns: float if reduce_over_query else torch.Tensor[d_vocab] (indexed by query token)
    """
    scores = find_all_d_attention_scores(model, min_gap=min_gap)
    while len(scores.shape) != 1:
        scores = scores.min(dim=-1).values
    if reduce_over_query:
        scores = scores.min(dim=0).values.item()
    return scores

# In[ ]:
@torch.no_grad()
def EU_PU_PVOU(model: HookedTransformer, attention_pattern: TensorType["batch", "n_ctx"]) -> TensorType["batch", "d_vocab_q", "d_vocab_out"]: # noqa: F821
    """
    Calculates logits from EU, PU, and the positional part of the OV path for a given batch of attentions
    attention_pattern: (batch, n_ctx) # post softmax
    Returns: (batch, d_vocab_q, d_vocab_out)
    Complexity: O(d_vocab^2 * d_model + d_vocab^2 * d_model^2 + batch * n_ctx * d_vocab_out + batch * d_vocab^2)
    """
    n_ctx, d_vocab, d_vocab_out = model.cfg.n_ctx, model.cfg.d_vocab, model.cfg.d_vocab_out
    batch, _ = attention_pattern.shape
    assert attention_pattern.shape == (batch, n_ctx), f"attention_post_softmax.shape = {attention_pattern.shape} != {(batch, n_ctx)} = (batch, n_ctx)"
    EUPU = EU_PU(model)
    assert EUPU.shape == (d_vocab, d_vocab_out), f"EUPU.shape = {EUPU.shape} != {(d_vocab, d_vocab_out)} = (d_vocab, d_vocab_out)"
    PVOU = all_PVOU(model)
    assert PVOU.shape == (n_ctx, d_vocab_out), f"PVOU.shape = {PVOU.shape} != {(n_ctx, d_vocab_out)} = (n_ctx, d_vocab_out)"
    PVOU_scaled = attention_pattern @ PVOU
    assert PVOU_scaled.shape == (batch, d_vocab_out), f"PVOU_scaled.shape = {PVOU_scaled.shape} != {(batch, d_vocab_out)} = (batch, d_vocab_out)"
    result = EUPU[None, :, :] + PVOU_scaled[:, None, :]
    assert result.shape == (batch, d_vocab, d_vocab_out), f"result.shape = {result.shape} != {(batch, d_vocab, d_vocab_out)} = (batch, d_vocab, d_vocab_out)"

    return result

# In[ ]:
# @verify(UNIQUE, CONTINUOUS)
# class TokenType(Enum):
#     EXACT = enum.auto() # max, or within gap
#     BELOW_GAP = enum.auto()

# In[ ]:
@torch.no_grad()
def worst_PVOU_gap_for(model: HookedTransformer, query_tok: int, max_tok: int,
                       min_gap: int = 0,
                       PVOU: Optional[TensorType["n_ctx", "d_vocab_out"]] = None, # noqa: F821
                       attention_score_map: Optional[TensorType["n_ctx_k", "d_vocab_q", "d_vocab_k"]] = None, # noqa: F821
                       optimize_max_query_comparison=True) -> TensorType["d_vocab_out"]: # noqa: F821
    """
    Returns a map of non_max_output_tok to PVOU with the worst (largest) value of PVOU[non_max_output_tok] - PVOU[max_tok],
        across all possible attention scalings for the query token and for token values <= max_tok - min_gap.
    Complexity: O(PVOU + attention_score_map + d_vocab_out * n_ctx^2)
    Complexity: ~ O(n_ctx * d_vocab * d_model^2 (from PVOU) + d_vocab * d_head^2 * d_model * n_ctx (from attention_score_map) + (n_ctx * log(n_ctx) (sorting) + n_ctx^2) * d_vocab)
    Complexity: (for n_ctx=2) O(POVU + attention_score_map + n_ctx)
    N.B. Clever caching could reduce n_ctx^2 to n_ctx, leaving n_ctx log(n_ctx) from sorting as the dominant factor
    N.B. If optimize_max_query_comparison is set, and n_ctx is 2, then whenever query_tok != max_tok we know exactly what the sequence is and can just compute the attention
    """
    assert max_tok >= query_tok, f"max_tok = {max_tok} < {query_tok} = query_tok"
    assert max_tok == query_tok or max_tok >= query_tok + min_gap, f"max_tok = {max_tok} < {query_tok} + {min_gap} = query_tok + min_gap"
    n_ctx, d_vocab_out, d_vocab = model.cfg.n_ctx, model.cfg.d_vocab_out, model.cfg.d_vocab
    if PVOU is None: PVOU = all_PVOU(model)
    assert PVOU.shape == (n_ctx, d_vocab_out), f"PVOU.shape = {PVOU.shape} != {(n_ctx, d_vocab_out)} = (n_ctx, d_vocab_out)"
    if attention_score_map is None: attention_score_map = all_attention_scores(model)
    assert attention_score_map.shape == (n_ctx, d_vocab, d_vocab), f"attention_scores.shape = {attention_score_map.shape} != {(n_ctx, d_vocab, d_vocab)} = (n_ctx, d_vocab, d_vocab)"
    worst_attention_score = torch.zeros((n_ctx,))
    worst_attention_score[-1] = attention_score_map[-1, query_tok, query_tok]
    if n_ctx == 2 and optimize_max_query_comparison and query_tok != max_tok:
        worst_attention_score[0] = attention_score_map[0, query_tok, max_tok]
        worst_PVOU = worst_attention_score.softmax(dim=-1) @ PVOU
        return worst_PVOU - worst_PVOU[max_tok]
    elif max_tok - min_gap < 0:
        # everything must be the max
        worst_PVOU = attention_score_map[:, query_tok, max_tok].softmax(dim=-1) @ PVOU
        return worst_PVOU - worst_PVOU[max_tok]
    else:
        # compute the min and max attention scores for each position and query token where the key token is either max_tok or <= max_tok - gap
        min_attention_scores_below_gap, max_attention_scores_below_gap = attention_score_map[:-1, query_tok, :max_tok+1-min_gap].min(dim=-1).values, attention_score_map[:-1, query_tok, :max_tok+1-min_gap].max(dim=-1).values
        assert min_attention_scores_below_gap.shape == (n_ctx-1,), f"min_attention_scores.shape = {min_attention_scores_below_gap.shape} != {(n_ctx-1,)} = (n_ctx-1,)"
        assert max_attention_scores_below_gap.shape == (n_ctx-1,), f"max_attention_scores.shape = {max_attention_scores_below_gap.shape} != {(n_ctx-1,)} = (n_ctx-1,)"
        min_attention_scores = torch.minimum(attention_score_map[:-1, query_tok, max_tok], min_attention_scores_below_gap)
        max_attention_scores = torch.maximum(attention_score_map[:-1, query_tok, max_tok], max_attention_scores_below_gap)
        assert min_attention_scores.shape == (n_ctx-1,), f"min_attention_scores.shape = {min_attention_scores.shape} != {(n_ctx-1,)} = (n_ctx-1,)"
        assert max_attention_scores.shape == (n_ctx-1,), f"max_attention_scores.shape = {max_attention_scores.shape} != {(n_ctx-1,)} = (n_ctx-1,)"
        worst_attention_score[:-1] = min_attention_scores
        PVOU = PVOU.T
        assert PVOU.shape == (d_vocab_out, n_ctx), f"PVOU.T.shape = {PVOU.shape} != {(d_vocab_out, n_ctx)} = (d_vocab_out, n_ctx)"
        worst_PVOU = torch.zeros((d_vocab_out, ))
        d_PVOU = PVOU[:, :] - PVOU[max_tok, :][None, :]
        assert d_PVOU.shape == (d_vocab_out, n_ctx), f"d_PVOU.shape = {d_PVOU.shape} != {(d_vocab_out, n_ctx)} = (d_vocab_out, n_ctx)"
        # sort d_PVOU in descending order
        _, d_PVOU_idxs = d_PVOU[:, :-1].sort(dim=-1, descending=True)
        for non_max_output_tok in range(d_vocab_out):
            worst_attention_score[:-1] = min_attention_scores
            for i in d_PVOU_idxs[non_max_output_tok, :]:
                # compare d_PVOU weighted by softmax of worst_attention_score for worst_attention_score[i] in (min_attention_scores[i], max_attention_scores[i])
                # set worst_attention_score[i] to whichever one is worse (more positive)
                # print(d_PVOU.shape, worst_attention_score.softmax(dim=-1).shape)
                min_d_PVOU = worst_attention_score.softmax(dim=-1) @ d_PVOU[non_max_output_tok, :]
                worst_attention_score[i] = max_attention_scores[i]
                max_d_PVOU = worst_attention_score.softmax(dim=-1) @ d_PVOU[non_max_output_tok, :]
                if min_d_PVOU > max_d_PVOU: worst_attention_score[i] = min_attention_scores[i]
            worst_PVOU[non_max_output_tok] = worst_attention_score.softmax(dim=-1) @ d_PVOU[non_max_output_tok, :]
            # print(i, min_attention_scores[i], worst_attention_score[i], max_attention_scores[i], min_d_PVOU, max_d_PVOU, d_PVOU[i])
        # return the PVOU for the worst_attention_score
        return worst_PVOU

# In[ ]:
@torch.no_grad()
def all_worst_PVOU(model: HookedTransformer, min_gap: int = 0, tqdm=None, **kwargs) -> TensorType["d_vocab_q", "d_vocab_max", "d_vocab_out"]: # noqa: F821
    """
    Returns the mixture of PVOUs with the worst (largest) value of PVOU[non_max_output_tok] - PVOU[max_tok], across all possible attention scalings for the query token and for token values <= max_tok - min_gap.
    Complexity: O(PVOU + attention_score_map + n_ctx^2 * d_vocab^3)
    Complexity: ~ O(n_ctx * d_vocab * d_model^2 (from PVOU) + d_vocab * d_head^2 * d_model * n_ctx (from attention_score_map) + (n_ctx * log(n_ctx) (sorting) + n_ctx^2) * d_vocab^3)
    Complexity: (for n_ctx=2) O(PVOU + attention_score_map + n_ctx * d_vocab^2)
    N.B. Clever caching could reduce n_ctx^2 to n_ctx, leaving n_ctx log(n_ctx) * d_vocab^3 from sorting as the dominant factor.
    N.B. for max_of_{two,three}, this is maybe? worse than exhaustive enumeration (oops)
    """
    local_tqdm = make_local_tqdm(tqdm)
    n_ctx, d_vocab_out, d_vocab = model.cfg.n_ctx, model.cfg.d_vocab_out, model.cfg.d_vocab
    PVOU = all_PVOU(model)
    assert PVOU.shape == (n_ctx, d_vocab_out), f"PVOU.shape = {PVOU.shape} != {(n_ctx, d_vocab_out)} = (n_ctx, d_vocab_out)"
    attention_score_map = all_attention_scores(model)
    assert attention_score_map.shape == (n_ctx, d_vocab, d_vocab), f"attention_scores.shape = {attention_score_map.shape} != {(n_ctx, d_vocab, d_vocab)} = (n_ctx, d_vocab, d_vocab)"
    result = torch.zeros((d_vocab, d_vocab, d_vocab_out)) + float('nan')
    for query_tok in local_tqdm(range(d_vocab), total=d_vocab):
        for max_tok in [query_tok] + list(range(query_tok+np.max([1, min_gap]), d_vocab)):
            result[query_tok, max_tok, :] = worst_PVOU_gap_for(model, query_tok, max_tok, min_gap=min_gap, PVOU=PVOU, attention_score_map=attention_score_map, **kwargs)

    return result

# In[ ]:
@torch.no_grad()
def worst_EVOU_gap_for(model: HookedTransformer, query_tok: int, max_tok: int,
                       min_gap: int = 0,
                       EVOU: Optional[TensorType["d_vocab", "d_vocab_out"]] = None, # noqa: F821
                       attention_score_map: Optional[TensorType["n_ctx_k", "d_vocab_q", "d_vocab_k"]] = None, # noqa: F821
                       optimize_max_query_comparison=True) -> TensorType["d_vocab_out"]: # noqa: F821
    """
    Returns the map of non_max_output_tok to worst (largest) value of EVOU[non_max_output_tok] - EVOU[max_tok], across all possible attention scalings for the query token
        and for token values <= max_tok - min_gap.
    To deal with the fact that attention and EVOU are not truly independent, we relax the "worst" calculation by saying that the attention paid to a given token in a given position
        is the min of (most attention paid to this token in this position) and (most attention paid to any token < max in this position).
    "<" is relaxed to "<=" when the token under consideration is the max token.

    Complexity: O(EVOU + attention_score_map + n_ctx * d_vocab + d_vocab^2)
    Complexity: (for n_ctx=2) O(EOVU + attention_score_map + d_vocab + n_ctx)
    #N.B. If optimize_max_query_comparison is set, and n_ctx is 2, then whenever query_tok != max_tok we know exactly what the sequence is and can just compute the attention
    """
    assert max_tok >= query_tok, f"max_tok = {max_tok} < {query_tok} = query_tok"
    assert max_tok == query_tok or max_tok >= query_tok + min_gap, f"max_tok = {max_tok} < {query_tok} + {min_gap} = query_tok + min_gap"
    n_ctx, d_vocab_out, d_vocab = model.cfg.n_ctx, model.cfg.d_vocab_out, model.cfg.d_vocab
    if EVOU is None: EVOU = all_EVOU(model)
    assert EVOU.shape == (d_vocab, d_vocab_out), f"EVOU.shape = {EVOU.shape} != {(d_vocab, d_vocab_out)} = (d_vocab, d_vocab_out)"
    if attention_score_map is None: attention_score_map = all_attention_scores(model)
    assert attention_score_map.shape == (n_ctx, d_vocab, d_vocab), f"attention_scores.shape = {attention_score_map.shape} != {(n_ctx, d_vocab, d_vocab)} = (n_ctx, d_vocab, d_vocab)"
    if n_ctx == 2 and optimize_max_query_comparison and query_tok != max_tok:
        worst_attention_score = torch.zeros((n_ctx,))
        worst_attention_score[-1] = attention_score_map[-1, query_tok, query_tok]
        worst_attention_score[0] = attention_score_map[0, query_tok, max_tok]
        worst_EVOU = worst_attention_score.softmax(dim=-1) @ EVOU[torch.tensor([max_tok, query_tok]), :]
        return worst_EVOU - worst_EVOU[max_tok]
    elif max_tok - min_gap < 0:
        # everything must be the max
        assert max_tok == query_tok, f"max_tok = {max_tok} != {query_tok} = query_tok"
        worst_EVOU = EVOU[max_tok, :]
        return worst_EVOU - worst_EVOU[max_tok]
    else:
        # for each non-query position, compute the min and max attention scores for that position and query token where the key token is < max_tok, and also when the key token is <= max_tok
        max_nonmax_tok = np.min([max_tok - 1, max_tok - min_gap])
        min_attention_scores_without_max, max_attention_scores_without_max = attention_score_map[:-1, query_tok, :max_nonmax_tok+1].min(dim=-1).values, attention_score_map[:-1, query_tok, :max_nonmax_tok+1].max(dim=-1).values
        assert min_attention_scores_without_max.shape == (n_ctx-1,), f"min_attention_scores_without_max.shape = {min_attention_scores_without_max.shape} != {(n_ctx-1,)} = (n_ctx-1,)"
        assert max_attention_scores_without_max.shape == (n_ctx-1,), f"max_attention_scores_without_max.shape = {max_attention_scores_without_max.shape} != {(n_ctx-1,)} = (n_ctx-1,)"
        # for each key token below the max, compute the min and max attention scores for that token and query token where the key token is <= max_tok
        # if query token is max, we assume all other tokens are the same; otherwise, we pick the minimal attention slot for the max token and the other slots for the non-max, except when we consider all maxes but the query
        # we must subtract off the maximum to avoid overflow, as per https://github.com/pytorch/pytorch/blob/bc047ec906d8e1730e2ccd8192cef3c3467d75d1/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L115-L136
        attention_to_query = attention_score_map[-1, query_tok, query_tok]
        attentions_to_max = attention_score_map[:-1, query_tok, max_tok]
        attention_offset = torch.maximum(attentions_to_max.max(), attention_to_query)
        attention_to_max_exp = (attentions_to_max - attention_offset).exp().sum()
        attention_to_query_exp = (attention_to_query - attention_offset).exp()
        attention_sum = attention_to_max_exp + attention_to_query_exp
        EVOUs = torch.zeros((max_tok+1, d_vocab_out))
        EVOUs[max_tok, :] = EVOU[max_tok, :] * attention_to_max_exp / attention_sum + EVOU[query_tok, :] * attention_to_query_exp / attention_sum
        assert EVOUs[max_tok, :].shape == (d_vocab_out,), f"EVOU_all_maxes.shape = {EVOUs[max_tok, :].shape} != {(d_vocab_out,)} = (d_vocab_out,)"

        # consider all tokens < max, compute EVOU for each
        attention_to_max = attention_score_map[:-1, query_tok, max_tok].min()
        for non_max_tok in range(max_nonmax_tok+1):
            # we need to relax attention to non-max, picking the attention to this slot from the min of largest attention to this token and largest attention to this slot
            max_attention_to_non_max = attention_score_map[:-1, query_tok, non_max_tok].max()
            attention_to_non_max = torch.minimum(max_attention_to_non_max, max_attention_scores_without_max)
            if query_tok == max_tok:
                # we must subtract off the maximum to avoid overflow, as per https://github.com/pytorch/pytorch/blob/bc047ec906d8e1730e2ccd8192cef3c3467d75d1/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L115-L136
                attention_offset = torch.maximum(attention_to_query, attention_to_non_max.max())
                attention_to_max_exp = (attention_to_max - attention_offset).exp()
                attention_to_query_exp = (attention_to_query - attention_offset).exp()
                attention_to_non_max_exp = (attention_to_non_max - attention_offset).exp().sum()
                attention_sum = attention_to_non_max_exp + attention_to_query_exp
                EVOUs[non_max_tok, :] = EVOU[non_max_tok, :] * attention_to_non_max_exp / attention_sum + EVOU[query_tok, :] * attention_to_query_exp / attention_sum
            else:
                # we must subtract off the maximum to avoid overflow, as per https://github.com/pytorch/pytorch/blob/bc047ec906d8e1730e2ccd8192cef3c3467d75d1/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L115-L136
                attention_offset = torch.maximum(torch.maximum(attention_to_max, attention_to_query), attention_to_non_max.max())
                attention_to_non_max_exp = (attention_to_non_max - attention_offset).exp()
                # drop the smallest value in attention_to_non_max
                attention_to_non_max_exp = attention_to_non_max_exp.sum() - attention_to_non_max_exp.min()
                attention_to_max_exp = (attention_to_max - attention_offset).exp()
                attention_to_query_exp = (attention_to_query - attention_offset).exp()
                attention_sum = attention_to_non_max_exp + attention_to_query_exp + attention_to_max_exp
                EVOUs[non_max_tok, :] = EVOU[non_max_tok, :] * attention_to_non_max_exp / attention_sum + EVOU[query_tok, :] * attention_to_query_exp / attention_sum + EVOU[max_tok, :] * attention_to_max_exp / attention_sum
        # subtract off the max_tok EVOU
        print(EVOUs)
        EVOUs = EVOUs - EVOUs[:, max_tok][:, None]
        # return the worst EVOU
        return EVOUs.max(dim=0).values
# print(worst_EVOU_gap_for(model, 63, 63, 2))
# In[ ]:
@torch.no_grad()
def all_worst_EVOU(model: HookedTransformer, min_gap: int = 0, tqdm=None, **kwargs) -> TensorType["d_vocab_q", "d_vocab_max", "d_vocab_out"]: # noqa: F821
    """
    Returns the mixture of EVOUs with the worst (largest) value of EVOU[non_max_output_tok] - EVOU[max_tok], across all possible attention scalings for the query token and for token values <= max_tok - min_gap.
    Complexity: O(EVOU + attention_score_map + (n_ctx + d_vocab) * d_vocab^3)
    Complexity: (for n_ctx=2) O(EVOU + attention_score_map + (n_ctx + d_vocab) * d_vocab^2)
    N.B. for max_of_{two,three}, this is maybe? worse than exhaustive enumeration (oops)
    """
    local_tqdm = make_local_tqdm(tqdm)
    n_ctx, d_vocab_out, d_vocab = model.cfg.n_ctx, model.cfg.d_vocab_out, model.cfg.d_vocab
    EVOU = all_EVOU(model)
    assert EVOU.shape == (d_vocab, d_vocab_out), f"EVOU.shape = {EVOU.shape} != {(d_vocab, d_vocab_out)} = (d_vocab, d_vocab_out)"
    attention_score_map = all_attention_scores(model)
    assert attention_score_map.shape == (n_ctx, d_vocab, d_vocab), f"attention_scores.shape = {attention_score_map.shape} != {(n_ctx, d_vocab, d_vocab)} = (n_ctx, d_vocab, d_vocab)"
    result = torch.zeros((d_vocab, d_vocab, d_vocab_out)) + float('nan')
    for query_tok in local_tqdm(range(d_vocab), total=d_vocab):
        for max_tok in [query_tok] + list(range(query_tok+np.max([1, min_gap]), d_vocab)):
            result[query_tok, max_tok, :] = worst_EVOU_gap_for(model, query_tok, max_tok, min_gap=min_gap, EVOU=EVOU, attention_score_map=attention_score_map, **kwargs)

    return result


In [None]:
#@title imports
from tqdm.auto import tqdm
import torch
from transformer_lens import HookedTransformer
from jaxtyping import Float
from torch import Tensor

In [None]:
#model = max_of_2.get_model(train_if_necessary=False)
model = max_of_2.get_model(train_if_necessary=False)
all_tokens = compute_all_tokens(model)
all_logits = model(all_tokens)
expected_max = all_tokens.max(dim=-1).values
predicted_max = all_logits[..., -1, :].argmax(dim=-1)
print(f"Model Accuracy: {acc_fn(all_logits, all_tokens, return_per_token=False) * 100}%")
print(f"Number Incorrect Sequences: {(predicted_max != expected_max).sum()}")
print(f"Model Loss: {loss_fn(all_logits, all_tokens, return_per_token=False)}")
print(f"{all_logits.dtype} ULP on log-softmax = ULP at 1.0 = -(exp(0) - eps).log() = {torch.finfo(all_logits.dtype).eps}")


Note that the loss is *lower* than what we would expect from a one ULP (unit of least precision) error in the log-softmax calculation at the end.  This is about as good as the model can possibly do with 32-bit floats.

## Finding the Size Direction

We can run SVD on EQKE (actually $(W_E + W_{\text{pos}}[-1]) W_Q W_K^T \left(W_E + W_{\text{pos}}\text{.mean}(\text{dim}=0)\right)^T$) to find the size direction and the query direction in token space.

In [None]:
size_direction, query_direction, size_query_singular_value = find_size_and_query_direction(model, plot_heatmaps=True, colorscale='Picnic_r')
print(f"Size direction: {size_direction}\nQuery direction: {query_direction}\nSingular value: {size_query_singular_value}")

### A couple of notes:
- SVD is only unique up to the sign of each singular vector.  PyTorch SVD gives us a negative query direction vector, so we negate both the query and size direction vectors.
- If we fit the size direction to a cubic (or quintic), the bounds on the errors might not actually give us enough information to ensure adjacent tokens are ordered correctly.  But if we fit the differences in size-direction overlap of adjacent tokens to a quadratic (or quartic), we see that all differences are positive, and so we can get monotonicity even with worst-case errors.

### Interpretation and relevance
- The first singular value is just over 8,000; the next singular value is just under 30, so to a first approximation there's only one thing going on.
- However, the remainder of the QK circuit (labeled "Residual" on the "Contribution of the first singular value to Attention" plot) is not actually small enough to neglect in all cases.
  - Looking at the "Size Direction Δ & Fit" plots, consider the "resid.max - resid.min (worst-case independent query)" line.  This line results from taking the remainder of the QK circuit, finding the maximum possible difference in attention, and scaling it according to the worst possible query token (the one which overlaps the least with the size direction).  The position of this line shows that the size direction is enough to explain the majority of cases (many sequences with ($i$, $i+1$) will have a comparatively large attention gap just based on the size direction, and most sequences with larger gaps will certainly pay more attention to the larger token), but not all of them.
  - Looking at the "(resid[i] - resid[i+1]).max (worst-case query)" line suggests that even accounting for exact attention values, the model might pay more attention to the smaller token!  Let's compute if this ever happens.

In [None]:
print('The model pays more attention to the smaller token for the following seuqences:')
for minpos, qtok, ktokmin in find_backwards_attention(model):
    if qtok not in range(ktokmin, ktokmin+2): descr = "(invalid! query token not in sequence)"
    elif (minpos == 1 and qtok != ktokmin): descr = "(invalid! minimum in the query position but not equal to query token)"
    elif (minpos == 0 and qtok != ktokmin+1): descr = "(invalid! contradictory minimum position and query token constraints)"
    else: descr = "(valid!)"
    print(f"Tokens: {ktokmin}, {ktokmin+1};\tQuery: {qtok};\tPosition of the Minimum: {minpos}\t{descr}")

Notably, the model only pays more attention to smaller tokens when the query token is not present in the sequence.

### Compact Guarantees
- If we are trying to generate the most compact guarantee, neither the SVD nor the fit buy us much.  There are two locations in the proof where we might hope to gain in compactness by using the size direction:
  1. In explaining the behavior of the QK circuit.  But in generating a guarantee, we still have to establish that the particular QK circuit is doing the right thing, and I'm not sure how to compactly argue that the principle component of a product of matrices is what it is without multiplying out the matrices.  But if we multiply out the matrices, we have all of the pairwise attention weights, and so we don't need the size direction to explain the behavior of the QK circuit.
  2. In using the behavior of the QK circuit to explain the rest of the transformer.  Here in fact we get some benefit from having a compact description of *what* the QK circuit is doing.  Here we get a lot of benefit from some simple cut-off behavior (computing, for example, the minimal attention gap between tokens separated by at least two), but further dependencies between the QK circuit and the rest of the transformer seem to be more about the query direction than the size direction.

**Hypothesis**: In almost all cases (for almost all possible sequences), either:
1. The best reasoning we can do with the size direction isn't enough to get us 100% accuracy, and the loss will be rather sensitive to the exact attention values; or
2. A simple lower bound on the attention gap between non-adjacent tokens is enough to get us 100% accuracy, and the loss will be so insensitive to the exact attention values that we won't get much benefit from any approximation more detailed than a lower bound.

To test this hypothesis, we can compute, for each maximum token, how much attention needs to be on that token in order for the model to predict the correct output.

What does "needs to be" mean, though?
It could mean:
1. Given exactly how the rest of the transformer behaves, what's the cutoff for attention?
2. For some particular functional model of the rest of the transformer's behavior and bounds on the errors, what's the cutoff for attention?

or anything in between.

Let's compute the histogram of attention cutoffs for a variety of interpretations.

In [None]:
n_ctx, d_vocab, d_vocab_out = model.cfg.n_ctx, model.cfg.d_vocab, model.cfg.d_vocab_out
EUPU: Float[Tensor, "d_vocab d_vocab_out"] = EU_PU(model, pos=-1)
#assert EUPU.shape == (d_vocab, d_vocab_out), f"EUPU.shape = {EUPU.shape} != {(d_vocab, d_vocab_out)} = (d_vocab, d_vocab_out)"
EVOU: Float[Tensor, "d_vocab d_vocab_out"] = all_EVOU(model)
#assert EVOU.shape == (d_vocab, d_vocab_out), f"EVOU.shape = {EVOU.shape} != {(d_vocab, d_vocab_out)} = (d_vocab, d_vocab_out)"
PVOU: Float[Tensor, "n_ctx d_vocab_out"] = all_PVOU(model)
#assert PVOU.shape == (n_ctx, d_vocab_out), f"PVOU.shape = {PVOU.shape} != {(n_ctx, d_vocab_out)} = (n_ctx, d_vocab_out)"
# assume we vary only the attention, find the minimum attention required for the correct output
all_tokens: Float[Tensor, "batch n_ctx"] = compute_all_tokens(model)
all_tokens_max: Float[Tensor, "batch"] = all_tokens.max(dim=-1).values
all_tokens_max_pos: Float[Tensor, "batch"] = all_tokens.argmax(dim=-1)
all_tokens_EUPU: Float[Tensor, "batch d_vocab_out"] = EUPU[all_tokens[:, -1], :]
all_tokens_EVOU_PVOU: Float[Tensor, "batch n_ctx d_vocab_out"] = EVOU[all_tokens, :] + PVOU
# to compute the minimum post-softmax attention required for the correct output, we center each output so the logit of the correct output is 0
# then we want to compute the minimum p such that p * (logit of incorrect output on max token) + (1 - p) * (logit of incorrect token on non-max token) + (logit of incorrect output on residual path) < 0
# or equivalently p * (logit of incorrect output on max token - logit of incorrect output on non-max token) < -(logit of incorrect output on residual path + logit of incorrect token on non-max token)
# or p * sign(logit of incorrect output on max token - logit of incorrect output on non-max token) < -(logit of incorrect output on residual path + logit of incorrect token on non-max token) / abs(logit of incorrect output on max token - logit of incorrect output on non-max token)
# if the logit of the incorrect output on EUPU + the max token is positive, we say nan (no attention is enough)
# if the logit of the incorrect output on the non-max token is negative, we say 0 (any amount of attention is enough)
all_tokens_EUPU -= all_tokens_EUPU[torch.arange(all_tokens_EUPU.shape[0]), all_tokens_max][:, None]
for b in range(all_tokens_EVOU_PVOU.shape[0]):
    for n in range(all_tokens_EVOU_PVOU.shape[1]):
        all_tokens_EVOU_PVOU[b, n, :] = all_tokens_EVOU_PVOU[b, n, :] - all_tokens_EVOU_PVOU[b, n, all_tokens_max[b]]

# fold all_tokens_EUPU into all_tokens_EVOU_PVOU
all_tokens_EVOU_PVOU_max_token: Float[Tensor, "batch d_vocab_out"] = all_tokens_EVOU_PVOU[torch.arange(all_tokens_EVOU_PVOU.shape[0]), all_tokens_max_pos, :] + all_tokens_EUPU
# we're worst off when the attention on the non-max token is largest, so we take max
all_tokens_EVOU_PVOU_non_max_token_tmp = all_tokens_EVOU_PVOU.clone()
all_tokens_EVOU_PVOU_non_max_token_tmp[torch.arange(all_tokens_EVOU_PVOU.shape[0]), all_tokens_max_pos, :] = -float('inf')
all_tokens_EVOU_PVOU_non_max_token: Float[Tensor, "batch d_vocab_out"] = all_tokens_EVOU_PVOU_non_max_token_tmp.max(dim=-2).values + all_tokens_EUPU
attention_never_enough = (all_tokens_EVOU_PVOU_max_token > 0)
logit_diff = all_tokens_EVOU_PVOU_max_token - all_tokens_EVOU_PVOU_non_max_token
logit_gap = -all_tokens_EVOU_PVOU_non_max_token / logit_diff.abs()
# where logit_diff < 0, we want smallest p > -logit_gap
# where logit_diff > 0 we want smallest p < logit_gap
min_p_1 = -logit_gap[logit_diff < 0]
min_p_1 = torch.max(min_p_1, torch.zeros_like(min_p_1))
min_p_2 = logit_gap[logit_diff > 0]
min_p_2 = torch.min(min_p_2, torch.zeros_like(min_p_2))
min_p = torch.cat([min_p_1, min_p_2], dim=0)
#from analysis_utils import hist
hist(min_p)
hist(min_p[min_p != 0])
hist(min_p[min_p < 0])
# find pre-softmax
# min_p = weight on max = e^max / (e^max + e^non-max) = 1 / (1 + e^(non-max - max)
# e^(non_max - max) = 1 / min_p - 1
# non_max - max = log(1 / min_p - 1)
# max - non_max = -log(1 / min_p - 1)
min_attn = -(1 / min_p[min_p > 0] - 1).log()
hist(min_attn)
hist(min_attn[min_attn > 0])

# all_tokens_EVOU_PVOU_max_token - all_tokens_EVOU_PVOU_non_max_token
# -all_tokens_EVOU_PVOU_non_max_token / (all_tokens_EVOU_PVOU_max_token - all_tokens_EVOU_PVOU_non_max_token).abs()

# print(all_tokens_EUPU.shape)

HERE

There are numerous approximations to this computation we might want to consider:
1. For every sequence, how much attention needs to be on the correct token for predicting the correct output?
2. For each max token, for each position it could be in, for the worst-case (or average-case) query token compatible with that choice (determining the residual stream impact), for the worst-case (or average-case) non-max-token OV behavior (per-logit), how much attention needs to be on the max token for the max token logit to be higher than the logit corresponding to the other token?  (Here we'd also need to verify that the OV behavior has all diagonal entries higher than all off-diagonal entries in the same row, to ensure that we can find the worst case for each logit separately.)
3. For each max token, we can compute the smallest (or average) gap between the logit of that token and the logit of any other token, and then reduce across positions.  Then we can compute the largest (or average) gap between logits in the residual stream impact, and the largest (or average) gap between logits across all other copying behavior.  Finally we can compute how much attention needs to be on the max token to prevent the worst-case (or average) behavior of the rest of the transformer from outweighing the correct behavior on the max token.

Although there are even more approximations we might consider, let's zoom out to make sense of the landscape here.  When picking an approximation to compute, we should ask: "What interpretation are we validating with this computation?"

1. The first approximation validates the hypothesis "the model (somehow) pays enough more attention to the largest token, where 'enough' means: enough that whatever other computations are going on, the model outputs the correct answer".  Notably, this doesn't explain very much.
2. The second approximation validates the hypothesis "for every sequence, find the maximum, and consider all sequences with the same maximum in the same position; whatever computation is going on in the rest of the model, it's sufficiently invariant over the details of the rest of the sequence that the model outputs the correct answer".
3. The third approximation validates the hypothesis "the behavior on the non-max token and the bahvior on the skip-connection is irrelevant noise; we pay enough more attention to the max token that we can neglect everything else that happens beyond getting a simple bound on how much it perturbs the output".

Notably, the computation for 3 is (slightly) more compact than the computation for 2 (which itself is slightly more compact than the computation for 1), which, to our eyes at least, is minor evidence in favor of compact proofs being a good proxy for human interpretability.

However, in the max-of-2 model, exhaustive enumeration, which is exponential in the sequence length, is indistinguishable from a pairwise quadratic analysis (quadratic and exponential are the same when the exponent is 2).  There are not many asymptotic gains to be had here, and compactness differences are somewhat harder to see.

Let's plot how much attention is given to the max token in each sequence, and how much attention is required, according to each of these three computations.

In [None]:
# def compute_attention_to_max(model: HookedTransformer) -> torch.Tensor:
#     """Compute the attention given to the max token in each sequence."""
#     n_ctx, d_vocab = model.cfg.n_ctx, model.cfg.d_vocab
#     # pre soft-max
#     all_attention_scores = all_attention_scores(model)
#     assert all_attention_scores.shape == (n_ctx, d_vocab, d_vocab), f"all_attention.shape = {all_attention_scores.shape} != ({n_ctx}, {d_vocab}, {d_vocab}) = (n_ctx_k, d_vocab_k, d_vocab_q)"

#     # compute soft-max



#     all_tokens = compute_all_tokens(model)
#     all_logits = model(all_tokens)
#     expected_max = all_tokens.max(dim=-1).values
#     predicted_max = all_logits[..., -1, :].argmax(dim=-1)
#     return (all_logits.gather(-1, expected_max.unsqueeze(-1)) * (predicted_max == expected_max).float()).sum(dim=1)

  HERE
 # %%
 # compute EU PU
 W_E, W_pos, W_U = model.W_E, model.W_pos, model.W_U
 print(W_E.shape, W_pos.shape, W_U.shape)
 line(W_pos[-1] @ W_U)
 imshow(W_E @ W_U)
 imshow((W_E + W_pos[-1]) @ W_U)
 analyze_svd(W_E @ W_U)
 analyze_svd((W_E + W_pos[-1]) @ W_U)
 # %%
 # compute OV
 import analysis_utils
 analysis_utils.calculate_OV_of_pos_embed(model)
 analysis_utils.calculate_copying(model)
 # %%
 W_E, W_pos, W_U, W_V, W_O = model.W_E, model.W_pos, model.W_U, model.W_V, model.W_O
 analyze_svd(W_E @ W_V[0, 0] @ W_O[0, 0] @ W_U)
 # %%

In [None]:
#import analysis_utils
# reload analysis_utils module
#import importlib
#importlib.reload(analysis_utils)
#analysis_utils.analyze_EVOU(model, scale_by_singular_value=False)
#analysis_utils.analyze_PVOU(model)
#analysis_utils.analyze_PU(model)
#analysis_utils.analyze_EU(model)
analyze_EVOU(model, scale_by_singular_value=False)
analyze_PVOU(model)
analyze_PU(model)
analyze_EU(model)
