In [1]:
import pandas as pd
from vanna.openai import OpenAI_Chat
from vanna.base import VannaBase
from vanna.flask import VannaFlaskApp
from vanna.vannadb import VannaDB_VectorStore



In [2]:
class MyCustomVectorDB(VannaBase):
    def __init__(self, config=None):
        super().__init__(config)
        self.ddl_list = []
        self.documentation_list = []
        self.question_sql_list = []

    def add_ddl(self, ddl: str, **kwargs) -> str:
        self.ddl_list.append(ddl)
        return ddl

    def add_documentation(self, doc: str, **kwargs) -> str:
        self.documentation_list.append(doc)
        return doc

    def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
        self.question_sql_list.append((question, sql))
        return sql

    def get_related_ddl(self, question: str, **kwargs) -> list:
        return self.ddl_list

    def get_related_documentation(self, question: str, **kwargs) -> list:
        return self.documentation_list

    def get_similar_question_sql(self, question: str, **kwargs) -> list:
        return self.question_sql_list

    def get_training_data(self, **kwargs) -> pd.DataFrame:
        # Créer des listes de même longueur
        max_len = max(len(self.ddl_list), len(self.documentation_list), len(self.question_sql_list))
        
        # Étendre chaque liste à la longueur maximale avec None
        ddl_extended = self.ddl_list + [None] * (max_len - len(self.ddl_list))
        doc_extended = self.documentation_list + [None] * (max_len - len(self.documentation_list))
        qs_extended = self.question_sql_list + [None] * (max_len - len(self.question_sql_list))
        
        data = {
            'ddl': ddl_extended,
            'documentation': doc_extended,
            'question_sql': qs_extended
        }
        return pd.DataFrame(data)

    def remove_training_data(self, id: str, **kwargs) -> bool:
        return True

    def generate_embedding(self, text: str, **kwargs) -> list:
        # Retourne un vecteur d'embedding simple (liste de 0)
        return [0] * 10


class MyVanna(OpenAI_Chat, MyCustomVectorDB):  # Changement de l'ordre d'héritage
    def __init__(self, config=None):
        config = config or {}  # Si config est None, utilise un dictionnaire vide
        OpenAI_Chat.__init__(self, config=config)  # Initialise OpenAI_Chat en premier
        VannaBase.__init__(self, config=config)  # Initialise VannaBase directement
        self.ddl_list = []  # Initialise les listes directement ici
        self.documentation_list = []
        self.question_sql_list = []

vn = MyVanna(config={'api_key': 'sk-KEuDYxIl4E88PJeq1reRWZ6yixfPFFHW4npS6j87YgT3BlbkFJBoj3Q-iRJqsp1CNS6MKv8bKAR9yTL1-MEIIxWUtHoA', 'model': 'gpt-4o-mini'})


In [3]:
vn.connect_to_sqlite('sqlite-sakila.db')


In [4]:
df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")

for ddl in df_ddl['sql'].to_list():
  vn.train(ddl=ddl)

Adding ddl: CREATE TABLE actor (
  actor_id numeric NOT NULL ,
  first_name VARCHAR(45) NOT NULL,
  last_name VARCHAR(45) NOT NULL,
  last_update TIMESTAMP NOT NULL,
  PRIMARY KEY  (actor_id)
  )
Adding ddl: CREATE INDEX idx_actor_last_name ON actor(last_name)

Adding ddl: CREATE TRIGGER actor_trigger_ai AFTER INSERT ON actor
 BEGIN
  UPDATE actor SET last_update = DATETIME('NOW')  WHERE rowid = new.rowid;
 END
Adding ddl: CREATE TRIGGER actor_trigger_au AFTER UPDATE ON actor
 BEGIN
  UPDATE actor SET last_update = DATETIME('NOW')  WHERE rowid = new.rowid;
 END
Adding ddl: CREATE TABLE country (
  country_id SMALLINT NOT NULL,
  country VARCHAR(50) NOT NULL,
  last_update TIMESTAMP,
  PRIMARY KEY  (country_id)
)
Adding ddl: CREATE TRIGGER country_trigger_ai AFTER INSERT ON country
 BEGIN
  UPDATE country SET last_update = DATETIME('NOW')  WHERE rowid = new.rowid;
 END
Adding ddl: CREATE TRIGGER country_trigger_au AFTER UPDATE ON country
 BEGIN
  UPDATE country SET last_update = DATETIM

In [5]:
VannaFlaskApp(vn, allow_llm_to_see_data=False).run()

Your app is running at:
http://localhost:8084
 * Serving Flask app 'vanna.flask'
 * Debug mode: on
