In [None]:
import dspy



In [None]:
import os
import dspy
from dspy.retrieve.neo4j_rm import Neo4jRM
import openai

os.environ["NEO4J_URI"] = "bolt://localhost:7687"
os.environ["NEO4J_USERNAME"] = "tester"
os.environ["NEO4J_PASSWORD"] = "password"


neo4j_retriever = Neo4jRM(
    index_name="biobert_emb",
    text_node_property="term",
    k=10,
    embedding_provider="huggingface",
    embedding_model="text-embedding-ada-002",
)

dspy.settings.configure(rm=neo4j_retriever)

In [None]:
retriever = dspy.Retrieve(k=3)

In [None]:
query='When was the first FIFA World Cup held?'

# Call the retriever on a particular query.
topK_passages = retriever(query).passages

txt2SQL

In [10]:
import os

os.environ["DPS_CACHEBOOL"] = "False"
import json

import dspy
from dotenv import load_dotenv

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
from langchain_community.utilities import SQLDatabase

In [12]:
load_dotenv(".env")
AACT_USER = os.getenv("AACT_USER")
AACT_PWD = os.getenv("AACT_PWD")

In [13]:
tables = [
    "browse_interventions",
    "sponsors",
    "outcome_analysis_groups",
    "detailed_descriptions",
    "facilities",
    "studies",
    "outcomes",
    "browse_conditions",
    "outcome_analyses",
    "keywords",
    "eligibilities",
    "id_information",
    "design_group_interventions",
    "reported_events",
    "brief_summaries",
    "designs",
    "drop_withdrawals",
    "outcome_measurements",
    "countries",
]

In [14]:
database = "aact"
host = "aact-db.ctti-clinicaltrials.org"
user = AACT_USER
password = AACT_PWD
port = 5432
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
sql_db = SQLDatabase.from_uri(db_uri, include_tables=tables)

In [15]:
with open("./src/txt_2_sql/aact_schema.txt", "r") as f:
    aact_schema =  f.readlines()

aact_schema = "".join(aact_schema).replace("\n\n", "\n")

In [16]:
with open('./src/txt_2_sql/common_sql_mistakes.txt', "r") as f:
    common_mistakes = f.readlines()

In [17]:
common_mistakes = "".join(common_mistakes)

In [18]:
import yaml

with open('./src/txt_2_sql/aact_schema.yaml', 'r') as f:
    aact_schema = yaml.safe_load(f)


In [19]:
class Text2Sql(dspy.Signature):
    """Given an input question and a SQL db schema, create a syntactically correct PostgreSQL query to run. 
    Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
    Pay attention to use only the column names that you can see in the schema description. 
    Be careful to not query for columns that do not exist. Pay attention to which column is in which table.
    Also, qualify column names with the table name when needed
    """

    context:str = dspy.InputField(prefix="Schema:", desc="SQL db schema")
    question:str = dspy.InputField(prefix="Question:", desc="user question")
    sql_query:str = dspy.OutputField(prefix="SQLQuery:", desc="SQL query that answers user question")

In [20]:
class CheckSqlQuery(dspy.Signature):
    "Take a SQL query and a list of common mistakes and suggest a revised SQL query"
    context:str = dspy.InputField(prefix="Common mistakes:", desc="Common SQL syntax mistakes")
    sql_query:str = dspy.InputField(prefix="SQLQuery:")
    revised_sql:str = dspy.OutputField(prefix="Revised SQLQuery:",)

In [21]:
class CheckSqlSchema(dspy.Signature):
    "Take a SQL query and a SQL db schema and suggest a revised SQL query"
    context:str = dspy.InputField(prefix="SQL db schema:")
    sql_query:str = dspy.InputField(prefix="SQLQuery:")
    revised_sql:str = dspy.OutputField(prefix="Revised SQLQuery:")

In [22]:
class CheckSqlError(dspy.Signature):
    "Take a SQL query and the error produced when running it and suggests a revised SQL query"
    db_schema = dspy.InputField(desc="SQL DB schema", prefix="Schema: ")
    question = dspy.InputField(desc="User question", prefix="Question: ")
    sql_query = dspy.InputField(desc="Original SQL query", prefix="SQLQuery: ")
    error = dspy.InputField(
        desc="Exception through when running SQL query", prefix="Exception: "
    )
    revised_sql = dspy.OutputField(desc="Revised SQL query", prefix="Revised:")

In [23]:
class QuestionSqlAnswer(dspy.Signature):
    question:str = dspy.InputField(prefix="Question: ")
    sql_output:str = dspy.InputField(prefix="SQL output: ")
    answer:str = dspy.OutputField(prefix="Answer:")

In [24]:
from typing import Literal

In [37]:
class PickListTables(dspy.Signature):
    "Based on a user question and SQL db schema. Write a list of tables that would be relevant"
    context:dict = dspy.InputField(prefix="Schema: ")
    question:str = dspy.InputField(prefix="Question: ")
    answer:list[Literal[*(x for x in aact_schema.keys())]] = dspy.OutputField()

In [38]:
from typing import List
from pydantic import BaseModel

class StringList(BaseModel):
    strings: List[str]

# Example usage:
input_data = {"strings": ["apple", "banana", "cherry"]}
string_list = StringList(**input_data)
print(string_list)

strings=['apple', 'banana', 'cherry']


In [39]:
class SqlTablesList(BaseModel):
    sql_tables: List[str]

In [40]:
acct_tables = SqlTablesList(sql_tables=list(aact_schema.keys()))

In [41]:
def write_schema_txt(schema:dict)->str:
    out_str= ""
    for table in schema.keys():
        out_str+= f"table name: {table}\n"
        out_str+= f"table description: {schema[table]['description']}\n"
        out_str+= f"table schema: {schema[table]['schema']}\n"
        out_str+= f"{schema[table]['example']}\n"
    return out_str

In [42]:
aact_schema_str = write_schema_txt(aact_schema)
aact_schema_str = aact_schema_str.replace("\n\n","\n")

In [43]:
class Txt2SqlAgent(dspy.Module):
    def __init__(self, sql_schema:str, common_mistakes=str) -> None:
        super().__init__()
        self.text_2_sql = dspy.Predict(Text2Sql)
        self.review_query = dspy.Predict(CheckSqlQuery)
        self.review_schema = dspy.Predict(CheckSqlSchema)
        self.sql_schema = sql_schema
        self.common_mistakes = common_mistakes
    
    def forward(self, question:str) ->str:
        response = {}
        response["txt2sql"] = self.text_2_sql(context=self.sql_schema, question=question)
        response["check_sql_query"] = self.review_query(context=self.common_mistakes, sql_query=response["txt2sql"]["sql_query"] )
        response["check_sql_schema"] =self.review_schema(context=self.sql_schema, sql_query=response["check_sql_query"]["revised_sql"] )
        return response
        

In [44]:
class ChainOfTables(dspy.Module):
    def __init__(self, sql_db: SQLDatabase) -> None:
        super().__init__()
        self.sql_db = sql_db
        self.text_2_sql = dspy.Predict(Text2Sql)
        self.review_query = dspy.Predict(CheckSqlQuery)
        self.question_sql_answer = dspy.Predict(CheckSqlSchema)
        self.schema = sql_db.get_table_info(sql_db.get_usable_table_names())

    def forward(self, question: str, n: int = 3):
        attempts = 0
        sql_output = None
        db_schema = self.schema
        response = self.text_2_sql(question=question, db_schema=db_schema)
        sql_query = response.sql_query
        # return response
        while attempts < n and sql_output is None:
            try:
                sql_output = self.sql_db.run(sql_query)
            except Exception as e:
                response = self.review_query(
                    question=question,
                    db_schema=db_schema,
                    sql_query=sql_query,
                    error=str(e),
                )
                sql_query = response.revised_sql
                print(sql_query)
            attempts += 1
        if sql_output is None:
            return {
                "question": question,
                "sql_query": sql_query,
                "answer": "Sorry I could not reply your question",
            }
        return self.question_sql_answer(question=question, sql_output=sql_output)

In [45]:
#lm = dspy.OllamaLocal(model="mistral", stop=["\n", "\n\n"])
# lm = dspy.OllamaLocal(model="mistral",stop=["[INST]", "[/INST]"])
# lm = dspy.OllamaLocal(model="mistral")
lm = dspy.OllamaLocal(model="mistral",stop=["[INST]", "[/INST]"], max_tokens=500)
dspy.settings.configure(lm=lm, temperature=0.0 )

In [46]:
aact_schema_str.replace("\n\n","\n")

"table name: brief_summaries\ntable description: clinical trial study protocol brief summary. if possible use this over detailed description.\ntable schema: CREATE TABLE brief_summaries (\n  nct_id VARCHAR NOT NULL,  -- clinical trial study unique id\n  description TEXT,  -- clinical trial description / brief summary\n  CONSTRAINT brief_summaries_pkey PRIMARY KEY (id), \n  CONSTRAINT brief_summaries_nct_id_fkey FOREIGN KEY(nct_id) REFERENCES studies (nct_id)\n );\n/*\n3 rows from brief_summaries table:\nnct_id  description\nNCT01308385  Despite enormous progress insufficient postoperative pain management remains a frequent problem in t\nNCT05280444  The purpose of this real-world study is to evaluate the safety and efficacy of lipiodol-TACE with id\nNCT00372151  The aim of the proposed study is to investigate the efficacy and safety of add-on gamma-glutamylethy\n*/\ntable name: browse_conditions\ntable description: condition studied in clinical trial. mesh term.\ntable schema: CREATE T

In [47]:
#aact_sql_rag = ChainOfTables(sql_db)
acct_sql = Txt2SqlAgent(aact_schema_str, common_mistakes)

question = (
    "Which clinical trial ids are associated with the condition 'Asthma' and conducted in"
    "the United States, China and India, while involving the"
    "intervention 'Xhance', and reporting more than five affected subjects"
    "in either deaths or serious adverse events?"
)

question = question.replace("\n", " ")
response = acct_sql(question=question)

In [66]:
acct_summary = [{"name": table, "description" : aact_schema[table]['description']} for table in aact_schema.keys()]

In [68]:
import numpy as np

In [71]:
acct_summary = np.array2string(np.array(acct_summary))

In [74]:
table_picker = dspy.Predict(PickListTables)
response = table_picker(context=acct_summary, question="Get the clinical study title")

In [75]:
response

Prediction(
    answer="Schema: ${context}\nQuestion: Get the clinical study title\nAnswer: The `studies` table in the given schema should contain the column for storing the clinical study title. Therefore, you can query the title by selecting the 'title' or 'name' column from the 'studies' table.\n\nAnswer: ${answer} = SELECT title FROM studies;"
)

In [31]:
print(response["txt2sql"]["rationale"])

KeyError: 'rationale'

In [None]:
print(response["check_sql_query"]["rationale"])

produce a revised SQL query with proper syntax and avoid common mistakes.

Common mistakes:
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Keep the query as simple as possible

SQLQuery: AS e ON s.study_id = e.study_id WHERE (num_deaths + num_saes) > 5;

Reasoning:
1. First, let's ensure that all identifiers are properly quoted to avoid any potential issues with special characters or reserved keywords. In this case, there are no obvious unquoted identifiers in the query.
2. Next, we need to make sure that data types match in predicates. Since `num_deaths` and `num_saes` are likely to be integers or decimal numbers, we don't need to worry about this issue in the given query.
3. The query appears to be simple enough and doesn't involve any user-defined functions, so there is no need to check for the correct number of arguments for functions.
4. Since the query doesn't involve any

In [44]:
print(response["txt2sql"]["sql_query"])

Here is an SQL query that should return the clinical trial ids associated with the condition 'Asthma' and conducted in the United States, China, and India, while involving the intervention 'Xhance', and reporting more than five affected subjects in either deaths or serious adverse events:

```sql
SELECT DISTINCT nct_id 
FROM conditions c
JOIN diagnoses d ON c.condition_id = d.condition_id
JOIN interventions i ON d.diagnosis_id = i.diagnosis_id
JOIN trials t ON i.intervention_id = t.intervention_id
WHERE c.name LIKE '%Asthma%' AND (t.country IN ('United States', 'China', 'India') OR FIND_STR(t.location, 'United States') > 0) AND i.drug_name = 'Xhance' AND (t.number_of_participants_affected_by_deaths > 5 OR t.number_of_participants_affected_by_serious_adverse_events > 5)
ORDER BY t.start_date DESC;
```

This query assumes that there are tables named `conditions`, `diagnoses`, `interventions`, and `trials`. The `conditions` table contains the name of each condition, while the `diagnoses` 

In [45]:
print(response["check_sql_query"]["revised_sql"])

```sql-- Revised SQL query with common mistakes corrections
REVISED:

SQLQuery: SELECT DISTINCT nct_id FROM conditions c
          JOIN diagnoses d ON c.condition_id = d.condition_id
          JOIN interventions i ON d.diagnosis_id = i.intervention_id
          JOIN trials t ON i.intervention_id = t.intervention_id
         WHERE c.name LIKE '%Asthma%' -- ILIKE is preferred over LIKE for case-insensitive search in PostgreSQL, but LIKE works as well
         AND (t.country IN ('United States', 'China', 'India') -- No need to use OR with FIND_STR since IN can handle NULL values
             OR t.location LIKE '%United States%' -- Properly quote identifiers and use % for wildcard characters
           )
         AND i.drug_name = 'Xhance'
         AND (t.number_of_participants_affected_by_deaths > 5 OR t.number_of_participants_affected_by_serious_adverse_events > 5) -- Use proper data types and cast if necessary
         ORDER BY t.start_date DESC;

Revised SQLQuery:
SELECT DISTINCT nct_i

In [46]:
print(response["check_sql_schema"]["revised_sql"])

SELECT DISTINCT nct\_id FROM conditions c
JOIN diagnoses d ON c.condition\_id = d.condition\_id
JOIN interventions i ON d.diagnosis\_id = i.intervention\_id
JOIN trials t ON i.intervention\_id = t.intervention\_id
WHERE LOWER(c.name) LIKE '%asthma%' -- Make all WHERE statements case insensitive using LOWER
AND (t.country IN ('United States', 'China', 'India') OR (t.location IS NOT NULL AND t.location LIKE '%united states%')) -- Properly quote identifiers and use % for wildcard characters
AND i.drug\_name = 'Xhance'
AND (CAST(t.number\_of\_participants\_affected\_by\_deaths AS INT) > 5 OR CAST(t.number\_of\_participants\_affected\_by\_serious\_adverse\_events AS INT) > 5);

This revised SQL query should correctly search for clinical trials related to Asthma in the United States, Canada, or India involving at least 6 participants affected by deaths or serious adverse events. The results will be ordered by descending start date.


In [61]:
lm = dspy.OllamaLocal(model="mistral",stop=["[INST]", "[/INST]"], max_tokens=2000)
dspy.settings.configure(lm=lm, temperature=0.7)

qa = dspy.Predict("question -> long_answer")
response = qa(question="Write a long text about Health Data Science")

In [62]:
print(response["long_answer"])

Title: Unraveling the Intricacies of Health Data Science: A Comprehensive Overview

Health Data Science (HDS) is an interdisciplinary field that combines principles from healthcare, data science, and various domains of engineering to extract meaningful insights from health-related data. This data-driven approach aims to improve patient outcomes, optimize clinical workflows, and advance medical research.

The Health Data Science landscape encompasses several key components:

1. **Data Collection**: The first step in HDS involves gathering relevant data from various sources such as Electronic Health Records (EHRs), wearable devices, clinical trials, and public health databases. This data can include demographic information, medical histories, lab results, vital signs, and imaging studies.

2. **Data Preprocessing**: Once collected, the raw data undergoes preprocessing to ensure its quality and consistency. This may involve cleaning, normalization, and transformation of data into a format