In [1]:
!pip install streamlit

[0m

In [4]:
!streamlit run text2sql_ui.py


Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8503[0m
[34m  Network URL: [0m[1mhttp://172.17.0.2:8503[0m
[34m  External URL: [0m[1mhttp://130.65.254.7:8503[0m
[0m
^C
[34m  Stopping...[0m


In [5]:
!ls

'=0.21.0'				      lora_adapters
 AGENT.ipynb				      results
 DATA266.ipynb				      spider-schema.csv
 DATA_266_Project_266_version_5_May_8.ipynb   spider_databases
 Untitled.ipynb				      spider_dataset
 Untitled1.ipynb			      spider_finetune.jsonl
 Untitled2.ipynb			      spider_schemas
 Untitled3.ipynb			      spider_scheme_to_sqlite.ipynb
 base_evaluation_results_basemodel.csv	      spider_sqlite
 database				      spider_sqlite#1
 evaluation_results_FINETUNED.csv	      text2sql_ui.py
 llama3-text2sql			      wandb
 llama3-text2sql_interrupted


In [2]:
pip install plotly

[0mNote: you may need to restart the kernel to use updated packages.


In [1]:
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
import sqlite3
import pandas as pd
import plotly.express as px
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import torch
import webbrowser
from urllib.parse import quote
import pickle
import os

class EnhancedText2SQLUI:
    def __init__(self, model_path, db_path):
        self.model_path = model_path
        self.db_path = db_path
        self.query_history = []
        self.initialize_agent()
        self.create_ui()
        
    def initialize_agent(self):
        """Initialize the Text-to-SQL agent"""
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
            config = PeftConfig.from_pretrained(self.model_path)
            base_model = AutoModelForCausalLM.from_pretrained(
                config.base_model_name_or_path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto"
            )
            self.model = PeftModel.from_pretrained(
                base_model,
                self.model_path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            ).merge_and_unload()
            
            self.schema = self.load_schema()
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.model.to(self.device)
            
        except Exception as e:
            print(f"Error initializing agent: {str(e)}")
            raise
    
    def load_schema(self):
        """Load database schema"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0] for table in cursor.fetchall()]
        
        schema = {}
        for table in tables:
            cursor.execute(f"PRAGMA table_info({table});")
            schema[table] = [{"name": col[1], "type": col[2]} for col in cursor.fetchall()]
        
        conn.close()
        return schema
    
    def generate_sql(self, question):
        """Generate SQL from natural language question"""
        schema_lines = []
        for table, cols in self.schema.items():
            col_defs = [f"{col['name']} ({col['type']})" for col in cols]
            schema_lines.append(f"Table {table}: {', '.join(col_defs)}")
        
        schema_prompt = "\n".join(schema_lines)
        
        prompt = f"""Given this database schema, translate the question to SQL:

{schema_prompt}

Question: {question}
SQL Query:"""
        
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=256,
                pad_token_id=self.tokenizer.eos_token_id,
                num_beams=5,
                temperature=0.7
            )
        
        return self.tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True
        ).strip()
    
    def execute_query(self, sql):
        """Execute SQL query and return results as DataFrame"""
        conn = sqlite3.connect(self.db_path)
        
        try:
            df = pd.read_sql_query(sql, conn)
            return df, None
        except Exception as e:
            return None, str(e)
        finally:
            conn.close()
    
    def create_visualization(self, df):
        """Create appropriate visualization based on data"""
        if len(df.columns) == 1:
            # Single column - show bar chart
            fig = px.bar(df, x=df.columns[0], title="Query Results")
            return fig
        elif len(df.columns) == 2:
            # Two columns - show relationship
            if pd.api.types.is_numeric_dtype(df.iloc[:, 1]):
                fig = px.bar(df, x=df.columns[0], y=df.columns[1], title="Query Results")
            else:
                fig = px.pie(df, names=df.columns[0], values=df.columns[1], title="Query Results")
            return fig
        else:
            # For more columns, show table with interactive features
            return None
    
    def on_run_click(self, b):
        """Handler for Run button click"""
        with self.output:
            clear_output()
            question = self.question_input.value.strip()
            
            if not question:
                print("Please enter a question")
                return
            
            # Generate SQL
            sql = self.generate_sql(question)
            self.sql_output.value = sql
            
            # Execute query
            df, error = self.execute_query(sql)
            
            if error:
                display(HTML(f"<div style='color:red; padding:10px; border:1px solid red;'>Error: {error}</div>"))
            else:
                # Add to history
                self.query_history.append({
                    'question': question,
                    'sql': sql,
                    'results': df.to_dict('records') if df is not None else []
                })
                
                # Save the current state
                self.save_state()
                
                # Display results
                display(HTML("<h3>Query Results</h3>"))
                
                if df is not None:
                    # Try to create visualization
                    fig = self.create_visualization(df)
                    
                    if fig is not None:
                        display(fig)
                    
                    # Always show data table
                    display(HTML(df.to_html(index=False, classes='table table-striped')))
                    display(HTML(f"<p>Returned {len(df)} rows</p>"))
                    
                    # Create shareable URL
                    share_url = f"data:application/vnd.ms-excel,{quote(df.to_csv(index=False))}"
                    share_button = widgets.Button(
                        description="Export Results",
                        icon='download'
                    )
                    share_button.on_click(lambda b: webbrowser.open(share_url))
                    display(share_button)
                else:
                    display(HTML("<p>No results returned</p>"))
    
    def save_state(self):
        """Save the current state of the UI"""
        state = {
            'question': self.question_input.value,
            'sql': self.sql_output.value,
            'query_history': self.query_history
        }
        with open('text2sql_ui_state.pkl', 'wb') as f:
            pickle.dump(state, f)
    
    def load_state(self):
        """Load the saved state of the UI"""
        if os.path.exists('text2sql_ui_state.pkl'):
            with open('text2sql_ui_state.pkl', 'rb') as f:
                state = pickle.load(f)
                self.question_input.value = state.get('question', '')
                self.sql_output.value = state.get('sql', '')
                self.query_history = state.get('query_history', [])
    
    def create_ui(self):
        """Create the enhanced interactive UI"""
        # Custom CSS
        display(HTML("""
        <style>
            .sql-output {
                font-family: monospace;
                background-color: #f5f5f5;
                padding: 10px;
                border-radius: 5px;
                border: 1px solid #ddd;
            }
            .table {
                width: 100%;
                margin-top: 20px;
            }
            .table-striped tbody tr:nth-child(odd) {
                background-color: #f9f9f9;
            }
            .schema-panel {
                background-color: #f8f9fa;
                padding: 15px;
                border-radius: 5px;
                height: 100%;
            }
        </style>
        """))
        
        # Input widgets
        self.question_input = widgets.Textarea(
            value='',
            placeholder='Enter your question (e.g., "Show students in Computer Science major")',
            description='Question:',
            layout=widgets.Layout(width='95%', height='100px')
        )
        
        self.run_button = widgets.Button(
            description='Run Query',
            button_style='success',
            icon='play',
            layout=widgets.Layout(width='200px')
        )
        self.run_button.on_click(self.on_run_click)
        
        # Output widgets
        self.sql_output = widgets.Textarea(
            value='',
            description='Generated SQL:',
            layout=widgets.Layout(width='95%', height='100px'),
            style={'description_width': 'initial'},
            disabled=True
        )
        
        self.output = widgets.Output(
            layout=widgets.Layout(width='95%', border='1px solid #eee', padding='10px')
        )
        
        # Schema viewer
        schema_html = "<div class='schema-panel'><h3>Database Schema</h3><ul>"
        for table, columns in self.schema.items():
            col_list = ", ".join([f"<code>{col['name']}</code> ({col['type']})" for col in columns])
            schema_html += f"<li><strong>{table}</strong>: {col_list}</li>"
        schema_html += "</ul></div>"
        
        self.schema_viewer = widgets.HTML(schema_html)
        
        # Documentation link
        docs_button = widgets.Button(
            description='Open Documentation',
            icon='book',
            layout=widgets.Layout(width='200px', margin='10px 0 0 0')
        )
        docs_button.on_click(lambda b: webbrowser.open("https://example.com/docs"))
        
        # Load saved state
        self.load_state()
        
        # Display the last query results if history exists
        if self.query_history:
            last_query = self.query_history[-1]
            with self.output:
                display(HTML("<h3>Query Results</h3>"))
                df = pd.DataFrame(last_query['results'])
                
                # Try to create visualization
                fig = self.create_visualization(df)
                
                if fig is not None:
                    display(fig)
                
                # Always show data table
                display(HTML(df.to_html(index=False, classes='table table-striped')))
                display(HTML(f"<p>Returned {len(df)} rows</p>"))
                
                # Create shareable URL
                share_url = f"data:application/vnd.ms-excel,{quote(df.to_csv(index=False))}"
                share_button = widgets.Button(
                    description="Export Results",
                    icon='download'
                )
                share_button.on_click(lambda b: webbrowser.open(share_url))
                display(share_button)
        
        # Assemble UI
        left_panel = widgets.VBox([
            widgets.HTML("<h1 style='margin-top:0;'>Text-to-SQL Agent</h1>"),
            self.question_input,
            self.run_button,
            self.sql_output,
            docs_button,
            self.output
        ], layout=widgets.Layout(width='70%'))
        
        right_panel = widgets.VBox([
            self.schema_viewer
        ], layout=widgets.Layout(width='30%'))
        
        display(widgets.HBox([left_panel, right_panel]))

# Initialize and display the UI
ui = EnhancedText2SQLUI(
    model_path="llama3-text2sql",
    db_path="spider_sqlite/activity_1/activity_1.sqlite"
)

  warn(


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

HBox(children=(VBox(children=(HTML(value="<h1 style='margin-top:0;'>Text-to-SQL Agent</h1>"), Textarea(value='…

In [2]:
#Agentic approach with followup question:

from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
import sqlite3
import pandas as pd
import plotly.express as px
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import torch
import webbrowser
from urllib.parse import quote
import pickle
import os

class EnhancedText2SQLUI:
    def __init__(self, model_path, db_path):
        self.model_path = model_path
        self.db_path = db_path
        self.query_history = []
        self.current_context = {}
        self.initialize_agent()
        self.create_ui()
        
    def initialize_agent(self):
        """Initialize the Text-to-SQL agent"""
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
            config = PeftConfig.from_pretrained(self.model_path)
            base_model = AutoModelForCausalLM.from_pretrained(
                config.base_model_name_or_path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto"
            )
            self.model = PeftModel.from_pretrained(
                base_model,
                self.model_path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            ).merge_and_unload()
            
            self.schema = self.load_schema()
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.model.to(self.device)
            
        except Exception as e:
            print(f"Error initializing agent: {str(e)}")
            raise
    
    def load_schema(self):
        """Load database schema"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0] for table in cursor.fetchall()]
        
        schema = {}
        for table in tables:
            cursor.execute(f"PRAGMA table_info({table});")
            schema[table] = [{"name": col[1], "type": col[2]} for col in cursor.fetchall()]
        
        conn.close()
        return schema
    
    def generate_sql(self, question, context=None):
        """Generate SQL from natural language question with context"""
        schema_lines = []
        for table, cols in self.schema.items():
            col_defs = [f"{col['name']} ({col['type']})" for col in cols]
            schema_lines.append(f"Table {table}: {', '.join(col_defs)}")
        
        schema_prompt = "\n".join(schema_lines)
        
        context_prompt = ""
        if context:
            context_prompt = f"\n\nContext from previous questions:\n{context}"
        
        prompt = f"""Given this database schema, translate the question to SQL:
{schema_prompt}{context_prompt}

Question: {question}
SQL Query:"""
        
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=256,
                pad_token_id=self.tokenizer.eos_token_id,
                num_beams=5,
                temperature=0.7
            )
        
        return self.tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True
        ).strip()
    
    def execute_query(self, sql):
        """Execute SQL query and return results as DataFrame"""
        conn = sqlite3.connect(self.db_path)
        
        try:
            df = pd.read_sql_query(sql, conn)
            return df, None
        except Exception as e:
            return None, str(e)
        finally:
            conn.close()
    
    def create_visualization(self, df):
        """Create appropriate visualization based on data"""
        if len(df.columns) == 1:
            # Single column - show bar chart
            fig = px.bar(df, x=df.columns[0], title="Query Results")
            return fig
        elif len(df.columns) == 2:
            # Two columns - show relationship
            if pd.api.types.is_numeric_dtype(df.iloc[:, 1]):
                fig = px.bar(df, x=df.columns[0], y=df.columns[1], title="Query Results")
            else:
                fig = px.pie(df, names=df.columns[0], values=df.columns[1], title="Query Results")
            return fig
        else:
            # For more columns, show table with interactive features
            return None
    
    def handle_ambiguity(self, question, error):
        """Handle ambiguous questions by asking follow-up"""
        clarification_prompt = f"""The following question is ambiguous when trying to convert to SQL:
Question: {question}
Error: {error}

Please suggest 2-3 clarifying questions that would help resolve this ambiguity, formatted as a bulleted list:"""
        
        inputs = self.tokenizer(clarification_prompt, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=256,
                pad_token_id=self.tokenizer.eos_token_id,
                temperature=0.7
            )
        
        clarifications = self.tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True
        ).strip()
        
        return clarifications
    
    def on_run_click(self, b):
        """Handler for Run button click"""
        with self.output:
            clear_output()
            question = self.question_input.value.strip()
            
            if not question:
                display(HTML("<div style='color:red; padding:10px;'>Please enter a question</div>"))
                return
            
            # Generate SQL with context
            context = "\n".join([f"Previous question: {q['question']}\nSQL: {q['sql']}" 
                               for q in self.query_history[-3:]])  # Last 3 queries as context
            sql = self.generate_sql(question, context)
            self.sql_output.value = sql
            
            # Execute query
            df, error = self.execute_query(sql)
            
            if error:
                # Try to handle ambiguity
                clarifications = self.handle_ambiguity(question, error)
                
                display(HTML(f"""
                <div style='color:red; padding:10px; border:1px solid red;'>
                    <strong>Error:</strong> {error}
                    <div style='margin-top:10px;'>
                        <strong>Possible clarifications needed:</strong>
                        {clarifications}
                    </div>
                </div>
                """))
                
                # Add follow-up question widgets
                self.create_followup_questions(clarifications)
            else:
                # Add to history
                self.query_history.append({
                    'question': question,
                    'sql': sql,
                    'results': df.to_dict('records') if df is not None else []
                })
                
                # Save the current state
                self.save_state()
                
                # Display results
                self.display_results(df)
    
    def create_followup_questions(self, clarifications):
        """Create widgets for follow-up questions"""
        lines = [line.strip() for line in clarifications.split('\n') if line.strip()]
        self.followup_buttons = []
        
        display(HTML("<h4>Select a follow-up question:</h4>"))
        
        for i, line in enumerate(lines[:3]):  # Show max 3 follow-ups
            if line.startswith('-') or line.startswith('*'):
                line = line[1:].strip()
            btn = widgets.Button(
                description=line,
                layout=widgets.Layout(width='95%', margin='5px 0')
            )
            btn.on_click(lambda b, q=line: self.on_followup_click(q))
            self.followup_buttons.append(btn)
            display(btn)
    
    def on_followup_click(self, question):
        """Handler for follow-up question click"""
        self.question_input.value = question
        self.on_run_click(None)
    
    def display_results(self, df):
        """Display query results with visualization"""
        display(HTML("<h3>Query Results</h3>"))
        
        if df is not None:
            # Try to create visualization
            fig = self.create_visualization(df)
            
            if fig is not None:
                display(fig)
            
            # Always show data table
            display(HTML(df.to_html(index=False, classes='table table-striped')))
            display(HTML(f"<p>Returned {len(df)} rows</p>"))
            
            # Create shareable URL
            share_url = f"data:application/vnd.ms-excel,{quote(df.to_csv(index=False))}"
            share_button = widgets.Button(
                description="Export Results",
                icon='download'
            )
            share_button.on_click(lambda b: webbrowser.open(share_url))
            display(share_button)
        else:
            display(HTML("<p>No results returned</p>"))
    
    def save_state(self):
        """Save the current state of the UI"""
        state = {
            'question': self.question_input.value,
            'sql': self.sql_output.value,
            'query_history': self.query_history,
            'current_context': self.current_context
        }
        with open('text2sql_ui_state.pkl', 'wb') as f:
            pickle.dump(state, f)
    
    def load_state(self):
        """Load the saved state of the UI"""
        if os.path.exists('text2sql_ui_state.pkl'):
            with open('text2sql_ui_state.pkl', 'rb') as f:
                state = pickle.load(f)
                self.question_input.value = state.get('question', '')
                self.sql_output.value = state.get('sql', '')
                self.query_history = state.get('query_history', [])
                self.current_context = state.get('current_context', {})
    
    def create_ui(self):
        """Create the enhanced interactive UI"""
        # Custom CSS
        display(HTML("""
        <style>
            .sql-output {
                font-family: monospace;
                background-color: #f5f5f5;
                padding: 10px;
                border-radius: 5px;
                border: 1px solid #ddd;
            }
            .table {
                width: 100%;
                margin-top: 20px;
            }
            .table-striped tbody tr:nth-child(odd) {
                background-color: #f9f9f9;
            }
            .schema-panel {
                background-color: #f8f9fa;
                padding: 15px;
                border-radius: 5px;
                height: 100%;
            }
            .error-panel {
                color: red;
                padding: 10px;
                border: 1px solid red;
                margin: 10px 0;
            }
            .followup-panel {
                background-color: #f0f7ff;
                padding: 15px;
                border-radius: 5px;
                margin: 10px 0;
            }
        </style>
        """))
        
        # Input widgets
        self.question_input = widgets.Textarea(
            value='',
            placeholder='Enter your question (e.g., "Show students in Computer Science major")',
            description='Question:',
            layout=widgets.Layout(width='95%', height='100px')
        )
        
        self.run_button = widgets.Button(
            description='Run Query',
            button_style='success',
            icon='play',
            layout=widgets.Layout(width='200px')
        )
        self.run_button.on_click(self.on_run_click)
        
        # Output widgets
        self.sql_output = widgets.Textarea(
            value='',
            description='Generated SQL:',
            layout=widgets.Layout(width='95%', height='100px'),
            style={'description_width': 'initial'},
            disabled=True
        )
        
        self.output = widgets.Output(
            layout=widgets.Layout(width='95%', border='1px solid #eee', padding='10px')
        )
        
        # Schema viewer
        schema_html = "<div class='schema-panel'><h3>Database Schema</h3><ul>"
        for table, columns in self.schema.items():
            col_list = ", ".join([f"<code>{col['name']}</code> ({col['type']})" for col in columns])
            schema_html += f"<li><strong>{table}</strong>: {col_list}</li>"
        schema_html += "</ul></div>"
        
        self.schema_viewer = widgets.HTML(schema_html)
        
        # Documentation link
        docs_button = widgets.Button(
            description='Open Documentation',
            icon='book',
            layout=widgets.Layout(width='200px', margin='10px 0 0 0')
        )
        docs_button.on_click(lambda b: webbrowser.open("https://example.com/docs"))
        
        # Load saved state
        self.load_state()
        
        # Display the last query results if history exists
        if self.query_history:
            last_query = self.query_history[-1]
            with self.output:
                self.display_results(pd.DataFrame(last_query['results']))
        
        # Assemble UI
        left_panel = widgets.VBox([
            widgets.HTML("<h1 style='margin-top:0;'>Text-to-SQL Agent</h1>"),
            self.question_input,
            self.run_button,
            self.sql_output,
            docs_button,
            self.output
        ], layout=widgets.Layout(width='70%'))
        
        right_panel = widgets.VBox([
            self.schema_viewer
        ], layout=widgets.Layout(width='30%'))
        
        display(widgets.HBox([left_panel, right_panel]))

# Initialize and display the UI
ui = EnhancedText2SQLUI(
    model_path="llama3-text2sql",
    db_path="spider_sqlite/activity_1/activity_1.sqlite"
)



Current model requires 128 bytes of buffer for offloaded layers, which seems does not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using offload_buffers=True.



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]


Current model requires 256 bytes of buffer for offloaded layers, which seems does not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using offload_buffers=True.



HBox(children=(VBox(children=(HTML(value="<h1 style='margin-top:0;'>Text-to-SQL Agent</h1>"), Textarea(value='…

In [3]:
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
import sqlite3
import pandas as pd
import plotly.express as px
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import torch
import webbrowser
from urllib.parse import quote
import pickle
import os
import re

class EnhancedText2SQLUI:
    def __init__(self, model_path, db_path):
        self.model_path = model_path
        self.db_path = db_path
        self.query_history = []
        self.current_context = {}
        self.followup_mode = False
        self.followup_question = ""
        self.initialize_agent()
        self.create_ui()
        
    def initialize_agent(self):
        """Initialize the Text-to-SQL agent with offloading for GPU memory issues"""
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
            config = PeftConfig.from_pretrained(self.model_path)
            base_model = AutoModelForCausalLM.from_pretrained(
                config.base_model_name_or_path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto",
                offload_buffers=True
            )
            self.model = PeftModel.from_pretrained(
                base_model,
                self.model_path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            ).merge_and_unload()
            
            self.schema = self.load_schema()
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.model.to(self.device)
            
        except Exception as e:
            print(f"Error initializing agent: {str(e)}")
            raise
    
    def load_schema(self):
        """Load database schema"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0] for table in cursor.fetchall()]
        
        schema = {}
        for table in tables:
            cursor.execute(f"PRAGMA table_info({table});")
            schema[table] = [{"name": col[1], "type": col[2]} for col in cursor.fetchall()]
        
        conn.close()
        return schema
    
    def generate_sql(self, question, context=None):
        """Generate SQL from natural language question with context"""
        schema_lines = []
        for table, cols in self.schema.items():
            col_defs = [f"{col['name']} ({col['type']})" for col in cols]
            schema_lines.append(f"Table {table}: {', '.join(col_defs)}")
        
        schema_prompt = "\n".join(schema_lines)
        
        context_prompt = ""
        if context:
            context_prompt = f"\n\nContext from previous questions:\n{context}"
        
        prompt = f"""Given this database schema, translate the question to SQL:
{schema_prompt}{context_prompt}

Question: {question}
SQL Query:"""
        
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=512,
                pad_token_id=self.tokenizer.eos_token_id,
                num_beams=5,
                temperature=0.7
            )
        
        sql = self.tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True
        ).strip()
        
        # Robust cleaning to remove redundant INTERSECT clauses
        if "INTERSECT" in sql.upper():
            # Split into subqueries
            subqueries = sql.split("INTERSECT")
            # Keep only the first subquery if they're identical
            base_query = subqueries[0].strip()
            if all(subquery.strip() == base_query for subquery in subqueries[1:]):
                sql = base_query
            else:
                # If subqueries differ, keep the first valid one
                sql = base_query
        
        return sql
    
    def execute_query(self, sql):
        """Execute SQL query and return results as DataFrame"""
        conn = sqlite3.connect(self.db_path)
        
        try:
            df = pd.read_sql_query(sql, conn)
            return df, None
        except Exception as e:
            return None, str(e)
        finally:
            conn.close()
    
    def create_visualization(self, df):
        """Create appropriate visualization based on data"""
        if len(df.columns) == 1:
            fig = px.bar(df, x=df.columns[0], title="Query Results")
            return fig
        elif len(df.columns) == 2:
            if pd.api.types.is_numeric_dtype(df.iloc[:, 1]):
                fig = px.bar(df, x=df.columns[0], y=df.columns[1], title="Query Results")
            else:
                fig = px.pie(df, names=df.columns[0], values=df.columns[1], title="Query Results")
            return fig
        else:
            return None
    
    def detect_ambiguity(self, question):
        """Detect if a question is ambiguous"""
        all_tables = list(self.schema.keys())
        all_columns = [col['name'] for cols in self.schema.values() for col in cols]
        
        patterns = [
            r'\b(some|information|details|data)\b.*\b(about|for|of)\b',
            r'\b(show|find|get)\b.*\b(without|any)\b.*\b(field|column)\b',
            r'\b(students|faculty|activities)\b.*\b(without|no)\b.*\b(specify|filter)\b'
        ]
        
        for pattern in patterns:
            if re.search(pattern, question.lower()):
                return True, "The question is too vague. Please specify fields or conditions."
        
        words = re.findall(r'\b\w+\b', question.lower())
        found_table = any(word in [t.lower() for t in all_tables] for word in words)
        found_column = any(word in [c.lower() for c in all_columns] for word in words)
        
        if found_table and not found_column and "all" not in question.lower():
            return True, "The question references a table but not specific columns."
        
        return False, ""
    
    def handle_ambiguity(self, question, error=None, sql=None):
        """Handle ambiguous questions by generating follow-up questions"""
        is_ambiguous, ambiguity_message = self.detect_ambiguity(question)
        if is_ambiguous or error:
            clarification_prompt = f"""The following question is ambiguous or caused an error when converting to SQL:
Question: {question}
{'Generated SQL: ' + sql if sql else ''}
{'Error: ' + error if error else 'Message: ' + ambiguity_message}

Suggest 2-3 specific follow-up questions to resolve the ambiguity or error, formatted as a numbered list (1., 2., 3.) with clear, concise questions relevant to the schema and error/message:"""
            
            inputs = self.tokenizer(clarification_prompt, return_tensors="pt").to(self.device)
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=256,
                    pad_token_id=self.tokenizer.eos_token_id,
                    temperature=0.7
                )
            
            clarifications = self.tokenizer.decode(
                outputs[0][inputs["input_ids"].shape[1]:],
                skip_special_tokens=True
            ).strip()
            
            questions = []
            for line in clarifications.split('\n'):
                match = re.match(r'^\d+\.\s*(.+)', line.strip())
                if match:
                    questions.append(match.group(1))
            
            return questions[:3]
        return []
    
    def on_run_click(self, b):
        """Handler for Run button click"""
        with self.output:
            clear_output()
            question = self.question_input.value.strip()
            
            if not question:
                display(HTML("<div style='color:red; padding:10px;'>Please enter a question</div>"))
                return
            
            if self.followup_mode:
                refined_context = f"{self.followup_question}\nUser response: {question}"
                sql = self.generate_sql(question, refined_context)
                self.followup_mode = False
                self.followup_question = ""
            else:
                context = "\n".join([f"Previous question: {q['question']}\nSQL: {q['sql']}" 
                                   for q in self.query_history[-3:]])
                sql = self.generate_sql(question, context)
                self.sql_output.value = sql
            
            # Check for ambiguity before execution
            is_ambiguous, ambiguity_message = self.detect_ambiguity(question)
            if is_ambiguous:
                followup_questions = self.handle_ambiguity(question, sql=sql)
                if followup_questions:
                    display(HTML(f"""
                    <div style='color:red; padding:10px; border:1px solid red;'>
                        <strong>Clarification Needed:</strong> {ambiguity_message}
                        <div style='margin-top:10px;'>
                            <strong>Please clarify by selecting or answering one of the following:</strong>
                            <ol>
                                {'<li>' + '</li><li>'.join(followup_questions) + '</li>'}
                            </ol>
                        </div>
                    </div>
                    """))
                    
                    self.followup_input = widgets.Textarea(
                        value='',
                        placeholder='Type your clarification here or select a question below',
                        layout=widgets.Layout(width='95%', height='60px')
                    )
                    self.followup_submit = widgets.Button(
                        description='Submit Clarification',
                        button_style='primary',
                        layout=widgets.Layout(width='200px', margin='10px 0')
                    )
                    self.followup_submit.on_click(lambda b: self.on_followup_submit(question))
                    
                    self.followup_buttons = []
                    for q in followup_questions:
                        btn = widgets.Button(
                            description=q[:40] + "..." if len(q) > 40 else q,
                            layout=widgets.Layout(width='95%', margin='5px 0'),
                            tooltip=q
                        )
                        btn.on_click(lambda b, q=q: self.on_followup_select(q))
                        self.followup_buttons.append(btn)
                        display(btn)
                    
                    display(widgets.VBox([self.followup_input, self.followup_submit]))
                return  # Stop execution until clarification
            
            df, error = self.execute_query(sql)
            
            if error:
                followup_questions = self.handle_ambiguity(question, error=error, sql=sql)
                if followup_questions:
                    display(HTML(f"""
                    <div style='color:red; padding:10px; border:1px solid red;'>
                        <strong>Error:</strong> {error}
                        <div style='margin-top:10px;'>
                            <strong>Please clarify by selecting or answering one of the following:</strong>
                            <ol>
                                {'<li>' + '</li><li>'.join(followup_questions) + '</li>'}
                            </ol>
                        </div>
                    </div>
                    """))
                    
                    self.followup_input = widgets.Textarea(
                        value='',
                        placeholder='Type your clarification here or select a question below',
                        layout=widgets.Layout(width='95%', height='60px')
                    )
                    self.followup_submit = widgets.Button(
                        description='Submit Clarification',
                        button_style='primary',
                        layout=widgets.Layout(width='200px', margin='10px 0')
                    )
                    self.followup_submit.on_click(lambda b: self.on_followup_submit(question))
                    
                    self.followup_buttons = []
                    for q in followup_questions:
                        btn = widgets.Button(
                            description=q[:40] + "..." if len(q) > 40 else q,
                            layout=widgets.Layout(width='95%', margin='5px 0'),
                            tooltip=q
                        )
                        btn.on_click(lambda b, q=q: self.on_followup_select(q))
                        self.followup_buttons.append(btn)
                        display(btn)
                    
                    display(widgets.VBox([self.followup_input, self.followup_submit]))
            else:
                self.query_history.append({
                    'question': question,
                    'sql': sql,
                    'results': df.to_dict('records') if df is not None else []
                })
                self.save_state()
                self.display_results(df)
    
    def on_followup_select(self, question):
        """Handle selection of a follow-up question"""
        self.question_input.value = question
        self.on_run_click(None)
    
    def on_followup_submit(self, original_question):
        """Handle custom clarification submission"""
        clarification = self.followup_input.value.strip()
        if clarification:
            self.followup_mode = True
            self.followup_question = f"Original question: {original_question}\nClarification: {clarification}"
            self.question_input.value = clarification
            self.on_run_click(None)
    
    def display_results(self, df):
        """Display query results with visualization"""
        display(HTML("<h3>Query Results</h3>"))
        
        if df is not None:
            fig = self.create_visualization(df)
            if fig is not None:
                display(fig)
            display(HTML(df.to_html(index=False, classes='table table-striped')))
            display(HTML(f"<p>Returned {len(df)} rows</p>"))
            share_url = f"data:application/vnd.ms-excel,{quote(df.to_csv(index=False))}"
            share_button = widgets.Button(description="Export Results", icon='download')
            share_button.on_click(lambda b: webbrowser.open(share_url))
            display(share_button)
        else:
            display(HTML("<p>No results returned</p>"))
    
    def save_state(self):
        """Save the current state of the UI"""
        state = {
            'question': self.question_input.value,
            'sql': self.sql_output.value,
            'query_history': self.query_history,
            'current_context': self.current_context,
            'followup_mode': self.followup_mode,
            'followup_question': self.followup_question
        }
        with open('text2sql_ui_state.pkl', 'wb') as f:
            pickle.dump(state, f)
    
    def load_state(self):
        """Load the saved state of the UI"""
        if os.path.exists('text2sql_ui_state.pkl'):
            with open('text2sql_ui_state.pkl', 'rb') as f:
                state = pickle.load(f)
                self.question_input.value = state.get('question', '')
                self.sql_output.value = state.get('sql', '')
                self.query_history = state.get('query_history', [])
                self.current_context = state.get('current_context', {})
                self.followup_mode = state.get('followup_mode', False)
                self.followup_question = state.get('followup_question', '')
    
    def create_ui(self):
        """Create the enhanced interactive UI"""
        display(HTML("""
        <style>
            .sql-output {font-family: monospace; background-color: #f5f5f5; padding: 10px; border-radius: 5px; border: 1px solid #ddd;}
            .table {width: 100%; margin-top: 20px;}
            .table-striped tbody tr:nth-child(odd) {background-color: #f9f9f9;}
            .schema-panel {background-color: #f8f9fa; padding: 15px; border-radius: 5px; height: 100%;}
            .error-panel {color: red; padding: 10px; border: 1px solid red; margin: 10px 0;}
            .followup-panel {background-color: #f0f7ff; padding: 15px; border-radius: 5px; margin: 10px 0;}
        </style>
        """))
        
        self.question_input = widgets.Textarea(
            value='',
            placeholder='Enter your question (e.g., "Show students in Computer Science major")',
            description='Question:',
            layout=widgets.Layout(width='95%', height='100px')
        )
        
        self.run_button = widgets.Button(
            description='Run Query',
            button_style='success',
            icon='play',
            layout=widgets.Layout(width='200px')
        )
        self.run_button.on_click(self.on_run_click)
        
        self.sql_output = widgets.Textarea(
            value='',
            description='Generated SQL:',
            layout=widgets.Layout(width='95%', height='100px'),
            style={'description_width': 'initial'},
            disabled=True
        )
        
        self.output = widgets.Output(
            layout=widgets.Layout(width='95%', border='1px solid #eee', padding='10px')
        )
        
        schema_html = "<div class='schema-panel'><h3>Database Schema</h3><ul>"
        for table, columns in self.schema.items():
            col_list = ", ".join([f"<code>{col['name']}</code> ({col['type']})" for col in columns])
            schema_html += f"<li><strong>{table}</strong>: {col_list}</li>"
        schema_html += "</ul></div>"
        
        self.schema_viewer = widgets.HTML(schema_html)
        
        docs_button = widgets.Button(
            description='Open Documentation',
            icon='book',
            layout=widgets.Layout(width='200px', margin='10px 0 0 0')
        )
        docs_button.on_click(lambda b: webbrowser.open("https://example.com/docs"))
        
        self.load_state()
        
        if self.query_history:
            last_query = self.query_history[-1]
            with self.output:
                self.display_results(pd.DataFrame(last_query['results']))
        
        left_panel = widgets.VBox([
            widgets.HTML("<h1 style='margin-top:0;'>Text-to-SQL Agent</h1>"),
            self.question_input,
            self.run_button,
            self.sql_output,
            docs_button,
            self.output
        ], layout=widgets.Layout(width='70%'))
        
        right_panel = widgets.VBox([
            self.schema_viewer
        ], layout=widgets.Layout(width='30%'))
        
        display(widgets.HBox([left_panel, right_panel]))

# Initialize and display the UI
ui = EnhancedText2SQLUI(
    model_path="llama3-text2sql",
    db_path="spider_sqlite/activity_1/activity_1.sqlite"
)


Current model requires 128 bytes of buffer for offloaded layers, which seems does not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using offload_buffers=True.



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]


Current model requires 256 bytes of buffer for offloaded layers, which seems does not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using offload_buffers=True.



HBox(children=(VBox(children=(HTML(value="<h1 style='margin-top:0;'>Text-to-SQL Agent</h1>"), Textarea(value='…