In [1]:
!pip install --upgrade pip
!pip install transformers torch gradio


Collecting pip
  Downloading pip-25.1-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-25.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.1
Collecting gradio
  Downloading gradio-5.27.1-py3-none-any.whl.metadata (16 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
import re
import difflib
import pandas as pd
import sqlite3
import torch
import gradio as gr
from pathlib import Path
from transformers import BartForConditionalGeneration, BartTokenizerFast

# --- Configuration: adjust checkpoint path if needed ---
CHECKPOINT = "/content/drive/MyDrive/text2sql_checkpoint"  # or your model folder
DEVICE     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load model & tokenizer ---
tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT)
model     = BartForConditionalGeneration.from_pretrained(CHECKPOINT).to(DEVICE)
model.eval()

# --- Prompt builder ---
def extract_entities(q: str):
    # stub for entity extraction
    return []

def build_input(question: str, schema_txt: str) -> str:
    ents = extract_entities(question)
    et = ";".join(ents) if ents else "NONE"
    return f"[ENT]{et}[/ENT][SCHEMA]{schema_txt}[/SCHEMA]Question: {question}"

# --- Load CSVs into SQLite and build schema string ---
def load_csvs_to_sqlite(files):
    conn = sqlite3.connect(":memory:")
    table_defs = []
    for idx, file in enumerate(files):
        # name tables t1, t2, ...
        tbl = f"t{idx+1}"
        df = pd.read_csv(file.name)
        df.to_sql(tbl, conn, index=False, if_exists="replace")
        cols = ",".join(df.columns)
        table_defs.append(f"{tbl}({cols})")
    schema = " | ".join(table_defs)
    return conn, schema

# --- Main query function with robust fallbacks ---
def query_multi_csv(files, question):
    try:
        conn, schema = load_csvs_to_sqlite(files)
        prompt = build_input(question, schema)
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="longest").to(DEVICE)
        out    = model.generate(**inputs, num_beams=5, max_length=256, early_stopping=True)
        sql    = tokenizer.decode(out[0], skip_special_tokens=True)

        # --- Post-process malformed SQL ---
        # fix missing space after FROM ("FROMt1" -> "FROM t1")
        sql = re.sub(r"\bFROMt(\d+)", r"FROM t\1", sql, flags=re.IGNORECASE)
        # fix bare "FROM 1" -> "FROM t1"
        sql = re.sub(r"\bFROM\s+1\b", "FROM t1", sql, flags=re.IGNORECASE)
        # ensure space after WHERE
        sql = re.sub(r"WHERE(?=[A-Za-z_])", "WHERE ", sql, flags=re.IGNORECASE)

        # execute generated SQL
        try:
            df = pd.read_sql_query(sql, conn)
        except Exception:
            df = pd.DataFrame()

        # fallback 1: case-insensitive LIKE
        if df.empty:
            m = re.match(
                r'SELECT\s+(?P<proj>.+?)\s+FROM\s+(?P<table>t\d+)\s+WHERE\s+(?P<col>\w+)\s*=\s*["\'](?P<val>.+?)["\']',
                sql, re.IGNORECASE
            )
            if m:
                proj, table, col, val = m.group("proj","table","col","val")
                like_sql = f"SELECT {proj} FROM {table} WHERE lower({col}) LIKE '%{val.lower()}%' COLLATE NOCASE"
                try:
                    df = pd.read_sql_query(like_sql, conn)
                    sql = like_sql
                except:
                    pass

        # fallback 2: fuzzy match on column values
        if df.empty and 'm' in locals() and m:
            distinct = pd.read_sql_query(f"SELECT DISTINCT {m.group('col')} FROM {m.group('table')}", conn)[m.group('col')].astype(str).tolist()
            close = difflib.get_close_matches(m.group('val'), distinct, n=1, cutoff=0.6)
            if close:
                corrected = close[0]
                fuzzy_sql = f"SELECT {m.group('proj')} FROM {m.group('table')} WHERE {m.group('col')} = '{corrected}'"
                try:
                    df = pd.read_sql_query(fuzzy_sql, conn)
                    sql = fuzzy_sql
                except:
                    pass

        return sql, df

    except Exception as e:
        return f"ERROR: {type(e).__name__}: {e}", pd.DataFrame()

# --- Gradio UI ---
with gr.Blocks() as demo:
    gr.Markdown("## Text-to-SQL over Multiple CSVs")
    with gr.Row():
        csv_inputs = gr.File(label="Upload one or more CSVs", file_types=['.csv'], file_count="multiple")
        question   = gr.Textbox(label="Question", placeholder="e.g. who is author of The Catcher in the Rye")
    submit    = gr.Button("Submit")
    sql_out   = gr.Textbox(label="Generated SQL")
    results   = gr.Dataframe(label="Query Results")
    submit.click(fn=query_multi_csv, inputs=[csv_inputs, question], outputs=[sql_out, results])

demo.launch(share=True)


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://1a75b3d40b3027c661.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


