# Inseq：可视化解释LLM的输出

> [HW7: Understand what Generative AI is thinking](https://colab.research.google.com/drive/1Xnz0GHC0yWO2Do0aAYBCq9zL45lbiRjM?usp=sharing#scrollTo=UFOUfh2k1jFN) 中文镜像版
>
> 指导文章：[12. Inseq 特征归因：可视化解释 LLM 的输出](https://github.com/Hoper-J/LLM-Guide-and-Demos-zh_CN/blob/master/Guide/12.%20Inseq%20特征归因：可视化解释%20LLM%20的输出.md)

Feature attribution 特征归因，你可以将其当做对输出的解释，就像图像分类模型可以可视化关注区域一样，LLM 一样也可以。

在线链接：[Kaggle](https://www.kaggle.com/code/aidemos/10-inseq-llm) | [Colab](https://colab.research.google.com/drive/1bWqGtRaG3aO7Vo149wIPHaz_XKnbJqlE?usp=sharing)


In [None]:
!uv add inseq
!uv add transformers
!uv add bitsandbytes
!uv add accelerate
!uv add sacremoses

## 可视化翻译任务

### 下载中译英模型到本地

我们这里使用多线程的方法进行快速下载。

如果直接运行以下命令报错，根据[a. 使用 HFD 加快 Hugging Face 模型和数据集的下载](https://github.com/Hoper-J/LLM-Guide-and-Demos-zh_CN/blob/master/Guide/a.%20使用%20HFD%20加快%20Hugging%20Face%20模型和数据集的下载.md)进行前置安装。

当然，你也可以取消我注释的部分，使用官方的命令进行安装，但是会很慢。

In [None]:
!wget https://hf-mirror.com/hfd/hfd.sh
!chmod a+x hfd.sh

In [None]:
!export HF_ENDPOINT=https://hf-mirror.com
!./hfd.sh 'Helsinki-NLP/opus-mt-zh-en' --tool aria2c -x 16

In [1]:
import inseq

# 定义要使用的归因方法列表
attribution_methods = ['saliency', 'attention']

for method in attribution_methods:
    print(f"======= 归因方法: {method} =======")
    # 加载中译英模型并设置归因方法
    # model = inseq.load_model("Helsinki-NLP/opus-mt-zh-en", method)
    model = inseq.load_model("opus-mt-zh-en", method)  # 导入之前下载到本地的模型

    # 使用指定的归因方法对输入文本进行归因
    attribution_result = model.attribute(
        input_texts="我喜歡機器學習和人工智慧。",
    )

    # 从tokenizer中去除 '▁' 前缀以避免混淆（可以忽略这段代码）
    for attr in attribution_result.sequence_attributions:
        for item in attr.source:
            item.token = item.token.replace('▁', '')
        for item in attr.target:
            item.token = item.token.replace('▁', '')

    # 显示归因结果
    attribution_result.show()



Attributing with saliency...: 100%|██████████████████████████████████████| 10/10 [00:00<00:00, 69.63it/s]


Unnamed: 0_level_0,I,like,machine,learning,and,artificial,intelligence,.,</s>
我喜歡,0.23,0.438,0.084,0.064,0.183,0.052,0.058,0.25,0.193
機器,0.162,0.135,0.491,0.225,0.13,0.069,0.07,0.141,0.151
學,0.074,0.054,0.151,0.188,0.07,0.03,0.032,0.069,0.061
習,0.07,0.078,0.106,0.267,0.118,0.052,0.055,0.083,0.064
和,0.067,0.047,0.026,0.05,0.158,0.036,0.03,0.073,0.073
人工,0.099,0.052,0.039,0.063,0.104,0.399,0.203,0.091,0.114
智慧,0.113,0.068,0.039,0.073,0.109,0.274,0.437,0.111,0.153
。,0.119,0.066,0.02,0.027,0.083,0.045,0.057,0.118,0.124
</s>,0.066,0.063,0.044,0.043,0.046,0.043,0.058,0.064,0.068
probability,0.767,0.589,0.425,0.671,0.856,0.647,0.866,0.891,0.896




Attributing with attention...: 100%|████████████████████████████████████| 10/10 [00:00<00:00, 233.77it/s]


Unnamed: 0_level_0,I,like,machine,learning,and,artificial,intelligence,.,</s>
我喜歡,0.19,0.517,0.28,0.018,0.048,0.032,0.014,0.049,0.023
機器,0.084,0.037,0.233,0.208,0.038,0.036,0.024,0.028,0.015
學,0.012,0.01,0.044,0.129,0.105,0.018,0.006,0.017,0.012
習,0.021,0.015,0.066,0.155,0.109,0.024,0.009,0.024,0.007
和,0.058,0.042,0.063,0.039,0.152,0.178,0.015,0.034,0.036
人工,0.023,0.015,0.022,0.027,0.025,0.26,0.183,0.033,0.016
智慧,0.029,0.017,0.033,0.046,0.04,0.132,0.365,0.155,0.013
。,0.153,0.087,0.047,0.019,0.059,0.076,0.037,0.163,0.363
</s>,0.43,0.262,0.212,0.359,0.425,0.245,0.348,0.496,0.516
probability,0.767,0.589,0.425,0.671,0.856,0.647,0.866,0.891,0.896


## 可视化文本生成任务

### 下载 GPT-2 XL 模型

如果只是为了查看的话，可以取消注释部分下载 GPT-2，因为 GPT-2 XL 下载完需要占用 30G，记得注释原来的部分。

In [None]:
!export HF_ENDPOINT=https://hf-mirror.com
!./hfd.sh 'gpt2-xl' --tool aria2c -x 16
#!./hfd.sh 'gpt2' --tool aria2c -x 16

In [2]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

# 创建一个 BitsAndBytesConfig 对象，用于配置量化选项
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    "gpt2-xl",  # gpt2
    quantization_config=bnb_config,
    device_map={"": 0}
)

In [4]:
import inseq

# 定义要使用的归因方法列表
attribution_methods = ['saliency', 'attention']

for method in attribution_methods:
    print(f"======= 归因方法: {method} =======")
    # 使用指定的归因方法加载模型
    inseq_model = inseq.load_model(model, method)

    # 对输入文本进行归因分析
    attribution_result = inseq_model.attribute(
        input_texts="Hello world",
        step_scores=["probability"],
    )

    # 清理 tokenizer 中的特殊字符（可选）
    for attr in attribution_result.sequence_attributions:
        for item in attr.source:
            item.token = item.token.replace('Ġ', '')
        for item in attr.target:
            item.token = item.token.replace('Ġ', '')

    # 显示归因结果
    attribution_result.show()

The model is loaded in 8bit mode. The device cannot be changed after loading the model.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.




Attributing with saliency...: 100%|██████████████████████████████████████| 20/20 [00:04<00:00,  3.89it/s]


Unnamed: 0_level_0,"""",Ċ,Ċ,The,first,thing,you,'ll,notice,is,that,the,code,is,a,bit,more,verb
Hello,0.647,0.505,0.359,0.401,0.388,0.344,0.35,0.261,0.261,0.302,0.305,0.232,0.368,0.224,0.283,0.236,0.194,0.241
world,0.353,0.315,0.223,0.238,0.249,0.2,0.242,0.176,0.169,0.212,0.197,0.165,0.246,0.149,0.164,0.153,0.122,0.152
"""",Unnamed: 1_level_3,0.18,0.183,0.159,0.133,0.128,0.103,0.093,0.084,0.096,0.073,0.071,0.043,0.058,0.064,0.068,0.047,0.052
Ċ,Unnamed: 1_level_4,Unnamed: 2_level_4,0.235,0.092,0.058,0.056,0.044,0.045,0.044,0.042,0.039,0.065,0.036,0.042,0.043,0.033,0.025,0.027
Ċ,Unnamed: 1_level_5,Unnamed: 2_level_5,Unnamed: 3_level_5,0.11,0.058,0.062,0.046,0.053,0.052,0.041,0.035,0.04,0.025,0.037,0.027,0.028,0.025,0.024
The,Unnamed: 1_level_6,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,0.114,0.072,0.066,0.059,0.045,0.04,0.045,0.041,0.02,0.041,0.03,0.037,0.025,0.027
first,Unnamed: 1_level_7,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,0.138,0.048,0.074,0.06,0.046,0.047,0.04,0.021,0.051,0.027,0.039,0.037,0.027
thing,Unnamed: 1_level_8,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,0.101,0.101,0.087,0.048,0.048,0.053,0.028,0.043,0.037,0.048,0.037,0.032
you,Unnamed: 1_level_9,Unnamed: 2_level_9,Unnamed: 3_level_9,Unnamed: 4_level_9,Unnamed: 5_level_9,Unnamed: 6_level_9,Unnamed: 7_level_9,0.137,0.082,0.047,0.048,0.049,0.028,0.043,0.044,0.054,0.036,0.035
'll,Unnamed: 1_level_10,Unnamed: 2_level_10,Unnamed: 3_level_10,Unnamed: 4_level_10,Unnamed: 5_level_10,Unnamed: 6_level_10,Unnamed: 7_level_10,Unnamed: 8_level_10,0.117,0.047,0.051,0.064,0.032,0.052,0.049,0.063,0.043,0.042
notice,Unnamed: 1_level_11,Unnamed: 2_level_11,Unnamed: 3_level_11,Unnamed: 4_level_11,Unnamed: 5_level_11,Unnamed: 6_level_11,Unnamed: 7_level_11,Unnamed: 8_level_11,Unnamed: 9_level_11,0.079,0.074,0.084,0.055,0.069,0.06,0.063,0.065,0.048
is,Unnamed: 1_level_12,Unnamed: 2_level_12,Unnamed: 3_level_12,Unnamed: 4_level_12,Unnamed: 5_level_12,Unnamed: 6_level_12,Unnamed: 7_level_12,Unnamed: 8_level_12,Unnamed: 9_level_12,Unnamed: 10_level_12,0.037,0.036,0.02,0.022,0.026,0.018,0.028,0.022
that,Unnamed: 1_level_13,Unnamed: 2_level_13,Unnamed: 3_level_13,Unnamed: 4_level_13,Unnamed: 5_level_13,Unnamed: 6_level_13,Unnamed: 7_level_13,Unnamed: 8_level_13,Unnamed: 9_level_13,Unnamed: 10_level_13,Unnamed: 11_level_13,0.061,0.022,0.031,0.025,0.021,0.031,0.02
the,Unnamed: 1_level_14,Unnamed: 2_level_14,Unnamed: 3_level_14,Unnamed: 4_level_14,Unnamed: 5_level_14,Unnamed: 6_level_14,Unnamed: 7_level_14,Unnamed: 8_level_14,Unnamed: 9_level_14,Unnamed: 10_level_14,Unnamed: 11_level_14,Unnamed: 12_level_14,0.056,0.029,0.027,0.023,0.026,0.027
code,Unnamed: 1_level_15,Unnamed: 2_level_15,Unnamed: 3_level_15,Unnamed: 4_level_15,Unnamed: 5_level_15,Unnamed: 6_level_15,Unnamed: 7_level_15,Unnamed: 8_level_15,Unnamed: 9_level_15,Unnamed: 10_level_15,Unnamed: 11_level_15,Unnamed: 12_level_15,Unnamed: 13_level_15,0.111,0.058,0.055,0.082,0.088
is,Unnamed: 1_level_16,Unnamed: 2_level_16,Unnamed: 3_level_16,Unnamed: 4_level_16,Unnamed: 5_level_16,Unnamed: 6_level_16,Unnamed: 7_level_16,Unnamed: 8_level_16,Unnamed: 9_level_16,Unnamed: 10_level_16,Unnamed: 11_level_16,Unnamed: 12_level_16,Unnamed: 13_level_16,Unnamed: 14_level_16,0.036,0.016,0.025,0.014
a,Unnamed: 1_level_17,Unnamed: 2_level_17,Unnamed: 3_level_17,Unnamed: 4_level_17,Unnamed: 5_level_17,Unnamed: 6_level_17,Unnamed: 7_level_17,Unnamed: 8_level_17,Unnamed: 9_level_17,Unnamed: 10_level_17,Unnamed: 11_level_17,Unnamed: 12_level_17,Unnamed: 13_level_17,Unnamed: 14_level_17,Unnamed: 15_level_17,0.044,0.027,0.013
bit,Unnamed: 1_level_18,Unnamed: 2_level_18,Unnamed: 3_level_18,Unnamed: 4_level_18,Unnamed: 5_level_18,Unnamed: 6_level_18,Unnamed: 7_level_18,Unnamed: 8_level_18,Unnamed: 9_level_18,Unnamed: 10_level_18,Unnamed: 11_level_18,Unnamed: 12_level_18,Unnamed: 13_level_18,Unnamed: 14_level_18,Unnamed: 15_level_18,Unnamed: 16_level_18,0.125,0.048
more,Unnamed: 1_level_19,Unnamed: 2_level_19,Unnamed: 3_level_19,Unnamed: 4_level_19,Unnamed: 5_level_19,Unnamed: 6_level_19,Unnamed: 7_level_19,Unnamed: 8_level_19,Unnamed: 9_level_19,Unnamed: 10_level_19,Unnamed: 11_level_19,Unnamed: 12_level_19,Unnamed: 13_level_19,Unnamed: 14_level_19,Unnamed: 15_level_19,Unnamed: 16_level_19,Unnamed: 17_level_19,0.059
verb,Unnamed: 1_level_20,Unnamed: 2_level_20,Unnamed: 3_level_20,Unnamed: 4_level_20,Unnamed: 5_level_20,Unnamed: 6_level_20,Unnamed: 7_level_20,Unnamed: 8_level_20,Unnamed: 9_level_20,Unnamed: 10_level_20,Unnamed: 11_level_20,Unnamed: 12_level_20,Unnamed: 13_level_20,Unnamed: 14_level_20,Unnamed: 15_level_20,Unnamed: 16_level_20,Unnamed: 17_level_20,Unnamed: 18_level_20
probability,0.197,0.097,0.988,0.043,0.046,0.211,0.25,0.248,0.52,0.557,0.662,0.257,0.046,0.483,0.106,0.438,0.207,0.259


The model is loaded in 8bit mode. The device cannot be changed after loading the model.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.




Attributing with attention...: 100%|█████████████████████████████████████| 20/20 [00:01<00:00, 14.14it/s]


Unnamed: 0_level_0,"""",Ċ,Ċ,The,first,thing,you,'ll,notice,is,that,the,code,is,a,bit,more,verb
Hello,0.867,0.758,0.675,0.619,0.619,0.636,0.611,0.626,0.623,0.586,0.518,0.527,0.551,0.571,0.555,0.577,0.563,0.556
world,0.133,0.109,0.054,0.068,0.055,0.038,0.041,0.03,0.024,0.028,0.037,0.04,0.034,0.024,0.015,0.015,0.013,0.012
"""",Unnamed: 1_level_3,0.133,0.086,0.084,0.062,0.048,0.043,0.037,0.031,0.03,0.035,0.037,0.035,0.025,0.019,0.018,0.016,0.016
Ċ,Unnamed: 1_level_4,Unnamed: 2_level_4,0.185,0.114,0.092,0.075,0.068,0.061,0.054,0.059,0.058,0.052,0.046,0.042,0.038,0.034,0.032,0.033
Ċ,Unnamed: 1_level_5,Unnamed: 2_level_5,Unnamed: 3_level_5,0.116,0.093,0.072,0.065,0.049,0.041,0.043,0.056,0.054,0.041,0.03,0.024,0.021,0.02,0.02
The,Unnamed: 1_level_6,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,0.08,0.078,0.061,0.041,0.034,0.033,0.043,0.036,0.031,0.025,0.019,0.017,0.016,0.016
first,Unnamed: 1_level_7,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,0.055,0.059,0.037,0.031,0.032,0.026,0.022,0.02,0.014,0.013,0.012,0.012,0.011
thing,Unnamed: 1_level_8,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,0.052,0.058,0.045,0.041,0.027,0.02,0.017,0.014,0.013,0.011,0.011,0.009
you,Unnamed: 1_level_9,Unnamed: 2_level_9,Unnamed: 3_level_9,Unnamed: 4_level_9,Unnamed: 5_level_9,Unnamed: 6_level_9,Unnamed: 7_level_9,0.062,0.06,0.04,0.03,0.022,0.018,0.018,0.013,0.011,0.01,0.009
'll,Unnamed: 1_level_10,Unnamed: 2_level_10,Unnamed: 3_level_10,Unnamed: 4_level_10,Unnamed: 5_level_10,Unnamed: 6_level_10,Unnamed: 7_level_10,Unnamed: 8_level_10,0.057,0.048,0.027,0.02,0.015,0.017,0.014,0.011,0.01,0.009
notice,Unnamed: 1_level_11,Unnamed: 2_level_11,Unnamed: 3_level_11,Unnamed: 4_level_11,Unnamed: 5_level_11,Unnamed: 6_level_11,Unnamed: 7_level_11,Unnamed: 8_level_11,Unnamed: 9_level_11,0.059,0.065,0.04,0.028,0.025,0.025,0.017,0.018,0.016
is,Unnamed: 1_level_12,Unnamed: 2_level_12,Unnamed: 3_level_12,Unnamed: 4_level_12,Unnamed: 5_level_12,Unnamed: 6_level_12,Unnamed: 7_level_12,Unnamed: 8_level_12,Unnamed: 9_level_12,Unnamed: 10_level_12,0.08,0.054,0.031,0.026,0.03,0.019,0.02,0.018
that,Unnamed: 1_level_13,Unnamed: 2_level_13,Unnamed: 3_level_13,Unnamed: 4_level_13,Unnamed: 5_level_13,Unnamed: 6_level_13,Unnamed: 7_level_13,Unnamed: 8_level_13,Unnamed: 9_level_13,Unnamed: 10_level_13,Unnamed: 11_level_13,0.075,0.063,0.048,0.041,0.03,0.032,0.028
the,Unnamed: 1_level_14,Unnamed: 2_level_14,Unnamed: 3_level_14,Unnamed: 4_level_14,Unnamed: 5_level_14,Unnamed: 6_level_14,Unnamed: 7_level_14,Unnamed: 8_level_14,Unnamed: 9_level_14,Unnamed: 10_level_14,Unnamed: 11_level_14,Unnamed: 12_level_14,0.071,0.059,0.039,0.03,0.029,0.026
code,Unnamed: 1_level_15,Unnamed: 2_level_15,Unnamed: 3_level_15,Unnamed: 4_level_15,Unnamed: 5_level_15,Unnamed: 6_level_15,Unnamed: 7_level_15,Unnamed: 8_level_15,Unnamed: 9_level_15,Unnamed: 10_level_15,Unnamed: 11_level_15,Unnamed: 12_level_15,Unnamed: 13_level_15,0.062,0.071,0.052,0.048,0.045
is,Unnamed: 1_level_16,Unnamed: 2_level_16,Unnamed: 3_level_16,Unnamed: 4_level_16,Unnamed: 5_level_16,Unnamed: 6_level_16,Unnamed: 7_level_16,Unnamed: 8_level_16,Unnamed: 9_level_16,Unnamed: 10_level_16,Unnamed: 11_level_16,Unnamed: 12_level_16,Unnamed: 13_level_16,Unnamed: 14_level_16,0.069,0.059,0.058,0.048
a,Unnamed: 1_level_17,Unnamed: 2_level_17,Unnamed: 3_level_17,Unnamed: 4_level_17,Unnamed: 5_level_17,Unnamed: 6_level_17,Unnamed: 7_level_17,Unnamed: 8_level_17,Unnamed: 9_level_17,Unnamed: 10_level_17,Unnamed: 11_level_17,Unnamed: 12_level_17,Unnamed: 13_level_17,Unnamed: 14_level_17,Unnamed: 15_level_17,0.067,0.049,0.038
bit,Unnamed: 1_level_18,Unnamed: 2_level_18,Unnamed: 3_level_18,Unnamed: 4_level_18,Unnamed: 5_level_18,Unnamed: 6_level_18,Unnamed: 7_level_18,Unnamed: 8_level_18,Unnamed: 9_level_18,Unnamed: 10_level_18,Unnamed: 11_level_18,Unnamed: 12_level_18,Unnamed: 13_level_18,Unnamed: 14_level_18,Unnamed: 15_level_18,Unnamed: 16_level_18,0.044,0.042
more,Unnamed: 1_level_19,Unnamed: 2_level_19,Unnamed: 3_level_19,Unnamed: 4_level_19,Unnamed: 5_level_19,Unnamed: 6_level_19,Unnamed: 7_level_19,Unnamed: 8_level_19,Unnamed: 9_level_19,Unnamed: 10_level_19,Unnamed: 11_level_19,Unnamed: 12_level_19,Unnamed: 13_level_19,Unnamed: 14_level_19,Unnamed: 15_level_19,Unnamed: 16_level_19,Unnamed: 17_level_19,0.049
verb,Unnamed: 1_level_20,Unnamed: 2_level_20,Unnamed: 3_level_20,Unnamed: 4_level_20,Unnamed: 5_level_20,Unnamed: 6_level_20,Unnamed: 7_level_20,Unnamed: 8_level_20,Unnamed: 9_level_20,Unnamed: 10_level_20,Unnamed: 11_level_20,Unnamed: 12_level_20,Unnamed: 13_level_20,Unnamed: 14_level_20,Unnamed: 15_level_20,Unnamed: 16_level_20,Unnamed: 17_level_20,Unnamed: 18_level_20
probability,0.197,0.097,0.988,0.043,0.046,0.211,0.25,0.248,0.52,0.557,0.662,0.257,0.046,0.483,0.106,0.438,0.207,0.259
