In [14]:
import torch
import torch.nn as nn
from model import Transformer
from config import get_config, get_weights_file_path
from train import get_model, get_ds, greedy_decode
import altair as alt
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [15]:
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [16]:
config = get_config()
train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)
model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)

# Load the pretrained weights
epoch = 19
model_filename = get_weights_file_path(config, epoch=f"{epoch}") # get the weights file path of the epoch
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])

Max length of source sentence: 471
Max length of target sentence: 482


<All keys matched successfully>

In [19]:
def load_next_batch():
    # 从验证集加载一个样本 batch
    batch = next(iter(val_dataloader))
    # 获取编码器输入和 mask，并放到指定设备上
    encoder_input = batch["encoder_input"].to(device)
    encoder_mask = batch["encoder_mask"].to(device)
    # 获取解码器输入和 mask，并放到指定设备上
    decoder_input = batch["decoder_input"].to(device)
    decoder_mask = batch["decoder_mask"].to(device)

    # 将编码器输入的 token id 转为 token 字符串
    encoder_input_tokens = [vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]
    # 将解码器输入的 token id 转为 token 字符串
    decoder_input_tokens = [vocab_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]

    # 检查 batch size 是否为 1（验证时通常只处理一个样本）
    assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

    # 使用贪心解码生成模型输出（通常用于推理或可视化）
    model_out = greedy_decode(
        model, encoder_input, encoder_mask, vocab_src, vocab_tgt, config['seq_len'], device)
    
    # 返回 batch 及其 token 列表
    return batch, encoder_input_tokens, decoder_input_tokens

In [20]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    # 将注意力矩阵 m 转换为 pandas DataFrame，便于可视化
    return pd.DataFrame(
        [
            (
                r,  # 行索引
                c,  # 列索引
                float(m[r, c]),  # 注意力分数
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),  # 行token
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),  # 列token
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col  # 只保留指定范围内的行列
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

def get_attn_map(attn_type: str, layer: int, head: int):
    # 获取指定类型、层、头的注意力分数张量
    if attn_type == "encoder":
        attn = model.encoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "decoder":
        attn = model.decoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "encoder-decoder":
        attn = model.decoder.layers[layer].cross_attention_block.attention_scores
    return attn[0, head].data  # 取 batch 的第一个样本和指定 head 的注意力分数

def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):
    # 生成单个注意力头的可视化图表
    df = mtx2df(
        get_attn_map(attn_type, layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()  # 用色块表示注意力分数
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),  # 列token为x轴
            y=alt.Y("row_token", axis=alt.Axis(title="")),  # 行token为y轴
            color="value",  # 色彩深浅表示注意力分数
            tooltip=["row", "column", "value", "row_token", "col_token"],  # 鼠标悬停显示详细信息
        )
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")  # 图表大小和标题
        .interactive()  # 支持交互
    )

def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):
    # 生成所有指定层和头的注意力可视化图表
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            # 每个 head 一个图
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
        # 每层的所有 head 横向拼接
        charts.append(alt.hconcat(*rowCharts))
    # 所有层纵向拼接
    return alt.vconcat(*charts)

In [21]:
batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')
sentence_len = encoder_input_tokens.index("[PAD]")

Source: Such was Julien's first thought on his return to his own room.
Target: Telle fut la première pensée de Julien, en rentrant dans sa chambre.


In [22]:
layers = [0, 1, 2]  # 指定要可视化的层编号（这里是第0、1、2层）
heads = [0, 1, 2, 3, 4, 5, 6, 7]  # 指定每层要可视化的注意力头编号（这里是8个头）

# Encoder Self-Attention
# 可视化编码器自注意力的所有层和所有头的注意力分布
# 参数依次为：注意力类型、层列表、头列表、行token、列token、最大句长
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))

In [23]:
# Encoder Self-Attention
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))

In [24]:
# Encoder Self-Attention
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_tokens, decoder_input_tokens, min(20, sentence_len))