### Imports

In [2]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, logging

  from .autonotebook import tqdm as notebook_tqdm


### Initialize Picard model

In [3]:
pic_model_path = 'tscholak/cxmefzzi'
pic_model = AutoModelForSeq2SeqLM.from_pretrained(pic_model_path)
pic_tokenizer = AutoTokenizer.from_pretrained(pic_model_path)

: 

### Format Schema

In [21]:
import re

def parse_schema(sql_text):
    tables = {}
    # Split into individual statements by semicolon
    statements = sql_text.split(";")
    for statement in statements:
        # Check if it's a CREATE TABLE statement
        match = re.search(r'CREATE TABLE (\w+)', statement, re.IGNORECASE)
        if match:
            table_name = match.group(1)
            tables[table_name] = []
            # Extract column definitions within the parentheses
            columns_section = re.search(r'\((.*?)\)', statement, re.S)
            if columns_section:
                columns = re.findall(r'(\w+)\s+\w+', columns_section.group(1))
                tables[table_name].extend(columns)
    return tables

# Read and parse schema
with open("TestData/schema.sql", "r") as file:
    schema_sql = file.read()
schema_data = parse_schema(schema_sql)

print(schema_data)

{'Activity': ['actid', 'PRIMARY', 'activity_name'], 'Participates_in': ['stuid', 'actid', 'FOREIGN'], 'Faculty_Participates_in': ['FacID', 'actid', 'FOREIGN'], 'Student': ['StuID', 'PRIMARY', 'LName'], 'Faculty': ['FacID', 'PRIMARY', 'Lname']}


### Format Input and Question

In [22]:
def format_picard_input(question, db_id, schema):
    tables_info = []
    for table, columns in schema.items():
        table_info = f"{table} : " + " , ".join([f"{col} ( )" for col in columns])
        tables_info.append(table_info)
    formatted_schema = " | ".join(tables_info)
    return f"{question} | {db_id} | {formatted_schema}"

# Example
question = "How many students do soccer?"
db_id = "school"  # Replace with your database ID
picard_input = format_picard_input(question, db_id, schema_data)

print(picard_input)

How many students do soccer? | school | Activity : actid ( ) , PRIMARY ( ) , activity_name ( ) | Participates_in : stuid ( ) , actid ( ) , FOREIGN ( ) | Faculty_Participates_in : FacID ( ) , actid ( ) , FOREIGN ( ) | Student : StuID ( ) , PRIMARY ( ) , LName ( ) | Faculty : FacID ( ) , PRIMARY ( ) , Lname ( )


## Run models

In [23]:
# Tokenize input
pic_model_inputs = pic_tokenizer(picard_input, return_tensors="pt", truncation=True)

# Generate output
pic_outputs = pic_model.generate(**pic_model_inputs, max_length=1024)

# Decode output
pic_output_text = pic_tokenizer.batch_decode(pic_outputs, skip_special_tokens=True)

# Print the SQL query
print("Pic SQL Query:", pic_output_text)

Pic SQL Query: ['school | select count(*) from Participates_in as t1 join Activity as t2 on t2.actid = t2.actid where t2.activity_name = "soccer"']


In [10]:
import sqlite3

# Connect to the database (or create it if it doesn't exist)
conn = sqlite3.connect('TestData\\activity_1.sqlite')
cursor = conn.cursor()

query = 'select count(*) from Participates_in as t1 join Activity as t2 on t2.actid = t2.actid where t2.activity_name = "chess"'

# Execute the query
cursor.execute(query)

# Fetch the result
result = cursor.fetchone()  # Returns a tuple, e.g., (count_value,)
print(result)

# Close the connection
conn.close()

(0,)
