In [None]:
!pip install crewai crewai-tools
!pip install crewai



In [1]:
import sqlite3
import re
import json
import requests
from langchain.llms.base import LLM
from langchain.schema import Generation
from langchain.callbacks.manager import CallbackManagerForLLMRun
from typing import Any, Dict, Iterator, Optional, List

# LLM-related functionality
ngrok_url = "https://54bc-35-194-40-255.ngrok-free.app"

def query_llm(user_message):
    tokens = requests.post(f"{ngrok_url}get_num_tokens", headers={'Content-Type': 'application/json'}, data=json.dumps({"text": user_message}))
    tokens = tokens.json()
    print(tokens.get("num_tokens"))
    endpoint = f"{ngrok_url}/invoke_llm"
    question = f"""{user_message} assistant"""
    data = {"text": question}

    response = requests.post(endpoint, headers={'Content-Type': 'application/json'}, data=json.dumps(data))

    if response.status_code == 200:
        response_json = response.json()
        result_text = response_json.get('result', 'No result key in response')
        print(result_text)
        return result_text.strip()
    else:
        print(f"Request failed with status code: {response.status_code}")
        print(response.text)
        return response.text

# Custom LLM wrapper
class CustomLLM(LLM):
    custom_function: Any

    def __init__(self, custom_function: Any, **kwargs: Any):
        super().__init__(**kwargs)
        self.custom_function = custom_function

    def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        response = self.custom_function(prompt)
        return response

    def _stream(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> Iterator[Generation]:
        response = self.custom_function(prompt)
        for char in response:
            yield Generation(text=char)

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        return {"model_name": "CustomChatModel"}

    @property
    def _llm_type(self) -> str:
        return "custom"

custom_llm = CustomLLM(custom_function=query_llm)


In [2]:
data_researcher_prompt_template = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a Data Researcher Agent. Your responsibility is to analyze the user's clinical trial-related question, extract relevant keywords using an LLM, and dynamically identify the appropriate tables and fields from the database schema. Then, provide structured information to be used for SQL query generation.

- **Tables**: Identify the relevant tables based on the extracted keywords.
- **Fields**: List the fields from these tables that are required to address the user query.
- **Conditions**: Identify any conditions or filters based on the user query.

Make sure the information is clear and structured for further SQL query generation.<|eot_id|>

<|start_header_id|>user<|end_header_id|>

{question}<|eot_id|>

<|start_header_id|>assistant<|end_header_id|>

Based on the user's query, the relevant database schema involves the following:
- **Tables**: {selected_tables}
- **Fields**: {selected_fields}
- **Conditions**: {conditions}

This information is ready for the next agent to process and generate the SQL query.<|eot_id|>
"""


sql_specialist_prompt_template = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are an SQL Specialist Agent. Your responsibility is to take structured data (tables, fields, and conditions) from the Data Researcher Agent and use this information to generate an SQL query. The SQL query should be optimized for fetching clinical trial data from the database, while considering any necessary filters or conditions.

- **Tables**: Use the provided tables to generate the FROM clause.
- **Fields**: Include the listed fields in the SELECT clause.
- **Conditions**: Add the specified conditions in the WHERE clause if applicable.

Generate a well-formed SQL query using the provided information.<|eot_id|>

<|start_header_id|>data_researcher<|end_header_id|>

- Tables: {tables}
- Fields: {fields}
- Conditions: {conditions}<|eot_id|>

<|start_header_id|>assistant<|end_header_id|>

Here is the generated SQL query based on the provided information:
```sql
{sql_query}
"""


execution_agent_prompt_template = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are an Execution Agent. Your task is to execute the nct ids driven from the clinical trials database. You take the user query and search for the appropriate nct_id related to it in the clinical trials database. Then, you will return the nct ids in a structured format.

Make sure to handle errors during execution and return an appropriate message.<|eot_id|>

<|start_header_id|>sql_specialist<|end_header_id|>

Here is the SQL query that needs to be executed:
```sql
{sql_query}
```<|eot_id|>

<|start_header_id|>assistant<|end_header_id|>

The query has been executed. Here are the results:
- **Results**: {query_results}

If there were errors, the output will indicate the issue.
This output is now ready for summarization.<|eot_id|>
"""


writer_agent_prompt_template = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a Writer Agent. Your role is to take raw results from the Execution Agent and summarize them in a human-readable format with the help of the llm. If the results are empty or no relevant data is found, generate an appropriate response indicating this.

Make sure the summary is concise, clear, and informative.

<|eot_id|>

<|start_header_id|>execution_agent<|end_header_id|>

Here are the query results that need to be summarized:
- **Results**: {query_results}<|eot_id|>

<|start_header_id|>assistant<|end_header_id|>

The summary of the retrieved data is:
{summary}

If there is no relevant data, output: "No relevant data found."<|eot_id|>
"""



In [3]:
import requests
import json
from typing import Any, Dict, List, Tuple
import sqlite3  

# Function to execute SQL query
def execute_sql_query(query: str) -> List[Dict[str, Any]]:
    print(f"Executing SQL query: {query}")
    # Replace with actual database interaction code
    try:
        conn = sqlite3.connect('clinical_trials.db')  
        cursor = conn.cursor()
        cursor.execute(query)
        results = cursor.fetchall()
        conn.close()
        return [dict(row) for row in results]  # Convert to list of dicts
    except Exception as e:
        print(f"Error executing query: {e}")
        return []


In [4]:
class DataResearcherAgent:
    """Fetches relevant information based on a user query."""
    def fetch_information(self, user_question):
        return f"DataResearcherAgent has fetched relevant research information for: {user_question}"

from typing import Dict, List, Any

class SQLSpecialistAgent:
    def __init__(self, schema_info: Dict[str, List[str]]):
        """
        Initialize the SQLSpecialistAgent with schema information.

        :param schema_info: A dictionary containing table names as keys and their respective field names as values.
        """
        self.schema_info = schema_info

    def generate_sql(self, extracted_info: Dict[str, Any]) -> str:
        """
        Generate SQL queries based on the extracted information, which includes
        tables, conditions, fields, and relationships.

        :param extracted_info: A dictionary containing the following keys:
            - "tables": List of table names to be queried.
            - "conditions": A dictionary of conditions for the WHERE clause.
            - "fields": List of field names to select. If empty, selects all fields.
        
        :return: A string containing the generated SQL query.
        """
        tables = extracted_info.get("tables", [])
        conditions = extracted_info.get("conditions", {})
        fields = extracted_info.get("fields", ["*"])  # Default to selecting all fields

        if not tables:
            return "SELECT * FROM studies LIMIT 10;"  # Default query if no specific table

        # Construct the SELECT clause with table prefixes
        if fields == ["*"]:
            select_clause = ', '.join(f"{table}.{field}" for table in tables for field in self.schema_info[table])
        else:
            select_clause = ', '.join(f"{table}.{field}" for field in fields for table in tables if field in self.schema_info[table])

        # Construct the FROM clause
        from_clause = ', '.join(tables)

        # Construct the WHERE clause based on conditions
        where_clauses = [f"{table}.{field} = '{value}'" for table, value in conditions.items() for table in tables if field in self.schema_info[table]]
        where_clause = ' AND '.join(where_clauses) if where_clauses else ""

        # Build the final SQL query
        sql_query = f"SELECT {select_clause} FROM {from_clause}"
        if where_clause:
            sql_query += f" WHERE {where_clause}"
        sql_query += " LIMIT 10;"  # Limiting results for safety

        return sql_query


In [5]:
class ExecutionAgent:
    def __init__(self, db_path='clinical_trials.db', llm=None):
        self.conn = self.connect_to_database(db_path)
        self.llm = llm

    def connect_to_database(self, db_path):
        try:
            conn = sqlite3.connect(db_path)
            print(f"Connected to database: {db_path}")
            return conn
        except sqlite3.Error as e:
            print(f"Error connecting to database: {e}")
            return None

    def fetch_nct_id(self, query_type, query_value):
        cursor = self.conn.cursor()
        if query_type == 'drug':
            cursor.execute(f"SELECT nct_id FROM all_interventions WHERE names LIKE ? LIMIT 50", (f'%{query_value}%',))
        elif query_type == 'condition':
            cursor.execute(f"SELECT nct_id FROM all_conditions WHERE names LIKE ? LIMIT 50", (f'%{query_value}%',))
        elif query_type == 'trial':
            cursor.execute(f"SELECT nct_id FROM all_keywords WHERE names LIKE ? LIMIT 50", (f'%{query_value}%',))

        results = cursor.fetchall()
        return [row[0] for row in results]

    def fetch_details_for_nct_id(self, nct_id):
        cursor = self.conn.cursor()
        cursor.execute(f"SELECT * FROM all_keywords WHERE nct_id = ?", (nct_id,))
        trial_details = cursor.fetchall()

        cursor.execute(f"SELECT * FROM all_interventions WHERE nct_id = ?", (nct_id,))
        drug_details = cursor.fetchall()

        cursor.execute(f"SELECT * FROM all_conditions WHERE nct_id = ?", (nct_id,))
        condition_details = cursor.fetchall()

        return trial_details, drug_details, condition_details

    def handle_user_query(self, user_question):
        query_type, query_value = self.map_question_to_query(user_question)
        if query_type and query_value:
            nct_ids = self.fetch_nct_id(query_type, query_value)
            if nct_ids:
                llm_input = f"Query results for {query_value} ({query_type}):\n"
                for nct_id in nct_ids:
                    trial_details, drug_details, condition_details = self.fetch_details_for_nct_id(nct_id)
                    llm_input += f"\nDetails for nct_id {nct_id}:\n"
                    llm_input += f"Trial Details: {trial_details}\n"
                    llm_input += f"Drug/Intervention Details: {drug_details}\n"
                    llm_input += f"Condition Details: {condition_details}\n"
                summary_request = f"Based on the following details:\n{llm_input}\nPlease summarize this information."
                if self.llm:
                    llm_output = self.llm._call(summary_request)
                    return llm_output
                else:
                    return None
            else:
                return None
        else:
            return None

    def map_question_to_query(self, question):
        question = question.lower()
        if "drug" in question or "intervention" in question or "treatment" in question:
            drug_match = re.search(r"(?:drug|intervention|treatment)\s*(?:for\s*)?(\w+)", question)
            if drug_match:
                return 'drug', drug_match.group(1)
        if "condition" in question or "disease" in question or "for" in question:
            condition_match = re.search(r"(?:condition|disease|for)\s*(\w+)", question)
            if condition_match:
                return 'condition', condition_match.group(1)
        if "trial" in question:
            trial_match = re.search(r"trial\s*(\w+)", question)
            if trial_match:
                return 'trial', trial_match.group(1)
        return None, None

    def close_connection(self):
        if self.conn:
            self.conn.close()


In [6]:
schema_info ={

 'all_id_information': ['nct_id', 'names'],
 'all_interventions': ['nct_id', 'names'],
 'all_intervention_types': ['nct_id', 'names'], 
'all_keywords': ['nct_id', 'names'], 
'all_overall_officials': ['nct_id', 'names'], 
'all_overall_official_affiliations': ['nct_id', 'names'],
 'all_primary_outcome_measures': ['nct_id', 'names'],
 'all_secondary_outcome_measures': ['nct_id', 'names'],
 'all_sponsors': ['nct_id', 'names'],
 'all_states': ['nct_id', 'names'],
 'analyzed_studies': ['id', 'nct_id', 'url', 'brief_title', 'start_month', 'start_year', 'overall_status', 'p_completion_month', 'p_completion_year', 'completion_month', 'completion_year', 'verification_month', 'verification_year', 'p_comp_mn', 'p_comp_yr', 'received_year', 'mntopcom', 'enrollment', 'number_of_arms', 'allocation', 'masking', 'phase', 'primary_purpose', 'sponsor_name', 'agency_class', 'collaborator_names', 'funding', 'responsible_party_type', 'responsible_party_organization', 'us_coderc', 'oversight', 'behavioral', 'biological', 'device', 'dietsup', 'drug', 'genetic', 'procedure', 'radiation', 'otherint', 'intervg1', 'results', 'resultsreceived_month', 'resultsreceived_year', 'firstreceived_results_dt', 't2result', 't2result_imp', 't2resmod', 'results12', 'delayed', 'dr_received_dt', 'mn2delay', 'delayed12'],
 'calculated_values': ['id', 'nct_id', 'number_of_facilities', 'number_of_nsae_subjects', 'number_of_sae_subjects', 'registered_in_calendar_year', 'nlm_download_date', 'actual_duration', 'were_results_reported', 'months_to_report_results', 'has_us_facility', 'has_single_facility', 'minimum_age_num', 'maximum_age_num', 'minimum_age_unit', 'maximum_age_unit', 'number_of_primary_outcomes_to_measure', 'number_of_secondary_outcomes_to_measure', 'number_of_other_outcomes_to_measure'],
 'designs': ['id', 'nct_id', 'allocation', 'intervention_model', 'observational_model', 'primary_purpose', 'time_perspective', 'masking', 'masking_description', 'intervention_model_description', 'subject_masked', 'caregiver_masked', 'investigator_masked', 'outcomes_assessor_masked'], 
'design_group_interventions': ['id', 'nct_id', 'design_group_id', 'intervention_id'],
 'participant_flows': ['id', 'nct_id', 'recruitment_details', 'pre_assignment_details', 'units_analyzed'],
 'responsible_parties': ['id', 'nct_id', 'responsible_party_type', 'name', 'title', 'organization', 'affiliation', 'old_name_title'], 
'result_agreements': ['id', 'nct_id', 'pi_employee', 'agreement', 'restriction_type', 'other_details', 'restrictive_agreement'], 
'result_contacts': ['id', 'nct_id', 'organization', 'name', 'phone', 'email', 'extension'],
 'result_groups': ['id', 'nct_id', 'ctgov_group_code', 'result_type', 'title', 'description', 'outcome_id'], 
'retractions': ['id', 'reference_id', 'pmid', 'source', 'nct_id'], 
'sponsors': ['id', 'nct_id', 'agency_class', 'lead_or_collaborator', 'name'],
 'all_conditions': ['nct_id', 'names'],
'studies': ['nct_id', 'nlm_download_date_description', 'study_first_submitted_date', 'results_first_submitted_date', 'disposition_first_submitted_date', 'last_update_submitted_date', 'study_first_submitted_qc_date', 'study_first_posted_date', 'study_first_posted_date_type', 'results_first_submitted_qc_date', 'results_first_posted_date', 'results_first_posted_date_type', 'disposition_first_submitted_qc_date', 'disposition_first_posted_date', 'disposition_first_posted_date_type', 'last_update_submitted_qc_date', 'last_update_posted_date', 'last_update_posted_date_type', 'start_month_year', 'start_date_type', 'start_date', 'verification_month_year', 'verification_date', 'completion_month_year', 'completion_date_type', 'completion_date', 'primary_completion_month_year', 'primary_completion_date_type', 'primary_completion_date', 'target_duration', 'study_type', 'acronym', 'baseline_population', 'brief_title', 'official_title', 'overall_status', 'last_known_status', 'phase', 'enrollment', 'enrollment_type', 'source', 'limitations_and_caveats', 'number_of_arms', 'number_of_groups', 'why_stopped', 'has_expanded_access', 'expanded_access_type_individual', 'expanded_access_type_intermediate', 'expanded_access_type_treatment', 'has_dmc', 'is_fda_regulated_drug', 'is_fda_regulated_device', 'is_unapproved_device', 'is_ppsd', 'is_us_export', 'biospec_retention', 'biospec_description', 'ipd_time_frame', 'ipd_access_criteria', 'ipd_url', 'plan_to_share_ipd', 'plan_to_share_ipd_description', 'created_at', 'updated_at', 'source_class', 'delayed_posting', 'expanded_access_nctid', 'expanded_access_status_for_nctid', 'fdaaa801_violation', 'baseline_type_units_analyzed', 'patient_registry'],
 'all_countries': ['nct_id', 'names'],
 'all_design_outcomes': ['nct_id', 'names'], 
'all_facilities': ['nct_id', 'names'],
 'all_group_types': ['nct_id', 'names']
}


In [7]:
# Integrating agents
class MultiAgentSystem:
    def __init__(self, llm, schema_info):
        self.data_researcher = DataResearcherAgent()
        self.sql_specialist = SQLSpecialistAgent(schema_info=schema_info)
        self.execution_agent = ExecutionAgent(llm=llm)

    def handle_query(self, user_question):
        # DataResearcherAgent output
        research_output = self.data_researcher.fetch_information(user_question)

        # SQL Specialist output
        extracted_info = {
            "tables": ["all_interventions", "all_conditions", "all_keywords"],  # Example tables to use
            "conditions": {},  # Placeholder for conditions, this can be expanded
            "fields": ["nct_id"],  # Placeholder for fields
        }
        sql_output = self.sql_specialist.generate_sql(extracted_info)

        # ExecutionAgent output
        execution_output = self.execution_agent.handle_user_query(user_question)

        # Combine and return the result from all agents
        combined_output = (
            f"Data Research Output: {research_output}\n"
            f"SQL Query Output: {sql_output}\n"
            f"Execution Output: {execution_output}"
        )
        return combined_output

    def close(self):
        self.execution_agent.close_connection()


In [8]:
# Dynamic question-answering process
def dynamic_question_process(multi_agent_system):
    print("Welcome to the Clinical Trials Query System!")
    print("You can ask questions related to drugs, conditions, or clinical trials.")
    while True:
        user_query = input("Enter your question (or type 'exit' to quit): ")
        if user_query.lower() == 'exit':
            break
        response = multi_agent_system.handle_query(user_query)
        print("Response:\n", response)

# Usage
if __name__ == "__main__":
    multi_agent_system = MultiAgentSystem(custom_llm, schema_info)
    dynamic_question_process(multi_agent_system)
    multi_agent_system.close()

Connected to database: clinical_trials.db
Welcome to the Clinical Trials Query System!
You can ask questions related to drugs, conditions, or clinical trials.


ConnectionError: HTTPSConnectionPool(host='54bc-35-194-40-255.ngrok-free.appget_num_tokens', port=443): Max retries exceeded with url: / (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x0000021096ABB130>: Failed to resolve '54bc-35-194-40-255.ngrok-free.appget_num_tokens' ([Errno 11001] getaddrinfo failed)"))