####  用于将模型输出的embeddings通过库进行查找与可视化分析

In [71]:
import plotly.express as px
import numpy as np
from sklearn.decomposition import PCA
from semantic_eval import evaluate_semantic_mapping
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False     # 正常显示负号

In [72]:
test_type = 'Semantic'
# test_type = 'Visual'
path ='output_embeddings/Semantic_Doublelayer_(1024-512)-lr1e-4/test_embeddings_epoch_150.npy'
# path = 'output_embeddings/Semantic_Doublelayer_(512-256)_lr1e-4/test_embeddings_epoch_200_2'
#savedir 是 path 去掉output_embeddings/
savedir = path[18:]
if test_type == 'Semantic':
    savedir = 'semantic_eval_out/'+savedir
else:
    savedir ='visual_eval_out/'+savedir

In [73]:
label_list = ['丝瓜', '你', '关门', '凳子', '厕所', '口渴', '吃',\
               '喝', '嘴巴', '外卖', '头疼', '家人', '小刀', '帮助',\
                  '平静', '心情', '怎样', '感觉', '愿意', '我', '手机',\
                    '找', '把', '护士', '拿', '换药', '放在', '是', '有',\
                          '朋友', '橙汁', '毛巾', '汤圆', '漂亮', '热水', \
                            '猪肉', '玩', '电脑', '看', '碗', '穿', '篮球',\
                                  '米饭', '给', '脸盆', '菠萝', '葱花', '蒜泥',\
                                      '衣服', '豆腐', '软糖', '醋', '钢琴', '问题',\
                                          '需要', '青菜', '面条', '音乐', '预约', '香肠', '鱼块']

In [74]:
def load_npz_vis(path="Duin_vit_embeddings.npz"):
    data = np.load(path, allow_pickle=True)
    return data["chars"], data["embeddings"], dict(data["meta"])

def load_npz_semantic(path):
    data = np.load(path, allow_pickle=True)
    words = data['words']
    emb_cls = data['emb_cls']
    emb_mean = data['emb_mean']
    emb_max = data['emb_max']
    emb_weighted = data['emb_weighted']
    emb_mixed = data['emb_mixed']
    return words, emb_cls, emb_mean, emb_max, emb_weighted, emb_mixed

In [75]:
def load_test_embeddings(path):
    data = np.load(path, allow_pickle=True)
    labels = data[:,-1]
    #把label的元素转化为int
    labels = labels.astype(int)
    labels=[label_list[i] for i in labels]
    labels = np.array(labels)
    embeddings = data[:,0:-1]
    return labels, embeddings

In [76]:
def visualization(emb,label,title,savepath):
    # Plotly 可视化（交互式，不会文字重叠）
    fig = px.scatter(
        x=emb[:, 0],
        y=emb[:, 1],
        text=label,         # 每个点的悬停显示文字
        hover_name=label,   # 鼠标悬停显示
        width=800,
        height=800
    )
    fig.update_traces(
        marker=dict(size=8, opacity=0.7),
        textposition="top center"  # 让点的 label 在上方
    )
    fig.update_layout(
        title=title,
        xaxis_title="PC1",
        yaxis_title="PC2"
    )

    fig.write_html(savepath, auto_open=True)

### Semantic

In [77]:
# 读取GT语义embeddings和模型输出的语义embeddigs
words, emb_cls, emb_mean, emb_max, emb_weighted, emb_mixed = \
    load_npz_semantic('GT_embeddings/61words/Duin_bert_embeddings.npz')

words_v,emb_v,_ = load_npz_vis('GT_embeddings/61words/Duin_vit_embeddings_vit_whole.npz')
if test_type == 'Semantic':
    GT_labels_semantic = words
    GT_embeddings_semantic = emb_mean
else:
    GT_labels_semantic = words_v
    GT_embeddings_semantic = emb_v

test_labels_semantic, test_embeddings_semantic = \
    load_test_embeddings(path)
print('语义数据格式')
print('GT_labels_semantic:', GT_labels_semantic.shape)
print('GT_embeddings_semantic:', GT_embeddings_semantic.shape)
print('GT_embeddings_semantic Average:', np.mean(GT_embeddings_semantic))
print('test_labels_semantic:', test_labels_semantic.shape)
print('test_embeddings_semantic:', test_embeddings_semantic.shape)
print('test_embeddings_semantic Average:', np.mean(test_embeddings_semantic))

语义数据格式
GT_labels_semantic: (61,)
GT_embeddings_semantic: (61, 768)
GT_embeddings_semantic Average: -0.00010871876
test_labels_semantic: (320,)
test_embeddings_semantic: (320, 768)
test_embeddings_semantic Average: -0.00012762386423444987


In [78]:
# 归一化
GT_embeddings_semantic = GT_embeddings_semantic / np.linalg.norm(GT_embeddings_semantic, axis=1, keepdims=True)
test_embeddings_semantic = test_embeddings_semantic / np.linalg.norm(test_embeddings_semantic, axis=1, keepdims=True)

In [79]:
# #semantic可视化：
# pca = PCA(n_components=2)
# GT_embeddings_semantic_2d = pca.fit_transform(GT_embeddings_semantic)
# test_embeddings_semantic_2d = pca.transform(test_embeddings_semantic)
# # 可视化test 语义embeddings
# visualization(test_embeddings_semantic_2d, test_labels_semantic,\
#                'Test Semantic Embeddings PCA Visualization', 'Test_visualization/modeloutput_semantic_embeddings_200epoch.html')

In [80]:
results = evaluate_semantic_mapping(
    GT_labels_semantic,         # 长度=61 的汉字列表
    GT_embeddings_semantic,     # 形状 (61, 768)
    test_labels_semantic,       # 长度=329 的标签（取自61类）
    test_embeddings_semantic,   # 形状 (329, 768)
    topk=(1, 3, 5, 10),         # 
    reducer="pca",              # 可选 "pca" | "tsne" | "umap"（需安装 umap-learn）
    out_dir=savedir,
    annotate_prototypes=True,   # 是否在图上标注汉字
    random_state=0
)

print(results["overall"])       # 查看总体指标（accuracy、MRR、hits@K、ARI/NMI）


{'topk': {'hits@1': 0.02187499962747097, 'hits@3': 0.046875, 'hits@5': 0.08749999850988388, 'hits@10': 0.16562500596046448}, 'MRR': 0.08183050732524008, 'accuracy': 0.021875, 'ARI': 0.0, 'NMI': 0.0}


## Visual

In [34]:
# # 读取GT图像embeddings和模型输出的视觉embeddigs
# GT_labels_vis, GT_embeddings_vis, _ = load_npz_vis('GT_embeddings/Duin_vit_embeddings_vit_per_char.npz')
# test_labels_vis, test_embeddings_vis = load_test_embeddings('output_embeddings/test_embeddings_vis.npy')
# print('视觉数据格式')
# print('GT_labels_vis:', GT_labels_vis.shape)
# print('GT_embeddings_vis:', GT_embeddings_vis.shape)
# print('test_labels_vis:', test_labels_vis.shape)
# print('test_embeddings_vis:', test_embeddings_vis.shape)

In [35]:
# #visual可视化：
# pca = PCA(n_components=2)
# GT_embeddings_vis_2d = pca.fit_transform(GT_embeddings_vis)
# test_embeddings_vis_2d = pca.transform(test_embeddings_vis)
# # 可视化test 视觉embeddings
# visualization(test_embeddings_vis_2d, test_labels_vis,\
#                'Test Visual Embeddings PCA Visualization', 'Test_visualization/test_visual_embeddings.html')