# Finetunning LLM : Text to SQL

Generating SQL queries from natural language descriptions can significantly streamline data retrieval and improve efficiency. This task involves training a model to interpret text descriptions and convert them into accurate SQL queries that can be executed against a database to fetch the desired information.

In this notebook, we will utilize a dataset `gretelai/synthetic_text_to_sql` containing pairs of natural language descriptions and corresponding SQL queries. The dataset will be used to train our model. Additionally, we will provide the model with SQL context to ensure accurate and context-aware query generation.

Let's begin by exploring the dataset and preparing it for training.

Sample dataset format:
```json
[
    {
        "query": #query,
        "response": #response,
    },
    {
        ....
    }
]
```

In [1]:
# prepare seed examples in required format
import json
import datasets
from tqdm import tqdm
from pprint import pprint

dataset = datasets.load_dataset("gretelai/synthetic_text_to_sql")
n_seed_samples = 1000
data = []

for d in tqdm(dataset['train']):
    data.append(
        dict(
            query = f"## sql context :\n{d['sql_context']}\n\n## Query generation task:\n{d['sql_prompt']}\n\n",
            response = d['sql']
        )
    )
    if len(data) == n_seed_samples:
        break

json.dump(data, open('texttosql_seed_data.json', 'w'))
print(len(data))
pprint(data[-1])

  1%|          | 999/100000 [00:00<00:10, 9256.09it/s]

1000
{'query': '## sql context :\n'
          'CREATE TABLE ClientInvestments (ClientID INT, InvestmentType '
          'VARCHAR(20), Value FLOAT); INSERT INTO ClientInvestments (ClientID, '
          "InvestmentType, Value) VALUES (1, 'Stock', 10000), (1, 'Bond', "
          "20000), (2, 'Stock', 30000), (2, 'Bond', 15000), (3, 'Stock', "
          "5000), (3, 'Bond', 25000), (4, 'Stock', 40000), (4, 'Bond', 30000), "
          "(5, 'Stock', 7000), (5, 'Bond', 18000); CREATE TABLE Clients "
          '(ClientID INT, State VARCHAR(20)); INSERT INTO Clients (ClientID, '
          "State) VALUES (1, 'NY'), (2, 'TX'), (3, 'CA'), (4, 'NY'), (5, "
          "'TX');\n"
          '\n'
          '## Query generation task:\n'
          'What is the total value of investments in bonds for clients '
          'residing in Texas?\n'
          '\n',
 'response': 'SELECT SUM(Value) FROM ClientInvestments CI JOIN Clients C ON '
             "CI.ClientID = C.ClientID WHERE C.State = 'TX' AND Investmen




Create a client class with your own Leeroo API

In [2]:
import os
import time
from dager.clients.client import LeerooClient


leeroo_api_key = "<api-key-here>"
client = LeerooClient(
    leeroo_api_key,
)

User: arshad Logged in!


For designing the experiments, you need to provide us:

- `evaluation_description` (optional): A short summary of your application, and what are important evaluation factors in your mind. Just describe them in natural language.  
- `workflow_name` : The name of this experiment. This will be later saved along with the id of workflow.  
- `seed_data_path` (optional): Your dataset for the desired application. The dataset should follow JSON format with `query` and `response` as fields.

In [3]:
evaluation_policy = \
"""
Extract SQL Context:
Review the SQL context given in the input, including table definitions and any sample data inserted into these tables.

Formulate Expected Query:
Based on the task description, determine the logical structure and components of the SQL query that should be generated. For instance, identify the relevant tables, columns, and conditions that should be included in the query.

Check Query Components:
Ensure the generated query includes the correct tables and columns specified in the SQL context.
Verify that the conditions and clauses in the query match the task description. For example, checking for conditions like InvestmentType = 'Bond' and State = 'TX'.

Syntax Validation:
Confirm that the generated query is syntactically correct according to SQL standards. It should be executable without syntax errors.
Logical Accuracy:

Ensure the logic of the query aligns with the task description. For instance, it should correctly aggregate the values as required by the task.

Output should be only sql query and no explaination.
"""

In [9]:
workflow_configs = client.initialize_workflow_configs(
    evaluation_description= evaluation_policy,
    workflow_name="TextToSqlCheckSensitivityQATask",
    seed_data_path="texttosql_seed_data.json",
    budget=2 # days
) 
workflow_configs

<Response [200]>


{'data_gen_config': {'task_description': "\nExtract SQL Context:\nReview the SQL context given in the input, including table definitions and any sample data inserted into these tables.\n\nFormulate Expected Query:\nBased on the task description, determine the logical structure and components of the SQL query that should be generated. For instance, identify the relevant tables, columns, and conditions that should be included in the query.\n\nCheck Query Components:\nEnsure the generated query includes the correct tables and columns specified in the SQL context.\nVerify that the conditions and clauses in the query match the task description. For example, checking for conditions like InvestmentType = 'Bond' and State = 'TX'.\n\nSyntax Validation:\nConfirm that the generated query is syntactically correct according to SQL standards. It should be executable without syntax errors.\nLogical Accuracy:\n\nEnsure the logic of the query aligns with the task description. For instance, it should co

🚀 Once you're happy with hyper-parameters, you can submit the training workflow. It will **automatically execute experiments, evaluate them, and pick the best model** based your customized evaluation system!

In [10]:
# workflow_configs['experiment_config']['0']['training_args']['num_train_epochs'] = 1

In [11]:
# Submit workflow for execution
running_workflow_status = client.submit_workflow(
    workflow_configs=workflow_configs
)
print(" Workflow running state:", running_workflow_status)

<Response [200]>
 Workflow running state: {'workflow_runnning_state_id': '1721586842'}


You can get the status of all your workflows, by running the following command:

- `runing_workflows`: shows the training workflows with `running` status.  
- `finished_workflows`: shows executed workflows

In [12]:
# Retrieve user's workflows
user_workflows = client.all_workflows()

print( f"Total finished workflows : {len(user_workflows['finished_workflows'])}")
print( f"Total running workflows : {len(user_workflows['running_workflows'])}")

user_workflows['running_workflows']

<Response [200]>
Total finished workflows : 8
Total running workflows : 1


[{'user_id': 'arshad',
  'workflow_runnning_state_id': '1721586842',
  'workflow_name': 'TextToSqlCheckSensitivityQATask',
  'workflow_start_timestamp': 1721586842.828452,
  'status': 'running'}]

If you need further details on the status of a specific workflow, you can run the following function:

- `status`: overal status of workflow
- `workflow_node_status`: status of all nodes
- `workflow_name`: name of your workflow
- `workflow_running_state_id`: id of your workflow

In [25]:
# Check status of the running workflow
workflow_status = client.get_workflow_status('1721586842')
workflow_status

<Response [200]>


{'user_id': 'arshad',
 'workflow_runnning_state_id': '1721586842',
 'workflow_name': 'TextToSqlCheckSensitivityQATask',
 'workflow_start_timestamp': 1721586842.828452,
 'status': True,
 'workflow_node_status': {'DataGenConfig-172158684276520kyab': 'Executed',
  'DataPrepConfig-172158684276524frjf': 'Executed',
  'SFTrainingConfig-172158684276533sjlf': 'Executed',
  'EvalResponseGenConfig-172158684276536gikb': 'Executed',
  'EvalConfig-172158684276545vwmm': 'Executed',
  'PickBestConfig-172158684276526yrmk': 'Executed'},
 'workflow_completed_timestamp': 1721590603.874977}

In [40]:
# Deploy the workflow
workflow_id = '1721586842'
deployment_status = client.deploy_workflow(
    workflow_id
)
print(deployment_status)

<Response [200]>
{'cluster_name': 'DeploymentState-1721599750.281206', 'status': 'Deployment started'}


In [42]:
deployment_details = client.get_workflow_deployment_status('DeploymentState-1721599750.281206')
deployment_details

<Response [200]>


{'cluster_name': 'DeploymentState-1721599750.281206',
 'ip': '3.80.255.142',
 'gradio-playground': 'http://3.80.255.142:8000',
 'api-access': 'http://3.80.255.142:9000',
 'status': 'Deployed'}

In [None]:
# Get Model id
import requests
model_details = requests.get( f"http://3.80.255.142:9000/v1/models").json()
model_id = model_details['data'][0]['id']
model_details

In [74]:
# Inference
import json
sql_data = json.load(open('texttosql_seed_data.json'))

url = "http://3.80.255.142:9000/v1/completions"

for d in sql_data[-5:]:
    data = {
        "model": model_id,
        "prompt": [d['query']],
        "max_tokens": 200,
        "temperature": 0.0
    }
    response = requests.post(url, json=data)
    print("Prompt :\n", d['query'], "\nLLM Response:")
    print(response.json()['choices'][0]['text'])
    print("\nOriginal Response:\n", d['response'])
    print("-----\n\n")

Prompt :
 ## sql context :
CREATE TABLE Funding (company_id INT, funding_year INT, amount INT); INSERT INTO Funding (company_id, funding_year, amount) VALUES (1, 2015, 3000000); INSERT INTO Funding (company_id, funding_year, amount) VALUES (2, 2017, 5000000); INSERT INTO Funding (company_id, funding_year, amount) VALUES (3, 2017, 7000000);

## Query generation task:
What is the total funding received by companies founded in 2017, ordered by the amount of funding?

 
LLM Response:
## Query:
SELECT company_id, SUM(amount) as total_funding FROM Funding WHERE funding_year = 2017 GROUP BY company_id ORDER BY total_funding DESC;

Original Response:
 SELECT company_id, SUM(amount) as total_funding FROM Funding WHERE funding_year = 2017 GROUP BY company_id ORDER BY total_funding DESC;
-----


Prompt :
 ## sql context :
CREATE TABLE GraduateStudents (StudentID INT, Name VARCHAR(50), Department VARCHAR(50), Publications INT, PublicationYear INT);

## Query generation task:
How many publications 

In [76]:
client.kill_deployment(
    'DeploymentState-1721599750.281206'
)

<Response [200]>


{'cluster_name': 'DeploymentState-1721599750.281206',
 'status': 'Deployment Killed'}