In [None]:
%pip install 'vanna[chromadb,openai,mysql]'

In [2]:
from vanna.base import VannaBase
from vanna.chromadb import ChromaDB_VectorStore
from openai import OpenAI

# 自定义LLM

https://vanna.ai/docs/mysql-other-llm-chromadb/

In [3]:
DEBUG_INFO=None

class SiliconflowLLM(VannaBase):
  def __init__(self,client=None,config=None):
    self.model=config['model']
    self.client=client
  
  def system_message(self,message: str):
    return {'role':'system','content':message}

  def user_message(self, message: str):
    return {'role':'user','content':message}

  def assistant_message(self, message: str):
    return {'role':'assistant','content':message}
  
  def submit_prompt(self,prompt,**kwargs):
    resp = self.client.chat.completions.create(
      model=self.model,
      messages=prompt,
      stream=True  # 启用流式输出
    )

    answer=''
    for chunk in resp:
      answer = chunk.choices[0].delta.content
      print(answer, end='', flush=True)

    global DEBUG_INFO
    DEBUG_INFO=(prompt,answer)
    return answer

# Vanna客户端

In [4]:
class MyVanna(ChromaDB_VectorStore,SiliconflowLLM):
    def __init__(self,client=None, config=None):
        ChromaDB_VectorStore.__init__(self,config=config)
        SiliconflowLLM.__init__(self,client=client,config=config)

In [5]:

client = OpenAI(
    api_key='sk-tcoagstffdthsvfowprqygyupjthicblaskkwokcndjqllat',
    base_url='https://api.siliconflow.cn/v1'
)

vn=MyVanna(client=client,config={'model': 'deepseek-ai/DeepSeek-V3'})

# MySQL服务端

sudo docker run -d --name mysql-vanna -p 3306:3306 -e MYSQL_ROOT_PASSWORD=123456 mysql:8

mysql --protocol=tcp -hlocalhost -P3306 -uroot -p123456

CREATE TABLE IF NOT EXISTS user (
        id INT PRIMARY KEY COMMENT '用户ID' ,
        name VARCHAR(100) COMMENT '姓名',
        age INT COMMENT '年龄'
    ) COMMENT '用户信息表';
    
insert into user values(1,'小鱼儿',34),(2,'小悲剧',36);

# 构造向量库

In [6]:
vn.connect_to_mysql(host='www.hxfssc.com',dbname='vanna',user='root',password='root',port=3306)

In [7]:
DDL='''CREATE TABLE IF NOT EXISTS user (
        id INT PRIMARY KEY COMMENT '用户ID' ,
        name VARCHAR(100) COMMENT '姓名',
        age INT COMMENT '年龄'
    ) COMMENT '用户信息表';
'''

In [None]:
# 存储DDL到向量库
vn.train(ddl=DDL)

In [None]:
# 存储document到向量库
vn.train(documentation='"福报"是指age>=35岁，也就是可以向社会输送的人才')

In [None]:
# 存储SQL到向量库
'''
1，通过LLM根据SQL构造一个question
2，按question-SQL的JSON入库
            {
                "question": question,
                "sql": sql,
            }
'''

vn.train(sql='select name from user where age between 10 and 20')

In [None]:
Q,A=DEBUG_INFO
print('PROMPT:',Q[0]['content'])
print('ANSWER:',A)

In [None]:
# 存储question-SQL到向量库
'''
按question-SQL的JSON入库
            {
                "question": question,
                "sql": sql,
            }
'''
vn.train(question='小鱼儿的年龄',sql='select age from user where name="小鱼儿"')

In [None]:
# 检查所有入库的知识
vn.get_training_data()

# 开始查询

In [None]:
# 基本使用
result=vn.generate_sql('用户的平均年龄')
print('SQL:',result)

Q,A=DEBUG_INFO
print('PROMPT:',Q[0]['content'])
print('ANSWER:',A)

In [None]:
vn.ask('用户的平均年龄')

In [None]:
result=vn.generate_sql('打算给一批员工送福报，把他们的名字过滤出来')

In [None]:
vn.ask('打算给一批员工送福报，把他们的名字过滤出来')

In [None]:
Q,A=DEBUG_INFO
print('PROMPT:',Q[0]['content'])
print('ANSWER:',A)

In [None]:
vn.generate_sql('统计一下各年龄段的用户数量,年龄段是指0-10,10-20,20-30,30-40,40-50,50-60,60-70,70-80...左闭右开区间')

In [None]:
# 知识沉淀后可以直接提问
vn.train(documentation='用户年龄段划分逻辑：0-10,10-20,20-30,30-40,40-50,50-60,60-70,70-80...左闭右开区间')
vn.generate_sql('统计一下各年龄段的用户数量')