In [1]:
# # Add '--use-deprecated=legacy-resolver' to resolve dependency error

# !pip install -q python-dotenv --use-deprecated=legacy-resolver
# !pip install -q langchain --use-deprecated=legacy-resolver
# !pip install -q langchain_community --use-deprecated=legacy-resolver
# !pip install -q langchain_experimental --use-deprecated=legacy-resolver

# # If using Cohere:
# !pip install -q langchain-cohere --use-deprecated=legacy-resolver

# # Using IRIS:
# !pip install -q langchain-iris --use-deprecated=legacy-resolver

In [2]:
import os
from dotenv import load_dotenv, find_dotenv
from langchain_cohere import ChatCohere
from langchain.prompts.prompt import PromptTemplate
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain_iris import IRISVector

In [3]:
# # https://docs.cohere.com/docs/cohere-and-langchain
# # class `Cohere` was deprecated in LangChain 0.1.14

_ = load_dotenv(find_dotenv()) # read local .env file
cohere_api_key = "YOUR_COHERE_API_KEY"

# https://docs.cohere.com/docs/models
model = "command"

temperature = 0
llm = ChatCohere(model=model,temperature=0, cohere_api_key=cohere_api_key)

In [4]:
username = 'superuser'
password = 'sys'
hostname = os.getenv('IRIS_HOSTNAME', 'localhost')
port = '1972'
namespace = 'TEST'
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
print(CONNECTION_STRING)

iris://superuser:sys@localhost:1972/TEST


In [15]:
# https://community.intersystems.com/post/langchain-fixed-sql-me

_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.

The only table available is DiseaseProfile.

The columns are Disease VARCHAR(512), Fever VARCHAR(25), Cough VARCHAR(25), Fatigue VARCHAR(25), DifficultyBreathing VARCHAR(25), Age INT, Gender VARCHAR(25), BloodPressure VARCHAR(25), CholesterolLevel VARCHAR(25), OutcomeVariable VARCHAR(25).
Fever, Cough, Fatigue and DifficultyBreathing are potential symptoms which the patients are experiencing. 

Columns and Usage:

Disease: The name of the disease or medical condition.
Fever: Indicates whether the patient has a fever (Yes/No).
Cough: Indicates whether the patient has a cough (Yes/No).
Fatigue: Indicates whether the patient experiences fatigue (Yes/No).
Difficulty Breathing: Indicates whether the patient has difficulty breathing (Yes/No).
Age: The age of the patient in years.
Gender: The gender of the patient (Male/Female).
Blood Pressure: The blood pressure level of the patient (Normal/High).
Cholesterol Level: The cholesterol level of the patient (Normal/High).
Outcome Variable: The outcome variable indicating the result of the diagnosis or assessment for the specific disease (Positive/Negative).

Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"

The SQL query should NOT end with semi-colon or "```"
Question: {input}"""

PROMPT = PromptTemplate(
    input_variables=["input", "dialect"], template=_DEFAULT_TEMPLATE
)

db_sql = SQLDatabase.from_uri(CONNECTION_STRING) 

db_chain = SQLDatabaseChain.from_llm(llm=llm, db=db_sql, prompt=PROMPT, verbose=True) 

In [13]:
db_chain.invoke('how many rows are there')



[1m> Entering new SQLDatabaseChain chain...[0m
how many rows are there
SQLQuery:[32;1m[1;3mSELECT COUNT(*) FROM DiseaseProfile
SQLResult:[0m
SQLResult: [33;1m[1;3m[(349,)][0m
Answer:[32;1m[1;3mSQLQuery: SELECT COUNT(*) FROM DiseaseProfile
SQLResult:[0m
[1m> Finished chain.[0m


{'query': 'how many rows are there',
 'result': 'SQLQuery: SELECT COUNT(*) FROM DiseaseProfile\nSQLResult:'}

In [16]:
db_chain.invoke('Find the disease names and number of occurances among patients below 30 who have normal cholesterol level')



[1m> Entering new SQLDatabaseChain chain...[0m
Find the disease names and number of occurances among patients below 30 who have normal cholesterol level
SQLQuery:[32;1m[1;3mSELECT Disease, COUNT(*) FROM DiseaseProfile WHERE Age < 30 AND CholesterolLevel = 'Normal' GROUP BY Disease;[0m
SQLResult: [33;1m[1;3m[('ASTHMA', 3), ('COMMON COLD', 2), ('DEPRESSION', 1), ('DIABETES', 1), ('ECZEMA', 2), ('GASTROENTERITIS', 1), ('HYPERTHYROIDISM', 2), ('INFLUENZA', 3), ('LIVER CANCER', 1), ('PANCREATITIS', 1), ('STROKE', 1)][0m
Answer:[32;1m[1;3m```sql
SELECT Disease, COUNT(*) FROM DiseaseProfile WHERE Age < 30 AND CholesterolLevel = 'Normal' GROUP BY Disease;
```
```json
[
    ['Disease', 'Count'],
    ['ASTHMA', 3L],
    ['COMMON COLD', 2L],
    ['DEPRESSION', 1L],
    ['DIABETES', 1L],
    ['ECZEMA', 2L],
    ['GASTROENTERITIS', 1L],
    ['HYPERTHYROIDISM', 2L],
    ['INFLUENZA', 3L],
    ['LIVER CANCER', 1L],
    ['PANCREATITIS', 1L],
    ['STROKE', 1L]
]
```[0m
[1m> Finished chai

{'query': 'Find the disease names and number of occurances among patients below 30 who have normal cholesterol level',
 'result': "```sql\nSELECT Disease, COUNT(*) FROM DiseaseProfile WHERE Age < 30 AND CholesterolLevel = 'Normal' GROUP BY Disease;\n```\n```json\n[\n    ['Disease', 'Count'],\n    ['ASTHMA', 3L],\n    ['COMMON COLD', 2L],\n    ['DEPRESSION', 1L],\n    ['DIABETES', 1L],\n    ['ECZEMA', 2L],\n    ['GASTROENTERITIS', 1L],\n    ['HYPERTHYROIDISM', 2L],\n    ['INFLUENZA', 3L],\n    ['LIVER CANCER', 1L],\n    ['PANCREATITIS', 1L],\n    ['STROKE', 1L]\n]\n```"}

In [12]:
# # THIS WILL DELETE THE SQL TABLE ON MANAGEMENT PORTAL [!!!] --> remove DDLAllowed

# db_chain.invoke('delete the table')