In [19]:
# create_db_sqlite.py
import os
from dotenv import load_dotenv
import pandas as pd
from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String, DateTime, ForeignKey, Numeric, func
from sqlalchemy import text
load_dotenv()
DB_URL = os.getenv("DB_URL", "sqlite:///./electronic_inc.db")

engine = create_engine(DB_URL, future=True)  
metadata = MetaData()

customers = Table(
    "customers", metadata,
    Column("id", Integer, primary_key=True),
    Column("name", String(100), nullable=False),
    Column("country", String(2), nullable=False),
    Column("age", Integer, nullable=True)
)

staff = Table(
    "staff", metadata,
    Column("id", Integer, primary_key=True),
    Column("name", String(100), nullable=False),
    Column("role", String(50), nullable=False),
    Column("department", String(50), nullable=False),
    Column("email", String(100), nullable=True, unique=True),
    Column("hire_date", DateTime, server_default=func.now())
)

# Create tables
metadata.create_all(engine)

# Seed some rows (id autoincrements on both SQLite/Postgres)
with engine.begin() as conn:
    conn.execute(customers.insert(), [
        {"name": "Alice", "country": "AT", "age": 30},
        {"name": "Bob",   "country": "DE", "age": 25},
        {"name": "Chun",  "country": "US", "age": 35},
    ])
    conn.execute(staff.insert(), [
        {"name": "Eve", "role": "Manager", "department": "Sales", "email": "eve@electronic.inc"},
        {"name": "Mallory", "role": "Technician", "department": "Support", "email": "mallory@electronic.inc"},
        {"name": "Trent", "role": "Analyst", "department": "Finance", "email": "trent@electronic.inc"},
        {"name": "Peggy", "role": "Clerk", "department": "HR", "email": "peggy@electronic.inc"},
        {"name": "Victor", "role": "Engineer", "department": "Development", "email": "victor@electronic.inc"},
    ])

In [2]:
import os
from dotenv import load_dotenv
import pandas as pd
from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String, DateTime, ForeignKey, Numeric, func
from sqlalchemy import text

In [3]:
# Connect to the database
load_dotenv()
DB_URL = os.getenv("DB_URL", "sqlite:///./electronic_inc.db")

read_only_engine = create_engine(DB_URL + "?mode=ro", future=True)


In [5]:
def get_db_schema(engine):
        """
        Returns a compact, LLM-friendly schema description:
        {
          "tables": [
             {"name": "staff",
              "columns": [{"name":"id","type":"INTEGER","nullable":False}, ...],
              "row_count": 5,
              "samples": [{"id":1,"name":"Eve",...}, ...]
             },
             ...
          ]
        }
        """
        
        tables = []
        with engine.connect() as conn:
            result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table';"))
            table_names = [row[0] for row in result.fetchall()]
            print(table_names)
            for table_name in table_names:
                columns = []
                result = conn.execute(text(f"PRAGMA table_info({table_name});"))
                for col in result.fetchall():
                    columns.append({
                        "name": col[1],
                        "type": col[2],
                        "nullable": not bool(col[3])
                    })

                table_info = {"name": table_name, "columns": columns}

                tables.append(table_info)
                    

        return {"tables": tables}

In [6]:
import json

schema = get_db_schema(read_only_engine)
print(schema)
j = json.dumps(get_db_schema(read_only_engine), indent=2)
print(j)

['customers', 'staff']
{'tables': [{'name': 'customers', 'columns': [{'name': 'id', 'type': 'INTEGER', 'nullable': False}, {'name': 'name', 'type': 'VARCHAR(100)', 'nullable': False}, {'name': 'country', 'type': 'VARCHAR(2)', 'nullable': False}, {'name': 'age', 'type': 'INTEGER', 'nullable': True}]}, {'name': 'staff', 'columns': [{'name': 'id', 'type': 'INTEGER', 'nullable': False}, {'name': 'name', 'type': 'VARCHAR(100)', 'nullable': False}, {'name': 'role', 'type': 'VARCHAR(50)', 'nullable': False}, {'name': 'department', 'type': 'VARCHAR(50)', 'nullable': False}, {'name': 'email', 'type': 'VARCHAR(100)', 'nullable': True}, {'name': 'hire_date', 'type': 'DATETIME', 'nullable': True}]}]}
['customers', 'staff']
{
  "tables": [
    {
      "name": "customers",
      "columns": [
        {
          "name": "id",
          "type": "INTEGER",
          "nullable": false
        },
        {
          "name": "name",
          "type": "VARCHAR(100)",
          "nullable": false
        },


In [9]:
def execute_query(query: str, engine=read_only_engine) -> pd.DataFrame:
    with engine.connect() as conn:
        result = conn.execute(text(query))
        if result.returns_rows:  # Check if the query returns rows
            df = pd.DataFrame(result.fetchall(), columns=result.keys())
            return df
        else:
            return None 

In [10]:
s = execute_query("SELECT * FROM staff")
print(s)

   id     name        role   department                   email  \
0   1      Eve     Manager        Sales      eve@electronic.inc   
1   2  Mallory  Technician      Support  mallory@electronic.inc   
2   3    Trent     Analyst      Finance    trent@electronic.inc   
3   4    Peggy       Clerk           HR    peggy@electronic.inc   
4   5   Victor    Engineer  Development   victor@electronic.inc   

             hire_date  
0  2025-08-25 13:20:03  
1  2025-08-25 13:20:03  
2  2025-08-25 13:20:03  
3  2025-08-25 13:20:03  
4  2025-08-25 13:20:03  


In [19]:
s = execute_query("SELECT * FROM staff WHERE department = 'Support'")
print(s)

   id     name        role department                   email  \
0   2  Mallory  Technician    Support  mallory@electronic.inc   

             hire_date  
0  2025-08-25 13:20:03  


In [29]:
test = execute_query("Delete FROM staff WHERE department = 'Sales'")

In [30]:
from sqlalchemy import inspect

SYSTEM_SCHEMAS = {"pg_catalog", "information_schema", "sqlite_master", "mysql", "performance_schema", "sys"}

def get_db_schema(engine, sample_rows: int = 3, max_tables: int = 50):
    insp = inspect(engine)
    try:
        schemas = [s for s in insp.get_schema_names() if s not in SYSTEM_SCHEMAS]
        if not schemas:  # SQLite hat meist keine Schemas
            schemas = [None]
    except Exception:
        schemas = [None]

    tables_meta = []
    with engine.connect() as conn:
        count = 0
        for schema in schemas:
            try:
                tbls = insp.get_table_names(schema=schema)
            except Exception:
                tbls = []
            for t in tbls:
                if count >= max_tables:
                    break
                # Spalten
                cols_meta = []
                try:
                    for c in insp.get_columns(t, schema=schema):
                        cols_meta.append({
                            "name": c.get("name"),
                            "type": str(c.get("type")),
                            "nullable": c.get("nullable", True),
                        })
                except Exception:
                    pass
                # FKs
                fks = []
                try:
                    for fk in insp.get_foreign_keys(t, schema=schema):
                        fks.append({"constrained_columns": fk.get("constrained_columns", []),
                                    "referred_table": fk.get("referred_table"),
                                    "referred_schema": fk.get("referred_schema")})
                except Exception:
                    pass
                # Samples
                fq = f'"{schema}".\"{t}\"' if schema else f'"{t}"'
                samples = []
                try:
                    res = conn.execute(text(f"SELECT * FROM {fq} LIMIT {sample_rows}"))
                    rows = res.fetchall()
                    cols = res.keys()
                    for r in rows:
                        samples.append(dict(zip(cols, r)))
                except Exception:
                    pass

                tables_meta.append({
                    "table name": t,
                    "columns": cols_meta,
                    "foreign_keys": fks,
                    "samples": samples,
                })
                count += 1
    return {"tables": tables_meta}


In [31]:
schema = get_db_schema(read_only_engine)
j = json.dumps(schema, indent=2)
print(j)

{
  "tables": [
    {
      "table name": "customers",
      "columns": [
        {
          "name": "id",
          "type": "INTEGER",
          "nullable": false
        },
        {
          "name": "name",
          "type": "VARCHAR(100)",
          "nullable": false
        },
        {
          "name": "country",
          "type": "VARCHAR(2)",
          "nullable": false
        },
        {
          "name": "age",
          "type": "INTEGER",
          "nullable": true
        }
      ],
      "foreign_keys": [],
      "samples": [
        {
          "id": 1,
          "name": "Alice",
          "country": "AT",
          "age": 30
        },
        {
          "id": 2,
          "name": "Bob",
          "country": "DE",
          "age": 25
        },
        {
          "id": 3,
          "name": "Chun",
          "country": "US",
          "age": 35
        }
      ]
    },
    {
      "table name": "staff",
      "columns": [
        {
          "name": "id",
         

In [40]:
import re
from collections import Counter

def pick_candidate_tables(question: str, schema_json: dict, topk: int = 3):
    q_tokens = set(re.findall(r"[A-Za-z_]+", question.lower()))
    scored = []
    for t in schema_json["tables"]:
        t_name = f"{t['table_name']}".strip(".").lower()
        t_tokens = set(re.findall(r"[A-Za-z_]+", t_name))
        col_tokens = set()
        for c in t.get("columns", []):
            col_tokens |= set(re.findall(r"[A-Za-z_]+", (c.get("name","") + " " + str(c.get("type",""))).lower()))
        score = len(q_tokens & (t_tokens | col_tokens))
        scored.append((score, t))
    scored.sort(key=lambda x: x[0], reverse=True)
    return [t for _, t in scored[:topk] if _ > 0]

In [41]:
r = pick_candidate_tables("List all staff in the Sales department", schema, topk=2)
print(json.dumps(r, indent=2))

[
  {
    "schema": "",
    "table_name": "staff",
    "columns": [
      {
        "name": "id",
        "type": "INTEGER",
        "nullable": false
      },
      {
        "name": "name",
        "type": "VARCHAR(100)",
        "nullable": false
      },
      {
        "name": "role",
        "type": "VARCHAR(50)",
        "nullable": false
      },
      {
        "name": "department",
        "type": "VARCHAR(50)",
        "nullable": false
      },
      {
        "name": "email",
        "type": "VARCHAR(100)",
        "nullable": true
      },
      {
        "name": "hire_date",
        "type": "DATETIME",
        "nullable": true
      }
    ],
    "foreign_keys": [],
    "samples": [
      {
        "id": 1,
        "name": "Eve",
        "role": "Manager",
        "department": "Sales",
        "email": "eve@electronic.inc",
        "hire_date": "2025-08-25 13:20:03"
      },
      {
        "id": 2,
        "name": "Mallory",
        "role": "Technician",
        "depart

In [34]:
def build_sql_prompt(question: str, schema_json: dict):
    candidates = pick_candidate_tables(question, schema_json, topk=4)
    # Fallback: nimm alles, wenn kein Kandidat scored
    if not candidates:
        candidates = schema_json["tables"]

    def fmt_table(t):
        fq = f"{t['table name']}"
        cols = ", ".join(c['name'] for c in t.get("columns", []))
        sample_lines = []
        for s in t.get("samples", [])[:2]:
            sample_lines.append(json.dumps(s, ensure_ascii=False))
        return (
                f"- table name: {fq}\n"
                f"  columns: {cols}\n"
                f"  samples:\n    " + ("\n    ".join(sample_lines) if sample_lines else "(none)")
        )

    schema_block = "\n".join(fmt_table(t) for t in candidates)

    return f"""
You are an expert SQL generator. Generate a single **valid** SELECT statement for SQLite/SQLAlchemy.

Rules:
- Use only the tables and columns listed in SCHEMA below.
- Quote identifiers with double quotes if they contain capitals or special chars.
- Prefer joins by foreign keys if available.
- If aggregation is implied (e.g., totals, averages), include GROUP BY and aliases.
- Return ONLY the SQL; no prose.

QUESTION:
{question}

SCHEMA:
{schema_block}

Output ONLY the SQL:
""".strip()


In [39]:
prompt = build_sql_prompt("List all customers from EU", schema)
print(prompt)

KeyError: 'table name'

In [36]:
import json
import re
from sqlalchemy import text, inspect

SYSTEM_SCHEMAS = {
    "pg_catalog", "information_schema", "sqlite_master",
    "mysql", "performance_schema", "sys", "pg_toast"
}

def _normalize_schema(schema: str | None, dialect: str) -> str:
    """Return '' for default schema; hide SQLite's main/temp."""
    if not schema:
        return ""
    s = schema.strip()
    if dialect == "sqlite" and s.lower() in {"main", "temp"}:
        return ""
    return s

def _fq_name(schema: str, table: str) -> str:
    """Quote each identifier part separately."""
    return f'"{schema}"."{table}"' if schema else f'"{table}"'

def get_db_schema(engine, sample_rows: int = 3, max_tables: int = 50):
    insp = inspect(engine)
    dialect = engine.dialect.name  # 'sqlite', 'postgresql', 'mysql', ...
    # --- discover schemas ---
    try:
        if dialect == "sqlite":
            raw_schemas = [None]  # hide main/temp
        else:
            raw_schemas = [
                s for s in insp.get_schema_names()
                if s and s not in SYSTEM_SCHEMAS and not s.startswith("_")
            ] or [None]
    except Exception as e:
        print(f"[get_db_schema] get_schema_names failed: {e}")
        raw_schemas = [None]

    tables_meta = []
    count = 0

    with engine.connect() as conn:
        for raw_schema in raw_schemas:
            schema = _normalize_schema(raw_schema, dialect)

            # tables + views
            try:
                tbls = insp.get_table_names(schema=raw_schema) or []
            except Exception as e:
                print(f"[get_db_schema] get_table_names({raw_schema}) failed: {e}")
                tbls = []
            try:
                views = insp.get_view_names(schema=raw_schema) or []
            except Exception as e:
                print(f"[get_db_schema] get_view_names({raw_schema}) failed: {e}")
                views = []

            for t in (tbls + views):
                if count >= max_tables:
                    break

                # columns
                cols_meta = []
                try:
                    for c in insp.get_columns(t, schema=raw_schema) or []:
                        cols_meta.append({
                            "name": c.get("name"),
                            "type": str(c.get("type")),
                            "nullable": c.get("nullable", True),
                        })
                except Exception as e:
                    print(f"[get_db_schema] get_columns({t},{raw_schema}) failed: {e}")

                # FKs
                fks = []
                try:
                    for fk in insp.get_foreign_keys(t, schema=raw_schema) or []:
                        fks.append({
                            "constrained_columns": fk.get("constrained_columns", []),
                            "referred_table": fk.get("referred_table"),
                            "referred_schema": _normalize_schema(
                                fk.get("referred_schema"), dialect
                            ),
                        })
                except Exception as e:
                    print(f"[get_db_schema] get_foreign_keys({t},{raw_schema}) failed: {e}")

                # samples
                samples = []
                try:
                    fq = _fq_name(schema, t)
                    res = conn.execute(text(f"SELECT * FROM {fq} LIMIT {int(sample_rows)}"))
                    rows = res.fetchall()
                    cols = list(res.keys())
                    for r in rows:
                        samples.append(dict(zip(cols, r)))
                except Exception as e:
                    print(f"[get_db_schema] sampling {t} failed: {e}")

                tables_meta.append({
                    "schema": schema,   
                    "table_name": t,
                    "columns": cols_meta,
                    "foreign_keys": fks,
                    "samples": samples,
                })
                count += 1

            if count >= max_tables:
                break

    return {"tables": tables_meta}


In [37]:
schema = get_db_schema(read_only_engine)
j = json.dumps(schema, indent=2)

In [38]:
print(j)

{
  "tables": [
    {
      "schema": "",
      "table_name": "customers",
      "columns": [
        {
          "name": "id",
          "type": "INTEGER",
          "nullable": false
        },
        {
          "name": "name",
          "type": "VARCHAR(100)",
          "nullable": false
        },
        {
          "name": "country",
          "type": "VARCHAR(2)",
          "nullable": false
        },
        {
          "name": "age",
          "type": "INTEGER",
          "nullable": true
        }
      ],
      "foreign_keys": [],
      "samples": [
        {
          "id": 1,
          "name": "Alice",
          "country": "AT",
          "age": 30
        },
        {
          "id": 2,
          "name": "Bob",
          "country": "DE",
          "age": 25
        },
        {
          "id": 3,
          "name": "Chun",
          "country": "US",
          "age": 35
        }
      ]
    },
    {
      "schema": "",
      "table_name": "staff",
      "columns": [
   