In [24]:
import os
import ast
import re

TOOLS_DIR = r"e:\turing-tau\amazon-tau-bench-tasks\envs\hr_payroll\tools\interface_2"

def extract_get_info_dict_from_code(code):
    """Parses the AST to find the dictionary returned by get_info()."""
    try:
        tree = ast.parse(code)
        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef) and node.name == "get_info":
                for child in ast.walk(node):
                    if isinstance(child, ast.Return):
                        return ast.literal_eval(child.value)
    except Exception as e:
        print(f"Error parsing AST: {e}")
    return {}

def extract_required_fields(parsed_dict):
    """Extract 'required' or fallback to 'properties' keys."""
    try:
        params = parsed_dict.get("function", {}).get("parameters", {})
        required = params.get("required", [])
        if required:
            return required
        return list(params.get("properties", {}).keys())
    except Exception:
        return []

# Collect required fields for each tool
tool_required_fields = {}

for file in os.listdir(TOOLS_DIR):
    if file.endswith('.py') and file != '__init__.py':
        file_path = os.path.join(TOOLS_DIR, file)
        tool_name = file.replace('.py', '')

        with open(file_path, "r", encoding="utf-8") as f:
            code = f.read()
            get_info_dict = extract_get_info_dict_from_code(code)
            required_fields = extract_required_fields(get_info_dict)
            tool_required_fields[tool_name] = required_fields

tool_required_fields
# print("\n🧾 Required fields for each tool:")
# for tool, fields in tool_required_fields.items():
#     print(f"{tool}: {fields}")

{'approve_overtime_entry': ['time_entry_id'],
 'block_suspicious_payment': ['payment_id'],
 'block_virtual_card': ['card_id'],
 'check_user_virtual_cards': ['user_id'],
 'create_new_contract': ['worker_id', 'terms'],
 'create_payment': ['user_id', 'invoice_id', 'amount', 'currency'],
 'extend_contract_period': ['contract_id', 'new_end_date'],
 'fetch_time_summary_by_team': ['team_id'],
 'find_user': ['user_id',
  'email',
  'first_name',
  'last_name',
  'role',
  'status',
  'locale',
  'timezone'],
 'get_contracts': ['contract_id',
  'user_id',
  'worker_id',
  'status',
  'currency',
  'organization_id',
  'rate_type',
  'document_id',
  'min_rate',
  'max_rate',
  'start_date_from',
  'start_date_to',
  'end_date_from',
  'end_date_to'],
 'get_documents': ['document_id',
  'user_id',
  'worker_id',
  'title',
  'file_type',
  'status'],
 'get_payments': ['payment_id',
  'user_id',
  'invoice_id',
  'status',
  'currency',
  'processed_at',
  'min_amount',
  'max_amount'],
 'get_pay

In [25]:
id_map = {0:"instruction",
 1: 'approve_overtime_entry',
 2: 'block_suspicious_payment',
 3: 'block_virtual_card',
 4: 'check_user_virtual_cards',
 5: 'create_new_contract',
 6: 'create_payment',
 7: 'extend_contract_period',
 8: 'fetch_time_summary_by_team',
 9: 'find_user',
 10: 'get_contracts',
 11: 'get_documents',
 12: 'get_payments',
 13: 'get_payroll_run_details',
 14: 'get_pending_reimbursements',
 15: 'retrieve_worker_contracts_with_organization',
 16: 'start_new_engagement',
 17: 'update_document_status',
 18: 'upload_document'}

In [28]:
import os
import json


data_dir = 'envs/hr_payroll/data'

json_data = {}

for filename in os.listdir(data_dir):
    if filename.endswith('.json'):
        file_path = os.path.join(data_dir, filename)
        with open(file_path, 'r') as file:
            key_name = filename.replace('.json', '')  # Remove .json extension
            json_data[key_name] = json.load(file)

print("Imported JSON files:", list(json_data.keys()))

# Access specific data like this:
# json_data['users'], json_data['contracts'], etc.


def output_manupulation(tool_id):
    if tool_id == 1:
        time_entry_id = [
            value 
            for value in json_data["time_entries"].values()
        ][0]
        result = [key for key in time_entry_id.keys()]
        result[-2] = "status_approved"
        return result
    if tool_id ==2:
        payment_id = [
            value for value in json_data["payments"].values() 
        ][0]
        result = [key for key in payment_id.keys()]
        result[3] = "blocked_status"
        return result
    if tool_id ==3:
        virtual_card_id_1 = [
            value for value in json_data["virtual_cards"].values() 
        ][0]
        result = [key for key in virtual_card_id_1.keys()]
        result[-1] = "blocked_status"
        return result
    
    if tool_id ==4:
        virtual_cards_id = [value for value in json_data["virtual_cards"].values()][0]
        result = [key for key in virtual_cards_id.keys()]
        return result
    if tool_id ==5:
        return [
            "worker_id",["start_date", "end_date", "rate", "rate_type", "currency"]
        ]
    if tool_id == 6:
        return [
        "invoice_id" ,
        "amount",
        "currency",
        "status",
        "processed_at",
        "user_id"
        ]
    if tool_id == 7:
        contract_id = [value for value in json_data["contracts"].values()][0]
        result = [key for key in contract_id.keys()]
        return result
    
    if tool_id == 8:
        entries = [

                'worker_id',
                'duration_hours',
                'date'
                ]

        result = [
            "team_id",
            entries,
            "daily_totals"
        ]
        return result
    if tool_id == 9:
        user_id = [
            value for value in json_data['users'].values()
        ][0]
        result = [key for key in user_id.keys()]
        return result
    if tool_id ==10:
        contract_id = [
            value for value in json_data['contracts'].values()
        ][0]
        result = [key for key in contract_id.keys()]
        return result
    
    if tool_id == 11:
        document_id = [
            value for value in json_data['documents'].values()
        ][0]
        result = [key for key in document_id.keys()]
        return result
    if tool_id ==12:
        payment_id = [
            value for value in json_data['payments'].values()
        ][0]
        result = [key for key in payment_id.keys()]
        return result
    
    if tool_id == 13:
        payroll_runs_id = [value for id,value in json_data['payroll_runs'].items()][0]
        result = [key for key in payroll_runs_id.keys()]
        return result
    
    if tool_id ==14:
        reimbursement_id = [value for id,value in json_data['reimbursements'].items()][0]
        result = [key for key in reimbursement_id.keys()]
        return result
    
    if tool_id == 15:
        contract_id = [value for id,value in json_data['contracts'].items()][0]
        result = [key for key in contract_id.keys()]
        return result
    
    if tool_id == 16:
        worker_data = [value for value in json_data['workers'].values()]
        output = [key for key in worker_data[0].keys()]
        return output

    if tool_id == 17:
        json_datas = [value for value in json_data['documents'].values()][0]
        outputs = [id for id in json_datas.keys()]
        return outputs
    
    if tool_id ==18:
        return [
            "user_id",
            "worker_id",
            "title",
            "file_type",
            "status"
        ]

Imported JSON files: ['bank_accounts', 'contracts', 'documents', 'financial_providers', 'invoices', 'organizations', 'org_departments', 'payments', 'payroll_items', 'payroll_runs', 'reimbursements', 'teams', 'team_members', 'time_entries', 'users', 'virtual_cards', 'workers']


In [30]:
all_data = []

def select_fields(fields, field_type="input"):
    print(f"\nAvailable {field_type} fields:")
    for i, field in enumerate(fields):
        print(f"{i}: {field}")
    indices = input(f"Select {field_type} fields (comma-separated indices): ")
    selected = [fields[int(i.strip())] for i in indices.split(",") if i.strip().isdigit()]
    return selected

while True:
    user_input_for_from = int(input("Enter ID for 'from': "))
    from_values = [value for id, value in id_map.items() if id == user_input_for_from][0]

    user_input_for_to = int(input("Enter ID for 'to': "))
    to_values = [value for id, value in id_map.items() if id == user_input_for_to][0]

    inputs_values = [values for id, values in tool_required_fields.items() if id == to_values][0]
    output_values = output_manupulation(user_input_for_to)

    if user_input_for_from != 0:
        selected_inputs = select_fields(inputs_values, "input")
        data_entry = {
            "from": from_values,
            "to": to_values,
            "connection": {
                "input": selected_inputs,
                "output": selected_inputs
            }
        }
    else:
        selected_outputs = select_fields(output_values, "output")
        data_entry = {
            "from": from_values,
            "to": to_values,
            "connection": {
                "input": selected_outputs,
                "output": selected_outputs
            }
        }

    all_data.append(data_entry)

    user_choice = input("Do you want to add another entry? (y/n): ").lower()
    if user_choice != 'y':
        break

# Optional: Print the collected data
import json
print(json.dumps(all_data, indent=4))



Available output fields:
0: user_id
1: organization_id
2: worker_type
3: status
[
    {
        "from": "instruction",
        "to": "start_new_engagement",
        "connection": {
            "input": [
                "worker_type",
                "status"
            ],
            "output": [
                "worker_type",
                "status"
            ]
        }
    }
]


In [17]:
len(all_data)

3