In [14]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

def load_model(model_path):
    """
    加载模型和分词器。

    参数:
    model_path (str): 模型的路径。

    返回:
    tokenizer: 加载的分词器。
    model: 加载的模型。
    """
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path)
    model.eval()
    return tokenizer, model

def get_embeddings(prompts, tokenizer, model):
    """
    计算给定提示的嵌入表示。

    参数:
    prompts (list): 自定义提示的列表。
    tokenizer: 加载的分词器。
    model: 加载的模型。

    返回:
    torch.Tensor: 嵌入表示的张量。
    """
    # Tokenize custom prompts
    encoded_input = tokenizer(prompts, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)
        # Perform pooling. In this case, cls pooling.
        sentence_embeddings = model_output[0][:, 0]

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    return sentence_embeddings

def calculate_cosine_similarity(embedding_1, embeddings):
    """
    计算一个嵌入与其他多个嵌入的余弦相似度。

    参数:
    embedding_1 (torch.Tensor): 第一个嵌入表示。
    embeddings (torch.Tensor): 其他嵌入表示的张量。

    返回:
    torch.Tensor: 相似度列表。
    """
    cosine_similarities = F.cosine_similarity(embedding_1.unsqueeze(0), embeddings)
    return cosine_similarities

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
import pandas as pd
def get_theme_prompts(lda_key_words_path):
    file_path = lda_key_words_path
    df = pd.read_excel(file_path)
    
    prompts = []
    
    for index, row in df.iterrows():
        theme_category = row['主题类别']
        keywords = row.dropna()[1:].tolist()
        prompt = f"新闻的主题是{theme_category}，关键词是{'、'.join(keywords)}"
        prompts.append(prompt)
    

    # 返回6个prompt
    return prompts


In [18]:
if __name__ == "__main__":
    lda_key_words_path = '~/LLM_news_emo_analyze/DATA/lda_key_words.xlsx'
    news_prompt = "证券时报网讯，中信建投研报指出，2024年上半年，根据样本数据，医疗器械板块收入同比增长1%，扣非归母净利润同比增长3%。"
    key_topic_lisy = get_theme_prompts(lda_key_words_path)
    model_path = '/root/.cache/LLMS/hub/BAAI/bge-large-zh-v1___5'  # BGE-large-zh-v1.5模型
    tokenizer, model = load_model(model_path)

    # 获取所有提示的嵌入
    all_prompts = [news_prompt] + key_topic_lisy
    embeddings = get_embeddings(all_prompts, tokenizer, model)

    # 计算prompt_1与其他提示的相似度
    similarity_scores = calculate_cosine_similarity(embeddings[0], embeddings[1:])

    # 找出相似度最高的提示及其相似度值
    max_similarity_index = torch.argmax(similarity_scores).item()
    max_similarity_score = similarity_scores[max_similarity_index].item()
    most_similar_prompt = key_topic_lisy[max_similarity_index]

    print("与news_prompt最相似的提示是:", most_similar_prompt)
    print("相似度值:", max_similarity_score)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


与news_prompt最相似的提示是: 新闻的主题是生产投资，关键词是制造、设备、研发、制造业、工业、生产线、机械、材料、设计、国外、钢铁、产能
相似度值: 0.39552390575408936


In [None]:
lda_key_words_path = '~/LLM_news_emo_analyze/DATA/lda_key_words.xlsx'
def find_most_similar_prompt(input_prompt,lda_key_words_path):
    theme_prompts = get_theme_prompts(lda_key_words_path)
    max_similarity = -1
    most_similar_prompt = None
    for theme_prompt in theme_prompts:
        embeddings = get_embeddings([input_prompt, theme_prompt], tokenizer, model)
        similarity = calculate_cosine_similarity(embeddings)
        
        if similarity > max_similarity:
            max_similarity = similarity
            most_similar_prompt = theme_prompt

return most_similar_prompt, max_similarity

In [None]:
    for theme_prompt in theme_prompts:
        embeddings = get_embeddings([input_prompt, theme_prompt], tokenizer, model)
        similarity = calculate_cosine_similarity(embeddings)
        
        if similarity > max_similarity:
            max_similarity = similarity
            most_similar_prompt = theme_prompt
    
    return most_similar_prompt, max_similarity

In [9]:
lda_key_words_path = '~/LLM_news_emo_analyze/DATA/lda_key_words.xlsx'


In [None]:
if __name__ == "__main__":
    prompt_1 = "这是第一个自定义的提示"
    prompt_2 = "这是第二个自定义的提示"
    model_path = '/root/.cache/LLMS/hub/BAAI/bge-large-zh-v1___5' # BG-large-zh-v1.5模型
    
    tokenizer, model = load_model(model_path)
    embeddings = get_embeddings([prompt_1, prompt_2], tokenizer, model)
    
    similarity = calculate_cosine_similarity(embeddings)
    print("余弦相似度:", similarity)