## 用于对语料库及wiki数据库进行查询进行查询

In [None]:
# 计算语料库的平均长度
import numpy as np
import matplotlib.pyplot as plt
file_path = 'data/wiki1m_knowledge.txt'

# 读取文本文件，并计算每个句子的长度
with open(file_path, 'r', encoding='utf-8') as file:
    sentence_word_counts = [len(line.strip().split()) for line in file]

# 转换为 NumPy 数组
sentence_word_counts_array = np.array(sentence_word_counts)

# 计算中位数
median = np.median(sentence_word_counts_array)

# 计算众数
mode = int(np.argmax(np.bincount(sentence_word_counts_array)))

# 计算平均数
mean = np.mean(sentence_word_counts_array)

# 计算最大值和最小值
max_value = np.max(sentence_word_counts_array)
min_value = np.min(sentence_word_counts_array)

print("中位数:", median)
print("众数:", mode)
print("平均数:", mean)
print("最大值:", max_value)
print("最小值:", min_value)

# 画频率分布图

# 绘制直方图
plt.hist(sentence_word_counts, bins=100, color='skyblue', edgecolor='black')
plt.title('Distribution of Sentence Lengths')
plt.xlabel('Length of Sentences')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

In [None]:
# bert prompt test
from transformers import BertTokenizer, BertModel

path = '/pretrain_model/bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(path)
model = BertModel.from_pretrained(path)
text = "Example sentence to be tokenized."
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)


In [25]:
from transformers import BertModel
import os

# 下载并加载BERT模型
model = BertModel.from_pretrained('bert-base-uncased')

# 将模型移动到指定的目录
output_dir = 'pretrain_model/bert-base-uncased'
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)


In [None]:
from flair.models import SequenceTagger
from flair.data import Sentence
import json
import flair
import torch
from tqdm import tqdm

flair.device = torch.device('cuda')
# 加载NER模型
tagger = SequenceTagger.load("ner")

input_file = '../data/wiki1m_for_simcse.txt'
output_file = '../data/wiki1m_for_simcse_ner.json'

with open(input_file, 'r', encoding='utf-8') as file:
    lines = file.readlines()

batch_size = 512

result = []

for i in tqdm(range(0, len(lines), batch_size)):
    sentence_list = []
    for line in lines[i:i+batch_size]:
        sentence_list.append(Sentence(line))

    tagger.predict(sentence_list, mini_batch_size=128, verbose=False)


    for sentence in sentence_list:
        entities_list = []
        for i, entity in enumerate(sentence.get_spans('ner')):
            entities_list.append({
                "text": entity.text,
                "start_position": entity.start_position,
                "end_position": entity.end_position,
                "label": entity.get_label('ner').value, 
                "confidence": entity.score
            })
        result.append({"text": sentence.to_original_text(), "entities": entities_list})

# 保存结果
with open(output_file, 'w', encoding='utf-8') as file:
    json.dump(result, file, ensure_ascii=False, indent=4)

# # 2024-10-11 10:52:22,912 SequenceTagger predicts: Dictionary with 20 tags: <unk>, O, S-ORG, S-MISC, B-PER, E-PER, S-LOC, B-ORG, E-ORG, I-PER, S-PER, B-MISC, I-MISC, E-MISC, I-ORG, B-LOC, E-LOC, I-LOC, <START>, <STOP>

In [None]:
# 数据梳理
import json
import os

input_file = '../data/wiki1m_for_simcse_ner.json'
output_file = '../data/wiki1m_for_simcse_ner_entity.txt'

with open(input_file, 'r', encoding='utf-8') as file:
    sentence_list = json.load(file)

entity_list = []
for sentence in sentence_list:
    for entity in sentence['entities']:
        entity_list.append(entity)

print("实体数量:", len(entity_list))

# 去重
entity_set = set()
for entity in entity_list:
    entity_set.add(entity['text'])

print("去重后的实体数量:", len(entity_set))
# 实体数量: 1977083
# 去重后的实体数量: 618928

with open(output_file, 'w', encoding='utf-8') as file:
    for entity in entity_set:
        file.write(entity + '\n')

In [None]:
import requests
import json
from tqdm import tqdm
import os
import hashlib
from concurrent.futures import ThreadPoolExecutor, as_completed

base_dir = '../data/'

input_file = base_dir + 'wiki1m_for_simcse_ner_entity.txt'
output_dir = base_dir + 'wiki1m_for_simcse_ner_entity_search/'
output_file = base_dir + 'wiki1m_for_simcse_ner_entity_search_dict.json'

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

with open(input_file, 'r', encoding='utf-8') as file:
    entity_list = file.read().splitlines()

entity_dict = {}

# 设置 API 的基础 URL
wikidata_url = 'https://www.wikidata.org/w/api.php'

def search_wikidata(entity):
    params = {
        'action': 'wbsearchentities',  # 使用实体搜索
        'format': 'json',              # 返回格式为JSON
        'language': 'en',              # 查询语言
        'search': entity,              # 要查询的实体名称
        'limit': 10                    # 限制返回结果数量
    }
    
    try:
        response = requests.get(wikidata_url, params=params)
        if response.status_code == 200:
            data = response.json()
            return data.get('search', [])
        else:
            print(f"Error {response.status_code}: {response.text}")
            return None
    except Exception as e:
        print(f"Error querying {entity}: {e}")
        return None

def hash(text):
    return hashlib.md5(text.encode()).hexdigest()

def process_entity(entity):
    """
    查询实体并保存结果到文件
    """
    output_file_path = output_dir + hash(entity) + '.json'
    if os.path.exists(output_file_path):
        return None
    
    # 查询实体
    entities = search_wikidata(entity)
    if entities is not None:
        with open(output_file_path, 'w', encoding='utf-8') as file:
            json.dump(entities, file, ensure_ascii=False, indent=4)
    return entities

# 使用线程池进行并发请求
max_workers = 10  # 设置并发线程数量
with ThreadPoolExecutor(max_workers=max_workers) as executor:
    futures = [executor.submit(process_entity, entity) for entity in entity_list]
    for future in tqdm(as_completed(futures), total=len(futures)):
        future.result()  # 等待每个任务完成

# 都结束后，将所有实体信息整合到一个文件中
entity_dict = {}
for file in os.listdir(output_dir):
    with open(output_dir + file, 'r', encoding='utf-8') as f:
        entity_dict[file.split('.')[0]] = json.load(f)

with open(output_file, 'w', encoding='utf-8') as file:
    json.dump(entity_dict, file, ensure_ascii=False, indent=4)