# Inseq：可视化解释LLM的输出

> [HW7: Understand what Generative AI is thinking](https://colab.research.google.com/drive/1Xnz0GHC0yWO2Do0aAYBCq9zL45lbiRjM?usp=sharing#scrollTo=UFOUfh2k1jFN) 中文镜像版
>
> 指导文章：[12. Inseq 特征归因：可视化解释 LLM 的输出](https://github.com/Hoper-J/LLM-Guide-and-Demos-zh_CN/blob/master/Guide/12.%20Inseq%20特征归因：可视化解释%20LLM%20的输出.md)

Feature attribution 特征归因，你可以将其当做对输出的解释，就像图像分类模型可以可视化关注区域一样，LLM 一样也可以。

In [1]:
!pip install inseq==0.5.0
!pip install transformers==4.40.2
!pip install bitsandbytes==0.43.1
!pip install -U accelerate==0.28.0

Collecting inseq==0.5.0
  Downloading inseq-0.5.0-py3-none-any.whl.metadata (18 kB)
Collecting captum>=0.6.0 (from inseq==0.5.0)
  Downloading captum-0.7.0-py3-none-any.whl.metadata (26 kB)
Collecting jaxtyping<0.3.0,>=0.2.23 (from inseq==0.5.0)
  Downloading jaxtyping-0.2.36-py3-none-any.whl.metadata (6.5 kB)
Collecting nvidia-cublas-cu11>=11.10.3.66 (from inseq==0.5.0)
  Downloading nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu11>=11.7.101 (from inseq==0.5.0)
  Downloading nvidia_cuda_cupti_cu11-11.8.87-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cuda-nvrtc-cu11>=11.7.99 (from inseq==0.5.0)
  Downloading nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu11>=11.7.99 (from inseq==0.5.0)
  Downloading nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cudnn-cu11>=8.5.0.96 (from i

## 可视化翻译任务

In [4]:
import inseq

# 要使用的归因方法列表
attribution_methods = ['saliency', 'attention']

for method in attribution_methods:
    print(f"======= 归因方法: {method} =======")
    # 加载中译英模型并设置归因方法
    model = inseq.load_model("Helsinki-NLP/opus-mt-zh-en", method)
#     model = inseq.load_model("opus-mt-zh-en", method)  # 本地

    # 使用指定的归因方法对输入文本进行归因
    attribution_result = model.attribute(
        input_texts="今天学到了很多，我很开心。",
        step_scores=["probability"],
    )

    # 从tokenizer中去除 '▁' 前缀以避免混淆（可以忽略这段代码）
    for attr in attribution_result.sequence_attributions:
        for item in attr.source:
            item.token = item.token.replace('▁', '')
        for item in attr.target:
            item.token = item.token.replace('▁', '')

    # 显示归因结果
    attribution_result.show()



Attributing with saliency...: 100%|██████████| 16/16 [00:05<00:00,  2.54it/s]


Unnamed: 0_level_0,I,',ve,learned,a,lot,today,",",and,I,',m,happy,.,</s>
今天,0.117,0.138,0.102,0.119,0.107,0.114,0.297,0.088,0.105,0.075,0.078,0.077,0.059,0.093,0.127
学到,0.274,0.244,0.231,0.397,0.272,0.269,0.206,0.172,0.168,0.182,0.143,0.159,0.144,0.193,0.208
了很多,0.144,0.101,0.115,0.113,0.24,0.273,0.116,0.118,0.105,0.075,0.07,0.074,0.071,0.102,0.096
",",0.061,0.048,0.055,0.031,0.036,0.045,0.047,0.105,0.071,0.039,0.034,0.03,0.029,0.04,0.063
我,0.08,0.07,0.077,0.045,0.044,0.042,0.044,0.112,0.094,0.099,0.084,0.065,0.07,0.072,0.062
很开心,0.174,0.21,0.266,0.171,0.147,0.133,0.137,0.222,0.246,0.372,0.427,0.44,0.473,0.32,0.188
。,0.092,0.112,0.092,0.066,0.09,0.072,0.094,0.111,0.129,0.096,0.096,0.093,0.093,0.111,0.167
</s>,0.059,0.076,0.063,0.058,0.064,0.054,0.058,0.073,0.083,0.062,0.069,0.064,0.061,0.069,0.091
probability,0.44,0.396,0.267,0.725,0.65,0.826,0.722,0.342,0.551,0.817,0.762,0.888,0.491,0.788,0.895




Attributing with attention...: 100%|██████████| 16/16 [00:01<00:00, 12.40it/s]


Unnamed: 0_level_0,I,',ve,learned,a,lot,today,",",and,I,',m,happy,.,</s>
今天,0.09,0.073,0.054,0.071,0.057,0.055,0.191,0.16,0.038,0.023,0.023,0.018,0.02,0.015,0.017
学到,0.184,0.211,0.085,0.245,0.217,0.055,0.063,0.033,0.028,0.023,0.03,0.019,0.023,0.02,0.013
了很多,0.12,0.092,0.073,0.117,0.257,0.343,0.205,0.023,0.029,0.021,0.016,0.016,0.021,0.015,0.014
",",0.08,0.046,0.031,0.022,0.022,0.024,0.045,0.062,0.133,0.079,0.022,0.015,0.013,0.009,0.055
我,0.053,0.145,0.061,0.039,0.017,0.018,0.024,0.098,0.123,0.133,0.184,0.082,0.115,0.025,0.021
很开心,0.061,0.18,0.123,0.138,0.063,0.044,0.061,0.085,0.168,0.23,0.356,0.236,0.461,0.248,0.037
。,0.088,0.048,0.059,0.036,0.028,0.045,0.032,0.071,0.115,0.117,0.052,0.054,0.04,0.092,0.279
</s>,0.324,0.205,0.515,0.332,0.34,0.415,0.379,0.469,0.367,0.374,0.317,0.561,0.307,0.576,0.565
probability,0.44,0.396,0.267,0.725,0.65,0.826,0.722,0.342,0.551,0.817,0.762,0.888,0.491,0.788,0.895


## 可视化文本生成任务

### 下载 GPT-2 XL 模型

如果只是为了查看的话，可以取消注释部分下载 GPT-2，因为 GPT-2 XL 下载完需要占用 30G，记得注释原来的部分。

In [3]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig


# 创建一个 BitsAndBytesConfig 对象，用于配置量化选项
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
    "gpt2-xl",  # "gpt2"
    quantization_config=bnb_config,
    device_map={"": 0}
)

config.json:   0%|          | 0.00/689 [00:00<?, ?B/s]

ImportError: Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` and the latest version of bitsandbytes: `pip install -i https://pypi.org/simple/ bitsandbytes`

In [None]:
import inseq


# 定义要使用的归因方法列表
attribution_methods = ['saliency', 'attention']

for method in attribution_methods:
    print(f"======= 归因方法: {method} =======")
    # 使用指定的归因方法加载模型
    inseq_model = inseq.load_model(model, method)

    # 对输入文本进行归因分析
    attribution_result = inseq_model.attribute(
        input_texts="Hello world",
        step_scores=["probability"],

    )

    # 清理 tokenizer 中的特殊字符（可选）
    for attr in attribution_result.sequence_attributions:
        for item in attr.source:
            item.token = item.token.replace('Ġ', '')
        for item in attr.target:
            item.token = item.token.replace('Ġ', '')

    # 显示归因结果
    # 注意是一个个生成的  所以相关性是阶梯状的
    attribution_result.show()



The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Attributing with saliency...: 100%|██████████| 20/20 [00:11<00:00,  1.54it/s]


Unnamed: 0_level_0,"""",Ċ,Ċ,The,first,thing,you,'ll,notice,is,that,the,code,is,very,simple,.,It
Hello,0.647,0.505,0.354,0.401,0.392,0.343,0.353,0.263,0.267,0.302,0.304,0.237,0.365,0.22,0.25,0.255,0.146,0.175
world,0.353,0.315,0.22,0.237,0.254,0.199,0.242,0.179,0.173,0.212,0.197,0.164,0.245,0.148,0.144,0.162,0.084,0.112
"""",Unnamed: 1_level_3,0.18,0.186,0.159,0.127,0.129,0.102,0.091,0.085,0.094,0.073,0.072,0.045,0.057,0.047,0.056,0.058,0.06
Ċ,Unnamed: 1_level_4,Unnamed: 2_level_4,0.241,0.092,0.055,0.056,0.044,0.045,0.043,0.042,0.039,0.065,0.036,0.042,0.033,0.032,0.035,0.029
Ċ,Unnamed: 1_level_5,Unnamed: 2_level_5,Unnamed: 3_level_5,0.11,0.057,0.062,0.045,0.052,0.05,0.041,0.035,0.039,0.026,0.037,0.031,0.028,0.037,0.026
The,Unnamed: 1_level_6,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,0.114,0.073,0.066,0.058,0.044,0.04,0.045,0.04,0.02,0.041,0.025,0.029,0.03,0.035
first,Unnamed: 1_level_7,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,0.138,0.048,0.074,0.058,0.046,0.046,0.039,0.022,0.051,0.03,0.029,0.035,0.031
thing,Unnamed: 1_level_8,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,0.101,0.101,0.084,0.049,0.049,0.053,0.028,0.043,0.037,0.033,0.044,0.037
you,Unnamed: 1_level_9,Unnamed: 2_level_9,Unnamed: 3_level_9,Unnamed: 4_level_9,Unnamed: 5_level_9,Unnamed: 6_level_9,Unnamed: 7_level_9,0.137,0.079,0.047,0.048,0.049,0.028,0.043,0.036,0.051,0.044,0.052
'll,Unnamed: 1_level_10,Unnamed: 2_level_10,Unnamed: 3_level_10,Unnamed: 4_level_10,Unnamed: 5_level_10,Unnamed: 6_level_10,Unnamed: 7_level_10,Unnamed: 8_level_10,0.115,0.048,0.051,0.063,0.032,0.052,0.053,0.038,0.053,0.054
notice,Unnamed: 1_level_11,Unnamed: 2_level_11,Unnamed: 3_level_11,Unnamed: 4_level_11,Unnamed: 5_level_11,Unnamed: 6_level_11,Unnamed: 7_level_11,Unnamed: 8_level_11,Unnamed: 9_level_11,0.079,0.075,0.083,0.056,0.069,0.071,0.044,0.066,0.051
is,Unnamed: 1_level_12,Unnamed: 2_level_12,Unnamed: 3_level_12,Unnamed: 4_level_12,Unnamed: 5_level_12,Unnamed: 6_level_12,Unnamed: 7_level_12,Unnamed: 8_level_12,Unnamed: 9_level_12,Unnamed: 10_level_12,0.038,0.036,0.02,0.022,0.024,0.024,0.028,0.023
that,Unnamed: 1_level_13,Unnamed: 2_level_13,Unnamed: 3_level_13,Unnamed: 4_level_13,Unnamed: 5_level_13,Unnamed: 6_level_13,Unnamed: 7_level_13,Unnamed: 8_level_13,Unnamed: 9_level_13,Unnamed: 10_level_13,Unnamed: 11_level_13,0.06,0.022,0.031,0.027,0.018,0.035,0.025
the,Unnamed: 1_level_14,Unnamed: 2_level_14,Unnamed: 3_level_14,Unnamed: 4_level_14,Unnamed: 5_level_14,Unnamed: 6_level_14,Unnamed: 7_level_14,Unnamed: 8_level_14,Unnamed: 9_level_14,Unnamed: 10_level_14,Unnamed: 11_level_14,Unnamed: 12_level_14,0.055,0.03,0.044,0.029,0.03,0.04
code,Unnamed: 1_level_15,Unnamed: 2_level_15,Unnamed: 3_level_15,Unnamed: 4_level_15,Unnamed: 5_level_15,Unnamed: 6_level_15,Unnamed: 7_level_15,Unnamed: 8_level_15,Unnamed: 9_level_15,Unnamed: 10_level_15,Unnamed: 11_level_15,Unnamed: 12_level_15,Unnamed: 13_level_15,0.113,0.11,0.085,0.065,0.105
is,Unnamed: 1_level_16,Unnamed: 2_level_16,Unnamed: 3_level_16,Unnamed: 4_level_16,Unnamed: 5_level_16,Unnamed: 6_level_16,Unnamed: 7_level_16,Unnamed: 8_level_16,Unnamed: 9_level_16,Unnamed: 10_level_16,Unnamed: 11_level_16,Unnamed: 12_level_16,Unnamed: 13_level_16,Unnamed: 14_level_16,0.038,0.019,0.028,0.016
very,Unnamed: 1_level_17,Unnamed: 2_level_17,Unnamed: 3_level_17,Unnamed: 4_level_17,Unnamed: 5_level_17,Unnamed: 6_level_17,Unnamed: 7_level_17,Unnamed: 8_level_17,Unnamed: 9_level_17,Unnamed: 10_level_17,Unnamed: 11_level_17,Unnamed: 12_level_17,Unnamed: 13_level_17,Unnamed: 14_level_17,Unnamed: 15_level_17,0.069,0.048,0.025
simple,Unnamed: 1_level_18,Unnamed: 2_level_18,Unnamed: 3_level_18,Unnamed: 4_level_18,Unnamed: 5_level_18,Unnamed: 6_level_18,Unnamed: 7_level_18,Unnamed: 8_level_18,Unnamed: 9_level_18,Unnamed: 10_level_18,Unnamed: 11_level_18,Unnamed: 12_level_18,Unnamed: 13_level_18,Unnamed: 14_level_18,Unnamed: 15_level_18,Unnamed: 16_level_18,0.135,0.067
.,Unnamed: 1_level_19,Unnamed: 2_level_19,Unnamed: 3_level_19,Unnamed: 4_level_19,Unnamed: 5_level_19,Unnamed: 6_level_19,Unnamed: 7_level_19,Unnamed: 8_level_19,Unnamed: 9_level_19,Unnamed: 10_level_19,Unnamed: 11_level_19,Unnamed: 12_level_19,Unnamed: 13_level_19,Unnamed: 14_level_19,Unnamed: 15_level_19,Unnamed: 16_level_19,Unnamed: 17_level_19,0.036
It,Unnamed: 1_level_20,Unnamed: 2_level_20,Unnamed: 3_level_20,Unnamed: 4_level_20,Unnamed: 5_level_20,Unnamed: 6_level_20,Unnamed: 7_level_20,Unnamed: 8_level_20,Unnamed: 9_level_20,Unnamed: 10_level_20,Unnamed: 11_level_20,Unnamed: 12_level_20,Unnamed: 13_level_20,Unnamed: 14_level_20,Unnamed: 15_level_20,Unnamed: 16_level_20,Unnamed: 17_level_20,Unnamed: 18_level_20
probability,0.197,0.097,0.986,0.043,0.045,0.211,0.247,0.249,0.512,0.577,0.661,0.256,0.045,0.476,0.102,0.182,0.555,0.125


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.




Attributing with attention...: 100%|██████████| 20/20 [00:04<00:00,  4.26it/s]


Unnamed: 0_level_0,"""",Ċ,Ċ,The,first,thing,you,'ll,notice,is,that,the,code,is,very,simple,.,It
Hello,0.867,0.758,0.674,0.619,0.619,0.636,0.612,0.626,0.623,0.584,0.518,0.528,0.552,0.572,0.555,0.568,0.496,0.446
world,0.133,0.109,0.054,0.068,0.055,0.038,0.041,0.03,0.024,0.028,0.036,0.04,0.034,0.024,0.015,0.014,0.017,0.019
"""",Unnamed: 1_level_3,0.133,0.086,0.084,0.061,0.048,0.043,0.037,0.031,0.03,0.035,0.038,0.035,0.025,0.019,0.018,0.019,0.02
Ċ,Unnamed: 1_level_4,Unnamed: 2_level_4,0.186,0.114,0.091,0.075,0.068,0.061,0.054,0.059,0.058,0.052,0.046,0.042,0.038,0.036,0.046,0.043
Ċ,Unnamed: 1_level_5,Unnamed: 2_level_5,Unnamed: 3_level_5,0.116,0.093,0.072,0.064,0.049,0.041,0.044,0.056,0.055,0.04,0.03,0.024,0.022,0.026,0.034
The,Unnamed: 1_level_6,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,0.08,0.078,0.061,0.041,0.035,0.033,0.043,0.036,0.031,0.025,0.019,0.018,0.02,0.022
first,Unnamed: 1_level_7,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,0.054,0.059,0.037,0.031,0.032,0.026,0.022,0.02,0.014,0.013,0.012,0.013,0.014
thing,Unnamed: 1_level_8,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,0.052,0.058,0.044,0.042,0.026,0.02,0.017,0.014,0.013,0.011,0.011,0.012
you,Unnamed: 1_level_9,Unnamed: 2_level_9,Unnamed: 3_level_9,Unnamed: 4_level_9,Unnamed: 5_level_9,Unnamed: 6_level_9,Unnamed: 7_level_9,0.062,0.06,0.04,0.03,0.022,0.018,0.018,0.013,0.011,0.015,0.014
'll,Unnamed: 1_level_10,Unnamed: 2_level_10,Unnamed: 3_level_10,Unnamed: 4_level_10,Unnamed: 5_level_10,Unnamed: 6_level_10,Unnamed: 7_level_10,Unnamed: 8_level_10,0.056,0.048,0.027,0.02,0.015,0.016,0.014,0.012,0.013,0.013
notice,Unnamed: 1_level_11,Unnamed: 2_level_11,Unnamed: 3_level_11,Unnamed: 4_level_11,Unnamed: 5_level_11,Unnamed: 6_level_11,Unnamed: 7_level_11,Unnamed: 8_level_11,Unnamed: 9_level_11,0.06,0.065,0.04,0.029,0.025,0.025,0.019,0.02,0.017
is,Unnamed: 1_level_12,Unnamed: 2_level_12,Unnamed: 3_level_12,Unnamed: 4_level_12,Unnamed: 5_level_12,Unnamed: 6_level_12,Unnamed: 7_level_12,Unnamed: 8_level_12,Unnamed: 9_level_12,Unnamed: 10_level_12,0.08,0.053,0.03,0.026,0.03,0.022,0.024,0.022
that,Unnamed: 1_level_13,Unnamed: 2_level_13,Unnamed: 3_level_13,Unnamed: 4_level_13,Unnamed: 5_level_13,Unnamed: 6_level_13,Unnamed: 7_level_13,Unnamed: 8_level_13,Unnamed: 9_level_13,Unnamed: 10_level_13,Unnamed: 11_level_13,0.075,0.062,0.048,0.041,0.035,0.037,0.034
the,Unnamed: 1_level_14,Unnamed: 2_level_14,Unnamed: 3_level_14,Unnamed: 4_level_14,Unnamed: 5_level_14,Unnamed: 6_level_14,Unnamed: 7_level_14,Unnamed: 8_level_14,Unnamed: 9_level_14,Unnamed: 10_level_14,Unnamed: 11_level_14,Unnamed: 12_level_14,0.071,0.059,0.039,0.032,0.029,0.028
code,Unnamed: 1_level_15,Unnamed: 2_level_15,Unnamed: 3_level_15,Unnamed: 4_level_15,Unnamed: 5_level_15,Unnamed: 6_level_15,Unnamed: 7_level_15,Unnamed: 8_level_15,Unnamed: 9_level_15,Unnamed: 10_level_15,Unnamed: 11_level_15,Unnamed: 12_level_15,Unnamed: 13_level_15,0.062,0.072,0.052,0.054,0.046
is,Unnamed: 1_level_16,Unnamed: 2_level_16,Unnamed: 3_level_16,Unnamed: 4_level_16,Unnamed: 5_level_16,Unnamed: 6_level_16,Unnamed: 7_level_16,Unnamed: 8_level_16,Unnamed: 9_level_16,Unnamed: 10_level_16,Unnamed: 11_level_16,Unnamed: 12_level_16,Unnamed: 13_level_16,Unnamed: 14_level_16,0.069,0.066,0.051,0.031
very,Unnamed: 1_level_17,Unnamed: 2_level_17,Unnamed: 3_level_17,Unnamed: 4_level_17,Unnamed: 5_level_17,Unnamed: 6_level_17,Unnamed: 7_level_17,Unnamed: 8_level_17,Unnamed: 9_level_17,Unnamed: 10_level_17,Unnamed: 11_level_17,Unnamed: 12_level_17,Unnamed: 13_level_17,Unnamed: 14_level_17,Unnamed: 15_level_17,0.054,0.048,0.024
simple,Unnamed: 1_level_18,Unnamed: 2_level_18,Unnamed: 3_level_18,Unnamed: 4_level_18,Unnamed: 5_level_18,Unnamed: 6_level_18,Unnamed: 7_level_18,Unnamed: 8_level_18,Unnamed: 9_level_18,Unnamed: 10_level_18,Unnamed: 11_level_18,Unnamed: 12_level_18,Unnamed: 13_level_18,Unnamed: 14_level_18,Unnamed: 15_level_18,Unnamed: 16_level_18,0.061,0.053
.,Unnamed: 1_level_19,Unnamed: 2_level_19,Unnamed: 3_level_19,Unnamed: 4_level_19,Unnamed: 5_level_19,Unnamed: 6_level_19,Unnamed: 7_level_19,Unnamed: 8_level_19,Unnamed: 9_level_19,Unnamed: 10_level_19,Unnamed: 11_level_19,Unnamed: 12_level_19,Unnamed: 13_level_19,Unnamed: 14_level_19,Unnamed: 15_level_19,Unnamed: 16_level_19,Unnamed: 17_level_19,0.107
It,Unnamed: 1_level_20,Unnamed: 2_level_20,Unnamed: 3_level_20,Unnamed: 4_level_20,Unnamed: 5_level_20,Unnamed: 6_level_20,Unnamed: 7_level_20,Unnamed: 8_level_20,Unnamed: 9_level_20,Unnamed: 10_level_20,Unnamed: 11_level_20,Unnamed: 12_level_20,Unnamed: 13_level_20,Unnamed: 14_level_20,Unnamed: 15_level_20,Unnamed: 16_level_20,Unnamed: 17_level_20,Unnamed: 18_level_20
probability,0.197,0.097,0.986,0.043,0.045,0.211,0.247,0.249,0.512,0.577,0.661,0.256,0.045,0.476,0.102,0.182,0.555,0.125
