### 提取数据

In [77]:
from neo4j import GraphDatabase

class Neo4jEntityFetcher:
    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    # 获取所有实体（节点）
    def get_all_entities(self):
        with self.driver.session() as session:
            query = "MATCH (n) RETURN n"
            result = session.run(query)
            # 将实体的属性存入列表
            entities = [{"id": record["n"].element_id, "labels": list(record["n"].labels), "properties": dict(record["n"])} for record in result]
            return entities

    # 获取指定标签的实体
    def get_entities_by_label(self, label):
        with self.driver.session() as session:
            query = f"MATCH (n:{label}) RETURN n"
            result = session.run(query)
            # 将实体的属性存入列表
            entities = [{"id": record["n"].element_id, "labels": list(record["n"].labels), "properties": dict(record["n"])} for record in result]
            return entities

    # 获取指定属性的实体
    def get_entities_by_property(self, property_name, property_value):
        with self.driver.session() as session:
            query = f"MATCH (n {{{property_name}: '{property_value}'}}) RETURN n"
            result = session.run(query)
            # 将实体的属性存入列表
            entities = [{"id": record["n"].element_id, "labels": list(record["n"].labels), "properties": dict(record["n"])} for record in result]
            return entities

    # 关闭驱动
    def close(self):
        self.driver.close()


uri = "bolt://localhost:7687"  # Neo4j 数据库地址
user = "neo4j"  # Neo4j 用户名
password = "password"  # Neo4j 密码
fetcher = Neo4jEntityFetcher(uri, user, password)

all_entities = fetcher.get_all_entities()

knowledge_entities = fetcher.get_entities_by_label("knowledge")

diabetes_entities = fetcher.get_entities_by_property("name", "糖尿病")

fetcher.close()

In [78]:
knowledge_entities[0]

{'id': '4:0cfb6f88-002e-4f58-87e3-6809e3bfecb2:0',
 'labels': ['knowledge'],
 'properties': {'name': '支气管扩张症的病因包括囊性纤维化、巨大气管支气管症、肺叶内肺隔离症、免疫缺陷性疾病、感染后、机械性气道阻塞、原发性或继发性纤毛运动障碍以及变态反应性支气管肺曲霉病等。'}}

### 嵌入模型

In [24]:
import warnings
warnings.filterwarnings("ignore")

import os
import pandas as pd
from tqdm import tqdm
from gensim.models import KeyedVectors
from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn as nn

In [26]:
def LoadModel(model_path='../model/modified_bge-large-zh-v1.5'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  

    # 加载已保存的模型
    if os.path.exists(model_path) and os.path.getsize(model_path) > 1e9:
        model = AutoModel.from_pretrained(model_path).to(device)  
        tokenizer = AutoTokenizer.from_pretrained(model_path)
    else:
        # 加载词向量并处理不存在的词汇
        word_vectors = KeyedVectors.load_word2vec_format('../data/vector.txt', binary=False)
        existing_vectors = {word: word_vectors[word] for word in word_vectors.index_to_key}

        model_name = "../model/bge-large-zh-v1.5"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name).to(device)  
        vocab_size = len(tokenizer)
        embedding_dim = 1024
        new_embedding = nn.Embedding(vocab_size, embedding_dim).to(device)  

        for word, index in tokenizer.get_vocab().items():
            if word in existing_vectors:
                new_embedding.weight.data[index] = torch.cat(
                    (torch.tensor(existing_vectors[word]).to(device), torch.zeros(512, device=device))
                )

        model.embeddings.word_embeddings = new_embedding

        output_dir = "../model/modified_bge-large-zh-v1.5"
        model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)

    return model, tokenizer

In [None]:
def encode_text(model, tokenizer, text, max_length=512):
    device = model.device
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=max_length).to(device)

    model.eval()

    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1).cpu()
    return embeddings

# 使用示例
text = ["输入文本",'22213']
model, tokenizer = LoadModel()
embeddings = encode_text(model, tokenizer, text)
embeddings

### 生成索引

In [90]:
from langchain.vectorstores import FAISS
from langchain.docstore.document import Document
from GetData import Neo4jEntityFetcher
from Embedding import *
from tqdm import tqdm
import redis
from langchain.vectorstores import FAISS
from langchain.vectorstores.redis import Redis

In [32]:
uri = "bolt://localhost:7687"  # Neo4j 数据库地址
user = "neo4j"  # Neo4j 用户名
password = "password"  # Neo4j 密码
fetcher = Neo4jEntityFetcher(uri, user, password)
knowledge_entities = fetcher.get_entities_by_label("knowledge")

texts = [i['properties']['name'] for i in knowledge_entities] 
ids = [i['id'] for i in knowledge_entities]

In [40]:
model, tokenizer = LoadModel()
def batch_encode_texts(model, tokenizer, texts, batch_size=32):
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size)):
        batch_texts = texts[i:i + batch_size]
        batch_embeddings = encode_text(model, tokenizer, batch_texts)
        embeddings.extend(batch_embeddings)
    return embeddings

embeddings = batch_encode_texts(model, tokenizer, texts, batch_size=64)

100%|█████████████████████████████████████████| 156/156 [45:30<00:00, 17.51s/it]


In [105]:
redis_client = redis.StrictRedis(host='localhost', port=6380, db=0)

In [106]:
# 将数据存储到 Redis
for id, embedding, text in zip(ids, embeddings, texts):
    redis_client.hset(id, mapping={
        'embedding': str(embedding),
        'text': text
    })