In [1]:
!pip install pyngrok streamlit unsloth ragas

In [2]:
# Step 1: In your Colab notebook, run this first
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [3]:
%%writefile sql_generator_app.py
import streamlit as st
import torch
from unsloth import FastLanguageModel
from peft import PeftModel
import time
import pandas as pd
import re
import os
import asyncio
from ragas.metrics import LLMSQLEquivalence
from ragas.dataset_schema import SingleTurnSample
import openai
from ragas.llms import llm_factory
import sqlglot

# Set OpenAI API key for RAGAS evaluation
# You'll need to ask the user to input this or securely store it
if "OPENAI_API_KEY" not in st.session_state:
    st.session_state.OPENAI_API_KEY = None

# Mount Google Drive (needed for Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
except:
    pass  # Not running in Colab

# Set page configuration
st.set_page_config(
    page_title="LLM SQL Generator & Evaluator",
    page_icon="🔍",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Define model paths
MODEL_INFO = {
    "Mistral-7B (Fine-tuned)": {
        "base": "unsloth/mistral-7b-instruct-v0.2",
        "adapter": "/content/drive/MyDrive/SQL-Generation/models/mistral_tuned",
        "type": "mistral"
    },
    "Llama-3-8B (Fine-tuned)": {
        "base": "unsloth/llama-3-8b-bnb-4bit",
        "adapter": "/content/drive/MyDrive/SQL-Generation/models/llama_tuned",
        "type": "llama"
    },
    "Phi-3-mini (Fine-tuned)": {
        "base": "unsloth/Phi-3-mini-4k-instruct",
        "adapter": "/content/drive/MyDrive/SQL-Generation/models/phi_tuned",
        "type": "phi"
    }
}

# Styles with added comparison table styling
st.markdown("""
<style>
    .main-header {
        font-size: 2.5rem;
        font-weight: 600;
        color: #1E88E5;
    }
    .sub-header {
        font-size: 1.5rem;
        font-weight: 500;
        color: #42A5F5;
    }
    .success-box {
        background-color: #E8F5E9;
        padding: 20px;
        border-radius: 5px;
        border-left: 5px solid #4CAF50;
    }
    .info-box {
        background-color: #E3F2FD;
        padding: 20px;
        border-radius: 5px;
        border-left: 5px solid #2196F3;
    }
    .code-box {
        background-color: #263238;
        color: #FFFFFF;
        padding: 15px;
        border-radius: 5px;
        font-family: 'Courier New', Courier, monospace;
        overflow-x: auto;
    }
    .model-card {
        background-color: #F5F5F5;
        padding: 15px;
        border-radius: 5px;
        margin-bottom: 20px;
    }
    .mistral-box {
        border-left: 5px solid #4285F4;
    }
    .llama-box {
        border-left: 5px solid #FBBC04;
    }
    .phi-box {
        border-left: 5px solid #34A853;
    }
    .highlight {
        font-weight: 600;
        background-color: rgba(255, 235, 59, 0.2);
        padding: 0 5px;
    }
    .comparison-table {
        width: 100%;
        border-collapse: collapse;
    }
    .comparison-table th, .comparison-table td {
        padding: 8px 12px;
        text-align: left;
        border-bottom: 1px solid #ddd;
    }
    .comparison-table th {
        background-color: #f2f2f2;
    }
    .score-high {
        color: #4CAF50;
        font-weight: bold;
    }
    .score-medium {
        color: #FFC107;
        font-weight: bold;
    }
    .score-low {
        color: #F44336;
        font-weight: bold;
    }
</style>
""", unsafe_allow_html=True)

@st.cache_resource
def load_model(model_key):
    """Load model using model-specific approach"""
    model_info = MODEL_INFO[model_key]

    try:
        with st.spinner(f"Loading {model_key}..."):
            # Load base model
            base_model, tokenizer = FastLanguageModel.from_pretrained(
                model_name=model_info["base"],
                max_seq_length=2048,
                dtype=None,
                load_in_4bit=True
            )

            # Special handling for Llama tokenizer
            if "llama" in model_info["base"].lower():
                tokenizer.pad_token = tokenizer.eos_token

            # Load adapter
            model = PeftModel.from_pretrained(base_model, model_info["adapter"])
            FastLanguageModel.for_inference(model)

            return model, tokenizer
    except Exception as e:
        st.error(f"Error loading {model_key}: {e}")
        return None, None

def format_prompt(question, schema, model_type="mistral"):
    """Format input for the model based on model type"""
    if model_type == "llama":
        # Alpaca-style prompt format for Llama
        prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
        ### Instruction:
        Write a SQL query that answers the following question based on the given database schema. Use SQLite syntax.

        ### Input:
        [SCHEMA] {schema}
        [QUESTION] {question}

        ### Response:
        """
    elif model_type == "phi":
        # Format optimized for Phi
        prompt = f"""You are a SQL expert. Follow these instructions and provide an appropriate response.
        # TASK:
        Create a SQL query that solves the given question using the provided database schema. Use standard SQLite syntax.

        # CONTEXT:
        Database Schema:
        {schema}

        Question:
        {question}

        # RESPONSE:
        """
    else:
        # Standard format for Mistral
        prompt = f"""You are a SQL expert. Follow these instructions and provide an appropriate response.
        # TASK:
        Create a SQL query that solves the given question using the provided database schema. Use standard SQLite syntax.

        # CONTEXT:
        Database Schema:
        {schema}

        Question:
        {question}

        # RESPONSE:
        """
    return prompt

def clean_sql(sql, model_type="mistral"):
    """Clean generated SQL for display based on model type"""
    if model_type == "llama":
        # Clean Llama-specific formatting
        if "### Response:" in sql:
            sql = sql.split("### Response:")[1]

        sql = sql.replace("<|begin_of_text|>", "")
        sql = sql.replace("<|end_of_text|>", "")
    elif model_type == "phi":
        # Extract content after "# RESPONSE:" marker
        if "# RESPONSE:" in sql:
            # Get everything after "# RESPONSE:"
            response_part = sql.split("# RESPONSE:")[1].strip()

            # Find where the next section starts (if any)
            end_markers = ["# EXPLANATION:", "# TASK:", "# CONTEXT:", "# QUESTION:", '"""', "Question:"]
            end_pos = len(response_part)

            for marker in end_markers:
                marker_pos = response_part.find(marker)
                if marker_pos != -1 and marker_pos < end_pos:
                    end_pos = marker_pos

            # Extract just the SQL part
            sql = response_part[:end_pos].strip()
    else:
        # Standard cleaning for Mistral
        if "# RESPONSE:" in sql:
            sql = sql.split("# RESPONSE:")[1]

    # Common cleaning for all models
    sql = sql.replace("</s>", "")
    sql = sql.replace("<|endoftext|>", "")  # Fix for Phi model
    sql = sql.replace("[SQL QUERY]", "")
    sql = sql.replace('"""', "")
    sql = sql.replace("```sql", "")
    sql = sql.replace("```", "")

    # Trim whitespace
    sql = sql.strip()

    return sql

def generate_sql(model, tokenizer, question, schema, model_type="mistral"):
    """Generate SQL from question and schema"""
    prompt = format_prompt(question, schema, model_type)

    # Tokenize and generate
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")

    # Start timing
    start_time = time.time()

    # Model-specific generation parameters
    if model_type == "llama":
        outputs = model.generate(
            **inputs,
            max_new_tokens=300,
            use_cache=True,
            temperature=0.7,
            do_sample=True,
            repetition_penalty=1.1
        )
    elif model_type == "phi":
        outputs = model.generate(
            **inputs,
            max_new_tokens=250,
            use_cache=True,
            temperature=0.2,
            do_sample=True
        )
    else:
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            use_cache=True,
            temperature=0.1
        )

    # End timing
    end_time = time.time()

    # Decode the output
    result = tokenizer.batch_decode(outputs)[0]

    # Clean the result
    cleaned_sql = clean_sql(result, model_type)

    return cleaned_sql, end_time - start_time

def highlight_sql_keywords(sql):
    """Add syntax highlighting to SQL keywords"""
    keywords = [
        "SELECT", "FROM", "WHERE", "JOIN", "GROUP BY", "ORDER BY", "HAVING",
        "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER", "TABLE",
        "AND", "OR", "NOT", "LIKE", "IN", "BETWEEN", "IS NULL", "IS NOT NULL",
        "ASC", "DESC", "DISTINCT", "COUNT", "SUM", "AVG", "MAX", "MIN",
        "INNER JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL JOIN", "OUTER JOIN",
        "ON", "AS", "WITH", "UNION", "ALL", "CASE", "WHEN", "THEN", "ELSE", "END"
    ]

    # Sort keywords by length (longer first) to avoid partial matches
    keywords.sort(key=len, reverse=True)

    # Create pattern to match keywords (case insensitive)
    pattern = r'\b(' + '|'.join(keywords) + r')\b'

    # Replace keywords with highlighted version
    highlighted = re.sub(
        pattern,
        r'<span style="color: #FFA726; font-weight: bold;">\1</span>',
        sql,
        flags=re.IGNORECASE
    )

    # Highlight string literals
    highlighted = re.sub(
        r"'([^']*)'",
        r'<span style="color: #66BB6A;">"\1"</span>',
        highlighted
    )

    # Highlight numbers
    highlighted = re.sub(
        r'\b(\d+)\b',
        r'<span style="color: #EF5350;">\1</span>',
        highlighted
    )

    # Highlight function calls
    highlighted = re.sub(
        r'\b(\w+)\(',
        r'<span style="color: #42A5F5;">\1</span>(',
        highlighted
    )

    return highlighted

def check_syntax_validity(sql):
    """Check if the SQL syntax is valid using sqlglot"""
    try:
        if sql and len(sql) > 5:
            sqlglot.parse(sql)
            return True
        return False
    except Exception:
        return False

async def custom_evaluate_queries(sql_queries, reference_query, schema):
    """Custom evaluation using direct GPT-4 calls - binary scoring approach"""
    if not st.session_state.OPENAI_API_KEY:
        st.warning("OpenAI API key not set. Skipping semantic evaluation.")
        return None

    client = openai.OpenAI(api_key=st.session_state.OPENAI_API_KEY)
    results = {}

    # st.info(f"Starting custom semantic evaluation with {len(sql_queries)} queries")

    for model_name, sql in sql_queries.items():
        try:
            # st.info(f"Evaluating {model_name} query...")

            prompt = f"""
            You are a SQL expert. Determine if this query correctly answers the question based on the given schema.

            Database Schema:
            {schema}

            Reference Query (known to be correct):
            {reference_query}

            Query to Evaluate:
            {sql}

            If the query is correct (would provide the same results as the reference query), respond with "1.0".
            If the query is incorrect, respond with a number between 0.0 and 0.9 that represents how close it is to being correct.

            Respond ONLY with a single number.
            """

            response = client.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {"role": "system", "content": "You are an expert SQL evaluator."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0
            )

            score_text = response.choices[0].message.content.strip()

            # Extract numeric score
            try:
                score = float(score_text)
                if score < 0 or score > 1:
                    score = max(0, min(float(score), 1))  # Ensure it's between 0 and 1
            except ValueError:
                # If not a clean number, try to find any number in the response
                import re
                match = re.search(r"([0-9]*\.?[0-9]+)", score_text)
                if match:
                    score = float(match.group(1))
                else:
                    st.warning(f"Couldn't extract score from: {score_text}")
                    score = 0.5  # Default mid-range score

            results[model_name] = score
            # st.success(f"Successfully evaluated {model_name}: Score = {score}")

        except Exception as e:
            st.error(f"Error evaluating {model_name}: {str(e)}")
            results[model_name] = 0.0

    return results

async def evaluate_semantic_equivalence(sql_queries, reference_query, schema):
    """Evaluate semantic equivalence using RAGAS - fixed version"""
    if not st.session_state.OPENAI_API_KEY:
        st.warning("OpenAI API key not set. Skipping semantic evaluation.")
        return None

    # Set OpenAI API key
    os.environ["OPENAI_API_KEY"] = st.session_state.OPENAI_API_KEY
    client = openai.OpenAI(api_key=st.session_state.OPENAI_API_KEY)

    # Initialize RAGAS evaluator
    semantic_checker = LLMSQLEquivalence()
    semantic_checker.llm = llm_factory(model="gpt-4o")

    # Log for debugging
    st.info(f"Starting semantic evaluation with {len(sql_queries)} queries")

    results = {}

    for model_name, sql in sql_queries.items():
        try:
            # st.info(f"Evaluating {model_name} query: {sql[:100]}...")
            # This matches the reference implementation's approach
            sample = SingleTurnSample(
                response=sql,
                reference=reference_query,
                reference_contexts=[schema]  # Changed to match your sample code
            )

            # Try with direct OpenAI call if RAGAS is failing
            try:
                score = await semantic_checker.single_turn_ascore(sample)
                st.success(f"RAGAS evaluation succeeded for {model_name}")
            except Exception as e:
                st.error(f"RAGAS evaluation failed: {e}. Trying direct OpenAI call...")

                prompt = f"""
                Are these two SQL queries semantically equivalent? They should return the same results.

                Query 1: {reference_query}

                Query 2: {sql}

                Schema information: {schema}

                Score from 0.0 to 1.0, where 1.0 means completely equivalent:
                """

                response = client.chat.completions.create(
                    model="gpt-4o",
                    messages=[
                        {"role": "system", "content": "You are a SQL expert evaluating query equivalence."},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0
                )

                # Try to extract a score from the response
                score_text = response.choices[0].message.content.strip()

                try:
                    score = float(score_text)
                except:
                    # If we can't extract a clean number, look for a number in the text
                    import re
                    match = re.search(r"([0-9]*\.?[0-9]+)", score_text)
                    if match:
                        score = float(match.group(1))
                    else:
                        score = 0.0

            results[model_name] = float(score)
            # st.info(f"Successfully evaluated {model_name}: Score = {score}")

        except Exception as e:
            st.error(f"All evaluation methods failed for {model_name}: {str(e)}")
            results[model_name] = 0.0

    return results

def get_reference_query(question, schema):
    """Get a reference SQL query from GPT-4 to use as ground truth"""
    if not st.session_state.OPENAI_API_KEY:
        return None

    client = openai.OpenAI(api_key=st.session_state.OPENAI_API_KEY)

    prompt = f"""Generate a correct SQL query for the following question and schema, without any explanations:

    Schema:
    {schema}

    Question:
    {question}

    SQL Query:"""

    try:
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[{"role": "system", "content": "You are an expert SQL developer."},
                      {"role": "user", "content": prompt}],
            temperature=0
        )

        # Extract just the SQL from the response
        sql = response.choices[0].message.content.strip()

        # Remove any explanation text around the SQL
        sql_keywords = ["SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER", "WITH"]
        start_idx = 0
        for keyword in sql_keywords:
            if keyword in sql.upper():
                keyword_idx = sql.upper().find(keyword)
                if keyword_idx > -1:
                    start_idx = keyword_idx
                    break

        # Extract just the SQL part
        sql = sql[start_idx:].strip()
        return sql
    except Exception as e:
        st.error(f"Error getting reference query: {str(e)}")
        return None

# Initialize session state for models
if 'models' not in st.session_state:
    st.session_state.models = {}
    st.session_state.tokenizers = {}
    st.session_state.loaded_models = []

# App header
st.markdown('<p class="main-header">LLM SQL Generator & Evaluator</p>', unsafe_allow_html=True)
st.markdown("""
This application uses fine-tuned language models to generate SQL queries from natural language descriptions.
Provide a database schema and a question, and see how different models respond! The app also evaluates query quality.
""")

# Sidebar for API key and model loading
st.sidebar.markdown('<p class="sub-header">Settings</p>', unsafe_allow_html=True)

# OpenAI API key input
openai_api_key = st.sidebar.text_input(
    "OpenAI API Key (for RAGAS evaluation)",
    type="password",
    value=st.session_state.OPENAI_API_KEY if st.session_state.OPENAI_API_KEY else "",
    help="Required for semantic evaluation of SQL queries"
)

if openai_api_key:
    st.session_state.OPENAI_API_KEY = openai_api_key
    st.sidebar.success("API key set!")

# Model loading section
st.sidebar.markdown('<p class="sub-header">Model Loading</p>', unsafe_allow_html=True)

# Load all models at startup
if not st.session_state.loaded_models:
    progress_bar = st.sidebar.progress(0)
    st.sidebar.info("Loading models... this may take a few minutes")

    i = 0
    total_models = len(MODEL_INFO)

    for model_name, model_info in MODEL_INFO.items():
        i += 1
        progress_bar.progress(i/total_models)

        # Try to load the model
        model, tokenizer = load_model(model_name)

        if model is not None:
            st.session_state.models[model_name] = model
            st.session_state.tokenizers[model_name] = tokenizer
            st.session_state.loaded_models.append(model_name)
            st.sidebar.success(f"{model_name} loaded successfully!")
        else:
            st.sidebar.error(f"Failed to load {model_name}")

    progress_bar.empty()

# Show loaded models
st.sidebar.subheader("Loaded Models:")
for model_name in st.session_state.loaded_models:
    st.sidebar.success(f"✅ {model_name}")

# Main content
st.markdown('<p class="sub-header">SQL Query Generator</p>', unsafe_allow_html=True)

# Show sample database schemas
with st.expander("Sample Database Schemas (Click to expand)", expanded=False):
    st.markdown("""
    ### E-commerce Database
    ```sql
    CREATE TABLE customers (customer_id INT, name TEXT, email TEXT, registration_date DATE);
    CREATE TABLE products (product_id INT, name TEXT, category TEXT, price REAL, stock INT);
    CREATE TABLE orders (order_id INT, customer_id INT, order_date DATE, total_amount REAL);
    CREATE TABLE order_items (order_id INT, product_id INT, quantity INT, price REAL);
    ```

    ### Employee Management
    ```sql
    CREATE TABLE departments (dept_id INT, name TEXT, location TEXT);
    CREATE TABLE employees (emp_id INT, name TEXT, dept_id INT, salary REAL, hire_date DATE);
    CREATE TABLE projects (project_id INT, name TEXT, start_date DATE, end_date DATE);
    CREATE TABLE project_assignments (emp_id INT, project_id INT, role TEXT);
    ```
    """)

# Input form
with st.form("query_form"):
    col1, col2 = st.columns(2)

    with col1:
        schema = st.text_area(
            "Database Schema (SQL CREATE statements)",
            height=200,
            placeholder="CREATE TABLE users (id INT, name TEXT, email TEXT);\nCREATE TABLE orders (id INT, user_id INT, amount REAL, date DATE);"
        )

    with col2:
        question = st.text_area(
            "Your Question",
            height=200,
            placeholder="What is the total amount spent by each user, ordered by highest amount first?"
        )

    with_evaluation = st.checkbox("Include RAGAS evaluation (requires OpenAI API key)", value=st.session_state.OPENAI_API_KEY is not None)

    submitted = st.form_submit_button("Generate SQL Query from All Models")

# Generate and display SQL queries from all loaded models
if submitted and schema and question:
    if not st.session_state.loaded_models:
        st.error("No models have been successfully loaded. Please check the console for errors.")
    else:
        st.success(f"Generating SQL queries using {len(st.session_state.loaded_models)} models...")

        tabs = st.tabs(["All Models"] + st.session_state.loaded_models)

        all_results = {}

        # Generate SQL with each model
        for model_name in st.session_state.loaded_models:
            model = st.session_state.models[model_name]
            tokenizer = st.session_state.tokenizers[model_name]
            model_type = MODEL_INFO[model_name]["type"]

            try:
                sql_query, generation_time = generate_sql(
                    model,
                    tokenizer,
                    question,
                    schema,
                    model_type
                )

                # Check syntax validity
                syntax_valid = check_syntax_validity(sql_query)

                all_results[model_name] = {
                    "sql": sql_query,
                    "time": generation_time,
                    "syntax_valid": syntax_valid
                }
            except Exception as e:
                all_results[model_name] = {
                    "error": str(e)
                }

        # Generate reference query from GPT-4 if evaluation is enabled
        reference_query = None
        semantic_scores = None

        if with_evaluation and st.session_state.OPENAI_API_KEY:
            with st.spinner("Generating reference SQL query using GPT-4..."):
                reference_query = get_reference_query(question, schema)

                if reference_query:
                    st.info("Reference query generated for evaluation")

                    # Only include SQL queries, not errors
                    valid_queries = {model: result["sql"] for model, result in all_results.items()
                                    if "sql" in result}

                    # Evaluate semantic equivalence
                    with st.spinner("Evaluating semantic equivalence..."):
                        # semantic_scores = asyncio.run(
                        #     evaluate_semantic_equivalence(valid_queries, reference_query, schema)
                        # )
                        semantic_scores = asyncio.run(custom_evaluate_queries(valid_queries, reference_query, schema))

                        # Add scores to results
                        for model_name, score in semantic_scores.items():
                            all_results[model_name]["semantic_score"] = score

        # All Models Tab
        with tabs[0]:
            # First show reference query if available
            if reference_query:
                st.markdown('<div class="info-box">', unsafe_allow_html=True)
                st.markdown("### Reference Query (GPT-4)")
                st.code(reference_query, language="sql")
                st.markdown('</div>', unsafe_allow_html=True)

            # Display comparison table if we have semantic scores
            if semantic_scores:
                st.markdown("### Model Performance Comparison")

                # Create DataFrame for comparison
                comparison_data = []
                for model_name in st.session_state.loaded_models:
                    if model_name in all_results and "error" not in all_results[model_name]:
                        result = all_results[model_name]
                        comparison_data.append({
                            "Model": model_name,
                            "Generation Time (s)": f"{result['time']:.2f}",
                            "Syntax Valid": "✅" if result.get("syntax_valid", False) else "❌",
                            "Semantic Score": f"{result.get('semantic_score', 0):.2f}"
                        })

                comparison_df = pd.DataFrame(comparison_data)
                st.table(comparison_df)

            # Display individual model results
            for i, model_name in enumerate(st.session_state.loaded_models, 1):
                model_result = all_results[model_name]
                model_type = MODEL_INFO[model_name]["type"]
                model_box_class = f"{model_type}-box model-card"

                st.markdown(f'<div class="{model_box_class}">', unsafe_allow_html=True)
                st.subheader(f"{model_name}")

                if "error" in model_result:
                    st.error(f"Error: {model_result['error']}")
                else:
                    metrics_col1, metrics_col2 = st.columns(2)

                    with metrics_col1:
                        st.metric("Generation Time", f"{model_result['time']:.2f}s")

                    with metrics_col2:
                        if model_result.get("syntax_valid", False):
                            st.success("✅ Syntax Valid")
                        else:
                            st.error("❌ Invalid Syntax")

                    if "semantic_score" in model_result:
                        score = model_result["semantic_score"]
                        score_class = "score-high" if score > 0.7 else "score-medium" if score > 0.4 else "score-low"
                        st.markdown(f'<p>Semantic Equivalence: <span class="{score_class}">{score:.2f}</span></p>', unsafe_allow_html=True)

                    # Only show the copy button version
                    st.code(model_result['sql'], language="sql")

                st.markdown('</div>', unsafe_allow_html=True)

        # Individual Model Tabs
        for i, model_name in enumerate(st.session_state.loaded_models, 1):
            with tabs[i]:
                model_result = all_results[model_name]

                if "error" in model_result:
                    st.error(f"Error: {model_result['error']}")
                else:
                    st.success(f"SQL query generated in {model_result['time']:.2f} seconds")

                    # Display syntax validation result
                    if model_result.get("syntax_valid", False):
                        st.success("✅ SQL syntax is valid")
                    else:
                        st.error("❌ SQL syntax is invalid")

                    # Display semantic score if available
                    if "semantic_score" in model_result:
                        score = model_result["semantic_score"]
                        if score > 0.7:
                            st.success(f"✅ Semantic equivalence score: {score:.2f}/1.0 (Good)")
                        elif score > 0.4:
                            st.warning(f"⚠️ Semantic equivalence score: {score:.2f}/1.0 (Fair)")
                        else:
                            st.error(f"❌ Semantic equivalence score: {score:.2f}/1.0 (Poor)")

                    # Only show the code block
                    st.code(model_result['sql'], language="sql")

                    # If we have a reference, show comparison
                    if reference_query:
                        st.markdown("### Reference Query (GPT-4)")
                        st.code(reference_query, language="sql")

                    with st.expander("Model Details", expanded=False):
                        st.markdown(f"""
                        - **Model**: {model_name}
                        - **Generation Time**: {model_result['time']:.2f} seconds
                        - **Output Token Count**: ~{len(model_result['sql'].split())} tokens
                        """)

elif submitted:
    st.warning("Please provide both a database schema and a question.")

# Sample questions
st.markdown('<p class="sub-header">Sample Questions</p>', unsafe_allow_html=True)
st.markdown('<div class="info-box">', unsafe_allow_html=True)
st.markdown("""
Try these sample questions with the appropriate schema:

1. **For E-commerce:**
   - What are the top 5 most purchased products?
   - Find customers who spent more than $1000 in total.
   - How many orders were placed in each month of 2023?

2. **For Employee Management:**
   - What is the average salary by department?
   - Which employees are assigned to multiple projects?
   - List all employees hired in the last year with their department names.
""")
st.markdown('</div>', unsafe_allow_html=True)

# Footer
st.markdown("---")
st.markdown("LLM SQL Generator | Built with Streamlit, Unsloth, and PyTorch")

Writing sql_generator_app.py


# THIS IS THE FINAL NGROK CELL TO HOST THE UI (RUN THIS CELL THEN CLICK .NGROK-FREE.APP)

In [4]:
# Import necessary libraries
from pyngrok import ngrok
import os

# Set ngrok auth token
ngrok.set_auth_token("2u32iphqQT5Wyb5Ms0dW4KJO0Uu_4sUC1x7wNNCkRG3phcFP")

# Kill any existing ngrok tunnels
ngrok.kill()

# Specify the path to your Streamlit app
app_path = "sql_generator_app.py"

# Check if Streamlit file exists
if not os.path.exists(app_path):
    print(f"Error: {app_path} does not exist")
    exit(1)

# Start Streamlit directly using os.system
os.system(f"streamlit run {app_path} &")

# Wait longer for Streamlit to start up
import time
print("Waiting for Streamlit to start (60 seconds)...")
time.sleep(60)  # Wait a full minute for Streamlit to initialize

# Create the ngrok tunnel
print("Creating ngrok tunnel...")
try:
    public_url = ngrok.connect(8501)
    print(f"Public URL: {public_url}")
except Exception as e:
    print(f"Error connecting ngrok: {e}")
    # Try again with explicit hostname
    try:
        public_url = ngrok.connect(addr="localhost:8501", bind_tls=True)
        print(f"Public URL: {public_url}")
    except Exception as e:
        print(f"Error on second attempt: {e}")

Waiting for Streamlit to start (60 seconds)...
Creating ngrok tunnel...
Public URL: NgrokTunnel: "https://0e14-34-71-30-75.ngrok-free.app" -> "http://localhost:8501"
