In [1]:

from typing import List
import asyncio
import os
import numpy as np
from dotenv import load_dotenv
import aiohttp
from gqlalchemy import Memgraph
import logging
import time
import sys
from typing import List, Dict, Any, Optional


In [2]:
# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler("vector_embedding.log")
    ]
)
logger = logging.getLogger(__name__)

# 加载环境变量
load_dotenv()

# 获取API密钥与验证
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY")
if not SILICONFLOW_API_KEY:
    logger.error("未设置SILICONFLOW_API_KEY环境变量")
    sys.exit(1)

# 连接Memgraph
memgraph_host = os.getenv("MEMGRAPH_HOST", "localhost")
memgraph_port = int(os.getenv("MEMGRAPH_PORT", "7687"))

# API配置
MAX_RETRIES = 3
RETRY_DELAY = 2
BATCH_API_SIZE = 8  # API批量调用大小

async def connect_to_memgraph(max_retries: int = 3) -> Memgraph:
    """连接到Memgraph数据库，带重试机制"""
    retry_count = 0
    while retry_count < max_retries:
        try:
            memgraph = Memgraph(host=memgraph_host, port=memgraph_port)
            # 测试连接
            list(memgraph.execute_and_fetch("RETURN 1"))
            logger.info(f"成功连接到Memgraph: {memgraph_host}:{memgraph_port}")
            return memgraph
        except Exception as e:
            retry_count += 1
            logger.warning(f"连接Memgraph失败 (尝试 {retry_count}/{max_retries}): {str(e)}")
            if retry_count < max_retries:
                await asyncio.sleep(RETRY_DELAY)
            else:
                logger.error(f"无法连接到Memgraph，已达最大重试次数: {str(e)}")
                raise

async def get_embedding(texts: List[str], api_key: str, retry_count: int = 0) -> List[List[float]]:
    """调用API获取多个文本的嵌入向量，支持重试机制"""
    api_url = "https://api.siliconflow.cn/v1/embeddings"
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }
    
    payload = {
        "model": "netease-youdao/bce-embedding-base_v1",
        "input": texts,
        "max_token_size": 512
    }
    
    try:
        async with aiohttp.ClientSession() as session:
            async with session.post(api_url, headers=headers, json=payload) as response:
                if response.status != 200:
                    error_text = await response.text()
                    raise ValueError(f"API错误: {response.status}, {error_text}")
                
                result = await response.json()
                return [item["embedding"] for item in result["data"]]
    except Exception as e:
        if retry_count < MAX_RETRIES:
            logger.warning(f"获取嵌入向量失败，尝试重试 ({retry_count+1}/{MAX_RETRIES}): {str(e)}")
            await asyncio.sleep(RETRY_DELAY * (2 ** retry_count))  # 指数退避
            return await get_embedding(texts, api_key, retry_count + 1)
        else:
            logger.error(f"获取嵌入向量失败，超过最大重试次数: {str(e)}")
            raise


In [3]:
def format_entity_text(entity: Dict[str, Any]) -> str:
    """根据实体属性格式化文本用于嵌入"""
    props = entity["properties"]
    label = entity["labels"][0] if entity["labels"] else "Unknown"
    name = props.get("name", "")
    description = props.get("description", "")
    
    # 组合文本
    text_parts = []
    if label:
        text_parts.append(f"类型: {label}")
    if name:
        text_parts.append(f"名称: {name}")
    if description and description != name:
        text_parts.append(f"描述: {description}")
    
    return ", ".join(text_parts)