# Multi-agent consensus tutorial

In this tutorial, we will be using two different agents, Finch and Crow to do differential expression analysis on some RNASeq data from [here](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE52778). Additionally, we do consensus sampling with Finch to improve reliability of the results.

The process follows four steps:
1. Differential expression analysis: run 10 DEAs in parallel with Finch
2. Consensus sampling: Aggregate the results of the DEAs with Finch
3. Literature search: Use Crow to search the literature for the top differentially expressed genes
4. Visualization: Use Finch to create a final interactive volcano plot containing all differentially expressed genes, their evidence and the evidence score.

Let's get started!

In [None]:
import json
import pandas as pd
import uuid

from futurehouse_client import FutureHouseClient, JobNames
from futurehouse_client.models import TaskRequest, RuntimeConfig
from futurehouse_client.models.app import AuthType
import fhda.prompts as prompts

In [None]:
# Here are the prompts we'll be using
TREATMENT = "dexamethasone"
MECHANISM = "airway smooth muscle cells"
CONTEXT = "asthma"
N_TOP_GENES = 10
DEA_PROMPT = """
Determine the effect of {treatment} on {mechanism} in {context}. 

Perform differential expression analysis and pathway analysis on relevant comparison groups. Map all gene IDs to gene symbols using annotation package such as ‘org.Hs.eg.db’.

Generate volcano plots and heatmap of differentially expressed genes, and dot plots for enriched pathways, use gene symbols for labels where relevant.

Output a single csv file named "dea_results.csv"  with the results for all tested genes of the most relevant contrast, report both gene ID and gene symbol.

If there is an error, keep trying, do not give up until you reach the end of the analysis. When mapping gene ID to gene symbol, consider all possible forms of gene IDs, keep trying until the gene symbols are obtained.
"""

CONSENSUS_PROMPT = f"""
Combine these differential expression analysis results by calculating the mode of log2FC and adjusted p values. Output the results in a file named ‘consensus_results.csv’, include the columns gene_symbol, log2FC and adjusted P values. In a separate file named ‘top_genes.csv’, output the top {N_TOP_GENES} gene symbols of the consensus most significant genes with the column name “gene_symbol”. 

Create a stacked bar plot showing gene regulation consistency across all analyses. Plot regulation direction (up vs down) on x-axis and percentage of genes in each category on y-axis. Color-code by significance category: all analyses, >50% of analyses and  <50% of analyses. Include percentages within each segment and a clear legend. Exclude genes that are non-significant across all analyses.
"""

PQA_PROMPT = """
    What are the possible mechanisms for {gene} in the effect of {treatment} on {mechanism} in {context}?
    From 1 to 5, with 1 being no evidence of association at all and 5 being strong association with supporting evidence, how strong is the evidence supporting this mechanism?
    Give a concise summary for the evidence in up to 10 words, and a short summary of mechanisms in up to 20 words. Do not include references or links.
    Please share this information in json format in the form of: `"gene_symbol": <gene_symbol>, "association_evidence_score":[1...5], "evidence_summary": <evidence_summary>, "mechanism_summary": <mechanism_summary>`.
    Share nothing else but the JSON output.
    """

VOLCANO_PROMPT = """
Make an interactive volcano plot. Colour-code by significance categories: top up-regulated genes, up-regulated genes, top down-regulated genes, down-regulated genes, and non-significant genes. Genes considered as top differentially expressed genes have extra annotation available in 'pqa_results.csv’.

Include hover information according to the categories, for the top genes, on hover, show gene symbol, log2FC, adjusted p value, mechanism, evidence and evidence score. For up and down regulated genes that are not in top differentially expressed genes, show gene symbol, log2FC and adjusted p value. For non-significant genes, do not include hover information.

For the annotations, remove all text in the brackets in the summary columns, and remove the fullstop at the end. For annotations with 6 words or more in a line, use text-wrap. Don’t include text on the plot itself. Include a legend explaining the color-codes.

PLEASE USE TEXT WRAP FOR THE HOVER INFORMATION!
"""


def augment_query(query, language):
    guidelines = prompts.GENERAL_NOTEBOOK_GUIDELINES.format(language=language)
    if language == "R":
        guidelines = prompts.R_SPECIFIC_GUIDELINES.format(language=language)
    return (
        f"{prompts.CHAIN_OF_THOUGHT_AGNOSTIC.format(language=language)}\n"
        f"{guidelines}"
        f"Here is the research question to address:\n"
        f"<query>\n"
        f"{query}\n"
        f"</query>\n"
    )

In [None]:
# Here we instantiate the FutureHouse client and define the job names
FH_API_KEY = ""  # Add your API key here
# We will be creating three folders in GCS to store the results of the three steps
DEA_UPLOAD_ID = f"consensus_tutorial_dea_{str(uuid.uuid4())[:8]}"
CONSENSUS_UPLOAD_ID = f"consensus_tutorial_consensus_{str(uuid.uuid4())[:8]}"
PQA_UPLOAD_ID = f"consensus_tutorial_pqa_{str(uuid.uuid4())[:8]}"
INITIAL_RNASEQ_FILE = "datasets/GSE52778_All_Sample_FPKM_Matrix.txt.gz"
client = FutureHouseClient(
    auth_type=AuthType.API_KEY,
    api_key=FH_API_KEY,
)

In [None]:
# First let's upload the dataset to GCS and check the files were uploaded correctly
client.upload_file(
    JobNames.FINCH, file_path=INITIAL_RNASEQ_FILE, upload_id=DEA_UPLOAD_ID
)
# Check what files were uploaded to your gcs folder
client.list_files(JobNames.FINCH, upload_id=DEA_UPLOAD_ID)

In [None]:
# Now let's run 5 Finch DEA tasks in parallel
NUM_DEA_TASKS = 5
TIMEOUT = 15 * 60
runtime_config = RuntimeConfig(
    max_steps=30,
    upload_id=DEA_UPLOAD_ID,
    environment_config={
        "default_cot_prompt": False,
        "language": "R",
    },
)
task_request = TaskRequest(
    name=JobNames.FINCH,
    query=augment_query(
        DEA_PROMPT.format(treatment=TREATMENT, mechanism=MECHANISM, context=CONTEXT),
        "R",
    ),
    runtime_config=runtime_config,
)
dea_completed_tasks = await client.arun_tasks_until_done(
    [task_request for i in range(NUM_DEA_TASKS)], progress_bar=True, timeout=TIMEOUT
)
dea_task_ids = [str(task.task_id) for task in dea_completed_tasks]
success = sum([task.status == "success" for task in dea_completed_tasks])
print(f"Task success rate: {success / NUM_DEA_TASKS * 100}%")

In [None]:
# The Finch runs should take anywhere between 3-10 minutes to complete.
# Once the runs have completed, lets's download the results upload them to a new folder in GCS and run a consensus step
for c, task_id in enumerate(dea_task_ids):
    try:
        client.download_file(
            JobNames.FINCH,
            trajectory_id=task_id,
            file_path="dea_results.csv",
            destination_path=f"output/dea_results/dea_results_{c}.csv",
        )
    except Exception as e:
        print(f"Error downloading task results for task {task_id}: {e}")

# Now let's upload the whole directory of consensus results to GCS
client.upload_file(
    JobNames.FINCH, file_path="output/dea_results", upload_id=CONSENSUS_UPLOAD_ID
)

print("These files have been uploaded to GCS:")
print(client.list_files(JobNames.FINCH, upload_id=CONSENSUS_UPLOAD_ID))

In [None]:
# Now lets's run a single consensus step
runtime_config = RuntimeConfig(
    max_steps=30,
    upload_id=CONSENSUS_UPLOAD_ID,
    environment_config={
        "default_cot_prompt": False,
        "language": "R",
    },
)
consensus_task_request = TaskRequest(
    name=JobNames.FINCH,
    query=augment_query(CONSENSUS_PROMPT, "R"),
    runtime_config=runtime_config,
)
consensus_task_response = client.run_tasks_until_done(
    [consensus_task_request], progress_bar=True, timeout=TIMEOUT
)
consensus_task_id = consensus_task_response[0].task_id

In [None]:
# Once the consensus step is done, lets's download the results
client.download_file(
    JobNames.FINCH,
    trajectory_id=consensus_task_id,
    file_path="consensus_results.csv",
    destination_path="output/consensus_results.csv",
)
client.download_file(
    JobNames.FINCH,
    trajectory_id=consensus_task_id,
    file_path="top_genes.csv",
    destination_path="output/top_genes.csv",
)

In [None]:
# Let's use PaperQA to give us a summary of each gene
top_genes_df = pd.read_csv("output/top_genes.csv")
display(top_genes_df.head())
gene_symbols = top_genes_df["gene_symbol"].tolist()
pqa_tasks = [
    {
        "name": JobNames.CROW,
        "query": PQA_PROMPT.format(
            gene=gene, treatment=TREATMENT, mechanism=MECHANISM, context=CONTEXT
        ),
    }
    for gene in gene_symbols
]
pqa_task_list = await client.arun_tasks_until_done(
    pqa_tasks, progress_bar=True, timeout=TIMEOUT, verbose=True
)

In [None]:
# when PQAs are done, parse answers to csv

answer_list = []
for task_response in pqa_task_list:
    try:
        answer = json.loads(
            task_response.environment_frame["state"]["state"]["response"]["answer"][
                "answer"
            ]
        )
        if isinstance(answer, list):
            answer = answer[0]
        answer_list.append(answer)
    except Exception as e:
        print(f"Error parsing answer for task {task_response.task_id}: {e}")

pqa_df = pd.DataFrame(answer_list)
pqa_df.to_csv("output/pqa_results.csv", index=False)

In [None]:
# Finally let's create a beutiful interactive plotly plot that brings all the results together
# Now lets's run a single consensus step
client.upload_file(
    JobNames.FINCH, file_path="output/pqa_results.csv", upload_id=PQA_UPLOAD_ID
)
client.upload_file(
    JobNames.FINCH, file_path="output/consensus_results.csv", upload_id=PQA_UPLOAD_ID
)
runtime_config = RuntimeConfig(
    max_steps=30,
    upload_id=PQA_UPLOAD_ID,
    environment_config={
        "default_cot_prompt": False,
        "language": "PYTHON",
    },
)
volcano_task_request = TaskRequest(
    name=JobNames.FINCH,
    query=augment_query(VOLCANO_PROMPT, "PYTHON"),
    runtime_config=runtime_config,
)
volcano_task_id = client.create_task(volcano_task_request)

print(
    f"Task running on platform, you can view progress live for our final results at:https://platform.futurehouse.org/trajectories/{volcano_task_id}"
)

The final trajectory will have the reliable results of our DEA analysis in an interactive volcano plotly plot containing the top differentially expressed genes, their evidence and the evidence score! All in about 20 minutes!