![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/37.Text2SQL_Generation.ipynb)


# **Text2SQL Generation**

The Text-to-SQL task, which involves automatically converting natural language questions into corresponding SQL queries, has seen significant advancements with the application of state-of-the-art models. In this direction, We are excited to introduce our new Text2SQL annotator. This powerful tool revolutionizes the way you interact with databases by effortlessly translating natural language text prompts into accurate and effective SQL queries. With the integration of a state-of-the-art LLM, this annotator opens new possibilities for enhanced data retrieval and manipulation, streamlining your workflow and boosting efficiency.

Also we have a new text2sql_mimicsql model that is specifically finetuned on MIMIC-III dataset schema for enhancing the precision of SQL queries derived from medical natural language queries on MIMIC dataset.


Available models can be found at the [Models Hub](https://nlp.johnsnowlabs.com/models?annotator=Text2SQL).


## Colab Setup

📌To run this yourself, you will need to upload your license keys to the notebook. Just Run The Cell Below in order to do that. Also You can open the file explorer on the left side of the screen and upload `license_keys.json` to the folder that opens.
Otherwise, you can look at the example outputs at the bottom of the notebook.

In [1]:
import json
import os

from google.colab import files

if 'spark_jsl.json' not in os.listdir():
  license_keys = files.upload()
  os.rename(list(license_keys.keys())[0], 'spark_jsl.json')

with open('spark_jsl.json') as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)
os.environ.update(license_keys)

In [2]:
# Installing pyspark and spark-nlp
! pip install --upgrade -q pyspark==3.1.2  spark-nlp==$PUBLIC_VERSION

# Installing Spark NLP Healthcare
! pip install --upgrade -q spark-nlp-jsl==$JSL_VERSION  --extra-index-url https://pypi.johnsnowlabs.com/$SECRET

# Installing Spark NLP Display Library for visualization
! pip install -q spark-nlp-display

In [3]:
import json
import os

from pyspark.ml import Pipeline, PipelineModel
from pyspark.sql import SparkSession

import sparknlp
import sparknlp_jsl

from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.base import *
from sparknlp.util import *
from sparknlp.pretrained import ResourceDownloader
from pyspark.sql import functions as F

import pandas as pd

pd.set_option('display.max_columns', None)
pd.set_option('display.expand_frame_repr', False)
pd.set_option('max_colwidth', None)

import string
import numpy as np

params = {"spark.driver.memory":"16G",
          "spark.kryoserializer.buffer.max":"2000M",
          "spark.serializer": "org.apache.spark.serializer.KryoSerializer",
          "spark.driver.maxResultSize":"2000M"}

spark = sparknlp_jsl.start(license_keys['SECRET'], params=params)

print ("Spark NLP Version :", sparknlp.version())
print ("Spark NLP_JSL Version :", sparknlp_jsl.version())

spark

Spark NLP Version : 5.0.2
Spark NLP_JSL Version : 5.0.2


# 🔎 MODELS

<div align="center">

| **Index** | **Text2SQL models**        |
|---------------|----------------------|
| 1        |  [text2sql_mimicsql](https://nlp.johnsnowlabs.com/2023/08/14/text2sql_mimicsql_en.html)     |
    |


</div>

## 📑  **Text2SQL_MIMICSQL**

This model is based on the LLM FlanT5-Large, which is finetuned with a biomedical dataset (MIMICSQL) by John Snow Labs. It can generate SQL queries from medical natural language questions on MIMIC-III dataset.

In [5]:
document_assembler = DocumentAssembler()\
    .setInputCol("prompt")\
    .setOutputCol("document_prompt")

med_text_generator  = Text2SQL.pretrained("text2sql_mimicsql", "en", "clinical/models")\
    .setInputCols("document_prompt")\
    .setOutputCol("sql_query")\

pipeline = Pipeline(stages=[document_assembler, med_text_generator])

model = pipeline.fit(spark.createDataFrame([[""]]).toDF("prompt"))

text2sql_mimicsql download started this may take some time.
[OK!]


In [6]:
text = ["Find the average number of prescriptions per patient for patients with a specific diagnosis.",
        "give me the number of patients who had single internal mammary-coronary artery bypass.",
        "provide the drug code and drug dose for anna johnson.",
        "calculate the minimum age of married patients who had elective type hospital admission.",
        "What is the maximum age of patients who were hospitalized for 20 days and died before 2023 ?"]

data = spark.createDataFrame([(prompt,) for prompt in text], ["prompt"])

result = model.transform(data)

result.select("sql_query.result").show(truncate=False)

+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|result                                                                                                                                                                                                                                                                   |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|[SELECT AVG ( DEMOGRAPHIC."AGE" ) FROM DEMOGRAPHIC INNER JOIN DIAGNOSES on DEMOGRAPHIC.HADM_ID = DIAGNOSES.HADM_ID INNER JOIN PRESCRIPTIONS on DEMOGRAPHIC.HADM_ID = PRESCRIPTIONS.HADM_ID WHERE DI

### **📍 LightPipelines**

In [7]:
light_model = LightPipeline(model)
light_result = light_model.annotate(text)

In [8]:
import textwrap

for i in range(len(light_result)):
    document_text = textwrap.fill(light_result[i]['document_prompt'][0], width=120)
    summary_text = textwrap.fill(light_result[i]['sql_query'][0], width=120)

    print("➤ User query: {}: \n{}".format(i+1, document_text))
    print("\n")
    print("➤ SQL query {}: \n{}".format(i+1, summary_text))
    print("\n")

➤ User query: 1: 
Find the average number of prescriptions per patient for patients with a specific diagnosis.


➤ SQL query 1: 
SELECT AVG ( DEMOGRAPHIC."AGE" ) FROM DEMOGRAPHIC INNER JOIN DIAGNOSES on DEMOGRAPHIC.HADM_ID = DIAGNOSES.HADM_ID INNER
JOIN PRESCRIPTIONS on DEMOGRAPHIC.HADM_ID = PRESCRIPTIONS.HADM_ID WHERE DIAGNOSES."SHORT_TITLE" = "Specific hst" AND
PRESCRIPTIONS."DRUG" = "1"


➤ User query: 2: 
give me the number of patients who had single internal mammary-coronary artery bypass.


➤ SQL query 2: 
SELECT COUNT ( DISTINCT DEMOGRAPHIC."SUBJECT_ID" ) FROM DEMOGRAPHIC INNER JOIN PROCEDURES on DEMOGRAPHIC.HADM_ID =
PROCEDURES.HADM_ID WHERE PROCEDURES."SHORT_TITLE" = "1 int mam-cor art bypass"


➤ User query: 3: 
provide the drug code and drug dose for anna johnson.


➤ SQL query 3: 
SELECT PRESCRIPTIONS."FORMULARY_DRUG_CD",PRESCRIPTIONS."DRUG_DOSE" FROM DEMOGRAPHIC INNER JOIN PRESCRIPTIONS on
DEMOGRAPHIC.HADM_ID = PRESCRIPTIONS.HADM_ID WHERE DEMOGRAPHIC."NAME" = "Anna Johnson