In [None]:
# This is a visualisation utilising plotly library with dash so like a web app
# more interactive than matlib and pandasai (can hover over graph and see info breakdown)
# still have not been tested with larger database

In [21]:
!pip install "notebook>=5.3" "ipywidgets>=7.5"
!pip install plotly==5.22.0
!pip install dash
!pip install chart_studio

Collecting chart_studio
  Downloading chart_studio-1.1.0-py3-none-any.whl.metadata (1.3 kB)
Downloading chart_studio-1.1.0-py3-none-any.whl (64 kB)
   ---------------------------------------- 0.0/64.4 kB ? eta -:--:--
   ------------------------- -------------- 41.0/64.4 kB 960.0 kB/s eta 0:00:01
   ---------------------------------------- 64.4/64.4 kB 1.2 MB/s eta 0:00:00
Installing collected packages: chart_studio
Successfully installed chart_studio-1.1.0


In [1]:
import os
import pandas as pd
from pandasai import Agent
import dash
from dash import dcc
from dash import html 
from dash.dependencies import Input, Output
import plotly.express as px
import cohere

# By default, unless you choose a different LLM, it will use BambooLLM.
# You can get your free API key signing up at https://pandabi.ai (you can also configure it in your .env file)
os.environ["PANDASAI_API_KEY"] = ""

In [15]:
username = ''
password = ''
hostname = os.getenv('IRIS_HOSTNAME', 'localhost')
port = '1972'
namespace = ''
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
print(CONNECTION_STRING)

iris://:@localhost:1972/


In [3]:
# https://community.intersystems.com/post/langchain-fixed-sql-me
# https://community.intersystems.com/post/using-sqlalchemy-transfer-tables-and-iris
import sqlalchemy as db

engine = db.create_engine(" ")
connection = engine.connect()

In [None]:
# Read the CSV file
# ONLY NEED TO DO ONCE subsequent runs dont need since tables are loaded already
csv_file_path = 'disease_data.csv'
df = pd.read_csv(csv_file_path)

# Transfer data to IRIS
df.to_sql('Location', engine, if_exists='replace', index=False)

print("Data has been successfully transferred to the IRIS database.")

In [19]:
# Read the CSV file
csv_file_path = 'Disease_symptom_and_patient_profile_dataset.csv'
df = pd.read_csv(csv_file_path)
# Transfer data to IRIS also only need to be done ONCE (dont need to rerun after it has been done the first time)
df.to_sql('DiseaseProfile', engine, if_exists='replace', index=False)

In [4]:
# Read the 'DiseaseProfile' table from the SQL database into a DataFrame
DiseaseProfile = pd.read_sql_table('DiseaseProfile', connection, schema="SQLUser")

In [5]:
DiseaseProfile.head(5)

Unnamed: 0,Disease,Fever,Cough,Fatigue,Difficulty_Breathing,Age,Gender,Blood_Pressure,Cholesterol_Level,Outcome_Variable
0,Influenza,Yes,No,Yes,Yes,19,Female,Low,Normal,Positive
1,Common Cold,No,Yes,Yes,No,25,Female,Normal,Normal,Negative
2,Eczema,No,Yes,Yes,No,25,Female,Normal,Normal,Negative
3,Asthma,Yes,Yes,No,Yes,25,Male,Normal,Normal,Positive
4,Asthma,Yes,Yes,No,Yes,25,Male,Normal,Normal,Positive


In [8]:
from pandasai import SmartDataframe

from langchain_cohere import ChatCohere


# llm = BambooLLM(api_key="$2a$10$GfFQiPXypOZBsp.EXT3J7.vCbxMJqVd1rlsoJPjB/A5B9olozt0Di")
# smart_df = SmartDataframe(DiseaseProfile, name="DiseaseProfile", description="Dataset used to generate SQL query")
# smart_df = SmartDataframe(DiseaseProfile, name="DiseaseProfile", description="Dataset used to generate SQL query", config={"llm": llm})

cohere_api_key = ""
# os.environ.get('COHERE_API_KEY_TWO')

# https://docs.cohere.com/docs/models
model = "command"

temperature = 0
llm = ChatCohere(model=model,temperature=0, cohere_api_key=cohere_api_key)
smart_df = SmartDataframe(DiseaseProfile, name="DiseaseProfile", description="Dataset used to generate SQL query", config={"llm": llm})

In [9]:
# https://community.intersystems.com/post/langchain-fixed-sql-me
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.prompts.prompt import PromptTemplate

_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.

The only table available is DiseaseProfile.

The columns are Disease VARCHAR(512), Fever VARCHAR(25), Cough VARCHAR(25), Fatigue VARCHAR(25), Difficulty_Breathing VARCHAR(25), Age INT, Gender VARCHAR(25), Blood_Pressure VARCHAR(25), Cholesterol_Level VARCHAR(25), Outcome_Variable VARCHAR(25).
Fever, Cough, Fatigue and DifficultyBreathing are potential symptoms which the patients are experiencing. 

Columns and Usage:

Disease: The name of the disease or medical condition.
Fever: Indicates whether the patient has a fever (Yes/No).
Cough: Indicates whether the patient has a cough (Yes/No).
Fatigue: Indicates whether the patient experiences fatigue (Yes/No).
Difficulty_Breathing: Indicates whether the patient has difficulty breathing (Yes/No).
Age: The age of the patient in years.
Gender: The gender of the patient (Male/Female).
Blood_Pressure: The blood pressure level of the patient (Normal/High).
Cholesterol_Level: The cholesterol level of the patient (Normal/High).
Outcome_Variable: The outcome variable indicating the result of the diagnosis or assessment for the specific disease (Positive/Negative).

Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

The SQL query should NOT end with semi-colon
Question: {input}"""

PROMPT = PromptTemplate(
    input_variables=["input", "dialect"], template=_DEFAULT_TEMPLATE
)

db_sql = SQLDatabase.from_uri(CONNECTION_STRING) 

db_chain = SQLDatabaseChain.from_llm(llm=llm, db=db_sql, return_direct=True, return_intermediate_steps=True, prompt=PROMPT, verbose=True) 

In [14]:
# dash web app 
# hosted on port 8053
# image of graph can be downloaded as a png by clicking the camera icon
# Import necessary libraries
import pandas as pd
from dash import Dash, html, dcc, dash_table, Input, Output, callback
from jupyter_dash import JupyterDash
import plotly.express as px
import cohere
#from pandasai import SmartDataframe
from cohere import ClassifyExample

# Load the data either from IRIS of from same directory
#csv_file_path = 'disease_data.csv'
df = pd.read_sql_table('Location', connection, schema="SQLUser")

# Initialize the Cohere API
cohere_api_key = "mD1SGFkiLf0RlAzUBJGn7uIPbdK95sCp2ys1eGWL"
co = cohere.Client(cohere_api_key)
examples=[
            ClassifyExample(text="Plot a bar chart for the top 5 diseases by count", label="bar"),
            ClassifyExample(text="Show me a histogram of ages", label="histogram"),
            ClassifyExample(text="Create a box plot of disease counts", label="box"),
            ClassifyExample(text="Display a pie chart of disease distribution", label="pie")
        ]
def generate_visualization(query):
    response = co.classify(
        model='large',
        inputs=[query],
        examples=[
            ClassifyExample(text="Plot a bar chart for the top 5 diseases by count", label="bar"),
            ClassifyExample(text="Show me a histogram of ages", label="histogram"),
            ClassifyExample(text="Create a box plot of disease counts", label="box"),
            ClassifyExample(text="Display a pie chart of disease distribution", label="pie"),
            ClassifyExample(text="I want a bar chart for the most common diseases", label="bar"),
            ClassifyExample(text="Give me a histogram of patient ages", label="histogram"),
            ClassifyExample(text="Make a box plot of the number of diseases", label="box"),
            ClassifyExample(text="Can you show a pie chart of how diseases are distributed?", label="pie"),
            ClassifyExample(text="Show a bar graph of disease occurrences", label="bar"),
            ClassifyExample(text="I need a histogram of age distribution", label="histogram"),
            ClassifyExample(text="Box plot the counts of different diseases", label="box"),
            ClassifyExample(text="Pie chart of disease counts", label="pie")
        ]
    )
    chart_type = response.classifications[0].prediction
    return chart_type

# Initialize the Dash app
app = Dash(__name__)

# App layout
app.layout = html.Div([
    html.H1('Disease Data Visualization', style={'textAlign': 'center'}),
    html.Hr(),
    dash_table.DataTable(
        data=df.to_dict('records'),
        page_size=6,
        style_table={'height': '300px', 'overflowY': 'auto'}
    ),
    dcc.Input(id='nl-query', type='text', value='', style={'width': '100%'}),
    html.Button('Submit', id='submit-query', n_clicks=0),
    dcc.Graph(id='visualization')
])

# Add callback to update graph based on natural language query
@app.callback(
    Output('visualization', 'figure'),
    [Input('submit-query', 'n_clicks')],
    [Input('nl-query', 'value')]
)
def update_graph(n_clicks, nl_query):
    if n_clicks > 0 and nl_query:
        chart_type = generate_visualization(nl_query)
        if chart_type == "bar":
            fig = px.bar(df, x='Disease', y='Count')
        elif chart_type == "histogram":
            fig = px.histogram(df, x='Disease', y='Count')
        elif chart_type == "box":
            fig = px.box(df, y='Count')
        elif chart_type == "pie":
            fig = px.pie(df, names='Disease', values='Count')
        return fig
    return px.bar(df, x='Disease', y='Count')  # Default chart

# Run the app
if __name__ == '__main__':
    app.run(debug=True, port=8053)
