In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="umap")
import umap.umap_ as umap
UMAP = umap.UMAP


import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from sklearn.cluster import KMeans

# 简化的隐藏表示提取函数
def extract_hidden_representations(model, params, samples):
    """提取样本的隐藏表示"""
    # 获取模型应用函数
    def get_embed(x):
        # 在实际代码中处理一个样本
        if model.two_dimensional:
            patches = extract_patches2d(x.reshape(1, -1), model.patch_size)[0]
        else:
            patches = extract_patches1d(x.reshape(1, -1), model.patch_size)[0]
        
        # 对于每个patch应用嵌入层
        patch_embeddings = jnp.matmul(patches, params["embed"]["kernel"]) + params["embed"]["bias"]
        return jnp.mean(patch_embeddings, axis=0)
    
    # 对每个样本应用函数
    features = np.array([jax.device_get(get_embed(s)) for s in samples])
    return features

# 生成可视化
def visualize_hidden_representations(model, params, samples):
    """创建隐藏表示的可视化，确保颜色条正确放置"""
    # 生成随机参数(初始状态的模拟)
    rng_key = jax.random.PRNGKey(42)
    random_params = jax.tree_util.tree_map(
        lambda x: jax.random.normal(rng_key, x.shape, x.dtype),
        params
    )
    
    # 提取隐藏表示
    print("提取训练后的隐藏表示...")
    features_trained = extract_hidden_representations(model, params, samples)
    print(f"训练后特征形状: {features_trained.shape}")
    
    print("提取随机初始化的隐藏表示...")
    features_random = extract_hidden_representations(model, random_params, samples)
    print(f"随机特征形状: {features_random.shape}")
    
    # 使用UMAP进行降维
    print("对随机特征应用UMAP降维...")
    reducer_random = UMAP(random_state=42)
    embedding_random = reducer_random.fit_transform(features_random)
    
    print("对训练后特征应用UMAP降维...")
    reducer_trained = UMAP(random_state=42)
    embedding_trained = reducer_trained.fit_transform(features_trained)
    
    # 使用第一个UMAP维度作为着色的代理
    color_values = embedding_trained[:, 0]
    normalized_colors = (color_values - np.min(color_values)) / (np.max(color_values) - np.min(color_values))
    
    # 创建一个新的图，设置合适的宽高比以容纳colorbar
    fig = plt.figure(figsize=(16, 6))
    
    # 创建子图，并为colorbar留出空间
    ax1 = fig.add_subplot(121)  # 1行2列的第1个
    ax2 = fig.add_subplot(122)  # 1行2列的第2个
    
    # 图10(a): 随机初始化参数的隐藏表示
    scatter1 = ax1.scatter(embedding_random[:, 0], embedding_random[:, 1], 
                         c=normalized_colors, cmap='viridis', alpha=0.8, s=20)
    ax1.set_title('Random Initialization', fontsize=14)
    ax1.set_xlabel('UMAP dimension 1', fontsize=12)
    ax1.set_ylabel('UMAP dimension 2', fontsize=12)
    
    # 图10(b): 优化后参数的隐藏表示
    scatter2 = ax2.scatter(embedding_trained[:, 0], embedding_trained[:, 1], 
                         c=normalized_colors, cmap='viridis', alpha=0.8, s=20)
    ax2.set_title('After Optimization', fontsize=14)
    ax2.set_xlabel('UMAP dimension 1', fontsize=12)
    ax2.set_ylabel('UMAP dimension 2', fontsize=12)
    
    # 添加颜色条在图的右侧
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    cbar = fig.colorbar(scatter2, cax=cbar_ax)
    cbar.set_label('Normalized Feature Value', rotation=270, labelpad=15)
    
    # 调整布局，确保所有元素可见
    fig.tight_layout(rect=[0, 0, 0.9, 1])  # [left, bottom, right, top]
    
    return fig

# 可视化注意力图
def visualize_attention_maps(params, n_layers=2, n_heads=12):
    """可视化Transformer的注意力映射"""
    # 提取注意力权重
    attention_maps = []
    
    for l in range(n_layers):
        # 获取当前层的注意力权重
        J = params["encoder"][f"layers_{l}"]["attn"]["J"]
        attention_maps.append(J)
    
    # 转换为numpy数组以便绘图
    attention_maps = np.array(jax.device_get(attention_maps))  # shape: [n_layers, n_heads, L_eff]
    print(f"注意力图形状: {attention_maps.shape}")
    
    # 创建图11(a)：各层各头的注意力图
    fig, axs = plt.subplots(n_layers, n_heads, figsize=(20, 5*n_layers))
    if n_layers == 1:
        axs = np.array([axs])
    
    for l in range(n_layers):
        for h in range(n_heads):
            # 直接显示1D注意力权重
            attention_weights = attention_maps[l, h]
            
            # 尝试重塑为方形网格进行可视化
            side = int(np.sqrt(attention_weights.shape[0]))
            if side * side == attention_weights.shape[0]:
                attention_grid = attention_weights.reshape(side, side)
                im = axs[l, h].imshow(attention_grid, cmap='viridis')
            else:
                # 如果无法重塑为方形，显示为线性图
                axs[l, h].bar(range(len(attention_weights)), attention_weights)
                
            axs[l, h].set_title(f'Layer {l+1}, Head {h+1}')
            axs[l, h].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # 创建图11(b)：所有注意力图绝对值的平均
    mean_attention = np.mean(np.abs(attention_maps), axis=(0, 1))  # shape: [L_eff]
    
    plt.figure(figsize=(8, 6))
    
    # 尝试将均值注意力重塑为方形网格
    side = int(np.sqrt(mean_attention.shape[0]))
    if side * side == mean_attention.shape[0]:
        mean_attention_grid = mean_attention.reshape(side, side)
        im = plt.imshow(mean_attention_grid, cmap='viridis')
        plt.title('Mean Attention Map', fontsize=14)
        plt.axis('off')
    else:
        # 如果无法重塑为方形，显示为线性图
        plt.bar(range(len(mean_attention)), mean_attention)
        plt.title('Mean Attention Weights', fontsize=14)
        plt.xlabel('Position', fontsize=12)
        plt.ylabel('Mean Absolute Value', fontsize=12)
    
    plt.colorbar(im if 'im' in locals() else None, label='Mean Absolute Attention')
    plt.tight_layout()
    plt.show()

# 聚类分析
def analyze_hidden_clusters(model, params, samples, n_clusters=5):
    """分析隐藏表示中的聚类结构"""
    from sklearn.cluster import KMeans
    
    # 提取隐藏表示
    features = extract_hidden_representations(model, params, samples)
    
    # 使用K-means进行聚类
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = kmeans.fit_predict(features)
    
    # 使用UMAP降维可视化
    reducer = UMAP(random_state=42)
    embedding = reducer.fit_transform(features)
    
    # 可视化聚类
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=clusters, cmap='tab10', s=30, alpha=0.8)
    plt.colorbar(scatter, label='Cluster ID')
    plt.xlabel('UMAP dimension 1', fontsize=12)
    plt.ylabel('UMAP dimension 2', fontsize=12)
    plt.title('Hidden Representation Clusters', fontsize=14)
    
    # 打印聚类统计信息
    print("Cluster Statistics:")
    for i in range(n_clusters):
        mask = (clusters == i)
        cluster_size = np.sum(mask)
        print(f"Cluster {i}: {cluster_size} samples")
    
    plt.tight_layout()
    plt.show()
    
    return clusters

# 执行可视化
try:
    # 生成样本
    samples = vqs.sample(n_samples=1000)
    
    # 创建隐藏表示的可视化
    print("正在创建隐藏表示可视化...")
    fig = visualize_hidden_representations(model_no_symm, vqs.parameters, samples)
    plt.savefig('figure10_hidden_representations.png', dpi=300)
    plt.show()
    
    # 可视化注意力图
    print("正在创建注意力图可视化...")
    visualize_attention_maps(vqs.parameters, n_layers=N_layers, n_heads=n_heads)
    
    # 执行聚类分析
    print("正在进行隐藏表示聚类分析...")
    clusters = analyze_hidden_clusters(model_no_symm, vqs.parameters, samples, n_clusters=5)
    
except Exception as e:
    print(f"错误: {e}")
    import traceback
    traceback.print_exc()
