In [1]:
import sys

sys.path.append("../")

In [2]:
from src.text_to_sql import Postgre_VectorStore, Groq_Chat
from src.settings import settings

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config = {
    "api_key": settings.groq_api_key,
    "model": "llama-3.1-8b-instant",
    "database_url": settings.postgres_url,
    "embedding_model": "sentence-transformers/all-MiniLM-L6-v2"
}

class VannaGroq(Postgre_VectorStore, Groq_Chat):

    def __init__(self, config=None):
        
        Postgre_VectorStore.__init__(self, config=config)
        Groq_Chat.__init__(self, config=config)


vn = VannaGroq(config=config)

VectorStore bound to postgresql://myuser:mypassword@localhost:5432/postgres


In [4]:
from urllib.parse import urlparse

parsed_url = urlparse(settings.postgres_url)
vn.connect_to_postgres(
    host=parsed_url.hostname, 
    dbname=parsed_url.path[1:], 
    user=parsed_url.username, 
    password=parsed_url.password, 
    port=parsed_url.port
)


In [5]:
vn.run_sql("SELECT * FROM fraud_data LIMIT 3")

Unnamed: 0,trans_date_trans_time,cc_num,merchant,category,amt,first,last,gender,street,city,...,lat,long,city_pop,job,dob,trans_num,unix_time,merch_lat,merch_long,is_fraud
0,2020-09-05 12:16:08,4410582919485061752,fraud_Friesen Ltd,health_fitness,2.81,David,Cole,M,7955 Allen Orchard Apt. 336,Springfield,...,32.3697,-81.3618,9034,"Therapist, occupational",1969-03-02,2a2905ca7bbac865a8212f942831094f,1378383368,31.64911,-81.922042,False
1,2020-08-16 21:57:53,30551643947183,fraud_Morissette PLC,shopping_pos,2.78,Morgan,Smith,F,1441 Bradley Place,Grover,...,35.1836,-81.4552,5621,Toxicologist,1973-11-14,9983d4c6a4d15774ffec9d2999d578ae,1376690273,34.717705,-81.627907,False
2,2020-07-23 00:49:26,630451534402,fraud_Kuvalis Ltd,gas_transport,53.64,Rachel,Daniels,F,561 Little Plain Apt. 738,Wetmore,...,46.3535,-86.6345,765,Immunologist,1972-06-12,f598c71191a3ca4e78e8cc3083dcf8d1,1374540566,45.77693,-86.230599,False


# 'Train' the vanna (not actually training, just stacking information as a sql pair and vector)

In [6]:
vn.train(ddl="""
CREATE TABLE public.fraud_data (
	trans_date_trans_time timestamp NULL,
	cc_num int8 NULL,
	merchant text NULL,
	category text NULL,
	amt float8 NULL,
	"first" text NULL,
	"last" text NULL,
	gender text NULL,
	street text NULL,
	city text NULL,
	state text NULL,
	zip int8 NULL,
	lat float8 NULL,
	long float8 NULL,
	city_pop int8 NULL,
	job text NULL,
	dob text NULL,
	trans_num text NULL,
	unix_time int8 NULL,
	merch_lat float8 NULL,
	merch_long float8 NULL,
	is_fraud bool NULL
)
""")

Adding ddl: 
CREATE TABLE public.fraud_data (
	trans_date_trans_time timestamp NULL,
	cc_num int8 NULL,
	merchant text NULL,
	category text NULL,
	amt float8 NULL,
	"first" text NULL,
	"last" text NULL,
	gender text NULL,
	street text NULL,
	city text NULL,
	state text NULL,
	zip int8 NULL,
	lat float8 NULL,
	long float8 NULL,
	city_pop int8 NULL,
	job text NULL,
	dob text NULL,
	trans_num text NULL,
	unix_time int8 NULL,
	merch_lat float8 NULL,
	merch_long float8 NULL,
	is_fraud bool NULL
)



In [7]:
vn.train(documentation="This is a simulated credit card transaction dataset containing legitimate and fraud transactions from the duration 1st Jan 2019 - 31st Dec 2020. It covers credit cards of 1000 customers doing transactions with a pool of 800 merchants.")

Adding documentation....


In [8]:
training_data = [
    {
        "question": "How many transactions are there?",
        "sql": """
            SELECT COUNT(*) AS total_tx
            FROM public.fraud_data;
        """
    },
    {
        "question": "How many unique card numbers and merchants?",
        "sql": """
            SELECT COUNT(DISTINCT cc_num) AS cards,
                   COUNT(DISTINCT merchant) AS merchants
            FROM public.fraud_data;
        """
    },
    {
        "question": "What is the overall fraud rate?",
        "sql": """
            SELECT AVG((is_fraud)::int)::numeric(10,4) AS fraud_rate
            FROM public.fraud_data;
        """
    },
    {
        "question": "What is the daily fraud rate over the last two years?",
        "sql": """
            WITH bounds AS (
                SELECT (MAX(trans_date_trans_time) - INTERVAL '2 years') AS start_dt
                FROM public.fraud_data
            )
            SELECT date_trunc('day', trans_date_trans_time)::date AS day,
                COUNT(*) AS n_tx,
                AVG((is_fraud)::int)::numeric(10,4) AS fraud_rate
            FROM public.fraud_data f
            JOIN bounds b ON f.trans_date_trans_time >= b.start_dt
            GROUP BY 1
            ORDER BY 1;
        """
    },
    {
        "question": "What is the monthly fraud rate over the last two years?",
        "sql": """
            WITH bounds AS (
                SELECT (MAX(trans_date_trans_time) - INTERVAL '2 years') AS start_dt
                FROM public.fraud_data
            )
            SELECT date_trunc('month', trans_date_trans_time)::date AS month,
                COUNT(*) AS n_tx,
                SUM((is_fraud)::int) AS fraud_tx,
                AVG((is_fraud)::int)::numeric(10,4) AS fraud_rate
            FROM public.fraud_data f
            JOIN bounds b ON f.trans_date_trans_time >= b.start_dt
            GROUP BY 1
            ORDER BY 1;
        """
    },
    {
        "question": "Which merchants have the highest number of fraudulent transactions?",
        "sql": """
            SELECT merchant,
                   COUNT(*) AS n_tx,
                   SUM((is_fraud)::int) AS fraud_tx,
                   (SUM((is_fraud)::int)::float / COUNT(*)) AS fraud_rate
            FROM public.fraud_data
            GROUP BY merchant
            HAVING COUNT(*) >= 100
            ORDER BY fraud_tx DESC
            LIMIT 20;
        """
    },
    {
        "question": "Which merchant categories have the highest fraud rate?",
        "sql": """
            SELECT category,
                   COUNT(*) AS n_tx,
                   AVG((is_fraud)::int) AS fraud_rate
            FROM public.fraud_data
            GROUP BY category
            HAVING COUNT(*) >= 1000
            ORDER BY fraud_rate DESC
            LIMIT 20;
        """
    },
    {
        "question": "How much higher are fraud rates when the transaction counterpart is located far away (proxy for outside EEA)?",
        "sql": """
            WITH dist AS (
              SELECT is_fraud,
                     6371 * acos(
                       cos(radians(lat)) * cos(radians(merch_lat)) *
                       cos(radians(merch_long) - radians(long)) +
                       sin(radians(lat)) * sin(radians(merch_lat))
                     ) AS km
              FROM public.fraud_data
              WHERE lat IS NOT NULL AND long IS NOT NULL
                AND merch_lat IS NOT NULL AND merch_long IS NOT NULL
            )
            SELECT CASE
                     WHEN km >= 1000 THEN 'outside_region'
                     ELSE 'inside_region'
                   END AS region_flag,
                   COUNT(*) AS n_tx,
                   AVG((is_fraud)::int) AS fraud_rate
            FROM dist
            GROUP BY 1
            ORDER BY 1;
        """
    },
    {
        "question": "What share of total fraud value in H1 2023 was due to cross-border transactions (proxy via distance)?",
        "sql": """
            WITH base AS (
              SELECT *,
                     6371 * acos(
                       cos(radians(lat)) * cos(radians(merch_lat)) *
                       cos(radians(merch_long) - radians(long)) +
                       sin(radians(lat)) * sin(radians(merch_lat))
                     ) AS km
              FROM public.fraud_data
              WHERE trans_date_trans_time >= DATE '2023-01-01'
                AND trans_date_trans_time <  DATE '2023-07-01'
                AND is_fraud = TRUE
            )
            SELECT SUM(CASE WHEN km >= 1000 THEN amt ELSE 0 END) / SUM(amt) AS share_cross_border
            FROM base;
        """
    }
]

for item in training_data:
    vn.train(question=item["question"], sql=item["sql"])

In [9]:
result = vn.generate_sql("What is the monthly fraud rate over the last two years?", allow_llm_to_see_data=True)

SQL Prompt: [{'role': 'system', 'content': 'You are a PostgreSQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \n\nCREATE TABLE public.fraud_data (\n\ttrans_date_trans_time timestamp NULL,\n\tcc_num int8 NULL,\n\tmerchant text NULL,\n\tcategory text NULL,\n\tamt float8 NULL,\n\t"first" text NULL,\n\t"last" text NULL,\n\tgender text NULL,\n\tstreet text NULL,\n\tcity text NULL,\n\tstate text NULL,\n\tzip int8 NULL,\n\tlat float8 NULL,\n\tlong float8 NULL,\n\tcity_pop int8 NULL,\n\tjob text NULL,\n\tdob text NULL,\n\ttrans_num text NULL,\n\tunix_time int8 NULL,\n\tmerch_lat float8 NULL,\n\tmerch_long float8 NULL,\n\tis_fraud bool NULL\n)\n\n\n\n===Additional Context \n\nThis is a simulated credit card transaction dataset containing legitimate and fraud transactions from the duration 1st Jan 2019 - 31st Dec 2020. It covers credit cards of 10