## T5 Transformers

### Import Libraries

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install happytransformer

In [3]:
import transformers
print(transformers.__version__)

4.29.2


In [4]:
!pip uninstall transformers -y

Found existing installation: transformers 4.29.2
Uninstalling transformers-4.29.2:
  Successfully uninstalled transformers-4.29.2


In [None]:
# Restart Kernel
# https://github.com/huggingface/transformers/issues/22846

In [None]:
!pip install transformers==4.28.0

In [2]:
import transformers
print(transformers.__version__)

4.28.0


In [10]:
import pandas as pd
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from happytransformer import HappyTextToText
from happytransformer import TTTrainArgs
from happytransformer import TTSettings

### Load Dataset

In [None]:
dataset_sql_create_conext = load_dataset("b-mc2/sql-create-context")
dataset_sql_create_context_alpaca_style = load_dataset("lucasmccabe-lmi/sql-create-context_alpaca_style")

In [5]:
data = dataset_sql_create_context_alpaca_style

In [6]:
data

DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output'],
        num_rows: 78577
    })
})

In [7]:
df = data['train'].to_pandas()

In [8]:
df.head(3)

Unnamed: 0,instruction,input,output
0,Write a SQL query that answers the following q...,The relevant table was constructed using the f...,SELECT race_4 FROM table_name_51 WHERE race_1 ...
1,Write a SQL query that answers the following q...,The relevant table was constructed using the f...,SELECT AVG(total) FROM table_name_99 WHERE bro...
2,Write a SQL query that answers the following q...,The relevant table was constructed using the f...,"SELECT SUM(skin_depth), _inches FROM table_nam..."


In [11]:
# Update the data to required format

df_updated = pd.concat([df['input'] + '. ' + df['instruction'] + ' ' , df['output']], axis=1)
df_updated.columns = ['input', 'target']

In [12]:
df_updated.head()

Unnamed: 0,input,target
0,The relevant table was constructed using the f...,SELECT race_4 FROM table_name_51 WHERE race_1 ...
1,The relevant table was constructed using the f...,SELECT AVG(total) FROM table_name_99 WHERE bro...
2,The relevant table was constructed using the f...,"SELECT SUM(skin_depth), _inches FROM table_nam..."
3,The relevant table was constructed using the f...,"SELECT first_name, last_name, department_id FR..."
4,The relevant table was constructed using the f...,SELECT result FROM table_name_38 WHERE opponen...


In [13]:
len(df_updated)

78577

In [26]:
# Take 10,000 random examples to train our model on.

df_reduced = df_updated.sample(n=10000, random_state=42)

In [27]:
train_df, test_df = train_test_split(df_reduced, test_size=0.2, random_state=42)

In [28]:
train_df.head(2)

Unnamed: 0,input,target
2383,The relevant table was constructed using the f...,SELECT tries_for FROM table_name_47 WHERE losi...
9146,The relevant table was constructed using the f...,SELECT COUNT(attendance) FROM table_name_20 WH...


In [29]:
test_df.head(2)

Unnamed: 0,input,target
8022,The relevant table was constructed using the f...,SELECT team_nickname FROM table_26476336_2 WHE...
70230,The relevant table was constructed using the f...,SELECT player FROM table_name_79 WHERE score =...


In [30]:
# Save the data

train_df.to_csv('/content/df_train.csv', index=False)
test_df.to_csv('/content/df_test.csv', index=False)

In [31]:
len(train_df)

8000

### FineTune the Model

In [36]:
# Load the model if saved earlier

# wikisq2_t5 = HappyTextToText(model_type = "T5", model_name = "mrm8488/t5-base-finetuned-wikiSQL", load_path ='/content/drive/MyDrive/Text2SQL/model_tt_wikisql')
# t5_base = HappyTextToText(model_type = "T5", model_name = "t5-base", load_path ='/content/drive/MyDrive/Text2SQL/model')
beam_settings =  TTSettings(num_beams=5, min_length=1, max_length=70)

#### T5 Base

In [None]:
# T5 base pretrained model

t5_base = HappyTextToText("T5", "t5-base")
args = TTTrainArgs(batch_size=2, num_train_epochs=3, learning_rate=0.001)
beam_settings =  TTSettings(num_beams=5, min_length=1, max_length=70)

In [None]:
# Train the model

t5_base.train('/content/df_train.csv', args=args)

Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-2c5128a4554e562f/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-2c5128a4554e562f/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/24000 [00:00<?, ? examples/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
500,0.5819
1000,0.425
1500,0.3856
2000,0.3214
2500,0.3037
3000,0.288
3500,0.2629
4000,0.2466
4500,0.2537
5000,0.2515


KeyboardInterrupt: ignored

In [34]:
test_df

Unnamed: 0,input,target
8022,The relevant table was constructed using the f...,SELECT team_nickname FROM table_26476336_2 WHE...
70230,The relevant table was constructed using the f...,SELECT player FROM table_name_79 WHERE score =...
39192,The relevant table was constructed using the f...,SELECT venue FROM table_name_45 WHERE score = ...
18672,The relevant table was constructed using the f...,SELECT time_retired FROM table_name_46 WHERE l...
40074,The relevant table was constructed using the f...,SELECT AVG(week) FROM table_name_94 WHERE reco...
...,...,...
44005,The relevant table was constructed using the f...,SELECT buena_vista_edition FROM table_25173505...
57927,The relevant table was constructed using the f...,SELECT school_club_team FROM table_name_48 WHE...
28309,The relevant table was constructed using the f...,SELECT to_par FROM table_name_76 WHERE player ...
48067,The relevant table was constructed using the f...,SELECT construction AS date FROM table_2218035...


In [72]:
# Predicted vs True SQL queries 

example_1 = test_df['input'][test_df.index[0]] 
result_1 = t5_base.generate_text(example_1, args=beam_settings)
print(example_1)
print(" ")
print(result_1.text)
print(" ")
print(test_df['target'][test_df.index[0]])



The relevant table was constructed using the following SQL CREATE TABLE statement: CREATE TABLE table_26476336_2 (team_nickname VARCHAR, institution VARCHAR). Write a SQL query that answers the following question: What is the nickname at the University of Nebraska at Omaha? 
 
SELECT team_nickname FROM table_26476336_2 WHERE institution = "University of Nebraska"
 
SELECT team_nickname FROM table_26476336_2 WHERE institution = "University of Nebraska at Omaha"


In [73]:
# Predicted vs True SQL queries 

example_2 = test_df['input'][test_df.index[1]]
result_2 = t5_base.generate_text(example_2, args=beam_settings)
print(example_2)
print(" ")
print(result_2.text)
print(" ")
print(test_df['target'][test_df.index[1]])



The relevant table was constructed using the following SQL CREATE TABLE statement: CREATE TABLE table_name_79 (player VARCHAR, score VARCHAR). Write a SQL query that answers the following question: Who is the player with a 75-68-70=213 score? 
 
SELECT player FROM table_name_79 WHERE score = 75 - 68 - 70 = 213
 
SELECT player FROM table_name_79 WHERE score = 75 - 68 - 70 = 213


In [43]:
# Predicted vs True SQL queries 

example_3 = '''
The relevant table was constructed using the following SQL CREATE TABLE statement: 
CREATE TABLE Employee (id INTEGER, start_date DATE). 
Write a SQL query that answers the following question: How many employees were hired last month? 
'''

result_3 = t5_base.generate_text(example_3, args=beam_settings)
print(result_3.text)

SELECT MAX(id) FROM Employee ORDER BY start_date DESC LIMIT 1


In [70]:
# Predicted vs True SQL queries 

example_4 = '''
The relevant table was constructed using the following SQL CREATE TABLE statement: 
CREATE TABLE Employee (id INTEGER, gender VARCHAR). 
Write a SQL query that answers the following question: What percentage of employees are female? 
'''

result_4 = t5_base.generate_text(example_4, args=beam_settings)
print(result_4.text)

SELECT MAX(id) FROM Employee WHERE gender = "Female"


In [None]:
# Evaluation Loss on test dataset

result = t5_base.eval("/content/df_test.csv", args=args)
print(type(result))  
print(result)  
print(result.loss)

Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-d050b9c7281c68fe/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating eval split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-d050b9c7281c68fe/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/6000 [00:00<?, ? examples/s]



<class 'happytransformer.happy_trainer.EvalResult'>
EvalResult(loss=0.13618651032447815)
0.13618651032447815


#### T5 Base finetuned on wikiSQL

In [None]:
# T5 model pretrained on wikisql dataset

wikisq2_t5 = HappyTextToText("T5", "mrm8488/t5-base-finetuned-wikiSQL")
args = TTTrainArgs(batch_size=2, num_train_epochs=1, learning_rate=0.001)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.23k [00:00<?, ?B/s]

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


Downloading pytorch_model.bin:   0%|          | 0.00/1.19G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


In [None]:
# Train the model

wikisq2_t5.train('/content/df_train.csv', args=args)

Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-3fdd6af5e7d32397/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-3fdd6af5e7d32397/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/24000 [00:00<?, ? examples/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
500,0.5107
1000,0.3998
1500,0.3622
2000,0.3234
2500,0.286
3000,0.2812
3500,0.2606
4000,0.2424
4500,0.2449
5000,0.2443


Step,Training Loss
500,0.5107
1000,0.3998
1500,0.3622
2000,0.3234
2500,0.286
3000,0.2812
3500,0.2606
4000,0.2424
4500,0.2449
5000,0.2443


KeyboardInterrupt: ignored

In [74]:
# Predicted vs True SQL queries 

example_1 = test_df['input'][test_df.index[0]]
result_1 = wikisq2_t5.generate_text(example_1, args=beam_settings)
print(example_1)
print(" ")
print(result_1.text)
print(" ")
print(test_df['target'][test_df.index[0]])

The relevant table was constructed using the following SQL CREATE TABLE statement: CREATE TABLE table_26476336_2 (team_nickname VARCHAR, institution VARCHAR). Write a SQL query that answers the following question: What is the nickname at the University of Nebraska at Omaha? 
 
SELECT team_nickname FROM table_26476336_2 WHERE institution = "University of Nebraska"
 
SELECT team_nickname FROM table_26476336_2 WHERE institution = "University of Nebraska at Omaha"


In [75]:
# Predicted vs True SQL queries 

example_2 = test_df['input'][test_df.index[1]]
result_2 = wikisq2_t5.generate_text(example_2, args=beam_settings)
print(example_2)
print(" ")
print(result_2.text)
print(" ")
print(test_df['target'][test_df.index[1]])

The relevant table was constructed using the following SQL CREATE TABLE statement: CREATE TABLE table_name_79 (player VARCHAR, score VARCHAR). Write a SQL query that answers the following question: Who is the player with a 75-68-70=213 score? 
 
SELECT player FROM table_name_79 WHERE score = 75 - 68 - 70 = 213
 
SELECT player FROM table_name_79 WHERE score = 75 - 68 - 70 = 213


In [50]:
# Predicted vs True SQL queries 

example_3 = '''
The relevant table was constructed using the following SQL CREATE TABLE statement: 
CREATE TABLE Employee (id INTEGER, start_date DATE). 
Write a SQL query that answers the following question: How many employees were hired last month? 
'''

result_3 = wikisq2_t5.generate_text(example_3, args=beam_settings)
print(result_3.text)

SELECT SUM(id) FROM Employee WHERE start_date = "September"


In [51]:
# Predicted vs True SQL queries 

example_4 = '''
The relevant table was constructed using the following SQL CREATE TABLE statement: 
CREATE TABLE Employee (id INTEGER, gender VARCHAR). 
Write a SQL query that answers the following question: What percentage of employees are female? 
'''

result_4 = wikisq2_t5.generate_text(example_4, args=beam_settings)
print(result_4.text)
print(" ")

SELECT MIN(id) FROM Employee WHERE gender = "Female"
 


### Save the Models

In [None]:
t5_base.save("/content/drive/MyDrive/Text2SQL/model/")
wikisq2_t5.save("/content/drive/MyDrive/Text2SQL/model_tt_wikisql/")

### Database

In [171]:
import sqlite3
import re

In [1]:
# Upload the .csv files of the data

In [109]:
# Read the dataset

df_earning = pd.read_csv('/content/drive/MyDrive/Text2SQL/Intern Assessment Data (HR) - Earning.csv')
df_employee = pd.read_csv('/content/drive/MyDrive/Text2SQL/Intern Assessment Data (HR) - Employee.csv')
df_employee_pay_roll = pd.read_csv('/content/drive/MyDrive/Text2SQL/Intern Assessment Data (HR) - EmployeePayrollRun.csv')
df_group = pd.read_csv('/content/drive/MyDrive/Text2SQL/Intern Assessment Data (HR) - Group.csv')
df_pay_group = pd.read_csv('/content/drive/MyDrive/Text2SQL/Intern Assessment Data (HR) - PayGroup.csv')
df_payroll_run = pd.read_csv('/content/drive/MyDrive/Text2SQL/Intern Assessment Data (HR) - PayrollRun.csv')

In [63]:
conn = sqlite3.connect(':memory:')

In [111]:
# Create tables in SQL

df_earning.to_sql('Earning', conn, index=False)
df_employee.to_sql('Employee', conn, index=False)
df_employee_pay_roll.to_sql('Employee Pay Roll', conn, index=False)
df_group.to_sql('Group', conn, index=False)
df_pay_group.to_sql('Pay Group', conn, index=False)
df_payroll_run.to_sql('Payroll Run', conn, index=False)

17

In [67]:
query = "SELECT * FROM Employee WHERE first_name = 'Judith'"
result = pd.read_sql_query(query, conn)

In [68]:
result

Unnamed: 0,id,remote_id,employee_number,company,first_name,last_name,display_full_name,username,groups,work_email,...,ssn,gender,ethnicity,marital_status,date_of_birth,start_date,remote_created_at,employment_status,termination_date,avatar
0,b53b1dff-6136-4ba1-880e-b7262f17c370,62274421,10365,3479aeef-f3fa-44ef-a319-83db557bbc62,Judith,Braun,Judith Braun,Judith.Braun,"8d809e31-6bf0-4840-b31d-e445027f2306,05454df5-...",Judith.Braun@ACME-United.com,...,896-24-9191,MALE,ASIAN,DIVORCED,1959-03-31,2013-05-11,2023-05-01,INACTIVE,2014-10-18,https://picsum.photos/568/180
1,a43735cd-3ced-4fbf-8688-247a9077f074,16993324,51289,3479aeef-f3fa-44ef-a319-83db557bbc62,Judith,Ross,Judith Ross,Judith.Ross,"aeead089-00c2-4446-a7cf-1d3bb8584895,d5a21ad4-...",Judith.Ross@ACME-United.com,...,759-76-6553,FEMALE,WHITE,DIVORCED,1995-08-27,2016-08-20,2014-10-01,INACTIVE,2020-03-19,https://picsum.photos/656/897
2,95f87766-dbb5-4921-8676-f5b6719cc50b,8331477,20595,3479aeef-f3fa-44ef-a319-83db557bbc62,Judith,Hernandez,Judith Hernandez,Judith.Hernandez,"c536f3ab-0912-442c-b53e-496af749d25e,abea7ee3-...",Judith.Hernandez@ACME-United.com,...,507-65-6291,MALE,HISPANIC,MARRIED,1969-10-26,2020-07-08,2015-07-13,ACTIVE,2014-06-27,https://placekitten.com/620/335


In [168]:
query_1 = '''
The relevant table was constructed using the following SQL CREATE TABLE statement: 
CREATE TABLE Employee (first_name VARCHAR, ethnicity VARCHAR). 
Write a SQL query that answers the following question: Who is the first_name with a ASIAN ethnicity?
'''

query_2 = '''
The relevant table was constructed using the following SQL CREATE TABLE statement: 
CREATE TABLE Employee (first_name VARCHAR, ssn INTEGER). 
Write a SQL query that answers the following question: Who is the first_name with a ssn 896-24-9191?
'''

query_3 = '''
The relevant table was constructed using the following SQL CREATE TABLE statement: 
CREATE TABLE Group (name VARCHAR, type VARCHAR). 
Write a SQL query that answers the following question: What is name when type is DEPARTMENT?
'''

In [179]:
def text_to_sql(query):
  '''
  Input: Natural Language Query
  Output: Tabular Data
  The function converts the Natural Language Query to a SQL query using our fine-tuned model.
  It then cleans the query to make it suitable to pass it to 'read_sql_query' function from pyhton.
  It then returns tabular data generated after executing the query.
  '''
  query_result = wikisq2_t5.generate_text(query, args=beam_settings)
  query_result = query_result.text
  output_string = re.sub(r'(?<=FROM\s)(\w+)(?=\sWHERE|$)', r"'\1'", query_result)
  result = pd.read_sql_query(output_string, conn)
  return result

In [131]:
df_query_1 = text_to_sql(query_1)



In [165]:
df_query_1.head(3)

Unnamed: 0,first_name
0,Judith
1,Tanya
2,Joy


In [133]:
df_query_2 = text_to_sql(query_2)



In [134]:
df_query_2

Unnamed: 0,first_name
0,Judith


In [180]:
df_query_3 = text_to_sql(query_3)



In [181]:
df_query_3.head(3)

Unnamed: 0,name
0,Environmental education officer Department
1,Astronomer Department
2,Research scientist (physical sciences) Department


Below are some more moderate example for which the model performs good.

In [None]:
def output_text_to_sql(query):
  query_result = wikisq2_t5.generate_text(query, args=beam_settings)
  query_result = query_result.text
  output_string = re.sub(r'(?<=FROM\s)(\w+)(?=\sWHERE|$)', r"'\1'", query_result)
  return output_string

In [182]:
query_4 = '''
The relevant table was constructed using the following SQL CREATE TABLE statement: 
CREATE TABLE Employee Pay Roll (gross_pay INTEGER, net_pay INTEGER, remote_was_deleted VARCHAR).
Write a SQL query that answers the following question: what is the average gross_pay when net_pay is more than 2000, remote_was_deleted is FALSE?
'''

In [195]:
query_4_in_sql = output_text_to_sql(query_4)
print(query_4_in_sql)

SELECT AVG(gross_pay) FROM Employee Pay Pay Roll WHERE net_pay > 2000 AND remote_was_deleted = "FALSE"


In [199]:
query_5 = '''
The relevant table was constructed using the following SQL CREATE TABLE statement: 
CREATE TABLE Employee (first_name VARCHAR, start_date VARCHAR).
Write a SQL query that answers the following question: What is the first_name when the start_date of the job is 2013-06-06?"
'''

In [200]:
query_5_in_sql = output_text_to_sql(query_5)
print(query_5_in_sql)

SELECT first_name FROM 'Employee' WHERE start_date = "2013-06-06"


In [206]:
query_6 = '''
The relevant table was constructed using the following SQL CREATE TABLE statement: 
CREATE TABLE Earning (id INTEGER, type VARCHAR).
Write a SQL query that answers the following question: How many id type was REIMBURSEMENT?"
'''

In [207]:
query_6_in_sql = output_text_to_sql(query_6)
print(query_6_in_sql)

SELECT SUM(id) FROM 'Earning' WHERE type = "ReIMBURSEMENT"
