In [1]:
import sys

sys.path.append("../")

In [2]:
from vanna.qdrant import Qdrant_VectorStore
from qdrant_client import QdrantClient
from vanna.mistral import Mistral as VannaMistral

from app.settings import settings

In [3]:
import os
from mistralai import Mistral

api_key = settings.mistral_api_key
client = Mistral(api_key=api_key)

model = "codestral-latest"
message = [{"role": "user", "content": "Write a function for fibonacci"}]
chat_response = client.chat.complete(
    model = model,
    messages = message
)

In [4]:
chat_response.choices[0].message.content

'# Fibonacci Function\n\nHere\'s a Python function to calculate the nth Fibonacci number:\n\n```python\ndef fibonacci(n):\n    """\n    Calculate the nth Fibonacci number.\n\n    Args:\n        n (int): The position in the Fibonacci sequence (0-based index)\n\n    Returns:\n        int: The nth Fibonacci number\n    """\n    if n < 0:\n        raise ValueError("Input must be a non-negative integer")\n    elif n == 0:\n        return 0\n    elif n == 1:\n        return 1\n\n    a, b = 0, 1\n    for _ in range(2, n + 1):\n        a, b = b, a + b\n    return b\n```\n\n## Alternative Implementations\n\n### Recursive (less efficient for large n)\n```python\ndef fibonacci_recursive(n):\n    if n < 0:\n        raise ValueError("Input must be a non-negative integer")\n    if n == 0:\n        return 0\n    elif n == 1:\n        return 1\n    else:\n        return fibonacci_recursive(n-1) + fibonacci_recursive(n-2)\n```\n\n### Using memoization (optimized recursive)\n```python\nfrom functools im

In [5]:
# from sentence_transformers import SentenceTransformer

# embedding_model = SentenceTransformer('Qwen/Qwen3-Embedding-0.6B')
# query_vector = embedding_model.encode("how fraud is committed worldwide?")
# query_vector

In [6]:
qdran_client = QdrantClient(
    url=settings.qdrant_url
)

In [7]:
# hits = qdran_client.search(
#    collection_name="my_documents",
#    query_vector=query_vector,
#    limit=5  # Return 5 closest points
# )

# hits

In [8]:
class MyVanna(Qdrant_VectorStore, VannaMistral):
    def __init__(self, config=None):
        Qdrant_VectorStore.__init__(self, config=config)
        VannaMistral.__init__(self, config={'api_key': settings.mistral_api_key, 'model': 'codestral-latest'})

vn = MyVanna(config={'client': qdran_client})

In [None]:
from urllib.parse import urlparse

parsed_url = urlparse(settings.postgres_url)
vn.connect_to_postgres(
    host=parsed_url.hostname, 
    dbname=parsed_url.path[1:], 
    user=parsed_url.username, 
    password=parsed_url.password, 
    port=parsed_url.port
)


In [10]:
vn.run_sql("SELECT * FROM fraud_data LIMIT 3")

Unnamed: 0,trans_date_trans_time,cc_num,merchant,category,amt,first,last,gender,street,city,...,lat,long,city_pop,job,dob,trans_num,unix_time,merch_lat,merch_long,is_fraud
0,2020-06-21 12:14:25,2291163933867244,fraud_Kirlin and Sons,personal_care,2.86,Jeff,Elliott,M,351 Darlene Green,Columbia,...,33.9659,-80.9355,333497,Mechanical engineer,1968-03-19,2da90c7d74bd46a0caf3777415b3ebd3,1371816865,33.986391,-81.200714,False
1,2020-06-21 12:14:33,3573030041201292,fraud_Sporer-Keebler,personal_care,29.84,Joanne,Williams,F,3638 Marsh Union,Altonah,...,40.3207,-110.436,302,"Sales professional, IT",1990-01-17,324cc204407e99f51b0d6ca0055005e7,1371816873,39.450498,-109.960431,False
2,2020-06-21 12:14:53,3598215285024754,"fraud_Swaniawski, Nitzsche and Welch",health_fitness,41.28,Ashley,Lopez,F,9333 Valentine Point,Bellmore,...,40.6729,-73.5365,34496,"Librarian, public",1970-10-21,c81755dbbbea9d5c77f094348a7579be,1371816893,40.49581,-74.196111,False


# 'Train' the vanna (not actually training, just stacking information as a sql pair and vector)

In [11]:
# vn.train(ddl="""
# CREATE TABLE public.fraud_data (
# 	trans_date_trans_time timestamp NULL,
# 	cc_num int8 NULL,
# 	merchant text NULL,
# 	category text NULL,
# 	amt float8 NULL,
# 	"first" text NULL,
# 	"last" text NULL,
# 	gender text NULL,
# 	street text NULL,
# 	city text NULL,
# 	state text NULL,
# 	zip int8 NULL,
# 	lat float8 NULL,
# 	long float8 NULL,
# 	city_pop int8 NULL,
# 	job text NULL,
# 	dob text NULL,
# 	trans_num text NULL,
# 	unix_time int8 NULL,
# 	merch_lat float8 NULL,
# 	merch_long float8 NULL,
# 	is_fraud bool NULL
# )
# """)

In [12]:
# vn.train(documentation="This is a simulated credit card transaction dataset containing legitimate and fraud transactions from the duration 1st Jan 2019 - 31st Dec 2020. It covers credit cards of 1000 customers doing transactions with a pool of 800 merchants.")

In [13]:
training_data = [
    {
        "question": "How many transactions are there?",
        "sql": """
            SELECT COUNT(*) AS total_tx
            FROM public.fraud_data;
        """
    },
    {
        "question": "How many unique card numbers and merchants?",
        "sql": """
            SELECT COUNT(DISTINCT cc_num) AS cards,
                   COUNT(DISTINCT merchant) AS merchants
            FROM public.fraud_data;
        """
    },
    {
        "question": "What is the overall fraud rate?",
        "sql": """
            SELECT AVG((is_fraud)::int)::numeric(10,4) AS fraud_rate
            FROM public.fraud_data;
        """
    },
    {
        "question": "What is the daily fraud rate over the last two years?",
        "sql": """
            WITH bounds AS (
                SELECT (MAX(trans_date_trans_time) - INTERVAL '2 years') AS start_dt
                FROM public.fraud_data
            )
            SELECT date_trunc('day', trans_date_trans_time)::date AS day,
                COUNT(*) AS n_tx,
                AVG((is_fraud)::int)::numeric(10,4) AS fraud_rate
            FROM public.fraud_data f
            JOIN bounds b ON f.trans_date_trans_time >= b.start_dt
            GROUP BY 1
            ORDER BY 1;
        """
    },
    {
        "question": "What is the monthly fraud rate over the last two years?",
        "sql": """
            WITH bounds AS (
                SELECT (MAX(trans_date_trans_time) - INTERVAL '2 years') AS start_dt
                FROM public.fraud_data
            )
            SELECT date_trunc('month', trans_date_trans_time)::date AS month,
                COUNT(*) AS n_tx,
                SUM((is_fraud)::int) AS fraud_tx,
                AVG((is_fraud)::int)::numeric(10,4) AS fraud_rate
            FROM public.fraud_data f
            JOIN bounds b ON f.trans_date_trans_time >= b.start_dt
            GROUP BY 1
            ORDER BY 1;
        """
    },
    {
        "question": "Which merchants have the highest number of fraudulent transactions?",
        "sql": """
            SELECT merchant,
                   COUNT(*) AS n_tx,
                   SUM((is_fraud)::int) AS fraud_tx,
                   (SUM((is_fraud)::int)::float / COUNT(*)) AS fraud_rate
            FROM public.fraud_data
            GROUP BY merchant
            HAVING COUNT(*) >= 100
            ORDER BY fraud_tx DESC
            LIMIT 20;
        """
    },
    {
        "question": "Which merchant categories have the highest fraud rate?",
        "sql": """
            SELECT category,
                   COUNT(*) AS n_tx,
                   AVG((is_fraud)::int) AS fraud_rate
            FROM public.fraud_data
            GROUP BY category
            HAVING COUNT(*) >= 1000
            ORDER BY fraud_rate DESC
            LIMIT 20;
        """
    },
    {
        "question": "How much higher are fraud rates when the transaction counterpart is located far away (proxy for outside EEA)?",
        "sql": """
            WITH dist AS (
              SELECT is_fraud,
                     6371 * acos(
                       cos(radians(lat)) * cos(radians(merch_lat)) *
                       cos(radians(merch_long) - radians(long)) +
                       sin(radians(lat)) * sin(radians(merch_lat))
                     ) AS km
              FROM public.fraud_data
              WHERE lat IS NOT NULL AND long IS NOT NULL
                AND merch_lat IS NOT NULL AND merch_long IS NOT NULL
            )
            SELECT CASE
                     WHEN km >= 1000 THEN 'outside_region'
                     ELSE 'inside_region'
                   END AS region_flag,
                   COUNT(*) AS n_tx,
                   AVG((is_fraud)::int) AS fraud_rate
            FROM dist
            GROUP BY 1
            ORDER BY 1;
        """
    },
    {
        "question": "What share of total fraud value in H1 2023 was due to cross-border transactions (proxy via distance)?",
        "sql": """
            WITH base AS (
              SELECT *,
                     6371 * acos(
                       cos(radians(lat)) * cos(radians(merch_lat)) *
                       cos(radians(merch_long) - radians(long)) +
                       sin(radians(lat)) * sin(radians(merch_lat))
                     ) AS km
              FROM public.fraud_data
              WHERE trans_date_trans_time >= DATE '2023-01-01'
                AND trans_date_trans_time <  DATE '2023-07-01'
                AND is_fraud = TRUE
            )
            SELECT SUM(CASE WHEN km >= 1000 THEN amt ELSE 0 END) / SUM(amt) AS share_cross_border
            FROM base;
        """
    }
]

In [14]:
# for item in training_data:
#     vn.train(question=item["question"], sql=item["sql"])


In [15]:
def get_txt2sql_answer(question):
    sql = vn.generate_sql(question)
    return vn.run_sql(sql)

In [16]:
result = vn.ask("What is the monthly fraud rate over the last two years?", visualize=False, allow_llm_to_see_data=True, auto_train=False)

SQL Prompt: [{'role': 'system', 'content': 'You are a PostgreSQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \n\nCREATE TABLE public.fraud_data (\n\ttrans_date_trans_time timestamp NULL,\n\tcc_num int8 NULL,\n\tmerchant text NULL,\n\tcategory text NULL,\n\tamt float8 NULL,\n\t"first" text NULL,\n\t"last" text NULL,\n\tgender text NULL,\n\tstreet text NULL,\n\tcity text NULL,\n\tstate text NULL,\n\tzip int8 NULL,\n\tlat float8 NULL,\n\tlong float8 NULL,\n\tcity_pop int8 NULL,\n\tjob text NULL,\n\tdob text NULL,\n\ttrans_num text NULL,\n\tunix_time int8 NULL,\n\tmerch_lat float8 NULL,\n\tmerch_long float8 NULL,\n\tis_fraud bool NULL\n)\n\n\n\n===Additional Context \n\nThis is a simulated credit card transaction dataset containing legitimate and fraud transactions from the duration 1st Jan 2019 - 31st Dec 2020. It covers credit cards of 10

In [17]:
result[1]

Unnamed: 0,month,n_tx,fraud_tx,fraud_rate
0,2020-06-01,30058,133,0.0044
1,2020-07-01,85848,321,0.0037
2,2020-08-01,88759,415,0.0047
3,2020-09-01,69533,340,0.0049
4,2020-10-01,69348,384,0.0055
5,2020-11-01,72635,294,0.004
6,2020-12-01,139538,258,0.0018
