In [1]:
!pip install streamlit

[0m

In [4]:
!streamlit run text2sql_ui.py


Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8503[0m
[34m  Network URL: [0m[1mhttp://172.17.0.2:8503[0m
[34m  External URL: [0m[1mhttp://130.65.254.7:8503[0m
[0m
^C
[34m  Stopping...[0m


In [5]:
!ls

'=0.21.0'				      lora_adapters
 AGENT.ipynb				      results
 DATA266.ipynb				      spider-schema.csv
 DATA_266_Project_266_version_5_May_8.ipynb   spider_databases
 Untitled.ipynb				      spider_dataset
 Untitled1.ipynb			      spider_finetune.jsonl
 Untitled2.ipynb			      spider_schemas
 Untitled3.ipynb			      spider_scheme_to_sqlite.ipynb
 base_evaluation_results_basemodel.csv	      spider_sqlite
 database				      spider_sqlite#1
 evaluation_results_FINETUNED.csv	      text2sql_ui.py
 llama3-text2sql			      wandb
 llama3-text2sql_interrupted


In [2]:
pip install plotly

[0mNote: you may need to restart the kernel to use updated packages.


In [1]:
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
import sqlite3
import pandas as pd
import plotly.express as px
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import torch
import webbrowser
from urllib.parse import quote
import pickle
import os

class EnhancedText2SQLUI:
    def __init__(self, model_path, db_path):
        self.model_path = model_path
        self.db_path = db_path
        self.query_history = []
        self.initialize_agent()
        self.create_ui()
        
    def initialize_agent(self):
        """Initialize the Text-to-SQL agent"""
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
            config = PeftConfig.from_pretrained(self.model_path)
            base_model = AutoModelForCausalLM.from_pretrained(
                config.base_model_name_or_path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto"
            )
            self.model = PeftModel.from_pretrained(
                base_model,
                self.model_path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            ).merge_and_unload()
            
            self.schema = self.load_schema()
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.model.to(self.device)
            
        except Exception as e:
            print(f"Error initializing agent: {str(e)}")
            raise
    
    def load_schema(self):
        """Load database schema"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0] for table in cursor.fetchall()]
        
        schema = {}
        for table in tables:
            cursor.execute(f"PRAGMA table_info({table});")
            schema[table] = [{"name": col[1], "type": col[2]} for col in cursor.fetchall()]
        
        conn.close()
        return schema
    
    def generate_sql(self, question):
        """Generate SQL from natural language question"""
        schema_lines = []
        for table, cols in self.schema.items():
            col_defs = [f"{col['name']} ({col['type']})" for col in cols]
            schema_lines.append(f"Table {table}: {', '.join(col_defs)}")
        
        schema_prompt = "\n".join(schema_lines)
        
        prompt = f"""Given this database schema, translate the question to SQL:

{schema_prompt}

Question: {question}
SQL Query:"""
        
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=256,
                pad_token_id=self.tokenizer.eos_token_id,
                num_beams=5,
                temperature=0.7
            )
        
        return self.tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True
        ).strip()
    
    def execute_query(self, sql):
        """Execute SQL query and return results as DataFrame"""
        conn = sqlite3.connect(self.db_path)
        
        try:
            df = pd.read_sql_query(sql, conn)
            return df, None
        except Exception as e:
            return None, str(e)
        finally:
            conn.close()
    
    def create_visualization(self, df):
        """Create appropriate visualization based on data"""
        if len(df.columns) == 1:
            # Single column - show bar chart
            fig = px.bar(df, x=df.columns[0], title="Query Results")
            return fig
        elif len(df.columns) == 2:
            # Two columns - show relationship
            if pd.api.types.is_numeric_dtype(df.iloc[:, 1]):
                fig = px.bar(df, x=df.columns[0], y=df.columns[1], title="Query Results")
            else:
                fig = px.pie(df, names=df.columns[0], values=df.columns[1], title="Query Results")
            return fig
        else:
            # For more columns, show table with interactive features
            return None
    
    def on_run_click(self, b):
        """Handler for Run button click"""
        with self.output:
            clear_output()
            question = self.question_input.value.strip()
            
            if not question:
                print("Please enter a question")
                return
            
            # Generate SQL
            sql = self.generate_sql(question)
            self.sql_output.value = sql
            
            # Execute query
            df, error = self.execute_query(sql)
            
            if error:
                display(HTML(f"<div style='color:red; padding:10px; border:1px solid red;'>Error: {error}</div>"))
            else:
                # Add to history
                self.query_history.append({
                    'question': question,
                    'sql': sql,
                    'results': df.to_dict('records') if df is not None else []
                })
                
                # Save the current state
                self.save_state()
                
                # Display results
                display(HTML("<h3>Query Results</h3>"))
                
                if df is not None:
                    # Try to create visualization
                    fig = self.create_visualization(df)
                    
                    if fig is not None:
                        display(fig)
                    
                    # Always show data table
                    display(HTML(df.to_html(index=False, classes='table table-striped')))
                    display(HTML(f"<p>Returned {len(df)} rows</p>"))
                    
                    # Create shareable URL
                    share_url = f"data:application/vnd.ms-excel,{quote(df.to_csv(index=False))}"
                    share_button = widgets.Button(
                        description="Export Results",
                        icon='download'
                    )
                    share_button.on_click(lambda b: webbrowser.open(share_url))
                    display(share_button)
                else:
                    display(HTML("<p>No results returned</p>"))
    
    def save_state(self):
        """Save the current state of the UI"""
        state = {
            'question': self.question_input.value,
            'sql': self.sql_output.value,
            'query_history': self.query_history
        }
        with open('text2sql_ui_state.pkl', 'wb') as f:
            pickle.dump(state, f)
    
    def load_state(self):
        """Load the saved state of the UI"""
        if os.path.exists('text2sql_ui_state.pkl'):
            with open('text2sql_ui_state.pkl', 'rb') as f:
                state = pickle.load(f)
                self.question_input.value = state.get('question', '')
                self.sql_output.value = state.get('sql', '')
                self.query_history = state.get('query_history', [])
    
    def create_ui(self):
        """Create the enhanced interactive UI"""
        # Custom CSS
        display(HTML("""
        <style>
            .sql-output {
                font-family: monospace;
                background-color: #f5f5f5;
                padding: 10px;
                border-radius: 5px;
                border: 1px solid #ddd;
            }
            .table {
                width: 100%;
                margin-top: 20px;
            }
            .table-striped tbody tr:nth-child(odd) {
                background-color: #f9f9f9;
            }
            .schema-panel {
                background-color: #f8f9fa;
                padding: 15px;
                border-radius: 5px;
                height: 100%;
            }
        </style>
        """))
        
        # Input widgets
        self.question_input = widgets.Textarea(
            value='',
            placeholder='Enter your question (e.g., "Show students in Computer Science major")',
            description='Question:',
            layout=widgets.Layout(width='95%', height='100px')
        )
        
        self.run_button = widgets.Button(
            description='Run Query',
            button_style='success',
            icon='play',
            layout=widgets.Layout(width='200px')
        )
        self.run_button.on_click(self.on_run_click)
        
        # Output widgets
        self.sql_output = widgets.Textarea(
            value='',
            description='Generated SQL:',
            layout=widgets.Layout(width='95%', height='100px'),
            style={'description_width': 'initial'},
            disabled=True
        )
        
        self.output = widgets.Output(
            layout=widgets.Layout(width='95%', border='1px solid #eee', padding='10px')
        )
        
        # Schema viewer
        schema_html = "<div class='schema-panel'><h3>Database Schema</h3><ul>"
        for table, columns in self.schema.items():
            col_list = ", ".join([f"<code>{col['name']}</code> ({col['type']})" for col in columns])
            schema_html += f"<li><strong>{table}</strong>: {col_list}</li>"
        schema_html += "</ul></div>"
        
        self.schema_viewer = widgets.HTML(schema_html)
        
        # Documentation link
        docs_button = widgets.Button(
            description='Open Documentation',
            icon='book',
            layout=widgets.Layout(width='200px', margin='10px 0 0 0')
        )
        docs_button.on_click(lambda b: webbrowser.open("https://example.com/docs"))
        
        # Load saved state
        self.load_state()
        
        # Display the last query results if history exists
        if self.query_history:
            last_query = self.query_history[-1]
            with self.output:
                display(HTML("<h3>Query Results</h3>"))
                df = pd.DataFrame(last_query['results'])
                
                # Try to create visualization
                fig = self.create_visualization(df)
                
                if fig is not None:
                    display(fig)
                
                # Always show data table
                display(HTML(df.to_html(index=False, classes='table table-striped')))
                display(HTML(f"<p>Returned {len(df)} rows</p>"))
                
                # Create shareable URL
                share_url = f"data:application/vnd.ms-excel,{quote(df.to_csv(index=False))}"
                share_button = widgets.Button(
                    description="Export Results",
                    icon='download'
                )
                share_button.on_click(lambda b: webbrowser.open(share_url))
                display(share_button)
        
        # Assemble UI
        left_panel = widgets.VBox([
            widgets.HTML("<h1 style='margin-top:0;'>Text-to-SQL Agent</h1>"),
            self.question_input,
            self.run_button,
            self.sql_output,
            docs_button,
            self.output
        ], layout=widgets.Layout(width='70%'))
        
        right_panel = widgets.VBox([
            self.schema_viewer
        ], layout=widgets.Layout(width='30%'))
        
        display(widgets.HBox([left_panel, right_panel]))

# Initialize and display the UI
ui = EnhancedText2SQLUI(
    model_path="llama3-text2sql",
    db_path="spider_sqlite/activity_1/activity_1.sqlite"
)

  warn(


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

HBox(children=(VBox(children=(HTML(value="<h1 style='margin-top:0;'>Text-to-SQL Agent</h1>"), Textarea(value='…