## 用于对语料库及wiki数据库进行查询进行查询

In [None]:
# 计算语料库的平均长度
import numpy as np
import matplotlib.pyplot as plt
file_path = 'data/wiki1m_knowledge.txt'

# 读取文本文件，并计算每个句子的长度
with open(file_path, 'r', encoding='utf-8') as file:
    sentence_word_counts = [len(line.strip().split()) for line in file]

# 转换为 NumPy 数组
sentence_word_counts_array = np.array(sentence_word_counts)

# 计算中位数
median = np.median(sentence_word_counts_array)

# 计算众数
mode = int(np.argmax(np.bincount(sentence_word_counts_array)))

# 计算平均数
mean = np.mean(sentence_word_counts_array)

# 计算最大值和最小值
max_value = np.max(sentence_word_counts_array)
min_value = np.min(sentence_word_counts_array)

print("中位数:", median)
print("众数:", mode)
print("平均数:", mean)
print("最大值:", max_value)
print("最小值:", min_value)

# 画频率分布图

# 绘制直方图
plt.hist(sentence_word_counts, bins=100, color='skyblue', edgecolor='black')
plt.title('Distribution of Sentence Lengths')
plt.xlabel('Length of Sentences')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

In [None]:
# bert prompt test
from transformers import BertTokenizer, BertModel

path = '/pretrain_model/bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(path)
model = BertModel.from_pretrained(path)
text = "Example sentence to be tokenized."
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)


In [None]:
from transformers import BertModel
import os

# 下载并加载BERT模型
model = BertModel.from_pretrained('bert-base-uncased')

# 将模型移动到指定的目录
output_dir = 'pretrain_model/bert-base-uncased'
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)


In [None]:
from flair.models import SequenceTagger
from flair.data import Sentence
import json
import flair
import torch
from tqdm import tqdm

flair.device = torch.device('cuda')
# 加载NER模型
tagger = SequenceTagger.load("ner")

input_file = '../data/wiki1m_for_simcse.txt'
output_file = '../data/wiki1m_for_simcse_ner.json'

with open(input_file, 'r', encoding='utf-8') as file:
    lines = file.readlines()

batch_size = 512

result = []

for i in tqdm(range(0, len(lines), batch_size)):
    sentence_list = []
    for line in lines[i:i+batch_size]:
        sentence_list.append(Sentence(line))

    tagger.predict(sentence_list, mini_batch_size=128, verbose=False)


    for sentence in sentence_list:
        entities_list = []
        for i, entity in enumerate(sentence.get_spans('ner')):
            entities_list.append({
                "text": entity.text,
                "start_position": entity.start_position,
                "end_position": entity.end_position,
                "label": entity.get_label('ner').value, 
                "confidence": entity.score
            })
        result.append({"text": sentence.to_original_text(), "entities": entities_list})

# 保存结果
with open(output_file, 'w', encoding='utf-8') as file:
    json.dump(result, file, ensure_ascii=False, indent=4)

# # 2024-10-11 10:52:22,912 SequenceTagger predicts: Dictionary with 20 tags: <unk>, O, S-ORG, S-MISC, B-PER, E-PER, S-LOC, B-ORG, E-ORG, I-PER, S-PER, B-MISC, I-MISC, E-MISC, I-ORG, B-LOC, E-LOC, I-LOC, <START>, <STOP>

In [None]:
# 数据梳理
import json
import os

input_file = '../data/wiki1m_for_simcse_ner.json'
output_file = '../data/wiki1m_for_simcse_ner_entity.txt'

with open(input_file, 'r', encoding='utf-8') as file:
    sentence_list = json.load(file)

entity_list = []
for sentence in sentence_list:
    for entity in sentence['entities']:
        entity_list.append(entity)

print("实体数量:", len(entity_list))

# 去重
entity_set = set()
for entity in entity_list:
    entity_set.add(entity['text'])

print("去重后的实体数量:", len(entity_set))
# 实体数量: 1977083
# 去重后的实体数量: 618928

with open(output_file, 'w', encoding='utf-8') as file:
    for entity in entity_set:
        file.write(entity + '\n')

In [None]:
import requests
import json
from tqdm import tqdm
import os
import hashlib
from concurrent.futures import ThreadPoolExecutor, as_completed

base_dir = '../data/'

input_file = base_dir + 'wiki1m_for_simcse_ner_entity.txt'
output_dir = base_dir + 'wiki1m_for_simcse_ner_entity_search_all/'
output_file = base_dir + 'wiki1m_for_simcse_ner_entity_search_dict.json'

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

with open(input_file, 'r', encoding='utf-8') as file:
    entity_list = file.read().splitlines()

entity_dict = {}

# 设置 API 的基础 URL
wikidata_url = 'https://www.wikidata.org/w/api.php'

def search_wikidata(entity):
    params = {
        'action': 'wbsearchentities',  # 使用实体搜索
        'format': 'json',              # 返回格式为JSON
        'language': 'en',              # 查询语言
        'search': entity,              # 要查询的实体名称
        'limit': 10                    # 限制返回结果数量
    }
    
    try:
        response = requests.get(wikidata_url, params=params)
        if response.status_code == 200:
            data = response.json()
            return data.get('search', [])
        else:
            print(f"Error {response.status_code}: {response.text}")
            return None
    except Exception as e:
        print(f"Error querying {entity}: {e}")
        return None

def hash(text):
    return hashlib.md5(text.encode()).hexdigest()

def process_entity(entity):
    """
    查询实体并保存结果到文件
    """
    output_file_path = output_dir + hash(entity) + '.json'
    if os.path.exists(output_file_path):
        return None
    
    # 查询实体
    entities = search_wikidata(entity)
    if entities is not None:
        with open(output_file_path, 'w', encoding='utf-8') as file:
            json.dump(entities, file, ensure_ascii=False, indent=4)
    return entities

# 使用线程池进行并发请求
max_workers = 10  # 设置并发线程数量
with ThreadPoolExecutor(max_workers=max_workers) as executor:
    futures = [executor.submit(process_entity, entity) for entity in entity_list]
    for future in tqdm(as_completed(futures), total=len(futures)):
        future.result()  # 等待每个任务完成

# 都结束后，将所有实体信息整合到一个文件中
entity_dict = {}
for file in os.listdir(output_dir):
    with open(output_dir + file, 'r', encoding='utf-8') as f:
        try:
            entity_dict[file.split('.')[0]] = json.load(f)
        except:
            pass

with open(output_file, 'w', encoding='utf-8') as file:
    json.dump(entity_dict, file, ensure_ascii=False, indent=4)

In [None]:
import json
import os
from tqdm import tqdm

base_dir = '../data/'
input_dir = base_dir + 'wiki1m_for_simcse_ner_entity_search/'
output_file = base_dir + 'wiki1m_for_simcse_ner_entity_dict.json'

entity_dict = {}
for file in tqdm(os.listdir(input_dir)):
    with open(input_dir + file, 'r', encoding='utf-8') as f:
        try:
            entity_dict[file.split('.')[0]] = json.load(f)
        except:
            print(file)

with open(output_file, 'w', encoding='utf-8') as file:
    json.dump(entity_dict, file, ensure_ascii=False)

In [None]:
import requests

def search_entities(entity_name):
    """
    搜索 Wikipedia 上的同名实体并返回候选列表
    :param entity_name: 实体名称
    :return: 同名实体的候选列表，每个项包含页面 ID 和标题
    """
    url = "https://en.wikipedia.org/w/api.php"
    params = {
        "action": "query",
        "list": "search",
        "srsearch": entity_name,
        "format": "json",
    }
    
    response = requests.get(url, params=params)
    search_results = response.json().get("query", {}).get("search", [])
    
    candidates = [{"page_id": result["pageid"], "title": result["title"]} for result in search_results]
    return candidates

def fetch_entity_info(page_id):
    """
    根据页面 ID 获取 Wikipedia 实体的详细信息
    :param page_id: Wikipedia 页面 ID
    :return: 包含实体详细信息的字典
    """
    url = "https://en.wikipedia.org/w/api.php"
    params = {
        "action": "query",
        "format": "json",
        "pageids": page_id,
        "prop": "info|extracts|categories|links",
        "inprop": "url",
        "exintro": True,
        "explaintext": True,
        "cllimit": "max",
        "pllimit": "max"
    }
    
    response = requests.get(url, params=params)
    page_info = response.json().get("query", {}).get("pages", {}).get(str(page_id), {})

    entity_info = {
        "page_id": page_id,
        "title": page_info.get("title", "N/A"),
        "url": page_info.get("fullurl", "N/A"),
        "extract": page_info.get("extract", "No summary available"),
        "categories": [cat.get("title", "") for cat in page_info.get("categories", [])],
        "links": [link.get("title", "") for link in page_info.get("links", [])]
    }
    
    return entity_info

def fetch_all_entity_infos(entity_name):
    # 搜索同名实体的候选项
    candidates = search_entities(entity_name)
    if not candidates:
        print(f"No results found for '{entity_name}'")
        return
    
    # 获取所有候选实体的详细信息
    all_entity_infos = []
    for candidate in candidates:
        print(f"Fetching info for '{candidate['title']}' (Page ID: {candidate['page_id']})...")
        entity_info = fetch_entity_info(candidate["page_id"])
        all_entity_infos.append(entity_info)
    
    # 输出每个实体的信息
    for info in all_entity_infos:
        print("\n页面标题:", info['title'])
        print("页面 URL:", info['url'])
        print("简介:", info['extract'])
        print('---')

# 示例使用
entity_name = "Winner advances to the second stage."
fetch_all_entity_infos(entity_name)

In [None]:
# 对假负样例进行分析
from transformers import BertTokenizer, BertModel
import torch.nn.functional as F
import torch
import requests

model_name = 'princeton-nlp/unsup-simcse-bert-base-uncased'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

sent_list = ["Chamod Wickramasuriya",
    "Chamod Wickramasuriya (born 27 May 1999) is a Sri Lankan cricketer.",
    "He made his Twenty20 debut on 15 January 2020, for Galle Cricket Club in the 2019–20 SLC Twenty20 Tournament.",
    "Comedian Bharti Singh will Host this show along with her husband writer Haarsh Limbachiyaa."
    ]
base_sent = "hamod Wickramasuriya (born 27 May 1999) is a Sri Lankan cricketer. He made his Twenty20 debut on 15 January 2020, for Galle Cricket Club in the 2019–20 SLC Twenty20 Tournament."

api_url = 'https://en.wikipedia.org/w/api.php'

headers = {
    "User-Agent": "Wiki Study/1.0 (905899183@qq.com)"
}

def search_wiki(text):
    # 搜出10个结果
    params = {
        "action": "query",
        "list": "search",
        "srsearch": text,  # 精确匹配内容
        "srwhat": "text",         # 指定在内容中搜索
        "format": "json",
    }

    response = requests.get(api_url, params=params, headers=headers)
    search_results = response.json().get("query", {}).get("search", [])
    
    result = [{"page_id": result["pageid"], "title": result["title"]} for result in search_results]
    return result

for sent in sent_list:
    result = search_wiki(sent)
    print(result)

# n_sent = len(sent_list)

# sent_inputs = tokenizer(sent_list, return_tensors="pt", padding=True)
# base_sent_inputs = tokenizer(base_sent, return_tensors="pt", padding=True)

# sent_outputs = model(**sent_inputs)
# base_sent_outputs = model(**base_sent_inputs)

# sent_embeddings = sent_outputs.last_hidden_state[:, 0, :]   # cls
# base_sent_embeddings = base_sent_outputs.last_hidden_state[:, 0, :]  # cls

# cosine_similarities = F.cosine_similarity(base_sent_embeddings, sent_embeddings, dim=1)

# for i in range(n_sent):
#     print(f"Similarity between base sentence and sentence {i + 1}: {cosine_similarities[i].item()}")
    
# # sent_list 两两之间的相似度
# cos_sim = F.cosine_similarity(sent_embeddings.unsqueeze(1), sent_embeddings.unsqueeze(0), dim=-1)
# print("Similarity matrix:")
# print(cos_sim)

In [None]:
# 每条句子搜索wiki
import redis
import base64
from tqdm import tqdm
import requests
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
import random
from datetime import datetime
import base64

def text_encode(text):
    # base64 编码
    return base64.b64encode(text.encode()).decode()
def text_decode(text):
    # base64 解码
    return base64.b64decode(text.encode()).decode()

# 连接到 Redis 数据库
r = redis.Redis(host='localhost', port=6379, db=0, password='lyuredis579')

prefix = 'wikisearch:'

def search_wiki(text, max_proxies=10):
    api_url = 'https://en.wikipedia.org/w/api.php'
    # 搜出10个结果
    params = {
        "action": "query",
        "list": "search",
        "srsearch": f'"{text}"',  # 精确匹配内容
        "srwhat": "text",         # 指定在内容中搜索
        "format": "json",
    }

    proxy = {
        "https": "http://127.0.0.1:20172"
    }

    response = requests.get(api_url, params=params, proxies=proxy)
    time.sleep(random.uniform(0.05, 0.2))
    search_results = response.json().get("query", {}).get("search", [])
    
    result = [{"page_id": result["pageid"], "title": result["title"]} for result in search_results]
    return result

def text_search_task(text, max=10):
    key = prefix + text_encode(text)

    if r.exists(key):
        return json.loads(r.get(key))
    else:
        try:
            result = search_wiki(text,max)
            result = json.dumps(result, ensure_ascii=False)
            r.set(key, result)
        except Exception as e:
            print(f"Error querying {text}: {e}")
            return None

dataset_path = '../data/wiki1m_for_simcse.txt'
with open(dataset_path, 'r', encoding='utf-8') as file:
    sent_list = file.read().splitlines()

# 使用线程池进行并发请求
max_workers = 30  # 设置并发线程数量
with ThreadPoolExecutor(max_workers=max_workers) as executor:
    futures = [executor.submit(text_search_task, sent, max_workers) for sent in sent_list]
    for future in tqdm(as_completed(futures), total=len(futures)):
        future.result()  # 等待每个任务完成

In [None]:
# 统计查询为空的句子
import redis
import json
from tqdm import tqdm

# 连接 Redis 数据库
r = redis.Redis(host='localhost', port=6379, db=0, password='lyuredis579')

# 初始化计数器
empty_list_count = 0
total_count = 0
prefix = 'wikisearch:'

page_id_list = []

# 遍历符合条件的键并统计内容为空列表的键数量
for key in tqdm(r.scan_iter(prefix + '*')):
    value = r.get(key)
    # 检查值是否为空列表
    if value is not None and value.decode() == '[]':
        empty_list_count += 1
    else:
        page_id = [item['page_id'] for item in json.loads(value)]
        page_id_list.extend(page_id)
    total_count += 1
print(f"Keys with empty list content: {empty_list_count}")
print(f"Total keys: {total_count}")

# 去重
page_id_list = list(set(page_id_list))
print(f"Total page_id: {len(page_id_list)}")

r.set('page_id_list', json.dumps(page_id_list))


In [None]:
import requests
from tqdm import tqdm
import json
import redis
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
import random

# 连接 Redis 数据库
r = redis.Redis(host='localhost', port=6379, db=0, password='lyuredis579')

# 初始化计数器
empty_list_count = 0
total_count = 0
prefix = 'wikisearch:'

page_id_list_key = 'page_id_list'

if r.exists(page_id_list_key):
    page_id_list = json.loads(r.get(page_id_list_key))
else:
    page_id_list = []

    # 遍历符合条件的键并统计内容为空列表的键数量
    for key in tqdm(r.scan_iter(prefix + '*'), desc='Scan keys'):
        value = r.get(key)
        # 检查值是否为空列表
        if value is not None and value.decode() == '[]':
            empty_list_count += 1
        else:
            page_id = [item['page_id'] for item in json.loads(value)]
            page_id_list.extend(page_id)
        total_count += 1
    print(f"Keys with empty list content: {empty_list_count}")
    print(f"Total keys: {total_count}")

    # 去重
    page_id_list = list(set(page_id_list))
    print(f"Total page_id: {len(page_id_list)}")
    r.set(page_id_list_key, json.dumps(page_id_list))

r = redis.Redis(host='localhost', port=6379, db=1, password='lyuredis579')

proxy = {
    "https": "http://127.0.0.1:20171"
}

def get_detailed_page_info(page_id, language='en'):
    url = f"https://{language}.wikipedia.org/w/api.php"
    params = {
        "action": "query",
        "pageids": page_id,
        "prop": "extracts|categories|info|images|pageprops|revisions",
        "explaintext": True,  # 返回纯文本格式
        "inprop": "url",      # 包含页面的URL信息
        "format": "json"
    }
    try:
        response = requests.get(url, params=params, proxies=proxy)
        time.sleep(random.uniform(0.05, 0.2))
        data = response.json()
        
        page_data = data['query']['pages'][str(page_id)]
        
        # 将详细信息提取到字典中
        page_info = {
            "title": page_data.get("title"),
            "summary": page_data.get("extract"),  # 页面简介或全部内容
            "url": page_data.get("fullurl"),      # 页面URL
            "categories": [cat['title'] for cat in page_data.get("categories", [])],
            "images": [img['title'] for img in page_data.get("images", [])],  # 图片标题
            "wikidata_id": page_data.get("pageprops", {}).get("wikibase_item")
        }
        
        return page_info
    except Exception as e:
        print(f"Error querying page ID {page_id}: {e}")
        return None

prefix = 'wikipage:'
for page_id in tqdm(page_id_list, desc='Request page info'):
    key = prefix + str(page_id)
    if r.exists(key):
        continue
    page_info = get_detailed_page_info(page_id)
    if page_info is not None:
        r.set(key, json.dumps(page_info, ensure_ascii=False))

def get_page_info(page_id):
    key = prefix + str(page_id)
    if not r.exists(key):
        page_info = get_detailed_page_info(page_id)
        if page_info is not None:
            r.set(key, json.dumps(page_info, ensure_ascii=False))

# 使用线程池进行并发请求
max_workers = 30  # 设置并发线程数量
with ThreadPoolExecutor(max_workers=max_workers) as executor:
    futures = [executor.submit(get_page_info, page_id) for page_id in page_id_list]
    for future in tqdm(as_completed(futures), total=len(futures)):
        future.result()  # 等待每个任务完成


In [None]:
# 查询尝试
import json
import redis
from tqdm import tqdm
import base64

r = redis.Redis(host='localhost', port=6379, db=0, password='lyuredis579')

def text_encode(text):
    # base64 编码
    return base64.b64encode(text.encode()).decode()

page_dict = {}
prefix = 'wikisearch:'

input_file = '../data/wiki1m_for_simcse.txt'

with open(input_file, 'r', encoding='utf-8') as file:
    sent_list = file.read().splitlines()

sent_list = sent_list[:256]

for sent in tqdm(sent_list):
    key = prefix + text_encode(sent)
    if r.exists(key):
        page_dict[sent] = json.loads(r.get(key))

pass

In [2]:
# page信息测试
import json
import redis

page_id = 62653684
prefix = 'wikipage:'

r = redis.Redis(host='localhost', port=6379, db=0, password='lyuredis579')

key = prefix + str(page_id)
value = r.get(key)
if value is not None:
    page_info = json.loads(value)
    print(json.dumps(page_info, ensure_ascii=False, indent=4))
    # summary = page_info.get('summary', '')
    # l = len(summary.split())
    # print(l)
    

{
    "title": "YMCA in South Australia",
    "summary": "South Australia (SA) has a unique position in Australia's history as, unlike the other states which were founded as colonies, South Australia began as a self governing province. Many were attracted to this and Adelaide and SA developed as an independent and free thinking state.\nThe compound of philosophical radicalism, evangelical religion and self reliant ability typical of its founders had given an equalitarian flavour to South Australian thinking from the beginning.\nIt was into this social setting that in February 1850 a meeting was called primarily for the formation of an Association (apparently meaning a Y.M.C.A.) for apprentices and others, after their day's work, to enjoy books, lectures, discussions, readings, friendly relief and recreation for a leisure hour. In September 1850 records show that this became \"The Young Men's Christian Association of South Australia\" as evidenced by a member's letter in London Y.M.C.A.

In [None]:
# 整理数据分布
import json
import redis
import numpy as np
from tqdm import tqdm
import base64

r = redis.Redis(host='localhost', port=6379, db=0, password='lyuredis579')

sent_id_list = r.get('page_id_list')
sent_id_list = json.loads(sent_id_list)

# 获取最大值和最小值
max_value = np.max(sent_id_list)
min_value = np.min(sent_id_list)

def text_encode(text):
    # base64 编码
    return base64.b64encode(text.encode()).decode()


with open('../data/wiki1m_for_simcse.txt', 'r', encoding='utf-8') as file:
    sent_list = file.read().splitlines()

data = []
for i, sent in tqdm(enumerate(sent_list), desc='Query'):
    key = 'wikisearch:' + text_encode(sent)
    if r.exists(key):
        page_id_list = json.loads(r.get(key))
        page_id_list = [item['page_id'] for item in page_id_list]
        data.append(page_id_list)
    else:
        data.append([])
#     if i == 1000:
#         break
# print(data)

In [None]:
# 统计batch内相关度
import numpy as np

bs = 256
num_batches = len(data) // bs

for i in range(num_batches):
    batch_data = data[i * bs: (i + 1) * bs]
    batch_score = 0.0
    for page_id_list in batch_data:
        if len(page_id_list) == 0:
            continue
        score = len(set(page_id_list) & set(sent_id_list)) / len(set(page_id_list) | set(sent_id_list))
        batch_score += score

total_page_id_list = 0
total_page_unique_list = 0
for i in range(num_batches):
    batch_data = data[i * bs: (i + 1) * bs]
    batch_page_id_list = [item for sublist in batch_data for item in sublist]
    batch_page_unique_list = list(set(batch_page_id_list))
    total_page_id_list += len(batch_page_id_list)
    total_page_unique_list += len(batch_page_unique_list)
    if len(batch_page_unique_list) == 0:
        continue
    # 计算batch内的page重复率
    print(f"Batch {i}: {len(batch_page_unique_list) / len(batch_page_id_list)}")
print(f"Total: {total_page_unique_list / total_page_id_list}")