In [None]:
%pip install -qU langchain langchain-openai langchain-community langchain-experimental pandas

In [None]:
import os
os.environ["OPENAI_API_KEY"] = ""


In [None]:
import pandas as pd

df = pd.read_csv("ZAVUS24_Drug_Histories.csv")
print(df.shape)
print(df.columns.tolist())

In [None]:
print(df['patient_stock_name'].unique())

In [None]:
print(df['class_name'].unique())

In [None]:
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine

engine = create_engine("sqlite:///zavegepant.db")
df.to_sql("zavegepant", engine, index=False)

In [None]:
db = SQLDatabase(engine=engine)
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM zavegepant WHERE molecule_name == 'Zavegepant';")

In [None]:
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)

In [None]:
agent_executor.invoke(
    "How many distinct patients where ON Zavegepant ON month 60?"
)

In [None]:
examples = [
    {   "input": "List all different stocks.", 
        "query": "DISTINCT patient_stock_name FROM zavegepant;"
    },
    {
        "input": "How many patients are ON the stock CGRP Nasal ON month 60?.",
        "query": "SELECT COUNT(DISTINCT patient_id) AS unique_patient_count FROM zavegepant WHERE month == 60 AND patient_stock_name = 'CGRP Nasal';",
    },
    {
        "input": "How many patients started ON a CGRP Oral anew over the past 12 months?",
        "query": "SELECT COUNT(DISTINCT patient_id) AS new_patients_count FROM zavegepant WHERE month > 48 AND class_name = 'CGRP Oral' AND patient_id NOT IN (SELECT DISTINCT patient_id FROM zavegepant WHERE month <= 48 AND class_name = 'CGRP Oral');",
    },
    {
        "input": "Find the average number of months a patient spends on a NSAID class.",
        "query": "SELECT AVG(month_count) AS average_months_on_NSAID FROM (SELECT patient_id, COUNT(DISTINCT month) AS month_count FROM zavegepant WHERE class_name = 'NSAID' GROUP BY patient_id) AS subquery;",
    },
    {
        "input": "List all different drug classes.",
        "query": "DISTINCT class_name FROM zavegepant;",
    },
    {
        "input": "How many different molecules are there? ",
        "query": "SELECT COUNT(DISTINCT molecule_name) AS distinct_molecule_count FROM zavegepant;",
    },
]

In [None]:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    FAISS,
    k=5,
    input_keys=["input"],
)

In [None]:
from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotPromptTemplate,
    MessagesPlaceholder,
    PromptTemplate,
    SystemMessagePromptTemplate,
)

system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.

Here are some examples of user inputs and their corresponding SQL queries:"""

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input", "dialect", "top_k"],
    prefix=system_prefix,
    suffix="",
)

In [None]:
full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=few_shot_prompt),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

In [None]:
# Example formatted prompt
prompt_val = full_prompt.invoke(
    {
        "input": "How many patients were on a CGRP Injectable drug class on month 60?",
        "top_k": 5,
        "dialect": "SQLite",
        "agent_scratchpad": [],
    }
)
print(prompt_val.to_string())

In [None]:
agent = create_sql_agent(
    llm=llm,
    db=db,
    prompt=full_prompt,
    verbose=True,
    agent_type="openai-tools",
)

In [None]:
agent.invoke({"input": "How many patients were on a CGRP Injectable drug class on month 60?"})

In [None]:
import ast
import re


def query_as_list(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return list(set(res))


molecules = query_as_list(db, "SELECT molecule_name  FROM zavegepant")
classes = query_as_list(db, "SELECT class_name  FROM zavegepant")
stocks = query_as_list(db, "SELECT patient_stock_name FROM zavegepant")
molecules[:5]

In [None]:
from langchain.agents.agent_toolkits import create_retriever_tool

vector_db = FAISS.from_texts(molecules + classes + stocks, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 5})
description = """Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \
valid proper nouns. Use the noun most similar to the search."""
retriever_tool = create_retriever_tool(
    retriever,
    name="search_proper_nouns",
    description=description,
)

In [None]:
system = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the "search_proper_nouns" tool! 

You have access to the following tables: {table_names}

If the question does not seem related to the database, just return "I don't know" as the answer.


Any question pertaining to 'molecules' or 'drugs' refers to any of the different generic drug names column 'molecule_name' from the 'zavegepant' table.

Any question pertaining to 'drug classes' refers to any of the following: 
['lapsed' 'NSAID' 'Weak Opioid' 'Antiemetic' 'Steroid' 'Sedative'
 'Antipsychotic' 'Triptan' 'Antiepileptic' 'SSRI' 'Strong Opioid' 'SNRI'
 'Neural' 'Muscle Relaxant' 'Cardiovascular' 'Beta Blocker'
 'CGRP Injectable' 'Analgesic' 'Ergot' 'Tricyclic' 'Calcium Blocker'
 'CGRP Oral' 'Ditan' 'CGRP Nasal'],  which can be found on the colunm 'class_name' from the 'zavegepant' table.

Any question pertaining to 'stocks' refers to any of the following ['Lapsed' 'Preventive' 'Symptomatic' 'Naive' 'Triptans' 'CGRP Injectable' 'CGRP Oral' 'CGRP Nasal'],
which can be found on the colunm 'patient_stock_name' from the 'zavegepant' table.

If you can't find the exact molecule/class/stock asked for, please look up the closest noun. Allow for misspelled words.
"""

prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{input}"), MessagesPlaceholder("agent_scratchpad")]
)
agent = create_sql_agent(
    llm=llm,
    db=db,
    extra_tools=[retriever_tool],
    prompt=prompt,
    agent_type="openai-tools",
    verbose=True,
)

In [None]:
agent.invoke({"input": "How many unique patients were ON the molecule_name ZavegOpant on month 60?"})