In [6]:
from nnsight import LanguageModel
from typing import List, Callable
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from IPython.display import clear_output

clear_output()

# RoBERTa Modell laden
model = LanguageModel("roberta-base", device_map="auto", dispatch=True)

prompt = "Ich bin ein Mensch"
layers = model.model.encoder.layer  # RoBERTa verwendet encoder.layer
probs_layers = []

with model.trace() as tracer:
    with tracer.invoke(prompt) as invoker:
        for layer_idx, layer in enumerate(layers):
            # Layer-Verarbeitung und Normalisierung
            layer_output = model.lm_head(
                model.model.encoder.layernorm(layer.output[0])
            )
            
            # Wahrscheinlichkeiten berechnen
            probs = torch.nn.functional.softmax(layer_output, dim=-1).save()
            probs_layers.append(probs)

# Wahrscheinlichkeiten zusammenfügen
probs = torch.cat([probs.value for probs in probs_layers])

# Maximum Wahrscheinlichkeiten und Token-IDs finden
max_probs, tokens = probs.max(dim=-1)

# Token-IDs in Wörter umwandeln
words = [[model.tokenizer.decode(t).encode("unicode_escape").decode() for t in layer_tokens]
    for layer_tokens in tokens]

# Input-Wörter erhalten
input_words = [model.tokenizer.decode(t) for t in invoker.inputs[0]["input_ids"][0]]

# Visualisierung
output_words = input_words[1:] + ["?"]
cmap = sns.diverging_palette(255, 0, n=len(words[0]), as_cmap=True)

plt.figure(figsize=(10, 6))
ax = sns.heatmap(max_probs.cpu().detach().numpy(), 
                 annot=np.array(words), 
                 fmt='', 
                 cmap=cmap, 
                 linewidths=.5, 
                 cbar_kws={'label': 'Probability'})

plt.title('RoBERTa Logit Lens Visualization')
plt.xlabel('Input Tokens')
plt.ylabel('Layers')

plt.yticks(np.arange(len(words)) + 0.5, range(len(words)))

plt.gca().xaxis.tick_top()
plt.gca().xaxis.set_label_position("top")
plt.xticks(np.arange(len(input_words)) + 0.5, input_words, rotation=45)

plt.show()

If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`


ValueError: BertLMHeadModel does not support `device_map='auto'`. To implement support, the model class needs to implement the `_no_split_modules` attribute.