In [None]:
# 分析数据集
import os
import numpy as np

input_file = '../data/wiki1m_for_simcse.txt'

with open(input_file, 'r') as f:
    lines = f.readlines()

print('total lines:', len(lines))

# 统计句子长度
lengths = [len(line.split()) for line in lines]
print('max length:', max(lengths))
print('min length:', min(lengths))
print('avg length:', np.mean(lengths))
print('median length:', np.median(lengths))

# 统计句子长度超过32


In [None]:
s = '''A surgeon interviewed by Australia's "A Current Affair" television show criticized the marketing of the supplement, stating that customers are "wasting their money, and for this product a large amount of money; and secondly, they may be led to believe they don’t need to take their effective treatments for conditions they may actually have.” Despite this, Wilson claims that the "purple powder" can help elderly people "keep their vibrations up", and at one performance she invited an audience member to speak with her about the "best thing for scar tissue" off camera, so that "trading standards don't become all uppity."'''
print(len(s))

from transformers import AutoTokenizer

# 加载分词器
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# 使用encode方法
input_ids = tokenizer.encode(s)

# 打印生成的token数量
print(f"Token count: {len(input_ids)}")

In [None]:
from urllib.parse import quote

template = 'This sentence : \'{sentence}\' means [MASK].'

# template =template.replace('[MASK]', tokenizer.mask_token)
prompt_prefix = template.split('{sentence}')[0]    
prompt_suffix = template.split('{sentence}')[1]

print(prompt_prefix)
print(prompt_suffix)

In [None]:
# 数据处理 做实体链接和实体消歧
from transformers import BertTokenizer, BertModel
import torch
import torch.nn.functional as F


model_name = 'princeton-nlp/unsup-simcse-bert-base-uncased'

# 初始化 BERT 模型和 tokenizer
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

# 输入句子
sentence = "She picked a ripe, juicy apple from the tree and took a big bite, savoring its crisp sweetness."
# sentence = "Apple recently unveiled its latest iPhone model, showcasing cutting-edge technology and sleek design that captivated tech enthusiasts worldwide."
target_entity = "Apple"
entity_description = "Apple is a sweet, crisp fruit from the apple tree (Malus domestica), enjoyed worldwide for its flavor and nutritional benefits, including fiber and vitamin C."
# entity_description = "Apple is a company that designs, manufactures, and markets consumer electronics, computer software, and online services."


# Token 化
inputs = tokenizer(sentence, return_tensors="pt")
tokens = tokenizer.tokenize(sentence)
entity_tokens = tokenizer(target_entity)

entity_id = entity_tokens['input_ids'][1:-1]  # 去掉 [CLS] 和 [SEP]

# 扩展一维
entity_id = torch.tensor(entity_id).unsqueeze(0)

# 获取 BERT 的所有 token 嵌入
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state

input_ids = inputs['input_ids']

mask = input_ids == entity_id

# 提取并聚合实体嵌入
entity_embeddings = last_hidden_states[inputs['input_ids'] == entity_id]

entity_description_inputs = tokenizer(entity_description, return_tensors="pt")
entity_description_outputs = model(**entity_description_inputs)
# 获取cls token的嵌入
# 如果有entity_id的话，就取entity_id的嵌入
entity_mask = entity_description_inputs['input_ids'] == entity_id
if entity_mask.sum() > 0:
    entity_description_emb = entity_description_outputs.last_hidden_state[entity_mask]
    # 有可能有多个实体，取第一个
    entity_description_emb = entity_description_emb[0]
else:
    entity_description_emb = entity_description_outputs.last_hidden_state[:, 0, :]
# 计算实体描述和实体的相似度
cos_sim = F.cosine_similarity(entity_embeddings, entity_description_emb, dim=-1)

print(cos_sim.item())



In [25]:
# 输出文件
output_file = '../data/wiki1m_for_simcse_ner_entity_linking.json'

with open(entity_dict_file, 'r') as f:  
    entity_dict = json.load(f)

In [30]:
# 做实体链接
from transformers import BertTokenizer, BertModel
import torch
from tqdm import tqdm
from datasets import load_dataset
import torch.nn.functional as F
import hashlib
import json

model_name = 'princeton-nlp/unsup-simcse-bert-base-uncased'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 初始化 BERT 模型和 tokenizer
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name).to(device)

# 实体置信度阈值
entity_reg_threshold = 0.8

# 加载外部实体描述数据集
entity_dict_file = '../data/wiki1m_for_simcse_ner_entity_dict.json'

def hash(text):
    return hashlib.md5(text.encode()).hexdigest()

# 输入句子
dataset = load_dataset('json', data_files='../data/wiki1m_for_simcse_ner.json')['train']

dataset = dataset.select(range(2000))


item_list = []
for item in tqdm(dataset):
    sentence, entity_list = item['text'], item['entities']

    s_inputs = tokenizer(sentence, return_tensors="pt").to(device)
    s_output = model(**s_inputs)
    s_last_hidden_states = s_output.last_hidden_state

    # 输出策略
    # s_output = s_output.last_hidden_state[:, 0, :]  # 取CLS

    sentence_item = {'sentence': sentence, 'entity_list': [],'ner_entity_list': entity_list}
    for entity in entity_list:
        # 实体置信度过滤
        if entity['confidence'] < entity_reg_threshold:
            continue

        entity_text = entity['text']

        entity_text_hash = hash(entity_text)

        entity_knowledge = entity_dict.get(entity_text_hash, [])
        if not entity_knowledge:
            continue

        # 获取实体的id
        entity_inputs = tokenizer(entity_text)
        # 在最后一维去掉 [CLS] 和 [SEP]
        entity_inputs = entity_inputs['input_ids'][1]  # 去掉 [CLS] 和 [SEP]，取第一个
        entity_id = torch.tensor(entity_inputs).unsqueeze(0).to(device)


        # 实体描述的列表
        entity_description_list = [item['description'] for item in entity_knowledge if item.get('description', '')]
        if not entity_description_list:
            continue
        # 实体描述的文本token化
        entity_description_inputs = tokenizer(entity_description_list, return_tensors="pt", padding=True).to(device) # [num_descriptions, max_length]
        entity_description_outputs = model(**entity_description_inputs) # [num_descriptions, max_length, hidden_size]


        # 如果inputs中有entity_id的话，就取entity_id的嵌入，否则取cls token的嵌入
        # entity_mask = entity_description_inputs['input_ids'] == entity_inputs # [num_descriptions, max_length]
        entity_description_outputs = entity_description_outputs.last_hidden_state[:, 0, :]  # [num_descriptions, hidden_size]

        # 获取原句子的实体嵌入
        s_entity_output = s_last_hidden_states[s_inputs['input_ids'] == entity_id]
        # 有可能有多个实体，取第一个
        if s_entity_output.shape[0] > 1:
            s_entity_output = s_entity_output[0]
        else:
            continue

        # 计算实体描述和实体的相似度
        cos_sim = F.cosine_similarity(s_entity_output, entity_description_outputs, dim=-1) # [num_descriptions]
        # 获取最相似的实体描述
        max_index = cos_sim.argmax().item()
        max_cos_sim = cos_sim.max().item()
        max_entity_description = entity_description_list[max_index]

        sentence_item['entity_list'].append({'entity': entity_text, 'description': max_entity_description, 'similarity': max_cos_sim})

    item_list.append(sentence_item)

with open(output_file, 'w') as f:
    json.dump(item_list, f, indent=4)

100%|██████████| 2000/2000 [00:28<00:00, 69.19it/s]
