![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/37.Text2SQL_Generation.ipynb)


# **Text2SQL Generation**

The Text-to-SQL task, which involves automatically converting natural language questions into corresponding SQL queries, has seen significant advancements with the application of state-of-the-art models. In this direction, We are excited to introduce our new Text2SQL annotator. This powerful tool revolutionizes the way you interact with databases by effortlessly translating natural language text prompts into accurate and effective SQL queries. With the integration of a state-of-the-art LLM, this annotator opens new possibilities for enhanced data retrieval and manipulation, streamlining your workflow and boosting efficiency.

Also we have a new text2sql_mimicsql model that is specifically finetuned on MIMIC-III dataset schema for enhancing the precision of SQL queries derived from medical natural language queries on MIMIC dataset.

In addition, we introduced two models can generate SQL queries from natural questions and custom database schemas with a single table. It is based on a large-size LLM, which is finetuned by John Snow Labs on a dataset having schemas with single tables.

The model "***text2sql_with_schema_single_table_augmented***" trained on an augmented dataset achieves the new State-Of-The-Art (SOTA) for this task.


Available models can be found at the [Models Hub](https://nlp.johnsnowlabs.com/models?annotator=Text2SQL).


## Colab Setup

📌To run this yourself, you will need to upload your license keys to the notebook. Just Run The Cell Below in order to do that. Also You can open the file explorer on the left side of the screen and upload `license_keys.json` to the folder that opens.
Otherwise, you can look at the example outputs at the bottom of the notebook.

In [None]:
import json
import os

from google.colab import files

if 'spark_jsl.json' not in os.listdir():
  license_keys = files.upload()
  os.rename(list(license_keys.keys())[0], 'spark_jsl.json')

with open('spark_jsl.json') as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)
os.environ.update(license_keys)

In [None]:
# Installing pyspark and spark-nlp
! pip install --upgrade -q pyspark==3.1.2  spark-nlp==$PUBLIC_VERSION

# Installing Spark NLP Healthcare
! pip install --upgrade -q spark-nlp-jsl==$JSL_VERSION  --extra-index-url https://pypi.johnsnowlabs.com/$SECRET

# Installing Spark NLP Display Library for visualization
! pip install -q spark-nlp-display

In [None]:
import json
import os

from pyspark.ml import Pipeline, PipelineModel
from pyspark.sql import SparkSession

import sparknlp
import sparknlp_jsl

from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.base import *
from sparknlp.util import *
from sparknlp.pretrained import ResourceDownloader
from pyspark.sql import functions as F
import textwrap
import pandas as pd

pd.set_option('display.max_columns', None)
pd.set_option('display.expand_frame_repr', False)
pd.set_option('max_colwidth', None)

import string
import numpy as np

params = {"spark.driver.memory":"16G",
          "spark.kryoserializer.buffer.max":"2000M",
          "spark.serializer": "org.apache.spark.serializer.KryoSerializer",
          "spark.driver.maxResultSize":"2000M"}

spark = sparknlp_jsl.start(license_keys['SECRET'], params=params)

print ("Spark NLP Version :", sparknlp.version())
print ("Spark NLP_JSL Version :", sparknlp_jsl.version())

spark

Spark NLP Version : 5.1.1
Spark NLP_JSL Version : 5.1.1


# 🔎 MODELS

<div align="center">

| **Index** | **Text2SQL models**        |
|---------------|----------------------|
| 1        |  [text2sql_mimicsql](https://nlp.johnsnowlabs.com/2023/08/14/text2sql_mimicsql_en.html)     |
  2       |   [text2sql_with_schema_single_table](https://nlp.johnsnowlabs.com/2023/09/02/text2sql_with_schema_single_table_en.html)   
  3      | [text2sql_with_schema_single_table_augmented](https://nlp.johnsnowlabs.com/2023/09/25/text2sql_with_schema_single_table_augmented_en.html)


</div>

## 📑  **Text2SQL_MIMICSQL**

This model is based on the LLM FlanT5-Large, which is finetuned with a biomedical dataset (MIMICSQL) by John Snow Labs. It can generate SQL queries from medical natural language questions on MIMIC-III dataset.

In [None]:
document_assembler = DocumentAssembler()\
    .setInputCol("prompt")\
    .setOutputCol("document_prompt")

text2sql_mimicsql  = Text2SQL.pretrained("text2sql_mimicsql", "en", "clinical/models")\
    .setInputCols("document_prompt")\
    .setOutputCol("sql_query")\

pipeline = Pipeline(stages=[document_assembler, text2sql_mimicsql])

model = pipeline.fit(spark.createDataFrame([[""]]).toDF("prompt"))

In [None]:
text = ["Find the average number of prescriptions per patient for patients with a specific diagnosis.",
        "give me the number of patients who had single internal mammary-coronary artery bypass.",
        "provide the drug code and drug dose for anna johnson.",
        "calculate the minimum age of married patients who had elective type hospital admission.",
        "What is the maximum age of patients who were hospitalized for 20 days and died before 2023 ?"]

data = spark.createDataFrame([(prompt,) for prompt in text], ["prompt"])

result = model.transform(data)

result.select("sql_query.result").show(truncate=False)

+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|result                                                                                                                                                                                                                                                                   |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|[SELECT AVG ( DEMOGRAPHIC."AGE" ) FROM DEMOGRAPHIC INNER JOIN DIAGNOSES on DEMOGRAPHIC.HADM_ID = DIAGNOSES.HADM_ID INNER JOIN PRESCRIPTIONS on DEMOGRAPHIC.HADM_ID = PRESCRIPTIONS.HADM_ID WHERE DI

### **📍 LightPipelines**

In [None]:
light_model = LightPipeline(model)
light_result = light_model.annotate(text)

In [None]:
for i in range(len(light_result)):
    document_text = textwrap.fill(light_result[i]['document_prompt'][0], width=120)
    summary_text = textwrap.fill(light_result[i]['sql_query'][0], width=120)

    print("➤ User query: {}: \n{}".format(i+1, document_text))
    print("\n")
    print("➤ SQL query {}: \n{}".format(i+1, summary_text))
    print("\n")

➤ User query: 1: 
Find the average number of prescriptions per patient for patients with a specific diagnosis.


➤ SQL query 1: 
SELECT AVG ( DEMOGRAPHIC."AGE" ) FROM DEMOGRAPHIC INNER JOIN DIAGNOSES on DEMOGRAPHIC.HADM_ID = DIAGNOSES.HADM_ID INNER
JOIN PRESCRIPTIONS on DEMOGRAPHIC.HADM_ID = PRESCRIPTIONS.HADM_ID WHERE DIAGNOSES."SHORT_TITLE" = "Specific hst" AND
PRESCRIPTIONS."DRUG" = "1"


➤ User query: 2: 
give me the number of patients who had single internal mammary-coronary artery bypass.


➤ SQL query 2: 
SELECT COUNT ( DISTINCT DEMOGRAPHIC."SUBJECT_ID" ) FROM DEMOGRAPHIC INNER JOIN PROCEDURES on DEMOGRAPHIC.HADM_ID =
PROCEDURES.HADM_ID WHERE PROCEDURES."SHORT_TITLE" = "1 int mam-cor art bypass"


➤ User query: 3: 
provide the drug code and drug dose for anna johnson.


➤ SQL query 3: 
SELECT PRESCRIPTIONS."FORMULARY_DRUG_CD",PRESCRIPTIONS."DRUG_DOSE" FROM DEMOGRAPHIC INNER JOIN PRESCRIPTIONS on
DEMOGRAPHIC.HADM_ID = PRESCRIPTIONS.HADM_ID WHERE DEMOGRAPHIC."NAME" = "Anna Johnson

## 📑  **Text2SQL_With_Schema_Single_Table**


This model can generate SQL queries from natural questions and custom database schemas with a single table. It is based on a large-size LLM, which is finetuned by John Snow Labs on a dataset having schemas with single tables.

In [None]:
query_schema = {"patient": ["ID","Name","Age","Gender","BloodType","Weight","Height","Address","Email","Phone"] }

text2sql_with_schema_single_table = Text2SQL.pretrained("text2sql_with_schema_single_table", "en", "clinical/models")\
    .setMaxNewTokens(200)\
    .setSchema(query_schema)\
    .setInputCols(["document_prompt"])\
    .setOutputCol("sql_query")

pipeline = Pipeline(stages=[document_assembler, text2sql_with_schema_single_table])

model = pipeline.fit(spark.createDataFrame([[""]]).toDF("prompt"))

text2sql_with_schema_single_table download started this may take some time.
[OK!]


In [None]:
text = ["Calculate the average age of patients with blood type 'A-'",
        "Retrieve the names and email addresses of patients with blood type 'B+'",
        "Calculate the number of patients with blood type A- and weight above 100kg"
        ]

data = spark.createDataFrame([(prompt,) for prompt in text], ["prompt"])

result = model.transform(data)

result.select("sql_query.result").show(truncate=False)

+---------------------------------------------------------------------------+
|result                                                                     |
+---------------------------------------------------------------------------+
|[SELECT AVG(Age) FROM patient WHERE BloodType = "A-"]                      |
|[SELECT Name, Email FROM patient WHERE BloodType = "B+"]                   |
|[SELECT COUNT(Name) FROM patient WHERE BloodType = "A-" AND Weight > 100kg]|
+---------------------------------------------------------------------------+



Let's test with another custom database schema:

In [None]:
query_schema = {"drug": ["ID","Name","Manufacturer","Price","ExpiryDate","PrescriptionRequired","SideEffects","Dosage","Quantity"] }
text2sql_with_schema_single_table.setSchema(query_schema)

text = ["Retrieve the names and dosages of drugs containing '50mcg'",
        "Calculate the average price of drugs with a prescription requirement",
        "Retrieve the names and prices of drugs containing '600mg'"
        ]

data = spark.createDataFrame([(prompt,) for prompt in text], ["prompt"])

result = model.transform(data)

result.select("sql_query.result").show(truncate=False)

+----------------------------------------------------------------+
|result                                                          |
+----------------------------------------------------------------+
|[SELECT Name, Dosage FROM drug WHERE Quantity = "50mcg"]        |
|[SELECT AVG(Price) FROM drug WHERE PrescriptionRequired = "Yes"]|
|[SELECT Name, Price FROM drug WHERE Quantity = "600mg"]         |
+----------------------------------------------------------------+



### **📍 LightPipelines**

In [None]:
light_model = LightPipeline(model)
light_result = light_model.annotate(text)

for i in range(len(light_result)):
    document_text = textwrap.fill(light_result[i]['document_prompt'][0], width=120)
    summary_text = textwrap.fill(light_result[i]['sql_query'][0], width=120)

    print("➤ User query: {}: \n{}".format(i+1, document_text))
    print("\n")
    print("➤ SQL query {}: \n{}".format(i+1, summary_text))
    print("\n")

➤ User query: 1: 
Retrieve the names and dosages of drugs containing '50mcg'


➤ SQL query 1: 
SELECT Name, Dosage FROM drug WHERE Quantity = "50mcg"


➤ User query: 2: 
Calculate the average price of drugs with a prescription requirement


➤ SQL query 2: 
SELECT AVG(Price) FROM drug WHERE PrescriptionRequired = "Yes"


➤ User query: 3: 
Retrieve the names and prices of drugs containing '600mg'


➤ SQL query 3: 
SELECT Name, Price FROM drug WHERE Quantity = "600mg"




## 📑  **Text2SQL_With_Schema_Single_Table_Augmented**


This model is the State-of-the-Art (SOTA) for generating SQL queries from natural questions and custom database schemas with a single table. It is based on a large-size LLM, which is finetuned by John Snow Labs on an augmented dataset having schemas with single tables.

In [None]:
query_schema = {
    "medical_treatment": ["patient_id","patient_name","age","gender","diagnosis","treatment","doctor_name","hospital_name","admission_date","discharge_date"]
}

text2sql_with_schema_single_table_augmented = Text2SQL.pretrained("text2sql_with_schema_single_table_augmented", "en", "clinical/models")\
    .setMaxNewTokens(200)\
    .setSchema(query_schema)\
    .setInputCols(["document_prompt"])\
    .setOutputCol("sql_query")

pipeline = Pipeline(stages=[document_assembler, text2sql_with_schema_single_table_augmented])

model = pipeline.fit(spark.createDataFrame([[""]]).toDF("prompt"))

text2sql_with_schema_single_table_augmented download started this may take some time.
[OK!]


In [None]:
text = ["Which patients were admitted in September 2023?",
        "What is the average age of female patients with 'Diabetes'?",
        "Who are the patients treated by 'Dr. Brown'?"
        ]

data = spark.createDataFrame([(prompt,) for prompt in text], ["prompt"])

result = model.transform(data)

result.select("sql_query.result").show(truncate=False)

+-------------------------------------------------------------------------------------------+
|result                                                                                     |
+-------------------------------------------------------------------------------------------+
|[SELECT patient_name FROM medical_treatment WHERE admission_date = "September 2023"]       |
|[SELECT AVG(age) FROM medical_treatment WHERE gender = 'female' AND diagnosis = 'diabetes']|
|[SELECT patient_name FROM medical_treatment WHERE doctor_name = "Dr. Brown"]               |
+-------------------------------------------------------------------------------------------+



### **📍 LightPipelines**

In [None]:
light_model = LightPipeline(model)
light_result = light_model.annotate(text)

for i in range(len(light_result)):
    document_text = textwrap.fill(light_result[i]['document_prompt'][0], width=120)
    summary_text = textwrap.fill(light_result[i]['sql_query'][0], width=120)

    print("➤ User query: {}: \n{}".format(i+1, document_text))
    print("\n")
    print("➤ SQL query {}: \n{}".format(i+1, summary_text))
    print("\n")

➤ User query: 1: 
Which patients were admitted in September 2023?


➤ SQL query 1: 
SELECT patient_name FROM medical_treatment WHERE admission_date = "September 2023"


➤ User query: 2: 
What is the average age of female patients with 'Diabetes'?


➤ SQL query 2: 
SELECT AVG(age) FROM medical_treatment WHERE gender = 'female' AND diagnosis = 'diabetes'


➤ User query: 3: 
Who are the patients treated by 'Dr. Brown'?


➤ SQL query 3: 
SELECT patient_name FROM medical_treatment WHERE doctor_name = "Dr. Brown"


