In [10]:
import json
from typing import Union
import os
from fastapi import FastAPI, HTTPException
from vanna.openai import OpenAI_Chat
from vanna.chromadb import ChromaDB_VectorStore
import pandas as pd
from dotenv import find_dotenv, load_dotenv
from vanna.chromadb import ChromaDB_VectorStore
from vanna.openai import OpenAI_Chat
from chromadb.utils import embedding_functions
from openai import OpenAI
from pydantic import BaseModel

load_dotenv(find_dotenv())


True

In [11]:

class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)


In [12]:
# Select the embedding model to use.
# List of model names can be found here https://www.sbert.net/docs/pretrained_models.html
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
    model_name="BAAI/bge-large-zh-v1.5",
)
client = OpenAI(
    api_key=os.getenv("OPENAI_API_KEY"),
    base_url=os.getenv("OPENAI_BASE_URL"),
)
vn = MyVanna(
    config={
        "model": "qwen25:72b",
        "temperature": 0,
        "language": "chinese",
        "path": "..\chromadb_data",
        "embedding_function": sentence_transformer_ef,
    }
)

vn.connect_to_mysql(host=os.getenv("DORIS_HOST"), dbname=os.getenv("DORIS_DATABASE"), user=os.getenv("DORIS_USER"), password=os.getenv("DORIS_PASSWORD"), port=int(os.getenv("DORIS_PORT")))

In [None]:
# The information schema query may need some tweaking depending on your database. This is a good starting point.
df_information_schema = vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS")


In [None]:
# This will break up the information schema into bite-sized chunks that can be referenced by the LLM
plan = vn.get_training_plan_generic(df_information_schema)

In [None]:
# If you like the plan, then uncomment this and run it to train
vn.train(plan=plan)

In [None]:
question = '订单数量排名前五的采购订单信息,查询的字段必须有其对应的中文别名。'
sql_query = vn.generate_sql(question)

In [None]:
df_result = vn.run_sql(sql_query)
df_result

In [None]:
poltly_code  = vn.generate_plotly_code(question,sql_query,df_result)
poltly_code

In [None]:
figure = vn.get_plotly_figure(poltly_code,df_result,dark_mode=False)
figure

In [11]:
import plotly.io as pio
plotly_json = pio.to_json(figure)



In [None]:
sumamary  = vn.generate_summary(question,df_result)
sumamary

In [None]:
from plotly.io import from_json

fig = from_json(plotly_json)
fig.show()