<a href="https://colab.research.google.com/github/ShadNygren/RAG-based_Architectures_for_Drug_Side_Effect_Retrieval_in_LLMs/blob/main/RAG_based_Architectures_for_Drug_Side_Effect_Retrieval_in_LLMs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# ***** Assessment of LLMs for drug side effect identification *****




# ===== Load the basic libraries =====

In [None]:
# essentials
import sys
import pandas as pd
import numpy as np

# plotting
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

# ===== pip install libraries =====

In [None]:
%pip install langchain

In [None]:
from torch import cuda, bfloat16
import torch
import transformers
from transformers import AutoTokenizer
from time import time
#import chromadb
#from chromadb.config import Settings
#
# Comment out 2024-08-10
#from langchain.llms import HuggingFacePipeline
#from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
#from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
#from langchain.vectorstores import Chroma

# ===== Google Cloud Secrets Manager =====

In [None]:
# Install the Secret Manager Client Library

%pip install google-cloud-secret-manager

In [None]:
# Authenticate with Google Cloud

from google.colab import auth as google_auth
google_auth.authenticate_user()

In [None]:
from google.auth import default
from google.auth.transport.requests import Request

# Get the current credentials and project
credentials, project_id = default()

# Refresh the credentials (optional, but ensures they're up to date)
credentials.refresh(Request())

# Print the account email
print(f"Authenticated as: {credentials.service_account_email or credentials.client_email}")


In [None]:
from google.cloud import secretmanager
from google.api_core.exceptions import NotFound, PermissionDenied, GoogleAPIError

def get_secret_value_from_google_cloud_old(project_id: str, secret_id: str, version_id: str = 'latest') -> str:
    """
    Retrieves the value of a secret from Google Cloud Secret Manager.

    Args:
        project_id (str): The ID of the Google Cloud project.
        secret_id (str): The ID of the secret in Secret Manager.
        version_id (str): The version of the secret to retrieve. Defaults to 'latest'.

    Returns:
        str: The secret value decoded as a UTF-8 string.

    Raises:
        ValueError: If the secret value cannot be retrieved.
        NotFound: If the specified secret version does not exist.
        PermissionDenied: If the caller does not have permission to access the secret.
        GoogleAPIError: For other API-related errors.
    """
    try:
        # Initialize the Secret Manager Client
        client = secretmanager.SecretManagerServiceClient()

        # Build the resource name of the project_id and secret_id and version_id
        secret_name = client.secret_version_path(project_id, secret_id, version_id)

        # Access the Secret Version
        response = client.access_secret_version(request={"name": secret_name})

        # Extract and return the Secret Value
        secret_value = response.payload.data.decode('UTF-8')
        return secret_value

    except NotFound as e:
        raise ValueError(f"The specified secret '{secret_id}' version '{version_id}' was not found in project '{project_id}'.") from e
    except PermissionDenied as e:
        raise ValueError(f"Permission denied when accessing secret '{secret_id}' version '{version_id}' in project '{project_id}'.") from e
    except GoogleAPIError as e:
        raise ValueError(f"Failed to access the secret due to a Google API error: {e}") from e
    except Exception as e:
        raise ValueError(f"An unexpected error occurred: {e}") from e


In [None]:
import logging
from google.cloud import secretmanager
from google.auth import default
from google.auth.transport.requests import Request
from google.api_core.exceptions import NotFound, PermissionDenied, GoogleAPIError

# Set up detailed logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

def get_secret_value_from_google_cloud_debug_logging(project_id: str, secret_id: str, version_id: str = 'latest') -> str:
    """
    Retrieves the value of a secret from Google Cloud Secret Manager with extensive logging.

    Args:
        project_id (str): The ID of the Google Cloud project.
        secret_id (str): The ID of the secret in Secret Manager.
        version_id (str): The version of the secret to retrieve. Defaults to 'latest'.

    Returns:
        str: The secret value decoded as a UTF-8 string.

    Raises:
        ValueError: If the secret value cannot be retrieved.
        NotFound: If the specified secret version does not exist.
        PermissionDenied: If the caller does not have permission to access the secret.
        GoogleAPIError: For other API-related errors.
    """
    try:
        # Initialize the Secret Manager Client
        client = secretmanager.SecretManagerServiceClient()

        # Log authentication details
        credentials, project_id_default = default()
        logging.debug(f"Authenticated with project_id: {project_id_default}")
        logging.debug(f"Credentials: {credentials}")
        logging.debug(f"Using service account: {credentials.service_account_email or 'None'}")

        # Build the resource name
        secret_name = client.secret_version_path(project_id, secret_id, version_id)
        logging.debug(f"Attempting to access Secret: {secret_name}")

        # Access the Secret Version
        response = client.access_secret_version(request={"name": secret_name})
        logging.debug(f"Successfully accessed Secret: {secret_name}")

        # Extract and return the Secret Value
        secret_value = response.payload.data.decode('UTF-8')
        return secret_value

    except NotFound as e:
        logging.error(f"Secret '{secret_id}' version '{version_id}' not found in project '{project_id}'. Exception: {e}")
        raise ValueError(f"The specified secret '{secret_id}' version '{version_id}' was not found in project '{project_id}'.") from e
    except PermissionDenied as e:
        logging.error(f"Permission denied when accessing secret '{secret_id}' version '{version_id}' in project '{project_id}'.")
        logging.error(f"Credentials: {credentials}")
        logging.error(f"Authenticated as: {credentials.service_account_email or 'User credentials'}")
        logging.error(f"Make sure the account has the necessary IAM roles (e.g., Secret Manager Secret Accessor).")
        logging.error(f"Google API Error: {e}")
        raise ValueError(f"Permission denied when accessing secret '{secret_id}' version '{version_id}' in project '{project_id}'.") from e
    except GoogleAPIError as e:
        logging.error(f"Failed to access the secret due to a Google API error: {e}")
        raise ValueError(f"Failed to access the secret due to a Google API error: {e}") from e
    except Exception as e:
        logging.error(f"An unexpected error occurred: {e}")
        raise ValueError(f"An unexpected error occurred: {e}") from e

# Example usage (replace with your actual values):
# secret_value = get_secret_value_from_google_cloud('DrugSideEffects', 'OPENAI_API_KEY')
# print(secret_value)


In [None]:
import google.auth
from google.auth.transport.requests import Request
from google.cloud import secretmanager

def get_secret_value_from_google_cloud(project_id: str, secret_id: str, version_id: str = 'latest') -> str:
    """
    Retrieves the value of a secret from Google Cloud Secret Manager using explicit user credentials.

    Args:
        project_id (str): The ID of the Google Cloud project.
        secret_id (str): The ID of the secret in Secret Manager.
        version_id (str): The version of the secret to retrieve. Defaults to 'latest'.

    Returns:
        str: The secret value decoded as a UTF-8 string.

    Raises:
        ValueError: If the secret value cannot be retrieved.
        NotFound: If the specified secret version does not exist.
        PermissionDenied: If the caller does not have permission to access the secret.
        GoogleAPIError: For other API-related errors.
    """
    try:
        # Load explicit credentials and refresh them to ensure they're up to date
        credentials, project = google.auth.default()
        credentials.refresh(Request())

        # Log the authenticated account
        logging.debug(f"Authenticated as: {credentials.service_account_email or credentials.client_email}")
        # Print the account email
        print(f"Authenticated as: {credentials.service_account_email or credentials.client_email}")
        print(f"Project ID: {project}")
        #print(f"Project ID: {project_id}

        # Initialize the Secret Manager Client with explicit credentials
        client = secretmanager.SecretManagerServiceClient(credentials=credentials)

        # Log the authenticated account
        logging.debug(f"Using credentials: {credentials}")
        logging.debug(f"Authenticated as: {credentials.service_account_email or credentials.client_email}")

        # Build the resource name
        secret_name = client.secret_version_path(project_id, secret_id, version_id)
        logging.debug(f"Attempting to access Secret: {secret_name}")

        # Access the Secret Version
        response = client.access_secret_version(request={"name": secret_name})
        logging.debug(f"Successfully accessed Secret: {secret_name}")

        # Extract and return the Secret Value
        secret_value = response.payload.data.decode('UTF-8')
        return secret_value

    except NotFound as e:
        logging.error(f"Secret '{secret_id}' version '{version_id}' not found in project '{project_id}'. Exception: {e}")
        raise ValueError(f"The specified secret '{secret_id}' version '{version_id}' was not found in project '{project_id}'.") from e
    except PermissionDenied as e:
        logging.error(f"Permission denied when accessing secret '{secret_id}' version '{version_id}' in project '{project_id}'.")
        logging.error(f"Credentials: {credentials}")
        logging.error(f"Authenticated as: {credentials.service_account_email or 'User credentials'}")
        logging.error(f"Make sure the account has the necessary IAM roles (e.g., Secret Manager Secret Accessor).")
        logging.error(f"Google API Error: {e}")
        raise ValueError(f"Permission denied when accessing secret '{secret_id}' version '{version_id}' in project '{project_id}'.") from e
    except GoogleAPIError as e:
        logging.error(f"Failed to access the secret due to a Google API error: {e}")
        raise ValueError(f"Failed to access the secret due to a Google API error: {e}") from e
    except Exception as e:
        logging.error(f"An unexpected error occurred: {e}")
        raise ValueError(f"An unexpected error occurred: {e}") from e

# Example usage
# secret_value = get_secret_value_from_google_cloud('DrugSideEffects', 'OPENAI_API_KEY')
# print(secret_value)


# ===== API Keys =====

In [None]:
#from google.colab import userdata

In [None]:
import os

#os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')
os.environ['OPENAI_API_KEY'] = get_secret_value_from_google_cloud(project_id='995753152222', secret_id='OPENAI_API_KEY_DrugSideEffects', version_id='latest')

#PINECONE_API_KEY = userdata.get('PINECONE_API_KEY')
PINECONE_API_KEY = get_secret_value_from_google_cloud(project_id='995753152222', secret_id='PINECONE_API_KEY', version_id='latest')
#os.environ["PINECONE_API_KEY"] = userdata.get('PINECONE_API_KEY')
os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
#PINECONE_ENV = 'us-west2-aws'

#from huggingface_hub import login
#mytoken = userdata.get('HF_TOKEN')
#login(mytoken)

In [None]:
os.environ["NEO4J_URI"] = get_secret_value_from_google_cloud(project_id='995753152222', secret_id='NEO4J_URI_Diego', version_id='latest')
os.environ["NEO4J_USERNAME"] = get_secret_value_from_google_cloud(project_id='995753152222', secret_id='NEO4J_USERNAME_Diego', version_id='latest')
os.environ["NEO4J_PASSWORD"] = get_secret_value_from_google_cloud(project_id='995753152222', secret_id='NEO4J_PASSWORD_Diego', version_id='latest')

In [None]:
AWS_BEDROCK_URL = get_secret_value_from_google_cloud(project_id='995753152222', secret_id='AWS_BEDROCK_URL', version_id='latest')

# ===== Connect to Google Drive and define the root directory =====

In [None]:
from google.colab import userdata

In [None]:
# access to files in drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
root_dir = userdata.get('DrugSideEffects_Root_Dir')

# ===== Define the model, the device =====

In [None]:
# This is not currently being used to select the model
# The Bedrock Lambda should be modified to accept different LLM models
# Currently the Bedrock Lambda is using Llama-3-8B-Instruct
models2test = ["meta-llama/Meta-Llama-3-8B-Instruct", "ShadNygren/FineTuneTest-DrugAdverseEffects-SIDER-Diego1-50epochs-then-Diego2-10epochs", "ShadNygren/FineTuneTest-DrugAdverseEffects-SIDER-Diego1-50epochs", "ShadNygren/FineTuneTest-DrugAdverseEffects-SIDER-Diego2-10epochs"]
model_selected = 0
model_id = models2test[model_selected] # we need to run this for each model

# ===== Load Query Data =====

In [None]:
#print(root_dir)

In [None]:
# we read the drug side effect associations generated for 200 drugs
query_data = pd.read_excel(root_dir + '/data/drug_side_effectsALLdrugs_10randomADRs.xlsx')
#query_data = pd.read_excel(root_dir + '/data/drug_side_effects200drugs.xlsx')
#query_data = pd.read_excel(root_dir + '/data/drug_side_effects20rows.xlsx')
#query_data = pd.read_excel(root_dir + '/data/drug_side_effects1rows.xlsx')

# ===== AWS Bedrock Stuff =====

In [None]:
import requests

# Define a function to make an API POST request to AWS Bedrock
def http_post_to_aws_bedrock(prompt):
    url = AWS_BEDROCK_URL
    payload = {
        "prompt": prompt,
    }
    try:
        response = requests.post(url, json=payload)
        response.raise_for_status()  # Raise an error for bad status codes

        # Assuming the response is in text format (you can switch to .json() if needed)
        return response.text
    except requests.exceptions.RequestException as e:
        return {"error": str(e)}

# ===== Using OpenAI to compute embedding =====

In [None]:
%pip install openai

In [None]:
import openai

In [None]:
from openai import OpenAI
client = OpenAI()

def get_embedding(text, model="text-embedding-ada-002"):
   #text = text.replace("\n", " ")
   return client.embeddings.create(input = [text], model=model).data[0].embedding

In [None]:
print(get_embedding("Hello world"))

# ===== Filter RAG =====

In [None]:
def filter_rag(rag_results: list = [], terms: list = []) -> list:
    """
    This function filters the list of strings (`rag_results`) by checking if they contain all of the specified search terms in the `terms` list.

    Args:
        rag_results (list): A list of strings returned from a RAG (Retrieval-Augmented Generation) query.
        terms (list): A list of strings that should be present in the results for them to be included in the filtered list.

    Returns:
        list: A list of strings from `rag_results` that contain all the strings in the `terms` list.
              The order of strings in the returned list will be the same as in the input list.
    """
    filtered_results = []

    for result in rag_results:
        if all(term.lower() in result["text"].lower() for term in terms):
            filtered_results.append(result)

    if len(filtered_results) == 0:
        filtered_results.append({"text": "No, the side effect " + terms[1] + " is not listed as an adverse effect, adverse reaction or side effect of the drug " + terms[0]})
    return filtered_results


# ===== Pinecone VectorDB for RAG =====

In [None]:
# To install with gRPC run:
#%pip install "pinecone-client[grpc]"

# To install without gRPC run:
%pip install pinecone-client

In [None]:
#from pinecone.grpc import PineconeGRPC as Pinecone
from pinecone import Pinecone
from pinecone import ServerlessSpec

In [None]:
pc = Pinecone(api_key=PINECONE_API_KEY)

In [None]:
def query_pinecone_old(query_embedding_vector, namespace="drug-side-effects-formatA", top_k=1):
    # Pinecone Index
    index_name = "drug-side-effects-text-embedding-ada-002"
    index = pc.Index(index_name)

    query_results = index.query(
        namespace=namespace,
        vector=query_embedding_vector,
        top_k=top_k,
        include_values=True,
        #include_values=False,
        include_metadata=True
    )

    # Retrieve metadata from the results
    metadata_results = [match['metadata'] for match in query_results['matches']]
    print(metadata_results)
    #print("query_pinecone_result = " + query_results["metadata"]["text"])
    #return query_results["metadata"]["text"]
    return metadata_results


In [None]:
def query_pinecone(query_embedding_vector, namespace="drug-side-effects-formatA", top_k=1):
    """
    Queries a Pinecone index with a given query embedding vector and retrieves metadata from the matched results.

    Parameters:
    -----------
    query_embedding_vector : list or numpy.ndarray
        The embedding vector used to query the Pinecone index. This should be a vector that represents the query in the same embedding space as the indexed data.

    namespace : str, optional
        The namespace within the Pinecone index to query. This allows for segmentation of data within the same index.
        The default is "drug-side-effects-formatA".

    top_k : int, optional
        The number of top results to return based on similarity to the query embedding vector. The default is 1.

    Returns:
    --------
    metadata_results : list[dict]
        A list of metadata dictionaries from the matched results. Each dictionary contains metadata associated with a match from the Pinecone index.

    Example:
    --------
    >>> query_embedding_vector = [0.1, 0.2, 0.3, 0.4]
    >>> metadata = query_pinecone(query_embedding_vector, top_k=5)
    >>> print(metadata)
    [{'text': 'example text 1', ...}, {'text': 'example text 2', ...}, ...]
    """
    # Pinecone Index
    index_name = "drug-side-effects-text-embedding-ada-002"
    index = pc.Index(index_name)

    query_results = index.query(
        namespace=namespace,
        vector=query_embedding_vector,
        top_k=top_k,
        include_values=True,
        include_metadata=True
    )

    # Retrieve metadata from the results
    metadata_results = [match['metadata'] for match in query_results['matches']]
    print(metadata_results)
    return metadata_results


In [None]:
def rag_query(query, namespace, drug_name, side_effect):
    # Send POST request to AWS Bedrock for RAG query
    # Replace with the actual implementation
    embedding = get_embedding(text=query)
    #response = post_data({"prompt": query})
    rag_results = query_pinecone(query_embedding_vector=embedding, namespace=namespace, top_k = 5)
    print("rag_query rag_results = " + str(rag_results))

    filtered_rag_results = filter_rag(rag_results=rag_results, terms=[drug_name, side_effect])
    print("rag_query filtered_rag_results = " + str(filtered_rag_results))

    str_rag_results = "\n\n### RAG Results:\n\n"
    for filtered_result in filtered_rag_results:
        print("rag_query filtered_result = " + str(filtered_result))
        str_rag_results += filtered_result["text"] + "\n\n"
    print("rag_query str_rag_results = " + str_rag_results)

    #modified_rag_prompt = "Answer the following question starting with YES or NO followed by explanation. Your answer should be short, succinct, and to the point. Do not embellish your answer or acknowledg that this is a great question etc. Use the RAG Results to help formulate and generate your answer.\n\n" + "### Question:\n\n" + query + "\n\n" + rag_results
    modified_rag_prompt = "You are asked to answer the following question with a single word: YES or NO. Base your answer strictly on the RAG Results provided below. After your YES or NO answer, briefly explain your reasoning using the information from the RAG Results. Do not infer or speculate beyond the provided information.\n\n" + "### Question:\n\n" + query + "\n\n" + str_rag_results

    print("modified_rag_prompt = " + str(modified_rag_prompt))
    response = http_post_to_aws_bedrock({"prompt": modified_rag_prompt})

    # This is assuming response is JSON
    # Not really because http_post_to_aws_bedrock is also making this assumption so handle the JSON there
    return response

# ----- RAG Testing -----


In [None]:
print(rag_query("Is diabetic hyperosmolar coma an adverse effect of econazole?", "drug-side-effects-formatA", "econazole", "diabetic hyperosmolar coma"))

In [None]:
print(rag_query("Is carpal tunnel syndrome an adverse effect of tesamorelin?", "drug-side-effects-formatA", "tesamorelin", "carpal tunnel syndrome"))

In [None]:
print(rag_query("Is haemoglobin an adverse effect of tesamorelin?", "drug-side-effects-formatA", "tesamorelin", "haemoglobin"))

In [None]:
print(rag_query("Is skin reaction an adverse effect of lomustine?", "drug-side-effects-formatB", "lomustine", "skin reaction"))

In [None]:
print(rag_query("Is infusion site paraesthesia an adverse effect of lomustine?", "drug-side-effects-formatB", "lomustine", "infusion site paraesthesia"))

In [None]:
print(rag_query("Is meningioma an adverse effect of gadobenate?", "drug-side-effects-formatA", "gadobenate", "meningioma"))

In [None]:
print(rag_query("Is microcephaly an adverse effect of gadobenate?", "drug-side-effects-formatA", "gadobenate", "microcephaly"))

In [None]:
print(rag_query("Is abnormal dreams an adverse effect of gadobenate?", "drug-side-effects-formatA", "gadobenate", "abnormal dreams"))

In [None]:
print(rag_query("Is abnormal dreams an adverse effect of gadobenate?", "drug-side-effects-formatB", "gadobenate", "abnormal dreams"))

# ===== GraphRAG Stuff =====

In [None]:
#%pip install --upgrade --quiet  langchain langchain-community langchain-openai langchain-experimental neo4j
%pip install --upgrade neo4j

In [None]:
from neo4j import GraphDatabase
import os

In [None]:
from neo4j import GraphDatabase

# create a neo4j session to run queries
neo4j_driver = GraphDatabase.driver(
    uri = os.environ["NEO4J_URI"],
    auth = (os.environ["NEO4J_USERNAME"],
            os.environ["NEO4J_PASSWORD"]))

#session = driver.session()

In [None]:
# directly show the graph resulting from the given Cypher query
default_cypher = "MATCH (s)-[r:May_Cause_Side_Effect]->(t) RETURN s,r,t LIMIT 5"

#default_cypher = """
#MATCH (s)-[r:May_Cause_Side_Effect]->(t)
#WHERE s.name = 'carnitine' AND t.name = 'amblyopia'
#RETURN s, r, t
#"""

In [None]:
import re
from neo4j import GraphDatabase

def escape_special_characters(input_string):
    # Escape special characters for Neo4j Cypher strings
    return re.sub(r"(['\\])", r"\\\1", input_string)

def graphrag_query_drug_side_effect(query, drug_name, side_effect):
    # Escape the drug name and side effect strings to handle special characters
    drug_name = escape_special_characters(drug_name)
    side_effect = escape_special_characters(side_effect)

    drug_name_lower = drug_name.lower()
    side_effect_lower = side_effect.lower()

    cypher = f"""
    MATCH (s)-[r:May_Cause_Side_Effect]->(t)
    WHERE s.name = '{drug_name_lower}' AND t.name = '{side_effect_lower}'
    RETURN s, r, t
    """

    with neo4j_driver.session() as session:
        # Run the Cypher query
        cypher_result = session.run(cypher)

        # Prepare the output list
        connected_nodes = []
        for record in cypher_result:
            s = record['s']
            r = record['r']
            t = record['t']

            # Extract the 'name' property from the source and target nodes
            source_name = s.get('name', 'Unknown')
            target_name = t.get('name', 'Unknown')

            # Format the output string
            connected_nodes.append(
                f"{source_name} may cause side effect {target_name}"
            )

    if len(connected_nodes) > 0:
        #query_result = side_effect + " is an adverse effect of  " + drug_name
        query_result = "Yes, the side effect " + side_effect + " is listed as an adverse effect, adverse reaction or side effect of the drug " + drug_name
    else:
        #query_result = side_effect  + " is not an adverse effect of " + drug_name
        query_result = "No, the side effect " + side_effect + " is not listed as an adverse effect, adverse reaction or side effect of the drug " + drug_name

    #return query_result

    #=======================
    graphrag_results = "\n\n### GraphRAG Results:\n\n" + query_result + "\n\n"
    #for result in results:
    #    print("rag_query result = " + str(result))
    #    rag_results += result["text"] + "\n\n"
    print("graphrag_query graphrag_results = " + graphrag_results)

    #modified_graphrag_prompt = "Answer the following question starting with YES or NO followed by explanation. Your answer should be short, succinct, and to the point. Do not embellish your answer or acknowledge that this is a great question etc. Use the GraphRAG Results to help formulate and generate your answer.\n\n" + "### Question:\n\n" + query + "\n\n" + graphrag_results
    modified_graphrag_prompt = "You are asked to answer the following question with a single word: YES or NO. Base your answer strictly on the GraphRAG Results provided below. After your YES or NO answer, briefly explain your reasoning using the information from the GraphRAG Results. Do not infer or speculate beyond the provided information.\n\n" + "### Question:\n\n" + query + "\n\n" + graphrag_results
    print("modified_graphrag_prompt = " + str(modified_graphrag_prompt))
    response = http_post_to_aws_bedrock({"prompt": modified_graphrag_prompt})

    #===============================
    return response


## Example usage
#neo4j_driver = GraphDatabase.driver(
#    uri=os.environ["NEO4J_URI"],
#    auth=(os.environ["NEO4J_USERNAME"], os.environ["NEO4J_PASSWORD"])
#)
#
#result = graphrag_query_drug_side_effect("hydralazine", "hodgkin's disease lymphocyte depletion type stage unspecified")
#print(result)

# ===== GraphRAG Testing =====

In [None]:
# Example usage
#default_cypher = "MATCH (s)-[r:May_Cause_Side_Effect]->(t) RETURN s, r, t LIMIT 5"
#connected_nodes_string = graphrag_query(default_cypher)
#print(connected_nodes_string)

In [None]:
print(graphrag_query_drug_side_effect("Does carnitine cause amblyopia?", "carnitine", "amblyopia"))

In [None]:
print(graphrag_query_drug_side_effect("Does carnitine cause euphoria?", "carnitine", "euphoria"))

In [None]:
print(graphrag_query_drug_side_effect("Does Estrofem cause Depression?", "estrofem", "Depression"))

In [None]:
print(graphrag_query_drug_side_effect("Is sebaceous hyperplasia an adverse effect of gemeprost?", "gemeprost", "sebaceous hyperplasia"))

# Parallelized version of original code

In [None]:
import pandas as pd
import concurrent.futures
from tqdm import tqdm
import requests

In [None]:
def llm(prompt):
    # Send POST request to AWS Bedrock with the prompt
    modified_prompt = "You are asked to answer the following question with a single word: YES or NO. Base your answer strictly on the GraphRAG Results provided below. After your YES or NO answer, briefly explain your reasoning.\n\n" + "### Question:\n\n" + prompt

    print("modified_prompt = " + str(modified_prompt))
    response = http_post_to_aws_bedrock({"prompt": modified_prompt})
    #print("response = " + str(response))
    # This is assuming response is JSON
    # Not really because http_post_to_aws_bedrock is also making this assumption so handle the JSON there
    return response

In [None]:
def binary_answer(text):
    print("binary_answer text = " + str(text))
    return 1 if 'YES' in text else 0

In [None]:
def process_row(i, query_data, questions):
    results = {}
    se = query_data.iloc[i]['side effect'].lower()
    drug_name = query_data.iloc[i]['drug name']
    label = query_data.iloc[i]['label']

    for c, question in enumerate(questions):
        q = question.replace('[SE]', se).replace('[DRUG]', drug_name)
        print("q = " + str(q))

        # Method 1: Basic LLM
        response_llm = llm(prompt=q)
        print("response_llm = " + str(response_llm))

        # Method 2: LLM + RAG A
        response_rag_A = rag_query(query=q, namespace="drug-side-effects-formatA", drug_name=drug_name, side_effect=se)
        print("response_rag_A = " + str(response_rag_A))

        # Method 3: LLM + RAG B
        response_rag_B = rag_query(query=q, namespace="drug-side-effects-formatB", drug_name=drug_name, side_effect=se)
        print("response_rag_B = " + str(response_rag_B))

        # This needs to be reimplemented to use GraphRAG
        # Method 4: LLM + Graph RAG
        #if label == 1:
        #    temp = f"Answer the following question based on this information: The drug {drug_name} causes the adverse effect {se}. "
        #else:
        #    temp = f"Answer the following question based on this information: The drug {drug_name} does not cause the adverse effect {se}. "
        #response_graph_rag = llm(prompt=temp + q)
        response_graph_rag = graphrag_query_drug_side_effect(query=q, drug_name=drug_name, side_effect=se)
        print("response_graph_rag = " + str(response_graph_rag))

        # Store results in the dictionary
        results['prompt' + str(c + 1)] = q
        results['output_llm' + str(c + 1)] = binary_answer(response_llm)
        results['output_ragA' + str(c + 1)] = binary_answer(response_rag_A)
        results['output_ragB' + str(c + 1)] = binary_answer(response_rag_B)
        results['output_graphrag' + str(c + 1)] = binary_answer(response_graph_rag)

    return i, results

In [None]:
def update_dataframe(query_data, questions):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Use a list comprehension to map process_row function over the DataFrame indices
        futures = [executor.submit(process_row, i, query_data, questions) for i in range(query_data.shape[0])]

        #for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
        for future in concurrent.futures.as_completed(futures): #, total=len(futures):
            i, results = future.result()
            # Update the DataFrame with the results
            for key, value in results.items():
                query_data.loc[i, key] = value

In [None]:
# Main logic
if True:  # Placeholder for your condition
    questions = ['Is [SE] an adverse effect of [DRUG]?']
    update_dataframe(query_data, questions)

    # Save the results
    query_data.to_excel(root_dir + '/results/results_model_' + str(model_selected)+'_ALLdrugs_10randomADRs.xlsx')
    #query_data.to_excel(root_dir + '/results/results_model_' + str(model_selected)+'_200drugs.xlsx')
    #query_data.to_excel(root_dir + '/results/results_model_' + str(model_selected)+'_20rows.xlsx')
    #query_data.to_excel(root_dir + '/results/results_model_' + str(model_selected)+'_1rows.xlsx')