In [None]:
!pip install squidpy scanpy --upgrade

In [None]:
!pip install langchain langgraph

In [None]:
!pip install langchain_google_vertexai

In [None]:
!pip install synapseclient

In [None]:
!pip install commot

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import re
import io
import os
import sys
import synapseclient
from synapseclient import Synapse
import gcsfs

import vertexai
import requests
from vertexai.preview import reasoning_engines
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory

import scanpy as sc
import pandas as pd
from google.cloud import bigquery
import pandas_gbq


from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

syn = synapseclient.login(authToken="", silent=True)


os.environ['SYNAPSE_AUTH_TOKEN'] = ""

In [None]:
# Path to your GCS file
file_path = 'gs://htan_st_datasets/HTAN_ST_metadata_ad.tsv'

# Create a GCS filesystem object
fs = gcsfs.GCSFileSystem()

# Load the TSV file directly from GCS
with fs.open(file_path) as f:
    metadata_df = pd.read_csv(f, sep='\t')

# Display the first few rows of the dataframe
metadata_df.head()

Unnamed: 0,Attribute:,Description
0,Filename:,Name of a file
1,Run ID:,A unique identifier for this individual run (t...
2,File Format:,"Format of a file (e.g. txt, csv, fastq, bam, e..."
3,HTAN Parent Biospecimen ID:,HTAN Biospecimen Identifier (eg HTANx_yyy_zzz)...
4,HTAN Data File ID:,Self-identifier for this data file - HTAN ID o...


# ChatVertexAI Multi-agent LLM Using LangChain

In [None]:

from contextlib import redirect_stdout


In [None]:
## %%capture captured_output

import re
import os
import sys
import base64
import vertexai
from vertexai.generative_models import GenerativeModel, Part, SafetySetting

from langchain_google_vertexai import ChatVertexAI
from langchain.tools import tool
from langchain.agents import create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.prebuilt.tool_executor import ToolExecutor
from typing import TypedDict, Annotated,Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain.agents.output_parsers.tools import ToolAgentAction
from langchain_core.messages import BaseMessage
import operator
from langgraph.graph import END, StateGraph
from langgraph.checkpoint.memory import MemorySaver

vertexai.init(
    project="isb-cgc-external-004",
    location="us-central1",
    staging_bucket="gs://htan_st_datasets/staging_dir/",
)

model = "gemini-1.5-flash-002"

safety_settings = {
    HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}

model_kwargs = {
    # temperature (float): The sampling temperature controls the degree of
    # randomness in token selection.
    "temperature": 0.28,
    # max_output_tokens (int): The token limit determines the maximum amount of
    # text output from one prompt.
    "max_output_tokens": 8000,
    # top_p (float): Tokens are selected from most probable to least until
    # the sum of their probabilities equals the top-p value.
    "top_p": 0.95,
    # top_k (int): The next token is selected from among the top-k most
    # probable tokens.
    "top_k": 40,
    # safety_settings (Dict[HarmCategory, HarmBlockThreshold]): The safety
    # settings to use for generating content.
    "safety_settings": safety_settings,
}

safety_settings = [
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
        threshold=SafetySetting.HarmBlockThreshold.OFF
    ),
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
        threshold=SafetySetting.HarmBlockThreshold.OFF
    ),
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
        threshold=SafetySetting.HarmBlockThreshold.OFF
    ),
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
        threshold=SafetySetting.HarmBlockThreshold.OFF
    ),
]

@tool
def agent_1(question: str) -> str:
    """
   Download files from HTAN or Synapse. You can directly download files by generating and executing Python code to retrieve datasets using APIs.

    Parameters:
    question (str): A user-defined question regarding dataset retrieval from HTAN or Synapse.

    Returns:
    str: Python code enclosed in ```python``` tags, allowing the user to easily extract the code.
         Code will be specific to either BigQuery for HTAN datasets or synapseclient for Synapse datasets.
    """

    prompt = f"""
    You are an expert in downloading datasets from Human Tumor Network Atlas (HTAN) and Synapse using APIs. You can directly download files. Return only the python code.

    For Synapse Datasets:
    You can write the Python code to download the Synapse data directly using the synapseclient by passing the synapse id and the download location. If the download location is not defined, download to ‘./’. You can login using synapseclient.login(silent=True) function.

    For Datasets with an HTAN biospecimen ID:
    You can use BigQuery in Google Cloud to do this. Assume that all the packages are already installed. You can use: project_id = "isb-cgc-external-004".

    Important considerations:

    1. Please write and execute the Python code enclosed in ```python``` tags so that the user can regex extract the code easily.

    2. You do not have to access any data. You are an expert coder, just write the code. When you're asked to load a dataset, use BigQuery to load it.

    3. Include all the necessary import packages to run the code. If using any Google Cloud services, you can use:
    project_id = "isb-cgc-external-004"

    4. Do not write any try except blocks.

    Make sure you follow the considerations.

    Here are some examples:

    Question: Can you load the cells from the HTAN biospecimen HTA7_1_3?
    Answer:
    '''
    query = '''
    WITH cells AS (
      SELECT  CellID, X_centroid, Y_centroid,
      FROM `isb-cgc-bq.HTAN.imaging_level4_HMS_mel_mask_current`
      WHERE HTAN_Biospecimen_ID = 'HTA7_1_3'
    )
    SELECT CellID, X_centroid,  Y_centroid
    FROM cells
    '''
    df = pandas_gbq.read_gbq(query, project_id=project_id)
    print(df)
    '''


    """



    #### Synapse data download : Example 1

    prompt += f"""
    Question: Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current
    where the File_Format is hdf5?

    Answer:
    '''
    import pandas_gbq
    from google.cloud import bigquery

    project_id = "isb-cgc-external-004"

    query = '''
    SELECT entityId
    FROM `isb-cgc-bq.HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current`
    WHERE File_Format = 'hdf5'
    '''

    df = pandas_gbq.read_gbq(query, project_id=project_id)
    print(df)
    '''

    """

    ##  Synapse example 2

    prompt += f"""
    Question: Can you download the synapse data 'syn51133602' to /content/datasets/?
    Answer:
    '''
    import synapseclient
    syn = synapseclient.login(silent=True)
    entity = syn.get('syn51133602', downloadLocation='/content/datasets')
    '''

    """

    ### Spatial example 1

    prompt += f"""
    Question: Can you categorize cells within a defined spatial region as either 'Tumor' or 'Other' based on threshold expression levels of specific markers (SOX10_cellRingMask, S100B_cellRingMask, and CD63_cellRingMask)?
    Answer:
    '''
    query = '''
      WITH cells AS (
      SELECT CellID, X_centroid, Y_centroid,
      IF (SOX10_cellRingMask > 3704.5 AND (S100B_cellRingMask > 7589.48 OR CD63_cellRingMask > 570.68),
      'Tumor', 'Other') AS celltype
      FROM `isb-cgc-bq.HTAN.imaging_level4_HMS_mel_mask_current`
      WHERE HTAN_Biospecimen_ID = 'HTA7_1_3')

      SELECT CellID, X_centroid, Y_centroid, celltype
      FROM cells
      WHERE X_centroid > 23076.9 AND X_centroid < 30384.6
      AND Y_centroid > 9615.3 AND Y_centroid < 15000
      '''
    df = pandas_gbq.read_gbq(query, project_id=project_id)
    print(df)
    """

    ### Spatial example 2
    prompt += f"""
    Question: Can you classify cells as 'Tumor' or 'Other', convert their pixel coordinates to geospatial points, and calculate distances between cell pairs that are within a 20-micrometer proximity threshold
    Answer:
    query = '''
      WITH geodat AS (
      SELECT CellID, X_centroid, Y_centroid,
      IF (SOX10_cellRingMask > 3704.5 AND (S100B_cellRingMask > 7589.48 OR CD63_cellRingMask > 570.68),
      'Tumor', 'Other') AS celltype,
      ST_GeogPoint(X_centroid / 368570, Y_centroid / 368570) AS p
      FROM `isb-cgc-bq.HTAN.imaging_level4_HMS_mel_mask_current`
      WHERE HTAN_Biospecimen_ID = 'HTA7_1_3'
      )
    SELECT t1.CellID, t1.X_centroid, t1.Y_centroid, t1.p, t1.celltype,
    t2.CellID AS CellID_1, t2.X_centroid AS X_centroid_1, t2.Y_centroid AS Y_centroid_1, t2.p AS p_1, t2.celltype AS celltype_1,
    ST_Distance(t1.p, t2.p) AS Distance
    FROM geodat AS t1
    JOIN geodat AS t2
    ON ST_DWithin(t1.p, t2.p, 9.29324770787722)
    '''
    df = pandas_gbq.read_gbq(query, project_id=project_id)
    print(df)

    """

    ### Spatial example 3
    prompt += f"""
    Question: For each tumor cell within the specified spatial region, calculate the number of neighboring tumor cells
    Answer:
    query = '''
      WITH cellp AS (
      SELECT CellID, celltype, CellID_1, celltype_1
      FROM `isb-cgc-bq.temp15432.Melanoma_CyCIF_HTA7_1_3_points_within_20um`
      WHERE X_centroid > 23076.9 AND X_centroid < 30384.6
      AND Y_centroid > 9615.3 AND Y_centroid < 15000)

      SELECT CellID, COUNTIF(celltype_1 = 'Tumor') - 1 AS N_Tumor_Cells
      FROM cellp
      WHERE celltype = 'Tumor'
      GROUP BY CellID
    '''
    df = pandas_gbq.read_gbq(query, project_id=project_id)
    print(df)
    """


    ## Metadata

    # Create a prompt with the file content included
    prompt += f"""
    The HTAN metadata contains a subset of these type of attributes. Also included is a description of these attributes. Internalize this information and use it to answer any queries related to metadata.

    Metadata Description:
    {metadata_df}
    """



    ## ScRNAseq example 1
    prompt += f"""
    Question: What are the counts of unique cells, sex groupings, samples,
    cell types, and therapies by development stage in the MSK scRNAseq dataset?
    Answer:
    '''
    query = '''
    SELECT
      development_stage,
      count(distinct(iObs)) AS Number_Cells,
      count(distinct(sex)) AS Unique_Sex_Grouping,
      count(distinct(donor_id)) AS Number_Samples,
      count(distinct(cell_type)) AS Number_Cell_Types,
      count(distinct(treatment)) AS Number_Therapies
    FROM
      `isb-cgc-bq.HTAN.scRNAseq_MSK_SCLC_combined_samples_current`
    GROUP BY development_stage
    ORDER BY Number_Samples DESC
    '''
    df = pandas_gbq.read_gbq(query, project_id=project_id)
    print(df)
    '''

    """

    ## ScRNAseq example 2
    prompt += f"""
    Question: How many unique cell types, sex groupings, cells, and samples
    are present in the a specific human stage of development in the MSK
    scRNAseq dataset (e.g. 74-year-old)?
    Answer:
    '''
    query = '''
    SELECT
      cell_type,
      count(distinct(sex)) AS Unique_Sex_Grouping,
      count(distinct(iObs)) AS Number_Cells,
      count(distinct(donor_id)) AS Number_Samples
    FROM
      `isb-cgc-bq.HTAN.scRNAseq_MSK_SCLC_combined_samples_current`
    WHERE
      development_stage = '74-year-old human stage'
    GROUP BY cell_type
    '''
    df = pandas_gbq.read_gbq(query, project_id=project_id)
    print(df)
    '''

    """

    ## ScRNAseq example 3
    prompt += f"""
    Question: How many genes and cells are associated with each sex and
    cell type in the MSK scRNAseq dataset for an individual
    (e.g. a 74-year-old human stage)?
    Answer:
    '''
    query = '''
    SELECT
      sex,
      Cell_Type,
      count(distinct(feature_name)) AS Number_Genes,
      count(distinct(iObs)) AS Number_Cells
    FROM
      `isb-cgc-bq.HTAN.scRNAseq_MSK_SCLC_combined_samples_current`
    WHERE development_stage = '74-year-old human stage'
    GROUP BY sex, Cell_Type
    ORDER BY Cell_Type DESC
    '''
    df = pandas_gbq.read_gbq(query, project_id=project_id)
    print(df)
    '''

    """

    ## ScRNAseq example 4
    prompt += f"""
    Question: How many genes and cells are there in each Seurat Cluster
    for males and females of the 'epithelial cell' type in the specific
    human stage (here 74-year-old)?
    Answer:
    '''
    query = '''
    SELECT
      sex,
      clusters,
      Cell_Type,
      count(distinct(feature_name)) AS Number_Genes,
      count(distinct(iObs)) AS Number_Cells
    FROM
      `isb-cgc-bq.HTAN.scRNAseq_MSK_SCLC_combined_samples_current`
    WHERE development_stage = '74-year-old human stage' AND Cell_Type = 'epithelial cell'
    GROUP BY sex, clusters, Cell_Type
    ORDER BY clusters ASC
    '''
    df = pandas_gbq.read_gbq(query, project_id=project_id)
    print(df)
    '''

    """

   ## ScRNAseq example 5
    prompt += f"""
    Question: How do the average expression values for genes differ between
    male and female epithelial cells in a specific cluster, and which genes
    show the greatest differences (here cluster 41 of the 74-year-old
    human stage)?
    Answer:
    '''
    query = '''
    SELECT
      A.feature_name,
      A.avg_counts_clust10 AS female_avg_counts,
      B.avg_counts_clust10 AS male_avg_counts,
      A.avg_counts_clust10 - B.avg_counts_clust10 AS mean_diff
    FROM (
      SELECT
        feature_name,
        AVG(X_value) AS avg_counts_clust10
      FROM
        `isb-cgc-bq.HTAN.scRNAseq_MSK_SCLC_combined_samples_current`
      WHERE development_stage = '74-year-old human stage' AND Cell_Type = 'epithelial cell' AND clusters = '41' AND sex = 'female'
      GROUP BY feature_name
    ) AS A
    INNER JOIN (
      SELECT
        feature_name,
        AVG(X_value) AS avg_counts_clust10
      FROM
        `isb-cgc-bq.HTAN.scRNAseq_MSK_SCLC_combined_samples_current`
      WHERE development_stage = '74-year-old human stage' AND Cell_Type = 'epithelial cell' AND clusters = '41' AND sex = 'male'
      GROUP BY feature_name
    ) AS B
    ON A.feature_name = B.feature_name
    ORDER BY mean_diff DESC
    '''
    df = pandas_gbq.read_gbq(query, project_id=project_id)
    print(df)
    '''

    """


    # Add the dynamic user question and the placeholder for the answer
    prompt += f"""
    Question: {question}
    Answer:
    """

    vertexai.init(project="isb-cgc-external-004", location="us-central1")
    model = GenerativeModel(
        "gemini-1.5-flash-002",
    )
    responses = model.generate_content(
        [prompt],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
        # tools='code_execution'
    )

    # for response in responses:
    #     print(response.text, end="")

    pattern = r'```python\n(.*?)\n```'

    # Extract the code
    match = re.search(pattern, responses.text, re.DOTALL)

    if match:
        python_code = match.group(1).strip()
        print(python_code)
        exec(python_code)

    else:
        print("Agent 1: Executed Python code.")

@tool
def agent_2(question:str) -> str:
    """
    Generate and execute Python code to query metadata or retrieve datasets from HTAN via Google Cloud
    using BigQuery and relevant Python packages.

    Parameters:
    question (str): A user-defined question regarding metadata or dataset retrieval from HTAN.

    Returns:
    str: Python code enclosed in ```python``` tags for easy extraction. The code may include
         BigQuery queries or use of libraries like pandas, scanpy, or squidpy depending on
         the question.

    Notes:
    - Uses Google BigQuery with project ID 'isb-cgc-external-004' to access HTAN datasets.
    - Responds to general questions with a conversational response if code is not required.
    - Packages are imported as needed in the generated code.
    """

    prompt = f"""
    You are a bioinformatics expert coder using the Human Tumor Network Atlas (HTAN) datasets via Google Cloud and Jupyter.
    You can write Python code to answer questions by writing BigQuery queries, pandas, scanpy, squidpy, and commot code.
    You have access to an AnnData object named 'adata' that contains:
    - Gene expression data in adata.X
    - Cell metadata in adata.obs, including 'kmeans_9_clusters' for cell types
    - Gene names in adata.var_names
    - Spatial coordinates in adata.obsm['spatial'] if available

    Important technical details:
    - Gene expression data is stored as a sparse matrix and needs to be converted using .toarray() or scipy.sparse methods
    - Always handle sparse matrices appropriately
    - Add clear comments explaining the analysis steps
    - Include proper error handling for missing genes or data
    - Use sc.pl.spatial() for spatial plots
    - Use only 'cmap' parameter (not color_map) for color mapping
    - Include proper error handling for missing coordinates or genes
    - For multiple plots, use ncols parameter to arrange them
    - Add clear comments explaining the analysis steps
    - VERY IMPORTANT: Include all the necessary import packages to run the code.

    For any commands that have the parameter 'database_name' or 'database' please use 'CellChat'.
    The following results are attached after running ct.tl.spatial_communication. Please note that 'Fgf1-Fgfr1' can be replaced wiht other ligand-receptor pairs:
    - adata.uns['commot-user_database-info']: Ligand-receptor database used.
    - adata.obsm['commot-user_database-sum-sender']: Total sent signals per LR pair and pathway.
    - adata.obsm['commot-user_database-sum-receiver']: Total received signals per LR pair and pathway.
    - adata.obsm['commot_sender_vf-user_database-Fgf1-Fgfr1']: Signaling directions for sent signals.
    - adata.obsm['commot_receiver_vf-user_database-Fgf1-Fgfr1']: Signaling directions for received signals.
    - adata.obsp['commot-user_database-Fgf1-Fgfr1']: Sparse matrix of cell-cell communication scores per spot.

    When responding to questions:
    1. Analyze what type of operation is being requested
    2. Generate appropriate Python code using scanpy, pandas, matplotlib, and other relevant libraries
    3. For questions that require code, return ONLY the Python code within ```python``` tags
    4. Include proper error handling and data validation
    5. Add clear comments explaining the analysis steps

    Examples of queries and their answers:

    Question: How do preprocess my adata object?
    Answer:
    '''
    import commot as ct
    import scanpy as sc
    import pandas as pd
    import numpy as np
    from scipy import sparse
    import matplotlib.pyplot as plt
    sc.pp.normalize_total(adata, inplace=True)
    sc.pp.log1p(adata)
    '''

    Q: "Show me the distribution of cells across kmeans_9_clusters"
    A: ```python
    import scanpy as sc
    import pandas as pd
    import matplotlib.pyplot as plt

    # Get cluster counts
    cluster_counts = adata.obs['kmeans_9_clusters'].value_counts().sort_index()

    # Create bar plot
    plt.figure(figsize=(10, 6))
    cluster_counts.plot(kind='bar')
    plt.title('Distribution of Cells Across kmeans_9_clusters')
    plt.xlabel('Cluster')
    plt.ylabel('Number of Cells')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

    # Print numerical summary
    print("\\nNumerical distribution:")
    print(pd.DataFrame({{
        'Cluster': cluster_counts.index,
        'Count': cluster_counts.values,
        'Percentage': (cluster_counts.values / len(adata) * 100).round(2)
    }}))
    ```

    Q: "Find and plot the top 5 genes with highest Moran's I values"
    A: ```python
    import pandas as pd
    import matplotlib.pyplot as plt
    import scanpy as sc

    # Create DataFrame with Moran's I statistics
    morans_df = pd.DataFrame({{
        'Gene': adata.var_names,
        'Morans_I': adata.var['Morans_I'],
        'Adj_P_Value': adata.var['Morans_I_adj_p_val']
    }})

    # Sort by Moran's I value and get top 5
    top_genes = morans_df.sort_values('Morans_I', ascending=False).head(5)

    # Create spatial plots for top genes
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()

    # Plot each gene
    for idx, (gene, mi_value) in enumerate(zip(top_genes['Gene'], top_genes['Morans_I'])):
        sc.pl.spatial(adata,
                     color=gene,
                     title=f'{{gene}}\\nMoran\\'s I = {{mi_value:.3f}}',
                     ax=axes[idx],
                     show=False,
                     cmap='viridis')

    plt.tight_layout()
    plt.show()

    # Print summary table
    print("\\nTop 5 spatially autocorrelated genes:")
    print(top_genes)
    ```

    Q: "Generate a UMAP visualization comparing kmeans_9_clusters with kmeans_10_clusters"
    A: ```python
    import scanpy as sc
    import matplotlib.pyplot as plt

    # Compute UMAP if not already present
    if 'X_umap' not in adata.obsm_keys():
        sc.pp.neighbors(adata, n_pcs=30)
        sc.tl.umap(adata)

    # Create subplot with both clustering results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # Plot UMAP with kmeans_9_clusters
    sc.pl.umap(adata,
               color='kmeans_9_clusters',
               title='kmeans_9_clusters',
               ax=ax1,
               show=False)

    # Plot UMAP with kmeans_10_clusters
    sc.pl.umap(adata,
               color='kmeans_10_clusters',
               title='kmeans_10_clusters',
               ax=ax2,
               show=False)

    plt.tight_layout()
    plt.show()

    # Print cluster sizes
    print("\\nkmeans_9_clusters distribution:")
    print(adata.obs['kmeans_9_clusters'].value_counts())
    print("\\nkmeans_10_clusters distribution:")
    print(adata.obs['kmeans_10_clusters'].value_counts())
    ```

    Q: "Calculate correlation between the top 3 most highly expressed genes"
    A: ```python
    import pandas as pd
    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    from scipy import sparse

    # Calculate mean expression for each gene
    if sparse.issparse(adata.X):
        mean_expr = np.array(adata.X.mean(axis=0)).flatten()
    else:
        mean_expr = np.mean(adata.X, axis=0)

    # Get top 3 genes
    top_genes_idx = np.argsort(mean_expr)[-3:]
    top_genes = adata.var_names[top_genes_idx]

    # Extract expression data for top genes
    expr_matrix = adata[:, top_genes].X
    if sparse.issparse(expr_matrix):
        expr_matrix = expr_matrix.toarray()

    # Create correlation matrix
    corr_df = pd.DataFrame(expr_matrix, columns=top_genes)
    corr_matrix = corr_df.corr()

    # Plot correlation heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(corr_matrix,
                annot=True,
                cmap='coolwarm',
                vmin=-1,
                vmax=1,
                center=0)
    plt.title('Correlation between top 3 highly expressed genes')
    plt.tight_layout()
    plt.show()

    # Print mean expression values
    print("\\nMean expression values:")
    for gene, mean_exp in zip(top_genes, mean_expr[top_genes_idx]):
        print(f"{{gene}}: {{mean_exp:.2f}}")
    ```

    Question: What are the different ligand and receptor pairs in the TGFb pathway?
    Answer:
    '''
    import commot as ct
    import scanpy as sc
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    df_ligrec=ct.pp.ligand_receptor_database(database='CellChat', species='human')
    tgfb = df_ligrec[df_ligrec[2]=='TGFb']
    tgfb
    '''
    Question: Perform cell-cell communication analysis for the TGFb pathway.
    Only access the df_ligrec database with .iloc.
    '''
    import commot as ct
    import scanpy as sc
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    df_ligrec=ct.pp.ligand_receptor_database(database='CellChat', species='human')
    tgfb = df_ligrec[df_ligrec.iloc[:,2]=='TGFb']
    ct.tl.spatial_communication(adata,
    database_name='cellchat', df_ligrec=tgfb, dis_thr=500, heteromeric=True, pathway_sum=True)

    Question: Can you construct a cell-cell communication networks of the TGFb pathway between cells within 500 µm?
    Answer:
    '''
    import commot as ct
    import scanpy as sc
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    df_ligrec=ct.pp.ligand_receptor_database(database='CellChat', species='human')
    tgfb = df_ligrec[df_ligrec.iloc[:,2]=='TGFb']
    ct.tl.spatial_communication(adata,
    database_name='cellchat', df_ligrec=tgfb, dis_thr=500, heteromeric=True, pathway_sum=True)
    '''

    Question: How do I visualize the amount of sent and received signal between SEMA3A and NRP1_PLXNA1?
    Answer:
    Here's how to plot the signaling levels for the SEMA3A-NRP1_PLXNA1 pair.
    '''
    import commot as ct
    import scanpy as sc
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    #define the LR pair
    LR = np.array([['SEMA3A', 'NRP1_PLXNA2', 'SEMA3']], dtype=str)
    LR = pd.DataFrame(data=LR)
    ct.tl.spatial_communication(adata,
    database_name='cellchat', df_ligrec=LR, dis_thr=500, heteromeric=True, pathway_sum=True)
    pts = adata.obsm['spatial']
    s = adata.obsm['commot-cellchat-sum-sender']['s-SEMA3A']
    r = adata.obsm['commot-cellchat-sum-receiver']['r-NRP1-PLXNA2']
    fig, ax = plt.subplots(1,2, figsize=(10,4))
    ax[0].scatter(pts[:,0], pts[:,1], c=s, s=5, cmap='Blues')
    ax[0].set_title('Sender')
    ax[1].scatter(pts[:,0], pts[:,1], c=r, s=5, cmap='Reds')
    ax[1].set_title('Receiver')
    '''

    Question: How can I visualize signaling directions for TGFB1-TGFBR1_TGFBR2 as vector fields?
    Answer:
    '''
    import commot as ct
    import scanpy as sc
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    ct.tl.communication_direction(adata, database_name='cellchat', lr_pair=('TGFB1','TGFBR1_TGFBR2'), k=5)
    ct.pl.plot_cell_communication(adata, database_name='cellchat', lr_pair=('TGFB1','TGFBR1_TGFBR2'), plot_method='grid', background_legend=True,
    scale=0.00003, ndsize=8, grid_density=0.4, summary='sender', background='image', clustering='leiden', cmap='Alphabet',
    normalize_v = True, normalize_v_quantile=0.995)
    '''

    Question: Can you show me the cell-cell interaction between TGFB1 and TGFBR1_TGFBR2?
    Answer:
    '''
    import commot as ct
    import scanpy as sc
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    ct.tl.communication_direction(adata, database_name='cellchat', lr_pair=('TGFB1','TGFBR1_TGFBR2'), k=5)
    ct.pl.plot_cell_communication(adata, database_name='cellchat', lr_pair=('TGFB1','TGFBR1_TGFBR2'), plot_method='grid', background_legend=True,
    scale=0.00003, ndsize=8, grid_density=0.4, summary='receiver', background='summary', clustering='leiden', cmap='Reds',
    normalize_v = True, normalize_v_quantile=0.995)
    '''

    Current question to analyze: {question}

    First, analyze the type of operation requested and then provide the appropriate code or response.
    """

    vertexai.init(project="isb-cgc-external-004", location="us-central1")
    model = GenerativeModel(
        "gemini-1.5-flash-002",
    )
    responses = model.generate_content(
        [prompt],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
        # tools='code_execution'
    )


    pattern = r'```python\n(.*?)\n```'
    match = re.search(pattern, responses.text, re.DOTALL)

    if match:
        python_code = match.group(1).strip()
        print("Generated code:", python_code)

        # Dictionary to capture local variables from exec
        local_vars = {}

        # Execute the code in the local_vars dictionary to capture output
        exec(python_code, globals(), local_vars)

        # Update the global scope with new variables
        globals().update(local_vars)

        print("Agent 2: Executed Python code with accessible variables.")
    else:
        print(responses.text)
        print("Agent 2: No executable code found.")

toolkit = [agent_1, agent_2]

from IPython.display import Markdown, display

def agent_3(question:str) -> str:
    """
    Given output of a tool calling agent, convert the output to a human readable format.

    Parameters:
    question (str): Raw outputs from the tool calling agent.

    Returns:
    str: Human readable version of the raw output string.
    """

    prompt = """You are an experienced writer and editor. When given an input, convert it into a well formatted string output. If the input string is blank, just say that there are no outputs generated for the current query.

    Here are some examples:
    Example 1:
    Input: '''
    Generated code: import scanpy as sc
    import pandas as pd

    adata = sc.read_h5ad("/content/demo_data/6723_KL_1_unfiltered.h5ad")
    print(adata)
    AnnData object with n_obs × n_vars = 4992 × 19074
    obs: 'in_tissue', 'array_row', 'array_col', 'kmeans_7_clusters', 'kmeans_10_clusters', 'kmeans_4_clusters', 'kmeans_2_clusters', 'kmeans_6_clusters', 'graphclust', 'kmeans_3_clusters', 'kmeans_8_clusters', 'kmeans_9_clusters', 'kmeans_5_clusters'
    var: 'gene_ids', 'feature_types', 'genome', 'n_cells', 'Morans_I', 'Morans_I_p_val', 'Morans_I_adj_p_val', 'Feature Counts in Spots Under Tissue', 'Median Normalized Average Counts', 'Barcodes Detected per Feature'
    uns: 'spatial'
    obsm: 'spatial'
    Agent 2: Executed Python code with accessible variables.
    The agent action is tool='agent_2' tool_input={'question': 'Can you use scanpy to load the /content/demo_data/6723_KL_1_unfiltered.h5ad?'} log="\nInvoking: `agent_2` with `{'question': 'Can you use scanpy to load the /content/demo_data/6723_KL_1_unfiltered.h5ad?'}`\n\n\n" message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'agent_2', 'arguments': '{"question": "Can you use scanpy to load the /content/demo_data/6723_KL_1_unfiltered.h5ad?"}'}}, response_metadata={'is_blocked': False, 'safety_ratings': [], 'usage_metadata': {'prompt_token_count': 380, 'candidates_token_count': 35, 'total_token_count': 415, 'cached_content_token_count': 0}, 'finish_reason': 'STOP', 'avg_logprobs': -0.004535583513123649, 'logprobs_result': {'top_candidates': [], 'chosen_candidates': []}}, id='run-7d91d738-3af2-468a-bb2b-e9b23235e38a-0', tool_calls=[{'name': 'agent_2', 'args': {'question': 'Can you use scanpy to load the /content/demo_data/6723_KL_1_unfiltered.h5ad?'}, 'id': '62e66320-2760-4403-9049-43fd7a676e0f', 'type': 'tool_call'}], usage_metadata={'input_tokens': 380, 'output_tokens': 35, 'total_tokens': 415})] tool_call_id='62e66320-2760-4403-9049-43fd7a676e0f'
    The tool result is: None
    {'intermediate_steps': [(ToolAgentAction(tool='agent_2', tool_input={'question': 'Can you use scanpy to load the /content/demo_data/6723_KL_1_unfiltered.h5ad?'}, log="\nInvoking: `agent_2` with `{'question': 'Can you use scanpy to load the /content/demo_data/6723_KL_1_unfiltered.h5ad?'}`\n\n\n", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'agent_2', 'arguments': '{"question": "Can you use scanpy to load the /content/demo_data/6723_KL_1_unfiltered.h5ad?"}'}}, response_metadata={'is_blocked': False, 'safety_ratings': [], 'usage_metadata': {'prompt_token_count': 380, 'candidates_token_count': 35, 'total_token_count': 415, 'cached_content_token_count': 0}, 'finish_reason': 'STOP', 'avg_logprobs': -0.004535583513123649, 'logprobs_result': {'top_candidates': [], 'chosen_candidates': []}}, id='run-7d91d738-3af2-468a-bb2b-e9b23235e38a-0', tool_calls=[{'name': 'agent_2', 'args': {'question': 'Can you use scanpy to load the /content/demo_data/6723_KL_1_unfiltered.h5ad?'}, 'id': '62e66320-2760-4403-9049-43fd7a676e0f', 'type': 'tool_call'}], usage_metadata={'input_tokens': 380, 'output_tokens': 35, 'total_tokens': 415})], tool_call_id='62e66320-2760-4403-9049-43fd7a676e0f'), 'None')]}

    '''
    Well formated output:

    '''The dataset is loaded to the anndata object named adata. There are 4992 cells or spots, 19074 genes, and 13 feature columns. The data object is printed below:

    AnnData object with n_obs × n_vars = 4992 × 19074
    obs: 'in_tissue', 'array_row', 'array_col', 'kmeans_7_clusters', 'kmeans_10_clusters', 'kmeans_4_clusters', 'kmeans_2_clusters', 'kmeans_6_clusters', 'graphclust', 'kmeans_3_clusters', 'kmeans_8_clusters', 'kmeans_9_clusters', 'kmeans_5_clusters'
    var: 'gene_ids', 'feature_types', 'genome', 'n_cells', 'Morans_I', 'Morans_I_p_val', 'Morans_I_adj_p_val', 'Feature Counts in Spots Under Tissue', 'Median Normalized Average Counts', 'Barcodes Detected per Feature'
    uns: 'spatial'
    obsm: 'spatial'
    '''

    Example 2:
    Input: '''
    Welcome, arun.das!

    INFO:synapseclient_default:Welcome, arun.das!

    Downloading files:  77%|███████▋  | 16.8M/21.8M [00:00<00:00, 28.0MB/s, syn51133599]Downloaded syn51133599 to /content/demo_data/8578_AS_1_unfiltered.h5ad
    [INFO] Downloaded syn51133599 to /content/demo_data/8578_AS_1_unfiltered.h5ad
    Downloading files: 100%|██████████| 21.8M/21.8M [00:00<00:00, 31.8MB/s, syn51133599]INFO:synapseclient_default:Downloaded syn51133599 to /content/demo_data/8578_AS_1_unfiltered.h5ad
    Downloading files: 100%|██████████| 21.8M/21.8M [00:00<00:00, 31.6MB/s, syn51133599]import synapseclient
    syn = synapseclient.login(silent=True)
    entity = syn.get('syn51133599', downloadLocation='/content/demo_data')
    The agent action is tool='agent_1' tool_input={'question': 'Can you download the synapse dataset syn51133599 to /content/demo_data/?'} log="\nInvoking: `agent_1` with `{'question': 'Can you download the synapse dataset syn51133599 to /content/demo_data/?'}`\n\n\n" message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'agent_1', 'arguments': '{"question": "Can you download the synapse dataset syn51133599 to /content/demo_data/?"}'}}, response_metadata={'is_blocked': False, 'safety_ratings': [], 'usage_metadata': {'prompt_token_count': 396, 'candidates_token_count': 27, 'total_token_count': 423, 'cached_content_token_count': 0}, 'finish_reason': 'STOP', 'avg_logprobs': -0.0003545089038433852, 'logprobs_result': {'top_candidates': [], 'chosen_candidates': []}}, id='run-49c29b88-a638-4a7d-89fa-122d8edb910d-0', tool_calls=[{'name': 'agent_1', 'args': {'question': 'Can you download the synapse dataset syn51133599 to /content/demo_data/?'}, 'id': '48d647d4-94d2-4cae-bd6f-f5d5e2240945', 'type': 'tool_call'}], usage_metadata={'input_tokens': 396, 'output_tokens': 27, 'total_tokens': 423})] tool_call_id='48d647d4-94d2-4cae-bd6f-f5d5e2240945'
    The tool result is: None
    ----
    '''

    Well formated output:
    '''
    The synapse dataset with synapse id syn51133599 is successfully downloaded to /content/demo_data/8578_AS_1_unfiltered.h5ad.
    '''


    Example 3:
    Input: '''
    {'agent_outcome': [ToolAgentAction(tool='agent_1', tool_input={'question': 'Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?'}, log="\nInvoking: `agent_1` with `{'question': 'Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?'}`\n\n\n", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'agent_1', 'arguments': '{"question": "Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?"}'}}, response_metadata={'is_blocked': False, 'safety_ratings': [], 'usage_metadata': {'prompt_token_count': 536, 'candidates_token_count': 49, 'total_token_count': 585, 'cached_content_token_count': 0}, 'finish_reason': 'STOP', 'avg_logprobs': -0.0002904396352111077, 'logprobs_result': {'top_candidates': [], 'chosen_candidates': []}}, id='run-b2cf906b-d897-4164-b743-1120419ebdf9-0', tool_calls=[{'name': 'agent_1', 'args': {'question': 'Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?'}, 'id': '6987b536-4024-4b91-a0b1-720e41b1076e', 'type': 'tool_call'}], usage_metadata={'input_tokens': 536, 'output_tokens': 49, 'total_tokens': 585})], tool_call_id='6987b536-4024-4b91-a0b1-720e41b1076e')]}
    import pandas_gbq
    from google.cloud import bigquery
    import synapseclient

    project_id = "isb-cgc-external-004"

    query = '''
    SELECT entityId
    FROM `isb-cgc-bq.HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current`
    WHERE File_Format = 'hdf5'
    '''

    df = pandas_gbq.read_gbq(query, project_id=project_id)
    print(df)

    syn = synapseclient.login(silent=True)
    entity = syn.get('syn51133602', downloadLocation='/content/datasets')
    Downloading: 100%|##########|
        entityId
    0   syn51133519
    1   syn51133520
    2   syn51133521
    3   syn51133522
    4   syn51133523
    5   syn51133524
    6   syn51133525
    7   syn51133526
    8   syn51133527
    9   syn51133528
    10  syn51133529
    11  syn51133530
    12  syn51133531
    13  syn51133532
    14  syn51133533
    15  syn51133534
    16  syn51133537
    17  syn51133540
    18  syn51133578
    19  syn51133580
    20  syn51133583
    21  syn51133587
    22  syn51133591
    23  syn51133592
    24  syn51133593
    25  syn51133595
    26  syn51133596
    27  syn51133597
    28  syn51133598
    29  syn51133599
    30  syn51133600
    31  syn51133601
    32  syn51133602
    33  syn51133603
    34  syn51133604
    35  syn51133605
    36  syn51133606
    37  syn51133607
    38  syn51133608
    39  syn51133609
    40  syn51133612
    The agent action is tool='agent_1' tool_input={'question': 'Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?'} log="\nInvoking: `agent_1` with `{'question': 'Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?'}`\n\n\n" message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'agent_1', 'arguments': '{"question": "Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?"}'}}, response_metadata={'is_blocked': False, 'safety_ratings': [], 'usage_metadata': {'prompt_token_count': 536, 'candidates_token_count': 49, 'total_token_count': 585, 'cached_content_token_count': 0}, 'finish_reason': 'STOP', 'avg_logprobs': -0.0002904396352111077, 'logprobs_result': {'top_candidates': [], 'chosen_candidates': []}}, id='run-b2cf906b-d897-4164-b743-1120419ebdf9-0', tool_calls=[{'name': 'agent_1', 'args': {'question': 'Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?'}, 'id': '6987b536-4024-4b91-a0b1-720e41b1076e', 'type': 'tool_call'}], usage_metadata={'input_tokens': 536, 'output_tokens': 49, 'total_tokens': 585})] tool_call_id='6987b536-4024-4b91-a0b1-720e41b1076e'
    The tool result is: None
    {'intermediate_steps': [(ToolAgentAction(tool='agent_1', tool_input={'question': 'Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?'}, log="\nInvoking: `agent_1` with `{'question': 'Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?'}`\n\n\n", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'agent_1', 'arguments': '{"question": "Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?"}'}}, response_metadata={'is_blocked': False, 'safety_ratings': [], 'usage_metadata': {'prompt_token_count': 536, 'candidates_token_count': 49, 'total_token_count': 585, 'cached_content_token_count': 0}, 'finish_reason': 'STOP', 'avg_logprobs': -0.0002904396352111077, 'logprobs_result': {'top_candidates': [], 'chosen_candidates': []}}, id='run-b2cf906b-d897-4164-b743-1120419ebdf9-0', tool_calls=[{'name': 'agent_1', 'args': {'question': 'Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?'}, 'id': '6987b536-4024-4b91-a0b1-720e41b1076e', 'type': 'tool_call'}], usage_metadata={'input_tokens': 536, 'output_tokens': 49, 'total_tokens': 585})], tool_call_id='6987b536-4024-4b91-a0b1-720e41b1076e'), 'None')]}
    ---- Initial response captured ----
    '''
    Well formatted output:
    '''
    A total of 41 entity IDs were retrieved from the BigQuery table isb-cgc-bq.HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is 'hdf5'.  The entity IDs are listed below::

    1. syn51133519
    2. syn51133520
    3. syn51133521
    4. syn51133522
    5. syn51133523
    6. syn51133524
    7. syn51133525
    8. syn51133526
    9. syn51133527
    10. syn51133528
    11. syn51133529
    12. syn51133530
    13. syn51133531
    14. syn51133532
    15. syn51133533
    16. syn51133534
    17. syn51133537
    18. syn51133540
    19. syn51133578
    20. syn51133580
    21. syn51133583
    22. syn51133587
    23. syn51133591
    24. syn51133592
    25. syn51133593
    26. syn51133595
    27. syn51133596
    28. syn51133597
    29. syn51133598
    30. syn51133599
    31. syn51133600
    32. syn51133601
    33. syn51133602
    34. syn51133603
    35. syn51133604
    36. syn51133605
    37. syn51133606
    38. syn51133607
    39. syn51133608
    40. syn51133609
    41. syn51133612
    '''
    """

    prompt += f"""
    Input: {question}
    Please generate the well formated output.
    """

    vertexai.init(project="isb-cgc-external-004", location="us-central1")
    model = GenerativeModel(
        "gemini-1.5-flash-002",
    )
    responses = model.generate_content(
        [prompt],
        generation_config=generation_config,
        safety_settings=safety_settings,
        stream=False,
        # tools='code_execution'
    )

    # Example usage
    display(Markdown(responses.text))

llm = ChatVertexAI(model_name= 'gemini-1.5-flash-002')
#define system prompt for tool calling agent
system_prompt = """
You are a supervisor who can select the right worker to complete a task. The workers are called agents, and we have two agents, agent_1 and agent_2, each skilled in specific tasks, as described below:

- Use `agent_1` to generate code for downloading files from Synapse or the HTAN database. If the user request involves data retrieval from an external source like Synapse, agent_1 is the correct choice.

- Use `agent_2` to generate code for handling tasks related to the AnnData object, named `adata`, which is available in the local variables. Specifically:
  - Select `agent_2` for tasks that involve loading, preprocessing, or analyzing the `adata` object.
  - Use `agent_2` when the user’s question requires performing operations on `adata`, such as filtering, normalization, dimensionality reduction, or any other preprocessing step.
  - For questions or analysis based on the data contained within `adata`, also use `agent_2`.

If you do not have a tool to answer the question, inform the user accordingly.

"""


tool_calling_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        MessagesPlaceholder("chat_history", optional=True),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

tool_runnable = create_tool_calling_agent(llm, toolkit, prompt  = tool_calling_prompt)


def run_tool_agent(state):
    agent_outcome = tool_runnable.invoke(state)

    #this agent will overwrite the agent outcome state variable
    return {"agent_outcome": agent_outcome}

# tool executor invokes the tool action specified from the agent runnable
# they will become the nodes that will be called when the agent decides on a tool action.

tool_executor = ToolExecutor(toolkit)

# Define the function to execute tools
# This node will run a different tool as specified in the state variable agent_outcome
def execute_tools(state):
    # Get the most recent agent_outcome - this is the key added in the `agent` above
    agent_action = state['agent_outcome']
    if type(agent_action) is not list:
        agent_action = [agent_action]
    steps = []
    #sca only returns an action while tool calling returns a list
    # convert single actions to a list

    for action in agent_action:
    # Execute the tool
        output = tool_executor.invoke(action)
        print(f"The agent action is {action}")
        print(f"The tool result is: {output}")
        steps.append((action, str(output)))
    # Return the output
    return {"intermediate_steps": steps}

class AgentState(TypedDict):
   # The input string from human
   input: str
   # The list of previous messages in the conversation
   chat_history: list[BaseMessage]
   # The outcome of a given call to the agent
   # Needs 'list' as a valid type as the tool agent returns a list.
   # Needs `None` as a valid type, since this is what this will start as
   # this state will be overwritten with the latest everytime the agent is run
   agent_outcome: Union[AgentAction, list, ToolAgentAction, AgentFinish, None]

   # List of actions and corresponding observations
   # These actions should be added onto the existing so we use `operator.add`
   # to append to the list of past intermediate steps
   intermediate_steps: Annotated[list[Union[tuple[AgentAction, str], tuple[ToolAgentAction, str]]], operator.add]

def should_continue(data):
    # If the agent outcome is an AgentFinish, then we return `exit` string
    # This will be used when setting up the graph to define the flow
    if isinstance(data['agent_outcome'], AgentFinish):
        return "END"
    # Otherwise, an AgentAction is returned
    # Here we return `continue` string
    # This will be used when setting up the graph to define the flow
    else:
        return "CONTINUE"

# Define a new graph
workflow = StateGraph(AgentState)

# When nodes are called, the functions for to the tools will be called.
workflow.add_node("agent", run_tool_agent)


# Add tool invocation node to the graph
workflow.add_node("action", execute_tools)

# Define which node the graph will invoke at start.
workflow.set_entry_point("agent")

# Add flow logic with static edge.
# Each time a tool is invoked and completed we want to
# return the result to the agent to assess if task is complete or to take further actions

#each action invocation has an edge leading to the agent node.
workflow.add_edge('action', 'agent')


# Add flow logic with conditional edge.
workflow.add_conditional_edges(
    # first parameter is the starting node for the edge
    "agent",
    # the second parameter specifies the logic function to be run
    # to determine which node the edge will point to given the state.
    should_continue,

    #third parameter defines the mapping between the logic function
    #output and the nodes on the graph
    # For each possible output of the logic function there must be a valid node.
    {
        # If 'continue' we proceed to the action node.
        "CONTINUE": "action",
        # Otherwise we end invocations with the END node.
        "END": END
    }
)

memory = MemorySaver()

# Finally, compile the graph!
# This compiles it into a LangChain Runnable,
app = workflow.compile(checkpointer = memory)

def ask_agents(user_question, config={"configurable": {"thread_id": "1"}}):
    """
    Executes an application stream with enhanced error handling and logs the output.

    Parameters:
    - user_question: A string representing the user's question or command.
    - config: A dictionary with configuration settings for the app, such as 'thread_id'.

    Returns:
    - The captured standard output as a string.
    """
    buffer = io.StringIO()  # Capture output buffer
    inputs = {"input": user_question, "chat_history": []}  # Set up inputs for the stream

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")  # Ignore warnings during execution
        # Redirect stdout to the buffer to capture printed output
        with redirect_stdout(buffer):
            try:
                # Stream output and capture responses
                for i, response_chunk in enumerate(app.stream(inputs, config=config)):
                    result = list(response_chunk.values())[0]  # Accessing the actual response content
                    print(result)  # Print result to buffer to capture it
                    if i == 1:
                        # Optionally break after the initial response
                        print("---- Initial response captured ----")
                        break
            except Exception as e:
                # Print the error in the captured output
                print(f"Error during streaming execution: {e}")

    # Retrieve the contents of the buffer as a string
    return buffer.getvalue()

# Example usage to show HTAN data browsing
user_question = "Can you load the entityId of all data under HTAN.10xvisium_spatialtranscriptomics_scRNAseq_level4_metadata_current where the File_Format is hdf5?"

# Example usage to show HTAN data browsing
# user_question = "Can you download the Synapse datasaet with id syn51133612 to /content/demo_data/?"

# user_question = "Can you load the /content/demo_data/8899_AS_8_unfiltered.h5ad to adata using scanpy?"

# user_question = "Please preprocess the adata object"

# user_question = "Generate a UMAP visualization comparing kmeans_9_clusters with kmeans_10_clusters"

output = ask_agents(user_question)

In [None]:
user_question = "Can you load the /content/demo_data/8899_AS_8_unfiltered.h5ad to adata using scanpy?"
output = ask_agents(user_question)
pritn(output)

'{\'agent_outcome\': [ToolAgentAction(tool=\'agent_2\', tool_input={\'question\': \'Can you load the /content/demo_data/8899_AS_8_unfiltered.h5ad to adata using scanpy?\'}, log="\\nInvoking: `agent_2` with `{\'question\': \'Can you load the /content/demo_data/8899_AS_8_unfiltered.h5ad to adata using scanpy?\'}`\\n\\n\\n", message_log=[AIMessage(content=\'\', additional_kwargs={\'function_call\': {\'name\': \'agent_2\', \'arguments\': \'{"question": "Can you load the /content/demo_data/8899_AS_8_unfiltered.h5ad to adata using scanpy?"}\'}}, response_metadata={\'is_blocked\': False, \'safety_ratings\': [], \'usage_metadata\': {\'prompt_token_count\': 524, \'candidates_token_count\': 37, \'total_token_count\': 561, \'cached_content_token_count\': 0}, \'finish_reason\': \'STOP\', \'avg_logprobs\': -0.03173088705217516, \'logprobs_result\': {\'top_candidates\': [], \'chosen_candidates\': []}}, id=\'run-0fe65e7f-1452-4d60-8c6d-e9452bf2ae02-0\', tool_calls=[{\'name\': \'agent_2\', \'args\': {

In [None]:
agent_3(output)

```
The file `/content/demo_data/8899_AS_8_unfiltered.h5ad` was successfully loaded into the `adata` AnnData object using scanpy.  The object contains information for 4992 observations (cells or spots) and 18729 variables (genes).  Details are shown below:

AnnData object with n_obs × n_vars = 4992 × 18729
    obs: 'in_tissue', 'array_row', 'array_col', 'kmeans_7_clusters', 'kmeans_10_clusters', 'kmeans_4_clusters', 'kmeans_2_clusters', 'kmeans_6_clusters', 'graphclust', 'kmeans_3_clusters', 'kmeans_8_clusters', 'kmeans_9_clusters', 'kmeans_5_clusters'
    var: 'gene_ids', 'feature_types', 'genome', 'n_cells', 'Morans_I', 'Morans_I_p_val', 'Morans_I_adj_p_val', 'Feature Counts in Spots Under Tissue', 'Median Normalized Average Counts', 'Barcodes Detected per Feature'
    uns: 'spatial'
    obsm: 'spatial'
```
