In [1]:
#!/usr/bin/env python3
"""
GGUF Model Loader - Based on your working notebook code
Uses the exact same approach that worked in gemma_quantization.ipynb
"""

import os
from huggingface_hub import hf_hub_download
from llama_cpp import Llama

class NIRFRankingModel:
    def __init__(self, model_path):
        """Initialize the NIRF ranking model - Same as your notebook"""
        self.model_path = model_path
        self.llm = None
        self.load_model()
    
    def load_model(self):
        """Load the GGUF model - Exact same code from your notebook"""
        try:
            print(f"Loading NIRF model from: {self.model_path}")
            self.llm = Llama(
                model_path=self.model_path,
                n_ctx=2048,  # Context window
                n_threads=4,  # Number of CPU threads  
                verbose=False
            )
            print("Model loaded successfully!")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise
    
    def generate_response(self, prompt, max_tokens=512, temperature=0.3):
        """Generate response - Same as your notebook"""
        if not self.llm:
            raise Exception("Model not loaded")
        
        try:
            response = self.llm(
                prompt,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=0.9,
                repeat_penalty=1.1,
                stop=["</s>", "\n\n"]
            )
            return response['choices'][0]['text'].strip()
        except Exception as e:
            print(f"Error generating response: {e}")
            return None
    
    def ask_nirf_question(self, question):
        """Ask NIRF question - Same as your notebook"""
        print(f"\nQuestion: {question}")
        print("Generating answer...")
        
        response = self.generate_response(question, max_tokens=300)
        if response:
            print(f"Answer: {response}")
        else:
            print("Sorry, I couldn't generate a response.")

def download_model():
    """Download the GGUF model from HuggingFace"""
    print("üì• Downloading GGUF model...")
    
    model_path = hf_hub_download(
        repo_id="coderop12/gemma2b-nirf-lookup-gguf",
        filename="gemma2b-nirf-lookup-f16.gguf",
        local_dir="./gguf_output",  # Same directory structure as your notebook
        local_dir_use_symlinks=False
    )
    
    print(f"‚úÖ Downloaded to: {model_path}")
    return model_path

def main():
    """Main function - Same structure as your notebook"""
    
    # Download model first
    model_path = download_model()
    
    # Initialize model with exact same approach as your notebook
    nirf_model = NIRFRankingModel(model_path)
    
    # Test with same sample questions from your notebook
    sample_questions = [
        "What is the NIRF ranking of IIT Delhi in 2024?",
        "Which are the top engineering colleges in NIRF 2024?",
        "What is NIRF ranking methodology?",
        "How are universities ranked in India?"
    ]
    
    print("=== NIRF Ranking Model Inference Test ===\n")
    
    # Test first question (same as your notebook)
    print("\n--- Test 1 ---")
    nirf_model.ask_nirf_question(sample_questions[0])
    print("-" * 50)
    
    # Interactive mode (same as your notebook)
    print("\n=== Interactive Mode ===")
    print("Ask your NIRF ranking questions (type 'quit' to exit):")
    
    while True:
        try:
            user_question = input("\nQuestion: ").strip()
            if user_question.lower() in ['quit', 'exit', 'q']:
                break
            
            if user_question:
                nirf_model.ask_nirf_question(user_question)
            
        except KeyboardInterrupt:
            print("\nExiting...")
            break
        except Exception as e:
            print(f"Error: {e}")

if __name__ == "__main__":
    main()

üì• Downloading GGUF model...
‚úÖ Downloaded to: gguf_output/gemma2b-nirf-lookup-f16.gguf
Loading NIRF model from: gguf_output/gemma2b-nirf-lookup-f16.gguf


For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.
llama_context: n_ctx_per_seq (2048) < n_ctx_train (8192) -- the full capacity of the model will not be utilized
llama_kv_cache_unified_iswa: using full-size SWA cache (ref: https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)


Model loaded successfully!
=== NIRF Ranking Model Inference Test ===


--- Test 1 ---

Question: What is the NIRF ranking of IIT Delhi in 2024?
Generating answer...
Sorry, I couldn't generate a response.
--------------------------------------------------

=== Interactive Mode ===
Ask your NIRF ranking questions (type 'quit' to exit):

Question: What is the NIRF ranking of IIT Delhi in 2024?
Generating answer...
Sorry, I couldn't generate a response.

Question: answer?
Generating answer...
Answer: ```python
def find_longest_substring(s: str, k: int) -> str:
    if len(s) == 0 or k == 0:
        return ""
    start = 0
    max_len = 0
    end = 0
    freq = {}
    while end < len(s):
        char = s[end]
        if char in freq:
            start = max(start, freq[char] + 1)
        freq[char] = end
        max_len = max(max_len, end - start + 1)
        end += 1
    return s[start : start + max_len]
```
### Function Definition
The function `find_longest_substring` takes two arguments:


In [1]:
#!/usr/bin/env python3
"""
Debug script for GGUF model inference issues
Helps identify why the model loads but doesn't generate responses
"""

import os
from huggingface_hub import hf_hub_download
from llama_cpp import Llama

def debug_model_loading():
    """Debug model loading with verbose output"""
    print("üîç Debugging model loading...")
    
    # Download model
    model_path = hf_hub_download(
        repo_id="coderop12/gemma2b-nirf-lookup-gguf",
        filename="gemma2b-nirf-lookup-f16.gguf",
        local_dir="./gguf_output"
    )
    
    print(f"üìÅ Model path: {model_path}")
    print(f"üìä Model size: {os.path.getsize(model_path) / (1024**3):.2f} GB")
    
    # Load with verbose output to see what's happening
    print("\nüîÑ Loading model with verbose output...")
    try:
        llm = Llama(
            model_path=model_path,
            n_ctx=2048,
            n_threads=4,
            verbose=True,  # Enable verbose to see loading details
            use_mlock=False,
            use_mmap=True
        )
        print("‚úÖ Model loaded successfully!")
        return llm
    except Exception as e:
        print(f"‚ùå Error loading model: {e}")
        return None

def test_basic_inference(llm):
    """Test basic inference with different approaches"""
    if not llm:
        return
    
    print("\nüß™ Testing basic inference...")
    
    # Test 1: Very simple prompt
    print("\n--- Test 1: Simple prompt ---")
    try:
        response = llm("Hello", max_tokens=10, temperature=0.1)
        print(f"Response type: {type(response)}")
        print(f"Response keys: {response.keys() if isinstance(response, dict) else 'Not a dict'}")
        if isinstance(response, dict) and 'choices' in response:
            text = response['choices'][0]['text']
            print(f"Generated text: '{text}'")
        else:
            print(f"Raw response: {response}")
    except Exception as e:
        print(f"‚ùå Simple inference failed: {e}")
        import traceback
        traceback.print_exc()
    
    # Test 2: Direct call method
    print("\n--- Test 2: Direct call ---")
    try:
        output = llm.create_completion(
            prompt="Hi",
            max_tokens=5,
            temperature=0.1,
            echo=False
        )
        print(f"Direct call response: {output}")
    except Exception as e:
        print(f"‚ùå Direct call failed: {e}")
    
    # Test 3: Token generation
    print("\n--- Test 3: Token generation ---")
    try:
        tokens = llm.tokenize(b"Hello world")
        print(f"Tokenization works: {tokens}")
        
        # Try to generate one token
        token_gen = llm.generate(tokens, top_k=1)
        first_token = next(token_gen)
        print(f"Generated token: {first_token}")
    except Exception as e:
        print(f"‚ùå Token generation failed: {e}")

def test_nirf_specific(llm):
    """Test NIRF-specific prompts with debugging"""
    if not llm:
        return
    
    print("\nüéØ Testing NIRF-specific prompts...")
    
    # Test with different prompt formats
    test_prompts = [
        "NIRF",
        "What is NIRF?",
        "IIT Delhi ranking",
        "Tell me about NIRF ranking methodology",
        "NIRF 2024 rankings"
    ]
    
    for i, prompt in enumerate(test_prompts):
        print(f"\n--- NIRF Test {i+1}: '{prompt}' ---")
        try:
            response = llm(
                prompt,
                max_tokens=50,
                temperature=0.3,
                top_p=0.9,
                repeat_penalty=1.1,
                stop=["</s>", "\n\n"]
            )
            
            if isinstance(response, dict) and 'choices' in response:
                text = response['choices'][0]['text'].strip()
                if text:
                    print(f"‚úÖ Generated: '{text}'")
                else:
                    print("‚ö†Ô∏è Empty response")
            else:
                print(f"‚ùå Unexpected response format: {response}")
                
        except Exception as e:
            print(f"‚ùå Error with prompt '{prompt}': {e}")

def check_model_compatibility():
    """Check if the model file is compatible"""
    print("\nüîç Checking model compatibility...")
    
    model_path = "./gguf_output/gemma2b-nirf-lookup-f16.gguf"
    
    if not os.path.exists(model_path):
        print("‚ùå Model file not found")
        return False
    
    # Check file size
    size_gb = os.path.getsize(model_path) / (1024**3)
    print(f"üìä File size: {size_gb:.2f} GB")
    
    if size_gb < 1.0 or size_gb > 10.0:
        print("‚ö†Ô∏è Unusual file size for a 2B model")
    
    # Try to peek at file header
    try:
        with open(model_path, 'rb') as f:
            header = f.read(16)
            print(f"üìã File header: {header}")
            if b'GGUF' in header:
                print("‚úÖ Valid GGUF file format")
            else:
                print("‚ùå Not a valid GGUF file")
                return False
    except Exception as e:
        print(f"‚ùå Can't read file: {e}")
        return False
    
    return True

def minimal_working_example():
    """Minimal example that should work"""
    print("\nüîß Trying minimal working example...")
    
    model_path = "./gguf_output/gemma2b-nirf-lookup-f16.gguf"
    
    try:
        # Minimal configuration
        llm = Llama(
            model_path=model_path,
            n_ctx=512,      # Smaller context
            n_threads=2,    # Fewer threads
            verbose=False,
            n_batch=128     # Smaller batch
        )
        
        print("‚úÖ Minimal model loaded")
        
        # Very simple test
        result = llm("Hi", max_tokens=5, temperature=0.0)
        print(f"Minimal test result: {result}")
        
        return llm
        
    except Exception as e:
        print(f"‚ùå Even minimal example failed: {e}")
        return None

def main():
    """Main debugging function"""
    print("üöÄ Starting GGUF Model Debug Session")
    print("=" * 50)
    
    # Step 1: Check file compatibility
    if not check_model_compatibility():
        print("‚ùå Model file issues detected")
        return
    
    # Step 2: Try minimal example first
    minimal_llm = minimal_working_example()
    if minimal_llm:
        print("‚úÖ Minimal example works!")
        test_basic_inference(minimal_llm)
        return
    
    # Step 3: Debug full loading
    llm = debug_model_loading()
    if not llm:
        print("‚ùå Model loading failed")
        return
    
    # Step 4: Test inference
    test_basic_inference(llm)
    
    # Step 5: Test NIRF-specific
    test_nirf_specific(llm)
    
    print("\n" + "=" * 50)
    print("üèÅ Debug session complete")

if __name__ == "__main__":
    main()

üöÄ Starting GGUF Model Debug Session

üîç Checking model compatibility...
üìä File size: 4.88 GB
üìã File header: b'GGUF\x03\x00\x00\x00 \x01\x00\x00\x00\x00\x00\x00'
‚úÖ Valid GGUF file format

üîß Trying minimal working example...


llama_context: n_ctx_per_seq (512) < n_ctx_train (8192) -- the full capacity of the model will not be utilized
llama_kv_cache_unified_iswa: using full-size SWA cache (ref: https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)


‚úÖ Minimal model loaded
Minimal test result: {'id': 'cmpl-1f3d9e0e-29ab-4d93-bc7f-ecd2b889b478', 'object': 'text_completion', 'created': 1758968929, 'model': './gguf_output/gemma2b-nirf-lookup-f16.gguf', 'choices': [{'text': ' is there a way to', 'index': 0, 'logprobs': None, 'finish_reason': 'length'}], 'usage': {'prompt_tokens': 2, 'completion_tokens': 5, 'total_tokens': 7}}
‚úÖ Minimal example works!

üß™ Testing basic inference...

--- Test 1: Simple prompt ---
Response type: <class 'dict'>
Response keys: dict_keys(['id', 'object', 'created', 'model', 'choices', 'usage'])
Generated text: ' is there anyone who can help me with this task'

--- Test 2: Direct call ---
Direct call response: {'id': 'cmpl-20782c2f-0e9b-49d9-b54d-42902f88f445', 'object': 'text_completion', 'created': 1758968934, 'model': './gguf_output/gemma2b-nirf-lookup-f16.gguf', 'choices': [{'text': ' is there a way to', 'index': 0, 'logprobs': None, 'finish_reason': 'length'}], 'usage': {'prompt_tokens': 2, 'comple

In [2]:
#!/usr/bin/env python3
"""
Quick fix for the inference issue
Based on your working notebook code with error handling improvements
"""

from huggingface_hub import hf_hub_download
from llama_cpp import Llama

def load_model_safe():
    """Load model with error handling"""
    model_path = hf_hub_download(
        repo_id="coderop12/gemma2b-nirf-lookup-gguf",
        filename="gemma2b-nirf-lookup-f16.gguf",
        local_dir="./gguf_output"
    )
    
    print(f"Loading model from: {model_path}")
    
    # Use exact same parameters as your working notebook
    llm = Llama(
        model_path=model_path,
        n_ctx=2048,
        n_threads=4,
        verbose=False
    )
    
    print("Model loaded successfully!")
    return llm

def generate_response_debug(llm, prompt, max_tokens=512, temperature=0.3):
    """Generate response with detailed debugging"""
    print(f"üîç Generating response for: '{prompt}'")
    print(f"Parameters: max_tokens={max_tokens}, temperature={temperature}")
    
    try:
        # Try the exact same call as your notebook
        response = llm(
            prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=0.9,
            repeat_penalty=1.1,
            stop=["</s>", "\n\n"]
        )
        
        print(f"‚úÖ Response received: {type(response)}")
        
        if isinstance(response, dict) and 'choices' in response:
            text = response['choices'][0]['text'].strip()
            print(f"üìù Generated text length: {len(text)}")
            print(f"üìù Generated text: '{text}'")
            return text
        else:
            print(f"‚ùå Unexpected response format: {response}")
            return None
            
    except Exception as e:
        print(f"‚ùå Error during generation: {e}")
        print(f"Error type: {type(e)}")
        
        # Try alternative approach
        try:
            print("üîÑ Trying alternative approach...")
            response = llm.create_completion(
                prompt=prompt,
                max_tokens=max_tokens,
                temperature=temperature,
                stop=["</s>"]
            )
            text = response['choices'][0]['text'].strip()
            print(f"‚úÖ Alternative approach worked: '{text}'")
            return text
        except Exception as e2:
            print(f"‚ùå Alternative approach also failed: {e2}")
            return None

def test_step_by_step():
    """Test inference step by step"""
    print("üöÄ Testing inference step by step...")
    
    # Load model
    llm = load_model_safe()
    
    # Test 1: Very simple
    print("\n--- Test 1: Very simple prompt ---")
    result = generate_response_debug(llm, "Hi", max_tokens=10, temperature=0.1)
    
    if result:
        print("‚úÖ Basic inference works!")
        
        # Test 2: NIRF question
        print("\n--- Test 2: NIRF question ---")
        nirf_result = generate_response_debug(
            llm, 
            "What is the NIRF ranking of IIT Delhi in 2024?", 
            max_tokens=200, 
            temperature=0.3
        )
        
        if nirf_result:
            print("‚úÖ NIRF inference works!")
        else:
            print("‚ùå NIRF inference failed")
            
            # Test 3: Try different NIRF prompt
            print("\n--- Test 3: Simpler NIRF prompt ---")
            simple_nirf = generate_response_debug(
                llm, 
                "NIRF ranking", 
                max_tokens=50, 
                temperature=0.1
            )
    else:
        print("‚ùå Basic inference failed - there's a fundamental issue")
        
        # Debug the model loading
        print("\nüîç Debugging model properties...")
        try:
            print(f"Model context size: {llm.n_ctx()}")
            print(f"Model vocab size: {llm.n_vocab()}")
            
            # Try tokenization
            tokens = llm.tokenize(b"Hello")
            print(f"Tokenization works: {tokens}")
            
            # Try detokenization
            text = llm.detokenize(tokens)
            print(f"Detokenization works: {text}")
            
        except Exception as e:
            print(f"‚ùå Model properties check failed: {e}")

def quick_interactive_test():
    """Quick interactive test"""
    print("\nüéØ Quick Interactive Test")
    
    llm = load_model_safe()
    
    while True:
        question = input("\nEnter a question (or 'quit'): ").strip()
        if question.lower() in ['quit', 'exit']:
            break
            
        print(f"ü§î Processing: {question}")
        result = generate_response_debug(llm, question, max_tokens=100, temperature=0.3)
        
        if result:
            print(f"‚úÖ Answer: {result}")
        else:
            print("‚ùå No response generated")

if __name__ == "__main__":
    # Run step by step test first
    test_step_by_step()
    
    # Then interactive if basic tests pass
    print("\n" + "="*50)
    response = input("Run interactive test? (y/n): ").strip().lower()
    if response == 'y':
        quick_interactive_test()

üöÄ Testing inference step by step...
Loading model from: gguf_output/gemma2b-nirf-lookup-f16.gguf


llama_context: n_ctx_per_seq (2048) < n_ctx_train (8192) -- the full capacity of the model will not be utilized
llama_kv_cache_unified_iswa: using full-size SWA cache (ref: https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)


Model loaded successfully!

--- Test 1: Very simple prompt ---
üîç Generating response for: 'Hi'
Parameters: max_tokens=10, temperature=0.1
‚úÖ Response received: <class 'dict'>
üìù Generated text length: 34
üìù Generated text: 'is there a way to get rid of the "'
‚úÖ Basic inference works!

--- Test 2: NIRF question ---
üîç Generating response for: 'What is the NIRF ranking of IIT Delhi in 2024?'
Parameters: max_tokens=200, temperature=0.3
‚úÖ Response received: <class 'dict'>
üìù Generated text length: 0
üìù Generated text: ''
‚ùå NIRF inference failed

--- Test 3: Simpler NIRF prompt ---
üîç Generating response for: 'NIRF ranking'
Parameters: max_tokens=50, temperature=0.1
‚úÖ Response received: <class 'dict'>
üìù Generated text length: 63
üìù Generated text: '2025 (Overall) is 
7. What does this NIRF Overall rank tell us?'


üéØ Quick Interactive Test
Loading model from: gguf_output/gemma2b-nirf-lookup-f16.gguf


llama_context: n_ctx_per_seq (2048) < n_ctx_train (8192) -- the full capacity of the model will not be utilized
llama_kv_cache_unified_iswa: using full-size SWA cache (ref: https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)


Model loaded successfully!
ü§î Processing: i got 600 rank suggest me some college under cse
üîç Generating response for: 'i got 600 rank suggest me some college under cse'
Parameters: max_tokens=100, temperature=0.3
‚úÖ Response received: <class 'dict'>
üìù Generated text length: 0
üìù Generated text: ''
‚ùå No response generated


In [12]:
#!/usr/bin/env python3
"""
Working Production RAG System: GGUF Model + Gemini SQL + Supabase PostgreSQL
Fixed version with proper PostgreSQL connection to your real JoSAA/NIRF data
"""

import psycopg2
import psycopg2.extras
import json
from typing import List, Dict, Optional, Tuple
import google.generativeai as genai

class PostgreSQLManager:
    """Handles PostgreSQL operations with your Supabase database"""
    
    def __init__(self, connection_params: Dict[str, str]):
        self.connection_params = connection_params
        self.table_info = ""
        self.connect_and_get_schema()
    
    def connect_and_get_schema(self):
        """Connect to database and get schema info"""
        try:
            # Test connection
            conn = psycopg2.connect(**self.connection_params)
            cursor = conn.cursor()
            
            print("‚úÖ PostgreSQL connection successful")
            
            # Get all tables
            cursor.execute("""
                SELECT table_name 
                FROM information_schema.tables 
                WHERE table_schema = 'public'
                ORDER BY table_name;
            """)
            tables = [row[0] for row in cursor.fetchall()]
            
            # Get column info for each table
            schema_info = "Available Tables:\n"
            for table in tables:
                schema_info += f"- {table}\n"
                
                cursor.execute("""
                    SELECT column_name, data_type 
                    FROM information_schema.columns 
                    WHERE table_name = %s AND table_schema = 'public'
                    ORDER BY ordinal_position;
                """, (table,))
                
                columns = cursor.fetchall()
                if columns:
                    schema_info += f"  Columns: {', '.join([f'{col[0]}({col[1]})' for col in columns[:5]])}...\n"
            
            conn.close()
            
            print(f"üìã Database Schema:\n{schema_info}")
            self.table_info = schema_info
            
        except Exception as e:
            print(f"‚ùå PostgreSQL connection failed: {e}")
            self.table_info = "Could not retrieve table information"
            raise
    
    def execute_query(self, sql_query: str) -> List[Dict]:
        """Execute SQL query and return results"""
        try:
            conn = psycopg2.connect(**self.connection_params)
            cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
            
            cursor.execute(sql_query)
            results = [dict(row) for row in cursor.fetchall()]
            
            conn.close()
            print(f"üìä Found {len(results)} results")
            return results
            
        except Exception as e:
            print(f"‚ùå Database query failed: {e}")
            return []

class GeminiSQLGenerator:
    """Generates SQL queries using Gemini for your PostgreSQL schema"""
    
    def __init__(self, api_key: str, table_info: str = ""):
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel('gemini-2.0-flash-exp')
        self.table_info = table_info
        
        # Dynamic schema based on actual database
        self.schema_info = f"""
        PostgreSQL Database Schema (ACTUAL STRUCTURE):
        
        {table_info}
        
        Query Guidelines:
        - Use the actual table names shown above
        - For JoSAA data, look for tables with names like 'josaa', 'cutoff', 'admission', etc.
        - For NIRF data, look for tables with 'nirf', 'ranking', etc.
        - Use ILIKE for case-insensitive matching
        - When user mentions rank/AIR, use closing_rank >= user_rank for eligibility
        - Always ORDER BY closing_rank ASC and LIMIT 15
        """
    
    def generate_sql(self, user_query: str) -> Optional[str]:
        """Generate SQL query from natural language"""
        
        prompt = f"""
        You are an expert SQL generator for JoSAA counseling and NIRF ranking data in PostgreSQL.
        
        {self.schema_info}
        
        User Query: "{user_query}"
        
        Generate a precise PostgreSQL query to answer this question.
        
        Important Rules:
        1. When user mentions their rank/AIR (like "I have AIR 6000"), use WHERE closing_rank >= 6000
        2. For IIT queries, use WHERE institute ILIKE '%IIT%' or '%Indian Institute of Technology%'
        3. For specific programs, use WHERE program ILIKE '%Computer Science%'
        4. Always ORDER BY closing_rank ASC for best options first
        5. Use LIMIT 15 to avoid too many results
        6. For rank-based queries, focus on closing_rank comparison
        7. Use ILIKE for case-insensitive pattern matching
        
        Examples:
        - "I have AIR 5000" ‚Üí WHERE closing_rank >= 5000
        - "IIT programs" ‚Üí WHERE institute ILIKE '%Indian Institute of Technology%'
        - "Computer Science" ‚Üí WHERE program ILIKE '%Computer Science%'
        
        Return ONLY the SQL query, no explanations or markdown.
        """
        
        try:
            response = self.model.generate_content(prompt)
            sql_query = response.text.strip()
            
            # Clean up the response
            if sql_query.startswith('```'):
                sql_query = sql_query.replace('```sql', '').replace('```', '').strip()
            
            print(f"üîç Generated SQL: {sql_query}")
            return sql_query
            
        except Exception as e:
            print(f"‚ùå SQL generation failed: {e}")
            return None

class GGUFResponseGenerator:
    """Generates responses using your GGUF model (API calls)"""
    
    def __init__(self, gguf_api_endpoint: str = None):
        self.gguf_api_endpoint = gguf_api_endpoint
        self.model_available = gguf_api_endpoint is not None
        
        if self.model_available:
            print(f"ü§ñ GGUF Model API: {gguf_api_endpoint}")
        else:
            print("ü§ñ GGUF Model: Using Gemini simulation (deploy your GGUF for production)")
    
    def generate_response(self, query: str, context_data: List[Dict], gemini_model) -> str:
        """Generate response using GGUF model or Gemini simulation"""
        
        if self.model_available:
            return self._call_gguf_api(query, context_data)
        else:
            return self._simulate_gguf_response(query, context_data, gemini_model)
    
    def _call_gguf_api(self, query: str, context_data: List[Dict]) -> str:
        """Call your deployed GGUF model API"""
        # TODO: Implement actual API call to your GGUF model
        # Example:
        # response = requests.post(self.gguf_api_endpoint, json={
        #     'query': query,
        #     'context': context_data,
        #     'max_tokens': 300
        # })
        # return response.json()['text']
        
        return "GGUF API not implemented yet - using simulation"
    
    def _simulate_gguf_response(self, query: str, context_data: List[Dict], gemini_model) -> str:
        """Simulate GGUF response using Gemini (mimicking your model's style)"""
        
        if not context_data:
            return "Based on your query, I couldn't find matching programs in the database. You might want to broaden your search criteria or check different categories."
        
        # Format context data for Gemini
        context_text = "\n".join([
            f"- {item.get('institute', 'Unknown')}: {item.get('program', 'Unknown')} (Closing rank: {item.get('closing_rank', 'N/A')})"
            for item in context_data[:10]  # Limit context
        ])
        
        prompt = f"""
        You are a specialized NIRF/JoSAA counseling assistant. Generate a helpful response for the user's query.
        
        User Query: "{query}"
        
        Available Data:
        {context_text}
        
        Generate a response that:
        1. Directly answers the user's question
        2. Provides specific recommendations based on the data
        3. Includes practical counseling advice
        4. Uses a helpful, knowledgeable tone
        5. Focuses on actionable insights
        
        Keep the response concise but informative.
        """
        
        try:
            response = gemini_model.generate_content(prompt)
            return response.text.strip()
        except Exception as e:
            print(f"‚ùå Response generation failed: {e}")
            return "I encountered an error generating the response. Please try again."

class ProductionRAGProcessor:
    """Production RAG processor using your existing PostgreSQL database"""
    
    def __init__(self, db_params: Dict[str, str], gemini_api_key: str, gguf_api_endpoint: str = None):
        self.db_manager = PostgreSQLManager(db_params)
        self.sql_generator = GeminiSQLGenerator(gemini_api_key, self.db_manager.table_info)
        self.response_generator = GGUFResponseGenerator(gguf_api_endpoint)
        
        # Gemini model for response generation (if GGUF not available)
        genai.configure(api_key=gemini_api_key)
        self.gemini_model = genai.GenerativeModel('gemini-2.0-flash-exp')
    
    def process_query(self, user_query: str) -> Tuple[str, List[Dict]]:
        """Process user query end-to-end"""
        
        print(f"Query: {user_query}")
        
        # Step 1: Generate SQL query
        sql_query = self.sql_generator.generate_sql(user_query)
        if not sql_query:
            return "I couldn't understand your query. Please rephrase it.", []
        
        # Step 2: Execute query on your PostgreSQL database
        results = self.db_manager.execute_query(sql_query)
        
        # Step 3: Generate response using GGUF model (or simulation)
        response = self.response_generator.generate_response(
            user_query, results, self.gemini_model
        )
        
        return response, results

def main():
    """Main function using your existing PostgreSQL database"""
    
    # Your Supabase PostgreSQL connection parameters (hardcoded)
    DB_PARAMS = {
        'host': "DB_HOST_HERE",
        'database': "postgres",
        'user': "DB_USER_HERE", 
        'password': "DB_PASSWORD_HERE",
        'port': "6543"
    }
    
    GEMINI_API_KEY = "GEMINI_API_KEY_HERE"
    GGUF_API_ENDPOINT = None  # Set to your GGUF model API endpoint when deployed
    
    processor = ProductionRAGProcessor(DB_PARAMS, GEMINI_API_KEY, GGUF_API_ENDPOINT)
    
    print("üöÄ Production RAG System Ready!")
    print("- Database: Supabase PostgreSQL with your real JoSAA/NIRF data")
    print("- SQL Generation: Gemini 2.0 Flash")
    print("- Response Generation: GGUF Model (simulated)")
    
    # Test with the same queries that worked for you
    print("\n" + "="*60)
    print("=== Testing Working System ===")
    
    test_queries = [
        "I have AIR 6000, which IIT programs can I get?",
        "Show me Computer Science programs at IIT Goa",
        "What Mechanical Engineering programs have closing rank below 9000?"
    ]
    
    for i, query in enumerate(test_queries, 1):
        print(f"\n{'-'*60}")
        print(f"Test {i}: {query}")
        print('-'*60)
        
        response, results = processor.process_query(query)
        
        print(f"\nResponse:")
        print(response)
        
        if results:
            print(f"\nData (first 3 of {len(results)}):")
            for j, result in enumerate(results[:3], 1):
                print(f"  {j}. {result.get('institute', 'Unknown')}")
                program = result.get('program', 'Unknown program')
                print(f"     {program[:60]}...")
                print(f"     Closing: {result.get('closing_rank', 'N/A')}")
    
    print(f"\n{'='*60}")
    print(f"System Status:")
    print(f"- Database: ‚úÖ Connected and working")
    print(f"- SQL Generation: ‚úÖ Gemini-powered")
    print(f"- Response Generation: ‚úÖ Gemini (mimicking your model's style)")
    print(f"- Ready for: üöÄ Production deployment")
    
    print(f"\nNext Steps:")
    print(f"- Deploy your GGUF model separately on a server with more resources")
    print(f"- Replace Gemini calls with API calls to your deployed model")
    print(f"- This system provides the complete pipeline architecture")

def interactive_mode():
    """Interactive mode with your PostgreSQL database"""
    
    # Hardcoded Supabase credentials
    DB_PARAMS = {
        'host': "DB_HOST_HERE",
        'database': "postgres",
        'user': "DB_USER_HERE", 
        'password': "DB_PASSWORD_HERE",
        'port': "6543"
    }
    
    GEMINI_API_KEY = "GEMINI_API_KEY_HERE"
    
    processor = ProductionRAGProcessor(DB_PARAMS, GEMINI_API_KEY)
    
    print("\nüéØ Interactive Mode with Supabase PostgreSQL Database")
    print("Ask questions about JoSAA counseling, NIRF rankings, etc.")
    
    while True:
        query = input("\nYour question (or 'quit'): ").strip()
        if query.lower() in ['quit', 'exit']:
            break
        
        response, results = processor.process_query(query)
        print(f"\nResponse:\n{response}")
        
        if results:
            print(f"\nFound {len(results)} programs")

if __name__ == "__main__":
    print("Choose mode:")
    print("(m) Main demo with hardcoded DB params")
    print("(i) Interactive mode")
    
    choice = input("Choice: ").lower()
    if choice == 'i':
        interactive_mode()
    else:
        main()

Choose mode:
(m) Main demo with hardcoded DB params
(i) Interactive mode


‚úÖ PostgreSQL connection successful
üìã Database Schema:
Available Tables:
- institute_mapping
  Columns: id(integer), josaa_name(text), nirf_name(text), confidence_score(real), created_at(timestamp without time zone)...
- josaa_2024
  Columns: id(integer), year(integer), round(integer), institute(text), institute_type(text)...
- josaa_btech_2024
  Columns: id(integer), year(integer), round(integer), institute(text), institute_type(text)...
- josaa_nirf_combined
  Columns: id(integer), year(integer), round(integer), institute(text), institute_type(text)...
- nirf_rankings_2024
  Columns: id(integer), year(integer), category(text), rank(integer), institute(text)...

ü§ñ GGUF Model: Using Gemini simulation (deploy your GGUF for production)
üöÄ Production RAG System Ready!
- Database: Supabase PostgreSQL with your real JoSAA/NIRF data
- SQL Generation: Gemini 2.0 Flash
- Response Generation: GGUF Model (simulated)

=== Testing Working System ===

--------------------------------------

In [14]:
#!/usr/bin/env python3
"""
Working Production RAG System: GGUF Model + Gemini SQL + Supabase PostgreSQL
Fixed version with proper PostgreSQL connection to your real JoSAA/NIRF data
"""

import psycopg2
import psycopg2.extras
import json
from typing import List, Dict, Optional, Tuple
import google.generativeai as genai

class PostgreSQLManager:
    """Handles PostgreSQL operations with your Supabase database"""
    
    def __init__(self, connection_params: Dict[str, str]):
        self.connection_params = connection_params
        self.table_info = ""
        self.connect_and_get_schema()
    
    def connect_and_get_schema(self):
        """Connect to database and get detailed schema info with data types"""
        try:
            conn = psycopg2.connect(**self.connection_params)
            cursor = conn.cursor()
            
            print("‚úÖ PostgreSQL connection successful")
            
            # Get detailed schema information for each table
            schema_info = "DETAILED DATABASE SCHEMA:\n\n"
            
            cursor.execute("""
                SELECT table_name 
                FROM information_schema.tables 
                WHERE table_schema = 'public'
                ORDER BY table_name;
            """)
            tables = [row[0] for row in cursor.fetchall()]
            
            for table in tables:
                schema_info += f"TABLE: {table}\n"
                
                # Get all columns with detailed type information
                cursor.execute("""
                    SELECT column_name, data_type, is_nullable, column_default
                    FROM information_schema.columns 
                    WHERE table_name = %s AND table_schema = 'public'
                    ORDER BY ordinal_position;
                """, (table,))
                
                columns = cursor.fetchall()
                for col_name, data_type, nullable, default in columns:
                    schema_info += f"  - {col_name}: {data_type} {'NULL' if nullable == 'YES' else 'NOT NULL'}\n"
                
                # Get sample data to understand content
                try:
                    cursor.execute(f"SELECT * FROM {table} LIMIT 3;")
                    sample_rows = cursor.fetchall()
                    if sample_rows:
                        schema_info += f"  Sample data: {len(sample_rows)} rows found\n"
                except:
                    schema_info += f"  Sample data: Unable to fetch\n"
                
                schema_info += "\n"
            
            conn.close()
            
            print(f"üìã Detailed Database Schema:\n{schema_info}")
            self.table_info = schema_info
            
        except Exception as e:
            print(f"‚ùå PostgreSQL connection failed: {e}")
            self.table_info = "Could not retrieve table information"
            raise
    
    def execute_query(self, sql_query: str) -> List[Dict]:
        """Execute SQL query and return results"""
        try:
            conn = psycopg2.connect(**self.connection_params)
            cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
            
            cursor.execute(sql_query)
            results = [dict(row) for row in cursor.fetchall()]
            
            conn.close()
            print(f"üìä Found {len(results)} results")
            return results
            
        except Exception as e:
            print(f"‚ùå Database query failed: {e}")
            return []

class GeminiSQLGenerator:
    """Advanced SQL generator using Gemini 2.5 Pro with sophisticated reasoning"""
    
    def __init__(self, api_key: str, table_info: str = ""):
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel('gemini-2.0-flash-exp')
        self.table_info = table_info
    
    def generate_sql(self, user_query: str) -> Optional[str]:
        """Generate SQL query using advanced prompting techniques and handling mixed data"""
        
        advanced_prompt = f"""
You are an expert PostgreSQL query generator with deep understanding of JoSAA counseling data. The rank data contains mixed values like "153P", "5P", "382P" etc.

=== DATABASE SCHEMA ===
{self.table_info}

=== CRITICAL DATA HANDLING ===
IMPORTANT: The closing_rank and opening_rank columns contain:
- Pure numbers: "1234", "5678" 
- Numbers with letters: "153P", "5P", "382P"
- Special values: "-", "NULL", empty strings

For rank comparisons, you MUST:
1. Extract numeric part using: REGEXP_REPLACE(closing_rank, '[^0-9]', '', 'g')
2. Handle non-numeric values: WHERE closing_rank ~ '^[0-9]+'
3. Cast the cleaned value: CAST(REGEXP_REPLACE(closing_rank, '[^0-9]', '', 'g') AS INTEGER)

=== USER QUERY ===
"{user_query}"

=== REASONING PROCESS ===
Follow this step-by-step reasoning:

1. INTENT ANALYSIS: What is the user trying to find?
   - Rank-based eligibility query?
   - Institute-specific search?
   - Program-specific search?

2. TABLE SELECTION: Which table to use?
   - josaa_btech_2024: Best for BTech programs (most complete)
   - josaa_2024: All programs
   - josaa_nirf_combined: Combined with NIRF data

3. RANK HANDLING: For rank comparisons:
   - Filter numeric ranks first: WHERE closing_rank ~ '^[0-9]+'
   - Extract numbers: REGEXP_REPLACE(closing_rank, '[^0-9]', '', 'g')
   - Cast to integer: CAST(REGEXP_REPLACE(closing_rank, '[^0-9]', '', 'g') AS INTEGER)

4. INSTITUTE MATCHING:
   - Full name search: institute ILIKE '%Indian Institute of Technology%'
   - Specific IIT: institute ILIKE '%Indian Institute of Technology Goa%'

=== EXAMPLE QUERIES ===

Example 1: "I have rank 6000, what IIT programs?"
```sql
SELECT institute, program, closing_rank,
       CAST(REGEXP_REPLACE(closing_rank, '[^0-9]', '', 'g') AS INTEGER) as rank_numeric
FROM josaa_btech_2024 
WHERE closing_rank ~ '^[0-9]+'
AND CAST(REGEXP_REPLACE(closing_rank, '[^0-9]', '', 'g') AS INTEGER) >= 6000
AND institute ILIKE '%Indian Institute of Technology%'
ORDER BY CAST(REGEXP_REPLACE(closing_rank, '[^0-9]', '', 'g') AS INTEGER) ASC 
LIMIT 15;
```

Example 2: "Computer Science at IIT Goa"
```sql
SELECT institute, program, closing_rank
FROM josaa_btech_2024 
WHERE institute ILIKE '%Indian Institute of Technology Goa%'
AND program ILIKE '%Computer Science%'
ORDER BY 
  CASE 
    WHEN closing_rank ~ '^[0-9]+' 
    THEN CAST(REGEXP_REPLACE(closing_rank, '[^0-9]', '', 'g') AS INTEGER)
    ELSE 999999 
  END ASC
LIMIT 15;
```

Example 3: "Mechanical Engineering below rank 9000"
```sql
SELECT institute, program, closing_rank,
       CAST(REGEXP_REPLACE(closing_rank, '[^0-9]', '', 'g') AS INTEGER) as rank_numeric
FROM josaa_btech_2024 
WHERE program ILIKE '%Mechanical Engineering%'
AND closing_rank ~ '^[0-9]+'
AND CAST(REGEXP_REPLACE(closing_rank, '[^0-9]', '', 'g') AS INTEGER) <= 9000
ORDER BY CAST(REGEXP_REPLACE(closing_rank, '[^0-9]', '', 'g') AS INTEGER) ASC 
LIMIT 15;
```

=== QUERY GENERATION RULES ===
1. ALWAYS filter numeric ranks first: WHERE closing_rank ~ '^[0-9]+'
2. Extract numbers: REGEXP_REPLACE(closing_rank, '[^0-9]', '', 'g')
3. Cast extracted numbers: CAST(REGEXP_REPLACE(...) AS INTEGER)
4. For rank eligibility: extracted_rank >= user_rank
5. For institute search: ILIKE '%Indian Institute of Technology%'
6. For programs: ILIKE '%Program Name%'
7. ORDER BY the extracted numeric rank
8. LIMIT 15

=== INSTITUTE NAME PATTERNS ===
- IIT queries: '%Indian Institute of Technology%'
- NIT queries: '%National Institute of Technology%'
- IIIT queries: '%Indian Institute of Information Technology%'
- Specific IIT: '%Indian Institute of Technology [City]%'

Now generate the SQL query following this reasoning. Return ONLY the SQL query, no explanations.
"""
        
        try:
            response = self.model.generate_content(advanced_prompt)
            sql_query = response.text.strip()
            
            # Clean up the response
            if sql_query.startswith('```'):
                sql_query = sql_query.replace('```sql', '').replace('```', '').strip()
            
            print(f"üß† Generated SQL (handling mixed rank data): {sql_query}")
            return sql_query
            
        except Exception as e:
            print(f"‚ùå Advanced SQL generation failed: {e}")
            return None

class GGUFResponseGenerator:
    """Generates responses using your GGUF model (API calls)"""
    
    def __init__(self, gguf_api_endpoint: str = None):
        self.gguf_api_endpoint = gguf_api_endpoint
        self.model_available = gguf_api_endpoint is not None
        
        if self.model_available:
            print(f"ü§ñ GGUF Model API: {gguf_api_endpoint}")
        else:
            print("ü§ñ GGUF Model: Using Gemini simulation (deploy your GGUF for production)")
    
    def generate_response(self, query: str, context_data: List[Dict], gemini_model) -> str:
        """Generate response using GGUF model or Gemini simulation"""
        
        if self.model_available:
            return self._call_gguf_api(query, context_data)
        else:
            return self._simulate_gguf_response(query, context_data, gemini_model)
    
    def _call_gguf_api(self, query: str, context_data: List[Dict]) -> str:
        """Call your deployed GGUF model API"""
        # TODO: Implement actual API call to your GGUF model
        # Example:
        # response = requests.post(self.gguf_api_endpoint, json={
        #     'query': query,
        #     'context': context_data,
        #     'max_tokens': 300
        # })
        # return response.json()['text']
        
        return "GGUF API not implemented yet - using simulation"
    
    def _simulate_gguf_response(self, query: str, context_data: List[Dict], gemini_model) -> str:
        """Simulate GGUF response using Gemini (mimicking your model's style)"""
        
        if not context_data:
            return "Based on your query, I couldn't find matching programs in the database. You might want to broaden your search criteria or check different categories."
        
        # Format context data for Gemini
        context_text = "\n".join([
            f"- {item.get('institute', 'Unknown')}: {item.get('program', 'Unknown')} (Closing rank: {item.get('closing_rank', 'N/A')})"
            for item in context_data[:10]  # Limit context
        ])
        
        prompt = f"""
        You are a specialized NIRF/JoSAA counseling assistant. Generate a helpful response for the user's query.
        
        User Query: "{query}"
        
        Available Data:
        {context_text}
        
        Generate a response that:
        1. Directly answers the user's question
        2. Provides specific recommendations based on the data
        3. Includes practical counseling advice
        4. Uses a helpful, knowledgeable tone
        5. Focuses on actionable insights
        
        Keep the response concise but informative.
        """
        
        try:
            response = gemini_model.generate_content(prompt)
            return response.text.strip()
        except Exception as e:
            print(f"‚ùå Response generation failed: {e}")
            return "I encountered an error generating the response. Please try again."

class ProductionRAGProcessor:
    """Production RAG processor using your existing PostgreSQL database"""
    
    def __init__(self, db_params: Dict[str, str], gemini_api_key: str, gguf_api_endpoint: str = None):
        self.db_manager = PostgreSQLManager(db_params)
        self.sql_generator = GeminiSQLGenerator(gemini_api_key, self.db_manager.table_info)
        self.response_generator = GGUFResponseGenerator(gguf_api_endpoint)
        
        # Gemini model for response generation (if GGUF not available)
        genai.configure(api_key=gemini_api_key)
        self.gemini_model = genai.GenerativeModel('gemini-2.0-flash-exp')
    
    def process_query(self, user_query: str) -> Tuple[str, List[Dict]]:
        """Process user query end-to-end"""
        
        print(f"Query: {user_query}")
        
        # Step 1: Generate SQL query
        sql_query = self.sql_generator.generate_sql(user_query)
        if not sql_query:
            return "I couldn't understand your query. Please rephrase it.", []
        
        # Step 2: Execute query on your PostgreSQL database
        results = self.db_manager.execute_query(sql_query)
        
        # Step 3: Generate response using GGUF model (or simulation)
        response = self.response_generator.generate_response(
            user_query, results, self.gemini_model
        )
        
        return response, results

def main():
    """Main function using your existing PostgreSQL database"""
    
    # Your Supabase PostgreSQL connection parameters (hardcoded)
    DB_PARAMS = {
        'host': "DB_HOST_HERE",
        'database': "postgres",
        'user': "DB_USER_HERE", 
        'password': "DB_PASSWORD_HERE",
        'port': "6543"
    }
    
    GEMINI_API_KEY = "GEMINI_API_KEY_HERE"
    GGUF_API_ENDPOINT = None  # Set to your GGUF model API endpoint when deployed
    
    processor = ProductionRAGProcessor(DB_PARAMS, GEMINI_API_KEY, GGUF_API_ENDPOINT)
    
    print("üöÄ Production RAG System Ready!")
    print("- Database: Supabase PostgreSQL with your real JoSAA/NIRF data")
    print("- SQL Generation: Gemini 2.0 Flash")
    print("- Response Generation: GGUF Model (simulated)")
    
    # Test with the same queries that worked for you
    print("\n" + "="*60)
    print("=== Testing Working System ===")
    
    test_queries = [
        "I have AIR 6000, which IIT programs can I get?",
        "Show me Computer Science programs at IIT Goa",
        "What Mechanical Engineering programs have closing rank below 9000?"
    ]
    
    for i, query in enumerate(test_queries, 1):
        print(f"\n{'-'*60}")
        print(f"Test {i}: {query}")
        print('-'*60)
        
        response, results = processor.process_query(query)
        
        print(f"\nResponse:")
        print(response)
        
        if results:
            print(f"\nData (first 3 of {len(results)}):")
            for j, result in enumerate(results[:3], 1):
                print(f"  {j}. {result.get('institute', 'Unknown')}")
                program = result.get('program', 'Unknown program')
                print(f"     {program[:60]}...")
                print(f"     Closing: {result.get('closing_rank', 'N/A')}")
    
    print(f"\n{'='*60}")
    print(f"System Status:")
    print(f"- Database: ‚úÖ Connected and working")
    print(f"- SQL Generation: ‚úÖ Gemini-powered")
    print(f"- Response Generation: ‚úÖ Gemini (mimicking your model's style)")
    print(f"- Ready for: üöÄ Production deployment")
    
    print(f"\nNext Steps:")
    print(f"- Deploy your GGUF model separately on a server with more resources")
    print(f"- Replace Gemini calls with API calls to your deployed model")
    print(f"- This system provides the complete pipeline architecture")

def interactive_mode():
    """Interactive mode with your PostgreSQL database"""
    
    # Hardcoded Supabase credentials
    DB_PARAMS = {
        'host': "DB_HOST_HERE",
        'database': "postgres",
        'user': "DB_USER_HERE", 
        'password': "DB_PASSWORD_HERE",
        'port': "6543"
    }
    
    GEMINI_API_KEY = "GEMINI_API_KEY_HERE"
    
    processor = ProductionRAGProcessor(DB_PARAMS, GEMINI_API_KEY)
    
    print("\nüéØ Interactive Mode with Supabase PostgreSQL Database")
    print("Ask questions about JoSAA counseling, NIRF rankings, etc.")
    
    while True:
        query = input("\nYour question (or 'quit'): ").strip()
        if query.lower() in ['quit', 'exit']:
            break
        
        response, results = processor.process_query(query)
        print(f"\nResponse:\n{response}")
        
        if results:
            print(f"\nFound {len(results)} programs")

if __name__ == "__main__":
    print("Choose mode:")
    print("(m) Main demo with hardcoded DB params")
    print("(i) Interactive mode")
    
    choice = input("Choice: ").lower()
    if choice == 'i':
        interactive_mode()
    else:
        main()

Choose mode:
(m) Main demo with hardcoded DB params
(i) Interactive mode
‚úÖ PostgreSQL connection successful
üìã Detailed Database Schema:
DETAILED DATABASE SCHEMA:

TABLE: institute_mapping
  - id: integer NOT NULL
  - josaa_name: text NOT NULL
  - nirf_name: text NOT NULL
  - confidence_score: real NULL
  - created_at: timestamp without time zone NULL
  Sample data: 3 rows found

TABLE: josaa_2024
  - id: integer NOT NULL
  - year: integer NULL
  - round: integer NULL
  - institute: text NULL
  - institute_type: text NULL
  - program: text NULL
  - quota: text NULL
  - category: text NULL
  - gender: text NULL
  - opening_rank: text NULL
  - closing_rank: text NULL
  - created_at: timestamp without time zone NULL
  Sample data: 3 rows found

TABLE: josaa_btech_2024
  - id: integer NOT NULL
  - year: integer NOT NULL
  - round: integer NOT NULL
  - institute: text NOT NULL
  - institute_type: text NOT NULL
  - program: text NOT NULL
  - quota: text NOT NULL
  - category: text NOT NU

In [21]:
#!/usr/bin/env python3
"""
Dynamic Supabase PostgreSQL schema + samples (robust)
- Lists public tables
- Prints columns (name, type, nullability, default)
- Row count (fast estimate by default; exact optional)
- Live sample rows
- Handles statement timeouts cleanly (ROLLBACK on error)
"""

from typing import List, Dict, Optional
from contextlib import contextmanager
import psycopg2
import psycopg2.extras
from psycopg2 import sql

# ====== HARDCODED CREDS (as requested) ======
DB_PARAMS = {
    "host": "DB_HOST_HERE",
    "database": "postgres",
    "user": "DB_USER_HERE",
    "password": "DB_PASSWORD_HERE",
    "port": "6543",
    "sslmode": "require",
    "connect_timeout": 10,
}

SCHEMA = "public"
SAMPLE_ROWS = 5
QUERY_TIMEOUT_MS = 5000  # 5s per statement


@contextmanager
def get_conn():
    conn = psycopg2.connect(**DB_PARAMS)
    # Use autocommit False (default) so we can SET LOCAL and ROLLBACK after errors
    try:
        yield conn
    finally:
        conn.close()


def set_stmt_timeout(cur, ms: int = QUERY_TIMEOUT_MS):
    # SET LOCAL requires an active txn; call at the start of each operation
    cur.execute("BEGIN;")
    cur.execute(sql.SQL("SET LOCAL statement_timeout = %s;"), [f"{ms}ms"])


def safe_rollback(conn):
    try:
        conn.rollback()
    except Exception:
        pass


def get_public_tables(cur) -> List[str]:
    cur.execute("""
        SELECT table_name
        FROM information_schema.tables
        WHERE table_schema = %s
        ORDER BY table_name;
    """, (SCHEMA,))
    return [r[0] for r in cur.fetchall()]


def get_columns(cur, table: str) -> List[Dict]:
    set_stmt_timeout(cur)
    try:
        cur.execute("""
            SELECT column_name, data_type, is_nullable, column_default
            FROM information_schema.columns
            WHERE table_schema = %s AND table_name = %s
            ORDER BY ordinal_position;
        """, (SCHEMA, table))
        rows = cur.fetchall()
        cur.execute("COMMIT;")
        return [
            {
                "name": name,
                "type": dtype,
                "nullable": (nullable == "YES"),
                "default": default
            }
            for (name, dtype, nullable, default) in rows
        ]
    except Exception:
        safe_rollback(cur.connection)
        return []


def get_rowcount_estimate(cur, table: str) -> int:
    """Fast estimate via pg_class.reltuples"""
    set_stmt_timeout(cur)
    try:
        cur.execute("""
            SELECT COALESCE(c.reltuples::bigint, 0)
            FROM pg_class c
            JOIN pg_namespace n ON n.oid = c.relnamespace
            WHERE n.nspname = %s AND c.relname = %s;
        """, (SCHEMA, table))
        est = cur.fetchone()
        cur.execute("COMMIT;")
        return int(est[0]) if est and est[0] is not None else -1
    except Exception:
        safe_rollback(cur.connection)
        return -1


def get_rowcount_exact(cur, table: str, timeout_ms: int = QUERY_TIMEOUT_MS) -> Optional[int]:
    """Try exact COUNT(*) with timeout; if it fails, return None (don‚Äôt abort session)."""
    set_stmt_timeout(cur, timeout_ms)
    try:
        cur.execute(
            sql.SQL("SELECT COUNT(*) FROM {}.{};")
            .format(sql.Identifier(SCHEMA), sql.Identifier(table))
        )
        n = cur.fetchone()[0]
        cur.execute("COMMIT;")
        return int(n)
    except Exception:
        safe_rollback(cur.connection)
        return None


def get_samples(conn, table: str, limit: int = SAMPLE_ROWS) -> List[Dict]:
    """Fetch sample rows in a fresh cursor with its own timeout."""
    with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
        set_stmt_timeout(cur)
        try:
            cur.execute(
                sql.SQL("SELECT * FROM {}.{} LIMIT %s;")
                .format(sql.Identifier(SCHEMA), sql.Identifier(table)),
                (limit,)
            )
            rows = [dict(r) for r in cur.fetchall()]
            cur.execute("COMMIT;")
            return rows
        except Exception as e:
            safe_rollback(conn)
            return [{"__error__": str(e)}]


def truncate_vals(row: Dict, maxlen: int = 120) -> Dict:
    out = {}
    for k, v in row.items():
        s = str(v)
        out[k] = s if len(s) <= maxlen else s[: maxlen - 3] + "..."
    return out


def pretty_print_table(name: str, cols: List[Dict], count_txt: str, samples: List[Dict]):
    print(f"\nTABLE: {name}")
    for c in cols:
        null_txt = "NULL" if c.get("nullable") else "NOT NULL"
        default_txt = f" DEFAULT {c['default']}" if c.get("default") is not None else ""
        print(f"  - {c.get('name')}: {c.get('type')} {null_txt}{default_txt}")
    print(f"  Row count: {count_txt}")

    if samples:
        print(f"  Sample rows (up to {len(samples)}):")
        for i, row in enumerate(samples, 1):
            print(f"    {i}. {truncate_vals(row)}")
    else:
        print("  No sample rows found.")


def main():
    with get_conn() as conn:
        with conn.cursor() as cur:
            # Server version (session-level, no need for timeout)
            cur.execute("SHOW server_version;")
            print(f"‚úÖ PostgreSQL version: {cur.fetchone()[0]}")

            tables = get_public_tables(cur)
            print("\nüìã Public tables:", tables if tables else "(none)")

            for t in tables:
                cols = get_columns(cur, t)

                # Prefer fast estimate. If very small estimate (< 200k), try exact count quickly.
                est = get_rowcount_estimate(cur, t)
                count_txt = f"(estimate) {est}" if est >= 0 else "(estimate unavailable)"

                if est >= 0 and est < 200_000:
                    exact = get_rowcount_exact(cur, t, timeout_ms=3000)
                    if exact is not None:
                        count_txt = f"{exact}"

                samples = get_samples(conn, t, SAMPLE_ROWS)
                pretty_print_table(t, cols, count_txt, samples)


if __name__ == "__main__":
    main()


‚úÖ PostgreSQL version: 17.6

üìã Public tables: ['institute_mapping', 'josaa_2024', 'josaa_btech_2024', 'josaa_nirf_combined', 'nirf_rankings_2024']

TABLE: institute_mapping
  - id: integer NOT NULL DEFAULT nextval('institute_mapping_id_seq'::regclass)
  - josaa_name: text NOT NULL
  - nirf_name: text NOT NULL
  - confidence_score: real NULL DEFAULT 1.0
  - created_at: timestamp without time zone NULL DEFAULT CURRENT_TIMESTAMP
  Row count: (estimate unavailable)
  Sample rows (up to 5):
    1. {'id': '1', 'josaa_name': 'Indian Institute of Technology Madras', 'nirf_name': 'Indian Institute of Technology Madras', 'confidence_score': '1.0', 'created_at': '2025-09-27 09:50:19.732163'}
    2. {'id': '2', 'josaa_name': 'Indian Institute  of Technology Madras', 'nirf_name': 'Indian Institute of Technology Madras', 'confidence_score': '1.0', 'created_at': '2025-09-27 09:50:19.732163'}
    3. {'id': '3', 'josaa_name': 'Indian Institute of Technology Delhi', 'nirf_name': 'Indian Institute of

In [1]:
#!/usr/bin/env python3
"""
Two-Stage SQL RAG Pipeline:
A) Gemini Flash: query enhancement (2-3 variations)
B) Gemini 2.5 Pro: robust SQL generation (schema-aware)
C) DB execution: safe SELECT-only, pick the best result
D) GGUF (or Gemini sim): final natural language answer

- Hardcoded Supabase Postgres credentials (as requested)
- Hardcoded Gemini API key (as in your previous code)
"""

import re
from typing import List, Dict, Optional, Tuple

import psycopg2
import psycopg2.extras
from psycopg2.pool import SimpleConnectionPool

import google.generativeai as genai


# ========================
# HARD-CODED CREDENTIALS
# ========================
DB = {
    "host": "DB_HOST_HERE",
    "database": "postgres",
    "user": "DB_USER_HERE",
    "password": "DB_PASSWORD_HERE",
    "port": "6543",
    "sslmode": "require",
    "connect_timeout": 10,
}
GEMINI_API_KEY = "GEMINI_API_KEY_HERE"


# ========================
# CONSTANTS & HELPERS
# ========================
RANK_NUM_EXPR = (
    "CAST(NULLIF(regexp_replace(trim(closing_rank), '[^0-9]', '', 'g'), '') AS INTEGER)"
)

ALLOWED_READ = re.compile(r"^\s*(select|with)\b", re.IGNORECASE | re.DOTALL)
DANGEROUS = re.compile(
    r"\b(insert|update|delete|drop|alter|truncate|create|grant|revoke|copy|vacuum|analyze)\b",
    re.IGNORECASE,
)

def single_statement(sql: str) -> bool:
    return sql.strip().count(";") <= 1

def ensure_limit(sql: str, hard_limit: int = 15) -> str:
    if re.search(r"\blimit\b", sql, re.IGNORECASE):
        return sql
    return f"{sql.rstrip(';')} LIMIT {hard_limit};"

def sanitize_select(sql: str, hard_limit: int = 15) -> str:
    sql = sql.strip()
    if not single_statement(sql):
        raise ValueError("Multiple SQL statements not allowed.")
    if not ALLOWED_READ.match(sql):
        raise ValueError("Only SELECT/CTE read queries are allowed.")
    if DANGEROUS.search(sql):
        raise ValueError("Potentially dangerous SQL token detected.")
    return ensure_limit(sql, hard_limit)


# ========================
# DB LAYER
# ========================
class Pg:
    def __init__(self, params: Dict[str, str]):
        self.pool = SimpleConnectionPool(minconn=1, maxconn=6, **params)

    def _conn(self):
        conn = self.pool.getconn()
        with conn.cursor() as c:
            c.execute("SET LOCAL statement_timeout = '12000ms';")
            c.execute("SET LOCAL default_transaction_read_only = on;")
        return conn

    def fetch_schema_text(self) -> str:
        conn = self._conn()
        try:
            cur = conn.cursor()
            out = ["SCHEMA: public\n"]
            cur.execute("""
                SELECT table_name
                FROM information_schema.tables
                WHERE table_schema = 'public'
                ORDER BY table_name;
            """)
            for (t,) in cur.fetchall():
                out.append(f"TABLE: {t}")
                cur.execute("""
                    SELECT column_name, data_type, is_nullable
                    FROM information_schema.columns
                    WHERE table_schema='public' AND table_name=%s
                    ORDER BY ordinal_position;
                """, (t,))
                for name, dtype, nullable in cur.fetchall():
                    out.append(f"  - {name}: {dtype} {'NULL' if nullable=='YES' else 'NOT NULL'}")
                out.append("")
            return "\n".join(out)
        finally:
            self.pool.putconn(conn)

    def run(self, sql: str) -> List[Dict]:
        conn = self._conn()
        try:
            with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
                cur.execute(sql)
                return [dict(r) for r in cur.fetchall()]
        finally:
            self.pool.putconn(conn)


# ========================
# LLM STAGE A: FLASH ENHANCER
# ========================
class QueryEnhancer:
    """
    Uses a fast Gemini model to expand the user's query into 2‚Äì3
    semantically-similar, clarified variants.
    """
    def __init__(self, api_key: str):
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel("gemini-2.0-flash-exp")

    def enhance(self, user_query: str) -> List[str]:
        prompt = f"""
You are a helpful assistant that rewrites the user's database question into 2‚Äì3
clear, semantically similar variants that might help retrieval.

Rules:
- Keep intent the same.
- Clarify institute types (IIT/NIT/IIIT) if mentioned.
- Add obvious filters (quota/category/gender) only if the user explicitly asked; otherwise keep generic.
- Keep each variant 1 sentence. Output each variant on its own line. No bullets, no numbering.

User query: "{user_query}"
Provide 2‚Äì3 rewritten variants (one per line).
"""
        res = self.model.generate_content(prompt)
        text = (res.text or "").strip()
        variants = [line.strip() for line in text.splitlines() if line.strip()]
        # clamp to 3
        return variants[:3] if variants else [user_query]


# ========================
# LLM STAGE B: 2.5 PRO SQL GENERATOR
# ========================
class SqlGenPro:
    """
    Uses Gemini 2.5 Pro with strong, schema-aware prompting to generate SQL.
    We tell it about TEXT ranks and how to cast them; no imaginary columns.
    """
    def __init__(self, api_key: str, schema_text: str):
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel("gemini-2.5-pro")
        self.schema_text = schema_text

    def to_sql(self, nl_query: str) -> Optional[str]:
        prompt = f"""
You are an expert PostgreSQL writer. Only output a single SELECT (optionally WITH).

DATABASE SCHEMA (public):
{self.schema_text}

IMPORTANT DATA RULES:
- opening_rank and closing_rank are TEXT. Derive numeric like:
  {RANK_NUM_EXPR}
- No rank_num column exists; compute inline as needed.
- Use ILIKE for case-insensitive matches.
- For eligibility with AIR R: use derived closing_rank_num >= R.
- Prefer table 'josaa_btech_2024' for B.Tech queries (it has year, round, institute, institute_type, program, quota, category, gender, opening_rank, closing_rank).
- Always ORDER BY the derived numeric rank when ranking results.
- Always LIMIT 15.
- Return ONLY the SQL, no commentary.

User query: "{nl_query}"
SQL:
"""
        try:
            res = self.model.generate_content(prompt)
            sql = (res.text or "").strip()
            if sql.startswith("```"):
                sql = sql.replace("```sql", "").replace("```", "").strip()
            return sql
        except Exception as e:
            print("SQL gen error:", e)
            return None


# ========================
# STAGE C: EXECUTION & SELECTION
# ========================
def pick_best_result(candidates: List[Tuple[str, List[Dict]]]) -> Tuple[Optional[str], List[Dict]]:
    """
    Choose the best (sql, rows) among candidates.
    Heuristic: prefer non-empty; tie-breaker = more rows; else first.
    """
    non_empty = [c for c in candidates if len(c[1]) > 0]
    if non_empty:
        non_empty.sort(key=lambda x: len(x[1]), reverse=True)
        return non_empty[0]
    return candidates[0] if candidates else (None, [])


# ========================
# STAGE D: GGUF (SIM) ANSWER
# ========================
class Answerer:
    """
    If you have a GGUF API, wire it here. For now, we simulate with Gemini Flash.
    """
    def __init__(self, api_key: str, gguf_endpoint: Optional[str] = None):
        self.gguf_endpoint = gguf_endpoint
        genai.configure(api_key=api_key)
        self.sim = genai.GenerativeModel("gemini-2.0-flash-exp")

    def answer(self, user_query: str, rows: List[Dict]) -> str:
        if not rows:
            return ("I couldn't find matching rows for that query. "
                    "Try broadening filters (institute/program/category) or adjusting the rank.")
        bullet = "\n".join([
            f"- {r.get('institute','N/A')}: {r.get('program','N/A')} "
            f"(closing_rank={r.get('closing_rank','?')})"
            for r in rows[:10]
        ])
        prompt = f"""
You are a JoSAA counseling assistant. Based on the rows below, answer the user briefly,
with 2‚Äì3 actionable tips (category/round effects, choice filling strategy).

User query: "{user_query}"

Rows:
{bullet}
"""
        res = self.sim.generate_content(prompt)
        return (res.text or "").strip()


# ========================
# ORCHESTRATOR
# ========================
class Pipeline:
    def __init__(self):
        self.pg = Pg(DB)
        self.schema_text = self.pg.fetch_schema_text()
        print("‚úÖ Loaded live schema from DB")

        self.enhancer = QueryEnhancer(GEMINI_API_KEY)
        self.sqlpro = SqlGenPro(GEMINI_API_KEY, self.schema_text)
        self.answerer = Answerer(GEMINI_API_KEY)

    def run(self, user_query: str) -> Tuple[str, str, List[Dict]]:
        # A) enhance
        variants = self.enhancer.enhance(user_query)
        if not variants:
            variants = [user_query]
        print("\nüîé Enhanced variants:")
        for v in variants:
            print(" -", v)

        # B) SQL for each variant
        sqls = []
        for v in variants:
            sql = self.sqlpro.to_sql(v)
            if not sql:
                continue
            try:
                sanitized = sanitize_select(sql, hard_limit=15)
            except Exception as e:
                print("Sanitize refused:", e)
                continue
            sqls.append(sanitized)

        if not sqls:
            return ("I couldn't produce a safe SQL for that.", "", [])

        # C) Execute each, collect results
        candidates: List[Tuple[str, List[Dict]]] = []
        for s in sqls:
            try:
                rows = self.pg.run(s)
                print(f"  ‚úÖ Executed, rows={len(rows)}")
                candidates.append((s, rows))
            except Exception as e:
                print("  ‚ùå Exec error:", e)

        if not candidates:
            return ("All generated SQLs failed to run.", sqls[0], [])

        best_sql, best_rows = pick_best_result(candidates)

        # D) Answer
        answer = self.answerer.answer(user_query, best_rows)
        return answer, best_sql, best_rows


# ========================
# DEMO
# ========================
def main():
    pipe = Pipeline()

    tests = [
        "I have AIR 6000, which IIT programs can I get?",
        "Show me Computer Science programs at IIT Goa",
        "Mechanical Engineering with closing rank below 9000",
    ]

    for i, q in enumerate(tests, 1):
        print("\n" + "-"*60)
        print(f"Test {i}: {q}")
        print("-"*60)
        answer, sql, rows = pipe.run(q)
        print("\nüß† Best SQL:")
        print(sql)
        print("\nüí¨ Answer:")
        print(answer)
        if rows:
            print("\nüìä First 3 rows:")
            for r in rows[:3]:
                print(" ", {k: r[k] for k in r.keys() & {'institute','program','quota','category','gender','closing_rank'}})

if __name__ == "__main__":
    main()


‚úÖ Loaded live schema from DB

------------------------------------------------------------
Test 1: I have AIR 6000, which IIT programs can I get?
------------------------------------------------------------

üîé Enhanced variants:
 - What programs in Indian Institutes of Technology (IITs) are available for a rank of 6000?
 - Which IIT courses can I apply to with an All India Rank of 6000?
 - Based on a rank of 6000, what are the possible IIT degree options?
  ‚úÖ Executed, rows=0
  ‚úÖ Executed, rows=0
  ‚úÖ Executed, rows=0

üß† Best SQL:
WITH josaa_ranks AS (
  SELECT
    institute,
    program,
    quota,
    category,
    gender,
    opening_rank,
    closing_rank,
    CAST(NULLIF(regexp_replace(trim(closing_rank), '[^0-9]', '', 'g'), '') AS INTEGER) AS closing_rank_num
  FROM
    josaa_btech_2024
  WHERE
    institute_type ILIKE 'Indian Institute of Technology'
)
SELECT
  institute,
  program,
  quota,
  category,
  gender,
  opening_rank,
  closing_rank
FROM
  josaa_ranks
WHE

In [2]:
#!/usr/bin/env python3
"""
Two-Stage SQL RAG Pipeline (with defaults, self-critique & live schema)
A) Gemini Flash: query enhancement (2‚Äì3 variants) w/ AIR defaults if unspecified
B) Gemini 2.5 Pro: robust SQL generation (schema-aware) + self-critique/fix
C) DB: safe SELECT/CTE-only execution; pick best result
D) GGUF (or Gemini sim): final answer

Hardcoded:
- Supabase Postgres credentials
- Gemini API key

IMPORTANT: This script does not mutate DB; read-only per statement.
"""

import re
from typing import List, Dict, Optional, Tuple

import psycopg2
import psycopg2.extras
from psycopg2.pool import SimpleConnectionPool
import google.generativeai as genai


# ========================
# HARD-CODED CREDENTIALS
# ========================
DB = {
    "host": "DB_HOST_HERE",
    "database": "postgres",
    "user": "DB_USER_HERE",
    "password": "DB_PASSWORD_HERE",
    "port": "6543",
    "sslmode": "require",
    "connect_timeout": 10,
}
GEMINI_API_KEY = "GEMINI_API_KEY_HERE"


# ========================
# CONSTANTS & HELPERS
# ========================
RANK_NUM_EXPR = (
    "CAST(NULLIF(regexp_replace(trim(closing_rank), '[^0-9]', '', 'g'), '') AS INTEGER)"
)

ALLOWED_READ = re.compile(r"^\s*(select|with)\b", re.IGNORECASE | re.DOTALL)
DANGEROUS = re.compile(
    r"\b(insert|update|delete|drop|alter|truncate|create|grant|revoke|copy|vacuum|analyze)\b",
    re.IGNORECASE,
)

def single_statement(sql: str) -> bool:
    return sql.strip().count(";") <= 1

def ensure_limit(sql: str, hard_limit: int = 15) -> str:
    if re.search(r"\blimit\b", sql, re.IGNORECASE):
        return sql
    return f"{sql.rstrip(';')} LIMIT {hard_limit};"

def sanitize_select(sql: str, hard_limit: int = 15) -> str:
    sql = sql.strip()
    if not single_statement(sql):
        raise ValueError("Multiple SQL statements not allowed.")
    if not ALLOWED_READ.match(sql):
        raise ValueError("Only SELECT/CTE read queries are allowed.")
    if DANGEROUS.search(sql):
        raise ValueError("Potentially dangerous SQL token detected.")
    return ensure_limit(sql, hard_limit)


# ========================
# DB LAYER
# ========================
class Pg:
    def __init__(self, params: Dict[str, str]):
        self.pool = SimpleConnectionPool(minconn=1, maxconn=6, **params)

    def _conn(self):
        conn = self.pool.getconn()
        with conn.cursor() as c:
            c.execute("SET LOCAL statement_timeout = '12000ms';")
            c.execute("SET LOCAL default_transaction_read_only = on;")
        return conn

    def fetch_schema_text(self) -> str:
        conn = self._conn()
        try:
            cur = conn.cursor()
            out = ["SCHEMA: public\n"]
            cur.execute("""
                SELECT table_name
                FROM information_schema.tables
                WHERE table_schema = 'public'
                ORDER BY table_name;
            """)
            for (t,) in cur.fetchall():
                out.append(f"TABLE: {t}")
                cur.execute("""
                    SELECT column_name, data_type, is_nullable
                    FROM information_schema.columns
                    WHERE table_schema='public' AND table_name=%s
                    ORDER BY ordinal_position;
                """, (t,))
                for name, dtype, nullable in cur.fetchall():
                    out.append(f"  - {name}: {dtype} {'NULL' if nullable=='YES' else 'NOT NULL'}")
                out.append("")
            return "\n".join(out)
        finally:
            self.pool.putconn(conn)

    def run(self, sql: str) -> List[Dict]:
        conn = self._conn()
        try:
            with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
                cur.execute(sql)
                return [dict(r) for r in cur.fetchall()]
        finally:
            self.pool.putconn(conn)


# ========================
# STAGE A: FLASH ENHANCER
# ========================
class QueryEnhancer:
    """
    Expands the user's query into 2‚Äì3 variants.
    If it detects AIR/eligibility and no category/gender explicitly given,
    injects defaults: OPEN, Gender-Neutral, quota AI, exclude PwD, final round.
    """
    def __init__(self, api_key: str):
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel("gemini-2.0-flash-exp")

    def _air_defaults_hint(self) -> str:
        return (
            "If the user mentions an AIR or eligibility but does not specify "
            "category/gender/quota/round, add the following clarifying parenthetical "
            "to the rewritten variant (do not add if user explicitly provided them): "
            "(assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)."
        )

    def enhance(self, user_query: str) -> List[str]:
        prompt = f"""
You rewrite the user's database question into 2‚Äì3 clear, semantically similar variants.

Rules:
- Keep the same intent.
- Be concise, one line per variant, no bullets/numbers.
- Clarify institute types (IIT/NIT/IIIT) only if mentioned.
- Do NOT fabricate constraints that the user did not give.
- {self._air_defaults_hint()}

User query: "{user_query}"
Variants (one per line, 2‚Äì3 lines total):
"""
        res = self.model.generate_content(prompt)
        text = (res.text or "").strip()
        variants = [line.strip() for line in text.splitlines() if line.strip()]
        if not variants:
            return [user_query]
        return variants[:3]


# ========================
# STAGE B: 2.5 PRO SQL + SELF-CRITIQUE
# ========================
class SqlGenPro:
    """
    Generates SQL using Gemini 2.5 Pro with schema-aware rules,
    then runs a self-critique to fix common omissions.
    """
    def __init__(self, api_key: str, schema_text: str):
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel("gemini-2.5-pro")
        self.schema_text = schema_text

    def to_sql(self, nl_query: str) -> Optional[str]:
        rules = f"""
DATA RULES:
- opening_rank and closing_rank are TEXT. Use:
  {RANK_NUM_EXPR} AS closing_rank_num
- No 'rank_num' column exists.
- Use ILIKE for case-insensitive matches.
- Prefer table 'josaa_btech_2024' for B.Tech queries.
- If the user asks eligibility via AIR and does NOT specify category/gender/quota/round:
  * Add: quota = 'AI'
  * Add: category = 'OPEN'
  * Add: category NOT ILIKE '%PwD%'
  * Add: gender = 'Gender-Neutral'
  * Add: round = (SELECT MAX(round) FROM josaa_btech_2024 WHERE year = 2024)
- When filtering IITs, do NOT rely on institute_type equality; match by name:
  (institute ILIKE '%Indian Institute of Technology%' OR institute ILIKE 'IIT %')
- For eligibility 'AIR = R': use closing_rank_num >= R.
- ALWAYS ORDER BY closing_rank_num ASC and LIMIT 15.
- Return ONLY a single SQL SELECT (WITH allowed), no commentary.
"""
        prompt = f"""
You are an expert PostgreSQL writer.

DATABASE SCHEMA (public):
{self.schema_text}

{rules}

User query: "{nl_query}"

SQL:
"""
        try:
            res = self.model.generate_content(prompt)
            sql = (res.text or "").strip()
            if sql.startswith("```"):
                sql = sql.replace("```sql", "").replace("```", "").strip()
            if not sql:
                return None
            # self-critique / fix
            fixed = self.critique_fix(sql, nl_query, rules)
            return fixed or sql
        except Exception as e:
            print("SQL gen error:", e)
            return None

    def critique_fix(self, sql: str, nl_query: str, rules: str) -> Optional[str]:
        """
        Ask 2.5 Pro to validate and, if needed, correct the SQL based on policy.
        """
        critic = f"""
You are a PostgreSQL SQL critic/fixer.

Policy:
{rules}

Given the user query and the candidate SQL, check for:
- Proper numeric derivation of closing_rank (closing_rank_num).
- Defaults for AIR eligibility (OPEN, GN, AI, exclude PwD, final round) when unspecified.
- IIT name matching via institute ILIKE, not hard equality on institute_type.
- ORDER BY closing_rank_num ASC and LIMIT 15.
- Only a single SELECT (WITH allowed).

If the SQL is fully compliant, output ONLY the original SQL.
If not, output ONLY a corrected SQL that is compliant.

User query: "{nl_query}"

Candidate SQL:
{sql}

Output ONLY the SQL:
"""
        try:
            res = self.model.generate_content(critic)
            fixed = (res.text or "").strip()
            if fixed and fixed != sql:
                if fixed.startswith("```"):
                    fixed = fixed.replace("```sql", "").replace("```", "").strip()
                return fixed
            return None
        except Exception as e:
            print("Critic error:", e)
            return None


# ========================
# STAGE C: EXECUTION & SELECTION
# ========================
def pick_best_result(candidates: List[Tuple[str, List[Dict]]]) -> Tuple[Optional[str], List[Dict]]:
    non_empty = [c for c in candidates if len(c[1]) > 0]
    if non_empty:
        non_empty.sort(key=lambda x: len(x[1]), reverse=True)
        return non_empty[0]
    return candidates[0] if candidates else (None, [])


# ========================
# STAGE D: GGUF (SIM) ANSWER
# ========================
class Answerer:
    """
    If you have a GGUF API, wire it here. For now, we simulate with Gemini Flash.
    """
    def __init__(self, api_key: str, gguf_endpoint: Optional[str] = None):
        self.gguf_endpoint = gguf_endpoint
        genai.configure(api_key=api_key)
        self.sim = genai.GenerativeModel("gemini-2.0-flash-exp")

    def answer(self, user_query: str, rows: List[Dict]) -> str:
        if not rows:
            return ("I couldn't find matching rows for that query. "
                    "Try broadening filters (institute/program/category) or adjusting the rank.")
        bullet = "\n".join([
            f"- {r.get('institute','N/A')}: {r.get('program','N/A')} (closing_rank={r.get('closing_rank','?')})"
            for r in rows[:10]
        ])
        prompt = f"""
You are a JoSAA counseling assistant. Based on the rows below, answer the user briefly,
with 2‚Äì3 actionable tips (category/round effects, choice filling strategy).

User query: "{user_query}"

Rows:
{bullet}
"""
        res = self.sim.generate_content(prompt)
        return (res.text or "").strip()


# ========================
# ORCHESTRATOR
# ========================
class Pipeline:
    def __init__(self):
        self.pg = Pg(DB)
        self.schema_text = self.pg.fetch_schema_text()
        print("‚úÖ Loaded live schema from DB")

        self.enhancer = QueryEnhancer(GEMINI_API_KEY)
        self.sqlpro = SqlGenPro(GEMINI_API_KEY, self.schema_text)
        self.answerer = Answerer(GEMINI_API_KEY)

    def run(self, user_query: str) -> Tuple[str, str, List[Dict]]:
        # A) enhance
        variants = self.enhancer.enhance(user_query)
        if not variants:
            variants = [user_query]
        print("\nüîé Enhanced variants:")
        for v in variants:
            print(" -", v)

        # B) SQL for each variant (with sanitize + critique-fix)
        sqls = []
        for v in variants:
            sql = self.sqlpro.to_sql(v)
            if not sql:
                continue
            try:
                sanitized = sanitize_select(sql, hard_limit=15)
            except Exception as e:
                print("Sanitize refused:", e)
                continue
            sqls.append(sanitized)

        if not sqls:
            return ("I couldn't produce a safe SQL for that.", "", [])

        # C) Execute each, collect results
        candidates: List[Tuple[str, List[Dict]]] = []
        for s in sqls:
            try:
                rows = self.pg.run(s)
                print(f"  ‚úÖ Executed, rows={len(rows)}")
                candidates.append((s, rows))
            except Exception as e:
                print("  ‚ùå Exec error:", e)

        if not candidates:
            return ("All generated SQLs failed to run.", sqls[0], [])

        best_sql, best_rows = pick_best_result(candidates)

        # D) Answer
        answer = self.answerer.answer(user_query, best_rows)
        return answer, best_sql, best_rows


# ========================
# DEMO
# ========================
def main():
    pipe = Pipeline()

    tests = [
        "I have AIR 6000, which IIT programs can I get?",
        "Show me Computer Science programs at IIT Goa",
        "Mechanical Engineering with closing rank below 9000",
    ]

    for i, q in enumerate(tests, 1):
        print("\n" + "-"*60)
        print(f"Test {i}: {q}")
        print("-"*60)
        answer, sql, rows = pipe.run(q)
        print("\nüß† Best SQL:")
        print(sql)
        print("\nüí¨ Answer:")
        print(answer)
        if rows:
            print("\nüìä First 3 rows:")
            keys = {'institute','program','quota','category','gender','closing_rank'}
            for r in rows[:3]:
                print(" ", {k: r[k] for k in r.keys() & keys})

if __name__ == "__main__":
    main()


‚úÖ Loaded live schema from DB

------------------------------------------------------------
Test 1: I have AIR 6000, which IIT programs can I get?
------------------------------------------------------------

üîé Enhanced variants:
 - Which IIT programs are likely to admit someone with an AIR of 6000 (assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)?
 - Given an AIR of 6000, what IIT programs are within reach (assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)?
 - What IIT programs should I consider with an AIR of 6000 (assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)?
  ‚úÖ Executed, rows=15
  ‚úÖ Executed, rows=15
  ‚úÖ Executed, rows=15

üß† Best SQL:
WITH ranked_programs AS (
  SELECT
    institute,
    program,
    closing_rank,
    CAST(NULLIF(regexp_replace(trim(closing_rank), '[^0-9]', '', 'g'), '') AS INTEGER) AS closing_rank_num
  FROM
    josaa_btech_2024
  WHERE
    year = 2024
    AND (institute ILIKE '%Indian Institute of T

In [3]:
#!/usr/bin/env python3
"""
Two-Stage SQL RAG Pipeline (policy-driven, self-critique, live schema)
A) Gemini Flash: query enhancement (2‚Äì3 variants) + polite defaults hinting
B) Gemini 2.5 Pro: robust SQL generation (schema-aware) + self-critique/fix
C) DB: safe SELECT/CTE-only execution; pick best result, optional dedup
D) GGUF (or Gemini sim): final NL answer

Read-only: per-statement timeout, SELECT/CTE-only guard, LIMIT enforced.

Hardcoded:
- Supabase Postgres credentials
- Gemini API key
"""

import re
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple

import psycopg2
import psycopg2.extras
from psycopg2.pool import SimpleConnectionPool
import google.generativeai as genai


# ========================
# HARD-CODED CREDENTIALS
# ========================
DB = {
    "host": "DB_HOST_HERE",
    "database": "postgres",
    "user": "DB_USER_HERE",
    "password": "DB_PASSWORD_HERE",
    "port": "6543",
    "sslmode": "require",
    "connect_timeout": 10,
}
GEMINI_API_KEY = "GEMINI_API_KEY_HERE"


# ========================
# POLICY
# ========================
@dataclass
class Policy:
    exclude_pwd_default: bool = True      # exclude PwD when user didn't ask for it
    final_round_default: bool = True      # prefer latest round when unspecified
    apply_open_gn_ai_on_air: bool = True  # for AIR eligibility & unspecified: force OPEN/GN/AI


# ========================
# CONSTANTS & HELPERS
# ========================
RANK_NUM_EXPR = (
    "CAST(NULLIF(regexp_replace(trim(closing_rank), '[^0-9]', '', 'g'), '') AS INTEGER)"
)

ALLOWED_READ = re.compile(r"^\s*(select|with)\b", re.IGNORECASE | re.DOTALL)
DANGEROUS = re.compile(
    r"\b(insert|update|delete|drop|alter|truncate|create|grant|revoke|copy|vacuum|analyze)\b",
    re.IGNORECASE,
)

def single_statement(sql: str) -> bool:
    return sql.strip().count(";") <= 1

def ensure_limit(sql: str, hard_limit: int = 15) -> str:
    if re.search(r"\blimit\b", sql, re.IGNORECASE):
        return sql
    return f"{sql.rstrip(';')} LIMIT {hard_limit};"

def sanitize_select(sql: str, hard_limit: int = 15) -> str:
    sql = sql.strip()
    if not single_statement(sql):
        raise ValueError("Multiple SQL statements not allowed.")
    if not ALLOWED_READ.match(sql):
        raise ValueError("Only SELECT/CTE read queries are allowed.")
    if DANGEROUS.search(sql):
        raise ValueError("Potentially dangerous SQL token detected.")
    return ensure_limit(sql, hard_limit)


# ========================
# DB LAYER
# ========================
class Pg:
    def __init__(self, params: Dict[str, str]):
        self.pool = SimpleConnectionPool(minconn=1, maxconn=6, **params)

    def _conn(self):
        conn = self.pool.getconn()
        with conn.cursor() as c:
            c.execute("SET LOCAL statement_timeout = '12000ms';")
            c.execute("SET LOCAL default_transaction_read_only = on;")
        return conn

    def fetch_schema_text(self) -> str:
        """Build a concise, live schema string for prompt grounding."""
        conn = self._conn()
        try:
            cur = conn.cursor()
            out = ["SCHEMA: public\n"]
            cur.execute("""
                SELECT table_name
                FROM information_schema.tables
                WHERE table_schema = 'public'
                ORDER BY table_name;
            """)
            for (t,) in cur.fetchall():
                out.append(f"TABLE: {t}")
                cur.execute("""
                    SELECT column_name, data_type, is_nullable
                    FROM information_schema.columns
                    WHERE table_schema='public' AND table_name=%s
                    ORDER BY ordinal_position;
                """, (t,))
                for name, dtype, nullable in cur.fetchall():
                    out.append(f"  - {name}: {dtype} {'NULL' if nullable=='YES' else 'NOT NULL'}")
                out.append("")
            return "\n".join(out)
        finally:
            self.pool.putconn(conn)

    def run(self, sql: str) -> List[Dict]:
        conn = self._conn()
        try:
            with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
                cur.execute(sql)
                return [dict(r) for r in cur.fetchall()]
        finally:
            self.pool.putconn(conn)


# ========================
# STAGE A: FLASH ENHANCER
# ========================
class QueryEnhancer:
    """
    Expands the user's query into 2‚Äì3 variants.
    If it detects AIR/eligibility and no category/gender explicitly given,
    injects defaults: OPEN, Gender-Neutral, quota AI, exclude PwD, final round.
    Also gently nudges excluding PwD and using final round when user didn't specify anything.
    """
    def __init__(self, api_key: str, policy: Policy):
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel("gemini-2.0-flash-exp")
        self.policy = policy

    def enhance(self, user_query: str) -> List[str]:
        # Build hints from policy
        hints = []
        if self.policy.exclude_pwd_default:
            hints.append("exclude PwD by default unless the user explicitly asks for PwD")
        if self.policy.final_round_default:
            hints.append("assume final (latest) JoSAA round when round is unspecified")
        generic_defaults_hint = "; ".join(hints) if hints else "no additional defaults"

        air_defaults_hint = ""
        if self.policy.apply_open_gn_ai_on_air:
            air_defaults_hint = ("If the user mentions AIR/eligibility but does not specify "
                                 "category/gender/quota/round, add a clarifying parenthetical: "
                                 "(assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round).")

        prompt = f"""
Rewrite the user's database question into 2‚Äì3 clear, semantically similar variants.

Rules:
- Keep the same intent.
- Be concise, one line per variant, no bullets/numbers.
- Clarify institute types (IIT/NIT/IIIT) only if mentioned by the user.
- Do NOT fabricate constraints that the user did not give.
- When unspecified, {generic_defaults_hint}.
- {air_defaults_hint}

User query: "{user_query}"
Variants (one per line, 2‚Äì3 lines total):
"""
        res = self.model.generate_content(prompt)
        text = (res.text or "").strip()
        variants = [line.strip() for line in text.splitlines() if line.strip()]
        return variants[:3] if variants else [user_query]


# ========================
# STAGE B: 2.5 PRO SQL + SELF-CRITIQUE
# ========================
class SqlGenPro:
    """
    Generates SQL using Gemini 2.5 Pro with schema-aware rules,
    then runs a self-critique to fix common omissions.
    """
    def __init__(self, api_key: str, schema_text: str, policy: Policy):
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel("gemini-2.5-pro")
        self.schema_text = schema_text
        self.policy = policy

    def _rules(self) -> str:
        base = f"""
DATA RULES:
- opening_rank and closing_rank are TEXT. Use:
  {RANK_NUM_EXPR} AS closing_rank_num
- No 'rank_num' column exists.
- Use ILIKE for case-insensitive matches.
- Prefer table 'josaa_btech_2024' for B.Tech queries.
- When filtering IITs, do NOT rely on institute_type equality; match by name:
  (institute ILIKE '%Indian Institute of Technology%' OR institute ILIKE 'IIT %')
- ALWAYS ORDER BY closing_rank_num ASC and LIMIT 15.
"""
        # defaults when unspecified
        extra = []
        if self.policy.exclude_pwd_default:
            extra.append("If the user did not mention PwD, add: category NOT ILIKE '%PwD%'.")
        if self.policy.final_round_default:
            extra.append("If the user did not mention a round, add: round = (SELECT MAX(round) FROM josaa_btech_2024 WHERE year = 2024).")
        if self.policy.apply_open_gn_ai_on_air:
            extra.append(
                "If the user asks eligibility via AIR and does NOT specify category/gender/quota/round: "
                "add quota='AI', category='OPEN', category NOT ILIKE '%PwD%', gender='Gender-Neutral', "
                "round=(SELECT MAX(round) FROM josaa_btech_2024 WHERE year=2024); "
                "use closing_rank_num >= AIR."
            )
        if extra:
            base += "\n- " + "\n- ".join(extra) + "\n"
        return base

    def to_sql(self, nl_query: str) -> Optional[str]:
        rules = self._rules()
        prompt = f"""
You are an expert PostgreSQL writer.

DATABASE SCHEMA (public):
{self.schema_text}

{rules}

Return ONLY a single SQL SELECT (WITH allowed), no commentary.

User query: "{nl_query}"
SQL:
"""
        try:
            res = self.model.generate_content(prompt)
            sql = (res.text or "").strip()
            if sql.startswith("```"):
                sql = sql.replace("```sql", "").replace("```", "").strip()
            if not sql:
                return None
            fixed = self.critique_fix(sql, nl_query, rules)
            return fixed or sql
        except Exception as e:
            print("SQL gen error:", e)
            return None

    def critique_fix(self, sql: str, nl_query: str, rules: str) -> Optional[str]:
        critic = f"""
You are a PostgreSQL SQL critic/fixer.

Policy:
{rules}

Check the SQL covers:
- Numeric derivation (closing_rank_num).
- Defaults (exclude PwD; final round) if user didn't specify.
- AIR defaults (OPEN, GN, AI, exclude PwD, final round) when eligibility is asked & unspecified.
- IIT name matching via institute ILIKE (not institute_type equality).
- ORDER BY closing_rank_num ASC and LIMIT 15.
- Single SELECT (WITH allowed).

If compliant, output ONLY the original SQL. Otherwise output ONLY a corrected SQL.

User query: "{nl_query}"

Candidate SQL:
{sql}

Output ONLY the SQL:
"""
        try:
            res = self.model.generate_content(critic)
            fixed = (res.text or "").strip()
            if fixed and fixed != sql:
                if fixed.startswith("```"):
                    fixed = fixed.replace("```sql", "").replace("```", "").strip()
                return fixed
            return None
        except Exception as e:
            print("Critic error:", e)
            return None


# ========================
# STAGE D: GGUF (SIM) ANSWER
# ========================
class Answerer:
    """
    If you have a GGUF API, wire it here. For now, we simulate with Gemini Flash.
    """
    def __init__(self, api_key: str, gguf_endpoint: Optional[str] = None):
        self.gguf_endpoint = gguf_endpoint
        genai.configure(api_key=api_key)
        self.sim = genai.GenerativeModel("gemini-2.0-flash-exp")

    def answer(self, user_query: str, rows: List[Dict]) -> str:
        if not rows:
            return ("I couldn't find matching rows for that query. "
                    "Try broadening filters (institute/program/category) or adjusting the rank.")
        bullet = "\n".join([
            f"- {r.get('institute','N/A')}: {r.get('program','N/A')} (closing_rank={r.get('closing_rank','?')})"
            for r in rows[:10]
        ])
        prompt = f"""
You are a JoSAA counseling assistant. Based on the rows below, answer the user briefly,
with 2‚Äì3 actionable tips (category/round effects, choice filling strategy).

User query: "{user_query}"

Rows:
{bullet}
"""
        res = self.sim.generate_content(prompt)
        return (res.text or "").strip()


# ========================
# ORCHESTRATOR
# ========================
class Pipeline:
    def __init__(self):
        self.pg = Pg(DB)
        self.schema_text = self.pg.fetch_schema_text()
        print("‚úÖ Loaded live schema from DB")

        self.policy = Policy()
        self.enhancer = QueryEnhancer(GEMINI_API_KEY, self.policy)
        self.sqlpro = SqlGenPro(GEMINI_API_KEY, self.schema_text, self.policy)
        self.answerer = Answerer(GEMINI_API_KEY)

    def _dedup_rows(self, rows: List[Dict]) -> List[Dict]:
        """Deduplicate by (institute, program)."""
        seen = set()
        out = []
        for r in rows:
            key = (r.get("institute"), r.get("program"))
            if key in seen:
                continue
            seen.add(key)
            out.append(r)
        return out

    def run(self, user_query: str) -> Tuple[str, str, List[Dict]]:
        # A) enhance
        variants = self.enhancer.enhance(user_query) or [user_query]
        print("\nüîé Enhanced variants:")
        for v in variants:
            print(" -", v)

        # B) SQL for each variant (with sanitize + critique-fix)
        sqls = []
        for v in variants:
            sql = self.sqlpro.to_sql(v)
            if not sql:
                continue
            try:
                sanitized = sanitize_select(sql, hard_limit=15)
            except Exception as e:
                print("Sanitize refused:", e)
                continue
            sqls.append(sanitized)

        if not sqls:
            return ("I couldn't produce a safe SQL for that.", "", [])

        # C) Execute each, collect results
        candidates: List[Tuple[str, List[Dict]]] = []
        for s in sqls:
            try:
                rows = self.pg.run(s)
                print(f"  ‚úÖ Executed, rows={len(rows)}")
                candidates.append((s, rows))
            except Exception as e:
                print("  ‚ùå Exec error:", e)

        if not candidates:
            return ("All generated SQLs failed to run.", sqls[0], [])

        best_sql, best_rows = self._pick_best_result(candidates)
        best_rows = self._dedup_rows(best_rows)

        # D) Answer
        answer = self.answerer.answer(user_query, best_rows)
        return answer, best_sql, best_rows

    @staticmethod
    def _pick_best_result(candidates: List[Tuple[str, List[Dict]]]) -> Tuple[Optional[str], List[Dict]]:
        non_empty = [c for c in candidates if len(c[1]) > 0]
        if non_empty:
            non_empty.sort(key=lambda x: len(x[1]), reverse=True)
            return non_empty[0]
        return candidates[0] if candidates else (None, [])


# ========================
# DEMO
# ========================
def main():
    pipe = Pipeline()

    tests = [
        "I have AIR 6000, which IIT programs can I get?",
        "Show me Computer Science programs at IIT Goa",
        "Mechanical Engineering with closing rank below 9000",
    ]

    for i, q in enumerate(tests, 1):
        print("\n" + "-"*60)
        print(f"Test {i}: {q}")
        print("-"*60)
        answer, sql, rows = pipe.run(q)
        print("\nüß† Best SQL:")
        print(sql)
        print("\nüí¨ Answer:")
        print(answer)
        if rows:
            print("\nüìä First 3 rows:")
            keys = {'institute','program','quota','category','gender','closing_rank'}
            for r in rows[:3]:
                print(" ", {k: r[k] for k in r.keys() & keys})

if __name__ == "__main__":
    main()


‚úÖ Loaded live schema from DB

------------------------------------------------------------
Test 1: I have AIR 6000, which IIT programs can I get?
------------------------------------------------------------

üîé Enhanced variants:
 - *   With an AIR of 6000, what IIT programs am I eligible for (assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)?
 - *   What IIT programs can I get with a rank of 6000 (assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)?
 - *   Given an AIR of 6000, which IIT courses are available to me (assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)?
  ‚úÖ Executed, rows=15
  ‚úÖ Executed, rows=15
  ‚úÖ Executed, rows=15

üß† Best SQL:
WITH ranked_data AS (
  SELECT
    institute,
    program,
    closing_rank,
    CAST(NULLIF(regexp_replace(trim(closing_rank), '[^0-9]', '', 'g'), '') AS INTEGER) AS closing_rank_num
  FROM
    josaa_btech_2024
  WHERE
    year = 2024
    AND round = (SELECT MAX(round) FROM josaa_btech_202

In [5]:
#!/usr/bin/env python3
"""
Two-Stage SQL RAG Pipeline (policy-driven, self-critique, live schema)
A) Gemini Flash: query enhancement (2‚Äì3 variants) + polite defaults hinting
B) Gemini 2.5 Pro: robust SQL generation (schema-aware) + self-critique/fix
C) DB: safe SELECT/CTE-only execution; pick best result, optional dedup
D) GGUF (or Gemini sim): final NL answer

Read-only: per-statement timeout, SELECT/CTE-only guard, LIMIT enforced.

Hardcoded:
- Supabase Postgres credentials
- Gemini API key
"""

import re
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple

import psycopg2
import psycopg2.extras
from psycopg2.pool import SimpleConnectionPool
import google.generativeai as genai


# ========================
# HARD-CODED CREDENTIALS
# ========================
DB = {
    "host": "DB_HOST_HERE",
    "database": "postgres",
    "user": "DB_USER_HERE",
    "password": "DB_PASSWORD_HERE",
    "port": "6543",
    "sslmode": "require",
    "connect_timeout": 10,
}
GEMINI_API_KEY = "GEMINI_API_KEY_HERE"


# ========================
# POLICY
# ========================
@dataclass
class Policy:
    exclude_pwd_default: bool = True      # exclude PwD when user didn't ask for it
    final_round_default: bool = True      # prefer latest round when unspecified
    apply_open_gn_ai_on_air: bool = True  # for AIR eligibility & unspecified: force OPEN/GN/AI


# ========================
# CONSTANTS & HELPERS
# ========================
RANK_NUM_EXPR = (
    "CAST(NULLIF(regexp_replace(trim(closing_rank), '[^0-9]', '', 'g'), '') AS INTEGER)"
)

ALLOWED_READ = re.compile(r"^\s*(select|with)\b", re.IGNORECASE | re.DOTALL)
DANGEROUS = re.compile(
    r"\b(insert|update|delete|drop|alter|truncate|create|grant|revoke|copy|vacuum|analyze)\b",
    re.IGNORECASE,
)

def single_statement(sql: str) -> bool:
    return sql.strip().count(";") <= 1

def ensure_limit(sql: str, hard_limit: int = 15) -> str:
    if re.search(r"\blimit\b", sql, re.IGNORECASE):
        return sql
    return f"{sql.rstrip(';')} LIMIT {hard_limit};"

def sanitize_select(sql: str, hard_limit: int = 15) -> str:
    sql = sql.strip()
    if not single_statement(sql):
        raise ValueError("Multiple SQL statements not allowed.")
    if not ALLOWED_READ.match(sql):
        raise ValueError("Only SELECT/CTE read queries are allowed.")
    if DANGEROUS.search(sql):
        raise ValueError("Potentially dangerous SQL token detected.")
    return ensure_limit(sql, hard_limit)

def air_context_hint(text: str) -> str:
    """Soft hint for Pro when AIR/eligibility is detected."""
    return (
        "User intent hint: Eligibility by AIR is requested; "
        "apply OPEN/GN/AI, exclude PwD, and latest round unless the user overrides."
        if re.search(r'\b(AIR|rank)\b', text, re.I) else ""
    )

def fix_distinct_orderby(sql: str) -> str:
    """
    If SQL uses SELECT DISTINCT and ORDER BY closing_rank_num, but doesn't
    project closing_rank_num, add it to the SELECT list safely.
    """
    # quick checks
    if re.search(r"\bselect\s+distinct\b", sql, re.IGNORECASE) and \
       re.search(r"\border\s+by\s+.*closing_rank_num", sql, re.IGNORECASE):
        # extract select list
        m = re.search(r"(?is)select\s+distinct\s+(.*?)\s+from\s", sql)
        if m:
            select_list = m.group(1)
            if "closing_rank_num" not in select_list:
                # inject ", closing_rank_num" before FROM
                start, end = m.span(1)
                new_select_list = select_list.rstrip() + ", closing_rank_num"
                sql = sql[:start] + new_select_list + sql[end:]
    return sql


# ========================
# DB LAYER
# ========================
class Pg:
    def __init__(self, params: Dict[str, str]):
        self.pool = SimpleConnectionPool(minconn=1, maxconn=6, **params)

    def _conn(self):
        conn = self.pool.getconn()
        with conn.cursor() as c:
            c.execute("SET LOCAL statement_timeout = '12000ms';")
            c.execute("SET LOCAL default_transaction_read_only = on;")
        return conn

    def fetch_schema_text(self) -> str:
        """Build a concise, live schema string for prompt grounding."""
        conn = self._conn()
        try:
            cur = conn.cursor()
            out = ["SCHEMA: public\n"]
            cur.execute("""
                SELECT table_name
                FROM information_schema.tables
                WHERE table_schema = 'public'
                ORDER BY table_name;
            """)
            for (t,) in cur.fetchall():
                out.append(f"TABLE: {t}")
                cur.execute("""
                    SELECT column_name, data_type, is_nullable
                    FROM information_schema.columns
                    WHERE table_schema='public' AND table_name=%s
                    ORDER BY ordinal_position;
                """, (t,))
                for name, dtype, nullable in cur.fetchall():
                    out.append(f"  - {name}: {dtype} {'NULL' if nullable=='YES' else 'NOT NULL'}")
                out.append("")
            return "\n".join(out)
        finally:
            self.pool.putconn(conn)

    def run(self, sql: str) -> List[Dict]:
        conn = self._conn()
        try:
            with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
                cur.execute(sql)
                return [dict(r) for r in cur.fetchall()]
        finally:
            self.pool.putconn(conn)


# ========================
# STAGE A: FLASH ENHANCER
# ========================
class QueryEnhancer:
    """
    Expands the user's query into 2‚Äì3 variants.
    If it detects AIR/eligibility and no category/gender explicitly given,
    injects defaults: OPEN, Gender-Neutral, quota AI, exclude PwD, final round.
    Also gently nudges excluding PwD and using final round when user didn't specify anything.
    """
    def __init__(self, api_key: str, policy: Policy):
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel("gemini-2.0-flash-exp")
        self.policy = policy

    def enhance(self, user_query: str) -> List[str]:
        # Build hints from policy
        hints = []
        if self.policy.exclude_pwd_default:
            hints.append("exclude PwD by default unless the user explicitly asks for PwD")
        if self.policy.final_round_default:
            hints.append("assume final (latest) JoSAA round when round is unspecified")
        generic_defaults_hint = "; ".join(hints) if hints else "no additional defaults"

        air_defaults_hint = ""
        if self.policy.apply_open_gn_ai_on_air:
            air_defaults_hint = ("If the user mentions AIR/eligibility but does not specify "
                                 "category/gender/quota/round, add a clarifying parenthetical: "
                                 "(assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round).")

        prompt = f"""
Rewrite the user's database question into 2‚Äì3 clear, semantically similar variants.

Rules:
- Keep the same intent.
- Be concise, one line per variant, no bullets/numbers.
- Clarify institute types (IIT/NIT/IIIT) only if mentioned by the user.
- Do NOT fabricate constraints that the user did not give.
- When unspecified, {generic_defaults_hint}.
- {air_defaults_hint}

User query: "{user_query}"
Variants (one per line, 2‚Äì3 lines total):
"""
        res = self.model.generate_content(prompt)
        text = (res.text or "").strip()
        variants = [line.strip() for line in text.splitlines() if line.strip()]
        return variants[:3] if variants else [user_query]


# ========================
# STAGE B: 2.5 PRO SQL + SELF-CRITIQUE
# ========================
class SqlGenPro:
    """
    Generates SQL using Gemini 2.5 Pro with schema-aware rules,
    then runs a self-critique to fix common omissions.
    """
    def __init__(self, api_key: str, schema_text: str, policy: Policy):
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel("gemini-2.5-pro")
        self.schema_text = schema_text
        self.policy = policy

    def _rules(self) -> str:
        base = f"""
DATA RULES:
- opening_rank and closing_rank are TEXT. Use:
  {RANK_NUM_EXPR} AS closing_rank_num
- No 'rank_num' column exists.
- Use ILIKE for case-insensitive matches.
- Prefer table 'josaa_btech_2024' for B.Tech queries.
- When filtering IITs, do NOT rely on institute_type equality; match by name:
  (institute ILIKE '%Indian Institute of Technology%' OR institute ILIKE 'IIT %')
- Include year and round in SELECT outputs when present (e.g., year, round).
- If you use closing_rank_num in ORDER BY or filters, include closing_rank_num in the SELECT list.
- Use ORDER BY closing_rank_num ASC NULLS LAST.
- ALWAYS LIMIT 15.
"""
        # defaults when unspecified
        extra = []
        if self.policy.exclude_pwd_default:
            extra.append("If the user did not mention PwD, add: category NOT ILIKE '%PwD%'.")
        if self.policy.final_round_default:
            extra.append("If the user did not mention a round, add: round = (SELECT MAX(round) FROM josaa_btech_2024 WHERE year = 2024).")
        if self.policy.apply_open_gn_ai_on_air:
            extra.append(
                "If the user asks eligibility via AIR and does NOT specify category/gender/quota/round: "
                "add quota='AI', category='OPEN', category NOT ILIKE '%PwD%', gender='Gender-Neutral', "
                "round=(SELECT MAX(round) FROM josaa_btech_2024 WHERE year=2024); "
                "use closing_rank_num >= AIR."
            )
        if extra:
            base += "\n- " + "\n- ".join(extra) + "\n"
        # NOTE: We removed ‚ÄúPrefer DISTINCT ‚Ä¶‚Äù here. We do Python-side dedup.
        return base

    def to_sql(self, nl_query: str) -> Optional[str]:
        rules = self._rules()
        ctx = air_context_hint(nl_query)
        prompt = f"""
You are an expert PostgreSQL writer.

DATABASE SCHEMA (public):
{self.schema_text}

{rules}

{ctx}

Return ONLY a single SQL SELECT (WITH allowed), no commentary.

User query: "{nl_query}"
SQL:
"""
        try:
            res = self.model.generate_content(prompt)
            sql = (res.text or "").strip()
            if sql.startswith("```"):
                sql = sql.replace("```sql", "").replace("```", "").strip()
            if not sql:
                return None
            fixed = self.critique_fix(sql, nl_query, rules)
            sql_out = fixed or sql
            # Harden against DISTINCT/ORDER BY mismatch
            sql_out = fix_distinct_orderby(sql_out)
            return sql_out
        except Exception as e:
            print("SQL gen error:", e)
            return None

    def critique_fix(self, sql: str, nl_query: str, rules: str) -> Optional[str]:
        critic = f"""
You are a PostgreSQL SQL critic/fixer.

Policy:
{rules}

Check the SQL covers:
- Numeric derivation (closing_rank_num).
- If closing_rank_num appears in ORDER BY/filters, ensure it is also projected.
- Defaults (exclude PwD; final round) if user didn't specify.
- AIR defaults (OPEN, GN, AI, exclude PwD, final round) when eligibility is asked & unspecified.
- IIT name matching via institute ILIKE (not institute_type equality).
- Includes year and round in SELECT when present.
- ORDER BY closing_rank_num ASC NULLS LAST and LIMIT 15.
- Single SELECT (WITH allowed).

If compliant, output ONLY the original SQL. Otherwise output ONLY a corrected SQL.

User query: "{nl_query}"

Candidate SQL:
{sql}

Output ONLY the SQL:
"""
        try:
            res = self.model.generate_content(critic)
            fixed = (res.text or "").strip()
            if fixed and fixed != sql:
                if fixed.startswith("```"):
                    fixed = fixed.replace("```sql", "").replace("```", "").strip()
                return fixed
            return None
        except Exception as e:
            print("Critic error:", e)
            return None


# ========================
# STAGE D: GGUF (SIM) ANSWER
# ========================
class Answerer:
    """
    If you have a GGUF API, wire it here. For now, we simulate with Gemini Flash.
    """
    def __init__(self, api_key: str, gguf_endpoint: Optional[str] = None):
        self.gguf_endpoint = gguf_endpoint
        genai.configure(api_key=api_key)
        self.sim = genai.GenerativeModel("gemini-2.0-flash-exp")

    def answer(self, user_query: str, rows: List[Dict]) -> str:
        if not rows:
            return ("I couldn't find matching rows for that query. "
                    "Try broadening filters (institute/program/category) or adjusting the rank.")
        bullet = "\n".join([
            f"- {r.get('institute','N/A')} [{r.get('year','?')}/R{r.get('round','?')}]: "
            f"{r.get('program','N/A')} (closing_rank={r.get('closing_rank','?')}, "
            f"{r.get('quota','?')}/{r.get('category','?')}/{r.get('gender','?')})"
            for r in rows[:10]
        ])
        prompt = f"""
You are a JoSAA counseling assistant. Based on the rows below, answer the user briefly,
with 2‚Äì3 actionable tips (category/round effects, choice filling strategy).

User query: "{user_query}"

Rows:
{bullet}
"""
        res = self.sim.generate_content(prompt)
        return (res.text or "").strip()


# ========================
# ORCHESTRATOR
# ========================
class Pipeline:
    def __init__(self):
        self.pg = Pg(DB)
        self.schema_text = self.pg.fetch_schema_text()
        print("‚úÖ Loaded live schema from DB")

        self.policy = Policy()
        self.enhancer = QueryEnhancer(GEMINI_API_KEY, self.policy)
        self.sqlpro = SqlGenPro(GEMINI_API_KEY, self.schema_text, self.policy)
        self.answerer = Answerer(GEMINI_API_KEY)

    def _dedup_rows(self, rows: List[Dict]) -> List[Dict]:
        """Deduplicate by (institute, program, quota, category, gender, year, round)."""
        seen = set()
        out = []
        for r in rows:
            key = (
                r.get("institute"),
                r.get("program"),
                r.get("quota"),
                r.get("category"),
                r.get("gender"),
                r.get("year"),
                r.get("round"),
            )
            if key in seen:
                continue
            seen.add(key)
            out.append(r)
        return out

    def run(self, user_query: str) -> Tuple[str, str, List[Dict]]:
        # A) enhance; strip any stray bullets from the model output
        variants = self.enhancer.enhance(user_query) or [user_query]
        variants = [re.sub(r'^\s*[-*‚Ä¢]+\s*', '', v) for v in variants] or [user_query]

        print("\nüîé Enhanced variants:")
        for v in variants:
            print(" -", v)

        # B) SQL for each variant (with sanitize + critique-fix)
        sqls = []
        for v in variants:
            sql = self.sqlpro.to_sql(v)
            if not sql:
                continue
            try:
                sanitized = sanitize_select(sql, hard_limit=15)
            except Exception as e:
                print("Sanitize refused:", e)
                continue
            sqls.append(sanitized)

        if not sqls:
            return ("I couldn't produce a safe SQL for that.", "", [])

        # C) Execute each, collect results
        candidates: List[Tuple[str, List[Dict]]] = []
        for s in sqls:
            try:
                rows = self.pg.run(s)
                print(f"  ‚úÖ Executed, rows={len(rows)}")
                candidates.append((s, rows))
            except Exception as e:
                print("  ‚ùå Exec error:", e)

        if not candidates:
            return ("All generated SQLs failed to run.", sqls[0], [])

        best_sql, best_rows = self._pick_best_result(candidates)
        best_rows = self._dedup_rows(best_rows)

        # D) Answer
        answer = self.answerer.answer(user_query, best_rows)
        return answer, best_sql, best_rows

    @staticmethod
    def _pick_best_result(candidates: List[Tuple[str, List[Dict]]]) -> Tuple[Optional[str], List[Dict]]:
        non_empty = [c for c in candidates if len(c[1]) > 0]
        if non_empty:
            non_empty.sort(key=lambda x: len(x[1]), reverse=True)
            return non_empty[0]
        return candidates[0] if candidates else (None, [])


# ========================
# DEMO
# ========================
def main():
    pipe = Pipeline()

    tests = [
        "I have AIR 6000, which IIT programs can I get?",
        "Show me Computer Science programs at IIT Goa",
        "Mechanical Engineering with closing rank below 9000",
    ]

    for i, q in enumerate(tests, 1):
        print("\n" + "-"*60)
        print(f"Test {i}: {q}")
        print("-"*60)
        answer, sql, rows = pipe.run(q)
        print("\nüß† Best SQL:")
        print(sql)
        print("\nüí¨ Answer:")
        print(answer)
        if rows:
            print("\nüìä First 3 rows:")
            keys = {'year','round','institute','program','quota','category','gender','closing_rank'}
            for r in rows[:3]:
                print(" ", {k: r[k] for k in r.keys() & keys})

if __name__ == "__main__":
    main()


‚úÖ Loaded live schema from DB

------------------------------------------------------------
Test 1: I have AIR 6000, which IIT programs can I get?
------------------------------------------------------------

üîé Enhanced variants:
 - What IIT programs are available with an AIR of 6000 (assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)?
 - Which IIT courses can I get with an AIR of 6000 (assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)?
 - Given an AIR of 6000, what IIT programs am I eligible for (assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)?
  ‚úÖ Executed, rows=15
  ‚úÖ Executed, rows=15
  ‚úÖ Executed, rows=15

üß† Best SQL:
WITH josaa_ranks AS (
  SELECT
    *,
    CAST(NULLIF(regexp_replace(trim(closing_rank), '[^0-9]', '', 'g'), '') AS INTEGER) AS closing_rank_num
  FROM
    josaa_btech_2024
)
SELECT
  year,
  round,
  institute,
  program,
  quota,
  category,
  gender,
  closing_rank,
  closing_rank_num
FROM
  josaa_ranks
WH

In [1]:
#!/usr/bin/env python3
"""
Two-Stage SQL RAG Pipeline (policy-driven, self-critique, live schema)
A) Gemini Flash: query enhancement (2‚Äì3 variants) + polite defaults hinting
B) Gemini 2.5 Pro: robust SQL generation (schema-aware) + self-critique/fix
C) DB: safe SELECT/CTE-only execution; pick best result, optional dedup
D) GGUF (or Gemini sim): final NL answer

Read-only: per-statement timeout, SELECT/CTE-only guard, LIMIT enforced.

Hardcoded:
- Supabase Postgres credentials
- Gemini API key
"""

import re
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple

import psycopg2
import psycopg2.extras
from psycopg2.pool import SimpleConnectionPool
import google.generativeai as genai


# ========================
# HARD-CODED CREDENTIALS
# ========================
DB = {
    "host": "DB_HOST_HERE",
    "database": "postgres",
    "user": "DB_USER_HERE",
    "password": "DB_PASSWORD_HERE",
    "port": "6543",
    "sslmode": "require",
    "connect_timeout": 10,
}
GEMINI_API_KEY = "GEMINI_API_KEY_HERE"


# ========================
# POLICY
# ========================
@dataclass
class Policy:
    exclude_pwd_default: bool = True      # exclude PwD when user didn't ask for it
    final_round_default: bool = True      # prefer latest round when unspecified
    apply_open_gn_ai_on_air: bool = True  # for AIR eligibility & unspecified: force OPEN/GN/AI
    apply_open_gn_ai_on_numeric_eligibility: bool = True  # NEW: for numeric cutoff asks


# ========================
# CONSTANTS & HELPERS
# ========================
RANK_NUM_EXPR = (
    "CAST(NULLIF(regexp_replace(trim(closing_rank), '[^0-9]', '', 'g'), '') AS INTEGER)"
)

ALLOWED_READ = re.compile(r"^\s*(select|with)\b", re.IGNORECASE | re.DOTALL)
DANGEROUS = re.compile(
    r"\b(insert|update|delete|drop|alter|truncate|create|grant|revoke|copy|vacuum|analyze)\b",
    re.IGNORECASE,
)

def single_statement(sql: str) -> bool:
    return sql.strip().count(";") <= 1

def ensure_limit(sql: str, hard_limit: int = 15) -> str:
    if re.search(r"\blimit\b", sql, re.IGNORECASE):
        return sql
    return f"{sql.rstrip(';')} LIMIT {hard_limit};"

def sanitize_select(sql: str, hard_limit: int = 15) -> str:
    sql = sql.strip()
    if not single_statement(sql):
        raise ValueError("Multiple SQL statements not allowed.")
    if not ALLOWED_READ.match(sql):
        raise ValueError("Only SELECT/CTE read queries are allowed.")
    if DANGEROUS.search(sql):
        raise ValueError("Potentially dangerous SQL token detected.")
    return ensure_limit(sql, hard_limit)

def air_context_hint(text: str) -> str:
    """Soft hint for Pro when AIR/eligibility is detected."""
    return (
        "User intent hint: Eligibility by AIR is requested; "
        "apply OPEN/GN/AI, exclude PwD, and latest round unless the user overrides."
        if re.search(r'\b(AIR|rank)\b', text, re.I) else ""
    )

def numeric_eligibility_hint(text: str) -> bool:
    """
    Detects numeric cutoff intent, e.g., 'below 9000', '< 9000', '>= 6000',
    'closing rank under 7k', etc.
    """
    return bool(
        re.search(r'\b(under|below|less\s+than|greater\s+than|over|at\s+least|at\s+most)\s*\d+', text, re.I) or
        re.search(r'\b(<=|>=|<|>)\s*\d+', text) or
        re.search(r'\bclosing\s*rank\b.*\d+', text, re.I)
    )

def fix_distinct_orderby(sql: str) -> str:
    """
    If SQL uses SELECT DISTINCT and ORDER BY closing_rank_num, but doesn't
    project closing_rank_num, add it to the SELECT list safely.
    """
    if re.search(r"\bselect\s+distinct\b", sql, re.IGNORECASE) and \
       re.search(r"\border\s+by\s+.*closing_rank_num", sql, re.IGNORECASE):
        m = re.search(r"(?is)select\s+distinct\s+(.*?)\s+from\s", sql)
        if m:
            select_list = m.group(1)
            if "closing_rank_num" not in select_list:
                start, end = m.span(1)
                new_select_list = select_list.rstrip() + ", closing_rank_num"
                sql = sql[:start] + new_select_list + sql[end:]
    return sql


# ========================
# DB LAYER
# ========================
class Pg:
    def __init__(self, params: Dict[str, str]):
        self.pool = SimpleConnectionPool(minconn=1, maxconn=6, **params)

    def _conn(self):
        conn = self.pool.getconn()
        with conn.cursor() as c:
            c.execute("SET LOCAL statement_timeout = '12000ms';")
            c.execute("SET LOCAL default_transaction_read_only = on;")
        return conn

    def fetch_schema_text(self) -> str:
        """Build a concise, live schema string for prompt grounding."""
        conn = self._conn()
        try:
            cur = conn.cursor()
            out = ["SCHEMA: public\n"]
            cur.execute("""
                SELECT table_name
                FROM information_schema.tables
                WHERE table_schema = 'public'
                ORDER BY table_name;
            """)
            for (t,) in cur.fetchall():
                out.append(f"TABLE: {t}")
                cur.execute("""
                    SELECT column_name, data_type, is_nullable
                    FROM information_schema.columns
                    WHERE table_schema='public' AND table_name=%s
                    ORDER BY ordinal_position;
                """, (t,))
                for name, dtype, nullable in cur.fetchall():
                    out.append(f"  - {name}: {dtype} {'NULL' if nullable=='YES' else 'NOT NULL'}")
                out.append("")
            return "\n".join(out)
        finally:
            self.pool.putconn(conn)

    def run(self, sql: str) -> List[Dict]:
        conn = self._conn()
        try:
            with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
                cur.execute(sql)
                return [dict(r) for r in cur.fetchall()]
        finally:
            self.pool.putconn(conn)


# ========================
# STAGE A: FLASH ENHANCER
# ========================
class QueryEnhancer:
    """
    Expands the user's query into 2‚Äì3 variants.
    If it detects AIR/eligibility and no category/gender explicitly given,
    injects defaults: OPEN, Gender-Neutral, quota AI, exclude PwD, final round.
    Also gently nudges excluding PwD and using final round when user didn't specify anything.
    """
    def __init__(self, api_key: str, policy: Policy):
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel("gemini-2.0-flash-exp")
        self.policy = policy

    def enhance(self, user_query: str) -> List[str]:
        hints = []
        if self.policy.exclude_pwd_default:
            hints.append("exclude PwD by default unless the user explicitly asks for PwD")
        if self.policy.final_round_default:
            hints.append("assume final (latest) JoSAA round when round is unspecified")
        if self.policy.apply_open_gn_ai_on_numeric_eligibility:
            hints.append("for numeric cutoff queries (e.g., closing rank < N) assume OPEN, Gender-Neutral, quota AI unless specified")
        generic_defaults_hint = "; ".join(hints) if hints else "no additional defaults"

        air_defaults_hint = ""
        if self.policy.apply_open_gn_ai_on_air:
            air_defaults_hint = ("If the user mentions AIR/eligibility but does not specify "
                                 "category/gender/quota/round, add a clarifying parenthetical: "
                                 "(assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round).")

        prompt = f"""
Rewrite the user's database question into 2‚Äì3 clear, semantically similar variants.

Rules:
- Keep the same intent.
- Be concise, one line per variant, no bullets/numbers.
- Clarify institute types (IIT/NIT/IIIT) only if mentioned by the user.
- Do NOT fabricate constraints that the user did not give.
- When unspecified, {generic_defaults_hint}.
- {air_defaults_hint}

User query: "{user_query}"
Variants (one per line, 2‚Äì3 lines total):
"""
        res = self.model.generate_content(prompt)
        text = (res.text or "").strip()
        variants = [line.strip() for line in text.splitlines() if line.strip()]
        return variants[:3] if variants else [user_query]


# ========================
# STAGE B: 2.5 PRO SQL + SELF-CRITIQUE
# ========================
class SqlGenPro:
    """
    Generates SQL using Gemini 2.5 Pro with schema-aware rules,
    then runs a self-critique to fix common omissions.
    """
    def __init__(self, api_key: str, schema_text: str, policy: Policy):
        genai.configure(api_key=api_key)
        self.model = genai.GenerativeModel("gemini-2.5-pro")
        self.schema_text = schema_text
        self.policy = policy

    def _rules(self) -> str:
        base = f"""
DATA RULES:
- opening_rank and closing_rank are TEXT. Use:
  {RANK_NUM_EXPR} AS closing_rank_num
- No 'rank_num' column exists.
- Use ILIKE for case-insensitive matches.
- Prefer table 'josaa_btech_2024' for B.Tech queries.
- When filtering IITs, do NOT rely on institute_type equality; match by name:
  (institute ILIKE '%Indian Institute of Technology%' OR institute ILIKE 'IIT %')
- Include year and round in SELECT outputs when present (e.g., year, round).
- If you use closing_rank_num in ORDER BY or filters, include closing_rank_num in the SELECT list.
- Use ORDER BY closing_rank_num ASC NULLS LAST.
- ALWAYS LIMIT 15.
"""
        extra = []
        if self.policy.exclude_pwd_default:
            extra.append("If the user did not mention PwD, add: category NOT ILIKE '%PwD%'.")
        if self.policy.final_round_default:
            extra.append("If the user did not mention a round, add: round = (SELECT MAX(round) FROM josaa_btech_2024 WHERE year = 2024).")
        if self.policy.apply_open_gn_ai_on_air:
            extra.append(
                "If the user asks eligibility via AIR and does NOT specify category/gender/quota/round: "
                "add quota='AI', category='OPEN', category NOT ILIKE '%PwD%', gender='Gender-Neutral', "
                "round=(SELECT MAX(round) FROM josaa_btech_2024 WHERE year=2024); "
                "use closing_rank_num >= AIR."
            )
        if self.policy.apply_open_gn_ai_on_numeric_eligibility:
            extra.append(
                "If the user specifies a rank/closing-rank numeric cutoff but does NOT specify category/gender/quota: "
                "add quota='AI', category='OPEN', category NOT ILIKE '%PwD%', gender='Gender-Neutral'."
            )
        if extra:
            base += "\n- " + "\n- ".join(extra) + "\n"
        return base

    def to_sql(self, nl_query: str) -> Optional[str]:
        rules = self._rules()
        ctx_lines = [air_context_hint(nl_query)]
        if self.policy.apply_open_gn_ai_on_numeric_eligibility and numeric_eligibility_hint(nl_query):
            ctx_lines.append(
                "User intent hint: Numeric closing-rank eligibility requested; "
                "apply OPEN/GN/AI, exclude PwD, and latest round unless the user overrides."
            )
        ctx = "\n".join([c for c in ctx_lines if c])

        prompt = f"""
You are an expert PostgreSQL writer.

DATABASE SCHEMA (public):
{self.schema_text}

{rules}

{ctx}

Return ONLY a single SQL SELECT (WITH allowed), no commentary.

User query: "{nl_query}"
SQL:
"""
        try:
            res = self.model.generate_content(prompt)
            sql = (res.text or "").strip()
            if sql.startswith("```"):
                sql = sql.replace("```sql", "").replace("```", "").strip()
            if not sql:
                return None
            fixed = self.critique_fix(sql, nl_query, rules)
            sql_out = fixed or sql
            sql_out = fix_distinct_orderby(sql_out)  # harden DISTINCT/ORDER BY mismatch
            return sql_out
        except Exception as e:
            print("SQL gen error:", e)
            return None

    def critique_fix(self, sql: str, nl_query: str, rules: str) -> Optional[str]:
        critic = f"""
You are a PostgreSQL SQL critic/fixer.

Policy:
{rules}

Check the SQL covers:
- Numeric derivation (closing_rank_num).
- If closing_rank_num appears in ORDER BY/filters, ensure it is also projected.
- Defaults (exclude PwD; final round) if user didn't specify.
- AIR defaults (OPEN, GN, AI, exclude PwD, final round) when eligibility is asked & unspecified.
- Numeric-cutoff defaults (OPEN, GN, AI, exclude PwD) when user gives a numeric cutoff but not category/gender/quota.
- IIT name matching via institute ILIKE (not institute_type equality).
- Includes year and round in SELECT when present.
- ORDER BY closing_rank_num ASC NULLS LAST and LIMIT 15.
- Single SELECT (WITH allowed).

If compliant, output ONLY the original SQL. Otherwise output ONLY a corrected SQL.

User query: "{nl_query}"

Candidate SQL:
{sql}

Output ONLY the SQL:
"""
        try:
            res = self.model.generate_content(critic)
            fixed = (res.text or "").strip()
            if fixed and fixed != sql:
                if fixed.startswith("```"):
                    fixed = fixed.replace("```sql", "").replace("```", "").strip()
                return fixed
            return None
        except Exception as e:
            print("Critic error:", e)
            return None


# ========================
# STAGE D: GGUF (SIM) ANSWER
# ========================
class Answerer:
    """
    If you have a GGUF API, wire it here. For now, we simulate with Gemini Flash.
    """
    def __init__(self, api_key: str, gguf_endpoint: Optional[str] = None):
        self.gguf_endpoint = gguf_endpoint
        genai.configure(api_key=api_key)
        self.sim = genai.GenerativeModel("gemini-2.0-flash-exp")

    def answer(self, user_query: str, rows: List[Dict]) -> str:
        if not rows:
            return ("I couldn't find matching rows for that query. "
                    "Try broadening filters (institute/program/category) or adjusting the rank.")
        bullet = "\n".join([
            f"- {r.get('institute','N/A')} [{r.get('year','?')}/R{r.get('round','?')}]: "
            f"{r.get('program','N/A')} (closing_rank={r.get('closing_rank','?')}, "
            f"{r.get('quota','?')}/{r.get('category','?')}/{r.get('gender','?')})"
            for r in rows[:10]
        ])
        prompt = f"""
You are a JoSAA counseling assistant. Based on the rows below, answer the user briefly,
with 2‚Äì3 actionable tips (category/round effects, choice filling strategy).

User query: "{user_query}"

Rows:
{bullet}
"""
        res = self.sim.generate_content(prompt)
        return (res.text or "").strip()


# ========================
# ORCHESTRATOR
# ========================
class Pipeline:
    def __init__(self):
        self.pg = Pg(DB)
        self.schema_text = self.pg.fetch_schema_text()
        print("‚úÖ Loaded live schema from DB")

        self.policy = Policy()
        self.enhancer = QueryEnhancer(GEMINI_API_KEY, self.policy)
        self.sqlpro = SqlGenPro(GEMINI_API_KEY, self.schema_text, self.policy)
        self.answerer = Answerer(GEMINI_API_KEY)

    def _dedup_rows(self, rows: List[Dict]) -> List[Dict]:
        """Deduplicate by (institute, program, quota, category, gender, year, round)."""
        seen = set()
        out = []
        for r in rows:
            key = (
                r.get("institute"),
                r.get("program"),
                r.get("quota"),
                r.get("category"),
                r.get("gender"),
                r.get("year"),
                r.get("round"),
            )
            if key in seen:
                continue
            seen.add(key)
            out.append(r)
        return out

    def run(self, user_query: str) -> Tuple[str, str, List[Dict]]:
        # A) enhance; strip any stray bullets from the model output
        variants = self.enhancer.enhance(user_query) or [user_query]
        variants = [re.sub(r'^\s*[-*‚Ä¢]+\s*', '', v) for v in variants] or [user_query]

        print("\nüîé Enhanced variants:")
        for v in variants:
            print(" -", v)

        # B) SQL for each variant (with sanitize + critique-fix)
        sqls = []
        for v in variants:
            sql = self.sqlpro.to_sql(v)
            if not sql:
                continue
            try:
                sanitized = sanitize_select(sql, hard_limit=15)
            except Exception as e:
                print("Sanitize refused:", e)
                continue
            sqls.append(sanitized)

        if not sqls:
            return ("I couldn't produce a safe SQL for that.", "", [])

        # C) Execute each, collect results
        candidates: List[Tuple[str, List[Dict]]] = []
        for s in sqls:
            try:
                rows = self.pg.run(s)
                print(f"  ‚úÖ Executed, rows={len(rows)}")
                candidates.append((s, rows))
            except Exception as e:
                print("  ‚ùå Exec error:", e)

        if not candidates:
            return ("All generated SQLs failed to run.", sqls[0], [])

        best_sql, best_rows = self._pick_best_result(candidates)
        best_rows = self._dedup_rows(best_rows)

        # D) Answer
        answer = self.answerer.answer(user_query, best_rows)
        return answer, best_sql, best_rows

    @staticmethod
    def _pick_best_result(candidates: List[Tuple[str, List[Dict]]]) -> Tuple[Optional[str], List[Dict]]:
        non_empty = [c for c in candidates if len(c[1]) > 0]
        if non_empty:
            non_empty.sort(key=lambda x: len(x[1]), reverse=True)
            return non_empty[0]
        return candidates[0] if candidates else (None, [])


# ========================
# DEMO
# ========================
def main():
    pipe = Pipeline()

    tests = [
        "I have AIR 6000, which IIT programs can I get?",
        "Show me Computer Science programs at IIT Goa",
        "Mechanical Engineering with closing rank below 9000",
    ]

    for i, q in enumerate(tests, 1):
        print("\n" + "-"*60)
        print(f"Test {i}: {q}")
        print("-"*60)
        answer, sql, rows = pipe.run(q)
        print("\nüß† Best SQL:")
        print(sql)
        print("\nüí¨ Answer:")
        print(answer)
        if rows:
            print("\nüìä First 3 rows:")
            keys = {'year','round','institute','program','quota','category','gender','closing_rank'}
            for r in rows[:3]:
                print(" ", {k: r[k] for k in r.keys() & keys})

if __name__ == "__main__":
    main()


‚úÖ Loaded live schema from DB

------------------------------------------------------------
Test 1: I have AIR 6000, which IIT programs can I get?
------------------------------------------------------------

üîé Enhanced variants:
 - Which IIT programs are attainable with an AIR of 6000 (assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)?
 - What IIT programs can I get into with an AIR of 6000 (assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)?
 - With an AIR of 6000, what are my chances of getting into different IIT programs (assume OPEN, Gender-Neutral, quota AI, exclude PwD, final round)?
  ‚úÖ Executed, rows=15
  ‚úÖ Executed, rows=15
  ‚úÖ Executed, rows=15

üß† Best SQL:
WITH ranked_data AS (
  SELECT
    *,
    CAST(NULLIF(regexp_replace(trim(closing_rank), '[^0-9]', '', 'g'), '') AS INTEGER) AS closing_rank_num
  FROM
    josaa_btech_2024
)
SELECT
  year,
  round,
  institute,
  program,
  category,
  gender,
  closing_rank,
  closing_rank_num
F