In [1]:
# 导入必要的库
from IPython.display import display, HTML
from string import Template
import json
import sys
import os
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm

# 指定字体路径
font_path = '/storage/U-ViT/tools/Visualization/language/fonts/SimHei/SimHei.ttf'
font_prop = fm.FontProperties(fname=font_path)

# 配置全局字体
matplotlib.rcParams['font.family'] = font_prop.get_name()
matplotlib.rcParams['axes.unicode_minus'] = False

# 导入display_attention函数

from attention_visualizer import display_attention

# # 创建测试数据
# # 示例句子和相应的注意力权重
# sentence = ["这是", "一个", "注意力", "可视化", "测试"]
# attention_weights = [0.2, 0.5, 0.9, 0.7, 0.3]

# 显示可视化结果
# display_attention(sentence, attention_weights)
# print("Attention visualization displayed successfully.")

In [2]:
# 测试简单的 HTML 渲染
display(HTML('<div style="background-color:yellow; padding:20px;">这是测试 HTML</div>'))

In [3]:
# # 测试边界情况
# print("测试高注意力值:")
# high_attention = ["高", "注意力", "值", "测试"]
# high_weights = [0.95, 0.99, 0.98, 0.97]
# display_attention(high_attention, high_weights)

# print("测试低注意力值:")
# low_attention = ["低", "注意力", "值", "测试"]
# low_weights = [0.05, 0.01, 0.02, 0.03]
# display_attention(low_attention, low_weights)

# # 测试不同offset值
# print("大字体基础大小:")
# display_attention(sentence, attention_weights, offset=30)

# # 测试边界情况
# print("测试高注意力值:")
# high_attention = ["高", "注意力", "值", "测试"]
# high_weights = [0.95, 0.99, 0.98, 0.97]
# display_attention(high_attention, high_weights)

# print("测试低注意力值:")
# low_attention = ["低", "注意力", "值", "测试"]
# low_weights = [0.05, 0.01, 0.02, 0.03]
# display_attention(low_attention, low_weights)

# # 测试不同offset值
# print("大字体基础大小:")
# display_attention(sentence, attention_weights, offset=30)

# 导入CLIP模型

In [4]:
import torch
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPTextModel
from open_clip import create_model_and_transforms, get_tokenizer
import os

from transformers import (
    AutoConfig,
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    HfArgumentParser,
    PreTrainedTokenizerFast,
    Trainer,
    TrainingArguments,
    set_seed,
    AutoModel,
    
    CLIPPreTrainedModel,
    CLIPTextModel, 
    CLIPTextConfig,
    CLIPTokenizerFast, 
    PreTrainedModel, 
    CLIPConfig
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

import sys

# 将官方 timm 模块的路径添加到 sys.path 的最前面,避免导入libs下的timm
sys.path.insert(0, '/opt/conda/envs/uvit/lib/python3.10/site-packages/timm')


class AbstractEncoder(nn.Module):
    def __init__(self):
        super().__init__()

    def encode(self, *args, **kwargs):
        raise NotImplementedError


class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length
        self.freeze()

    def freeze(self):
        self.transformer = self.transformer.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device) # (batch_size, max_length)
        outputs = self.transformer(input_ids=tokens) # (batch_size, max_length, hidden_size)

        z = outputs.last_hidden_state
        return z

    def encode(self, text):  
        return self(text) # 自动调用forward方法

class BioMedClipEmbedder(AbstractEncoder):
    """Uses the BiomedCLIP transformer encoder for text (from Hugging Face Hub)"""
    def __init__(self, version="hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", device="cuda", max_length=256):
        super().__init__()
        # Load the model and tokenizer from Hugging Face Hub
        self.model, _, preprocess = create_model_and_transforms(version)
        self.encoder = self.model.text
        self.encoder.output_tokens = True

        self.tokenizer = get_tokenizer(version)
        self.device = device
        self.max_length = max_length
        self.freeze()

        print("\n")
        print("**TextEmbedder**:", version)
        print("\n")


    def freeze(self):
        """Freeze the model parameters to disable training."""
        self.model = self.model.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        # Tokenize the input text
        token_embeddings = self.tokenizer(text, context_length=self.max_length).to(self.device)
        attention_mask = (token_embeddings != 0).long()  # 非零位置为1，零位置为0
        outputs = self.encoder(token_embeddings) # (batch_size, max_length, hidden_size)
        
        # Get the hidden states from the transformer
        z = outputs[1] # 取出hidden_size
        return token_embeddings, z, attention_mask

    def encode(self, text):
        return self(text)

class PubMedClipEmbedder(AbstractEncoder):
    """Uses the PubMedCLIP transformer encoder for text embeddings"""
    def __init__(self, version="flaviagiammarino/pubmed-clip-vit-base-patch32", device="cuda", max_length=77):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length
        self.freeze()

    def freeze(self):
        self.transformer = self.transformer.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        attention_mask = batch_encoding["attention_mask"].to(self.device)
        tokens = batch_encoding["input_ids"].to(self.device)  # (batch_size, max_length)
        outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask)  # (batch_size, max_length, hidden_size)
        
        z = outputs.last_hidden_state
        
        # 返回原始tokens和特征，以及attention_mask，让调用者知道哪些是padding
        return tokens, z, attention_mask

    def encode(self, text):  
        return self(text) # 自动调用forward方法

class BertEmbedder(AbstractEncoder):
    """Uses the BERT transformer encoder for text (from Hugging Face)"""
    def __init__(self, version="michiyasunaga/BioLinkBERT-base", device="cuda", max_length=256):
        super().__init__()
        
        print("\n")
        print("**TextEmbedder**:", version)
        print("\n")

        # Load the model and tokenizer from Hugging Face Hub
        self.tokenizer = AutoTokenizer.from_pretrained(version)
        self.bert_model = AutoModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length

        self.freeze()

    def freeze(self):
        self.bert_model = self.bert_model.eval()
        for param in self.parameters():
            param.requires_grad = False
            
    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        attention_mask = batch_encoding["attention_mask"].to(self.device)
        tokens = batch_encoding["input_ids"].to(self.device)  # (batch_size, max_length)
        outputs = self.bert_model(input_ids=tokens, attention_mask=attention_mask)  # (batch_size, max_length, hidden_size)
        
        z = outputs.last_hidden_state
        
        # 返回原始tokens和特征，以及attention_mask，让调用者知道哪些是padding
        return tokens, z, attention_mask

    def encode(self, text):  
        return self(text) # 自动调用forward方法

可视化文字权重

In [5]:
def merge_subword_tokens_bert(tokens, weights):
    merged_tokens = []
    merged_weights = []
    
    current_word = []
    current_weights = []
    
    for token, weight in zip(tokens, weights):
        if token in ["[CLS]", "[SEP]"]:  # 特殊符号单独处理
            if current_word:
                merged_tokens.append("".join(current_word))
                merged_weights.append(sum(current_weights)/len(current_weights))
                current_word = []
                current_weights = []
            merged_tokens.append(token)
            merged_weights.append(weight)
        elif token.startswith("##"):
            current_word.append(token[2:])
            current_weights.append(weight)
        else:
            if current_word:
                merged_tokens.append("".join(current_word))
                merged_weights.append(sum(current_weights)/len(current_weights))
            current_word = [token]
            current_weights = [weight]
    
    # 处理剩余词
    if current_word:
        merged_tokens.append("".join(current_word))
        merged_weights.append(sum(current_weights)/len(current_weights))
    
    return merged_tokens, merged_weights


def merge_subword_tokens_clip(tokens, weights):
    """
    处理以 </w> 结尾的词尾标记，合并至标记位置
    
    示例输入：
        tokens = ["Hello</w>", "World</w>", "[SEP]", "Test", "ing</w>"]
        weights = [0.9, 0.8, 1.00.7, 0.6]
    
    输出：
        (["Hello", "World", "[SEP]", "Testing"], [0.9, 0.8, 1.0, 0.65])
    """
    merged_tokens = []
    merged_weights = []
    
    current_word = []
    current_weights = []
    
    for token, weight in zip(tokens, weights):
        if token in {"[CLS]", "[SEP]"}:
            if current_word:
                merged_tokens.append("".join(current_word))
                merged_weights.append(sum(current_weights)/len(current_weights))
                current_word = []
                current_weights = []
            merged_tokens.append(token)
            merged_weights.append(weight)
            continue
            
        if token.endswith("</w>"):
            current_word.append(token[:-4])
            current_weights.append(weight)
            merged_tokens.append("".join(current_word))
            merged_weights.append(sum(current_weights)/len(current_weights))
            current_word = []
            current_weights = []
        else:
            current_word.append(token)
            current_weights.append(weight)
    
    if current_word:
        merged_tokens.append("".join(current_word))
        merged_weights.append(sum(current_weights)/len(current_weights))
        
    return merged_tokens, merged_weights



def vis_word_label_attention(label, weights, version):
    """
    可视化每个类别的词汇注意力强度
    参数：
    label - 词汇标签列表（0-3的整数列表）
    weights - 对应词汇的注意力权重列表
    """
    # 验证输入合法性
    if len(label) != len(weights):
        raise ValueError("标签和权重列表长度必须一致")
    
    # 初始化分类容器
    category_weights = {0:[], 1:[], 2:[], 3:[]}
    
    # 遍历填充数据
    for lbl, w in zip(label, weights):
        if lbl in category_weights:
            category_weights[lbl].append(w)
    
    # 计算均值（处理空列表避免除零）
    category_means = {
        k: np.mean(v) if v else 0 
        for k, v in category_weights.items()
    }
    
    # 打印结果
    print("类别 | 平均注意力值")
    print("------------------")
    for k in sorted(category_means.keys()):
        print(f"{k:4} | {category_means[k]:.4f}")
    
    # 绘图设置
    plt.figure(figsize=(10, 6))
    categories = ['疾病名词', '程度描述', '位置描述', '其他']
    values = [category_means[0], category_means[1], category_means[2], category_means[3]]
    
    # 创建柱状图
    bars = plt.bar(categories, values, color=['#3a5988', '#4a89c0', '#80b8db', '#c5dcee'])
    
    # 添加数值标签
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                 f'{height:.4f}',
                 ha='center', va='bottom')
    
    # 图表装饰
    plt.title('各语义类别平均注意力强度', fontsize=14)
    plt.xlabel('语义类别', fontsize=12)
    plt.ylabel('平均注意力值', fontsize=12)
    plt.ylim(0, max(values)*1.2)  # 自动调整y轴范围
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    # plt.show()
    # 保存文件
    prefix = version.split("/")[-1]
    plt.savefig(f'{prefix}_word_vis.jpg', dpi=300, bbox_inches='tight')
    plt.close()


def visualize_text_attention(text, version, model_type):
    if model_type == "biomedclip":
        clip = BioMedClipEmbedder(version)
    elif model_type == "pubmedclip":
        clip = PubMedClipEmbedder(version)
    else:
        clip = BertEmbedder(version)
    
    clip.eval()
    clip.to(device)
    
    with torch.no_grad():
        tokens, features, attention_mask = clip.encode(text)
    
    valid_length = attention_mask[0].sum().item()
    valid_tokens = tokens[0, :valid_length]
    valid_features = features[0, :valid_length]
    
    tokens_list = valid_tokens.cpu().numpy().tolist()
    if model_type == "biomedclip":
        decoded_tokens = clip.tokenizer.tokenizer.convert_ids_to_tokens(tokens_list)
    else:
        decoded_tokens = clip.tokenizer.convert_ids_to_tokens(tokens_list)
    
    # 提取[CLS] token的特征向量
    cls_token = valid_features[0]  # 维度 [hidden_dim]
    
    # 计算余弦相似度（带非线性增强）
    similarity_scores = torch.nn.functional.cosine_similarity(
        valid_features,
        cls_token.unsqueeze(0).expand_as(valid_features),
        dim=1
    )
    enhanced_p = {"bert_series":(5,[0.01, 0.95]), "biomedclip":(5,[0.02, 0.9]), "pubmedclip":(5,[0.02, 0.9])}
    
    print("enhanced_p:", enhanced_p[model_type][0], enhanced_p[model_type][1])
    print("\n")
    
    # 非线性增强（指数放大差异）
    enhanced_scores = torch.exp(enhanced_p[model_type][0] * similarity_scores)  # 可调节的放大系数
    
    # 分位数归一化（增强中间值的区分度）
    quantiles = torch.quantile(enhanced_scores, torch.tensor(enhanced_p[model_type][1]).to(device))  # 去除极端值
    # print(quantiles[0], quantiles[1])
    clipped_scores = torch.clamp(enhanced_scores, quantiles[0], quantiles[1])
    token_importance = (clipped_scores - clipped_scores.min()) / (clipped_scores.max() - clipped_scores.min())
    
    token_importance = token_importance.cpu().numpy().tolist()
    
    # 合并子词并计算平均权重
    if model_type == "pubmedclip":
        merged_tokens, merged_weights = merge_subword_tokens_clip(decoded_tokens, token_importance)
    else:
        merged_tokens, merged_weights = merge_subword_tokens_bert(decoded_tokens, token_importance)
    
    print("Report:",merged_tokens)
    print("length of report:", len(merged_tokens))
    print("\n")
    
    label =[3, 3, 3, 3, 3, 3, 2, 1, 2, 0, 3, 1, 2, 0, 3, 1, 0, 3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, 1, 2, 0, 3, 2, 0, 3, 3, 3, 1, 0, 0, 3, 1, 2, 0, 3]
    print(len(label))  # 不包括[CLS]和[SEP]
    if model_type == "pubmedclip":
        vis_word_label_attention(label, merged_weights[0:-1], version)
    else:
        vis_word_label_attention(label, merged_weights[1:-1], version)
    
    
    # merged_tokens = [merged_token.replace("</w>", "") for merged_token in merged_tokens]
    display_attention(merged_tokens[1:-1], merged_weights[1:-1])  # 不可视化[CLS]和[SEP]

In [6]:
os.CUDA_VISIBLE_DEVICES = '0'
device = "cuda" if torch.cuda.is_available() else "cpu"

# 使用示例
# search_query = np.load("/cpfs01/projects-HDD/cfff-906dc71fafda_HDD/gbw_21307130160/U-ViT/tools/Visualization/language/0.npy", allow_pickle=True).item()["prompt"]
search_query = "The chest X-ray reveals bilateral diffuse interstitial edema with patchy alveolar infiltrates and multifocal consolidations in the mid-to-lower lung zones, accompanied by subtle air bronchograms and vascular congestion, suggestive of progressive hydrostatic edema and evolving lobar consolidation."
print(search_query)
print("\n")

The chest X-ray reveals bilateral diffuse interstitial edema with patchy alveolar infiltrates and multifocal consolidations in the mid-to-lower lung zones, accompanied by subtle air bronchograms and vascular congestion, suggestive of progressive hydrostatic edema and evolving lobar consolidation.




## Bert-series model

In [11]:
# michiyasunaga/BioLinkBERT-base 
# michiyasunaga/BioLinkBERT-large (hidden_size 1024, 似乎不太可行)
# microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext
# cambridgeltl/SapBERT-from-PubMedBERT-fulltext
# StanfordAIMI/RadBERT

visualize_text_attention(search_query, version="cambridgeltl/SapBERT-from-PubMedBERT-fulltext", model_type="bert_series")



**TextEmbedder**: cambridgeltl/SapBERT-from-PubMedBERT-fulltext






enhanced_p: 5 [0.01, 0.95]


Report: ['[CLS]', 'the', 'chest', 'x', '-', 'ray', 'reveals', 'bilateral', 'diffuse', 'interstitial', 'edema', 'with', 'patchy', 'alveolar', 'infiltrates', 'and', 'multifocal', 'consolidations', 'in', 'the', 'mid', '-', 'to', '-', 'lower', 'lung', 'zones', ',', 'accompanied', 'by', 'subtle', 'air', 'bronchograms', 'and', 'vascular', 'congestion', ',', 'suggestive', 'of', 'progressive', 'hydrostatic', 'edema', 'and', 'evolving', 'lobar', 'consolidation', '.', '[SEP]']
length of report: 48


46
类别 | 平均注意力值
------------------
   0 | 0.3045
   1 | 0.4063
   2 | 0.2675
   3 | 0.6339


## pubmedclip

In [13]:
visualize_text_attention(search_query, version="openai/clip-vit-base-patch32", model_type="pubmedclip")



enhanced_p: 5 [0.02, 0.9]


Report: ['<|startoftext|>the', 'chest', 'x', '-', 'ray', 'reveals', 'bilateral', 'diffuse', 'interstitial', 'edema', 'with', 'patchy', 'alveolar', 'infiltrates', 'and', 'multifocal', 'consolidations', 'in', 'the', 'mid', '-', 'to', '-', 'lower', 'lung', 'zones', ',', 'accompanied', 'by', 'subtle', 'air', 'bronchograms', 'and', 'vascular', 'congestion', ',', 'suggestive', 'of', 'progressive', 'hydrostatic', 'edema', 'and', 'evolving', 'lobar', 'consolidation', '.', '<|endoftext|>']
length of report: 47


46
类别 | 平均注意力值
------------------
   0 | 0.2420
   1 | 0.2588
   2 | 0.3368
   3 | 0.4186


## biomedclip

In [12]:
visualize_text_attention(search_query, version="hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", model_type="biomedclip")





**TextEmbedder**: hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224


enhanced_p: 5 [0.02, 0.9]


Report: ['[CLS]', 'the', 'chest', 'x', '-', 'ray', 'reveals', 'bilateral', 'diffuse', 'interstitial', 'edema', 'with', 'patchy', 'alveolar', 'infiltrates', 'and', 'multifocal', 'consolidations', 'in', 'the', 'mid', '-', 'to', '-', 'lower', 'lung', 'zones', ',', 'accompanied', 'by', 'subtle', 'air', 'bronchograms', 'and', 'vascular', 'congestion', ',', 'suggestive', 'of', 'progressive', 'hydrostatic', 'edema', 'and', 'evolving', 'lobar', 'consolidation', '.', '[SEP]']
length of report: 48


46
类别 | 平均注意力值
------------------
   0 | 0.0790
   1 | 0.0954
   2 | 0.1274
   3 | 0.6564


收集所有编码器的文本表现，再画一个雷达图，结束文本可视化