In [1]:
import onnxruntime as ort
from transformers import BertTokenizerFast
import numpy as np
import re

  from .autonotebook import tqdm as notebook_tqdm
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
def generate_html(tokens, attention_weights, output_path="attention_avg.html"):
    norm_weights = (attention_weights - attention_weights.min()) / (attention_weights.max() - attention_weights.min() + 1e-8)
    html = "<html><head><style>span.token {padding:2px 5px;margin:2px;border-radius:5px;display:inline-block;font-family:monospace;}</style></head><body>"
    html += "<h2>平均 Attention 可视化 ([CLS] → tokens)</h2><div>"
    for token, weight in zip(tokens, norm_weights):
        red = int(255 * weight)
        color = f"rgba({red}, 0, 0, {weight:.2f})"
        html += f'<span class="token" style="background-color:{color}">{token}</span>'
    html += "</div></body></html>"
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(html)
    print(f"✔️ 已保存 HTML 到: {output_path}")

In [3]:
def clean_weibo_text(text):
    valid_keywords = [
        '笑', '哭', '泪', '心', '怒', '汗', '抱', '喜欢', '亲', '色', '偷笑', '害羞', '惊讶', '开心', '哭泣',
        '调皮', '害怕', '生气', '思考', '微笑', '呲牙', '委屈', '感动', '鼓掌', '加油', '抱拳', '拍手', '星星眼',
        '晕', '晚安', '睡觉', '吐', '呆萌', '抓狂', '拍砖', '爱', '尴尬', '大哭', '坏笑', '高兴', '害羞', '发怒',
        '兴奋', '酷', '赞', 'ok', '拜年', '卖萌', '抱抱', '转圈', '拜拜', '惊恐', '冷', '拜托', '拜谢', '炸裂',
        '流汗', '偷乐', '开心', '傻眼', '鄙视', '叹气', '纠结', '疑问', '点赞', '赞', '抱歉', '感恩', '感冒', '感情',
        '炸鸡', '雪人', '火', '狗', '猫', '熊', '兔', '猪', '骷髅', '鸡', '太阳', '月亮', '星', '花', '蛋糕',
        '巧克力', '糖果', '礼物', '礼花', '福', '平安', '红包', '祝福', '祝', '祝贺', '新年', '节日', '节', '圣诞',
        '生日', '万圣', '奥运', '火炬', '鼓掌', '加油', '胜利', '拥抱', '握手', '拳头', '挥手', '招手'
    ]
    if not isinstance(text, str):
        return ""

    # 去除 URL
    text = re.sub(r"http[s]?://\S+", "", text)

    # 去除转发链 //@用户:
    text = re.sub(r"//@\S+?:", "", text)
    # 去除正文中 @用户
    text = re.sub(r"@\S+", "", text)

    # 连续表情缩成一个
    text = re.sub(r"(\[[^\[\]]+\])\1+", r"\1", text)

    # 连续标点缩成一个
    text = re.sub(r"([！？!。，、,.，?])\1+", r"\1", text)

    # 利用 valid_keywords 过滤表情
    def filter_emoji(match):
        emoji = match.group(0)  # 带中括号的表情
        content = emoji[1:-1]   # 去掉中括号
        if any(keyword in content for keyword in valid_keywords):
            return emoji
        else:
            return ''

    text = re.sub(r"\[[^\[\]]+\]", filter_emoji, text)

    # 多空格替换成单空格，去首尾空格
    text = re.sub(r"\s+", " ", text).strip()

    # 去除无意义的词语
    text = text.replace("转发微博", "")

    return text

In [4]:
text = "高兴死了"

text = clean_weibo_text(text)

id2label = {0: '积极', 1: '中性', 2: '消极'}
# ========== 主推理 + 可视化流程 ==========
tokenizer = BertTokenizerFast.from_pretrained("../minirbt-h256-with-emojis")
session = ort.InferenceSession("../model_cleaned/minirbt_with_attention_quant.onnx", providers=["CPUExecutionProvider"])



# Tokenize and prepare input
inputs = tokenizer(text, return_tensors="np", padding="max_length", truncation=True, max_length=64)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
attention_mask = inputs["attention_mask"][0]
seq_len = int(np.sum(attention_mask))

# Run ONNX model
input_names = [i.name for i in session.get_inputs()]
output_names = [o.name for o in session.get_outputs()]
outputs = session.run(output_names, {k: inputs[k] for k in input_names})

# Extract logits and attentions
logits, attentions = outputs
print(type(logits), logits.shape)
pred = np.argmax(logits, axis=1)[0]
print(f"预测类别: {id2label[pred]}")

# Select attention: layer 5, head 0, [CLS] token (index 0)
layer = 5  # 选择最后一层（MiniRBT 只有 6 层，索引0~5）

# 提取该层所有 head 的 attention 矩阵
all_heads = attentions[layer][0]  # shape: [num_heads, seq_len, seq_len]
print(f"attention shape: {np.array(attentions).shape}")
# 取平均
avg_attention = np.mean(all_heads, axis=0)  # shape: [seq_len, seq_len]

cls_attention = avg_attention[0][:seq_len]

# Tokens without padding
tokens = tokens[:seq_len]

# Generate HTML file
generate_html(tokens, cls_attention, output_path="attention.html")

<class 'numpy.ndarray'> (1, 3)
预测类别: 积极
attention shape: (6, 1, 8, 64, 64)
✔️ 已保存 HTML 到: attention.html
