In [3]:
!pip install torch transformers bitsandbytes accelerate sqlparse openai

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re

model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto",
    use_cache=True,
)


Collecting bitsandbytes
  Downloading bitsandbytes-0.42.0-py3-none-any.whl.metadata (9.9 kB)
Downloading bitsandbytes-0.42.0-py3-none-any.whl (105.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 MB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.42.0


tokenizer_config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/515 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.59G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

# Load Dataset

In [4]:
import pandas as pd
translation_path = "/kaggle/input/spider-translation-1/questions-translation.csv"
df = pd.read_csv(translation_path)
df.head()

Unnamed: 0,question,translation
0,How many heads of the departments are older th...,كم رئيسًا للأقسام هم أكبر من 56 سنة؟
1,"List the name, born state and age of the heads...",قائمة بأسماء رؤساء الأقسام، مكان ميلادهم، وأعم...
2,"List the creation year, name and budget of eac...",قائمة بسنوات الإنشاء، وأسماء وميزانيات كل قسم.
3,What are the maximum and minimum budget of the...,ما هي أقصى وأدنى ميزانية للأقسام؟
4,What is the average number of employees of the...,ما هو المتوسط ​​لعدد الموظفين في الأقسام الذين...


In [5]:
import re
def process(record):
    if not isinstance(record['translation'], str):
        return record
    
    record['translation'] = re.sub(re.compile(r'\[Translation]\s*'), '', record['translation'])
    record['translation'] = re.sub(re.compile(r'\[Question]\s*'), '', record['translation'])
    return record
df_processed = df.apply(process, axis=1)
df_processed.head()

Unnamed: 0,question,translation
0,How many heads of the departments are older th...,كم رئيسًا للأقسام هم أكبر من 56 سنة؟
1,"List the name, born state and age of the heads...",قائمة بأسماء رؤساء الأقسام، مكان ميلادهم، وأعم...
2,"List the creation year, name and budget of eac...",قائمة بسنوات الإنشاء، وأسماء وميزانيات كل قسم.
3,What are the maximum and minimum budget of the...,ما هي أقصى وأدنى ميزانية للأقسام؟
4,What is the average number of employees of the...,ما هو المتوسط ​​لعدد الموظفين في الأقسام الذين...


In [6]:
import json
spider_db_path = "/kaggle/input/yale-universitys-spider-10-nlp-dataset/spider/database"
spider_train_path = "/kaggle/input/yale-universitys-spider-10-nlp-dataset/spider/train_spider.json"

with open(spider_train_path, "r") as f:
    data = json.load(f)
    
data[0]

{'db_id': 'department_management',
 'query': 'SELECT count(*) FROM head WHERE age  >  56',
 'query_toks': ['SELECT',
  'count',
  '(',
  '*',
  ')',
  'FROM',
  'head',
  'WHERE',
  'age',
  '>',
  '56'],
 'query_toks_no_value': ['select',
  'count',
  '(',
  '*',
  ')',
  'from',
  'head',
  'where',
  'age',
  '>',
  'value'],
 'question': 'How many heads of the departments are older than 56 ?',
 'question_toks': ['How',
  'many',
  'heads',
  'of',
  'the',
  'departments',
  'are',
  'older',
  'than',
  '56',
  '?'],
 'sql': {'except': None,
  'from': {'conds': [], 'table_units': [['table_unit', 1]]},
  'groupBy': [],
  'having': [],
  'intersect': None,
  'limit': None,
  'orderBy': [],
  'select': [False, [[3, [0, [0, 0, False], None]]]],
  'union': None,
  'where': [[False, 3, [0, [0, 10, False], None], 56.0, None]]}}

In [7]:
import os
import sqlite3

def get_conn(db_id: str):
    db_file = f"{db_id}.sqlite"
    db_path = os.path.join(spider_db_path, db_id, db_file)
    conn = sqlite3.connect(db_path)
    return conn

def get_db_schema(conn) -> str:
    res = conn.execute("SELECT * FROM sqlite_master").fetchall()
    schema = ""
    for d in res:
        if d[-1] is None:
            continue
            
        schema += f"{d[-1]}\n\n"
    return schema

def inference(question: str, schema: str) -> str:
    prompt = """### Examples
Some example questions and corresponding SQL queries are provided based on similar problems:
Answer the following : ماهى العنواين الالكترونية للمستخدمين؟
SELECT email FROM users;
    
### Instructions:
Your task is convert a question into a SQL query, given a sqlite database schema.
Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float
- **Use only one column per sql query**
    
### Input:
Generate a SQL query that answers the question `{question}`.
This query will run on a database whose schema is represented in this string:
{schema}
### Response:
Based on your instructions, here is the SQL query I have generated to answer the question `{question}`:
```sql
""".format(question=question, schema=schema)
    eos_token_id = tokenizer.eos_token_id
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=eos_token_id,
        pad_token_id=eos_token_id,
        max_new_tokens=400,
        do_sample=False,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

    def postgres_to_sqlite(query: str) -> str:
        substitutions = [
            (r'ilike', 'LIKE'),
            (r'serial\s*$', 'INTEGER PRIMARY KEY AUTOINCREMENT'),
            (r'start\s+with\s+(\d+)', 'CHECK (id >= \\1)'),
        ]

        for pattern, replacement in substitutions:
            query = re.sub(pattern, replacement, query, flags=re.IGNORECASE)

        return query

    postgres_query = outputs[0].split("```sql")[-1].rstrip("```")
    query = postgres_to_sqlite(postgres_query)
    return query

# Evaluation

In [8]:
def compare_sql(gold: str, gen: str, conn, is_ordered=False):
    try:
        gold_res = pd.read_sql(gold, conn)
    except Exception as _:
        print("[Ground Fail]", gold)
        return 1

    try:
        gen_res = pd.read_sql(gen, conn)
    except Exception as _:
        print("[Gen Fail]", gen)
        return 0

    accuracy = 0
    if (len(gold_res)) == 0:
        return 1 if len(gen_res) == 0 else 0
    
    gold_len = len(gold_res)
    gen_len = len(gen_res)
    for i in range(min(gold_len, gen_len)):
        gold_record = gold_res.values[i]
        
        if not is_ordered:
            try:
                is_match = gold_record in gen_res.values
            except:
                is_match = False
        else:
            is_match = gold_record == gen_res.values[i]
            if (type(is_match) != bool):
                is_match = is_match.all()
                
        if is_match:
            accuracy += 1

    return accuracy / len(gold_res)

In [17]:
from openai import OpenAI
api_key = "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"
client = OpenAI(api_key=api_key)

def translate(client: OpenAI, text: str) -> str:
    """Translate text from English to Arabic."""
    completion = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {
                "role": "user",
                "content": "Imagine you are a translator from Arabic to "
                + "English on the following text. Give me the translation and "
                + f"the translation only.\n> Question: {text}\n> Answer: ",
            },
        ],
    )

    return completion.choices[0].message.content


In [19]:
import sqlparse
NUM_EVAL = 100
total_accuracy = 0
for i, h in enumerate(data):
    if i >= NUM_EVAL:
        break
        
    db_id = h['db_id']
    conn = get_conn(db_id)
    schema = get_db_schema(conn)
    arabic_question = df.iloc[i]['translation']
    english_translation = translate(client, arabic_question)
    gen = inference(english_translation, schema)
    gold = h['query']
    sub_accuracy = compare_sql(gold, gen, conn)
    total_accuracy += sub_accuracy
    format_gold = sqlparse.format(gold, reident=True)
    format_gen = sqlparse.format(gen, reident=True)
    print(f"[{i}] Accuracy: {sub_accuracy * 100}%\n> Question: {english_translation}\n> Ground:\n{format_gold}\n\n> Gen:{format_gen}\n")
    
total_accuracy /= NUM_EVAL
print(f"Execution accuracy: {total_accuracy * 100}%")
    
        

2024-02-08 01:27:03.253275: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-08 01:27:03.253366: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-08 01:27:03.406718: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[0] Accuracy: 100.0%
> Question: How many department heads are older than 56 years old?
> Ground:
SELECT count(*) FROM head WHERE age  >  56

> Gen:
 SELECT COUNT(*) FROM head h WHERE h.age > 56;

[1] Accuracy: 40.0%
> Question: List of department heads' names, place of birth, and ages arranged by age.
> Ground:
SELECT name ,  born_state ,  age FROM head ORDER BY age

> Gen:
 SELECT h.name, h.born_state, CAST(h.age AS FLOAT) AS age FROM head h JOIN management m ON h.head_ID = m.head_ID ORDER BY age NULLS LAST;

[2] Accuracy: 100.0%
> Question: List of years of establishment, names, and budgets of each department.
> Ground:
SELECT creation ,  name ,  budget_in_billions FROM department

> Gen:
 SELECT d.creation, d.name, d.budget_in_billions FROM department d;

[3] Accuracy: 100.0%
> Question: What is the maximum and minimum budget for departments?
> Ground:
SELECT max(budget_in_billions) ,  min(budget_in_billions) FROM department

> Gen:
 SELECT MAX(d.Budget_in_Billions) AS max_budget, 