<a href="https://colab.research.google.com/github/andyrdt/mi/blob/main/ARENA/monthly_algorithmic_problems/10_2023/Sorted_List.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Monthly Algorithmic Challenge (October 2023): Sorted List

This post is the fourth in the sequence of monthly mechanistic interpretability challenges. They are designed in the spirit of [Stephen Casper's challenges](https://www.lesswrong.com/posts/KSHqLzQscwJnv44T8/eis-vii-a-challenge-for-mechanists), but with the more specific aim of working well in the context of the rest of the ARENA material, and helping people put into practice all the things they've learned so far.


If you prefer, you can access the Streamlit page [here](https://arena-ch1-transformers.streamlit.app/Monthly_Algorithmic_Problems).

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/sorted-problem.png" width="350">

## Setup

In [1]:
%%capture

try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys

if IN_COLAB:
    # Install packages
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/eindex.git
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

    # Code to download the necessary files (e.g. solutions, test funcs)
    import os, sys
    if not os.path.exists("chapter1_transformers"):
        !curl -o /content/main.zip https://codeload.github.com/callummcdougall/ARENA_2.0/zip/refs/heads/main
        !unzip /content/main.zip 'ARENA_2.0-main/chapter1_transformers/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter1_transformers/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter1_transformers", "chapter1_transformers")
        os.rmdir("ARENA_2.0-main")
        os.chdir("chapter1_transformers/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

In [2]:
%%capture

import torch as t
from pathlib import Path

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "monthly_algorithmic_problems" / "october23_sorted_list"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from monthly_algorithmic_problems.october23_sorted_list.dataset import SortedListDataset
from monthly_algorithmic_problems.october23_sorted_list.model import create_model
from plotly_utils import hist, bar, imshow

device = t.device("cuda" if t.cuda.is_available() else "cpu")

## Task & Dataset

The problem for this month is interpreting a model which has been trained to sort a list. The model is fed sequences like:

```
[11, 2, 5, 0, 3, 9, SEP, 0, 2, 3, 5, 9, 11]
```

and has been trained to predict each element in the sorted list (in other words, the output at the `SEP` token should be a prediction of `0`, the output at `0` should be a prediction of `2`, etc).

Here is an example of what this dataset looks like:

In [3]:
dataset = SortedListDataset(size=1, list_len=5, max_value=10, seed=42)

print(dataset[0].tolist())
print(dataset.str_toks[0])

[9, 6, 2, 4, 5, 11, 2, 4, 5, 6, 9]
['9', '6', '2', '4', '5', 'SEP', '2', '4', '5', '6', '9']


The relevant files can be found at:

```
chapter1_transformers/
└── exercises/
    └── monthly_algorithmic_problems/
        └── october23_sorted_list/
            ├── model.py               # code to create the model
            ├── dataset.py             # code to define the dataset
            ├── training.py            # code to training the model
            └── training_model.ipynb   # actual training script
```


## Model

The model is attention-only, with 1 layer, and 2 attention heads per layer. It was trained with layernorm, weight decay, and an Adam optimizer with linearly decaying learning rate.

You can load the model in as follows:


In [4]:
filename = section_dir / "sorted_list_model.pt"

model = create_model(
    list_len=10,
    max_value=50,
    seed=0,
    d_model=96,
    d_head=48,
    n_layers=1,
    n_heads=2,
    normalization_type="LN",
    d_mlp=None
)

state_dict = t.load(filename)

state_dict = model.center_writing_weights(t.load(filename))
state_dict = model.center_unembed(state_dict)
state_dict = model.fold_layer_norm(state_dict)
state_dict = model.fold_value_biases(state_dict)
model.load_state_dict(state_dict, strict=False);

The code to process the state dictionary is a bit messy, but it's necessary to make sure the model is easy to work with. For instance, if you inspect the model's parameters, you'll see that `model.ln_final.w` is a vector of 1s, and `model.ln_final.b` is a vector of 0s (because the weight and bias have been folded into the unembedding).

In [5]:
print("ln_final weight: ", model.ln_final.w)
print("\nln_final, bias: ", model.ln_final.b)

ln_final weight:  Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1.], device='cuda:0', requires_grad=True)

ln_final, bias:  Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0', requires_

<details>
<summary>Aside - the other weight processing parameters</summary>

Here's some more code to verify that our weights processing worked, in other words:

* The unembedding matrix has mean zero over both its input dimension (`d_model`) and output dimension (`d_vocab`)
* All writing weights (i.e. `b_O`, `W_O`, and both embeddings) have mean zero over their output dimension (`d_model`)
* The value biases `b_V` are zero (because these can just be folded into the output biases `b_O`)

```python
W_U_mean_over_input = einops.reduce(model.W_U, "d_model d_vocab -> d_model", "mean")
t.testing.assert_close(W_U_mean_over_input, t.zeros_like(W_U_mean_over_input))

W_U_mean_over_output = einops.reduce(model.W_U, "d_model d_vocab -> d_vocab", "mean")
t.testing.assert_close(W_U_mean_over_output, t.zeros_like(W_U_mean_over_output))

W_O_mean_over_output = einops.reduce(model.W_O, "layer head d_head d_model -> layer head d_head", "mean")
t.testing.assert_close(W_O_mean_over_output, t.zeros_like(W_O_mean_over_output))

b_O_mean_over_output = einops.reduce(model.b_O, "layer d_model -> layer", "mean")
t.testing.assert_close(b_O_mean_over_output, t.zeros_like(b_O_mean_over_output))

W_E_mean_over_output = einops.reduce(model.W_E, "token d_model -> token", "mean")
t.testing.assert_close(W_E_mean_over_output, t.zeros_like(W_E_mean_over_output))

W_pos_mean_over_output = einops.reduce(model.W_pos, "position d_model -> position", "mean")
t.testing.assert_close(W_pos_mean_over_output, t.zeros_like(W_pos_mean_over_output))

b_V = model.b_V
t.testing.assert_close(b_V, t.zeros_like(b_V))
```

</details>


A demonstration of the model working:


In [6]:
from eindex import eindex

N = 500
dataset = SortedListDataset(size=N, list_len=10, max_value=50, seed=43)

logits, cache = model.run_with_cache(dataset.toks)
logits: t.Tensor = logits[:, dataset.list_len:-1, :]

targets = dataset.toks[:, dataset.list_len+1:]

logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
probs = logprobs.softmax(-1)

batch_size, seq_len = dataset.toks.shape
logprobs_correct = eindex(logprobs, targets, "batch seq [batch seq]")
probs_correct = eindex(probs, targets, "batch seq [batch seq]")

avg_cross_entropy_loss = -logprobs_correct.mean().item()

print(f"Average cross entropy loss: {avg_cross_entropy_loss:.3f}")
print(f"Mean probability on correct label: {probs_correct.mean():.3f}")
print(f"Median probability on correct label: {probs_correct.median():.3f}")
print(f"Min probability on correct label: {probs_correct.min():.3f}")

Average cross entropy loss: 0.039
Mean probability on correct label: 0.966
Median probability on correct label: 0.981
Min probability on correct label: 0.001


And a visualisation of its probability output for a single sequence:

In [7]:
def show(i):

    imshow(
        probs[i].T,
        y=dataset.vocab,
        x=[f"{dataset.str_toks[i][j]}<br><sub>({j})</sub>" for j in range(dataset.list_len+1, dataset.seq_len)],
        labels={"x": "Token", "y": "Vocab"},
        xaxis_tickangle=0,
        title=f"Sample model probabilities:<br>Unsorted = ({','.join(dataset.str_toks[i][:dataset.list_len])})",
        text=[
            ["〇" if (str_tok == target) else "" for target in dataset.str_toks[i][dataset.list_len+1: dataset.seq_len]]
            for str_tok in dataset.vocab
        ],
        width=400,
        height=1000,
    )

show(0)

Best of luck! 🎈

# Andy's work starts here

## High-level mechanism overview

- Predicting the first output value
  - The first output value is predicted from the `SEP` token at `pos=10`
  - Each head's attention to token $t$ is (roughly) monotonically decreasing in that token's value: $$A(t=0) > A(t=1) > \ldots > A(t=49) > A(t=50)$$
    - Note: although this is the general trend, this property isn't strictly true, and there are notable exceptions (which can be exploited in adversarial examples).
  - Heads then output the value they most strongly attend to.

- Predicting the rest of the values ("normal predictions")
  - For "normal predictions", the previously-output token will be the source token, determining the query.
  - Each head's attention to token $t$ from source token $t_s$ is (roughly) monotonically decreasing in token $t$'s value, starting at a maximum at $t_s+1$:
  $$A(t=t_s+1) > A(t=t_s+2) > A(t=t_s+3) > \ldots$$
    - Note: again, there are of course exceptions to this general trend.
  - Heads then output the value they most strongly attend to.
- Head 0.0 vs Head 0.1
  - The predicted token values are roughly partitioned across the two heads:
    - H0.0 is responsible for predictions of $\{ 28, 29, \ldots, 36\}$.
    - H0.1 is responsible for all other predictions.
    - (There are also a few tokens that seem to have shared responsibility.)
  - This evidenced by OV circuit visualization, and direct logit attribution.

## Some setup

### Custom imports

In [8]:
import pandas as pd
import einops
import plotly.express as px
import plotly.graph_objects as go

from plotly.subplots import make_subplots

### Layer norm

In [9]:
scale_ln1_layer0 = cache["scale", 0, "ln1"][:, :, 0, 0] # shape (batch, seq)
scale_lnfinal = cache["scale"][:, :, 0] # shape (batch, seq)

for scale, label in zip(
    [scale_ln1_layer0, scale_lnfinal],
    ["ln1, layer 0", "lnfinal"]):

    df = pd.DataFrame({
        "std": scale.std(0).cpu().numpy(),
        "mean": scale.mean(0).cpu().numpy(),
    })

    display(
        px.bar(
            df,
            title=f"Mean & std of layernorm before {label}",
            template="simple_white", width=450, height=300, barmode="group"
        )
    )

## OV circuits

We'll start the analysis by visualizing each head's OV circuit:
$$W_E W_V^{0.h} W_O^{0.h} W_U$$

In [10]:
# visualize layer 0 OV circuits
layer = 0

fig = make_subplots(
    rows=model.cfg.n_layers,
    cols=model.cfg.n_heads,
    subplot_titles=([f"H{layer}.{head} OV circuit" for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)])
)

global_zmin = float('inf')
global_zmax = float('-inf')

OV_circuits = {}

for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        W_OV = model.W_V[layer, head] @ model.W_O[layer, head]
        OV_circuit = model.W_E[:-1] @ W_OV @ model.W_U[:, :-1]

        OV_circuit /= scale_ln1_layer0.mean()
        OV_circuit /= scale_lnfinal.mean()

        global_zmin = min(global_zmin, OV_circuit.min().item())
        global_zmax = max(global_zmax, OV_circuit.max().item())

        OV_circuits[(layer, head)] = OV_circuit

# Adding traces to the figure using the global min and max and the stored OV_circuit values
for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        fig.add_trace(
            go.Heatmap(z=OV_circuits[(layer, head)].detach().cpu().numpy(),
                zmin=global_zmin,
                zmax=global_zmax,
                colorscale="RdBu",
                showscale=True
            ),
            row=layer+1, col=head+1
        )

fig.update_xaxes(title_text='Vocab (unembed)')
fig.update_yaxes(title_text='Vocab (embed)', autorange='reversed')
fig.update_layout(width=1400, height=750)

fig.show()

Observations:
- H0.0's OV circuit is mostly negative diagonal, with a positive diagonal strip from 28-35
- H0.1's OV circuit is pretty strongly diagonal, with a notable weak spot between 28-35

This suggests that the functionality of predicting tokens is roughly partitioned between the two heads:
- H0.0 is responsible for predicting tokens 28-35
- H0.1 is responsible for predicting all other tokens


## Direct logit attribution


We can visualize the above idea by looking at direct logit attribution, and splitting by correct token.

We verify that when the correct token is 28-35, H0.0 is primarily responsible for increasing its logits, and otherwise H0.1 is primarily responsible.

In [11]:
attn_results = cache["result", 0]
attn_results = attn_results / scale_lnfinal.mean()

attn_logits = einops.einsum(attn_results, model.W_U, 'batch seq n_heads d_model, d_model d_vocab -> batch seq n_heads d_vocab')

correct_labels = einops.rearrange(dataset.toks, 'batch seq -> batch seq 1 1')
correct_labels = einops.repeat(correct_labels, 'batch seq n_heads label -> batch seq (repeat n_heads) label', repeat=2)

attn_correct_logits = attn_logits[:, 10:20].gather(dim=-1, index=correct_labels[:, 11:21].to(device))

logit_per_head_per_correct_token_mean = t.zeros((model.cfg.n_heads, len(dataset.vocab[:-1])))
logit_per_head_per_correct_token_std = t.zeros((model.cfg.n_heads, len(dataset.vocab[:-1])))

for vocab_idx, vocab_tok in enumerate(dataset.vocab[:-1]):
    mask = (correct_labels[:, 11:21] == vocab_idx)
    for head in range(model.cfg.n_heads):
        logit_per_head_per_correct_token_mean[head, vocab_idx] = attn_correct_logits[:, :, head, :][mask[:, :, head, :]].mean()
        logit_per_head_per_correct_token_std[head, vocab_idx] = attn_correct_logits[:, :, head, :][mask[:, :, head, :]].std()

logit_mean_np = logit_per_head_per_correct_token_mean.detach().numpy()
logit_std_np = logit_per_head_per_correct_token_std.detach().numpy()

fig = go.Figure()

for head in range(model.cfg.n_heads):
    fig.add_trace(go.Bar(
        x=[str(i) for i in range(logit_mean_np.shape[1])],
        y=logit_mean_np[head, :],
        name=f'H0.{head}',
        # error_y=dict(type='data', array=logit_std_np[head, :], visible=True)
    ))

fig.update_layout(title='Correct token logit contribution, split by correct token', xaxis_title='Correct token',
                  yaxis_title='Logit contribution', barmode='group', template="simple_white", width=1000, height=400)
fig.show()


The trend that we read off from the OV circuits is essentially correct - the tokens are roughly partitioned between H0.0 and H0.1.

There are a few tokens where responsibility is more or less shared between the two heads: 2, 28, 35.

## QK circuits

From OV analysis, we basically learned that heads output positive logit contributions towards the tokens it attends to. Now we investigate how the heads identify the correct token to output via QK circuits.

We begin by visualizing each head's QK circuit:
$$W_E W_Q^{0.h} (W_E W_K^{0.h})^T$$

In [12]:
# visualize layer 0 QK embedding circuits
layer = 0

fig = make_subplots(
    rows=model.cfg.n_layers,
    cols=model.cfg.n_heads,
    subplot_titles=([f"H{layer}.{head} QK circuit" for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)])
)

global_zmin = float('inf')
global_zmax = float('-inf')

W_emb = model.W_E
W_emb_scaled = W_emb / scale_ln1_layer0.mean()

QK_values = {}

for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        W_QK = model.W_Q[layer, head] @ model.W_K[layer, head].T / model.cfg.d_head**0.5
        QK_full_emb = W_emb_scaled[:] @ W_QK @ W_emb_scaled[:-1].T

        global_zmin = min(global_zmin, QK_full_emb.min().item())
        global_zmax = max(global_zmax, QK_full_emb.max().item())

        # Storing each QK_full_emb
        QK_values[(layer, head)] = QK_full_emb

# Adding traces using the stored QK_full_emb values and the global zmin and zmax
for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        QK_full_emb = QK_values[(layer, head)]

        fig.add_trace(
            go.Heatmap(z=QK_full_emb.detach().cpu().numpy(),
                       zmin=global_zmin,
                       zmax=global_zmax,
                       colorscale="RdBu",
                       showscale=True),
            row=layer+1, col=head+1
        )

# Updating and showing the figure
fig.update_xaxes(title_text='Key (token)')
fig.update_yaxes(title_text='Query (token)', autorange='reversed')
fig.update_layout(width=1400, height=750)
fig.show()


Observations:
- For some query token $t_q$, attention is maximized at $t_q+1$ and then decreases:
$$A(t_q+1) > A(t_q+2) > A(t_q+3) > \ldots$$
- This is mostly true locally
  - There are increases further away from the original value
    - For example, look at H0.1's `query=0` row from `key=35-40`
  - There are also exceptions in regions where the head is not "responsible"
    - For example, look at H0.1's `query=28` row from `key=29-34`

We can zoom in on the case of predicting the first token, where the source query comes from `pos=10` and `emb=SEP`:

In [13]:
layer = 0

W_emb = model.W_E
W_emb_scaled = W_emb / scale_ln1_layer0.mean()

pos_10_resid = model.W_E[-1] + model.W_pos[10]
pos_10_resid /= scale_ln1_layer0.mean()

fig = go.Figure()

for head in range(model.cfg.n_heads):
    pos_10_query = pos_10_resid @ model.W_Q[layer, head]
    keys = W_emb_scaled @ model.W_K[layer, head]
    pos_10_QK = pos_10_query @ keys.T / model.cfg.d_head**0.5

    fig.add_trace(
        go.Scatter(
            x = dataset.vocab,
            y = pos_10_QK.detach().cpu().numpy(),
            mode='markers+lines',
            name=f"H{layer}.{head}",
        )
    )
fig.update_layout(
    title="Attention from pos=10 'SEP' token to each token embedding",
    xaxis_title="Token embedding (key)",
    yaxis_title="Attention magnitude",
    width=800, height=400,
)
fig.show()

We can see that in the first half, H0.1's attention is (almost) monotonically decreasing. Thus, H0.1 from this position will attend most strongly to the minimum token, as desired.

This pattern breaks starting at 28. This is likely because there were very few starting training sequences beginning with a value greater than or equal to 28.

## Adversarial examples

### Prediction from position `pos=10`

We can take advantage of the previous observation, that the first token prediction breaks for tokens above 28, and craft some adversarial examples that will fail to correctly predict the first token.

In [14]:
custom_toks = t.tensor([
    [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 51, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
])
custom_str_toks = [[dataset.vocab[tok] for tok in seq.tolist()] for seq in custom_toks]

In [15]:
custom_logits, custom_cache = model.run_with_cache(custom_toks)
custom_logits: t.Tensor = custom_logits[:, dataset.list_len:-1, :]

custom_targets = custom_toks[:, dataset.list_len+1:]

custom_logprobs = custom_logits.log_softmax(-1) # [batch seq_len vocab_out]
custom_probs = custom_logprobs.softmax(-1)

imshow(
    custom_probs[0].T,
    y=dataset.vocab,
    x=[f"{custom_str_toks[0][j]}<br><sub>({j})</sub>" for j in range(dataset.list_len+1, dataset.seq_len)],
    labels={"x": "Token", "y": "Vocab"},
    xaxis_tickangle=0,
    title=f"Sample model probabilities:<br>Unsorted = ({','.join(custom_str_toks[0][:dataset.list_len])})",
    text=[
        ["〇" if (str_tok == target) else "" for target in custom_str_toks[0][dataset.list_len+1: dataset.seq_len]]
        for str_tok in dataset.vocab
    ],
    width=500,
    height=1000,
)

### Prediction from token `emb=5`

We can find other quirks by zooming in on QK patterns for particular source (query) tokens, and craft adversarial examples accordingly.

In [16]:
# visualize layer 0 QK embedding circuits
layer = 0

W_emb = model.W_E
W_emb_scaled = W_emb / scale_ln1_layer0.mean()

fig = go.Figure()


for head in range(model.cfg.n_heads):
    tok_5_query = W_emb_scaled[5] @ model.W_Q[layer, head]
    keys = W_emb_scaled @ model.W_K[layer, head]
    tok_5_QK = tok_5_query @ keys.T / model.cfg.d_head**0.5

    fig.add_trace(
        go.Scatter(
            x = dataset.vocab,
            y = tok_5_QK.detach().cpu().numpy(),
            mode='markers+lines',
            name=f"H{layer}.{head}",
        )
    )
fig.update_layout(
    title="Attention from '5' token to each token embedding",
    xaxis_title="Token embedding (key)",
    yaxis_title="Attention magnitude",
    width=800, height=400,
)
fig.show()

We can see that from `emb=5`, attention to `emb=35` is minimal. Thus, the model is very unlikely to predict a `35` following a `5`, even if it is correct.

We can craft an adversarial example where `35` is actually the correct continuation from `5`:

In [17]:
custom_toks = t.tensor([
    [5, 35, 36, 37, 38, 39, 40, 41, 42, 43, 51, 5, 35, 36, 37, 38, 39, 40, 41, 42, 43],
])
custom_str_toks = [[dataset.vocab[tok] for tok in seq.tolist()] for seq in custom_toks]

In [18]:
custom_logits, custom_cache = model.run_with_cache(custom_toks)
custom_logits: t.Tensor = custom_logits[:, dataset.list_len:-1, :]

custom_targets = custom_toks[:, dataset.list_len+1:]

custom_logprobs = custom_logits.log_softmax(-1) # [batch seq_len vocab_out]
custom_probs = custom_logprobs.softmax(-1)

imshow(
    custom_probs[0].T,
    y=dataset.vocab,
    x=[f"{custom_str_toks[0][j]}<br><sub>({j})</sub>" for j in range(dataset.list_len+1, dataset.seq_len)],
    labels={"x": "Token", "y": "Vocab"},
    xaxis_tickangle=0,
    title=f"Sample model probabilities:<br>Unsorted = ({','.join(custom_str_toks[0][:dataset.list_len])})",
    text=[
        ["〇" if (str_tok == target) else "" for target in custom_str_toks[0][dataset.list_len+1: dataset.seq_len]]
        for str_tok in dataset.vocab
    ],
    width=500,
    height=1000,
)