In [3]:
from transformers import BertModel, BertTokenizer
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [5]:
# 加载预训练的BERT模型和分词器
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name)


Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
# 生成段落向量
def generate_paragraph_vector(paragraph, model, tokenizer):
    inputs = tokenizer(paragraph, return_tensors="pt", truncation=True, padding=True, max_length=512)
    outputs = model(**inputs)
    paragraph_vector = outputs.last_hidden_state.mean(dim=1).detach().numpy()
    return paragraph_vector

In [7]:
# 计算语义相似度
def calculate_similarity(query_vector, paragraph_vector):
    return cosine_similarity(query_vector, paragraph_vector)[0][0]

In [8]:
# 召回筛选
def retrieve_paragraphs(query, documents, model, tokenizer, threshold=0.8):
    query_vector = generate_paragraph_vector(query, model, tokenizer)
    
    relevant_paragraphs = []
    for doc_id, document in enumerate(documents):
        for para_id, paragraph in enumerate(document["paragraphs"]):
            para_vector = generate_paragraph_vector(paragraph, model, tokenizer)
            similarity = calculate_similarity(query_vector, para_vector)
            if similarity >= threshold:
                relevant_paragraphs.append({"doc_id": doc_id, "para_id": para_id, "paragraph": paragraph, "similarity": similarity})
                
    relevant_paragraphs.sort(key=lambda x: x["similarity"], reverse=True)
    return relevant_paragraphs


In [9]:
# 用户输入
query = "请找出和石油、故障、运维相关的段落。"

In [10]:
# 示例文档
documents = [
    {
        "title": "Document 1",
        "paragraphs": [
            "石油是一种重要的能源，广泛用于各个行业。",
            "石油开采过程中可能出现的故障会导致生产中断，需要及时进行运维。",
            "运维团队需要定期检查设备，确保生产顺利进行。"
        ]
    },
    {
        "title": "Document 2",
        "paragraphs": [
            "故障诊断是运维过程中的重要环节，可以及时发现并解决问题。",
            "石油行业的运维管理需要高度重视安全和环保。",
            "新能源的发展对石油产业产生了一定的影响。"
        ]
    }
]

In [11]:
# 执行召回筛选
relevant_paragraphs = retrieve_paragraphs(query, documents, bert_model, tokenizer, threshold=0.8)

In [12]:
# 输出结果
print("Relevant paragraphs:")
for para in relevant_paragraphs:
    print(f"Document {para['doc_id']}, Paragraph {para['para_id']}: {para['paragraph']} (Similarity: {para['similarity']:.4f})")

Relevant paragraphs:
Document 0, Paragraph 1: 石油开采过程中可能出现的故障会导致生产中断，需要及时进行运维。 (Similarity: 0.9557)
Document 1, Paragraph 2: 新能源的发展对石油产业产生了一定的影响。 (Similarity: 0.9471)
Document 0, Paragraph 0: 石油是一种重要的能源，广泛用于各个行业。 (Similarity: 0.9443)
Document 1, Paragraph 1: 石油行业的运维管理需要高度重视安全和环保。 (Similarity: 0.9407)
Document 0, Paragraph 2: 运维团队需要定期检查设备，确保生产顺利进行。 (Similarity: 0.9380)
Document 1, Paragraph 0: 故障诊断是运维过程中的重要环节，可以及时发现并解决问题。 (Similarity: 0.9136)
