In [None]:
import warnings
warnings.filterwarnings("ignore")

import logging
logging.basicConfig(level=logging.ERROR)

import os
import json
import asyncio
from typing import List, Dict, Any
from dotenv import load_dotenv

from google.adk.agents import Agent
from google.adk.sessions import InMemorySessionService
from google.adk.runners import Runner
from google.genai import types

load_dotenv()
api_key = os.environ['GOOGLE_API_KEY']


In [2]:
import pandas as pd
import json
df = pd.read_csv(r"C:\Users\AKSHAT SHAW\OneDrive - iitr.ac.in\Desktop\Side-Projects\Agents\nl2sql_agent\Copy of Schema_dump_(1)(1).csv")
    

SCHEMA_SAMPLES = {}
for table_name, group in df.groupby('table_name'):
    SCHEMA_SAMPLES[table_name] = {
        'columns': group['column_name'].tolist(),
        'description': f"Schema for {table_name} table",
        'column_details': group.to_dict(orient='records')
    }
    

In [3]:

AGENT_MODEL = "gemini-2.0-flash"


SQL_SAMPLES = [
    {
        "question": "Find all users who registered in the last month",
        "sql": "SELECT user_id, name, email FROM users WHERE registration_date >= DATE_SUB(CURRENT_DATE(), INTERVAL 1 MONTH)"
    },
    {
        "question": "Count the number of orders by status",
        "sql": "SELECT status, COUNT(*) as count FROM orders GROUP BY status ORDER BY count DESC"
    },
    {
        "question": "Find the top 5 products by sales revenue",
        "sql": """
        SELECT p.product_id, p.name, SUM(o.quantity * o.price) as revenue
        FROM order_items o
        JOIN products p ON o.product_id = p.product_id
        GROUP BY p.product_id, p.name
        ORDER BY revenue DESC
        LIMIT 5
        """
    }
]


CUSTOM_INSTRUCTIONS = """
When generating SQL queries, follow these guidelines:
• Include driver license details (license_number, issue_date, expiration_date)
• Format dates according to SQL standards (YYYY-MM-DD)
• Include vehicle registration details when relevant
• Use proper table aliases for clarity
• Include comments explaining complex logic
• Follow best practices for SQL performance
• Use appropriate JOIN types (INNER, LEFT, RIGHT) based on the requirements
• Format the SQL query with proper indentation and line breaks for readability
"""

def search_similar_examples(question: str, samples: List[Dict[str, str]], top_k: int = 3) -> List[Dict[str, str]]:
    # Simplified implementation - just keyword matching
    # In a real implementation, use embeddings and cosine similarity
    results = []
    keywords = question.lower().split()
    
    for sample in samples:
        score = sum(1 for keyword in keywords if keyword in sample["question"].lower())
        if score > 0:
            results.append({"sample": sample, "score": score})
    
    # Sort by score and return top_k
    results.sort(key=lambda x: x["score"], reverse=True)
    return [item["sample"] for item in results[:top_k]]

def get_relevant_tables(question: str, schema: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
   
    relevant_tables = {}
    keywords = question.lower().split()
    
    for table_name, table_info in schema.items():
        # Check if table name is mentioned
        if table_name.lower() in question.lower():
            relevant_tables[table_name] = table_info
            continue
            
        # Check if any column is mentioned
        for column in table_info["columns"]:
            if column.lower() in question.lower():
                relevant_tables[table_name] = table_info
                break
                
        # Check for semantic matches (simplified)
        description_words = table_info["description"].lower().split()
        match_score = sum(1 for keyword in keywords if keyword in description_words)
        if match_score > 1:  # If multiple keywords match the description
            relevant_tables[table_name] = table_info
    
    # If no tables found, return top 2 most likely tables
    if not relevant_tables:
        table_scores = {}
        for table_name, table_info in schema.items():
            # Create a score based on column matches
            score = sum(1 for keyword in keywords for column in table_info["columns"] if keyword in column.lower())
            table_scores[table_name] = score
        
        # Get top 2 tables by score
        top_tables = sorted(table_scores.items(), key=lambda x: x[1], reverse=True)[:2]
        for table_name, _ in top_tables:
            relevant_tables[table_name] = schema[table_name]
    
    return relevant_tables

def generate_system_prompt(question: str) -> str:
    """
    Generates a comprehensive system prompt for SQL query generation.

    Args:
        question (str): The natural language query to generate a system prompt for.

    Returns:
        str: A markdown-formatted system prompt containing:
             - User's original question
             - Similar SQL query examples
             - Relevant database schema information
             - Custom SQL generation instructions

    Process:
        1. Search for similar SQL query examples
        2. Identify relevant database tables
        3. Construct a detailed prompt with context and instructions

    Notes:
        The system prompt is designed to provide maximum context to 
        the SQL query generation model, improving the accuracy of 
        generated queries.

    Example:
        prompt = generate_system_prompt("Show active user accounts")
        # Returns a detailed markdown prompt with examples, schema, 
        # and instructions for generating a SQL query
    """
    # Search for similar examples
    similar_examples = search_similar_examples(question, SQL_SAMPLES)
    
    # Get relevant tables
    relevant_tables = get_relevant_tables(question, SCHEMA_SAMPLES)
    
    # Build the system prompt
    prompt = "# SQL Query Generation\n\n"
    prompt += f"## User Question\n{question}\n\n"
    
    # Add similar examples
    prompt += "## Similar SQL Examples\n"
    for i, example in enumerate(similar_examples, 1):
        prompt += f"### Example {i}\n"
        prompt += f"Question: {example['question']}\n"
        prompt += f"SQL: ```sql\n{example['sql']}\n```\n\n"
    
    # Add relevant schema information
    prompt += "## Relevant Database Schema\n"
    for table_name, table_info in relevant_tables.items():
        prompt += f"### Table: {table_name}\n"
        prompt += f"Description: {table_info['description']}\n"
        prompt += "Columns:\n"
        for column in table_info['columns']:
            prompt += f"- {column}\n"
        prompt += "\n"
    
    # Add custom instructions
    prompt += "## Custom Instructions\n"
    prompt += CUSTOM_INSTRUCTIONS
    
    # Final instruction
    prompt += "\n\nGenerate a SQL query to answer the user's question based on the available schema and examples. Include an explanation of the query."
    
    return prompt

# # Define our SQL generation tool
# def generate_sql_query(query_text: str) -> Dict[str, str]:
#     """
#     Generates a SQL query from a natural language input.
#     """
#     try:
#         # Generate the system prompt with similar examples and schema info
#         system_prompt = generate_system_prompt(query_text)
        
#         # In a real implementation, this would call a language model
#         # For now, we'll return a template
        
#         # This would be replaced with an actual LLM call in production
#         # For demonstration, we'll just return a mock response
#         mock_sql = """
#         SELECT 
#             d.license_number, 
#             d.issue_date, 
#             d.expiration_date,
#             p.name,
#             v.make, 
#             v.model
#         FROM 
#             driver_licenses d
#         JOIN 
#             users p ON d.person_id = p.user_id
#         LEFT JOIN 
#             vehicle_registrations v ON p.user_id = v.owner_id
#         WHERE 
#             d.expiration_date > CURRENT_DATE
#         ORDER BY 
#             d.expiration_date ASC
#         """
        
#         explanation = """
#         This query retrieves driver license information along with owner details and their vehicle information.
#         It joins the driver_licenses table with the users table to get the owner's name.
#         It also performs a LEFT JOIN with vehicle_registrations to include any vehicles registered to the owner.
#         The WHERE clause filters to only include active licenses (not expired).
#         Results are ordered by expiration date (soonest expiring first).
#         """
        
#         return {
#             "sql": mock_sql.strip(),
#             "explanation": explanation.strip(),
#             "system_prompt": system_prompt  # Include the system prompt for transparency
#         }
#     except Exception as e:
#         return {"error": str(e)}


In [None]:

# Create our SQL Agent
sql_agent = Agent(
    name="sql_query_generator",
    model=AGENT_MODEL,
    description="Generates SQL queries from natural language questions using schema metadata and similar examples.",
    instruction=(
        "You are an expert SQL query generator. "
        "Given a query from user, use the tool first 'generate_system_prompt' to generate the contextual query and explanation."
        "Given a natural language question about data, generate an appropriate SQL query. "
        "Use the provided schema metadata and similar examples to create accurate queries. "
        "Your response should include the SQL query and an explanation of what it does and why specific joins, filters, "
        "or aggregations were chosen. Format SQL with proper indentation for readability."
    ),
    tools=[generate_system_prompt],
)

# Create session service and session
session_service = InMemorySessionService()

APP_NAME = "sql_query_app"
USER_ID = "user_1"
SESSION_ID = "session_001"

# Create the specific session where the conversation will happen
session = session_service.create_session(
    app_name=APP_NAME,
    user_id=USER_ID,
    session_id=SESSION_ID
)
print(f"Session created: App='{APP_NAME}', User='{USER_ID}', Session='{SESSION_ID}'")

# Create the runner
runner = Runner(
    agent=sql_agent,
    app_name=APP_NAME,
    session_service=session_service
)
print(f"Runner created for agent '{runner.agent.name}'.")

async def generate_sql_from_nl(query: str):
    """Sends a natural language query to the agent and returns the SQL response."""
    print(f"\n>>> User Query: {query}")

    # Prepare the user's message in ADK format
    content = types.Content(role='user', parts=[types.Part(text=query)])

    final_response_text = "Agent did not produce a final response."  # default

    async for event in runner.run_async(user_id=USER_ID, session_id=SESSION_ID, new_message=content):
        # You can uncomment the line below to see *all* events during execution
        # print(f"  [Event] Author: {event.author}, Type: {type(event).__name__}, Final: {event.is_final_response()}, Content: {event.content}")

        # Key Concept: is_final_response() marks the concluding message for the turn.
        if event.is_final_response():
            if event.content and event.content.parts:
                # Assuming text response in the first part
                final_response_text = event.content.parts[0].text
            elif event.actions and event.actions.escalate:  # Handle potential errors/escalations
                final_response_text = f"Agent escalated: {event.error_message or 'No specific message.'}"
            break

    print(f"<<< Agent Response: {final_response_text}")
    return final_response_text

Session created: App='sql_query_app', User='user_1', Session='session_001'
Runner created for agent 'sql_query_generator'.


In [None]:
test_queries = [
    "what is the total amount in collection_view?",
    "what is emi_status for total_amount greater than 10000 in collection_view"
]

# Run the first test query
result = await generate_sql_from_nl(test_queries[1])
print("Result:", result)


>>> User Query: what is emi_status for total_amount greater than 10000 in collection_view
<<< Agent Response: ```sql
SELECT emi_status
FROM collection_view
WHERE total_overdue_amount > 10000;
```

Explanation:
This SQL query selects the `emi_status` from the `collection_view` table for all entries where the `total_overdue_amount` is greater than 10000. This will give you the EMI statuses associated with those entries that meet the specified condition.
Result: ```sql
SELECT emi_status
FROM collection_view
WHERE total_overdue_amount > 10000;
```

Explanation:
This SQL query selects the `emi_status` from the `collection_view` table for all entries where the `total_overdue_amount` is greater than 10000. This will give you the EMI statuses associated with those entries that meet the specified condition.


In [1]:
import pandas as pd


  from pandas.core import (


In [11]:
df = pd.read_csv(r"C:\Users\AKSHAT SHAW\OneDrive - iitr.ac.in\Desktop\Side-Projects\Agents\nl2sql_agent\app\schema_dump.csv")


df2 = pd.read_excel(r"C:\Users\AKSHAT SHAW\OneDrive - iitr.ac.in\Desktop\Side-Projects\Agents\nl2sql_agent\app\Schema dump.xlsx", sheet_name="Table Desc")

In [12]:
df

Unnamed: 0,table_schema,table_name,column_name,data_type
0,sttash_website_live,bureau_account_history,id,bigint
1,sttash_website_live,bureau_account_history,customer_id,bigint
2,sttash_website_live,bureau_account_history,loan_id,bigint
3,sttash_website_live,bureau_account_history,report_id,bigint
4,sttash_website_live,bureau_account_history,account_list_id,bigint
...,...,...,...,...
2637,,,,
2638,,,,
2639,,,,
2640,,,,


In [14]:
df2

Unnamed: 0,Table Name,Table description,Primary column
0,bureau_account_history,This table contains the complete tradeline (cr...,"Customer_id,loan_id,report_id &dpd"
1,bureau_account_list,This table provides a summary view of all cred...,"Customer_id,loan_id,report_id"
2,bureau_account_summary,This table contains summarized information for...,"Customer_id,loan_id,report_id"
3,bureau_account_type,The table bureau_account_type is typically a m...,"id,account_type,loan_type"
4,bureau_address,bureau_address contains the address records as...,"customer_id,report_id, address,datereported"
...,...,...,...
64,st_razorpay_trasfer,Transaction table for razorpay gateway which m...,"customer_id,payment_id,overall_payment_id"
65,st_razorpay_validate_bankacc,This table stores penny drop/ reverse penny dr...,"customer_id,service_type,response_json"
66,st_state,Master table of state name and its id,"id,state_name"
67,st_test_leads,Testing lead table for QA team,"customer_id,active"


In [None]:
merged_df = df.merge(df2[['Table Name', 'Table description']], 
                     left_on='table_name', 
                     right_on='Table Name', 
                     how='left')
merged_df.drop(columns=['Table Name'], inplace=True)

          table_schema              table_name      column_name data_type  \
0  sttash_website_live  bureau_account_history               id    bigint   
1  sttash_website_live  bureau_account_history      customer_id    bigint   
2  sttash_website_live  bureau_account_history          loan_id    bigint   
3  sttash_website_live  bureau_account_history        report_id    bigint   
4  sttash_website_live  bureau_account_history  account_list_id    bigint   

                                   Table description  
0  This table contains the complete tradeline (cr...  
1  This table contains the complete tradeline (cr...  
2  This table contains the complete tradeline (cr...  
3  This table contains the complete tradeline (cr...  
4  This table contains the complete tradeline (cr...  


In [20]:
merged_df.dropna(inplace= True)

In [22]:
merged_df.to_csv("Final_schema.csv", index=False)

In [3]:
merged_df = pd.read_csv(r"C:\Users\AKSHAT SHAW\OneDrive - iitr.ac.in\Desktop\Side-Projects\Agents\nl2sql_agent\app\Final_schema.csv")

In [4]:
new_df = pd.read_csv("new_schema.csv")

In [11]:
len(merged_df['table_name'].unique())

69

In [5]:
new_df

Unnamed: 0,public,all_bad_dept_and_npa_provision_data,vintage,loan_id,sum_of_principal,entity,product_segment,dpd_group,Unnamed: 8,Unnamed: 9,...,Unnamed: 139,Unnamed: 140,Unnamed: 141,Unnamed: 142,Unnamed: 143,Unnamed: 144,Unnamed: 145,Unnamed: 146,Unnamed: 147,Unnamed: 148
0,public,closed_loans,loan_id,customer_id,last_pay_date,,,,,,...,,,,,,,,,,
1,public,crif_aggregated_featurs_jan_oct_2024,customer_id,candidate_no,account_type,credit_limit,closed_loans_flag,closed_loans_dates,,,...,,,,,,,,,,
2,public,crif_model_validation_onus_customers,loan_id,customer_id,loan_approval_date,loan_disbursal_date,approved_amount,disbursal_amount,productgroup,approved_rate,...,,,,,,,,,,
3,public,crif_model_validation_rejected_customers,customer_id,create_date,occupation,cibil_report_id,cibil_report_create_date,cibil_score,rk,max_dpd_in_03,...,,,,,,,,,,
4,public,crif_scrub_scorecard_jan_oct24_features,customer_id,create_date,crif_bureau_score,first_credit_open_date,latest_loan_closed_date,no_of_tradelines,no_of_active_tradeline,no_of_closed_tradeline,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
106,sttash_website_live,userdeviceapps,id,customer_id,app_name,package_name,version_name,version_code,installation_date,last_used_date,...,,,,,,,,,,
107,sttash_website_live,westerncap_request_log,id,url,headers,post_data,resp,loan_id,customer_id,create_date,...,,,,,,,,,,
108,sttash_website_live,st_emi_status,id,emi_status,status,,,,,,...,,,,,,,,,,
109,sttash_website_live,st_loans,ERROR_FETCHING_COLUMNS (Database error: relati...,,,,,,,,...,,,,,,,,,,


In [27]:
SCHEMA_SAMPLES = {}
for table_name, group in merged_df.groupby('table_name'):
    description = group['Table description'].dropna().unique()
    if len(description) > 0:
        table_description = description[0]
    else:
        table_description = f"Schema for {table_name} table"
    SCHEMA_SAMPLES[table_name] = {
        'columns': group['column_name'].tolist(),
        'description': table_description,
        'column_details': group.to_dict(orient='records')
    }
   

In [None]:
import faiss
import pickle
from sentence_transformers import SentenceTransformer

embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", token=False)

def build_and_save_vector_store(schema: dict, index_path: str, metadata_path: str):
    texts = []
    table_keys = []
    
    for table_name, table_info in schema.items():
        text_blob = f"{table_name}. {table_info['description']}. {' '.join(table_info['columns'])}"
        texts.append(text_blob)
        table_keys.append(table_name)
    
    embeddings = embedder.encode(texts, normalize_embeddings=True)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)
    index.add(embeddings)
    
    # Save FAISS index
    faiss.write_index(index, index_path)
    
    # Save table_keys
    with open(metadata_path, 'wb') as f:
        pickle.dump(table_keys, f)

def load_vector_store(index_path: str, metadata_path: str):
    index = faiss.read_index(index_path)
    with open(metadata_path, 'rb') as f:
        table_keys = pickle.load(f)
    return index, table_keys


build_and_save_vector_store(SCHEMA_SAMPLES, 'table_index.faiss', 'table_keys.pkl')

# Load (anytime later)
index, table_keys = load_vector_store('table_index.faiss', 'table_keys.pkl')


Relevant tables: {'bureau_account_history': "This table contains the complete tradeline (credit account) history of customers as reported by credit bureaus. It includes detailed information about each credit account such as account type, opening date, current balance, credit limit, payment history, and account status. The table is used to assess a customer's credit behavior and exposure over time.", 'overall_payment': 'This table typically stores comprehensive details related to loan/emi payments, tracking financial transactions for borrowers', 'installment_fip': 'The Installment FIP table stores essential details related to customer loan repayments, helping to track installment schedules and payment behavior, days past due etc.'}


In [37]:
query_embedding = embedder.encode(["hello"], normalize_embeddings=True)
D, I = index.search(query_embedding, 5)

In [42]:

# Search example
def search_relevant_tables(query: str, index, table_keys, top_k=3):
    query_embedding = embedder.encode([query], normalize_embeddings=True)
    D, I = index.search(query_embedding, top_k)
    results = {table_keys[i]: SCHEMA_SAMPLES[table_keys[i]] for i in I[0]}
    return results

query = "customer payment history"
relevant_tables = search_relevant_tables(query, index, table_keys)
print("Relevant tables:", relevant_tables)

Relevant tables: {'bureau_account_history': {'columns': ['id', 'customer_id', 'loan_id', 'report_id', 'account_list_id', 'history_date', 'month_number', 'dpd_char', 'dpd', 'payment_status', 'suit_filed_status', 'asset_classification_status', 'create_date'], 'description': "This table contains the complete tradeline (credit account) history of customers as reported by credit bureaus. It includes detailed information about each credit account such as account type, opening date, current balance, credit limit, payment history, and account status. The table is used to assess a customer's credit behavior and exposure over time.", 'column_details': [{'table_schema': 'sttash_website_live', 'table_name': 'bureau_account_history', 'column_name': 'id', 'data_type': 'bigint', 'Table description': "This table contains the complete tradeline (credit account) history of customers as reported by credit bureaus. It includes detailed information about each credit account such as account type, opening da

In [8]:
from pymongo import MongoClient
from openai import OpenAI
import os
from dotenv import load_dotenv
load_dotenv()
MONGO_URI = os.getenv("MONGO_URI")
DB_NAME = "sql_agent"
COLLECTION_NAME = "test1"
api_key=os.getenv("OPENAI_API_KEY")
def get_openai_embedding(text: str, model="text-embedding-ada-002") -> list:
    client = OpenAI(api_key=api_key)
    response = client.embeddings.create(
        input=text,
        model=model
    )
    return response.data[0].embedding

def search_mongodb_tables(query: str, mongo_uri: str, db_name: str, collection_name: str, top_k: int = 5):
    # Connect to MongoDB
    client = MongoClient(mongo_uri)
    db = client[db_name]
    collection = db[collection_name]

    # Embed the query
    query_embedding = get_openai_embedding(query)

    # Perform vector search with $vectorSearch
    pipeline = [
        {
            "$vectorSearch": {
                "index": "vector_index",  # The name of your Atlas vector index
                "path": "embedding",
                "queryVector": query_embedding,
                "numCandidates": 100,     # Number of candidates to search over
                "limit": top_k
            }
        },
        {
            "$project": {
                "_id": 0,
                "table_name": 1,
                "description": 1,
                "columns": 1,
                "score": {"$meta": "vectorSearchScore"}
            }
        }
    ]

    results = list(collection.aggregate(pipeline))
    return results

In [11]:
query = "what is the total loan amount pending?"
result = search_mongodb_tables(query, MONGO_URI, DB_NAME, COLLECTION_NAME)

In [12]:
result

[{'table_name': 'overall_payment',
  'description': 'This table typically stores comprehensive details related to loan/emi payments, tracking financial transactions for borrowers',
  'columns': ['id',
   'customer_id',
   'loan_id',
   'amt_payment',
   'received_date',
   'cheque_number',
   'urm_no',
   'payment_channel',
   'transaction_id',
   'ref_no',
   'utr_no',
   'neft_bank',
   'presentation_status',
   'bounce_reason',
   'presentation_date',
   'create_date',
   'extra_amount',
   'extra_amount_pif',
   'remarks',
   'add_user_id',
   'update_user_id',
   'is_delete',
   'is_refund',
   'update_date',
   'transaction_commit_status'],
  'score': 0.897179126739502},
 {'table_name': 'st_fip_detail',
  'description': 'This table stores disbursed loan and applied all the charges with loans',
  'columns': ['id',
   'customer_id',
   'loan_id',
   'requested_amount',
   'approved_amount',
   'processing_transaction_fee',
   'processing_fees_rate',
   'pf_tf_gst',
   'transaction_

# Stroing Vector Embd in MongoDB for the new csv file

In [1]:
import pandas as pd

  from pandas.core import (


In [41]:
df = pd.read_csv(r"C:\Users\AKSHAT SHAW\OneDrive - iitr.ac.in\Desktop\Side-Projects\Agents\nl2sql_agent\app\GEN_AI Use cases & Data Dict (1).csv")

In [42]:
df.dropna(inplace=True, how ="all")

In [43]:
df.fillna(value=" ", inplace=True)

In [44]:
df['Description'] = df['Description'].astype(str) + " " + df['Remarks'].astype(str)
df.drop(columns=['Remarks'], inplace=True)

In [45]:
df

Unnamed: 0,Sno.,Column,Description,Data Type
0,1.0,loan_id,unique id for loan,bigint
1,2.0,customer_id,customer unique id,bigint
2,3.0,loan_creation_date,Should be used for date filter - lead generati...,date
3,4.0,disbursal_month,the month loan is disbursed to the customer,date
4,5.0,utm_source,elev8_ref_prg\r\nelevateassigned\r\nvaluelea...,character varying
...,...,...,...,...
83,84.0,outstanding_amt_active_unsecured,Bureau feature of the lead,double precision
84,85.0,max_pl_closed_limit_2yrs,Bureau feature of the lead,double precision
85,86.0,max_active_cc_limit,Bureau feature of the lead,double precision
86,87.0,bureau_age,Bureau feature of the lead,double precision


In [48]:
from openai import OpenAI
import numpy as np
from pymongo import MongoClient
from typing import Dict
import os
from dotenv import load_dotenv
load_dotenv()
import pandas as pd

# updated schema with table description 
SCHEMA_SAMPLES = {}

for _, row in df.iterrows():
    column = row['Column']
    description = row['Description']
    data_type = row['Data Type']
    SCHEMA_SAMPLES[column] = {
        'description': description,
        'data_type': data_type
    }
    
# Make sure your OpenAI API key is set
# openai.api_key = os.getenv("OPENAI_API_KEY")
MONGO_URI = os.getenv("MONGO_URI")
DB_NAME = "sql_agent"
COLLECTION_NAME = "test2"


def get_openai_embedding(text: str, model="text-embedding-ada-002") -> list:
    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    response = client.embeddings.create(
        input=text,
        model=model
    )
    return response.data[0].embedding

def build_and_save_vector_store(schema: Dict, mongo_uri: str, db_name: str, collection_name: str):
    client = MongoClient(mongo_uri)
    db = client[db_name]
    collection = db[collection_name]

    # Optional: clear previous entries
    collection.delete_many({})

    for table_name, table_info in schema.items():
        text_blob = f"{table_name}. {table_info['description']}. {' '.join(table_info['data_type'])}"
        embedding = get_openai_embedding(text_blob)

        document = {
            "table_name": table_name,
            "description": table_info['description'],
            "data_type": table_info['data_type'],
            "embedding": embedding
        }

        collection.insert_one(document)

    print("Vector store built and stored in MongoDB.")

if __name__== "__main__":
    build_and_save_vector_store(SCHEMA_SAMPLES, MONGO_URI, DB_NAME, COLLECTION_NAME)


Vector store built and stored in MongoDB.


In [53]:
from pymongo import MongoClient
from openai import OpenAI
import os
from dotenv import load_dotenv
load_dotenv()
MONGO_URI = os.getenv("MONGO_URI")
DB_NAME = "sql_agent"
COLLECTION_NAME = "test2"
api_key=os.getenv("OPENAI_API_KEY")
def get_openai_embedding(text: str, model="text-embedding-ada-002") -> list:
    client = OpenAI(api_key=api_key)
    response = client.embeddings.create(
        input=text,
        model=model
    )
    return response.data[0].embedding

def search_mongodb_tables(query: str, mongo_uri: str, db_name: str, collection_name: str, top_k: int = 5):
    # Connect to MongoDB
    client = MongoClient(mongo_uri)
    db = client[db_name]
    collection = db[collection_name]

    # Embed the query
    query_embedding = get_openai_embedding(query)

    # Perform vector search with $vectorSearch
    pipeline = [
        {
            "$vectorSearch": {
                "index": "vector_index",  # The name of your Atlas vector index
                "path": "embedding",
                "queryVector": query_embedding,
                "numCandidates": 100,     # Number of candidates to search over
                "limit": top_k
            }
        },
        {
            "$project": {
                "_id": 0,
                "table_name": 1,
                "description": 1,
                "data_type": 1,
                "score": {"$meta": "vectorSearchScore"}
            }
        }
    ]

    results = list(collection.aggregate(pipeline))
    return results

In [54]:
query = "deo_passed"
result = search_mongodb_tables(query, MONGO_URI, DB_NAME, COLLECTION_NAME, top_k=3)
result

[{'table_name': 'deo_passed',
  'description': '1 if passed else 0  ',
  'data_type': 'bigint',
  'score': 0.9369359016418457},
 {'table_name': 'chr_passed',
  'description': '1 if passed else 0  ',
  'data_type': 'bigint',
  'score': 0.9168753623962402},
 {'table_name': 'jud_passed',
  'description': '1 if passed else 0  ',
  'data_type': 'bigint',
  'score': 0.9155865907669067}]

# Testing PostgreSQL

In [None]:
# !pip install psycopg2-binary



In [3]:
import psycopg2

# Define connection parameters
conn = psycopg2.connect(
    host = "turntable.proxy.rlwy.net",
    port = "55145",
    dbname = "railway",
    user = "postgres",
    password = "tCwnKtFmoCkmBgReaPXShKFSRhMxpnqj"
)


In [None]:
sql = 'SELECT CAST(SUM(CASE WHEN loan_status = 'Under Processing' THEN 1 ELSE 0 END) AS FLOAT) * 100 / COUNT(*) AS percentage_under_processing FROM loan_data;'

In [11]:
# Define connection parameters
conn = psycopg2.connect(
    host = "turntable.proxy.rlwy.net",
    port = "55145",
    dbname = "railway",
    user = "postgres",
    password = "tCwnKtFmoCkmBgReaPXShKFSRhMxpnqj"
)

# Open cursor and execute SQL file
try:
    with conn:
        with conn.cursor() as cursor:
            with open('script.sql', 'r') as f:
                sql = f.read()
                cursor.execute(sql)
                print(cursor.fetchall())
    print("SQL script executed successfully.")
except Exception as e:
    print(f"Error executing SQL: {e}")
finally:
    conn.close()


[(0.0,)]
SQL script executed successfully.
