## Zero-shot, One-shot and Few-shot Inference
This notebook shows how to add examples to the prompt.
LLM provided with examples can improve significantly results.

In [1]:
%pip install -U datasets
%pip install torch 
%pip install torchdata
%pip install transformers

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Collecting torchdata
  Downloading torchdata-0.9.0-cp312-cp312-manylinux1_x86_64.whl.metadata (5.5 kB)
Downloading torchdata-0.9.0-cp312-cp312-manylinux1_x86_64.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: torchdata
Successfully installed torchdata-0.9.0
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
from transformers import AutoModelForSeq2SeqLM
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import GenerationConfig
import pprint
pp = pprint.PrettyPrinter(indent=4)


model_name='google/flan-t5-base'

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

#get dataset with translating text description and schema into SQL query
dataset = load_dataset("gretelai/synthetic_text_to_sql")

#print 2 chosen records
example_indexes = [37, 67, 101, 78]
pp.pprint(dataset['test'][example_indices])

config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

2024-11-10 07:30:40.881260: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1731220240.897958   23342 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1731220240.902991   23342 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-10 07:30:40.919917: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/8.18k [00:00<?, ?B/s]

(…)nthetic_text_to_sql_train.snappy.parquet:   0%|          | 0.00/32.4M [00:00<?, ?B/s]

(…)ynthetic_text_to_sql_test.snappy.parquet:   0%|          | 0.00/1.90M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/100000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5851 [00:00<?, ? examples/s]

NameError: name 'example_indices' is not defined

### Zero-shot inference

In [65]:
for i, index in enumerate(example_indexes):
    schema = dataset['test'][index]['sql_context']
    instruction = dataset['test'][index]['sql_prompt']
    sql = dataset['test'][index]['sql']

    prompt = f"""
Translate schema and description into SQL query. Respond only with SQL query without any additional characters.

Schema: {schema}

Instruction: {instruction}

SQL query:
    """
    
    inputs = tokenizer(prompt, return_tensors='pt')
    output = tokenizer.decode(
        model.generate(
            inputs["input_ids"],
            max_new_tokens=50,
        )[0],
        skip_special_tokens=True
    )
    
    print('Example ', i + 1)
    print(f'Task:\n{instruction}')
    print(f'Schema:\n{schema}')
    print(f'CORRECT SQL:\n{sql}')
    print(f'MODEL GENERATION SQL - ZERO SHOT:\n{output}\n')

Example  1
Task:
What is the number of medical supplies distributed by each organization, in East Africa, for the last 3 years, and the total cost of the supplies?
Schema:
CREATE TABLE medical_supplies (supply_id INT, organization_id INT, location VARCHAR(255), supply_type VARCHAR(255), supply_cost DECIMAL(10,2), distribution_date DATE); INSERT INTO medical_supplies VALUES (1, 1, 'Country A', 'Medicine', 5000, '2020-01-01'); INSERT INTO medical_supplies VALUES (2, 1, 'Country A', 'Medical Equipment', 7000, '2021-01-01'); INSERT INTO medical_supplies VALUES (3, 2, 'Country B', 'Vaccines', 10000, '2021-01-01'); INSERT INTO medical_supplies VALUES (4, 2, 'Country B', 'First Aid Kits', 8000, '2020-01-01');
CORRECT SQL:
SELECT organization_id, location as region, COUNT(*) as number_of_supplies, SUM(supply_cost) as total_supply_cost FROM medical_supplies WHERE location = 'East Africa' AND distribution_date >= DATE_SUB(CURRENT_DATE, INTERVAL 3 YEAR) GROUP BY organization_id, location;
MODEL G

### One-shot inference

In [54]:
def create_prompt_with_shots(schema: str, instruction: str, example_indexes: list)->str:
    """
    function that generates prompt based example indexes
    """
    prompt = f"""
Translate schema and description into SQL query. Respond only with SQL query without any additional characters.
Prepare result for last example based on previous examples below: 

    """
    for index in example_indexes:
        example_schema = dataset['test'][index]['sql_context']
        example_instruction = dataset['test'][index]['sql_prompt']
        example_sql = dataset['test'][index]['sql']

        prompt += f"""
Schema: {example_schema}

Instruction: {example_instruction}

SQL query:
{example_sql}
        """

    prompt += f"""
    Schema: {schema}

Instruction: {instruction}

SQL query:


"""
    return prompt

In [66]:
example_indexes = [20]
index_to_process = 78

one_shot_prompt = create_prompt_with_shots(
    dataset['test'][index_to_process]['sql_context'],
    dataset['test'][index_to_process]['sql_prompt'],
    example_indexes
)

print(one_shot_prompt)

sql = dataset['test'][index_to_process]['sql']

inputs = tokenizer(one_shot_prompt, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
    )[0],
    skip_special_tokens=True
)

print(f'CORRECT SQL:\n{sql}')
print(f'MODEL GENERATION SQL - ZERO SHOT:\n{output}\n')


Translate schema and description into SQL query. Respond only with SQL query without any additional characters.
Prepare result for last example based on previous examples below: 

    
Schema: CREATE TABLE asthma (id INTEGER, county VARCHAR(255), state VARCHAR(255), age INTEGER, prevalence FLOAT);

Instruction: Which rural areas have the highest prevalence of asthma in children?

SQL query:
SELECT county, state, AVG(prevalence) AS avg_prevalence FROM asthma WHERE age < 18 AND county LIKE '%rural%' GROUP BY county, state ORDER BY avg_prevalence DESC LIMIT 10;
        
    Schema: CREATE TABLE unions (id INT, name TEXT); CREATE TABLE workers (id INT, union_id INT, industry TEXT, wage FLOAT); INSERT INTO unions (id, name) VALUES (1, 'Union Z'), (2, 'Union AA'), (3, 'Union AB'); INSERT INTO workers (id, union_id, industry, wage) VALUES (1, 1, 'retail', 500), (2, 1, 'retail', 550), (3, 2, 'retail', 600), (4, 2, 'retail', 650), (5, 3, 'retail', 700), (6, 3, 'retail', 750);

Instruction: Wha

### Few-shot inference

In [68]:
example_indexes = list(range(60,70))
index_to_process = 78

prompt = create_prompt_with_shots(
    dataset['test'][index_to_process]['sql_context'],
    dataset['test'][index_to_process]['sql_prompt'],
    example_indexes
)

print(one_shot_prompt)

sql = dataset['test'][index_to_process]['sql']

inputs = tokenizer(prompt, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
    )[0],
    skip_special_tokens=True
)

print(f'CORRECT SQL:\n{sql}')
print(f'MODEL GENERATION SQL - ZERO SHOT:\n{output}\n')


Translate schema and description into SQL query. Respond only with SQL query without any additional characters.
Prepare result for last example based on previous examples below: 

    
Schema: CREATE TABLE asthma (id INTEGER, county VARCHAR(255), state VARCHAR(255), age INTEGER, prevalence FLOAT);

Instruction: Which rural areas have the highest prevalence of asthma in children?

SQL query:
SELECT county, state, AVG(prevalence) AS avg_prevalence FROM asthma WHERE age < 18 AND county LIKE '%rural%' GROUP BY county, state ORDER BY avg_prevalence DESC LIMIT 10;
        
    Schema: CREATE TABLE unions (id INT, name TEXT); CREATE TABLE workers (id INT, union_id INT, industry TEXT, wage FLOAT); INSERT INTO unions (id, name) VALUES (1, 'Union Z'), (2, 'Union AA'), (3, 'Union AB'); INSERT INTO workers (id, union_id, industry, wage) VALUES (1, 1, 'retail', 500), (2, 1, 'retail', 550), (3, 2, 'retail', 600), (4, 2, 'retail', 650), (5, 3, 'retail', 700), (6, 3, 'retail', 750);

Instruction: Wha