In [1]:
%pip install vanna[openai]
%pip install python-dotenv
%pip install pandas

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import pandas as pd
from vanna.openai import OpenAI_Chat
from vanna.base import VannaBase
from vanna.flask import VannaFlaskApp
from vanna.vannadb import VannaDB_VectorStore
from vanna.flask.auth import AuthInterface
import flask
import os
from dotenv import load_dotenv
load_dotenv()

True

In [3]:
class MyCustomVectorDB(VannaBase):
    def __init__(self, config=None):
        super().__init__(config)
        self.ddl_list = []
        self.documentation_list = []
        self.question_sql_list = []

    def add_ddl(self, ddl: str, **kwargs) -> str:
        self.ddl_list.append(ddl)
        return ddl

    def add_documentation(self, doc: str, **kwargs) -> str:
        self.documentation_list.append(doc)
        return doc

    def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
        self.question_sql_list.append((question, sql))
        return sql

    def get_related_ddl(self, question: str, **kwargs) -> list:
        return self.ddl_list

    def get_related_documentation(self, question: str, **kwargs) -> list:
        return self.documentation_list

    def get_similar_question_sql(self, question: str, **kwargs) -> list:
        return self.question_sql_list

    def get_training_data(self, **kwargs) -> pd.DataFrame:
        # Créer des listes de même longueur
        max_len = max(len(self.ddl_list), len(self.documentation_list), len(self.question_sql_list))
        
        # Étendre chaque liste à la longueur maximale avec None
        ddl_extended = self.ddl_list + [None] * (max_len - len(self.ddl_list))
        doc_extended = self.documentation_list + [None] * (max_len - len(self.documentation_list))
        qs_extended = self.question_sql_list + [None] * (max_len - len(self.question_sql_list))
        
        data = {
            'ddl': ddl_extended,
            'documentation': doc_extended,
            'question_sql': qs_extended
        }
        return pd.DataFrame(data)

    def remove_training_data(self, id: str, **kwargs) -> bool:
        return True

    def generate_embedding(self, text: str, **kwargs) -> list:
        # Retourne un vecteur d'embedding simple (liste de 0)
        return [0] * 10


class MyVanna(OpenAI_Chat, MyCustomVectorDB):  # Changement de l'ordre d'héritage
    def __init__(self, config=None):
        config = config or {}  # Si config est None, utilise un dictionnaire vide
        OpenAI_Chat.__init__(self, config=config)  # Initialise OpenAI_Chat en premier
        VannaBase.__init__(self, config=config)  # Initialise VannaBase directement
        self.ddl_list = []  # Initialise les listes directement ici
        self.documentation_list = []
        self.question_sql_list = []

vn = MyVanna(config={'api_key': os.getenv('OPENAI_API_KEY'), 'model': 'gpt-4o-mini'})


In [4]:
vn.connect_to_sqlite('sqlite-sakila.db')


In [5]:
df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")

for ddl in df_ddl['sql'].to_list():
  vn.train(ddl=ddl)

Adding ddl: CREATE TABLE actor (
  actor_id numeric NOT NULL ,
  first_name VARCHAR(45) NOT NULL,
  last_name VARCHAR(45) NOT NULL,
  last_update TIMESTAMP NOT NULL,
  PRIMARY KEY  (actor_id)
  )
Adding ddl: CREATE INDEX idx_actor_last_name ON actor(last_name)

Adding ddl: CREATE TRIGGER actor_trigger_ai AFTER INSERT ON actor
 BEGIN
  UPDATE actor SET last_update = DATETIME('NOW')  WHERE rowid = new.rowid;
 END
Adding ddl: CREATE TRIGGER actor_trigger_au AFTER UPDATE ON actor
 BEGIN
  UPDATE actor SET last_update = DATETIME('NOW')  WHERE rowid = new.rowid;
 END
Adding ddl: CREATE TABLE country (
  country_id SMALLINT NOT NULL,
  country VARCHAR(50) NOT NULL,
  last_update TIMESTAMP,
  PRIMARY KEY  (country_id)
)
Adding ddl: CREATE TRIGGER country_trigger_ai AFTER INSERT ON country
 BEGIN
  UPDATE country SET last_update = DATETIME('NOW')  WHERE rowid = new.rowid;
 END
Adding ddl: CREATE TRIGGER country_trigger_au AFTER UPDATE ON country
 BEGIN
  UPDATE country SET last_update = DATETIM

In [None]:
# Entraînement avec des paires question/SQL
training_data = [
    {
        "question": "Quels sont les films d'action les plus loués ?",
        "sql": """
        SELECT 
            film.title AS film_title, 
            category.name AS category_name, 
            COUNT(rental.rental_id) AS rental_count
        FROM film AS film
        JOIN film_category AS film_cat ON film.film_id = film_cat.film_id
        JOIN category AS category ON film_cat.category_id = category.category_id
        JOIN inventory AS inventory ON film.film_id = inventory.film_id
        JOIN rental AS rental ON inventory.inventory_id = rental.inventory_id
        WHERE category.name = 'Action'
        GROUP BY film.film_id, film.title, category.name
        ORDER BY rental_count DESC
        LIMIT 10
        """
    },
    {
        "question": "Liste des films qui n'ont jamais été loués",
        "sql": """
        SELECT 
            film.title AS film_title, 
            film.release_year AS release_year, 
            film.rental_rate AS rental_price
        FROM film AS film
        LEFT JOIN inventory AS inventory ON film.film_id = inventory.film_id
        LEFT JOIN rental AS rental ON inventory.inventory_id = rental.inventory_id
        WHERE rental.rental_id IS NULL
        """
    },
    {
    "question": "Quels sont les 10 films les plus loués ?",
    "sql": """
    SELECT 
        f.title AS film_title,
        COUNT(r.rental_id) AS rental_count
    FROM film AS f
    JOIN inventory AS i ON f.film_id = i.film_id
    JOIN rental AS r ON i.inventory_id = r.inventory_id
    GROUP BY f.film_id, f.title
    ORDER BY rental_count DESC
    LIMIT 10
    """
    },
    {
    "question": "Quels sont les 10 acteurs ayant joué dans le plus de films ?",
    "sql": """
    SELECT 
        a.first_name || ' ' || a.last_name AS actor_name,
        COUNT(fa.film_id) AS film_count
    FROM actor AS a
    JOIN film_actor AS fa ON a.actor_id = fa.actor_id
    GROUP BY a.actor_id
    ORDER BY film_count DESC
    LIMIT 10
    """
    },

    {
        "question": "Qui sont les 10 meilleurs clients en termes de montant dépensé ?",
        "sql": """
        SELECT 
            customer.first_name AS customer_first_name,
            customer.last_name AS customer_last_name,
            COUNT(rental.rental_id) AS total_rentals,
            SUM(payment.amount) AS total_spent
        FROM customer AS customer
        JOIN rental AS rental ON customer.customer_id = rental.customer_id
        JOIN payment AS payment ON rental.rental_id = payment.rental_id
        GROUP BY customer.customer_id, customer.first_name, customer.last_name
        ORDER BY total_spent DESC
        LIMIT 10
        """
    },
    {
    "question": "Quels clients ont dépensé le plus ?",
    "sql": """
    SELECT 
        c.first_name || ' ' || c.last_name AS customer_full_name,
        SUM(p.amount) AS total_spent
    FROM customer AS c
    JOIN payment AS p ON c.customer_id = p.customer_id
    GROUP BY c.customer_id, c.first_name, c.last_name
    ORDER BY total_spent DESC
    LIMIT 10
    """
    },

    # Questions sur la Géographie
    {
        "question": "Nombre de clients par pays",
        "sql": """
        SELECT 
            country.country AS country_name, 
            COUNT(DISTINCT customer.customer_id) AS total_customers
        FROM customer AS customer
        JOIN address AS address ON customer.address_id = address.address_id
        JOIN city AS city ON address.city_id = city.city_id
        JOIN country AS country ON city.country_id = country.country_id
        GROUP BY country.country
        ORDER BY total_customers DESC
        """
    },
    {
    "question": "Combien de clients y a-t-il par pays ?",
    "sql": """
    SELECT 
        co.country AS country_name,
        COUNT(DISTINCT cu.customer_id) AS customer_count
    FROM customer AS cu
    JOIN address AS a ON cu.address_id = a.address_id
    JOIN city AS ci ON a.city_id = ci.city_id
    JOIN country AS co ON ci.country_id = co.country_id
    GROUP BY co.country
    ORDER BY customer_count DESC
    """
    },

    {
        "question": "Chiffre d'affaires par catégorie de film",
        "sql": """
        SELECT 
            category.name AS category_name,
            COUNT(rental.rental_id) AS total_rentals,
            SUM(payment.amount) AS total_revenue
        FROM category AS category
        JOIN film_category AS film_cat ON category.category_id = film_cat.category_id
        JOIN film AS film ON film_cat.film_id = film.film_id
        JOIN inventory AS inventory ON film.film_id = inventory.film_id
        JOIN rental AS rental ON inventory.inventory_id = rental.inventory_id
        JOIN payment AS payment ON rental.rental_id = payment.rental_id
        GROUP BY category.category_id, category.name
        ORDER BY total_revenue DESC
        """
    },
    {
    "question": "Combien de films sont disponibles dans chaque magasin ?",
    "sql": """
    SELECT 
        s.store_id,
        COUNT(i.inventory_id) AS film_count
    FROM store AS s
    JOIN inventory AS i ON s.store_id = i.store_id
    GROUP BY s.store_id
    ORDER BY film_count DESC
    """
    },
    {
        "question": "Performance des vendeurs (nombre de locations et montant)",
        "sql": """
        SELECT 
            staff.first_name || ' ' || staff.last_name AS staff_full_name,
            COUNT(rental.rental_id) AS total_rentals,
            SUM(payment.amount) AS total_sales
        FROM staff AS staff
        JOIN rental AS rental ON staff.staff_id = rental.staff_id
        JOIN payment AS payment ON rental.rental_id = payment.rental_id
        GROUP BY staff.staff_id, staff.first_name, staff.last_name
        ORDER BY total_sales DESC
        """
    },
    {
        "question": "Durée moyenne de location par catégorie de film",
        "sql": """
        SELECT 
            category.name AS category_name,
            AVG(JULIANDAY(rental.return_date) - JULIANDAY(rental.rental_date)) AS avg_duration_days
        FROM category AS category
        JOIN film_category AS film_cat ON category.category_id = film_cat.category_id
        JOIN film AS film ON film_cat.film_id = film.film_id
        JOIN inventory AS inventory ON film.film_id = inventory.film_id
        JOIN rental AS rental ON inventory.inventory_id = rental.inventory_id
        WHERE rental.return_date IS NOT NULL
        GROUP BY category.category_id, category.name
        ORDER BY avg_duration_days DESC
        """
    },
    {
    "question": "Quels sont les films qui n'ont jamais été loués ?",
    "sql": """
    SELECT 
        f.title AS film_title
    FROM film AS f
    LEFT JOIN inventory AS i ON f.film_id = i.film_id
    LEFT JOIN rental AS r ON i.inventory_id = r.inventory_id
    WHERE r.rental_id IS NULL
    GROUP BY f.film_id
    """
    },
    {
    "question": "Quelle est la durée moyenne de location par catégorie de film ?",
    "sql": """
    SELECT 
        c.name AS category_name,
        AVG(JULIANDAY(r.return_date) - JULIANDAY(r.rental_date)) AS avg_rental_duration
    FROM category AS c
    JOIN film_category AS fc ON c.category_id = fc.category_id
    JOIN film AS f ON fc.film_id = f.film_id
    JOIN inventory AS i ON f.film_id = i.film_id
    JOIN rental AS r ON i.inventory_id = r.inventory_id
    WHERE r.return_date IS NOT NULL
    GROUP BY c.category_id
    ORDER BY avg_rental_duration DESC
    """
    },
    {
        "question": "Films les plus rentables (ratio revenu/coût)",
        "sql": """
        SELECT 
            film.title AS film_title,
            COUNT(rental.rental_id) AS rental_count,
            SUM(payment.amount) AS total_revenue,
            film.replacement_cost AS replacement_cost,
            ROUND(SUM(payment.amount) / film.replacement_cost, 2) AS roi
        FROM film AS film
        JOIN inventory AS inventory ON film.film_id = inventory.film_id
        JOIN rental AS rental ON inventory.inventory_id = rental.inventory_id
        JOIN payment AS payment ON rental.rental_id = payment.rental_id
        GROUP BY film.film_id, film.title, film.replacement_cost
        HAVING rental_count > 5
        ORDER BY roi DESC
        LIMIT 10
        """
    }
]

# Entraînement du modèle avec les paires question/SQL
for item in training_data:
    print(f"Training with question: {item['question']}")
    vn.train(question=item['question'], sql=item['sql'])

Training with question: Quels sont les films d'action les plus loués ?
Training with question: Liste des films qui n'ont jamais été loués
Training with question: Qui sont les 10 meilleurs clients en termes de montant dépensé ?
Training with question: Nombre de clients par pays
Training with question: Chiffre d'affaires par catégorie de film
Training with question: Performance des vendeurs (nombre de locations et montant)
Training with question: Durée moyenne de location par catégorie de film
Training with question: Films les plus rentables (ratio revenu/coût)


In [7]:
# VannaFlaskApp(vn, allow_llm_to_see_data=True).run()

In [8]:
class SimplePassword(AuthInterface):
    def __init__(self, users: dict):
        self.users = users

    def get_user(self, flask_request) -> any:
        return flask_request.cookies.get('user')

    def is_logged_in(self, user: any) -> bool:
        return user is not None

    def override_config_for_user(self, user: any, config: dict) -> dict:
        return config

    def login_form(self) -> str:
        return '''
        <div class="min-h-screen flex items-center justify-center bg-gradient-to-br from-indigo-500 to-purple-600 p-4">
            <div class="bg-white dark:bg-gray-800 rounded-2xl shadow-xl max-w-md w-full p-8 space-y-8">
                <!-- Logo/Icon -->
                <div class="text-center">
                    <div class="mx-auto h-16 w-16 bg-gradient-to-r from-pink-500 to-purple-500 rounded-full flex items-center justify-center">
                        <svg class="h-8 w-8 text-white" fill="none" stroke="currentColor" viewBox="0 0 24 24">
                            <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 15v2m-6 4h12a2 2 0 002-2v-6a2 2 0 00-2-2H6a2 2 0 00-2 2v6a2 2 0 002 2zm10-10V7a4 4 0 00-8 0v4h8z"/>
                        </svg>
                    </div>
                    <h2 class="mt-6 text-3xl font-extrabold text-gray-900 dark:text-white">
                        Bienvenue
                    </h2>
                    <p class="mt-2 text-sm text-gray-500 dark:text-gray-400">
                        Connectez-vous à votre compte
                    </p>
                </div>

                <!-- Formulaire -->
                <form class="mt-8 space-y-6" action="/auth/login" method="POST">
                    <div class="space-y-4">
                        <!-- Email -->
                        <div>
                            <label class="block text-sm font-medium text-gray-700 dark:text-gray-300">
                                Adresse email
                            </label>
                            <div class="mt-1 relative">
                                <input 
                                    type="email" 
                                    name="email" 
                                    required 
                                    class="appearance-none block w-full px-4 py-3 border border-gray-300 dark:border-gray-600 rounded-xl shadow-sm placeholder-gray-400 
                                    focus:outline-none focus:ring-2 focus:ring-purple-500 focus:border-transparent
                                    bg-white dark:bg-gray-700 text-gray-900 dark:text-white"
                                    placeholder="vous@exemple.com"
                                >
                            </div>
                        </div>

                        <!-- Mot de passe -->
                        <div>
                            <label class="block text-sm font-medium text-gray-700 dark:text-gray-300">
                                Mot de passe
                            </label>
                            <div class="mt-1 relative">
                                <input 
                                    type="password" 
                                    name="password" 
                                    required 
                                    class="appearance-none block w-full px-4 py-3 border border-gray-300 dark:border-gray-600 rounded-xl shadow-sm placeholder-gray-400 
                                    focus:outline-none focus:ring-2 focus:ring-purple-500 focus:border-transparent
                                    bg-white dark:bg-gray-700 text-gray-900 dark:text-white"
                                    placeholder="••••••••"
                                >
                            </div>
                        </div>

                        <!-- Options de connexion -->
                        <div class="flex items-center justify-between">
                            <div class="flex items-center">
                                <input 
                                    type="checkbox" 
                                    id="remember-me" 
                                    name="remember-me" 
                                    class="h-4 w-4 text-purple-600 focus:ring-purple-500 border-gray-300 rounded cursor-pointer"
                                >
                                <label for="remember-me" class="ml-2 block text-sm text-gray-700 dark:text-gray-300 cursor-pointer">
                                    Se souvenir de moi
                                </label>
                            </div>

                            <div class="text-sm">
                                <a href="#" class="font-medium text-purple-600 hover:text-purple-500">
                                    Mot de passe oublié?
                                </a>
                            </div>
                        </div>
                    </div>

                    <!-- Bouton de connexion -->
                    <div>
                        <button 
                            type="submit" 
                            class="w-full flex justify-center py-3 px-4 border border-transparent rounded-xl shadow-sm text-sm font-medium text-white 
                            bg-gradient-to-r from-purple-500 to-pink-500 hover:from-purple-600 hover:to-pink-600 
                            focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-purple-500 
                            transform transition-all duration-200 hover:scale-[1.02]"
                        >
                            Se connecter
                        </button>
                    </div>
                </form>

                <!-- Lien d'inscription -->
                <div class="text-center mt-4">
                    <p class="text-sm text-gray-600 dark:text-gray-400">
                        Pas encore de compte? 
                        <a href="#" class="font-medium text-purple-600 hover:text-purple-500">
                            Créer un compte
                        </a>
                    </p>
                </div>
            </div>
        </div>
        '''

    def login_handler(self, flask_request) -> str:
        email = flask_request.form['email']
        password = flask_request.form['password']
        # Find the user and password in the users dict
        for user in self.users:
            if user["email"] == email and user["password"] == password:
                response = flask.make_response('Logged in as ' + email)
                response.set_cookie('user', email)
                # Redirect to the main page
                response.headers['Location'] = '/'
                response.status_code = 302
                return response
        else:
            return 'Login failed'

    def callback_handler(self, flask_request) -> str:
        user = flask_request.args['user']
        response = flask.make_response('Logged in as ' + user)
        response.set_cookie('user', user)
        return response

    def logout_handler(self, flask_request) -> str:
        response = flask.make_response('Logged out')
        response.delete_cookie('user')
        return response

VannaFlaskApp(
    vn=vn,
    auth=SimplePassword(users=[{"email": "admin@example.com", "password": "password"}]),
    allow_llm_to_see_data=True,
    title="E2 - Sakila",
    subtitle="Interrogez la base de données Sakila",
    show_training_data=True,
    sql=False,
    table=True,
    chart=False,
    summarization=True,
    ask_results_correct=True,
).run()

Your app is running at:
http://localhost:8084
 * Serving Flask app 'vanna.flask'
 * Debug mode: on
Using model gpt-4o-mini for 4063.25 tokens (approx)
Using model gpt-4o-mini for 138.75 tokens (approx)
