In [1]:
import argparse
import random
import pandas as pd
import warnings
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
from model.model_lora import *
from trainer.trainer_utils import setup_seed
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = argparse.Namespace(
        load_from='model',
        save_dir='out',
        weight='full_sft',
        lora_weight='lora_classifier',
        hidden_size=512,
        num_hidden_layers=8,
        use_moe=0,
        inference_rope_scaling=False,
        max_new_tokens=8192,
        temperature=0.85,
        top_p=0.85,
        historys=0,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    )
def init_model(args):
    tokenizer = AutoTokenizer.from_pretrained(args.load_from)
    if 'model' in args.load_from:
        model = MiniMindForCausalLM(MiniMindConfig(
            hidden_size=args.hidden_size,
            num_hidden_layers=args.num_hidden_layers,
            use_moe=bool(args.use_moe),
            inference_rope_scaling=args.inference_rope_scaling
        ))
        moe_suffix = '_moe' if args.use_moe else ''
        ckp = f'./{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
        model.load_state_dict(torch.load(ckp, map_location=args.device), strict=True)
        if args.lora_weight != 'None':
            apply_lora(model)
            load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
    else:
        model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True)
    print(f'MiniMind模型参数: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)')
    return model.eval().to(args.device), tokenizer
model, tokenizer = init_model(args)

MiniMind模型参数: 25.96 M(illion)


In [4]:
for i, (name, module) in enumerate(model.named_modules()):
    print('-'*100)
    print(i)
    print(name, module)

----------------------------------------------------------------------------------------------------
0
 MiniMindForCausalLM(
  (model): MiniMindModel(
    (embed_tokens): Embedding(6400, 512)
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-7): 8 x MiniMindBlock(
        (self_attn): Attention(
          (q_proj): Linear(
            in_features=512, out_features=512, bias=False
            (lora): LoRA(
              (A): Linear(in_features=512, out_features=8, bias=False)
              (B): Linear(in_features=8, out_features=512, bias=False)
            )
          )
          (k_proj): Linear(in_features=512, out_features=128, bias=False)
          (v_proj): Linear(in_features=512, out_features=128, bias=False)
          (o_proj): Linear(
            in_features=512, out_features=512, bias=False
            (lora): LoRA(
              (A): Linear(in_features=512, out_features=8, bias=False)
              (B): Linear(in_features=8, out_features=512, bia

In [5]:
# 导入必要的函数
from model.model_minimind import apply_rotary_pos_emb, repeat_kv
import torch.nn.functional as F
import math
import types

# 用于存储每一层的注意力权重（格式: {layer_idx: attention_matrix}）
attention_weights = {}

def patch_attention_to_capture_scores(model):
    """修改模型中所有 Attention 模块的 forward 方法以捕获注意力分数"""
    # 保存原始的 forward 方法
    original_forwards = {}
    
    def create_capturing_forward(original_forward, layer_idx):
        """创建一个包装函数来捕获注意力分数"""
        def capturing_forward(self, x, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
            bsz, seq_len, _ = x.shape
            xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
            xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
            xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
            xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
            
            cos, sin = position_embeddings
            xq, xk = apply_rotary_pos_emb(xq, xk, cos[:seq_len], sin[:seq_len])
            
            if past_key_value is not None:
                xk = torch.cat([past_key_value[0], xk], dim=1)
                xv = torch.cat([past_key_value[1], xv], dim=1)
            past_kv = (xk, xv) if use_cache else None
            
            xq, xk, xv = (
                xq.transpose(1, 2),
                repeat_kv(xk, self.n_rep).transpose(1, 2),
                repeat_kv(xv, self.n_rep).transpose(1, 2)
            )
            
            # 计算注意力分数
            scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
            seq_len_kv = xk.shape[2]  # 考虑 kv_cache 的情况
            
            # Causal mask
            if seq_len_kv == seq_len:
                causal_mask = torch.triu(
                    torch.full((seq_len, seq_len), float("-inf"), device=scores.device),
                    diagonal=1
                ).unsqueeze(0).unsqueeze(0)
            else:
                causal_mask = torch.triu(
                    torch.full((seq_len, seq_len_kv), float("-inf"), device=scores.device),
                    diagonal=seq_len_kv - seq_len + 1
                ).unsqueeze(0).unsqueeze(0)
            scores = scores + causal_mask
            
            if attention_mask is not None:
                extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
                extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
                if seq_len_kv > seq_len:
                    pad_length = seq_len_kv - seq_len
                    extended_attention_mask = F.pad(extended_attention_mask, (0, pad_length), value=-1e9)
                scores = scores + extended_attention_mask
            
            attention_probs = F.softmax(scores.float(), dim=-1).type_as(xq)
            
            # 保存注意力权重（平均所有 head），形状: [num_heads, seq_len, seq_len_kv] -> [seq_len, seq_len_kv]
            attention_weights[f'layer_{layer_idx}'] = attention_probs.detach().cpu().mean(dim=1).squeeze(0)
            
            attention_probs = self.attn_dropout(attention_probs)
            output = attention_probs @ xv
            
            output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
            output = self.resid_dropout(self.o_proj(output))
            return output, past_kv
        
        return capturing_forward
    
    # 修改每一层的 Attention 模块
    for layer_idx, layer in enumerate(model.model.layers):
        original_forwards[layer_idx] = layer.self_attn.forward
        capturing_forward = create_capturing_forward(original_forwards[layer_idx], layer_idx)
        # 使用 types.MethodType 正确绑定方法
        layer.self_attn.forward = types.MethodType(capturing_forward, layer.self_attn)
    
    print(f"已修改 {len(model.model.layers)} 层 Attention 模块以捕获注意力权重")
    return original_forwards

# 应用修改
original_forwards = patch_attention_to_capture_scores(model)


已修改 8 层 Attention 模块以捕获注意力权重


In [6]:
# 创建自定义生成函数，逐步生成并收集注意力权重
def generate_with_attention(model, tokenizer, prompt, max_new_tokens=20, temperature=0.8, top_p=0.9):
    """
    自定义生成函数，在生成过程中收集每一层的注意力权重
    
    返回: (generated_text, all_attention_weights)
    """
    global attention_weights
    all_step_attentions = {}  # {step: {layer_key: attention_matrix}}
    
    # 禁用 flash attention 以确保能捕获注意力分数
    for layer in model.model.layers:
        layer.self_attn.flash = False
    
    device = next(model.parameters()).device
    
    # 准备输入
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs['input_ids']
    attention_mask = inputs.get('attention_mask', None)
    
    generated_ids = input_ids.clone()
    past_key_values = None
    next_token = None
    
    # 逐步生成
    for step in range(max_new_tokens):
        # 清空当前步骤的注意力权重
        attention_weights = {}
        
        # 前向传播（会自动捕获注意力权重）
        with torch.no_grad():
            if step == 0:
                # 第一次：使用完整的 input_ids
                outputs = model(
                    input_ids=generated_ids,
                    attention_mask=attention_mask,
                    past_key_values=None,
                    use_cache=True
                )
            else:
                # 后续步骤：只使用新生成的 token
                outputs = model(
                    input_ids=next_token,
                    attention_mask=None,  # kv_cache 模式下不需要完整的 mask
                    past_key_values=past_key_values,
                    use_cache=True
                )
            logits = outputs.logits[:, -1, :]  # 取最后一个位置的 logits
            past_key_values = outputs.past_key_values
        
        # 保存当前步骤的注意力权重
        if attention_weights:
            all_step_attentions[step] = attention_weights.copy()
        
        # 采样下一个 token
        if temperature > 0:
            probs = torch.softmax(logits / temperature, dim=-1)
            # Top-p sampling
            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
            cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
            sorted_indices_to_remove = cumsum_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            probs[indices_to_remove] = 0
            probs = probs / probs.sum(dim=-1, keepdim=True)
            next_token = torch.multinomial(probs, num_samples=1)
        else:
            next_token = torch.argmax(logits, dim=-1, keepdim=True)
        
        # 添加到生成的序列
        generated_ids = torch.cat([generated_ids, next_token], dim=1)
        
        # 更新 attention mask（用于后续解码）
        if attention_mask is not None:
            attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.shape[0], 1), device=device, dtype=attention_mask.dtype)], dim=1)
        
        # 检查是否遇到 EOS
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_text, all_step_attentions

print("已定义生成函数")


已定义生成函数


In [7]:
# 可视化注意力权重的函数
def visualize_attention(attention_matrix, layer_idx, step=None, tokenizer=None, input_ids=None, head_avg=True):
    """
    可视化单层注意力权重
    
    Args:
        attention_matrix: 注意力矩阵，形状 [seq_len, seq_len_kv]
        layer_idx: 层索引
        step: 生成步骤（可选）
        tokenizer: tokenizer（用于显示 token）
        input_ids: 输入 token IDs（用于显示 token）
        head_avg: 是否已经平均了所有 head
    """
    import matplotlib.pyplot as plt
    import numpy as np
    
    attn = attention_matrix.numpy() if isinstance(attention_matrix, torch.Tensor) else attention_matrix
    
    # 创建图形
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # 绘制热力图
    im = ax.imshow(attn, cmap='Blues', aspect='auto', interpolation='nearest')
    
    # 添加颜色条
    plt.colorbar(im, ax=ax, label='Attention Weight')
    
    # 设置标签
    title = f'Layer {layer_idx} Attention'
    if step is not None:
        title += f' (Step {step})'
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xlabel('Key Position (Previous Tokens)', fontsize=12)
    ax.set_ylabel('Query Position (Current Token)', fontsize=12)
    
    # 如果提供了 tokenizer 和 input_ids，显示 token 标签
    if tokenizer is not None and input_ids is not None:
        tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
        # 只显示部分 token（避免标签太密集）
        if len(tokens) > 20:
            step_size = len(tokens) // 20
            tick_positions = list(range(0, len(tokens), step_size))
            tick_labels = [tokens[i] if i < len(tokens) else '' for i in tick_positions]
        else:
            tick_positions = list(range(len(tokens)))
            tick_labels = tokens
        
        ax.set_xticks(tick_positions)
        ax.set_xticklabels(tick_labels, rotation=45, ha='right', fontsize=8)
        ax.set_yticks(tick_positions)
        ax.set_yticklabels(tick_labels, fontsize=8)
    
    plt.tight_layout()
    plt.show()
    
    # 打印数值统计
    print(f"\\nLayer {layer_idx} Attention Statistics:")
    print(f"  Shape: {attn.shape}")
    print(f"  Min: {attn.min():.6f}")
    print(f"  Max: {attn.max():.6f}")
    print(f"  Mean: {attn.mean():.6f}")
    print(f"  Sum per row (should be ~1.0): {attn.sum(axis=-1)[:5]}...")  # 显示前5行

def print_attention_details(attention_matrix, layer_idx, step=None, top_k=5):
    """打印注意力权重的详细信息"""
    attn = attention_matrix.numpy() if isinstance(attention_matrix, torch.Tensor) else attention_matrix
    
    print(f"\\n{'='*80}")
    title = f"Layer {layer_idx} Attention Details"
    if step is not None:
        title += f" (Generation Step {step})"
    print(title)
    print(f"{'='*80}")
    
    # 对于每一行（query position），显示注意力权重最高的 top_k 个位置
    for query_pos in range(min(5, attn.shape[0])):  # 只显示前5个 query positions
        row = attn[query_pos, :]
        top_k_indices = np.argsort(row)[-top_k:][::-1]
        top_k_values = row[top_k_indices]
        
        print(f"\\nQuery Position {query_pos}:")
        for i, (idx, val) in enumerate(zip(top_k_indices, top_k_values)):
            print(f"  {i+1}. Key Position {idx}: {val:.6f}")
    
    # 打印完整的注意力矩阵（如果不太大）
    if attn.shape[0] <= 10 and attn.shape[1] <= 10:
        print(f"\\nFull Attention Matrix:")
        print(attn)
    else:
        print(f"\\nAttention Matrix (first 5x5):")
        print(attn[:5, :5])

print("已定义可视化函数")


已定义可视化函数


In [83]:
# 测试生成并收集注意力权重
prompt = '''<|im_start|>user
标题：The dollar has hit its highest level
内容：The dollar has hit its highest level against the euro in almost 3 months
请根据标题和内容，给出文章的分类。(在以下选项中选：['entertainment' 'sport' 'politics' 'tech' 'business'])<|im_end|>
<|im_start|>assistant'''
max_new_tokens = 50

print(f"输入提示: {prompt}")
print(f"开始生成并收集注意力权重...\n")

generated_text, all_attention = generate_with_attention(
    model, tokenizer, prompt, 
    max_new_tokens=max_new_tokens,
    temperature=1.2,
    top_p=0.8
)

print(f"\n生成文本: {generated_text}")
print(f"\n收集到 {len(all_attention)} 个生成步骤的注意力权重")


输入提示: <|im_start|>user
标题：The dollar has hit its highest level
内容：The dollar has hit its highest level against the euro in almost 3 months
请根据标题和内容，给出文章的分类。(在以下选项中选：['entertainment' 'sport' 'politics' 'tech' 'business'])<|im_end|>
<|im_start|>assistant
开始生成并收集注意力权重...


生成文本: user
标题：The dollar has hit its highest level
内容：The dollar has hit its highest level against the euro in almost 3 months
请根据标题和内容，给出文章的分类。(在以下选项中选：['entertainment' 'sport' 'politics' 'tech' 'business'])
assistant>politiness

收集到 6 个生成步骤的注意力权重


In [84]:
# 显示每个生成步骤每一层的注意力权重详情
import numpy as np

# 选择要显示的步骤（例如最后一个步骤）
step_to_show = len(all_attention) - 1  # 最后一个生成步骤

if step_to_show in all_attention:
    print(f"\n{'='*80}")
    print(f"显示生成步骤 {step_to_show} 的注意力权重")
    print(f"{'='*80}")
    
    step_attentions = all_attention[step_to_show]
    
    for layer_key in sorted(step_attentions.keys(), key=lambda x: int(x.split('_')[1])):
        layer_idx = int(layer_key.split('_')[1])
        attn_matrix = step_attentions[layer_key]
        
        # 打印详细信息
        print_attention_details(attn_matrix, layer_idx, step=step_to_show, top_k=5)
        
        # 可选：可视化（取消注释以显示）
        # visualize_attention(attn_matrix, layer_idx, step=step_to_show)
else:
    print(f"步骤 {step_to_show} 不存在")



显示生成步骤 5 的注意力权重
Layer 0 Attention Details (Generation Step 5)
\nQuery Position 0:
  1. Key Position 104: 0.804786
  2. Key Position 102: 0.068178
  3. Key Position 103: 0.021745
  4. Key Position 98: 0.013952
  5. Key Position 84: 0.007418
\nAttention Matrix (first 5x5):
[[5.1872390e-05 8.4732991e-04 6.6318625e-04 4.1106701e-04 1.2921287e-04]]
Layer 1 Attention Details (Generation Step 5)
\nQuery Position 0:
  1. Key Position 104: 0.245569
  2. Key Position 103: 0.194255
  3. Key Position 96: 0.091334
  4. Key Position 98: 0.077849
  5. Key Position 102: 0.069417
\nAttention Matrix (first 5x5):
[[1.3912680e-02 2.0646994e-04 7.5389363e-04 4.9090316e-04 8.9023779e-05]]
Layer 2 Attention Details (Generation Step 5)
\nQuery Position 0:
  1. Key Position 0: 0.432069
  2. Key Position 104: 0.095296
  3. Key Position 103: 0.073218
  4. Key Position 102: 0.045389
  5. Key Position 100: 0.043947
\nAttention Matrix (first 5x5):
[[4.3206862e-01 1.4222973e-03 2.6972985e-03 5.0568772e-03 3.1000169

In [85]:
# 显示所有步骤、所有层的注意力权重数值（以表格形式）
import pandas as pd

# 创建一个 DataFrame 来存储注意力权重的统计信息
attention_stats = []

for step, step_attentions in all_attention.items():
    for layer_key, attn_matrix in step_attentions.items():
        layer_idx = int(layer_key.split('_')[1])
        attn_np = attn_matrix.numpy() if isinstance(attn_matrix, torch.Tensor) else attn_matrix
        
        attention_stats.append({
            'Step': step,
            'Layer': layer_idx,
            'Shape': str(attn_np.shape),
            'Min': attn_np.min(),
            'Max': attn_np.max(),
            'Mean': attn_np.mean(),
            'Std': attn_np.std(),
            'Sum': attn_np.sum()
        })

df = pd.DataFrame(attention_stats)
print("注意力权重统计信息:")
print(df.to_string(index=False))

# 对于特定步骤和层，显示完整的注意力矩阵
print(f"\n\n{'='*80}")
print("详细注意力矩阵数值:")
print(f"{'='*80}")

# 选择最后一个生成步骤和第一层作为示例
last_step = len(all_attention) - 1
first_layer_key = f'layer_0'

if last_step in all_attention and first_layer_key in all_attention[last_step]:
    attn_matrix = all_attention[last_step][first_layer_key]
    attn_np = attn_matrix.numpy() if isinstance(attn_matrix, torch.Tensor) else attn_matrix
    
    print(f"\n步骤 {last_step}, 层 0 的完整注意力矩阵:")
    print(f"形状: {attn_np.shape}")
    print("\n矩阵数值 (每个元素表示 query position 对 key position 的注意力权重):")
    print(attn_np)
    
    # 打印每一行的和（应该接近1.0）
    print(f"\n每一行的和（验证概率分布）:")
    row_sums = attn_np.sum(axis=-1)
    for i, row_sum in enumerate(row_sums):
        print(f"  Row {i}: {row_sum:.6f}")


注意力权重统计信息:
 Step  Layer      Shape      Min      Max     Mean      Std   Sum
    0      0 (100, 100) 0.000000 1.000000 0.010000 0.064100 100.0
    0      1 (100, 100) 0.000000 1.000000 0.010000 0.038689 100.0
    0      2 (100, 100) 0.000000 1.000000 0.010000 0.045502 100.0
    0      3 (100, 100) 0.000000 1.000000 0.010000 0.036393 100.0
    0      4 (100, 100) 0.000000 1.000000 0.010000 0.035212 100.0
    0      5 (100, 100) 0.000000 1.000000 0.010000 0.039258 100.0
    0      6 (100, 100) 0.000000 1.000000 0.010000 0.053254 100.0
    0      7 (100, 100) 0.000000 1.000000 0.010000 0.033236 100.0
    1      0   (1, 101) 0.001344 0.177436 0.009901 0.017942   1.0
    1      1   (1, 101) 0.000095 0.200046 0.009901 0.030237   1.0
    1      2   (1, 101) 0.000396 0.229295 0.009901 0.025841   1.0
    1      3   (1, 101) 0.000248 0.199773 0.009901 0.027445   1.0
    1      4   (1, 101) 0.000720 0.218724 0.009901 0.026122   1.0
    1      5   (1, 101) 0.000199 0.188120 0.009901 0.032826   1.0

In [89]:
encoded=tokenizer.encode(prompt)
splitted='|'.join([tokenizer.decode([encoded[i]]) for i in range(len(encoded))])
print(encoded)
print(splitted)

[1, 320, 275, 201, 1232, 674, 355, 1148, 302, 393, 78, 305, 780, 322, 284, 1113, 4143, 262, 377, 2187, 201, 1911, 355, 1148, 302, 393, 78, 305, 780, 322, 284, 1113, 4143, 262, 377, 2187, 4021, 276, 311, 5521, 295, 439, 79, 1802, 1328, 1804, 4567, 201, 1055, 1227, 1232, 674, 315, 1911, 270, 3496, 2239, 269, 2467, 286, 10, 368, 1104, 5752, 413, 838, 355, 61, 9, 309, 1544, 410, 9, 2040, 85, 1519, 9, 2040, 82, 393, 284, 1238, 9, 2040, 1491, 352, 9, 2040, 68, 320, 1532, 9, 63, 11, 2, 201, 1, 1078, 538, 501]
<|im_start|>|us|er|
|标|题|：|The| d|ol|l|ar| has| h|it| its| hig|he|st| level|
|内容|：|The| d|ol|l|ar| has| h|it| its| hig|he|st| level| against| the| e|uro| in| al|m|ost| 3| mon|ths|
|请|根据|标|题|和|内容|，|给出|文章|的|分类|。|(|在|以下|选项|中|选|：|[|'|ent|ertain|ment|'| '|s|port|'| '|p|ol|it|ics|'| '|te|ch|'| '|b|us|iness|'|]|)|<|im_end|>|
|<|im_start|>|ass|ist|ant


In [92]:
print(tokenizer.encode(' dollar'))

[302, 393, 78, 305]


In [None]:
# 24-27:' dollar'
print(all_attention[0]['layer_0'][:,24].mean())
print(all_attention[0]['layer_0'][:,25].mean())
print(all_attention[0]['layer_0'][:,26].mean())
print(all_attention[0]['layer_0'][:,27].mean())


tensor(0.0147)
tensor(0.0165)
tensor(0.0051)
tensor(0.0077)
