In [1]:
import vanna
from vanna.remote import VannaDefault

In [3]:

vn = VannaDefault(model='model_name', api_key='your-api-key')


In [4]:
vn.connect_to_postgres(host='localhost', dbname='postgres', user='postgres', password='faisal', port='5432')

In [24]:
import psycopg2
from psycopg2 import sql

def get_database_schema_as_create_statements():
    # Connect to the database
    conn = psycopg2.connect(
        dbname="db_name,
        user="user_name",
        password="password",
        host="localhost",
        port="port_number"
    )

    cursor = conn.cursor()
    cursor.execute("""
        SELECT table_name
        FROM information_schema.tables
        WHERE table_schema = 'public';
    """)
    tables = cursor.fetchall()
    schema_statements = {}

    for table in tables:
        table_name = table[0]
        get_create_statement_query = sql.SQL("""
            SELECT 'CREATE TABLE ' || tablename || ' (' ||
                array_to_string(array_agg(column_def), ', ') || ');'
            FROM (
                SELECT table_name AS tablename,
                       column_name || ' ' || data_type || 
                       COALESCE('(' || character_maximum_length || ')', '') || 
                       CASE WHEN is_nullable = 'NO' THEN ' NOT NULL' ELSE '' END AS column_def
                FROM information_schema.columns
                WHERE table_schema = 'public' AND table_name = %s
                ORDER BY ordinal_position
            ) AS table_definitions
            GROUP BY tablename;
        """)

        cursor.execute(get_create_statement_query, (table_name,))
        create_statement = cursor.fetchone()[0]
        schema_statements[table_name] = create_statement

    cursor.close()
    conn.close()

    # Store all CREATE TABLE statements in a single variable
    schema_output = ""
    for table_name, statement in schema_statements.items():
        schema_output += f"-- Table: {table_name}\n{statement}\n\n"

    return schema_output

# Store the schema in a variable





In [25]:
database_schema = get_database_schema_as_create_statements()

In [26]:
database_schema

'-- Table: rental\nCREATE TABLE rental (rental_id integer NOT NULL, rental_date timestamp without time zone NOT NULL, inventory_id integer NOT NULL, customer_id smallint NOT NULL, return_date timestamp without time zone, staff_id smallint NOT NULL, last_update timestamp without time zone NOT NULL);\n\n-- Table: staff\nCREATE TABLE staff (staff_id integer NOT NULL, first_name character varying(45) NOT NULL, last_name character varying(45) NOT NULL, address_id smallint NOT NULL, email character varying(50), store_id smallint NOT NULL, active boolean NOT NULL, username character varying(16) NOT NULL, password character varying(40), last_update timestamp without time zone NOT NULL, picture bytea);\n\n-- Table: payment\nCREATE TABLE payment (payment_id integer NOT NULL, customer_id smallint NOT NULL, staff_id smallint NOT NULL, rental_id integer NOT NULL, amount numeric NOT NULL, payment_date timestamp without time zone NOT NULL);\n\n-- Table: actor\nCREATE TABLE actor (actor_id integer NOT

In [27]:

vn.train(ddl=f"{database_schema}")


vn.train(documentation="The DVD Rental database is a schema that represents a DVD rental business. It contains tables for actors, films, rentals, payments, customers, staff, stores, and inventory. Each table has specific columns that reflect real-world attributes of a rental store.")

# You can also add SQL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL.
vn.train(sql="""
    SELECT * FROM film WHERE rental_rate > 5.00;
    
    SELECT first_name, last_name 
    FROM actor 
    WHERE actor_id IN (
        SELECT actor_id 
        FROM film_actor 
        WHERE film_id = 1
    );

    SELECT customer_id, SUM(amount) as total_spent 
    FROM payment 
    WHERE payment_date >= '2022-01-01' 
    GROUP BY customer_id;

    SELECT c.name AS category, SUM(p.amount) AS total_revenue
    FROM payment p
    JOIN rental r ON p.rental_id = r.rental_id
    JOIN inventory i ON r.inventory_id = i.inventory_id
    JOIN film f ON i.film_id = f.film_id
    JOIN film_category fc ON f.film_id = fc.film_id
    JOIN category c ON fc.category_id = c.category_id
    GROUP BY c.name;

    SELECT customer_id, COUNT(rental_id) AS rental_count
    FROM rental
    GROUP BY customer_id
    ORDER BY rental_count DESC;
""")


Adding ddl: -- Table: rental
CREATE TABLE rental (rental_id integer NOT NULL, rental_date timestamp without time zone NOT NULL, inventory_id integer NOT NULL, customer_id smallint NOT NULL, return_date timestamp without time zone, staff_id smallint NOT NULL, last_update timestamp without time zone NOT NULL);

-- Table: staff
CREATE TABLE staff (staff_id integer NOT NULL, first_name character varying(45) NOT NULL, last_name character varying(45) NOT NULL, address_id smallint NOT NULL, email character varying(50), store_id smallint NOT NULL, active boolean NOT NULL, username character varying(16) NOT NULL, password character varying(40), last_update timestamp without time zone NOT NULL, picture bytea);

-- Table: payment
CREATE TABLE payment (payment_id integer NOT NULL, customer_id smallint NOT NULL, staff_id smallint NOT NULL, rental_id integer NOT NULL, amount numeric NOT NULL, payment_date timestamp without time zone NOT NULL);

-- Table: actor
CREATE TABLE actor (actor_id integer NO

'6bfb8039021d2674285e65d03fd487d3-sql'

In [28]:
training_data = vn.get_training_data()
training_data

Unnamed: 0,id,training_data_type,question,content
0,2365236-doc,documentation,,The DVD Rental database is a schema that repre...
1,485635-sql,sql,What are the total revenue and total spent by ...,\n SELECT * FROM film WHERE rental_rate > 5...
2,341813-ddl,ddl,,-- Table: rental\nCREATE TABLE rental (rental_...


In [None]:
from vanna.flask import VannaFlaskApp
app = VannaFlaskApp(vn)
app.run()