# FlexAttention 代码流程图和注意力掩码可视化
# FlexAttention Code Flowchart and Attention Mask Visualization

本笔记本提供了 FlexAttention 实现的完整可视化，包括：
This notebook provides comprehensive visualizations for the FlexAttention implementation, including:

1. **代码流程图** - 展示整个处理流程 / **Code Flowchart** - Shows the entire processing pipeline
2. **注意力掩码形状** - 精确展示不同场景下的掩码矩阵 / **Attention Mask Shapes** - Precisely shows mask matrices in different scenarios
3. **可交互参数** - 方便修改和实验 / **Interactive Parameters** - Easy to modify for experimentation

---

## 设置说明 / Setup Instructions

确保安装了必要的包：Make sure you have the required packages installed:

```bash
pip install matplotlib numpy jupyter
```

In [None]:
# 导入必要的库 / Import required libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch, FancyArrowPatch
import numpy as np
from typing import List, Tuple
import warnings
warnings.filterwarnings('ignore')

# 设置中文字体支持 / Set Chinese font support
plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial Unicode MS', 'SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 设置图表样式 / Set plot style
plt.style.use('default')
%matplotlib inline

---

## 第一部分：代码流程图 / Part 1: Code Flowchart

### 可修改参数 / Modifiable Parameters

In [None]:
# ============================================================================
# 流程图参数配置 / Flowchart Configuration Parameters
# ============================================================================

# 可以修改这些参数来调整流程图的外观 / Modify these to adjust flowchart appearance
FLOWCHART_CONFIG = {
    'figure_size': (14, 16),          # 图表大小 / Figure size
    'box_width': 3.5,                 # 框的宽度 / Box width
    'box_height': 0.6,                # 框的高度 / Box height
    'vertical_spacing': 1.2,          # 垂直间距 / Vertical spacing
    'arrow_width': 2,                 # 箭头宽度 / Arrow width
    'font_size_title': 11,            # 标题字体大小 / Title font size
    'font_size_desc': 9,              # 描述字体大小 / Description font size
    
    # 颜色方案 / Color scheme
    'color_input': '#E3F2FD',         # 输入阶段 / Input phase
    'color_processing': '#FFF3E0',    # 处理阶段 / Processing phase  
    'color_attention': '#F3E5F5',     # 注意力阶段 / Attention phase
    'color_generation': '#E8F5E9',    # 生成阶段 / Generation phase
    'color_output': '#FCE4EC',        # 输出阶段 / Output phase
    'edge_color': '#424242',          # 边框颜色 / Edge color
}

In [None]:
def draw_flowchart(config=FLOWCHART_CONFIG):
    """
    绘制 FlexAttention 代码流程图
    Draw FlexAttention code flowchart
    
    Parameters:
        config: 配置字典 / Configuration dictionary
    """
    fig, ax = plt.subplots(figsize=config['figure_size'])
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 20)
    ax.axis('off')
    
    # 中心 x 坐标 / Center x coordinate
    center_x = 5
    box_w = config['box_width']
    box_h = config['box_height']
    spacing = config['vertical_spacing']
    
    # 流程步骤定义 / Flow step definitions
    # Format: (y_position, title_en, title_zh, description, color)
    steps = [
        # 1. Input Phase
        (18.5, '1. Input Preparation', '1. 输入准备', 
         'Load question + paraphrases', config['color_input']),
        
        (17.3, 'Generate 5 Paraphrases', '生成 5 个改写', 
         'dataset.construct_prompts()', config['color_input']),
        
        # 2. Concatenation Phase
        (16.1, '2. Concatenation with Position Tracking', '2. 拼接并追踪位置',
         'concatenate_paraphrases_with_positions()', config['color_processing']),
        
        (14.9, 'Track Segment Positions', '记录分段位置',
         'segment_positions = [(0,48), (48,95), ...]', config['color_processing']),
        
        (13.7, 'Tokenization', '分词处理',
         'tokenizer(concatenated_text)', config['color_processing']),
        
        # 3. Attention Mask Creation
        (12.5, '3. Create FlexAttention Mask', '3. 创建 FlexAttention 掩码',
         'create_flex_attention_mask()', config['color_attention']),
        
        (11.3, 'Define Mask Function', '定义掩码函数',
         'mask_mod(b, h, q_idx, kv_idx) -> bool', config['color_attention']),
        
        # 4. Model Patching
        (10.1, '4. Patch Model with FlexAttention', '4. 为模型打补丁',
         'FlexAttentionWrapper.patch_model()', config['color_attention']),
        
        # 5. Encoding Phase
        (8.9, '5. Encoding Phase', '5. 编码阶段',
         'Process original tokens with segment isolation', config['color_generation']),
        
        (7.7, 'Segment-Isolated Attention', '分段隔离注意力',
         'Each paraphrase attends only to itself', config['color_generation']),
        
        # 6. Generation Loop
        (6.5, '6. Generation Loop (Auto-regressive)', '6. 生成循环（自回归）',
         'for step in range(max_new_tokens):', config['color_generation']),
        
        (5.3, 'Forward Pass', '前向传播',
         'logits = model(input_ids)', config['color_generation']),
        
        (4.1, 'Fusion Attention', '融合注意力',
         'Generated tokens attend to ALL segments', config['color_generation']),
        
        (2.9, 'Token Selection', '选择令牌',
         'next_token = argmax(logits)', config['color_generation']),
        
        (1.7, 'Update Input', '更新输入',
         'input_ids = concat(input_ids, next_token)', config['color_generation']),
        
        # 7. Output
        (0.5, '7. Decode & Return', '7. 解码并返回',
         'tokenizer.decode(generated)', config['color_output']),
    ]
    
    # 绘制所有步骤 / Draw all steps
    for y, title_en, title_zh, desc, color in steps:
        # 绘制框 / Draw box
        box = FancyBboxPatch(
            (center_x - box_w/2, y - box_h/2),
            box_w, box_h,
            boxstyle="round,pad=0.05",
            facecolor=color,
            edgecolor=config['edge_color'],
            linewidth=1.5
        )
        ax.add_patch(box)
        
        # 添加文字 / Add text
        ax.text(center_x, y + 0.15, title_en, 
                ha='center', va='center', fontsize=config['font_size_title'], 
                fontweight='bold')
        ax.text(center_x, y - 0.15, desc, 
                ha='center', va='center', fontsize=config['font_size_desc'],
                style='italic', color='#555')
    
    # 绘制箭头 / Draw arrows
    for i in range(len(steps) - 1):
        y_from = steps[i][0] - box_h/2
        y_to = steps[i+1][0] + box_h/2
        arrow = FancyArrowPatch(
            (center_x, y_from),
            (center_x, y_to),
            arrowstyle='->,head_width=0.4,head_length=0.4',
            color=config['edge_color'],
            linewidth=config['arrow_width'],
            zorder=1
        )
        ax.add_patch(arrow)
    
    # 添加标题 / Add title
    ax.text(center_x, 19.5, 
            'FlexAttention Code Flowchart / FlexAttention 代码流程图',
            ha='center', va='center', fontsize=16, fontweight='bold')
    
    # 添加图例 / Add legend
    legend_elements = [
        mpatches.Patch(color=config['color_input'], label='Input / 输入'),
        mpatches.Patch(color=config['color_processing'], label='Processing / 处理'),
        mpatches.Patch(color=config['color_attention'], label='Attention Mask / 注意力掩码'),
        mpatches.Patch(color=config['color_generation'], label='Generation / 生成'),
        mpatches.Patch(color=config['color_output'], label='Output / 输出')
    ]
    ax.legend(handles=legend_elements, loc='upper right', fontsize=10)
    
    plt.tight_layout()
    return fig

# 绘制流程图 / Draw flowchart
fig = draw_flowchart()
plt.show()

# 保存图片 / Save figure
# fig.savefig('flowchart.png', dpi=300, bbox_inches='tight')

---

## 第二部分：注意力掩码可视化 / Part 2: Attention Mask Visualization

### 可修改参数 / Modifiable Parameters

In [None]:
# ============================================================================
# 注意力掩码参数配置 / Attention Mask Configuration Parameters
# ============================================================================

# 场景 1: 小型示例（完整展示）/ Scenario 1: Small example (full display)
MASK_CONFIG_SMALL = {
    'num_paraphrases': 3,              # 改写数量 / Number of paraphrases
    'tokens_per_paraphrase': 15,       # 每个改写的令牌数 / Tokens per paraphrase
    'separator_tokens': 2,             # 分隔符令牌数 / Separator tokens
    'num_generated_tokens': 5,         # 生成的令牌数 / Generated tokens
    'display_mode': 'full',            # 显示模式: 'full' 或 'sampled' / Display mode
}

# 场景 2: 中型示例（智能采样）/ Scenario 2: Medium example (smart sampling)
MASK_CONFIG_MEDIUM = {
    'num_paraphrases': 5,              # 改写数量 / Number of paraphrases
    'tokens_per_paraphrase': 25,       # 每个改写的令牌数 / Tokens per paraphrase
    'separator_tokens': 3,             # 分隔符令牌数 / Separator tokens
    'num_generated_tokens': 8,         # 生成的令牌数 / Generated tokens
    'display_mode': 'sampled',         # 显示模式 / Display mode
    'max_display_positions': 30,       # 最大显示位置数 / Max display positions
}

# 场景 3: 大型示例（真实场景）/ Scenario 3: Large example (realistic scenario)
MASK_CONFIG_LARGE = {
    'num_paraphrases': 5,              # 改写数量 / Number of paraphrases  
    'tokens_per_paraphrase': 50,       # 每个改写的令牌数 / Tokens per paraphrase
    'separator_tokens': 4,             # 分隔符令牌数 / Separator tokens
    'num_generated_tokens': 15,        # 生成的令牌数 / Generated tokens
    'display_mode': 'sampled',         # 显示模式 / Display mode
    'max_display_positions': 35,       # 最大显示位置数 / Max display positions
}

In [None]:
def create_segment_positions(config):
    """
    根据配置创建分段位置
    Create segment positions based on configuration
    
    Returns:
        segment_positions: List of (start, end) tuples
        original_length: Total length of original content
        total_length: Total length including generated tokens
    """
    segment_positions = []
    current_pos = 0
    
    for i in range(config['num_paraphrases']):
        if i > 0:
            # Add separator
            current_pos += config['separator_tokens']
        
        start = current_pos
        end = current_pos + config['tokens_per_paraphrase']
        segment_positions.append((start, end))
        current_pos = end
    
    original_length = current_pos
    total_length = original_length + config['num_generated_tokens']
    
    return segment_positions, original_length, total_length


def create_attention_mask_function(segment_positions, original_length):
    """
    创建注意力掩码函数（与 flex_attention_generate.py 中的实现一致）
    Create attention mask function (consistent with implementation in flex_attention_generate.py)
    
    Args:
        segment_positions: List of (start, end) tuples
        original_length: Length of original content
        
    Returns:
        mask_func: Function (b, h, q_idx, kv_idx) -> bool
    """
    def mask_func(b, h, q_idx, kv_idx):
        # 因果约束 / Causal constraint
        if q_idx < kv_idx:
            return False
        
        # 生成的令牌可以关注所有之前的内容 / Generated tokens can attend to all previous
        if q_idx >= original_length:
            return True
        
        # 原始令牌只在同一分段内关注 / Original tokens only attend within same segment
        q_segment = None
        kv_segment = None
        
        for seg_id, (start, end) in enumerate(segment_positions):
            if start <= q_idx < end:
                q_segment = seg_id
            if start <= kv_idx < end:
                kv_segment = seg_id
        
        if q_segment is not None and kv_segment is not None:
            return q_segment == kv_segment
        
        return False
    
    return mask_func


def smart_sample_positions(segment_positions, original_length, total_length, max_positions=30):
    """
    智能采样位置以展示掩码结构
    Smart sampling of positions to show mask structure
    
    Args:
        segment_positions: List of (start, end) tuples
        original_length: Length of original content
        total_length: Total sequence length
        max_positions: Maximum number of positions to display
        
    Returns:
        positions: List of sampled position indices
    """
    positions = set()
    
    # 1. 添加分段边界 / Add segment boundaries
    for start, end in segment_positions:
        positions.add(start)
        positions.add(end - 1)
        # 在分段中间添加一些位置 / Add some positions in the middle
        segment_len = end - start
        if segment_len > 4:
            positions.add(start + segment_len // 3)
            positions.add(start + 2 * segment_len // 3)
    
    # 2. 添加原始长度边界 / Add original_length boundary
    if original_length < total_length:
        positions.add(original_length - 1)
        positions.add(original_length)
    
    # 3. 添加生成的令牌位置 / Add generated token positions
    if total_length > original_length:
        gen_count = min(5, total_length - original_length)
        for i in range(gen_count):
            positions.add(original_length + i)
    
    # 4. 添加最后位置 / Add last position
    positions.add(total_length - 1)
    
    # 5. 填充剩余位置 / Fill remaining positions
    positions_list = sorted(list(positions))
    while len(positions_list) < max_positions and len(positions_list) < total_length:
        # 找到最大间隙 / Find largest gap
        max_gap = 0
        max_gap_idx = 0
        for i in range(len(positions_list) - 1):
            gap = positions_list[i+1] - positions_list[i]
            if gap > max_gap:
                max_gap = gap
                max_gap_idx = i
        
        if max_gap <= 1:
            break
        
        # 在最大间隙中插入中点 / Insert midpoint in largest gap
        mid = (positions_list[max_gap_idx] + positions_list[max_gap_idx + 1]) // 2
        positions_list.insert(max_gap_idx + 1, mid)
    
    return sorted(positions_list[:max_positions])


def visualize_attention_mask(config, title_suffix=''):
    """
    可视化注意力掩码矩阵
    Visualize attention mask matrix
    
    Args:
        config: Configuration dictionary
        title_suffix: Additional title text
    """
    # 创建分段位置 / Create segment positions
    segment_positions, original_length, total_length = create_segment_positions(config)
    
    # 创建掩码函数 / Create mask function
    mask_func = create_attention_mask_function(segment_positions, original_length)
    
    # 确定显示位置 / Determine display positions
    if config['display_mode'] == 'full' or total_length <= 25:
        positions = list(range(total_length))
        display_info = f"Full display: {total_length}×{total_length}"
    else:
        max_pos = config.get('max_display_positions', 30)
        positions = smart_sample_positions(segment_positions, original_length, 
                                          total_length, max_pos)
        display_info = f"Sampled display: showing {len(positions)} of {total_length} positions"
    
    # 创建掩码矩阵 / Create mask matrix
    n_pos = len(positions)
    mask_matrix = np.zeros((n_pos, n_pos))
    for i, q in enumerate(positions):
        for j, kv in enumerate(positions):
            mask_matrix[i, j] = 1 if mask_func(0, 0, q, kv) else 0
    
    # 绘制掩码矩阵 / Plot mask matrix
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # 使用颜色映射 / Use color mapping
    cmap = plt.cm.colors.ListedColormap(['white', '#2E7D32'])  # white=cannot attend, green=can attend
    im = ax.imshow(mask_matrix, cmap=cmap, aspect='auto', interpolation='nearest')
    
    # 设置刻度 / Set ticks
    ax.set_xticks(range(n_pos))
    ax.set_yticks(range(n_pos))
    ax.set_xticklabels([str(p) for p in positions], rotation=45, ha='right', fontsize=8)
    ax.set_yticklabels([str(p) for p in positions], fontsize=8)
    
    # 添加网格 / Add grid
    ax.set_xticks(np.arange(n_pos) - 0.5, minor=True)
    ax.set_yticks(np.arange(n_pos) - 0.5, minor=True)
    ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
    
    # 添加分段边界线 / Add segment boundary lines
    for start, end in segment_positions:
        if start in positions:
            idx = positions.index(start)
            ax.axhline(y=idx - 0.5, color='red', linewidth=2, linestyle='--', alpha=0.6)
            ax.axvline(x=idx - 0.5, color='red', linewidth=2, linestyle='--', alpha=0.6)
    
    # 标记生成开始位置 / Mark generation start
    if original_length in positions:
        idx = positions.index(original_length)
        ax.axhline(y=idx - 0.5, color='blue', linewidth=2, linestyle='--', alpha=0.6)
        ax.axvline(x=idx - 0.5, color='blue', linewidth=2, linestyle='--', alpha=0.6)
    
    # 设置标签 / Set labels
    ax.set_xlabel('Key/Value Position (KV)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Query Position (Q)', fontsize=12, fontweight='bold')
    
    # 标题 / Title
    title = f'Attention Mask Visualization{title_suffix}\n{display_info}'
    ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
    
    # 添加颜色条 / Add colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_ticks([0.25, 0.75])
    cbar.set_ticklabels(['Cannot Attend\n不可关注', 'Can Attend\n可关注'])
    
    # 添加信息文本 / Add info text
    info_text = (
        f"Configuration / 配置:\n"
        f"  • Paraphrases / 改写数: {config['num_paraphrases']}\n"
        f"  • Tokens per paraphrase / 每段令牌数: {config['tokens_per_paraphrase']}\n"
        f"  • Original length / 原始长度: {original_length}\n"
        f"  • Generated tokens / 生成令牌数: {config['num_generated_tokens']}\n"
        f"  • Total length / 总长度: {total_length}\n\n"
        f"Legend / 图例:\n"
        f"  ━━ Red dashed / 红色虚线: Segment boundary / 分段边界\n"
        f"  ━━ Blue dashed / 蓝色虚线: Generation start / 生成开始"
    )
    
    # 在右侧添加文本框 / Add text box on the right
    plt.gcf().text(0.98, 0.5, info_text, 
                   fontsize=9, 
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3),
                   verticalalignment='center',
                   horizontalalignment='left')
    
    plt.tight_layout()
    return fig, mask_matrix, positions


print("函数定义完成！/ Function definitions complete!")
print("现在可以使用不同的配置来可视化注意力掩码 / Now you can visualize attention masks with different configs")

### 场景 1：小型示例（完整展示）/ Scenario 1: Small Example (Full Display)

适合理解基本原理 / Good for understanding basic principles

In [None]:
# 可视化小型示例 / Visualize small example
fig1, matrix1, pos1 = visualize_attention_mask(
    MASK_CONFIG_SMALL, 
    title_suffix=' - Small Example / 小型示例'
)
plt.show()

# 打印统计信息 / Print statistics
print(f"\n小型示例统计 / Small Example Statistics:")
print(f"  矩阵大小 / Matrix size: {matrix1.shape}")
print(f"  可关注的位置数 / Attention-allowed positions: {np.sum(matrix1):.0f}")
print(f"  总位置数 / Total positions: {matrix1.size}")
print(f"  注意力比例 / Attention ratio: {np.sum(matrix1)/matrix1.size*100:.1f}%")

### 场景 2：中型示例（智能采样）/ Scenario 2: Medium Example (Smart Sampling)

展示采样策略如何保持结构 / Shows how sampling strategy preserves structure

In [None]:
# 可视化中型示例 / Visualize medium example
fig2, matrix2, pos2 = visualize_attention_mask(
    MASK_CONFIG_MEDIUM,
    title_suffix=' - Medium Example / 中型示例'
)
plt.show()

# 打印统计信息 / Print statistics
print(f"\n中型示例统计 / Medium Example Statistics:")
print(f"  显示位置数 / Displayed positions: {len(pos2)}")
print(f"  矩阵大小 / Matrix size: {matrix2.shape}")
print(f"  可关注的位置数 / Attention-allowed positions: {np.sum(matrix2):.0f}")
print(f"  注意力比例 / Attention ratio: {np.sum(matrix2)/matrix2.size*100:.1f}%")

### 场景 3：大型示例（真实场景）/ Scenario 3: Large Example (Realistic Scenario)

模拟真实使用场景 / Simulates realistic use case

In [None]:
# 可视化大型示例 / Visualize large example
fig3, matrix3, pos3 = visualize_attention_mask(
    MASK_CONFIG_LARGE,
    title_suffix=' - Large Example / 大型示例 (Realistic)'
)
plt.show()

# 打印统计信息 / Print statistics
segment_pos, orig_len, total_len = create_segment_positions(MASK_CONFIG_LARGE)
print(f"\n大型示例统计 / Large Example Statistics:")
print(f"  实际总长度 / Actual total length: {total_len}")
print(f"  显示位置数 / Displayed positions: {len(pos3)}")
print(f"  采样率 / Sampling rate: {len(pos3)/total_len*100:.1f}%")
print(f"  矩阵大小 / Matrix size: {matrix3.shape}")
print(f"  可关注的位置数 / Attention-allowed positions: {np.sum(matrix3):.0f}")
print(f"  注意力比例 / Attention ratio: {np.sum(matrix3)/matrix3.size*100:.1f}%")

---

## 第三部分：自定义配置实验 / Part 3: Custom Configuration Experimentation

### 修改参数进行实验 / Modify parameters for experimentation

你可以在下面的代码单元格中修改任何参数来探索不同的场景：
You can modify any parameters in the cell below to explore different scenarios:

In [None]:
# ============================================================================
# 自定义配置 / Custom Configuration
# ============================================================================
# 修改下面的参数来创建你自己的场景！
# Modify the parameters below to create your own scenario!

CUSTOM_CONFIG = {
    'num_paraphrases': 4,              # 尝试修改这个值！/ Try changing this!
    'tokens_per_paraphrase': 30,       # 尝试修改这个值！/ Try changing this!
    'separator_tokens': 3,             
    'num_generated_tokens': 10,        # 尝试修改这个值！/ Try changing this!
    'display_mode': 'sampled',         # 可选: 'full' 或 'sampled' / Options: 'full' or 'sampled'
    'max_display_positions': 30,       
}

# 可视化自定义配置 / Visualize custom configuration
fig_custom, matrix_custom, pos_custom = visualize_attention_mask(
    CUSTOM_CONFIG,
    title_suffix=' - Custom Configuration / 自定义配置'
)
plt.show()

# 打印详细信息 / Print detailed information
segment_pos_custom, orig_len_custom, total_len_custom = create_segment_positions(CUSTOM_CONFIG)

print("\n" + "="*60)
print("自定义配置详细信息 / Custom Configuration Details")
print("="*60)
print(f"\n分段信息 / Segment Information:")
for i, (start, end) in enumerate(segment_pos_custom):
    print(f"  Segment {i+1}: positions {start:3d} - {end-1:3d} (length: {end-start})")

print(f"\n长度统计 / Length Statistics:")
print(f"  原始内容长度 / Original content length: {orig_len_custom}")
print(f"  生成令牌数 / Generated tokens: {CUSTOM_CONFIG['num_generated_tokens']}")
print(f"  总长度 / Total length: {total_len_custom}")

print(f"\n显示信息 / Display Information:")
print(f"  显示位置数 / Displayed positions: {len(pos_custom)}")
print(f"  采样率 / Sampling rate: {len(pos_custom)/total_len_custom*100:.1f}%")

print(f"\n注意力统计 / Attention Statistics:")
print(f"  矩阵大小 / Matrix size: {matrix_custom.shape}")
print(f"  可关注位置 / Attention-allowed: {np.sum(matrix_custom):.0f}")
print(f"  总位置数 / Total positions: {matrix_custom.size}")
print(f"  注意力比例 / Attention ratio: {np.sum(matrix_custom)/matrix_custom.size*100:.1f}%")
print("="*60)

---

## 第四部分：注意力模式分析 / Part 4: Attention Pattern Analysis

### 对比不同阶段的注意力模式 / Compare attention patterns in different phases

In [None]:
def analyze_attention_patterns(config):
    """
    分析不同查询位置的注意力模式
    Analyze attention patterns for different query positions
    """
    segment_positions, original_length, total_length = create_segment_positions(config)
    mask_func = create_attention_mask_function(segment_positions, original_length)
    
    # 选择代表性查询位置 / Select representative query positions
    query_positions = [
        segment_positions[0][0] + 5,  # 第1段中间 / Middle of segment 1
        segment_positions[-1][0] + 5,  # 最后一段中间 / Middle of last segment
        original_length,               # 第一个生成令牌 / First generated token
        total_length - 1,              # 最后一个生成令牌 / Last generated token
    ]
    
    labels = [
        f'Segment 1 (pos {query_positions[0]})',
        f'Last Segment (pos {query_positions[1]})',
        f'First Generated (pos {query_positions[2]})',
        f'Last Generated (pos {query_positions[3]})'
    ]
    
    # 创建子图 / Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    for idx, (q_pos, label) in enumerate(zip(query_positions, labels)):
        # 计算这个查询位置对所有键位置的注意力 / Calculate attention for all key positions
        attention_pattern = np.array([mask_func(0, 0, q_pos, kv) for kv in range(total_length)])
        
        # 绘制柱状图 / Plot bar chart
        ax = axes[idx]
        colors = ['green' if val else 'lightgray' for val in attention_pattern]
        ax.bar(range(total_length), attention_pattern, color=colors, width=1.0, edgecolor='none')
        
        # 标记分段边界 / Mark segment boundaries
        for start, end in segment_positions:
            ax.axvline(x=start, color='red', linestyle='--', linewidth=1, alpha=0.5)
        ax.axvline(x=original_length, color='blue', linestyle='--', linewidth=2, alpha=0.7)
        
        # 设置标签和标题 / Set labels and title
        ax.set_xlabel('Key/Value Position')
        ax.set_ylabel('Can Attend')
        ax.set_title(f'Attention Pattern for {label}', fontweight='bold')
        ax.set_ylim(-0.1, 1.1)
        ax.grid(axis='y', alpha=0.3)
        
        # 添加统计信息 / Add statistics
        attend_count = np.sum(attention_pattern)
        attend_ratio = attend_count / total_length * 100
        ax.text(0.98, 0.95, f'Attend: {attend_count:.0f}/{total_length} ({attend_ratio:.1f}%)',
                transform=ax.transAxes, ha='right', va='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.suptitle('Attention Patterns Analysis / 注意力模式分析', 
                 fontsize=16, fontweight='bold', y=1.00)
    plt.tight_layout()
    return fig


# 分析中型示例的注意力模式 / Analyze attention patterns for medium example
fig_patterns = analyze_attention_patterns(MASK_CONFIG_MEDIUM)
plt.show()

print("\n注意力模式解读 / Attention Pattern Interpretation:")
print("  • 编码阶段（前两个图）：每个分段只关注自己 / Encoding phase (first two plots): Each segment attends only to itself")
print("  • 生成阶段（后两个图）：生成的令牌可以关注所有内容 / Generation phase (last two plots): Generated tokens attend to all content")
print("  • 这种设计允许融合来自所有改写的信息 / This design allows fusion of information from all paraphrases")

---

## 第五部分：掩码矩阵导出 / Part 5: Mask Matrix Export

### 保存图片和数据 / Save figures and data

In [None]:
import os

# 创建输出目录 / Create output directory
output_dir = 'attention_mask_outputs'
os.makedirs(output_dir, exist_ok=True)

print(f"保存可视化结果到 / Saving visualizations to: {output_dir}/")
print("="*60)

# 保存流程图 / Save flowchart
if 'fig' in locals():
    flowchart_path = os.path.join(output_dir, 'flowchart.png')
    fig.savefig(flowchart_path, dpi=300, bbox_inches='tight')
    print(f"✓ 流程图已保存 / Flowchart saved: {flowchart_path}")

# 保存注意力掩码可视化 / Save attention mask visualizations
if 'fig1' in locals():
    fig1.savefig(os.path.join(output_dir, 'mask_small.png'), dpi=300, bbox_inches='tight')
    print(f"✓ 小型示例掩码已保存 / Small example mask saved")

if 'fig2' in locals():
    fig2.savefig(os.path.join(output_dir, 'mask_medium.png'), dpi=300, bbox_inches='tight')
    print(f"✓ 中型示例掩码已保存 / Medium example mask saved")

if 'fig3' in locals():
    fig3.savefig(os.path.join(output_dir, 'mask_large.png'), dpi=300, bbox_inches='tight')
    print(f"✓ 大型示例掩码已保存 / Large example mask saved")

if 'fig_custom' in locals():
    fig_custom.savefig(os.path.join(output_dir, 'mask_custom.png'), dpi=300, bbox_inches='tight')
    print(f"✓ 自定义掩码已保存 / Custom mask saved")

# 保存注意力模式分析 / Save attention pattern analysis
if 'fig_patterns' in locals():
    fig_patterns.savefig(os.path.join(output_dir, 'attention_patterns.png'), dpi=300, bbox_inches='tight')
    print(f"✓ 注意力模式分析已保存 / Attention patterns saved")

# 保存掩码矩阵为 numpy 数组 / Save mask matrices as numpy arrays
if 'matrix_custom' in locals():
    np.save(os.path.join(output_dir, 'mask_matrix_custom.npy'), matrix_custom)
    print(f"✓ 自定义掩码矩阵已保存 / Custom mask matrix saved (numpy)")

print("="*60)
print(f"\n所有文件已保存到目录 / All files saved to: {output_dir}/")
print(f"\n可以使用以下代码加载掩码矩阵 / Load mask matrix with:")
print(f"  matrix = np.load('{output_dir}/mask_matrix_custom.npy')")

---

## 总结 / Summary

### 本笔记本提供的功能 / Features Provided by This Notebook:

1. **完整的代码流程图** - 展示从输入到输出的整个处理流程 / **Complete code flowchart** - Shows entire processing pipeline from input to output

2. **精确的注意力掩码可视化** - 三种预设场景 + 自定义配置 / **Precise attention mask visualization** - Three preset scenarios + custom configuration

3. **智能采样策略** - 处理大型序列时保持结构可见性 / **Smart sampling strategy** - Maintains structure visibility for large sequences

4. **注意力模式分析** - 对比编码和生成阶段的注意力行为 / **Attention pattern analysis** - Compares encoding and generation phase behavior

5. **易于修改的参数** - 所有配置集中在配置字典中 / **Easy-to-modify parameters** - All configs centralized in dictionaries

6. **导出功能** - 保存所有可视化和数据 / **Export functionality** - Save all visualizations and data

### 如何使用 / How to Use:

1. **查看预设示例**：运行所有单元格查看三种预设场景 / **View preset examples**: Run all cells to see three preset scenarios

2. **自定义实验**：修改 `CUSTOM_CONFIG` 字典中的参数 / **Custom experimentation**: Modify parameters in `CUSTOM_CONFIG` dictionary

3. **调整外观**：修改 `FLOWCHART_CONFIG` 来改变流程图样式 / **Adjust appearance**: Modify `FLOWCHART_CONFIG` to change flowchart style

4. **保存结果**：运行最后一个单元格导出所有可视化 / **Save results**: Run last cell to export all visualizations

### 关键见解 / Key Insights:

- **编码阶段**：每个改写在其自己的分段内隔离处理 / **Encoding phase**: Each paraphrase processed in isolation within its segment
- **生成阶段**：新令牌可以关注所有先前内容，实现融合 / **Generation phase**: New tokens attend to all previous content, enabling fusion
- **掩码形状**：清晰展示块对角结构（编码）+ 全关注（生成）/ **Mask shape**: Clearly shows block-diagonal structure (encoding) + full attention (generation)

---

**需要帮助？/ Need Help?**

- 查看 `flex_attention_generate.py` 了解实现细节 / See `flex_attention_generate.py` for implementation details
- 查看 `test_mask_visualization.py` 了解更多测试示例 / See `test_mask_visualization.py` for more test examples
- 查看 `docs/ARCHITECTURE.md` 了解架构说明 / See `docs/ARCHITECTURE.md` for architecture documentation