In [None]:
import dspy



In [7]:
import os
import dspy
from dspy.retrieve.neo4j_rm import Neo4jRM
import openai

os.environ["NEO4J_URI"] = "bolt://localhost:7687"
os.environ["NEO4J_USERNAME"] = "tester"
os.environ["NEO4J_PASSWORD"] = "password"


neo4j_retriever = Neo4jRM(
    index_name="biobert_emb",
    text_node_property="term",
    k=10,
    embedding_provider="huggingface",
    embedding_model="text-embedding-ada-002",
)

dspy.settings.configure(rm=neo4j_retriever)

In [8]:
retriever = dspy.Retrieve(k=3)

In [6]:
query='When was the first FIFA World Cup held?'

# Call the retriever on a particular query.
topK_passages = retriever(query).passages

AttributeError: 'Embedder' object has no attribute 'client'

txt2SQL

In [1]:
import os

os.environ["DPS_CACHEBOOL"] = "False"
import json

import dspy
from dotenv import load_dotenv

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from langchain_community.utilities import SQLDatabase

In [3]:
load_dotenv(".env")
AACT_USER = os.getenv("AACT_USER")
AACT_PWD = os.getenv("AACT_PWD")

In [4]:
tables = [
    "browse_interventions",
    "sponsors",
    "outcome_analysis_groups",
    "detailed_descriptions",
    "facilities",
    "studies",
    "outcomes",
    "browse_conditions",
    "outcome_analyses",
    "keywords",
    "eligibilities",
    "id_information",
    "design_group_interventions",
    "reported_events",
    "brief_summaries",
    "designs",
    "drop_withdrawals",
    "outcome_measurements",
    "countries",
]

In [5]:
database = "aact"
host = "aact-db.ctti-clinicaltrials.org"
user = AACT_USER
password = AACT_PWD
port = 5432
db_uri = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
sql_db = SQLDatabase.from_uri(db_uri, include_tables=tables)

In [33]:
with open("./src/txt_2_sql/aact_schema.txt", "r") as f:
    aact_schema =  f.readlines()

aact_schema = "".join(aact_schema).replace("\n\n", "\n")

In [35]:
print(aact_schema)

CREATE TABLE brief_summaries (
	nct_id VARCHAR NOT NULL,  -- Clinical Trial study unique id
	description TEXT,  -- Clinical trial description / brief summary
	CONSTRAINT brief_summaries_pkey PRIMARY KEY (id), 
	CONSTRAINT brief_summaries_nct_id_fkey FOREIGN KEY(nct_id) REFERENCES studies (nct_id)
)
/*
3 rows from brief_summaries table:
nct_id	description
NCT01308385	Despite enormous progress insufficient postoperative pain management remains a frequent problem in t
NCT05280444	The purpose of this real-world study is to evaluate the safety and efficacy of lipiodol-TACE with id
NCT00372151	The aim of the proposed study is to investigate the efficacy and safety of add-on gamma-glutamylethy
*/

CREATE TABLE browse_conditions ( 
	nct_id VARCHAR NOT NULL,  -- Clinical Trial study unique id
	mesh_term VARCHAR,  -- clinical condition
	downcase_mesh_term VARCHAR, -- downcase clinical condition 
	mesh_type VARCHAR, 
	CONSTRAINT browse_conditions_pkey PRIMARY KEY (id), 
	CONSTRAINT browse_conditi

In [6]:
class Text2Sql(dspy.Signature):
    """Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.
    Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
    Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed
    """

    db_schema = dspy.InputField(
        desc="Only use tables listed in the schema", prefix="Schema:"
    )
    question = dspy.InputField(prefix="Question:")
    sql_query = dspy.OutputField(prefix="SQLQuery:")

In [7]:
class CheckSqlQuery(dspy.Signature):
    "Take a SQL query and the error produced when running it and suggests a revised SQL query"
    db_schema = dspy.InputField(desc="SQL DB schema", prefix="Schema: ")
    question = dspy.InputField(desc="user question", prefix="Question: ")
    sql_query = dspy.InputField(desc="Original SQL query", prefix="SQLQuery: ")
    error = dspy.InputField(
        desc="Exception through when running SQL query", prefix="Exception: "
    )
    revised_sql = dspy.OutputField(desc="Revised SQL query", prefix="Revised:")

In [8]:
class QuestionSqlAnswer(dspy.Signature):
    question = dspy.InputField(prefix="Question: ")
    sql_output = dspy.InputField(prefix="SQL output: ")
    answer = dspy.OutputField(prefix="Answer:")

In [15]:
class ChainOfTables(dspy.Module):
    def __init__(self, sql_db: SQLDatabase) -> None:
        super().__init__()
        self.sql_db = sql_db
        self.text_2_sql = dspy.Predict(Text2Sql)
        self.review_query = dspy.Predict(CheckSqlQuery)
        self.question_sql_answer = dspy.Predict(QuestionSqlAnswer)
        self.schema = sql_db.get_table_info(sql_db.get_usable_table_names())

    def forward(self, question: str, n: int = 3):
        attempts = 0
        sql_output = None
        db_schema = self.schema
        response = self.text_2_sql(question=question, db_schema=db_schema)
        sql_query = response.sql_query
        # return response
        while attempts < n and sql_output is None:
            try:
                sql_output = self.sql_db.run(sql_query)
            except Exception as e:
                response = self.review_query(
                    question=question,
                    db_schema=db_schema,
                    sql_query=sql_query,
                    error=str(e),
                )
                sql_query = response.revised_sql
                print(sql_query)
            attempts += 1
        if sql_output is None:
            return {
                "question": question,
                "sql_query": sql_query,
                "answer": "Sorry I could not reply your question",
            }
        return self.question_sql_answer(question=question, sql_output=sql_output)

In [16]:
lm = dspy.OllamaLocal(model="sqlcoder", stop=["\n", "\n\n"])
dspy.settings.configure(lm=lm, temperature=0.0)

In [17]:
aact_sql_rag = ChainOfTables(sql_db)

question = (
    "Which clinical trial ids are associated with the condition 'Asthma' and conducted in"
    "the United States, China and India, while involving the"
    "intervention 'Xhance', and reporting more than five affected subjects"
    "in either deaths or serious adverse events?"
)

robust_sql_instructions = (
    "\nSQL query requirements:"
    # "Remove duplicates. "
    "Always include the nct_id in the SELECT statement. "
    "Keep the Query as simples as possible."
    # "Try to make the SQL query as robust as possible. "
    "Make all WHERE statements case insensitive using LOWER"
)

question = question.replace("\n", " ")
response = aact_sql_rag(question=question + robust_sql_instructions)


CREATE TABLE brief_summaries (
	id SERIAL NOT NULL, 
	nct_id VARCHAR, 
	description TEXT, 
	CONSTRAINT brief_summaries_pkey PRIMARY KEY (id), 
	CONSTRAINT brief_summaries_nct_id_fkey FOREIGN KEY(nct_id) REFERENCES studies (nct_id)
)

/*
3 rows from brief_summaries table:
id	nct_id	description
48854911	NCT01308385	Despite enormous progress insufficient postoperative pain management remains a frequent problem in t
48854912	NCT05280444	The purpose of this real-world study is to evaluate the safety and efficacy of lipiodol-TACE with id
49057466	NCT00372151	The aim of the proposed study is to investigate the efficacy and safety of add-on gamma-glutamylethy
*/


CREATE TABLE browse_conditions (
	id SERIAL NOT NULL, 
	nct_id VARCHAR, 
	mesh_term VARCHAR, 
	downcase_mesh_term VARCHAR, 
	mesh_type VARCHAR, 
	CONSTRAINT browse_conditions_pkey PRIMARY KEY (id), 
	CONSTRAINT browse_conditions_nct_id_fkey FOREIGN KEY(nct_id) REFERENCES studies (nct_id)
)

/*
3 rows from browse_conditions table:
id

In [18]:
response

{'question': "Which clinical trial ids are associated with the condition 'Asthma' and conducted inthe United States, China and India, while involving theintervention 'Xhance', and reporting more than five affected subjectsin either deaths or serious adverse events?\nSQL query requirements:Always include the nct_id in the SELECT statement. Keep the Query as simples as possible.Make all WHERE statements case insensitive using LOWER",
 'sql_query': "SELECT nct_id FROM clinical_trial WHERE interventional_name ILIKE '%Asthma%' AND country IN ('United States', 'China', 'India') AND intervention_name ILIKE '%Xhance%' AND adverse_events_reported > 5;",
 'answer': 'Sorry I could not reply your question'}