In [1]:
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

# 加载模型和 tokenizer
def load_model_and_tokenizer(model_name, device="cuda"):
    print(f"Loading model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        attn_implementation="eager",
    )
    return model, tokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
device = "cuda" if torch.cuda.is_available() else "cpu"

# 加载模型和 tokenizer
model, tokenizer = load_model_and_tokenizer(model_name, device)

# 加载和分词输入
dataset = load_dataset("emozilla/pg19", split="test", trust_remote_code=True)
text = dataset[0]["text"][:16384]  # 仅取前 16384 个字符
input_ids = tokenizer(text, return_tensors="pt", max_length=16384, truncation=True).input_ids


Loading model: mistralai/Mistral-7B-Instruct-v0.1


Loading checkpoint shards: 100%|██████████| 2/2 [00:21<00:00, 10.56s/it]
Some parameters are on the meta device because they were offloaded to the cpu.


In [3]:
print(text[:100])
print(input_ids[:100])

ST. PAUL***


E-text prepared by Josephine Paolucci and the Project Gutenberg Online
Distributed Pro
tensor([[    1,   920, 28723,  ..., 28725,  3364, 28705]])


In [None]:
# 提取注意力分数
# 模型前向
with torch.no_grad():
    output = model(input_ids.to(device), output_attentions=True)

Loading model: mistralai/Mistral-7B-Instruct-v0.1


Loading checkpoint shards: 100%|██████████| 2/2 [00:24<00:00, 12.39s/it]
Some parameters are on the meta device because they were offloaded to the cpu and disk.
Using the latest cached version of the dataset since emozilla/pg19 couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at C:\Users\fzkuj\.cache\huggingface\datasets\emozilla___pg19\default\0.0.0\c021754c8e01c5b1cc83a1f549c1f97fbbb756b8 (last modified on Sat Dec 28 15:02:36 2024).


In [4]:
print(output.keys())

33


In [None]:
output["attentions"]

In [None]:


# 可视化某一层的注意力分数
def visualize_attention_scores(attention_scores, layer_idx, head_idx, seq_len, output_file=None):
    plt.figure(figsize=(10, 8))
    attention_matrix = attention_scores[layer_idx][0, head_idx, :seq_len, :seq_len]
    plt.imshow(attention_matrix, cmap="viridis", aspect="auto")
    plt.colorbar(label="Attention Score")
    plt.title(f"Attention Layer {layer_idx + 1}, Head {head_idx + 1}")
    plt.xlabel("Key Position")
    plt.ylabel("Query Position")
    if output_file:
        plt.savefig(output_file)
    plt.show()


In [None]:

attention_scores = extract_attention_scores(model, input_ids, device)

# 保存或可视化
layer_idx = 0  # 可视化第几层
head_idx = 0   # 可视化第几个注意力头
seq_len = 1024  # 只显示前 1024 长度的分数（避免图太大）
visualize_attention_scores(attention_scores, layer_idx, head_idx, seq_len, output_file="attention_layer_1_head_1.png")

# 保存完整注意力分数到文件
torch.save(attention_scores, "attention_scores.pt")
print("Attention scores saved to 'attention_scores.pt'")