In [172]:
import torch
import mx
from mx.mx_ops import quantize_mx_op, get_mx_quantize_params, apply_mx_quantize_with_param

from mx.elemwise_ops import quantize_elemwise_op
from mx.specs import MxSpecs

In [173]:
import transformers
model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", torch_dtype=torch.float16)
model.eval()
w = model.model.layers[0].self_attn.k_proj.weight.data.clone()
print("Original weight shape:", w.shape)

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.66it/s]


Original weight shape: torch.Size([1024, 4096])


In [174]:
mx_specs = MxSpecs()
mx_specs["custom_cuda"] = True
mx_specs["w_elem_format"] = "fp4_e2m1"
mx_specs["w_scale_mode"] = 0
mx_specs["block_size"] = 32
mx_specs["round_mx_output"] = "nearest"


In [175]:
def quantize_mx_w(weight):
    dtype = weight.dtype
    # element-wise quantize for input
    bf_weight = quantize_elemwise_op(
            weight.float(), mx_specs=mx_specs, round=mx_specs["round_weight"]
        )
    if not mx_specs['double_quant']:
        qis_weight = quantize_mx_op(
            bf_weight,
            mx_specs,
            elem_format=mx_specs['w_elem_format'],
            scale_mode=mx_specs['w_scale_mode'],
            axes=[-1],
            round=mx_specs["round_mx_output"],
        )
    elif mx_specs['double_quant']: 
        def scale_mode_to_elem_format(scale_mode):
            if scale_mode == 143:
                return 'fp8_e4m3'
            elif scale_mode == 152:
                return 'fp8_e5m2'
            elif scale_mode == 0:
                return 'e8m0'
            else:
                raise ValueError(f"Unsupported scale mode: {scale_mode}")
        w_scale, _, _, q_w = get_mx_quantize_params(
            bf_weight,
            mx_specs,
            elem_format=mx_specs['w_elem_format'],
            scale_mode=2,
            axes=[-1],
            round=mx_specs["round_mx_output"],
        )
        scale_mx_specs = mx_specs.copy()
        scale_mx_specs['block_size'] = -1
        q_w_scale = quantize_mx_op(
                w_scale,
                scale_mx_specs,
                elem_format=scale_mode_to_elem_format(mx_specs['w_scale_mode']),
                scale_mode=2,
                axes=[-1],
                round=scale_mx_specs["round_mx_output"],
            )
        qis_weight = q_w_scale * q_w
        
    return qis_weight.to(dtype)

In [176]:
fq_w = quantize_mx_w(w.to('cuda:7'))
mx_specs['double_quant'] = True
fdq_w = quantize_mx_w(w.to('cuda:7'))
a = torch.nn.functional.mse_loss(w.to('cuda:7'), fq_w)
b = torch.nn.functional.mse_loss(w.to('cuda:7'), fdq_w)

In [177]:
bf_weight = quantize_elemwise_op(
            w.to('cuda:7').float(), mx_specs=mx_specs, round=mx_specs["round_weight"]
        )
w_scale1, _, _, q_w1 = get_mx_quantize_params(
    bf_weight,
    mx_specs,
    elem_format=mx_specs['w_elem_format'],
    scale_mode=mx_specs['w_scale_mode'],
    axes=[-1],
    round=mx_specs["round_mx_output"],
        )

In [178]:
w_scale2, _, _, q_w2 = get_mx_quantize_params(
    bf_weight,
    mx_specs,
    elem_format=mx_specs['w_elem_format'],
    scale_mode=2,
    axes=[-1],
    round=mx_specs["round_mx_output"],
        )
def scale_mode_to_elem_format(scale_mode):
            if scale_mode == 143:
                return 'fp8_e4m3'
            elif scale_mode == 152:
                return 'fp8_e5m2'
            elif scale_mode == 0:
                return 'e8m0'
            else:
                raise ValueError(f"Unsupported scale mode: {scale_mode}")
scale_mx_specs = mx_specs.copy()
scale_mx_specs['block_size'] = -1
q_w_scale = quantize_mx_op(
        w_scale2,
        scale_mx_specs,
        elem_format=scale_mode_to_elem_format(mx_specs['w_scale_mode']),
        scale_mode=2,
        axes=[-1],
        round=scale_mx_specs["round_mx_output"],
    )
qis_weight = q_w_scale * q_w2

In [179]:
def compute_tile_quantization_loss(original, quantized, block_size=32, loss_type='mse', axis=-1):
    """
    计算每个 tile 的量化损失 - 支持权重量化
    
    Args:
        original: 原始张量
        quantized: 量化后的张量
        block_size: tile 大小
        loss_type: 损失类型 ('mse', 'mae', 'cosine', 'snr')
        axis: 分 tile 的维度（-1 表示最后一个维度）
    """
    
    # 确保张量形状一致
    assert original.shape == quantized.shape
    
    # 获取指定轴的大小
    axis_size = original.shape[axis]
    num_tiles = (axis_size + block_size - 1) // block_size
    padded_size = num_tiles * block_size
    
    # 计算 padding
    if padded_size > axis_size:
        padding = padded_size - axis_size
        # 创建 padding 配置：只在指定轴的末尾 padding
        pad_config = [0, 0] * original.ndim
        pad_config[-(axis+1)*2 + 1] = padding  # 在指定轴末尾padding
        
        original_padded = torch.nn.functional.pad(original, pad_config)
        quantized_padded = torch.nn.functional.pad(quantized, pad_config)
    else:
        original_padded = original
        quantized_padded = quantized
    
    # 重塑为 tiles
    # 将指定轴移到最后，然后 reshape
    original_moved = original_padded.moveaxis(axis, -1)
    quantized_moved = quantized_padded.moveaxis(axis, -1)
    
    # 计算前面维度的总大小
    front_dims = original_moved.shape[:-1]
    total_front = torch.prod(torch.tensor(front_dims)).item()
    
    # 重塑为 [total_front * num_tiles, block_size]
    original_tiles = original_moved.reshape(total_front, num_tiles, block_size).reshape(-1, block_size)
    quantized_tiles = quantized_moved.reshape(total_front, num_tiles, block_size).reshape(-1, block_size)
    
    # 计算每个 tile 的损失
    if loss_type == 'mse':
        tile_losses = torch.mean((original_tiles - quantized_tiles) ** 2, dim=1)
    elif loss_type == 'mae':
        tile_losses = torch.mean(torch.abs(original_tiles - quantized_tiles), dim=1)
    elif loss_type == 'cosine':
        cosine_sim = torch.nn.functional.cosine_similarity(original_tiles, quantized_tiles, dim=1)
        tile_losses = 1 - cosine_sim
    elif loss_type == 'snr':
        signal_power = torch.mean(original_tiles ** 2, dim=1)
        noise_power = torch.mean((original_tiles - quantized_tiles) ** 2, dim=1)
        tile_losses = -10 * torch.log10(signal_power / (noise_power + 1e-8))
    
    # 重塑回原来的结构：[total_front, num_tiles]
    tile_losses = tile_losses.reshape(total_front, num_tiles)
    
    return tile_losses

In [180]:
import matplotlib.pyplot as plt
import numpy as np

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'WenQuanYi Micro Hei']
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 增强版绘图函数 - 修复中文显示
def plot_detailed_quantization_analysis(q_t1_loss, q_t2_loss, q_t1_label='scale_mode=0', q_t2_label='scale_mode=2'):
    """
    绘制详细的量化损失分析图
    """
    
    # 转换为 numpy
    q_t1_np = q_t1_loss.cpu().numpy() if hasattr(q_t1_loss, 'cpu') else q_t1_loss
    q_t2_np = q_t2_loss.cpu().numpy() if hasattr(q_t2_loss, 'cpu') else q_t2_loss
    
    num_tiles = len(q_t1_np)
    tile_indices = np.arange(num_tiles)
    
    # 创建子图
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. 柱状图对比
    width = 0.35
    axes[0,0].bar(tile_indices - width/2, q_t1_np, width, 
                  label=q_t1_label, color='lightblue', alpha=0.8)
    axes[0,0].bar(tile_indices + width/2, q_t2_np, width, 
                  label=q_t2_label, color='darkblue', alpha=0.8)
    axes[0,0].set_xlabel('Group Index (Tile Index)')
    axes[0,0].set_ylabel('Quantization Loss (MSE)')
    axes[0,0].set_title('Quantization Loss Comparison')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. 差值分析
    diff = q_t1_np - q_t2_np
    colors = ['green' if d > 0 else 'red' for d in diff]
    axes[0,1].bar(tile_indices, diff, color=colors, alpha=0.7)
    axes[0,1].set_xlabel('Group Index (Tile Index)')
    axes[0,1].set_ylabel('Loss Difference (q_t1 - q_t2)')
    axes[0,1].set_title('Loss Difference (Green: scale_mode=2 Better)')
    axes[0,1].grid(True, alpha=0.3)
    axes[0,1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
    
    # 3. 改进百分比
    improvement = (diff / q_t1_np) * 100
    axes[1,0].bar(tile_indices, improvement, color='orange', alpha=0.7)
    axes[1,0].set_xlabel('Group Index (Tile Index)')
    axes[1,0].set_ylabel('Improvement Percentage (%)')
    axes[1,0].set_title(f'{q_t1_label} vs {q_t2_label} Improvement')
    axes[1,0].grid(True, alpha=0.3)
    axes[1,0].axhline(y=0, color='black', linestyle='--', alpha=0.5)
    
    # 4. 累积损失
    cumsum_t1 = np.cumsum(q_t1_np)
    cumsum_t2 = np.cumsum(q_t2_np)
    axes[1,1].plot(tile_indices, cumsum_t1, 'o-', label=q_t1_label, 
                   color='lightblue', linewidth=2, markersize=4)
    axes[1,1].plot(tile_indices, cumsum_t2, 's-', label=q_t2_label, 
                   color='darkblue', linewidth=2, markersize=4)
    axes[1,1].set_xlabel('Group Index (Tile Index)')
    axes[1,1].set_ylabel('Cumulative Loss')
    axes[1,1].set_title('Cumulative Quantization Loss')
    axes[1,1].legend()
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # 打印统计信息
    print("=" * 60)
    print("Quantization Loss Statistical Analysis")
    print("=" * 60)
    print(f"{'Metric':<20} {q_t1_label:<15} {q_t2_label:<15} {'Improvement%':<10}")
    print("-" * 60)
    print(f"{'Average Loss':<20} {np.mean(q_t1_np):<15.6f} {np.mean(q_t2_np):<15.6f} {np.mean(improvement):<10.2f}")
    print(f"{'Total Loss':<20} {np.sum(q_t1_np):<15.6f} {np.sum(q_t2_np):<15.6f} {(np.sum(diff)/np.sum(q_t1_np)*100):<10.2f}")
    print(f"{'Max Loss':<20} {np.max(q_t1_np):<15.6f} {np.max(q_t2_np):<15.6f} {np.max(improvement):<10.2f}")
    print(f"{'Min Loss':<20} {np.min(q_t1_np):<15.6f} {np.min(q_t2_np):<15.6f} {np.min(improvement):<10.2f}")
    print(f"{'Std Dev':<20} {np.std(q_t1_np):<15.6f} {np.std(q_t2_np):<15.6f}")


In [181]:
def plot_scaled_data_histograms(scaled_pot, scaled_fp16, bins=100):
    """
    绘制两种缩放数据的直方图对比
    """
    
    # 转换为 numpy 数组
    scaled_pot_np = scaled_pot.cpu().numpy() if hasattr(scaled_pot, 'cpu') else scaled_pot
    scaled_fp16_np = scaled_fp16.cpu().numpy() if hasattr(scaled_fp16, 'cpu') else scaled_fp16
    
    # 创建子图
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. 重叠直方图
    axes[0,0].hist(scaled_pot_np, bins=bins, alpha=0.7, label='PoT Scale (scale_mode=0)', 
                   color='lightblue', density=True)
    axes[0,0].hist(scaled_fp16_np, bins=bins, alpha=0.7, label='FP16 Scale (scale_mode=2)', 
                   color='darkblue', density=True)
    axes[0,0].set_xlabel('Scaled Values')
    axes[0,0].set_ylabel('Density')
    axes[0,0].set_title('Overlapped Histograms of Scaled Data')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. 分别的直方图 - PoT
    axes[0,1].hist(scaled_pot_np, bins=bins, color='lightblue', alpha=0.8)
    axes[0,1].set_xlabel('Scaled Values')
    axes[0,1].set_ylabel('Frequency')
    axes[0,1].set_title('PoT Scale Distribution (scale_mode=0)')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. 分别的直方图 - FP16
    axes[1,0].hist(scaled_fp16_np, bins=bins, color='darkblue', alpha=0.8)
    axes[1,0].set_xlabel('Scaled Values')
    axes[1,0].set_ylabel('Frequency')
    axes[1,0].set_title('FP16 Scale Distribution (scale_mode=2)')
    axes[1,0].grid(True, alpha=0.3)
    
    # 4. 累积分布函数对比
    sorted_pot = np.sort(scaled_pot_np)
    sorted_fp16 = np.sort(scaled_fp16_np)
    y_pot = np.arange(1, len(sorted_pot) + 1) / len(sorted_pot)
    y_fp16 = np.arange(1, len(sorted_fp16) + 1) / len(sorted_fp16)
    
    axes[1,1].plot(sorted_pot, y_pot, label='PoT Scale', color='lightblue', linewidth=2)
    axes[1,1].plot(sorted_fp16, y_fp16, label='FP16 Scale', color='darkblue', linewidth=2)
    axes[1,1].set_xlabel('Scaled Values')
    axes[1,1].set_ylabel('Cumulative Probability')
    axes[1,1].set_title('Cumulative Distribution Functions')
    axes[1,1].legend()
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # 打印统计信息
    print("=" * 70)
    print("Scaled Data Statistical Analysis")
    print("=" * 70)
    print(f"{'Metric':<20} {'PoT Scale':<15} {'FP16 Scale':<15} {'Difference':<15}")
    print("-" * 70)
    print(f"{'Mean':<20} {np.mean(scaled_pot_np):<15.6f} {np.mean(scaled_fp16_np):<15.6f} {np.mean(scaled_pot_np) - np.mean(scaled_fp16_np):<15.6f}")
    print(f"{'Std Dev':<20} {np.std(scaled_pot_np):<15.6f} {np.std(scaled_fp16_np):<15.6f} {np.std(scaled_pot_np) - np.std(scaled_fp16_np):<15.6f}")
    print(f"{'Min':<20} {np.min(scaled_pot_np):<15.6f} {np.min(scaled_fp16_np):<15.6f} {np.min(scaled_pot_np) - np.min(scaled_fp16_np):<15.6f}")
    print(f"{'Max':<20} {np.max(scaled_pot_np):<15.6f} {np.max(scaled_fp16_np):<15.6f} {np.max(scaled_pot_np) - np.max(scaled_fp16_np):<15.6f}")
    print(f"{'Median':<20} {np.median(scaled_pot_np):<15.6f} {np.median(scaled_fp16_np):<15.6f} {np.median(scaled_pot_np) - np.median(scaled_fp16_np):<15.6f}")
    print(f"{'25th Percentile':<20} {np.percentile(scaled_pot_np, 25):<15.6f} {np.percentile(scaled_fp16_np, 25):<15.6f} {np.percentile(scaled_pot_np, 25) - np.percentile(scaled_fp16_np, 25):<15.6f}")
    print(f"{'75th Percentile':<20} {np.percentile(scaled_pot_np, 75):<15.6f} {np.percentile(scaled_fp16_np, 75):<15.6f} {np.percentile(scaled_pot_np, 75) - np.percentile(scaled_fp16_np, 75):<15.6f}")
    
    # 计算分布相似性
    from scipy import stats
    ks_stat, ks_p = stats.ks_2samp(scaled_pot_np, scaled_fp16_np)
    print(f"\nKolmogorov-Smirnov Test:")
    print(f"KS Statistic: {ks_stat:.6f}")
    print(f"P-value: {ks_p:.6f}")
    print(f"Distributions are {'similar' if ks_p > 0.05 else 'significantly different'} (α=0.05)")


In [182]:
from fast_hadamard_transform import hadamard_transform
import math
# token_roted = hadamard_transform(token.reshape(-1, token.shape[0] // 32,
#                                                       32), scale=1 / math.sqrt(32)).reshape(-1)

In [183]:

def analyze_tile_details(original, quantized, tile_index, block_size=32, title_prefix=""):
    """
    分析指定tile的详细量化情况
    
    Args:
        original: 原始张量
        quantized: 量化后的张量
        tile_index: 要分析的tile索引
        block_size: tile大小
        title_prefix: 图表标题前缀
    """
    
    # 转换为numpy
    original_np = original.cpu().numpy() if hasattr(original, 'cpu') else original
    quantized_np = quantized.cpu().numpy() if hasattr(quantized, 'cpu') else quantized
    
    # 计算padding并重塑为tiles
    seq_len = len(original_np)
    num_tiles = (seq_len + block_size - 1) // block_size
    padded_len = num_tiles * block_size
    
    if padded_len > seq_len:
        padding = padded_len - seq_len
        original_padded = np.pad(original_np, (0, padding))
        quantized_padded = np.pad(quantized_np, (0, padding))
    else:
        original_padded = original_np
        quantized_padded = quantized_np
    
    # 重塑为tiles
    original_tiles = original_padded.reshape(-1, block_size)
    quantized_tiles = quantized_padded.reshape(-1, block_size)
    
    # 获取指定tile的数据
    orig_tile = original_tiles[tile_index]
    quant_tile = quantized_tiles[tile_index]
    
    # 计算每个元素的损失
    element_losses = (orig_tile - quant_tile) ** 2
    element_indices = np.arange(block_size)
    
    # 创建详细分析图
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. 原始值 vs 量化值对比
    axes[0,0].plot(element_indices, orig_tile, 'o-', label='Original', 
                   color='blue', linewidth=2, markersize=4)
    axes[0,0].plot(element_indices, quant_tile, 's-', label='Quantized', 
                   color='red', linewidth=2, markersize=4)
    axes[0,0].set_xlabel('Element Index in Tile')
    axes[0,0].set_ylabel('Value')
    axes[0,0].set_title(f'{title_prefix}Tile {tile_index}: Original vs Quantized')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. 每个元素的量化误差
    axes[0,1].bar(element_indices, orig_tile - quant_tile, 
                  color='orange', alpha=0.7)
    axes[0,1].set_xlabel('Element Index in Tile')
    axes[0,1].set_ylabel('Quantization Error (Original - Quantized)')
    axes[0,1].set_title(f'Tile {tile_index}: Quantization Error per Element')
    axes[0,1].grid(True, alpha=0.3)
    axes[0,1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
    
    # 3. 每个元素的平方误差（MSE贡献）
    colors = ['red' if loss == np.max(element_losses) else 'lightcoral' for loss in element_losses]
    bars = axes[1,0].bar(element_indices, element_losses, color=colors, alpha=0.7)
    axes[1,0].set_xlabel('Element Index in Tile')
    axes[1,0].set_ylabel('Squared Error (MSE Contribution)')
    axes[1,0].set_title(f'Tile {tile_index}: MSE Contribution per Element')
    axes[1,0].grid(True, alpha=0.3)
    
    # 标记最大损失的元素
    max_loss_idx = np.argmax(element_losses)
    axes[1,0].annotate(f'Max Loss\nIndex: {max_loss_idx}\nValue: {element_losses[max_loss_idx]:.6f}',
                       xy=(max_loss_idx, element_losses[max_loss_idx]),
                       xytext=(max_loss_idx + 3, element_losses[max_loss_idx] * 1.2),
                       arrowprops=dict(arrowstyle='->', color='black'),
                       fontsize=10, ha='left')
    
    # 4. 相对误差百分比
    relative_error = np.abs(orig_tile - quant_tile) / (np.abs(orig_tile) + 1e-8) * 100
    axes[1,1].bar(element_indices, relative_error, color='purple', alpha=0.7)
    axes[1,1].set_xlabel('Element Index in Tile')
    axes[1,1].set_ylabel('Relative Error (%)')
    axes[1,1].set_title(f'Tile {tile_index}: Relative Error per Element')
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # 打印详细统计信息
    print("=" * 80)
    print(f"Detailed Analysis for {title_prefix}Tile {tile_index}")
    print("=" * 80)
    print(f"Tile MSE: {np.mean(element_losses):.8f}")
    print(f"Tile Max Squared Error: {np.max(element_losses):.8f} (at index {max_loss_idx})")
    print(f"Tile Min Squared Error: {np.min(element_losses):.8f}")
    print(f"Tile Mean Absolute Error: {np.mean(np.abs(orig_tile - quant_tile)):.8f}")
    print()
    
    # 找出损失最大的几个元素
    top_loss_indices = np.argsort(element_losses)[-5:][::-1]  # 前5个最大损失
    print("Top 5 elements with highest squared error:")
    print(f"{'Index':<8} {'Original':<12} {'Quantized':<12} {'Error':<12} {'Squared Error':<15} {'Rel Error %':<12}")
    print("-" * 80)
    for idx in top_loss_indices:
        orig_val = orig_tile[idx]
        quant_val = quant_tile[idx]
        error = orig_val - quant_val
        sq_error = element_losses[idx]
        rel_error = np.abs(error) / (np.abs(orig_val) + 1e-8) * 100
        print(f"{idx:<8} {orig_val:<12.6f} {quant_val:<12.6f} {error:<12.6f} {sq_error:<15.8f} {rel_error:<12.2f}")


In [184]:
def compare_rotation_effect(original_loss, rotated_loss, title="Rotation Effect Analysis"):
    """
    比较旋转前后的量化损失
    
    Args:
        original_loss: 旋转前的量化损失
        rotated_loss: 旋转后的量化损失
        title: 图表标题
    """
    
    # 转换为 numpy
    orig_np = original_loss.cpu().numpy() if hasattr(original_loss, 'cpu') else original_loss
    rot_np = rotated_loss.cpu().numpy() if hasattr(rotated_loss, 'cpu') else rotated_loss
    
    num_tiles = len(orig_np)
    tile_indices = np.arange(num_tiles)
    
    # 计算损失变化
    loss_diff = rot_np - orig_np  # 正值表示旋转后损失增加
    loss_ratio = rot_np / (orig_np + 1e-8)  # 损失比率
    loss_change_percent = (loss_diff / (orig_np + 1e-8)) * 100  # 损失变化百分比
    
    # 创建详细分析图
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. 旋转前后损失对比
    width = 0.35
    axes[0,0].bar(tile_indices - width/2, orig_np, width, 
                  label='Before Rotation', color='lightblue', alpha=0.8)
    axes[0,0].bar(tile_indices + width/2, rot_np, width, 
                  label='After Rotation', color='orange', alpha=0.8)
    axes[0,0].set_xlabel('Tile Index')
    axes[0,0].set_ylabel('Quantization Loss (MSE)')
    axes[0,0].set_title('Quantization Loss: Before vs After Rotation')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. 损失变化（绝对值）
    colors = ['red' if d > 0 else 'green' for d in loss_diff]
    bars = axes[0,1].bar(tile_indices, loss_diff, color=colors, alpha=0.7)
    axes[0,1].set_xlabel('Tile Index')
    axes[0,1].set_ylabel('Loss Change (After - Before)')
    axes[0,1].set_title('Loss Change after Rotation (Red: Increased, Green: Decreased)')
    axes[0,1].grid(True, alpha=0.3)
    axes[0,1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
    
    # 3. 损失变化百分比
    colors_pct = ['red' if p > 0 else 'green' for p in loss_change_percent]
    axes[1,0].bar(tile_indices, loss_change_percent, color=colors_pct, alpha=0.7)
    axes[1,0].set_xlabel('Tile Index')
    axes[1,0].set_ylabel('Loss Change Percentage (%)')
    axes[1,0].set_title('Percentage Change in Loss after Rotation')
    axes[1,0].grid(True, alpha=0.3)
    axes[1,0].axhline(y=0, color='black', linestyle='--', alpha=0.5)
    
    # 4. 损失比率分布
    axes[1,1].hist(loss_ratio, bins=30, color='purple', alpha=0.7, edgecolor='black')
    axes[1,1].axvline(x=1, color='red', linestyle='--', linewidth=2, label='No Change (ratio=1)')
    axes[1,1].set_xlabel('Loss Ratio (After/Before)')
    axes[1,1].set_ylabel('Frequency')
    axes[1,1].set_title('Distribution of Loss Ratios')
    axes[1,1].legend()
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # 统计分析
    increased_tiles = np.where(loss_diff > 0)[0]
    decreased_tiles = np.where(loss_diff < 0)[0]
    unchanged_tiles = np.where(np.abs(loss_diff) < 1e-8)[0]
    
    print("=" * 80)
    print(f"{title}")
    print("=" * 80)
    print(f"Total tiles: {num_tiles}")
    print(f"Tiles with increased loss: {len(increased_tiles)} ({len(increased_tiles)/num_tiles*100:.1f}%)")
    print(f"Tiles with decreased loss: {len(decreased_tiles)} ({len(decreased_tiles)/num_tiles*100:.1f}%)")
    print(f"Tiles with unchanged loss: {len(unchanged_tiles)} ({len(unchanged_tiles)/num_tiles*100:.1f}%)")
    print()
    
    print("Overall Statistics:")
    print(f"{'Metric':<25} {'Before Rotation':<15} {'After Rotation':<15} {'Change':<15}")
    print("-" * 80)
    print(f"{'Mean Loss':<25} {np.mean(orig_np):<15.8f} {np.mean(rot_np):<15.8f} {np.mean(loss_diff):<15.8f}")
    print(f"{'Total Loss':<25} {np.sum(orig_np):<15.6f} {np.sum(rot_np):<15.6f} {np.sum(loss_diff):<15.6f}")
    print(f"{'Max Loss':<25} {np.max(orig_np):<15.8f} {np.max(rot_np):<15.8f} {np.max(orig_np) - np.max(rot_np):<15.8f}")
    print(f"{'Min Loss':<25} {np.min(orig_np):<15.8f} {np.min(rot_np):<15.8f} {np.min(orig_np) - np.min(rot_np):<15.8f}")
    print(f"{'Std Dev':<25} {np.std(orig_np):<15.8f} {np.std(rot_np):<15.8f} {np.std(rot_np) - np.std(orig_np):<15.8f}")
    print()
    
    print("Loss Change Analysis:")
    print(f"Mean loss change: {np.mean(loss_diff):.8f}")
    print(f"Mean percentage change: {np.mean(loss_change_percent):.2f}%")
    print(f"Max loss increase: {np.max(loss_diff):.8f} (tile {np.argmax(loss_diff)})")
    print(f"Max loss decrease: {np.min(loss_diff):.8f} (tile {np.argmin(loss_diff)})")
    print()
    
    # 找出损失增加最多的tiles
    if len(increased_tiles) > 0:
        print("Top 10 tiles with highest loss increase:")
        sorted_increase_indices = increased_tiles[np.argsort(loss_diff[increased_tiles])[::-1]][:10]
        print(f"{'Tile Index':<12} {'Before':<15} {'After':<15} {'Increase':<15} {'% Change':<12}")
        print("-" * 80)
        for idx in sorted_increase_indices:
            before = orig_np[idx]
            after = rot_np[idx]
            increase = loss_diff[idx]
            pct_change = loss_change_percent[idx]
            print(f"{idx:<12} {before:<15.8f} {after:<15.8f} {increase:<15.8f} {pct_change:<12.2f}")
        print()
    
    # 找出损失减少最多的tiles
    if len(decreased_tiles) > 0:
        print("Top 10 tiles with highest loss decrease:")
        sorted_decrease_indices = decreased_tiles[np.argsort(loss_diff[decreased_tiles])][:10]
        print(f"{'Tile Index':<12} {'Before':<15} {'After':<15} {'Decrease':<15} {'% Change':<12}")
        print("-" * 80)
        for idx in sorted_decrease_indices:
            before = orig_np[idx]
            after = rot_np[idx]
            decrease = loss_diff[idx]
            pct_change = loss_change_percent[idx]
            print(f"{idx:<12} {before:<15.8f} {after:<15.8f} {decrease:<15.8f} {pct_change:<12.2f}")
    
    return {
        'increased_tiles': increased_tiles,
        'decreased_tiles': decreased_tiles,
        'loss_diff': loss_diff,
        'loss_change_percent': loss_change_percent,
        'loss_ratio': loss_ratio
    }

In [185]:
def scatter_largest_to_groups(sorted_idx, block_size):
    N = sorted_idx.shape[-1]
    num_groups = N // block_size
    grouped_idx = torch.empty_like(sorted_idx)
    for i in range(N):
        group = i % num_groups
        pos_in_group = i // num_groups
        grouped_idx[group * block_size + pos_in_group] = sorted_idx[i]
    return grouped_idx


In [186]:
sorted_roted_t =  hadamard_transform(sorted_token.reshape(-1, sorted_token.shape[0] // 32,
                                                      32), scale=1 / math.sqrt(32)).reshape(-1)
sorted_p_rt1 = quantize_mx_op(sorted_roted_t,
                      mx_specs,
                      elem_format="fp4_e2m1",
                      block_size=32,
                      axes=-1,
                      scale_mode=0,)

NameError: name 'sorted_token' is not defined