In [61]:


import json
import os
import re
import archs4py
import nest_asyncio

archive_file = "human_gene_v2.5.h5"

nest_asyncio.apply()

def get_config_data():
    with open("local_data.json") as json_file:
        data = json.load(json_file)
        return data["OPENAI_KEY"]["key"]
        
OPENAI_API_KEY = get_config_data()

os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

#from llama_index.core import (
#    SimpleDirectoryReader,
#    VectorStoreIndex,
#    StorageContext,
#    load_index_from_storage,
#)
from llama_index.llms.openai import OpenAI

#llm = OpenAI(model="gpt-4o-mini")

from llama_index.core import Settings
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from pydantic import BaseModel, ConfigDict

# Use ollama in JSON mode
Settings.llm = OpenAI(
    model="gpt-4o-mini"
)
Settings.embed_model = OpenAIEmbedding(model_name="text-embedding-3-small")

class MyModel(BaseModel):
    df: pd.DataFrame

    model_config = ConfigDict(arbitrary_types_allowed=True)

###

def get_archs4py_expression_counts(query, file):

    """
    Accepts a parsed user query and filters the dataset using archs4py.
    
    Parameters:
    - query: str, the user query
    - file: the loaded dataset
    
    Returns:
    - temp_data: a subset of the dataset relevant to the query that consists of a list of data objects (e.g. meta and RNAseq data)
    """
    try:
        temp_data = archs4py.data.samples(file, query)
        return temp_data

    except Exception as e:
        return f"Error retrieving data: {str(e)}"


#global meta_file_counter
#meta_file_counter = 0
#global counts_file_counter 
#counts_file_counter = 0


In [39]:
#test = updated_archs4py_query("Look up gene expression for GSM1132425, GSM1132426, GSM1132427, GSM1179927, and GSM1179928")
#print(test['meta'].shape)



In [63]:
## Define tools
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from llama_index.core.agent import ReActAgent
import pandas as pd
from llama_index.experimental.query_engine import PandasQueryEngine
from llama_index.core import SimpleDirectoryReader

# Tool to accept user input to retrieve data from GEO using existing functions and ARCHS4PY
from llama_index.core.tools import FunctionTool
from pydantic import Field

def get_weather(
    location: str = Field(
        description="A city name and state, formatted like '<name>, <state>'"
    ),
) -> str:
    """Useful for getting the weather for a given location."""
    return f"The weather in {location} is sunny."


def get_time(
    location: str = Field(
        description="A city name and state, formatted like '<name>, <state>'"
    ),
) -> str:
    """Useful for getting the current time for a given location."""
    return f"The current time in {location} is 12:00 PM."

def test_parse_archs4py_query(
    full_query: str = Field(
        description="A query for GEO datasets"
    ),
) -> str:
    """Useful for retrieving datasets from GEO and saving them locally. Once a dataset is saved, this tool shouldn't be used again without confirmation from the user."""

    try:
    # Convert the query to uppercase to ensure case-insensitive matching
        archive_file = "human_gene_v2.5.h5"
        directory = "./testing_data"
        if "GSM" in full_query:
            query_upper = full_query.upper()
            # Perform operation for queries containing "GSM"
            # Split the string into a list of elements using spaces as the separator
            temp_elements = query_upper.split()
            pattern = r'[^a-zA-Z0-9]'
            # Use a list comprehension to apply the regex substitution to each string
            temp_elements = [re.sub(pattern, '', s) for s in temp_elements]
        
            # Filter the list to include only elements containing "GSE"
            temp_search_query = [temp_element for temp_element in temp_elements if "GSM" in temp_element]
            temp_meta = archs4py.meta.samples(archive_file, temp_search_query, 
                meta_fields=["geo_accession", "series_id", "characteristics_ch1", "extract_protocol_ch1", "source_name_ch1", "title"])
            #query_dataset = pd.DataFrame(temp_meta)
            #global meta_file_counter
            #print(meta_file_counter)
            #meta_file_counter += 1
            #file_path = os.path.join(directory, f"temp_meta.csv")
            #query_dataset.to_csv(file_path, index=False)
            #return f"DataFrame saved successfully at {file_path}"
            #df_id = str(uuid.uuid4())
            #GLOBAL_DATAFRAME_STORE[df_id] = df
            #return df_id
            return df.to_json(query_dataset)

        elif "GSE" in full_query:
            query_upper = full_query.upper()
            # Perform operation for queries containing "GSE"
            # Split the string into a list of elements using spaces as the separator
            temp_elements = query_upper.split()
            #temp_elements = remove_non_alphanumeric(temp_elements)
            pattern = r'[^a-zA-Z0-9]'
            # Use a list comprehension to apply the regex substitution to each string
            temp_elements = [re.sub(pattern, '', s) for s in temp_elements]
        
            # Filter the list to include only elements containing "GSE"
            temp_search_query = [temp_element for temp_element in temp_elements if "GSE" in temp_element]

            if len(temp_search_query) > 1:
                prior_dataset = []
            for temp_index, temp_value in enumerate(temp_search_query):
                temp_meta = archs4py.meta.series(archive_file, temp_value, 
                    meta_fields=["geo_accession", "series_id", "characteristics_ch1", "extract_protocol_ch1", "source_name_ch1", "title"])
                query_dataset = pd.DataFrame(temp_meta)
                if temp_index > 0:
                    query_dataset = pd.concat([prior_dataset, 
                                                    query_dataset], 
                                                    ignore_index=True)
                if temp_index < len(temp_search_query) - 1:
                    prior_dataset = query_dataset
                if temp_index == len(temp_search_query) - 1:
                    #global meta_file_counter
                    #print(meta_file_counter)
                    #meta_file_counter += 1
                    #file_path = os.path.join(directory, f"temp_meta.csv")
                    #query_dataset.to_csv(file_path, index=False)
                    #return f"DataFrame saved successfully at {file_path}"
                    #df_id = str(uuid.uuid4())
                    #GLOBAL_DATAFRAME_STORE[df_id] = df
                    #return df_id
                    return df.to_json(query_dataset)
        else:
            # Perform default operation for other queries
            #pattern = r'"[^"]*"|\S+'
    
            # Find all matches in the text
            #matches = re.findall(pattern, full_query)
            #temp_search_query = [element for element in matches if element.startswith('"') and element.endswith('"')]
            #temp_search_query = [element[1:-1] for element in temp_search_query]
            temp_search_query = [full_query]
            if len(temp_search_query) > 1:
                prior_dataset = []
            for temp_index, temp_value in enumerate(temp_search_query):
                print(temp_value)
                temp_meta = archs4py.meta.meta(archive_file, temp_value, 
                    meta_fields=["geo_accession", "series_id", "characteristics_ch1", "extract_protocol_ch1", "source_name_ch1", "title"],
                    remove_sc=True)
                print(temp_meta.shape)
                if temp_index > 0:
                    temp_meta = temp_meta[temp_meta.index.isin(prior_dataset.index)]
                if temp_index < len(temp_search_query) - 1:
                    prior_dataset = temp_meta
                if temp_index == len(temp_search_query) - 1:
                    query_dataset = pd.DataFrame(temp_meta)
                    #global meta_file_counter
                    #print(meta_file_counter)
                    #meta_file_counter += 1
                    #file_path = os.path.join(directory, f"temp_meta.csv")
                    #query_dataset.to_csv(file_path, index=False)
                    #return f"DataFrame saved successfully at {file_path}"
                    #df_id = str(uuid.uuid4())
                    #GLOBAL_DATAFRAME_STORE[df_id] = df
                    #return df_id
                    return df.to_json(query_dataset)
    except Exception as e:
        return f"Error retrieving data: {str(e)}"


def create_metadata_query_engine(
    query_dataset: str = Field(
        description="A tool to load locally saved dataframes from csv files"
    ),
) -> str:
    """Useful for retrieving datasets from GEO and saving them locally."""
    try:
        #directory = "./testing_data"
        #file_path = os.path.join(directory, f"temp_meta.csv")
        #SimpleDirectoryReader(input_files=[file_path])
        #query_dataset = pd.read_csv(file_path)
        return PandasQueryEngine(df=query_dataset, verbose=True, synthesize_response=True)
    except Exception as e:
        return f"Error retrieving data: {str(e)}"

metadata_query_tool = QueryEngineTool(
        query_engine=create_metadata_query_engine(),
        metadata=ToolMetadata(
            name="GEO_metadata",
            description=(
                "Provides information about GEO_metadata for the selected dataset. "
                "The tool should use the dataframe stored in memory that was downloaded from GEO."
            ),
        ),
    )
archs4py_tool = FunctionTool.from_defaults(test_parse_archs4py_query, name="use_archs4py", description="Useful for retrieving gene expression datasets from GEO.")

from langchain.agents import Tool
import pandas as pd
import io

def load_csv_into_dataframe_and_query(json_str: str) -> pd.DataFrame:
    """Load a CSV text into a Pandas DataFrame object."""
    df = pd.read_json(io.StringIO(json_str))
    if df is None:
        return "No dataframe found for ID: " + df_id
    # Do something with `df`
    return str(df.head())
    #return df

#def query_dataframe(df: pd.DataFrame, query_text: str) -> str:
#    if df is None:
#        return "No dataframe found for ID: " + df_id
#    # Do something with `df`
#    return str(df.head())

load_csv_tool = FunctionTool.from_defaults(load_csv_into_dataframe_and_query, name="load_csv", description="Converts CSV text into a Pandas DataFrame object and queries it.")

#query_df_tool = FunctionTool.from_defaults(query_dataframe, name="query_df", description="Queries a Pandas DataFrame object with a question string.")

weather_tool = FunctionTool.from_defaults(get_weather, name="get_weather", description="Useful for getting the weather for a given location.")
time_tool = FunctionTool.from_defaults(get_time, name="get_time", description="Useful for getting the current time for a given location.")
#all_tools = [weather_tool, time_tool,archs4py_tool,metadata_query_tool]
all_tools = [weather_tool, time_tool,archs4py_tool,load_csv_tool]



In [6]:
temp_meta = archs4py.meta.meta(archive_file, "iPSC", 
                    meta_fields=["geo_accession", "series_id", "characteristics_ch1", "extract_protocol_ch1", "source_name_ch1", "title"], 
                    remove_sc=True)
print(temp_meta.shape)

100%|██████████| 6/6 [00:05<00:00,  1.19it/s]


(24663, 6)


In [64]:
#Look up gene expression for GSM1132425, GSM1132426, GSM1132427, GSM1179927, and GSM1179928
#Look up gene expression for the terms IPSC

## Create human in the loop agent with tools

#meta_file_counter = 0
#counts_file_counter = 0

from llama_index.core.agent import AgentRunner, ReActAgent
from llama_index.agent.openai import OpenAIAgentWorker, OpenAIAgent
from llama_index.agent.openai import OpenAIAgentWorker


agent_llm = OpenAI(model="gpt-4o-mini",timeout=120)
# agent_llm = OpenAI(model="gpt-4-1106-preview")

#agent = ReActAgent.from_tools(
#    all_tools, llm=agent_llm, verbose=True, max_iterations=20,
#)

agent = OpenAIAgent.from_tools(
    all_tools, llm=agent_llm, verbose=True, max_iterations=20,
)

def chat_repl(exit_when_done: bool = True):
    """Chat REPL.

    Args:
        exit_when_done(bool): if True, automatically exit when step is finished.
            Set to False if you want to keep going even if step is marked as finished by the agent.
            If False, you need to explicitly call "exit" to finalize a task execution.

    """
    task_message = None
    while task_message != "exit":
        task_message = input(">> Human: ")
        if task_message == "exit":
            break

        task = agent.create_task(task_message)

        response = None
        step_output = None
        message = None
        while message != "exit":
            if message is None or message == "":
                step_output = agent.run_step(task.task_id)
            else:
                step_output = agent.run_step(task.task_id, input=message)
            if exit_when_done and step_output.is_last:
                print(
                    ">> Task marked as finished by the agent, executing task execution."
                )
                break

            message = input(
                ">> Add feedback during step? (press enter/leave blank to continue, and type 'exit' to stop): "
            )
            if message == "exit":
                break

        if step_output is None:
            print(">> You haven't run the agent. Task is discarded.")
        elif not step_output.is_last:
            print(">> The agent hasn't finished yet. Task is discarded.")
        else:
            response = agent.finalize_response(task.task_id)
        print(f"Agent: {str(response)}")


## Test agent
chat_repl()


Added user message to memory: Look up gene expression for the terms IPSC
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
=== Calling Function ===
Calling function: use_archs4py with args: {"full_query":"IPSC"}
IPSC


100%|██████████| 6/6 [00:05<00:00,  1.11it/s]


(24663, 6)
Got output: Error retrieving data: name 'df' is not defined

>> The agent hasn't finished yet. Task is discarded.
Agent: None


In [22]:
test = pd.read_csv("./testing_data/temp_meta.csv")
print(test.shape)
test.head()

(636, 6)


Unnamed: 0,geo_accession,series_id,characteristics_ch1,extract_protocol_ch1,source_name_ch1,title
0,GSM1556306,"GSE63734,GSE80163","DIAGNOSIS: SCHIZOPHRENIA,GENDER: MALE,REPOSITO...","CELLS WERE LYSED IN RNA BEE (TEL-TEST, INC). R...",HIPSC FOREBRAIN NEURONS,SZ1_A
1,GSM1556307,"GSE63734,GSE80163","DIAGNOSIS: SCHIZOPHRENIA,GENDER: MALE,REPOSITO...","CELLS WERE LYSED IN RNA BEE (TEL-TEST, INC). R...",HIPSC FOREBRAIN NEURONS,SZ1_B
2,GSM1556308,"GSE63734,GSE80163","DIAGNOSIS: SCHIZOPHRENIA,GENDER: MALE,REPOSITO...","CELLS WERE LYSED IN RNA BEE (TEL-TEST, INC). R...",HIPSC FOREBRAIN NEURONS,SZ2
3,GSM1556309,"GSE63734,GSE80163","DIAGNOSIS: SCHIZOPHRENIA,GENDER: FEMALE,REPOSI...","CELLS WERE LYSED IN RNA BEE (TEL-TEST, INC). R...",HIPSC FOREBRAIN NEURONS,SZ3_3
4,GSM1556310,"GSE63734,GSE80163","DIAGNOSIS: SCHIZOPHRENIA,GENDER: FEMALE,REPOSI...","CELLS WERE LYSED IN RNA BEE (TEL-TEST, INC). R...",HIPSC FOREBRAIN NEURONS,SZ3_5
