In [None]:
# Dragon Knowledge Base - DKB Demo
这是一个使用towhee 将文档知识库内容经过模型转换后存入Milvus数据库的例子

In [None]:
# 查看环境变量参数
%env ENV_FOR_DYNACONF=local
import os 
os.chdir('/Users/wangjialong/Documents/code/saic_project/global-vehicle-dragon/data-processing')

# 重新加载模块
from importlib import reload
from config.settings import Settings
import db.db_manager
import config.settings
import os
reload(config.settings)
reload(db.db_manager)
print(f"Debug: {bool(Settings.DEBUG)}")


In [None]:
# 根据数据库的内容，通过facebook/dpr-ctx_encoder-single-nq-base模型，转为向量存入milvus
import pandas as pd
from sqlalchemy import func, create_engine
from db.db_manager import DBConn
from db.model.dragon_knowledge_base import DragonKnowledgeBase
from pymilvus import (
    connections,
    utility,
    FieldSchema, CollectionSchema, DataType,
    Collection,
)
from towhee import pipe, ops
import numpy as np
from towhee.datacollection import DataCollection
from db.milvus_manager import MilvusConn

# 最大文本的长度
max_token = 512
# 查询分区数量
partition_num = 100

milvus_conn = MilvusConn()

insert_pipe = (
    pipe.input('id', 'name', 'content')
        .map('content', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))
        .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
        .map(('id', 'vec'), 'insert_status', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='dkb'))
        .output()
)

db_conn = DBConn()

count_query = f"SELECT count(1) FROM p_dragon.dragon_knowledge_base LIMIT {partition_num}"
total_num = db_conn.session.query(func.count(DragonKnowledgeBase.id)).scalar()
iteration_times = total_num // partition_num + (0 if total_num % partition_num == 0 else 1)
db_conn.close()

print(f"Records Count: {total_num} | iterations time: {iteration_times}")
for i in range(iteration_times):
    engine = create_engine(Settings.DB_URL)
    start_num = i * partition_num
    select_query = f"SELECT id, name, LEFT(content, {max_token}) as content FROM p_dragon.dragon_knowledge_base LIMIT {partition_num} OFFSET {start_num}"     
    with engine.connect() as connection, connection.begin():
        df = pd.read_sql_query(sql=select_query, con=connection.connection)
        for index, row in enumerate(df.itertuples(index=False), start=1):
            if row.id is None:
                break
            else:
                pass
                
            print(f"处理数据 分片{i}|{index}")
            insert_pipe(*row)
    engine.dispose()

milvus_conn.flush()
milvus_conn.stats()
print(f"all done")

# 模型效果
facebook的模型对英文的搜索是精准的，但是在中文的处理上，效果非常差。

In [None]:
from db.db_manager import DBConn
from db.model.dragon_knowledge_base import DragonKnowledgeBase

prompt_word = '物料响应确认'

ans_pipe = (
    pipe.input('question')
        .map('question', 'vec', ops.text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base"))
        .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
        .map('vec', 'res', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='dkb', limit=5))        
        .map('res', 'ids', lambda x: [int(i[0]) for i in x])
        .output('question', 'ids')
)

ans = ans_pipe(prompt_word)
ans = DataCollection(ans)
db_conn = DBConn()
try:
    for ans_result in ans:
        for d_id in ans_result['ids']:
            query_result = db_conn.session.query(DragonKnowledgeBase)\
                .filter(DragonKnowledgeBase.id == d_id)\
                .first();
            print(f"{d_id} | {query_result.content}")
except Exception as e:
    print(f"{e}")
finally:
    db_conn.close()

# ans.show()