In [None]:
import json
import time
from pathlib import Path
from typing import List, Dict

import chromadb
from llama_index.core import VectorStoreIndex, StorageContext, Settings
from llama_index.core.schema import TextNode
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import PromptTemplate
from transformers import AutoModelForCausalLM, AutoTokenizer
from llama_index.core.llms import ChatMessage
from llama_index.llms.huggingface import HuggingFaceLLM
import torch
import pandas as pd 
import json

In [None]:
import torch

# 检查 CUDA 是否可用并设置设备
if torch.cuda.is_available():
    device = torch.device("cuda")  # 使用 GPU
else:
    device = torch.device("cpu")  # 如果没有 GPU，使用 CPU
print(f"Using device: {device}")

In [None]:

from llama_index.llms.openai_like import OpenAILike
from llama_index.core import Settings

# 配置vLLM服务端参数
class VLLMConfig:
    API_BASE = "http://localhost:8000/v1"  # vLLM的默认端点
    MODEL_NAME = "DeepSeek-R1-Distill-Qwen-1___5B"
    API_KEY = "no-key-required"  # vLLM默认不需要密钥
    TIMEOUT = 60  # 请求超时时间

# 初始化LLM（替换原来的HuggingFaceLLM）
def init_vllm_llm():
    return OpenAILike(
        model=VLLMConfig.MODEL_NAME,
        api_base=VLLMConfig.API_BASE,
        api_key=VLLMConfig.API_KEY,
        temperature=0.3,
        max_tokens=1024,
        timeout=VLLMConfig.TIMEOUT,
        is_chat_model=True,  # 适用于对话模型
        additional_kwargs={"stop": ["<|im_end|>"]}  # DeepSeek的特殊停止符
    )

# 在全局设置中配置
Settings.llm = init_vllm_llm()

In [None]:
with open('/Users/vitol/vscode/firm_perf/firm_perf_code/转JSON/json文件汇总/stk_basic_info.json', 'r', encoding='utf-8') as file:

    json_list = json.load(file)

In [None]:
def create_nodes(raw_data: List[Dict]) -> List[TextNode]:
    """添加ID稳定性保障"""
    nodes = []
    for entry in raw_data:
        j2str = []
        content_str = json.dumps(entry['text'],ensure_ascii=False)

        id_e = entry["id"]



        try:
            node = TextNode(
                    text=content_str,
                    id_=id_e,  # 显式设置稳定ID
                    metadata={
                        "公司全称": entry['text']['公司全称'],
                        "公司简称": entry['text']['公司简称'],
                        "股票代码": entry['text']['股票代码'],
                        "行业分类": entry['text']['申万行业分类'],
                    }
                )
            nodes.append(node)
        except Exception as e:
            print(entry['id'][:9])
            node = TextNode(
                    text=content_str,
                    id_=id_e,  # 显式设置稳定ID
                    metadata={
                        "股票代码": entry['id'][:9],
                    }
                )
            nodes.append(node)
            continue
    
    print(f"生成 {len(nodes)} 个文本节点（ID示例：{nodes[0].id_}）")
    return nodes

In [None]:
nodes = create_nodes(json_list)

In [None]:
nodes[0]

In [None]:
!pip install -U vllm

In [None]:
class Config:
    EMBED_MODEL_PATH = r"/Users/vitol/vscode/firm_perf/firm_perf_code/embed"
    LLM_MODEL_PATH = r"/Users/vitol/vscode/firm_perf/firm_perf_code/model"
    
    DATA_DIR = "./data"
    VECTOR_DB_DIR = "./chroma_db"
    PERSIST_DIR = "./storage"
    
    COLLECTION_NAME = "firm_info"
    TOP_K = 1

In [None]:
def init_models():
    """初始化模型并验证"""
    model = AutoModelForCausalLM.from_pretrained(Config.LLM_MODEL_PATH, local_files_only=True)
    tokenizer = AutoTokenizer.from_pretrained(Config.LLM_MODEL_PATH, local_files_only=True)
    # Embedding模型
    embed_model = HuggingFaceEmbedding(
        model_name=Config.EMBED_MODEL_PATH,
        device=device,
    )
    
    # LLM
    llm = HuggingFaceLLM(
        model=model,
        tokenizer=tokenizer,
        model_kwargs={
            "trust_remote_code": True,
            "device_map": device
        },
        tokenizer_kwargs={"trust_remote_code": True},
        generate_kwargs={"temperature": 0.3}
    )
    
    Settings.embed_model = embed_model
    Settings.llm = llm
    
    # 验证模型
    test_embedding = embed_model.get_text_embedding("测试文本")
    print(f"Embedding维度验证：{len(test_embedding)}")
    
    return embed_model, llm

In [None]:
def init_vector_store(nodes: List[TextNode]) -> VectorStoreIndex:
    chroma_client = chromadb.PersistentClient(path=Config.VECTOR_DB_DIR)
    chroma_collection = chroma_client.get_or_create_collection(
        name=Config.COLLECTION_NAME,
        metadata={"hnsw:space": "cosine"}
    )

    # 确保存储上下文正确初始化
    storage_context = StorageContext.from_defaults(
        vector_store=ChromaVectorStore(chroma_collection=chroma_collection)
    )

    # 判断是否需要新建索引
    if chroma_collection.count() == 0 and nodes is not None:
        print(f"创建新索引（{len(nodes)}个节点）...")
        
        # 显式将节点添加到存储上下文
        storage_context.docstore.add_documents(nodes)  
        
        index = VectorStoreIndex(
            nodes,
            storage_context=storage_context,
            show_progress=True
        )
        # 双重持久化保障
        storage_context.persist(persist_dir=Config.PERSIST_DIR)
        index.storage_context.persist(persist_dir=Config.PERSIST_DIR)  # <-- 新增
    else:
        print("加载已有索引...")
        storage_context = StorageContext.from_defaults(
            persist_dir=Config.PERSIST_DIR,
            vector_store=ChromaVectorStore(chroma_collection=chroma_collection)
        )
        index = VectorStoreIndex.from_vector_store(
            storage_context.vector_store,
            storage_context=storage_context,
            embed_model=Settings.embed_model
        )
        # 安全验证
    print("\n存储验证结果：")
    doc_count = len(storage_context.docstore.docs)
    print(f"DocStore记录数：{doc_count}")
    
    if doc_count > 0:
        sample_key = next(iter(storage_context.docstore.docs.keys()))
        print(f"示例节点ID：{sample_key}")
    else:
        print("警告：文档存储为空，请检查节点添加逻辑！")
    
    
    return index

In [None]:
embed_model,llm = init_models()

In [None]:
def main(nodes):
    QA_TEMPLATE = (
    "<|im_start|>system\n"
    "你是一个专业的上市公司研究员，你的知识库中有一系列的块级数据，每一个块是一个上市公司的基本信息、包括股价、年化指标、员工评论数据。在检索公司信息时，首先，你需要根据客户提供的公司名和每个块的'id_'进行匹配，找到唯一的、最相似的块，然后就仅仅基于这个最相似的块进行进一步检索，不需要再检索其他块的信息。在检索到块后，你需要严格根据块内提供的公司信息，用中文回答问题。必要时根据客户的要求针对公司进行基于信息的适当分析。如果实在索引不到知识库中所提供的公司信息数据，那么请回复：我不知道。如果客户的问题比较泛化，比如问：“你对某个公司的发展有什么看法？”，或者问：“评价一下某公司？”，那么请根据公司信息进行分析总结后给出你的看法。最重要的一点是！你需要尽可能快的给我回复！\n"
    "相关的上市公司信息数据：\n{context_str}\n<|im_end|>\n"
    "<|im_start|>user\n{query_str}<|im_end|>\n"
    "<|im_start|>assistant\n"
)
    response_template = PromptTemplate(QA_TEMPLATE)
    
    
    # # 仅当需要更新数据时执行
    # if not Path(Config.VECTOR_DB_DIR).exists():
    #     print("\n初始化数据...")

    # else:
    #     nodes = None  # 已有数据时不加载
    
    print("\n初始化向量存储...")
    start_time = time.time()

    index = init_vector_store(nodes)
    print(f"索引加载耗时：{time.time()-start_time:.2f}s")
    
    # 创建查询引擎
    query_engine = index.as_query_engine(
        similarity_top_k=Config.TOP_K,
        # text_qa_template=PromptTemplate(text_qa_template),
        # text_qa_template=response_template,
        verbose=True
    )
    
    # 示例查询
    while True:
        question = input("\n请输入公司相关问题（输入q退出）: ")
        if question.lower() == 'q':
            break
        
        # 执行查询
        response = query_engine.query(question)
        
        # 显示结果
        print(f"\n智能助手回答：\n{response.response}")



In [None]:
main(nodes)