![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/Spark%20v2.7.6%20Notebooks/18.Text2SQL.ipynb)

# Text2SQL (only works after enterprise v2.7)

In [None]:
import json

from google.colab import files

license_keys = files.upload()

with open(list(license_keys.keys())[0]) as f:
    license_keys = json.load(f)

In [None]:
%%capture
for k,v in license_keys.items(): 
    %set_env $k=$v

!wget https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/jsl_colab_setup.sh
!bash jsl_colab_setup.sh -p 2.4.4

In [None]:
import json
import os
from pyspark.ml import Pipeline,PipelineModel
from pyspark.sql import SparkSession

from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.base import *
import sparknlp_jsl
import sparknlp
from pyspark.sql import functions as F

params = {"spark.driver.memory":"16G",
"spark.kryoserializer.buffer.max":"2000M",
"spark.driver.maxResultSize":"2000M"}

spark = sparknlp_jsl.start(license_keys['SECRET'],params=params)

print (sparknlp.version())
print (sparknlp_jsl.version())

2.7.4
2.7.6


## Convert schema json from SqLite schema

### explore SqLite tables

In [None]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/text2sql/university_basketball.sqlite

In [None]:
import sqlite3

conn = sqlite3.connect('university_basketball.sqlite')

cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
tables

[('basketball_match',), ('university',)]

In [None]:
import pandas as pd

from IPython.display import display, HTML

for table_name in tables:
    table_name = table_name[0]
    table = pd.read_sql_query("SELECT * from %s" % table_name, conn)
    display (table)

Unnamed: 0,Team_ID,School_ID,Team_Name,ACC_Regular_Season,ACC_Percent,ACC_Home,ACC_Road,All_Games,All_Games_Percent,All_Home,All_Road,All_Neutral
0,1,1,North Carolina,14–2,0.875,6–2,8–0,35–2,0.946,14–2,13–0,9–1
1,2,2,Duke,13–3,0.813,7–1,6–2,28–6,0.824,15–1,8–2,5–3
2,3,4,Clemson,10–6,0.625,7–1,3–5,24–10,0.706,14–2,6–5,4–3
3,4,5,Virginia Tech,9–7,0.563,6–2,3–5,21–14,0.6,14–3,4–8,3–3


Unnamed: 0,School_ID,School,Location,Founded,Affiliation,Enrollment,Nickname,Primary_conference
0,1,University of Delaware,"Newark, DE",1743.0,Public,19067.0,Fightin' Blue Hens,Colonial Athletic Association ( D-I )
1,2,Lebanon Valley College,"Annville, PA",1866.0,Private/Methodist,2100.0,Flying Dutchmen,MAC Commonwealth Conference ( D-III )
2,3,University of Rhode Island,"Kingston, RI",1892.0,Public,19095.0,Rams,Atlantic 10 Conference ( D-I )
3,4,Rutgers University,"New Brunswick, NJ",1766.0,Public,56868.0,Scarlet Knights,American Athletic Conference ( D-I )
4,5,Stony Brook University,"Stony Brook, NY",1957.0,Public,23997.0,Seawolves,America East Conference ( D-I )


### convert to text2SQL format

In [None]:
from sparknlp_jsl._tf_graph_builders.text2sql.util import sqlite2json

schema_json_path = 'schema_converted.json'

sqlite2json("university_basketball.sqlite",schema_json_path)


reading db:  university_basketball.sqlite


## Prepare DB schema

This is a one-time process for every new db schema that you want to work on

In [None]:
def prepare_db_schema(schema_json_path, output_json_path):

  document = DocumentAssembler()\
      .setInputCol("text")\
      .setOutputCol("document")

  tables = Text2SQLSchemaParser() \
      .setOutputCol("table_metadata_chunk") \
      .setSchemaPath(schema_json_path) \
      .setInputCols(["document"])

  chunk2doc = Chunk2Doc() \
      .setInputCols(["table_metadata_chunk"]) \
      .setOutputCol("table_metadata_doc")

  table_tokenizer = Tokenizer() \
      .setOutputCol("table_token") \
      .setInputCols(["table_metadata_doc"])

  table_embedding = WordEmbeddingsModel.pretrained("glove_6B_300", "xx") \
      .setInputCols(["table_metadata_doc", "table_token"]) \
      .setOutputCol("table_embedding")

  table_chunk_embeddings = ChunkEmbeddings() \
      .setOutputCol("table_metadata_chunk_embedding") \
      .setInputCols("table_metadata_chunk", "table_embedding")

  table_exporter = Text2SQLSchemaExporter()\
      .setInputCols(["table_metadata_chunk_embedding","table_metadata_chunk"])\
      .setOutputPath(output_json_path)

  table_pl = Pipeline() \
      .setStages([
      document,
      tables,
      chunk2doc,
      table_tokenizer,
      table_embedding,
      table_chunk_embeddings,
      table_exporter
  ])

  data = spark.createDataFrame([
              [1, ""]]) \
              .toDF("id", "text").cache()

  table_model = table_pl.fit(data).transform(data)

  print (output_json_path, 'is created and saved')

  return table_model.show()

In [None]:
schema_json_path = 'schema_converted.json'
output_json_path = "db_embeddings.json"

prepare_db_schema(schema_json_path, output_json_path)

glove_6B_300 download started this may take some time.
Approximate size to download 426.2 MB
[OK!]
db_embeddings.json is created and saved
+---+----+--------------------+--------------------+--------------------+--------------------+--------------------+------------------------------+--------------------+
| id|text|            document|table_metadata_chunk|  table_metadata_doc|         table_token|     table_embedding|table_metadata_chunk_embedding|              export|
+---+----+--------------------+--------------------+--------------------+--------------------+--------------------+------------------------------+--------------------+
|  1|    |[[document, 0, -1...|[[chunk, 0, 15, b...|[[document, 0, 15...|[[token, 0, 9, ba...|[[word_embeddings...|          [[word_embeddings...|[[chunk, 0, 15, b...|
+---+----+--------------------+--------------------+--------------------+--------------------+--------------------+------------------------------+--------------------+



## Prepare Text2SQL pipeline

This is a one-time process for every new db schema that you want to work on

In [None]:
def get_text2sql_model (schema_json_path, output_json_path):

  question_document = DocumentAssembler()\
      .setInputCol("text")\
      .setOutputCol("document")

  sentence_detector = SentenceDetectorDLModel.pretrained()\
      .setInputCols("document")\
      .setOutputCol("sentence")

  tokenizer = Tokenizer()\
      .setInputCols("sentence")\
      .setOutputCol("token")

  question_embbeding = WordEmbeddingsModel.pretrained("glove_6B_300", "xx") \
      .setInputCols(["sentence", "token"]) \
      .setOutputCol("question_embedding")

  text2sql_model = Text2SQLModel.pretrained('text2sql_glove', 'en', 'clinical/models') \
      .setSchemaPath(schema_json_path) \
      .setTableEmbeddingPath(output_json_path)\
      .setInputCols(["token", "question_embedding", "chunk_emb", "table_metadata_chunk"]) \
      .setOutputCol("sql")

  sql_pipeline = Pipeline(stages=[
      question_document,
      sentence_detector,
      tokenizer,
      question_embbeding,
      text2sql_model
          ])

  data = spark.createDataFrame([[""]]).toDF("text")

  sql_prediction_model = sql_pipeline.fit(data)

  sql_prediction_light = LightPipeline(sql_prediction_model)

  print ('text2sql prediction model is built')

  return sql_prediction_light

In [None]:
sql_prediction_light = get_text2sql_model (schema_json_path, output_json_path)

sentence_detector_dl download started this may take some time.
Approximate size to download 354.6 KB
[OK!]
glove_6B_300 download started this may take some time.
Approximate size to download 426.2 MB
[OK!]
text2sql_glove download started this may take some time.
Approximate size to download 37.6 MB
[OK!]
text2sql prediction model is built


## Example queries

In [None]:
import sqlparse

def annotate_and_print(question, sql_light=sql_prediction_light, markdown=False, param=None):

    sql = sql_light.annotate(question)["sql"][0]

    print(sqlparse.format(sql, reindent=True, keyword_case='upper'))
    print("\n")

    if markdown:
      print (pd.read_sql(sql,conn,params=param).to_markdown())
    else:
      display(pd.read_sql(sql,conn,params=param))
    

In [None]:
annotate_and_print("What are the enrollment and primary conference for the university which was founded the earliest?")

SELECT T1.Enrollment,
       T1.Primary_conference
FROM university AS T1
ORDER BY T1.Founded ASC
LIMIT 1




Unnamed: 0,Enrollment,Primary_conference
0,19067.0,Colonial Athletic Association ( D-I )


In [None]:
annotate_and_print("What are the enrollment and primary conference for the university which was founded the earliest?", markdown=True)

SELECT T1.Enrollment,
       T1.Primary_conference
FROM university AS T1
ORDER BY T1.Founded ASC
LIMIT 1


|    |   Enrollment | Primary_conference                    |
|---:|-------------:|:--------------------------------------|
|  0 |        19067 | Colonial Athletic Association ( D-I ) |


In [None]:
annotate_and_print("What is the total and minimum enrollment of all schools?")

SELECT sum(T1.Enrollment),
       min(T1.Enrollment)
FROM university AS T1




Unnamed: 0,sum(T1.Enrollment),min(T1.Enrollment)
0,121127.0,2100.0


In [None]:
annotate_and_print("Return the total and minimum enrollments across all schools.")

SELECT sum(T1.All_Neutral),
       min(T1.All_Neutral)
FROM basketball_match AS T1




Unnamed: 0,sum(T1.All_Neutral),min(T1.All_Neutral)
0,21.0,3–3


In [None]:
annotate_and_print("Find the total student enrollment for different affiliation type schools.")

SELECT sum(T1.Enrollment),
       T1.Affiliation
FROM university AS T1
GROUP BY T1.Affiliation




Unnamed: 0,sum(T1.Enrollment),Affiliation
0,2100.0,Private/Methodist
1,119027.0,Public


In [None]:
annotate_and_print("Find how many different affiliation types there are.")

SELECT count(T1.Affiliation)
FROM university AS T1




Unnamed: 0,count(T1.Affiliation)
0,5


## Use case: Hospital records

In [None]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/text2sql/hospital_records.sqlite

In [None]:
import sqlite3

conn = sqlite3.connect('/content/hospital_records.sqlite')

cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
tables

[('Physician',),
 ('Department',),
 ('Affiliated_With',),
 ('Procedures',),
 ('Trained_In',),
 ('Patient',),
 ('Nurse',),
 ('Appointment',),
 ('Medication',),
 ('Prescribes',),
 ('Block',),
 ('Room',),
 ('On_Call',),
 ('Stay',),
 ('Undergoes',)]

In [None]:
import pandas as pd

from IPython.display import display, HTML

for table_name in tables:
    table_name = table_name[0]
    table = pd.read_sql_query("SELECT * from %s" % table_name, conn)
    print (table_name)
    display (table.head(5))
    print ('========')

Physician


Unnamed: 0,EmployeeID,Name,Position,SSN
0,1,John Dorian,Staff Internist,111111111
1,2,Elliot Reid,Attending Physician,222222222
2,3,Christopher Turk,Surgical Attending Physician,333333333
3,4,Percival Cox,Senior Attending Physician,444444444
4,5,Bob Kelso,Head Chief of Medicine,555555555


Department


Unnamed: 0,DepartmentID,Name,Head
0,1,General Medicine,4
1,2,Surgery,7
2,3,Psychiatry,9


Affiliated_With


Unnamed: 0,Physician,Department,PrimaryAffiliation
0,1,1,1
1,2,1,1
2,3,1,0
3,3,2,1
4,4,1,1


Procedures


Unnamed: 0,Code,Name,Cost
0,1,Reverse Rhinopodoplasty,1500.0
1,2,Obtuse Pyloric Recombobulation,3750.0
2,3,Folded Demiophtalmectomy,4500.0
3,4,Complete Walletectomy,10000.0
4,5,Obfuscated Dermogastrotomy,4899.0


Trained_In


Unnamed: 0,Physician,Treatment,CertificationDate,CertificationExpires
0,3,1,2008-01-01,2008-12-31
1,3,2,2008-01-01,2008-12-31
2,3,5,2008-01-01,2008-12-31
3,3,6,2008-01-01,2008-12-31
4,3,7,2008-01-01,2008-12-31


Patient


Unnamed: 0,SSN,Name,Address,Phone,InsuranceID,PCP
0,100000001,John Smith,42 Foobar Lane,555-0256,68476213,1
1,100000002,Grace Ritchie,37 Snafu Drive,555-0512,36546321,2
2,100000003,Random J. Patient,101 Omgbbq Street,555-1204,65465421,2
3,100000004,Dennis Doe,1100 Foobaz Avenue,555-2048,68421879,3


Nurse


Unnamed: 0,EmployeeID,Name,Position,Registered,SSN
0,101,Carla Espinosa,Head Nurse,1,111111110
1,102,Laverne Roberts,Nurse,1,222222220
2,103,Paul Flowers,Nurse,0,333333330


Appointment


Unnamed: 0,AppointmentID,Patient,PrepNurse,Physician,Start,End,ExaminationRoom
0,13216584,100000001,101.0,1,2008-04-24 10:00,2008-04-24 11:00,A
1,26548913,100000002,101.0,2,2008-04-24 10:00,2008-04-24 11:00,B
2,36549879,100000001,102.0,1,2008-04-25 10:00,2008-04-25 11:00,A
3,46846589,100000004,103.0,4,2008-04-25 10:00,2008-04-25 11:00,B
4,59871321,100000004,,4,2008-04-26 10:00,2008-04-26 11:00,C


Medication


Unnamed: 0,Code,Name,Brand,Description
0,1,Procrastin-X,X,
1,2,Thesisin,Foo Labs,
2,3,Awakin,Bar Laboratories,
3,4,Crescavitin,Baz Industries,
4,5,Melioraurin,Snafu Pharmaceuticals,


Prescribes


Unnamed: 0,Physician,Patient,Medication,Date,Appointment,Dose
0,1,100000001,1,2008-04-24 10:47,13216584.0,5
1,9,100000004,2,2008-04-27 10:53,86213939.0,10
2,9,100000004,2,2008-04-30 16:53,,5


Block


Unnamed: 0,BlockFloor,BlockCode
0,1,1
1,1,2
2,1,3
3,2,1
4,2,2


Room


Unnamed: 0,RoomNumber,RoomType,BlockFloor,BlockCode,Unavailable
0,101,Single,1,1,0
1,102,Single,1,1,0
2,103,Single,1,1,0
3,111,Single,1,2,0
4,112,Single,1,2,1


On_Call


Unnamed: 0,Nurse,BlockFloor,BlockCode,OnCallStart,OnCallEnd
0,101,1,1,2008-11-04 11:00,2008-11-04 19:00
1,101,1,2,2008-11-04 11:00,2008-11-04 19:00
2,102,1,3,2008-11-04 11:00,2008-11-04 19:00
3,103,1,1,2008-11-04 19:00,2008-11-05 03:00
4,103,1,2,2008-11-04 19:00,2008-11-05 03:00


Stay


Unnamed: 0,StayID,Patient,Room,StayStart,StayEnd
0,3215,100000001,111,2008-05-01,2008-05-04
1,3216,100000003,123,2008-05-03,2008-05-14
2,3217,100000004,112,2008-05-02,2008-05-03


Undergoes


Unnamed: 0,Patient,Procedures,Stay,DateUndergoes,Physician,AssistingNurse
0,100000001,6,3215,2008-05-02,3,101
1,100000001,2,3215,2008-05-03,7,101
2,100000004,1,3217,2008-05-07,3,102
3,100000004,5,3217,2008-05-09,6,105
4,100000001,7,3217,2008-05-10,7,101




In [None]:
from sparknlp_jsl._tf_graph_builders.text2sql.util import sqlite2json

sqlite2json("hospital_records.sqlite","hospital_schema_converted.json")


reading db:  hospital_records.sqlite


In [None]:
schema_json_path = "hospital_schema_converted.json"

output_json_path = "hospital_db_embeddings.json"

prepare_db_schema(schema_json_path, output_json_path)

hospital_sql_prediction_light = get_text2sql_model (schema_json_path, output_json_path)

glove_6B_300 download started this may take some time.
Approximate size to download 426.2 MB
[OK!]
hospital_db_embeddings.json is created and saved
+---+----+--------------------+--------------------+--------------------+--------------------+--------------------+------------------------------+--------------------+
| id|text|            document|table_metadata_chunk|  table_metadata_doc|         table_token|     table_embedding|table_metadata_chunk_embedding|              export|
+---+----+--------------------+--------------------+--------------------+--------------------+--------------------+------------------------------+--------------------+
|  1|    |[[document, 0, -1...|[[chunk, 0, 8, ph...|[[document, 0, 8,...|[[token, 0, 8, ph...|[[word_embeddings...|          [[word_embeddings...|[[chunk, 0, 8, ph...|
+---+----+--------------------+--------------------+--------------------+--------------------+--------------------+------------------------------+--------------------+

sentence_de

In [None]:
annotate_and_print("Find the id of the appointment with the most recent start date", hospital_sql_prediction_light)

SELECT T1.Appointment
FROM Prescribes AS T1
JOIN Appointment AS T2 ON T1.Appointment = T2.AppointmentID
ORDER BY T2.Start DESC
LIMIT 1




Unnamed: 0,Appointment
0,86213939


In [None]:
annotate_and_print("What is the name of the patient who made the most recent appointment", hospital_sql_prediction_light, markdown=True)

SELECT T1.Name
FROM Patient AS T1
JOIN Appointment AS T2 ON T1.SSN = T2.Patient
ORDER BY T2.Start DESC
LIMIT 1


|    | Name       |
|---:|:-----------|
|  0 | Dennis Doe |


In [None]:
annotate_and_print("What is the name of the nurse has the most appointments?", hospital_sql_prediction_light)

SELECT T1.Name
FROM Nurse AS T1
JOIN Appointment AS T2 ON T1.EmployeeID = T2.PrepNurse
GROUP BY T2.prepnurse
ORDER BY count(*) DESC
LIMIT 1




Unnamed: 0,Name
0,Carla Espinosa


In [None]:
annotate_and_print("What is the name of the nurse has the most appointments?", hospital_sql_prediction_light, markdown=True)

SELECT T1.Name
FROM Nurse AS T1
JOIN Appointment AS T2 ON T1.EmployeeID = T2.PrepNurse
GROUP BY T2.prepnurse
ORDER BY count(*) DESC
LIMIT 1


|    | Name           |
|---:|:---------------|
|  0 | Carla Espinosa |


In [None]:
annotate_and_print("How many patients do each physician take care of? List their names and number of patients they take care of.", hospital_sql_prediction_light)

SELECT T1.Name,
       count(*)
FROM Physician AS T1
JOIN Patient AS T2 ON T1.EmployeeID = T2.PCP
GROUP BY T1.Name




Unnamed: 0,Name,count(*)
0,Christopher Turk,1
1,Elliot Reid,2
2,John Dorian,1


In [None]:
annotate_and_print("How many patients do each physician take care of? List their names and number of patients they take care of.", hospital_sql_prediction_light, markdown=True)

SELECT T1.Name,
       count(*)
FROM Physician AS T1
JOIN Patient AS T2 ON T1.EmployeeID = T2.PCP
GROUP BY T1.Name


|    | Name             |   count(*) |
|---:|:-----------------|-----------:|
|  0 | Christopher Turk |          1 |
|  1 | Elliot Reid      |          2 |
|  2 | John Dorian      |          1 |


In [None]:
annotate_and_print("Give me trhe name of the departments", hospital_sql_prediction_light, markdown=True)

SELECT T1.Name
FROM Department AS T1


|    | Name             |
|---:|:-----------------|
|  0 | General Medicine |
|  1 | Surgery          |
|  2 | Psychiatry       |


In [None]:
annotate_and_print("What is the most expensive procedure?", hospital_sql_prediction_light)

SELECT T1.Name
FROM Procedures AS T1
ORDER BY T1.Cost DESC
LIMIT 1




Unnamed: 0,Name
0,Complete Walletectomy


In [None]:
annotate_and_print("What is the cheapest procedure?", hospital_sql_prediction_light)

SELECT T1.Name
FROM Procedures AS T1
ORDER BY T1.Cost ASC
LIMIT 1




Unnamed: 0,Name
0,Follicular Demiectomy
