In [1]:
from mamba_ssm.ops.triton.layernorm import rms_norm_fn

from nnsight.models.Mamba import Mamba

from typing import Callable, List

import torch

from nnsight import LanguageModel

from typing import List

import baukit

from copy import deepcopy

import accelerate



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = Mamba("state-spaces/mamba-1.4b", device="cuda", dispatch=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:

lens = None

In [3]:
# Don't run this cell for regular logit lens

class TunedLens(torch.nn.Module):
    def __init__(self, layers: List, d_model: int) -> None:
        super().__init__()

        translator = torch.nn.Linear(d_model, d_model, bias=True)
        translator.weight.data.zero_()
        translator.bias.data.zero_()

        self.layer_translators = torch.nn.ModuleList(
            [deepcopy(translator) for _ in range(len(layers) - 1)]
        )

d_model_hidden_states = model.backbone.layers[0].output_shape[0][-1]

lens = TunedLens(model.backbone.layers, d_model_hidden_states).to("cuda")

lens = accelerate.load_checkpoint_and_dispatch(lens, 'model.safetensors')


In [4]:
def logitlens(model: LanguageModel, prompt: str, layers: List, decoding_fn: Callable):
    probs_layers = []

    with model.forward(validate=False) as runner:
        with runner.invoke(prompt, scan=False) as invoker:
            for layer_idx, layer in enumerate(layers):

                probs = torch.nn.functional.softmax(
                    decoding_fn(layer.output, layer=layer_idx), dim=-1
                ).save()

                probs_layers.append(probs)

    probs = torch.concatenate([probs.value for probs in probs_layers])

    max_probs, tokens = probs.max(dim=-1)

    words = [[model.tokenizer.decode(t).encode("unicode_escape").decode() for t in layer_tokens] for layer_tokens in tokens]

    input_words = [model.tokenizer.decode(t) for t in invoker.input["input_ids"][0]]

    return words, max_probs, input_words


In [5]:


def vis(
    words: List[List[str]],
    probs: List[List[float]],
    input_words: List[str],
    color=[50, 168, 123],
):
    header_line = [  # header line
        [[" "]]
        + [
            [
                baukit.show.style(fontWeight="bold", width="50px"),
                baukit.show.attr(title=f"Token {i}"),
                t,
            ]
            for i, t in enumerate(input_words)
        ]
    ]

    def color_fn(p):
        a = [int(255 * (1 - p) + c * p) for c in color]
        return baukit.show.style(background=f"rgb({a[0]}, {a[1]}, {a[2]})")

    layer_logits = [
        # first column
        [[baukit.show.style(fontWeight="bold", width="50px"), f"L{layer_idx}"]]
        + [
            # subsequent columns
            [
                color_fn(token_prob),
                baukit.show.style(overflowX="hide", color="black"),
                f"{token_word}",
            ]
            for token_word, token_prob in zip(layer_words, layer_probs)
        ]
        for layer_idx, (layer_words, layer_probs) in enumerate(zip(words, probs))
    ]

    baukit.show(header_line + layer_logits + header_line)


In [6]:

def decode(output, layer=0):

    hidden_states= output[0] + output[1]

    if lens is not None and layer != len(lens.layer_translators):

        hidden_states = lens.layer_translators[layer](hidden_states) + hidden_states

    norm_f = model.local_model.backbone.norm_f

    decoded = hidden_states.node.graph.add(
        target=rms_norm_fn,
        args=[hidden_states, norm_f.weight, norm_f.bias],
        kwargs={
            "eps": norm_f.eps,
            "residual": None,
            "prenorm": False,
            "residual_in_fp32": True,
        },
    )

    return model.lm_head(decoded)



In [7]:
words, probs, input_words = logitlens(model, "The Eiffel Tower is in the city of", model.backbone.layers, decode)


You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [8]:
vis(words, probs, input_words)
