In [82]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("/root/ProtAgent")
import os
import yaml
import random
import json

from agent.utils.others import setup_seed
from agent.tools.tool_manager import ToolManager
from openai import OpenAI
from tqdm import tqdm
from agent.agent.multi_agent_backbone import ProteinReActAgent

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [83]:
tool_manager = ToolManager(enable_quick_run=True)
tool_manager.set_out_dir("/home/public/ProtAgent/examples")


# Load all available tools

In [87]:
tool_dir = "/root/ProtAgent/agent/tools"
tools = {}
input_type2tool = {}
output_type2tool = {}

exclude_list = [
    "protrek_text2text",
    "evolla",
]

for tool_name, tool_cls in tool_manager.tools.items():
    config = tool_cls.config
    
    # Load tool document
    docs = config["document"]
    if isinstance(docs, dict):
        docs = [docs]
    
    for tool_doc in docs:
        if tool_doc["tool_name"] in exclude_list:
            continue
        
        if tool_doc["tool_name"] not in tool_manager.tools.keys():
            continue
        
        try:
            # Record input types
            input_type_cnt = {}
            input_types = []
            for param in tool_doc["required_parameters"]:
                detailed_type = param["detailed_type"]
                # Count the number of the same type
                input_type_cnt[detailed_type] = input_type_cnt.get(detailed_type, 0) + 1
                input_types.append(f"{detailed_type}_{input_type_cnt[detailed_type]}")
                
            # Record output types    
            output_type_cnt = {}
            output_types = []
            for param in tool_doc["return_values"]:
                detailed_type = param["detailed_type"]
                # Count the number of the same type
                output_type_cnt[detailed_type] = output_type_cnt.get(detailed_type, 0) + 1
                output_types.append(f"{detailed_type}_{output_type_cnt[detailed_type]}")
            
            tools[tool_doc["tool_name"]] = {
                "input_types": input_types,
                "output_types": output_types,
                "description": tool_doc["tool_description"],
            }
            
            for input_type in input_types:
                if input_type not in input_type2tool:
                    input_type2tool[input_type] = []
                input_type2tool[input_type].append(tool_doc["tool_name"])
            
            for output_type in output_types:
                if output_type not in output_type2tool:
                    output_type2tool[output_type] = []
                output_type2tool[output_type].append(tool_doc["tool_name"])
        
        except Exception as e:
            # print(dir_name, e)
            print(e)
            pass

In [88]:
for input_type in input_type2tool.keys():
    print(input_type)

AA_SEQUENCE_1
TEXT_1
FASTA_FILE_LIST_1
FASTA_PATH_1
AA_SEQUENCE_2
STRUCTURE_PATH_1
AA_POSITION_1
AA_POSITION_2
CHAIN_ID_1
CLUSTALW_ALN_PATH_1
HMMER_HMM_PATH_1
HHSUITE_HMM_PATH_1
HHSUITE_HMM_PATH_2
HHSUITE_A3M_PATH_1
HHSUITE_A3M_PATH_2
PDB_ID_1
PFAM_ID_1
UNIPROT_ID_1
FOLDSEEK_SEQUENCE_1
RFDIFFUSION_CONTIGS_1
LABEL_NUM_1
ADAPTOR_DIRECTORY_1
MUTATION_INFO_1
TRAINING_DATASET_1
STRUCTURE_PATH_2
SMILES_1
UNIPROT_KEYWORD_1


In [91]:

type2case = {
    "HHSUITE_A3M_PATH_1": "example_1.a3m",
    "FASTA_PATH_1": "example_1.fasta",
    "FASTA_FILE_LIST_1": ["class_0.fasta", "class_1.fasta"],
    "AA_SEQUENCE_1": "MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK",
    "AA_SEQUENCE_2": "MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT",
    "STRUCTURE_PATH_1": "example_1.pdb",
    "STRUCTURE_PATH_2": "example_2.pdb",
    "FOLDSEEK_SEQUENCE_1": "dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd",
    "TEXT_1": "Catalyzes the hydrolysis of cutin, a polyester that forms the structure of plant cuticle.",
    "RFDIFFUSION_CONTIGS_1": "A:50",
    "MUTATION_INFO_1": "A123B:C124D",
    "SMILES_1": "CC(=O)OC1=C(C(=C(C=C1)C(=O)O)C(=O)O)C(=O)O",
    "PDB_ID_1": "1A2B",
    "HHSUITE_HMM_PATH_1": "example.hmm",
    "UNIPROT_ID_1": "P05067",
    "PFAM_ID_1": "PF00085",
    "UNIPROT_KEYWORD_1": "cutinase",
    "AA_POSITION_1": "5",
    "AA_POSITION_2": "10",
    "CHAIN_ID_1": "A",
    "CLUSTALW_ALN_PATH_1": "result.aln",
    "HMMER_HMM_PATH_1": "example_hmmer.hmm",
    "HHSUITE_HMM_PATH_2": "example_hhsuite.hmm",
    "HHSUITE_A3M_PATH_2": "example_2.a3m",
    "LABEL_NUM_1": "2",
    "ADAPTOR_DIRECTORY_1": "adaptor_directory",
    "TRAINING_DATASET_1": "experiment_results.csv"
}

In [92]:
for type2case_key in type2case.keys():
    if type2case_key not in input_type2tool.keys():
        print(type2case_key)

for input_type in input_type2tool.keys():
    if input_type not in type2case.keys():
        print(input_type)

# Sample tool calling trajectories

In [95]:
setup_seed(250728)

# Decide the number of steps
step = 6
num = 10
cases = {}

while len(cases) < num:
    # Randomly sample the initial tool
    init_tool = random.choice(list(tools.keys()))
    # init_tool = "uniprot_query"
    
    tool_chain = [init_tool]
    # Inputs that the user has to provide
    required_inputs = set(tools[init_tool]["input_types"])
    # Inputs that later steps can obtain
    available_inputs = required_inputs.union(set(tools[init_tool]["output_types"]))
    
    for i in range(step-1):
        # Obtain available tools for the next step
        output_types = tools[tool_chain[-1]]["output_types"]
        next_tools = []
        for output_type in output_types:
            if output_type in input_type2tool:
                next_tools.extend(input_type2tool[output_type])
        
        # If no next tool is available, break the loop
        if next_tools == []:
            break
        
        # Randomly sample the next tool (can be the same as the previous one)
        while True:
            next_tool = random.choice(next_tools)
            if next_tool not in tool_chain:
                break
        tool_chain.append(next_tool)
        
        # Update the required inputs and available inputs
        new_required_inputs = set(tools[next_tool]["input_types"]) - available_inputs
        required_inputs = required_inputs.union(new_required_inputs)
        available_inputs = available_inputs.union(set(tools[next_tool]["output_types"]))
    
    # If some required inputs are not provided, skip this tool chain
    if required_inputs - type2case.keys() != set():
        continue
    
    else:
        tool_order = " -> ".join(tool_chain)
        cases[tool_order] = {
            "tool_chain": tool_chain,
            "required_inputs": required_inputs,
        }


print(cases)

{'seq2fasta -> hmmsearch -> hhblits -> hhalign_msa': {'tool_chain': ['seq2fasta', 'hmmsearch', 'hhblits', 'hhalign_msa'], 'required_inputs': {'HHSUITE_A3M_PATH_2', 'HMMER_HMM_PATH_1', 'AA_SEQUENCE_1'}}, 'esmfold -> rfdiffusion_motif_scaffolding -> pdb2aaseq -> deepab -> diffab_optimize': {'tool_chain': ['esmfold', 'rfdiffusion_motif_scaffolding', 'pdb2aaseq', 'deepab', 'diffab_optimize'], 'required_inputs': {'RFDIFFUSION_CONTIGS_1', 'AA_SEQUENCE_1', 'AA_SEQUENCE_2'}}, 'protrek_structure2protein -> seq2fasta -> hmmsearch -> fasta2seq -> saprot_tuned_inference_classification': {'tool_chain': ['protrek_structure2protein', 'seq2fasta', 'hmmsearch', 'fasta2seq', 'saprot_tuned_inference_classification'], 'required_inputs': {'LABEL_NUM_1', 'HMMER_HMM_PATH_1', 'FOLDSEEK_SEQUENCE_1', 'ADAPTOR_DIRECTORY_1'}}, 'biorxiv -> pinal -> saprot_tuned_inference_classification': {'tool_chain': ['biorxiv', 'pinal', 'saprot_tuned_inference_classification'], 'required_inputs': {'LABEL_NUM_1', 'TEXT_1', 'ADAP

In [None]:
cases.keys()

dict_keys(['protrek_structure2structure -> uniprot_fetch_byid -> saprot_mutation_bypos', 'proteinmpnn -> seq2fasta -> hmmsearch -> hhblits -> hhalign_msa', 'clustalw -> hmmbuild -> hmmsearch -> hmmscan -> hhblits -> hhfilter', 'clustalw -> hmmbuild -> hmmsearch -> hhblits -> hhalign_msa', 'umol -> pdb2aaseq -> saprot_mutation_bypos', 'uniprot_query -> hhblits -> hhalign_msa', 'saprot_tune_regression -> saprot_tuned_inference_classification', 'saprot_tuned_inference_token_classification', 'hmmsearch -> fasta2seq -> saprot_tuned_inference_regression', 'alphafold2 -> extract_peptide -> blast -> fasta2seq -> saprot_tuned_inference_classification', 'foldseek -> protrek_structure2structure -> protrek_structure2protein -> saprot_mutation_byinfo', 'saprot_tuned_inference_classification', 'protrek_text2protein -> umol -> pdb2aaseq -> esmfold -> foldseek -> protrek_structure2text', 'diffab_design', 'hhalign_msa', 'hmmscan -> clustalw -> hmmbuild -> hmmsearch -> hhblits -> hhfilter', 'tmalign', '

In [96]:
len(cases)

10

# Generate user query

In [97]:
# Record how can a input type be converted to another type by a tool
transfer_matrix = {}
for tool_name, obj in tool_manager.tools.items():
    doc = obj.config.document
    input_types = set([param["detailed_type"] for param in doc.required_parameters])
    output_types = set([param["detailed_type"] for param in doc.return_values])
    
    if len(input_types) == 1:
        input_type = input_types.pop()
        # Each output type can be generated given the input type and the tool
        if input_type not in transfer_matrix:
            transfer_matrix[input_type] = {output_type: [tool_name] for output_type in output_types}
        
        else:
            for output_type in output_types:
                transfer_matrix[input_type][output_type] = transfer_matrix[input_type].get(output_type, []) + [tool_name]
    

def find_path(input_type: str, output_type: str, exclusive: set = None) -> list:
    """
    Find the shortest paths from input_type to output_type in the transfer matrix.
    Args:
        input_type: Input type
        output_type: Output type

    Returns:
        A list of tool chains that can convert the input_type to output_type.
    """
    assert input_type != output_type, f"input type and output type are the same: {input_type}"
    
    if input_type not in transfer_matrix:
        return []
    
    # If the input type can be converted to the output type directly
    if output_type in transfer_matrix[input_type]:
        return [[tool] for tool in transfer_matrix[input_type][output_type]]
    
    else:
        if exclusive is None:
            exclusive = set()
        new_exclusive = exclusive.union({input_type})
        
        shortest_paths = []
        for available_output_type, tools in transfer_matrix[input_type].items():
            # If the output type is not in the exclusive list
            if available_output_type not in new_exclusive:
                # If the output type can be converted to the target output type
                paths = find_path(available_output_type, output_type, new_exclusive)
                if paths != []:
                    paths = [[tool] + path for tool in tools for path in paths]
                    shortest_paths.extend(paths)
        
        if shortest_paths:
            # Filter the paths to keep the shortest ones
            min_length = min([len(path) for path in shortest_paths])
            shortest_paths = [path for path in shortest_paths if len(path) == min_length]
        
        return shortest_paths


find_path("FASTA_PATH", "AA_SEQUENCE_LIST")

[]

In [98]:
def generate_query(tool_chain: list, required_inputs: set) -> str:
    prompt = """\
    You are a helpful AI assistant. Your task is to generate a reasonable user query based on a given tool chain so that this query could be answered after the tool chain is executed successfully.
    
    The description of each tool is provided below:
    PUT_TOOL_DESCRIPTION_HERE
    
    The order of the tools is provided below:
    PUT_TOOL_ORDER_HERE
    
    The inputs you should provide are:
    PUT_TOOL_INPUT_HERE
    
    You have to generate a query like a normal user. The query should be a natural language question or command that reflect what the user wants to do, usually regarding a real-world scenario. The query should be clear and concise. The query should contain the inputs you provided, not the keys but only the values. And you should not directly mention the tool names in your query.
    
    Example:
    1. Predict the structure of the sequence "AEGIKL", and the calculate the tmscore between this structure and "/example.pdb".
    
    2. Can you fetch the 3D structural data for the protein with UniProt ID P05067 and then align it with the structure in "/example_2.cif" to determine the TM-score?
    
    Now, please generate a user query based on the above information. Directly output the generated query.
    """
    tool_desc = tool_manager.brief_documents(tool_chain)
    
    tool_order = " -> ".join(tool_chain)
    user_input =  "\n".join([f"{key}: '{type2case[key]}'" for key in required_inputs])
    prompt = prompt.replace("PUT_TOOL_DESCRIPTION_HERE", tool_desc).replace("PUT_TOOL_ORDER_HERE", tool_order).replace("PUT_TOOL_INPUT_HERE", user_input)
    
    client = OpenAI(
        api_key="sk-QzZzg2fubuHu7DXQQnAKBSN2o1hpU4MipzgmoVpWpZC1ODxs",
        base_url="https://api.kwwai.top/v1"
    )
    
    response = client.chat.completions.create(
        model="claude-3-7-sonnet-20250219",  
        messages=[
            {"role": "user", "content": prompt},
        ],
        temperature=0
    )
    
    return response.choices[0].message.content

In [99]:
for key, case_dict in tqdm(cases.items()):
    query = generate_query(case_dict["tool_chain"], case_dict["required_inputs"])
    cases[key]["query"] = query

100%|██████████| 10/10 [00:53<00:00,  5.36s/it]


In [100]:
import csv

save_path = "/home/public/ProtAgent/agent_testset0729.tsv"
with open(save_path, "a", newline='', encoding='utf-8') as f:
    writer = csv.writer(f, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)
    # writer.writerow([
    #     "tool_chain",
    #     "required_inputs",
    #     "query"
    # ])
    for key, case_dict in tqdm(cases.items()):
        writer.writerow([
            key,
            case_dict['required_inputs'],
            case_dict['query']
        ])

100%|██████████| 10/10 [00:00<00:00, 51025.60it/s]


In [101]:
import pandas as pd

df = pd.read_csv(save_path, sep='\t')

# Drop duplicated row[0]
df = df.drop_duplicates(subset=[df.columns[0]])

df.to_csv(f"{save_path}.unique.tsv", sep='\t', index=False)

In [65]:
origin_df = pd.read_csv("/home/public/ProtAgent/agent_testset0729.tsv", sep='\t')

def remove_condition(str):
    if "saprot_tune" in str:
        return True
    else:
        return False
# remove rows that column[0] contains "saprot_tune"
origin_df_filtered = origin_df[~origin_df.iloc[:, 0].apply(remove_condition)]
origin_df_filtered.to_csv("/home/public/ProtAgent/agent_testset0729.removed_saprot.tsv", sep='\t', index=False)

In [62]:
df_filtered = pd.read_csv("/home/public/ProtAgent/agent_testset0728.tsv.saprot.tsv")
df_filtered.drop_duplicates()
df_filtered.to_csv("/home/public/ProtAgent/agent_testset0728.tsv.saprot.unique.tsv", sep='\t', index=False)

In [72]:
new_chains = df_filtered.values.tolist()

In [78]:
def gen_case_by_chain(chain):
    # Split the chain by " -> "
    tool_chain = chain.split(" -> ")
    
    # Get the required inputs
    required_inputs = set()
    for tool in tool_chain:
        required_inputs = required_inputs.union(set(tools[tool]["input_types"]))
    
    # Get the query
    query = generate_query(tool_chain, required_inputs)
    
    return {
        "required_inputs": required_inputs,
        "query": query
    }

In [76]:
new_chains[0][0]

'saprot_tuned_inference_token_classification'

In [93]:
cases = {}
for chain in tqdm(new_chains):
    chain = chain[0]
    cases[chain] = gen_case_by_chain(chain)

  0%|          | 0/63 [00:00<?, ?it/s]

100%|██████████| 63/63 [05:19<00:00,  5.07s/it]
