In [None]:
from langchain_community.document_loaders import TextLoader
from langchain.embeddings.base import Embeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
# 可用，但下面的程序中没有使用
from langchain_openai import OpenAIEmbeddings
from transformers import AutoTokenizer, AutoModel
import torch
import faiss
from tqdm import tqdm
from typing import List

model_name = "monologg/biobert_v1.1_pubmed"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
import os
import pickle

def get_medical_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    embedding = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
    return embedding

def batch_get_medical_embedding(texts, batch_size=8):
    all_embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="生成嵌入向量"):
        batch_texts = texts[i:i+batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=500)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
        all_embeddings.extend(embeddings)
    return all_embeddings

# 保存向量存储的函数
def save_vectorstore(vectorstore, file_path="medical_vectorstore.pkl"):
    with open(file_path, "wb") as f:
        pickle.dump(vectorstore, f)
    print(f"向量存储已保存到 {file_path}")

# 加载向量存储的函数
def load_vectorstore(file_path="medical_vectorstore.pkl"):
    if os.path.exists(file_path):
        with open(file_path, "rb") as f:
            vectorstore = pickle.load(f)
        print(f"已从 {file_path} 加载向量存储")
        return vectorstore
    else:
        print(f"未找到向量存储文件 {file_path}，需要重新创建")
        return None

# 加载数据
loader = TextLoader(
    file_path="QA_Health_and_Personal_Care.json",
)

documents = loader.load()

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,      # 每个块的最大字符数
    chunk_overlap=100,   # 块之间的重叠字符
    separators=["\n\n", "\n", ".", "?", "!", "。", "？", "！"]
)
chunks = text_splitter.split_documents(documents)

print(len(chunks))
print(chunks[0])

chunks = chunks[:500]

vectorstore = load_vectorstore()

if vectorstore is None:
    class MedicalEmbeddings(Embeddings):
        def __init__(self):
            super().__init__()
    
        def embed_documents(self, texts: List[str]) -> List[List[float]]:
            return [get_medical_embedding(text) for text in texts]
        
        def embed_query(self, text: str) -> List[float]:
            return get_medical_embedding(text)
        
        # 实现 __call__ 方法以兼容旧版 Langchain
        def __call__(self, text: str) -> List[float]:
            return self.embed_query(text)
    
    medical_embeddings = MedicalEmbeddings()
    print("创建中...")
    vectorstore = FAISS.from_documents(chunks, medical_embeddings)
    
    # 保存向量存储以便下次使用
    save_vectorstore(vectorstore)

# 检索函数
def retrieve_context(query, k=5):
    docs = vectorstore.similarity_search(query, k=k)
    return [doc.page_content for doc in docs]
