# Goals

* Create a pipeline to gather new datasets from GEO and SRA

# Init

In [1]:
import os
import operator
from functools import partial
from pprint import pprint
from typing import Annotated, List, Dict, Tuple, Optional, Union, Any, TypedDict, Sequence
from pydantic import BaseModel, Field
from Bio import Entrez
import pandas as pd
from dotenv import load_dotenv
from langgraph.types import Send
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langgraph.graph import START, END, StateGraph

In [2]:
from SRAgent.agents.entrez import create_entrez_agent
from SRAgent.agents.utils import create_step_summary_chain
from SRAgent.agents.convert import create_convert_graph, invoke_convert_graph
from SRAgent.agents.metadata import create_metadata_graph, invoke_metadata_graph, get_metadata_items

In [3]:
# setup
load_dotenv()
pd.set_option('display.max_colwidth', 1000)
os.environ["DEBUG_MODE"] = "TRUE"

In [4]:
# checks
if os.getenv("DEBUG_MODE") == "TRUE":
    print("DEBUG_MODE is enabled.")

DEBUG_MODE is enabled.


In [5]:
# set up Entrez
Entrez.email = "nick.youngblut@arcinstitute.org"
Entrez.api_key = os.getenv("NCBI_API_KEY")

# Graph

In [6]:
class TopState(TypedDict):
    """
    Shared state of the agents in the graph
    """
    # messages
    messages: Annotated[Sequence[BaseMessage], operator.add]
    # database
    database: str
    # dataset entrez ID
    entrez_id: str
    # accessions
    SRX: Annotated[List[str], operator.add]
    # is_illumina
    is_illumina: Annotated[List[str], operator.add]
    # is_single_cell
    is_single_cell: Annotated[List[str], operator.add]
    # is_paired_end
    is_paired_end: Annotated[List[str], operator.add]
    # is_10x
    is_10x: Annotated[List[str], operator.add]
    # organism
    organism: Annotated[List[str], operator.add]

## Nodes

In [7]:
# convert subgraph
invoke_convert_graph_p = partial(
    invoke_convert_graph,
    graph = create_convert_graph(),
)

In [8]:
def create_convert_graph_node():
    graph = create_convert_graph()
    def invoke_convert_graph_node(state: TopState) -> TopState:
        entrez_id = state["entrez_id"]
        database = state["database"]
        message = "\n".join([
            f"Convert Entrez ID {entrez_id} to SRX accessions.",
            f"The Entrez ID is associated with the {database} database."
        ])
        input = {"messages": [HumanMessage(message)]}
        return graph.invoke(input)
    return invoke_convert_graph_node

In [9]:
# metadata subgraph
invoke_metadata_graph_p = partial(
    invoke_metadata_graph,
    graph = create_metadata_graph(),
)

In [10]:
# Parallel invoke of the subgraph
def continue_to_metadata(state: TopState):
    # format the prompt for the metadata graph
    prompt = "\n".join([
        "For the SRA accession {SRX_accession}, find the following information:",
        ] + list(get_metadata_items().values())
    )
    
    # submit each accession to the metadata graph
    responses = []
    for SRX_accession in state["SRX"]:
        input = {
            "database": state["database"],
            "entrez_id": state["entrez_id"],
            "SRX": SRX_accession,
            "messages": [HumanMessage(prompt.format(SRX_accession=SRX_accession))]
        }
        responses.append(Send("metadata_graph_node", input))
    # return messages 
    return responses

In [11]:
def final_state(state: TopState):
    # return final state
    return {
        #"messages": state["messages"],
        "database": state["database"],
        "entrez_id": state["entrez_id"],
        "SRX": state["SRX"],
    }

## Compile

In [12]:
#-- graph --#
workflow = StateGraph(TopState)

# nodes
workflow.add_node("convert_graph_node", create_convert_graph_node())
workflow.add_node("metadata_graph_node", invoke_metadata_graph_p)
workflow.add_node("final_state_node", final_state)

# edges
workflow.add_edge(START, "convert_graph_node")
workflow.add_conditional_edges("convert_graph_node", continue_to_metadata, ["metadata_graph_node"])
workflow.add_edge("metadata_graph_node", "final_state_node")
workflow.add_edge("final_state_node", END)

# compile the graph
graph = workflow.compile()

In [13]:
from IPython.display import Image
Image(graph.get_graph().draw_mermaid_png())

<IPython.core.display.Image object>

In [15]:
# Call the graph: SRA database
final_state = None
input = {
    "entrez_id": "34747624", 
    "database": "sra"
}
    
for step in graph.stream(input, subgraphs=True, config={"max_concurrency" : 3, "recursion_limit": 200}):
    print(step)
    final_state = step
print("---")
print(final_state)

(('convert_graph_node:b32673e8-068f-1e0a-f558-05a3f664c915', 'convert_agent_node:8188cdd5-5aa4-38ea-5d32-5a4c1f3e07c1'), {'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_AMssdjcGD2RDjnsooPoS3eHW', 'function': {'arguments': '{"message":"eFetch the Entrez ID 34747624 in the sra database to obtain the SRX accessions."}', 'name': 'invoke_efetch_agent'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 39, 'prompt_tokens': 1222, 'total_tokens': 1261, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 1024}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_7f6be3efb0', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-467772a8-9dc4-4464-905e-bcf1c14c1bf7-0', tool_calls=[{'name': 'invoke_efetch_agent', 'args': {'message': 'eFetch the E

In [15]:
# Call the graph: SRA database
final_state = None
entrez_ids = ["36178521", "35087712", "27978912"]
for entrez_id in entrez_ids:
    print(f"# Running graph for Entrez ID {entrez_id}")
    input = {"entrez_id": entrez_id, "database": "sra"}
    for step in graph.stream(input, subgraphs=True, config={"max_concurrency" : 3, "recursion_limit": 200}):
        print(step)

# Running graph for Entrez ID 36178521
(('convert_graph_node:3de78ecd-a5c2-0e29-3a9a-d6fc5fc16269', 'convert_agent_node:2f3f3fc3-87e2-68f6-926e-d19457e90229'), {'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_zJpEfwx3JYOqOTsQuUJcK2Oj', 'function': {'arguments': '{"message":"Use esummary to retrieve SRX accessions for Entrez ID 36178521 in the sra database."}', 'name': 'invoke_esummary_agent'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 38, 'prompt_tokens': 738, 'total_tokens': 776, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_7f6be3efb0', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-9998e4b6-1a5f-4b2e-a6ce-3514b91e0568-0', tool_calls=[{'name': 'invoke_esummary

In [16]:
# Call the graph: SRA database
final_state = None
entrez_ids = ["27978915", "36178506", "35087715"]
for entrez_id in entrez_ids:
    print(f"# Running graph for Entrez ID {entrez_id}")
    input = {"entrez_id": entrez_id, "database": "sra"}
    for step in graph.stream(input, subgraphs=True, config={"max_concurrency" : 3, "recursion_limit": 200}):
        print(step)

# Running graph for Entrez ID 27978915
(('convert_graph_node:c3943ac1-37e4-8434-3850-31becf7c6b37', 'convert_agent_node:b011a2bd-9845-a8c5-99a7-e69807bce32b'), {'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_mvrLYOAyKRW4FERKjk4u0451', 'function': {'arguments': '{"message":"Use esummary to retrieve SRX accessions for Entrez ID 27978915 in the sra database."}', 'name': 'invoke_esummary_agent'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 38, 'prompt_tokens': 738, 'total_tokens': 776, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_7f6be3efb0', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-9de8c618-e938-4115-9963-3df13cc83618-0', tool_calls=[{'name': 'invoke_esummary

#### GEO database

In [None]:
# Call the graph: GEO database
final_state = None
entrez_ids = ['200254051']
for entrez_id in entrez_ids:
    print(f"# Running graph for Entrez ID {entrez_id}")
    input = {"entrez_id": entrez_id, "database": "gds"}
    for step in graph.stream(input, subgraphs=True, config={"max_concurrency" : 3, "recursion_limit": 200}):
        print(step)

# Running graph for Entrez ID 200254051
(('convert_graph_node:671f7d58-7a32-3835-aaa5-22922958957b', 'convert_agent_node:f89472ac-64f3-e3a9-76ab-b5ebea02714e'), {'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_BCWNnELx35wCF38FPCBKgOVM', 'function': {'arguments': '{"entrez_ids":["200254051"]}', 'name': 'geo2sra'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 21, 'prompt_tokens': 1080, 'total_tokens': 1101, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_7f6be3efb0', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-23c51cd0-c320-41d7-a4d9-067f5d491b89-0', tool_calls=[{'name': 'geo2sra', 'args': {'entrez_ids': ['200254051']}, 'id': 'call_BCWNnELx35wCF38FPCBKgOVM', 'type':

In [14]:
# Call the graph: GEO database
final_state = None
entrez_ids = ['200282014', '200281914', '200278601']
for entrez_id in entrez_ids:
    print(f"# Running graph for Entrez ID {entrez_id}")
    input = {"entrez_id": entrez_id, "database": "gds"}
    for step in graph.stream(input, subgraphs=True, config={"max_concurrency" : 3, "recursion_limit": 200}):
        print(step)

# Running graph for Entrez ID 200282014
(('convert_graph_node:c2d28681-c79a-828c-73dd-6cf963138f35', 'convert_agent_node:98a3f077-43fa-27b2-e603-8c139196b32c'), {'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_nPZoTdtdrl9ng5qot1vt27zw', 'function': {'arguments': '{"entrez_ids":["200282014"]}', 'name': 'geo2sra'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 21, 'prompt_tokens': 1222, 'total_tokens': 1243, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 1024}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_7f6be3efb0', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-980f2c12-f28a-4439-8837-a9bcaf01aebf-0', tool_calls=[{'name': 'geo2sra', 'args': {'entrez_ids': ['200282014']}, 'id': 'call_nPZoTdtdrl9ng5qot1vt27zw', 'typ