In [5]:


import json
import os
import re
import archs4py
import nest_asyncio

archive_file = "human_gene_v2.5.h5"

nest_asyncio.apply()

def get_config_data():
    with open("local_data.json") as json_file:
        data = json.load(json_file)
        return data["OPENAI_KEY"]["key"]
        
OPENAI_API_KEY = get_config_data()

os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

#from llama_index.core import (
#    SimpleDirectoryReader,
#    VectorStoreIndex,
#    StorageContext,
#    load_index_from_storage,
#)
from llama_index.llms.openai import OpenAI

#llm = OpenAI(model="gpt-4o-mini")

from llama_index.core import Settings
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding

# Use ollama in JSON mode
Settings.llm = OpenAI(
    model="gpt-4o-mini"
)
Settings.embed_model = OpenAIEmbedding(model_name="text-embedding-3-small")

###

def remove_non_alphanumeric(strings):
    # Define a regular expression pattern to match non-alphanumeric characters
    pattern = r'[^a-zA-Z0-9]'
    
    # Use a list comprehension to apply the regex substitution to each string
    cleaned_strings = [re.sub(pattern, '', s) for s in strings]
    
    return cleaned_strings

def split_and_retain_quoted(strings):

    # Regular expression to match quoted strings or words
    pattern = r'"[^"]*"|\S+'
    
    # Find all matches in the text
    matches = re.findall(pattern, strings)
    clean_matches = [element for element in matches if element.startswith('"') and element.endswith('"')]
    clean_matches = [element[1:-1] for element in clean_matches]
    
    return clean_matches


def updated_archs4py_query(full_query):
    """
    Parses a user query into chunks suitable for the handle_archs4py_query function.
    
    Parameters:
    - full_query: str, the unfiltereduser query
    
    Returns:
    - query_dataset: a subset of the dataset relevant to the full query that consists of a list of data objects (e.g. meta and RNAseq data)
    """
    
    try:
        # Convert the query to uppercase to ensure case-insensitive matching
        
        if "GSM" in full_query:
            query_upper = full_query.upper()
            # Perform operation for queries containing "GSM"
            # Split the string into a list of elements using spaces as the separator
            temp_elements = query_upper.split()
            #temp_elements = remove_non_alphanumeric(temp_elements)
            pattern = r'[^a-zA-Z0-9]'
            # Use a list comprehension to apply the regex substitution to each string
            temp_elements = [re.sub(pattern, '', s) for s in temp_elements]
        
            # Filter the list to include only elements containing "GSE"
            temp_search_query = [temp_element for temp_element in temp_elements if "GSM" in temp_element]
            temp_meta = archs4py.meta.samples(archive_file, temp_search_query, 
                meta_fields=["geo_accession", "series_id", "characteristics_ch1", "extract_protocol_ch1", "source_name_ch1", "title"])
            query_dataset = {'meta': temp_meta}
            return query_dataset
            #temp_search_type = "samples"

            # Select data for the samples
            #return handle_archs4py_query(temp_search_query, archive_file, temp_search_type)

        elif "GSE" in full_query:
            query_upper = full_query.upper()
            # Perform operation for queries containing "GSE"
            # Split the string into a list of elements using spaces as the separator
            temp_elements = query_upper.split()
            #temp_elements = remove_non_alphanumeric(temp_elements)
            pattern = r'[^a-zA-Z0-9]'
            # Use a list comprehension to apply the regex substitution to each string
            temp_elements = [re.sub(pattern, '', s) for s in temp_elements]
        
            # Filter the list to include only elements containing "GSE"
            temp_search_query = [temp_element for temp_element in temp_elements if "GSE" in temp_element]
            #temp_search_type = "series"

            if len(temp_search_query) > 1:
                prior_dataset = []
            for temp_index, temp_value in enumerate(temp_search_query):
                temp_meta = archs4py.meta.series(archive_file, temp_value, 
                    meta_fields=["geo_accession", "series_id", "characteristics_ch1", "extract_protocol_ch1", "source_name_ch1", "title"])
                query_dataset = {'meta': temp_meta}
                #temp_dataset = handle_archs4py_query(temp_value, archive_file, temp_search_type)
                if temp_index > 0:
                    query_dataset['meta'] = pd.concat([prior_dataset['meta'], 
                                                    query_dataset['meta']], 
                                                    ignore_index=True)
                if temp_index < len(temp_search_query) - 1:
                    prior_dataset = query_dataset
            return query_dataset

        else:
            # Perform default operation for other queries
            #temp_search_query = split_and_retain_quoted(full_query)
            pattern = r'"[^"]*"|\S+'
    
            # Find all matches in the text
            matches = re.findall(pattern, temp_search_query)
            temp_search_query = [element for element in matches if element.startswith('"') and element.endswith('"')]
            temp_search_query = [element[1:-1] for element in temp_search_query]
            #temp_search_type = "terms"
            #print(temp_search_query)
            if len(temp_search_query) > 1:
                prior_dataset = []
            for temp_index, temp_value in enumerate(temp_search_query):
                #print(temp_value)
                temp_meta = archs4py.meta.meta(archive_file, temp_value, 
                    meta_fields=["geo_accession", "series_id", "characteristics_ch1", "extract_protocol_ch1", "source_name_ch1", "title"], 
                    remove_sc=True)
                query_dataset = {'meta': temp_meta}
                #temp_dataset = handle_archs4py_query(temp_value, archive_file, temp_search_type)
                #print(query_dataset['meta'].shape)
                if temp_index > 0:
                    query_dataset['meta'] = query_dataset['meta'][query_dataset['meta'].index.isin(prior_dataset['meta'].index)]
                    #print(query_dataset['meta'].shape)
                if temp_index < len(temp_search_query) - 1:
                    prior_dataset = query_dataset
            return query_dataset

    except Exception as e:
        return f"Error retrieving data: {str(e)}"


def get_archs4py_expression_counts(query, file):

    """
    Accepts a parsed user query and filters the dataset using archs4py.
    
    Parameters:
    - query: str, the user query
    - file: the loaded dataset
    
    Returns:
    - temp_data: a subset of the dataset relevant to the query that consists of a list of data objects (e.g. meta and RNAseq data)
    """
    try:
        temp_data = archs4py.data.samples(file, query)
        return temp_data

    except Exception as e:
        return f"Error retrieving data: {str(e)}"



In [22]:
#test = updated_archs4py_query("Look up gene expression for GSM1132425, GSM1132426, GSM1132427, GSM1179927, and GSM1179928")
#print(test['meta'].shape)

(5, 6)


In [21]:
## Define tools
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from llama_index.core.agent import ReActAgent
import pandas as pd
from llama_index.experimental.query_engine import PandasQueryEngine


# Tool to accept user input to retrieve data from GEO using existing functions and ARCHS4PY
from llama_index.core.tools import FunctionTool
from pydantic import Field


def get_weather(
    location: str = Field(
        description="A city name and state, formatted like '<name>, <state>'"
    ),
) -> str:
    """Useful for getting the weather for a given location."""
    return f"The weather in {location} is sunny."


def get_time(
    location: str = Field(
        description="A city name and state, formatted like '<name>, <state>'"
    ),
) -> str:
    """Useful for getting the current time for a given location."""
    return f"The current time in {location} is 12:00 PM."

def test_parse_archs4py_query(
    full_query: str = Field(
        description="A query for GEO datasets"
    ),
) -> str:
    """Useful for retrieving datasets from GEO."""
    #test = parse_archs4py_query(full_query)

    try:
    # Convert the query to uppercase to ensure case-insensitive matching
        archive_file = "human_gene_v2.5.h5"
        directory = "./testing_data"
        file_path = os.path.join(directory, f"test.csv")
        if "GSM" in full_query:
            query_upper = full_query.upper()
            # Perform operation for queries containing "GSM"
            # Split the string into a list of elements using spaces as the separator
            temp_elements = query_upper.split()
            #temp_elements = remove_non_alphanumeric(temp_elements)
            pattern = r'[^a-zA-Z0-9]'
            # Use a list comprehension to apply the regex substitution to each string
            temp_elements = [re.sub(pattern, '', s) for s in temp_elements]
        
            # Filter the list to include only elements containing "GSE"
            temp_search_query = [temp_element for temp_element in temp_elements if "GSM" in temp_element]
            temp_meta = archs4py.meta.samples(archive_file, temp_search_query, 
                meta_fields=["geo_accession", "series_id", "characteristics_ch1", "extract_protocol_ch1", "source_name_ch1", "title"])
            query_dataset = pd.DataFrame(temp_meta)
            #query_dataset = {'meta': temp_meta}
            query_dataset.to_csv(file_path, index=False)
            return f"DataFrame saved successfully at {file_path}"
            #return query_dataset
            #return PandasQueryEngine(df=query_dataset['meta'], verbose=True)
            #temp_search_type = "samples"

            # Select data for the samples
            #return handle_archs4py_query(temp_search_query, archive_file, temp_search_type)

        elif "GSE" in full_query:
            query_upper = full_query.upper()
            # Perform operation for queries containing "GSE"
            # Split the string into a list of elements using spaces as the separator
            temp_elements = query_upper.split()
            #temp_elements = remove_non_alphanumeric(temp_elements)
            pattern = r'[^a-zA-Z0-9]'
            # Use a list comprehension to apply the regex substitution to each string
            temp_elements = [re.sub(pattern, '', s) for s in temp_elements]
        
            # Filter the list to include only elements containing "GSE"
            temp_search_query = [temp_element for temp_element in temp_elements if "GSE" in temp_element]
            #temp_search_type = "series"

            if len(temp_search_query) > 1:
                prior_dataset = []
            for temp_index, temp_value in enumerate(temp_search_query):
                temp_meta = archs4py.meta.series(archive_file, temp_value, 
                    meta_fields=["geo_accession", "series_id", "characteristics_ch1", "extract_protocol_ch1", "source_name_ch1", "title"])
                query_dataset = pd.DataFrame(temp_meta)
                #query_dataset = {'meta': temp_meta}
                #temp_dataset = handle_archs4py_query(temp_value, archive_file, temp_search_type)
                if temp_index > 0:
                    query_dataset = pd.concat([prior_dataset, 
                                                    query_dataset], 
                                                    ignore_index=True)
                if temp_index < len(temp_search_query) - 1:
                    prior_dataset = query_dataset

                query_dataset.to_csv(file_path, index=False)
                return f"DataFrame saved successfully at {file_path}"
            #return query_dataset
            #return PandasQueryEngine(df=query_dataset['meta'], verbose=True)
        else:
            # Perform default operation for other queries
            #temp_search_query = split_and_retain_quoted(full_query)
            pattern = r'"[^"]*"|\S+'
    
            # Find all matches in the text
            matches = re.findall(pattern, temp_search_query)
            temp_search_query = [element for element in matches if element.startswith('"') and element.endswith('"')]
            temp_search_query = [element[1:-1] for element in temp_search_query]
            #temp_search_type = "terms"
            #print(temp_search_query)
            if len(temp_search_query) > 1:
                prior_dataset = []
            for temp_index, temp_value in enumerate(temp_search_query):
                #print(temp_value)
                temp_meta = archs4py.meta.meta(archive_file, temp_value, 
                    meta_fields=["geo_accession", "series_id", "characteristics_ch1", "extract_protocol_ch1", "source_name_ch1", "title"], 
                    remove_sc=True)
                query_dataset = pd.DataFrame(temp_meta)
                #query_dataset = {'meta': temp_meta}
                #temp_dataset = handle_archs4py_query(temp_value, archive_file, temp_search_type)
                #print(query_dataset['meta'].shape)
                if temp_index > 0:
                    query_dataset = query_dataset[query_dataset.index.isin(prior_dataset.index)]
                    #print(query_dataset['meta'].shape)
                if temp_index < len(temp_search_query) - 1:
                    prior_dataset = query_dataset
                query_dataset.to_csv(file_path, index=False)
                return f"DataFrame saved successfully at {file_path}"
                #return PandasQueryEngine(df=query_dataset['meta'], verbose=True)
        #if isinstance(query_dataset, dict):
        #    query_dataset = pd.DataFrame(query_dataset)
        #elif isinstance(query_dataset, pd.DataFrame):
        #    query_dataset = query_dataset
        #directory = "./testing_data"
        #os.makedirs(directory, exist_ok=True)
        #file_path = os.path.join(directory, f"test.csv")
        #query_dataset.to_csv(file_path, index=False)
        #return f"DataFrame saved successfully at {file_path}"
    except Exception as e:
        return f"Error retrieving data: {str(e)}"

    #test_output = print("The size of the dataset is: ", query_dataset['meta'].shape)
    #return f"{test_output}"


archs4py_tool = FunctionTool.from_defaults(test_parse_archs4py_query, name="use_archs4py", description="Useful for retrieving gene expression datasets from GEO.")
#query_engine = PandasQueryEngine(df=query_dataset['meta'], verbose=True)
#archs4py_tool = QueryEngineTool(
#        query_engine=test_parse_archs4py_query,
#        metadata=ToolMetadata(
#            name="archs4py_tool",
#            description=(
#                "Useful for retrieving datasets from GEO."
#            ),
#        ),
#    )
weather_tool = FunctionTool.from_defaults(get_weather, name="get_weather", description="Useful for getting the weather for a given location.")
time_tool = FunctionTool.from_defaults(get_time, name="get_time", description="Useful for getting the current time for a given location.")
all_tools = [weather_tool, time_tool,archs4py_tool]


In [22]:
#Look up gene expression for GSM1132425, GSM1132426, GSM1132427, GSM1179927, and GSM1179928

## Create human in the loop agent with tools

from llama_index.core.agent import AgentRunner, ReActAgent
from llama_index.agent.openai import OpenAIAgentWorker, OpenAIAgent
from llama_index.agent.openai import OpenAIAgentWorker

agent_llm = OpenAI(model="gpt-4o-mini")
# agent_llm = OpenAI(model="gpt-4-1106-preview")

agent = ReActAgent.from_tools(
    all_tools, llm=agent_llm, verbose=True, max_iterations=20
)

def chat_repl(exit_when_done: bool = True):
    """Chat REPL.

    Args:
        exit_when_done(bool): if True, automatically exit when step is finished.
            Set to False if you want to keep going even if step is marked as finished by the agent.
            If False, you need to explicitly call "exit" to finalize a task execution.

    """
    task_message = None
    while task_message != "exit":
        task_message = input(">> Human: ")
        if task_message == "exit":
            break

        task = agent.create_task(task_message)

        response = None
        step_output = None
        message = None
        while message != "exit":
            if message is None or message == "":
                step_output = agent.run_step(task.task_id)
            else:
                step_output = agent.run_step(task.task_id, input=message)
            if exit_when_done and step_output.is_last:
                print(
                    ">> Task marked as finished by the agent, executing task execution."
                )
                break

            message = input(
                ">> Add feedback during step? (press enter/leave blank to continue, and type 'exit' to stop): "
            )
            if message == "exit":
                break

        if step_output is None:
            print(">> You haven't run the agent. Task is discarded.")
        elif not step_output.is_last:
            print(">> The agent hasn't finished yet. Task is discarded.")
        else:
            response = agent.finalize_response(task.task_id)
        print(f"Agent: {str(response)}")


## Test agent
chat_repl()


> Running step 04ddeb49-61c0-4daa-a29d-203881069c11. Step input: Look up gene expression for GSM1132425, GSM1132426, GSM1132427, GSM1179927, and GSM1179928
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
[1;3;38;5;200mThought: The user is asking for gene expression data for specific GSM identifiers. I will use the tool to retrieve this information from GEO.
Action: use_archs4py
Action Input: {'full_query': 'GSM1132425, GSM1132426, GSM1132427, GSM1179927, GSM1179928'}
[0m[1;3;34mObservation: DataFrame saved successfully at ./testing_data\test.csv
[0m> Running step cd7f67ad-ec01-4c5f-8061-4461c23d060a. Step input: None
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
[1;3;38;5;200mThought: I have successfully retrieved the gene expression dat