# Text-to-SQL Assistance With Claude and CodeT5

This notebook serves as a proof of concept to use a private LLM API and a finetuned transformer model in conjunction to answer questions about databases.

In [1]:
# include your anthropic api key
anthropic_api_key = ""

In [2]:
import sqlite3
import anthropic
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# set up the anthropic api client
CLIENT = anthropic.Anthropic(
    api_key=anthropic_api_key,  # Your Anthropic API key
    
    # Optional 
    base_url="https://api.anthropic.com",  # Default API endpoint
    timeout=60,  # Request timeout in seconds
    max_retries=2,  # Number of times to retry failed requests
    default_headers=None,  # Additional headers to include in requests
)
CLAUDE = "claude-3-7-sonnet-20250219"

PATH = 'finetuned/codet5-base-wikisql'
CUDA = torch.cuda.is_available()

device = torch.device("cuda" if CUDA else "cpu")
tokenizer = AutoTokenizer.from_pretrained(PATH)
model = AutoModelForSeq2SeqLM.from_pretrained(PATH)
model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


T5ForConditionalGeneration(
  (shared): Embedding(32100, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32100, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

## Query a Database

In [3]:
# function to interact with anthropic api
def ask_claude(question, schema):
    prompt = f"""Here is the schema for a database:
{schema}
Given this schema, can you output a SQL query to answer the following question? 
Only output the SQL query, use double quotes instead of single quotes in the query, and no markdown formatting or newline characters.
Question: {question}
"""
    try:
        response = CLIENT.messages.create(
            model=CLAUDE,
            max_tokens=512,
            messages=[{
                "role": 'user', "content":  prompt
            }]
        )
        return response.content[0].text
    except Exception as e:
        return f"Error querying API: {e}"

# function to generate from finetuned CodeT5-base
def ask_codet5(question, schema):
    prompt = 'schema: \n' + str(schema)[:420] + '\n\ntranslate to SQL: ' + str(question)
    inputs = tokenizer(prompt, truncation=True, padding='max_length', max_length=256, return_tensors='pt').to(device)
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    output = model.generate(input_ids, attention_mask=attention_mask, max_length=256)[0]
    prediction = tokenizer.decode(output, skip_special_tokens=True)
            
    return prediction
    
# prompt database with query
def execute_query(path, query):
    """ execute an SQL query given a path to sqlite db"""
    conn = sqlite3.connect(path)
    cursor = conn.cursor()
    
    try:
        cursor.execute("PRAGMA case_sensitive_like = OFF;") # case requirements are annoying
        cursor.execute(query)
        result = cursor.fetchall()[:]
    except sqlite3.Error as e:
        result = f"SQL Error: {e}"
    
    conn.close()
    return result

def get_table_schema(path, table):
    conn = sqlite3.connect(path)
    cursor = conn.cursor()
    
    try:    
        schema = cursor.execute(f"PRAGMA table_info({table})").fetchall()
        result = f"CREATE TABLE {table} (\n" + "\n".join([f"{col[1]} {col[2]}" for col in schema]) + "\n)"
    except sqlite3.Error as e:
        result = f"SQL Error: {e}"
    
    conn.close()
    return result

def get_database_schema(path):
    conn = sqlite3.connect(path)
    cursor = conn.cursor()

    try:
        cursor.execute("SELECT sql FROM sqlite_master WHERE type='table';")
        result = cursor.fetchall()
    except sqlite3.Error as e:
        result = f"SQL Error: {e}"

    conn.close()
    return result

# acquire a table's schema, prompt model, execute selection
def ask(question, table, db_path):
    table_schema = get_table_schema(db_path, table)
    
    codet5_prediction = ask_codet5(question, table_schema)
    claude_prediction = ask_claude(question, table_schema)
    
    print(question)
    print(f'Schema: \n{table_schema}\n')
    print(f'CodeT5 says: \n{codet5_prediction}')
    print(f'Claude says: \n{claude_prediction}')
    print('/////////////////////////////////////////////\n')

    print("Pick a number: \n1) Try what CodeT5 said.\n2) Use Claude's answer.\n3) I'll give my own SQL query!")
    selection = input()
    while selection not in ['1','2','3']:
        selection = input("I didn't catch that.\nPlease try again: ")

    if selection == '1':
        choice = codet5_prediction
    elif selection == '2':
        choice = claude_prediction
    else: 
        choice = input("Okay, let's have your query: \n")
    
    return execute_query(db_path, choice)

def run():
    # example natural language question
    question = "What is the average UnitPrice in invoice_items?"
    table = "invoice_items"
    
    db_path = 'sqlite/chinook.db'

    result = ask(question, table, db_path)

    print(f"Here's what we got back:\n{result[:500]}...")



In [4]:
#run the thing
run()

What is the average UnitPrice in invoice_items?
Schema: 
CREATE TABLE invoice_items (
InvoiceLineId INTEGER
InvoiceId INTEGER
TrackId INTEGER
UnitPrice NUMERIC(10,2)
Quantity INTEGER
)

CodeT5 says: 
SELECT AVG UnitPrice in invoice_items FROM table WHERE IDATAInvoiceLineId =
InvoiceId AND IDATAInvoiceTrackId =
InvoiceId AND IDATAQuantity =

Claude says: 
SELECT AVG(UnitPrice) FROM invoice_items
/////////////////////////////////////////////

Pick a number: 
1) Try what CodeT5 said.
2) Use Claude's answer.
3) I'll give my own SQL query!


 2


Here's what we got back:
[(1.0395535714285522,)]...
