# Env Dev: HF Dataset Traces from Snorkel

In [1]:
import os
from os.path import join

from omegaconf import OmegaConf

from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


## Agentic Finance Reasoning

In [2]:
env_config = """
name: "action_lm"
is_async: false
dataset_config:
  path: "snorkelai/agent-finance-reasoning"
  cache_dir: "/scr/mzhang/data/"
  split: "train"
"""

env_config = OmegaConf.create(env_config)

ds = load_dataset(**env_config.dataset_config)

Generating train split: 100%|██████████| 357/357 [00:00<00:00, 10261.63 examples/s]


In [3]:
ds

Dataset({
    features: ['id', 'model', 'user_query', 'system_prompt', 'company', 'company_10k_link', 'answer', 'correctness', 'correctness_rationale', 'completeness', 'completeness_rationale', 'trace'],
    num_rows: 357
})

In [8]:
print("Completeness: ", sum(ds["completeness"]) / len(ds))
print("Correctness: ", sum(ds["correctness"]) / len(ds))

Completeness:  0.9047619047619048
Correctness:  0.4677871148459384


In [9]:
ds = ds.filter(lambda x: x["correctness"])
ds

Filter: 100%|██████████| 357/357 [00:00<00:00, 23077.95 examples/s]


Dataset({
    features: ['id', 'model', 'user_query', 'system_prompt', 'company', 'company_10k_link', 'answer', 'correctness', 'correctness_rationale', 'completeness', 'completeness_rationale', 'trace'],
    num_rows: 167
})

In [10]:
for k in ds[0].keys():
    print(k)

id
model
user_query
system_prompt
company
company_10k_link
answer
correctness
correctness_rationale
completeness
completeness_rationale
trace


In [11]:
sample = ds[0]

print(sample["system_prompt"])

Only execute one tool call at a time. Think and reason step by step. Feel free to iteratively gather data as much as you need to answer the question. Although there is a limit of 50 turns, you can still gather data as much as you need to answer the question. When you have gathered all the data, answer the question. If                  

Once you have gathered all the data, then you can gradually start forming an answer. Once you have formed an answer, generate a 1 paragraph summary of it with all the relevant figures summarized.                   



In [17]:
sample["answer"]

"The interest cost ratio for postretirement benefits in 2024 is 4.6%, which indicates a relatively moderate interest burden on the company's postretirement obligations. This suggests that the company is managing its long-term benefit liabilities at a reasonable cost level, though this ratio should be monitored over time to ensure it remains sustainable as interest rates and benefit obligations evolve."

In [16]:
sample["user_query"]

"Here is the question : How significant are the company's postretirement benefit obligations in terms of interest burden, and what does this indicate about the company's long-term liability management in 2024?, Here are the companies name in the database to query for : at_t"

In [21]:
import json

messages = json.loads(sample["trace"])
messages

[{'content': "Here is the question : How significant are the company's postretirement benefit obligations in terms of interest burden, and what does this indicate about the company's long-term liability management in 2024?, Here are the companies name in the database to query for : at_t",
  'id': '2aafee6b-806f-456a-b0d8-a6baf7b6cb5b',
  'type': 'human',
  'additional_kwargs': {}},
 {'content': [{'text': "I'll help you analyze AT&T's postretirement benefit obligations and their interest burden to understand the company's long-term liability management in 2024. Let me start by exploring the available data.",
    'type': 'text'},
   {'id': 'toolu_01RZvNkP1xY2XmLrnnui8RUi',
    'input': {'company_name': 'at_t'},
    'name': 'get_descriptions',
    'type': 'tool_use'}],
  'id': 'run--4341b341-8c0b-4d4b-a537-272bc0140e9f-0',
  'type': 'ai',
  'additional_kwargs': {}},
 {'content': '["us_gaap_ScheduleOfComponentsOfIncomeTaxExpenseBenefitTableTextBlock", "us_gaap_ScheduleOfDebtTableTextBlock"

In [26]:
for msg_idx, message in enumerate(messages):
    print(f"---- {msg_idx} ----")
    print(message.keys())
    print(f"message_type: {message["type"]}")
    print(message)
    if msg_idx == 9:
        break

---- 0 ----
dict_keys(['content', 'id', 'type', 'additional_kwargs'])
message_type: human
{'content': "Here is the question : How significant are the company's postretirement benefit obligations in terms of interest burden, and what does this indicate about the company's long-term liability management in 2024?, Here are the companies name in the database to query for : at_t", 'id': '2aafee6b-806f-456a-b0d8-a6baf7b6cb5b', 'type': 'human', 'additional_kwargs': {}}
---- 1 ----
dict_keys(['content', 'id', 'type', 'additional_kwargs'])
message_type: ai
{'content': [{'text': "I'll help you analyze AT&T's postretirement benefit obligations and their interest burden to understand the company's long-term liability management in 2024. Let me start by exploring the available data.", 'type': 'text'}, {'id': 'toolu_01RZvNkP1xY2XmLrnnui8RUi', 'input': {'company_name': 'at_t'}, 'name': 'get_descriptions', 'type': 'tool_use'}], 'id': 'run--4341b341-8c0b-4d4b-a537-272bc0140e9f-0', 'type': 'ai', 'additi

In [27]:
for msg_idx, message in enumerate(messages):
    if not isinstance(message["content"], list):
        print(f"---- {msg_idx} ----")
        print(f"message_type: {message['type']}")
        print(message.keys())
        print(message)

---- 0 ----
message_type: human
dict_keys(['content', 'id', 'type', 'additional_kwargs'])
{'content': "Here is the question : How significant are the company's postretirement benefit obligations in terms of interest burden, and what does this indicate about the company's long-term liability management in 2024?, Here are the companies name in the database to query for : at_t", 'id': '2aafee6b-806f-456a-b0d8-a6baf7b6cb5b', 'type': 'human', 'additional_kwargs': {}}
---- 2 ----
message_type: tool
dict_keys(['content', 'id', 'type', 'additional_kwargs'])
{'content': '["us_gaap_ScheduleOfComponentsOfIncomeTaxExpenseBenefitTableTextBlock", "us_gaap_ScheduleOfDebtTableTextBlock", "us_gaap_SupplierFinanceProgramTableTextBlock", "_AdditionalFinancialInformationConsolidatedStatementsOfIncomeTableTextBlock", "us_gaap_ScheduleOfEarningsPerShareBasicAndDilutedTableTextBlock", "us_gaap_LeaseCostTableTextBlock", "us_gaap_ScheduleOfDeferredTaxAssetsAndLiabilitiesTableTextBlock", "_ScheduleOfPensionAndP

In [29]:
for msg_idx, message in enumerate(messages):
    if message["type"] == "ai":
        print(f"---- {msg_idx} ----")
        print(f"message_type: {message['type']}")
        print(message.keys())
        print(message)
        if isinstance(message["content"], list):
            break

---- 1 ----
message_type: ai
dict_keys(['content', 'id', 'type', 'additional_kwargs'])
{'content': [{'text': "I'll help you analyze AT&T's postretirement benefit obligations and their interest burden to understand the company's long-term liability management in 2024. Let me start by exploring the available data.", 'type': 'text'}, {'id': 'toolu_01RZvNkP1xY2XmLrnnui8RUi', 'input': {'company_name': 'at_t'}, 'name': 'get_descriptions', 'type': 'tool_use'}], 'id': 'run--4341b341-8c0b-4d4b-a537-272bc0140e9f-0', 'type': 'ai', 'additional_kwargs': {}}


In [30]:
message["content"]

[{'text': "I'll help you analyze AT&T's postretirement benefit obligations and their interest burden to understand the company's long-term liability management in 2024. Let me start by exploring the available data.",
  'type': 'text'},
 {'id': 'toolu_01RZvNkP1xY2XmLrnnui8RUi',
  'input': {'company_name': 'at_t'},
  'name': 'get_descriptions',
  'type': 'tool_use'}]

In [33]:
tool_call_bos = "<tool_call>"
tool_call_eos = "</tool_call>"

_msg_content: list[str] = []
for _msg in message["content"]:
    if _msg["type"] == "text":
        _msg_content.append(_msg["text"])
    elif _msg["type"] == "tool_use":
        tool_content = {
            "name": _msg["name"],
            "arguments": _msg["input"]
        }
        tool_content = f"{tool_call_bos}\n{json.dumps(tool_content)}\n{tool_call_eos}"
        _msg_content.append(tool_content)
    else:
        raise ValueError(f"Unknown message type: {_msg['type']}")
_msg_content = "\n".join(_msg_content)
print(_msg_content)



I'll help you analyze AT&T's postretirement benefit obligations and their interest burden to understand the company's long-term liability management in 2024. Let me start by exploring the available data.
<tool_call>
{"name": "get_descriptions", "arguments": {"company_name": "at_t"}}
</tool_call>


#### Define as function

In [34]:
import json
from typing import Any

convert_role = {
    "human": "user",
    "ai": "assistant",
    "tool": "tool",
    "system": "system",
}


def convert_msg_to_think_act(
    message: dict[str, str],
    tool_call_bos: str = "<tool_call>",
    tool_call_eos: str = "</tool_call>",
) -> tuple[dict[str, str], list[dict[str, Any]], list[str]]:
    """
    Convert an assistant message in the Agent-Finance-Reasoning dataset
    to our "think-act" format

    -> Returns: 
    """
    new_msg: dict[str, str] = {}
    new_tools: list[dict[str, Any]] = []
    old_texts: list[str] = []

    role = message["role"] if "role" in message else message["type"]
    role = convert_role.get(role, role)

    if role != "assistant":  # Only apply to assistant messages
        return message, new_tools, old_texts

    if not isinstance(message["content"], list):  # single response
        tool_dict = {
            "name": "respond_user",
            "arguments": {"text": message["content"]}
        } 
        new_msg_content = f"{tool_call_bos}\n{json.dumps(tool_dict)}\n{tool_call_eos}"
        new_tools.append(tool_dict)
        old_texts.append(message["content"])
    
    else:
        _new_msg_contents: list[str] = []
        for _msg in message["content"]:
            if _msg["type"] == "text":
                _new_msg_contents.append(_msg["text"])
                old_texts.append(_msg["text"])
            elif _msg["type"] == "tool_use":
                tool_dict = {
                    "name": _msg["name"],
                    "arguments": _msg["input"]
                }
                tool_content = f"{tool_call_bos}\n{json.dumps(tool_dict)}\n{tool_call_eos}"
                _new_msg_contents.append(tool_content)
                new_tools.append(tool_dict)
            else:
                raise ValueError(f"Unknown message type: {_msg['type']}")
        new_msg_content = "\n".join(_new_msg_contents)

    new_msg = {"role": role, "content": new_msg_content}
    return new_msg, new_tools, old_texts


    

In [35]:
convert_msg_to_think_act(message)

({'role': 'assistant',
  'content': 'I\'ll help you analyze AT&T\'s postretirement benefit obligations and their interest burden to understand the company\'s long-term liability management in 2024. Let me start by exploring the available data.\n<tool_call>\n{"name": "get_descriptions", "arguments": {"company_name": "at_t"}}\n</tool_call>'},
 [{'name': 'get_descriptions', 'arguments': {'company_name': 'at_t'}}],
 ["I'll help you analyze AT&T's postretirement benefit obligations and their interest burden to understand the company's long-term liability management in 2024. Let me start by exploring the available data."])

### Convert all messages

In [36]:
sample["system_prompt"]

'Only execute one tool call at a time. Think and reason step by step. Feel free to iteratively gather data as much as you need to answer the question. Although there is a limit of 50 turns, you can still gather data as much as you need to answer the question. When you have gathered all the data, answer the question. If                  \n\nOnce you have gathered all the data, then you can gradually start forming an answer. Once you have formed an answer, generate a 1 paragraph summary of it with all the relevant figures summarized.                   \n'

In [38]:
def convert_msg(message: dict[str, Any]) -> tuple[dict[str, str], list[dict[str, Any]], list[str]]:
    role = message["role"] if "role" in message else message["type"]
    role = convert_role.get(role, role)

    if role in ["user", "tool"]:
        return {"role": role, "content": message["content"]}, [], []

    elif role == "assistant":
        new_msg, new_tools, old_texts = convert_msg_to_think_act(message)
        return new_msg, new_tools, old_texts

    else:
        raise ValueError(f"Unknown role: {role}")
        

### Collect all tools

In [41]:
for msg_idx, message in enumerate(messages):
    
    new_msg, new_tools, old_texts = convert_msg(message)
    print(f"---- {msg_idx} ----")
    print(new_msg.keys())
    print(new_msg)
    print(new_tools)




---- 0 ----
dict_keys(['role', 'content'])
{'role': 'user', 'content': "Here is the question : How significant are the company's postretirement benefit obligations in terms of interest burden, and what does this indicate about the company's long-term liability management in 2024?, Here are the companies name in the database to query for : at_t"}
[]
---- 1 ----
dict_keys(['role', 'content'])
{'role': 'assistant', 'content': 'I\'ll help you analyze AT&T\'s postretirement benefit obligations and their interest burden to understand the company\'s long-term liability management in 2024. Let me start by exploring the available data.\n<tool_call>\n{"name": "get_descriptions", "arguments": {"company_name": "at_t"}}\n</tool_call>'}
[{'name': 'get_descriptions', 'arguments': {'company_name': 'at_t'}}]
---- 2 ----
dict_keys(['role', 'content'])
{'role': 'tool', 'content': '["us_gaap_ScheduleOfComponentsOfIncomeTaxExpenseBenefitTableTextBlock", "us_gaap_ScheduleOfDebtTableTextBlock", "us_gaap_Supp

In [49]:
ALL_TOOLS = {}  # keep track of all tools

for msg_idx, message in enumerate(messages):
    
    new_msg, new_tools, old_texts = convert_msg(message)
    print(f"---- {msg_idx} ----")
    print(new_msg.keys())
    print(new_msg)

    if len(new_tools) > 0:
        for tool in new_tools:
            name = tool["name"]
            arguments = tool["arguments"]
            argnames = list(arguments.keys())
            tool_hash = f"{name}_{'_'.join(argnames)}"
            if tool_hash not in ALL_TOOLS:
                ALL_TOOLS[tool_hash] = {
                    "name": name,
                    "property_names": argnames,
                    "examples": []
                }
            if len(ALL_TOOLS[tool_hash]["examples"]) < 3:
                ALL_TOOLS[tool_hash]["examples"].append(new_msg["content"])





---- 0 ----
dict_keys(['role', 'content'])
{'role': 'user', 'content': "Here is the question : How significant are the company's postretirement benefit obligations in terms of interest burden, and what does this indicate about the company's long-term liability management in 2024?, Here are the companies name in the database to query for : at_t"}
---- 1 ----
dict_keys(['role', 'content'])
{'role': 'assistant', 'content': 'I\'ll help you analyze AT&T\'s postretirement benefit obligations and their interest burden to understand the company\'s long-term liability management in 2024. Let me start by exploring the available data.\n<tool_call>\n{"name": "get_descriptions", "arguments": {"company_name": "at_t"}}\n</tool_call>'}
---- 2 ----
dict_keys(['role', 'content'])
{'role': 'tool', 'content': '["us_gaap_ScheduleOfComponentsOfIncomeTaxExpenseBenefitTableTextBlock", "us_gaap_ScheduleOfDebtTableTextBlock", "us_gaap_SupplierFinanceProgramTableTextBlock", "_AdditionalFinancialInformationConsol

In [82]:
len(messages)

40

In [83]:
sample.keys()

dict_keys(['id', 'model', 'user_query', 'system_prompt', 'company', 'company_10k_link', 'answer', 'correctness', 'correctness_rationale', 'completeness', 'completeness_rationale', 'trace'])

In [55]:
new_msg

{'role': 'assistant',
 'content': '<tool_call>\n{"name": "respond_user", "arguments": {"text": "Based on my analysis of AT&T\'s postretirement benefit obligations and their interest burden, I can provide you with a comprehensive assessment:\\n\\n**Analysis Summary:**\\n\\nAT&T\'s postretirement benefit obligations represent a relatively modest portion of the company\'s overall interest burden, indicating effective long-term liability management in 2024. The postretirement benefit interest cost of $310 million accounts for approximately 4.6% of AT&T\'s total interest expense of $6,759 million, while the total postretirement benefit obligation of $6,339 million represents only 4.7% of the company\'s total long-term debt of approximately $136 billion. The company has successfully reduced its postretirement obligations from $6,693 million to $6,339 million year-over-year, demonstrating proactive liability management, while benefiting from higher discount rates (5.60% in 2024 vs 5.00% in 20

In [50]:
ALL_TOOLS

{'get_descriptions_company_name': {'name': 'get_descriptions',
  'property_names': ['company_name'],
  'examples': ['I\'ll help you analyze AT&T\'s postretirement benefit obligations and their interest burden to understand the company\'s long-term liability management in 2024. Let me start by exploring the available data.\n<tool_call>\n{"name": "get_descriptions", "arguments": {"company_name": "at_t"}}\n</tool_call>']},
 'get_table_info_company_name_table_name': {'name': 'get_table_info',
  'property_names': ['company_name', 'table_name'],
  'examples': ['Let me look for tables related to pension and postretirement benefits:\n<tool_call>\n{"name": "get_table_info", "arguments": {"company_name": "at_t", "table_name": "_ScheduleOfPensionAndPostretirementAssetsAndLiabilitiesAtFairValueByLevelTableTextBlock"}}\n</tool_call>',
   'Let me also check for plan obligations information:\n<tool_call>\n{"name": "get_table_info", "arguments": {"company_name": "at_t", "table_name": "_ScheduleOfPlanO

In [51]:
len(ALL_TOOLS)

5

### Go through all samples

In [53]:
type(ds[0])

dict

In [52]:
type(sample)

dict

In [54]:
ds[0]

{'id': 0,
 'model': 'claude-opus-4-20250514',
 'user_query': "Here is the question : How significant are the company's postretirement benefit obligations in terms of interest burden, and what does this indicate about the company's long-term liability management in 2024?, Here are the companies name in the database to query for : at_t",
 'system_prompt': 'Only execute one tool call at a time. Think and reason step by step. Feel free to iteratively gather data as much as you need to answer the question. Although there is a limit of 50 turns, you can still gather data as much as you need to answer the question. When you have gathered all the data, answer the question. If                  \n\nOnce you have gathered all the data, then you can gradually start forming an answer. Once you have formed an answer, generate a 1 paragraph summary of it with all the relevant figures summarized.                   \n',
 'company': 'at_t',
 'company_10k_link': 'https://www.sec.gov/ix?doc=/Archives/edga

In [1]:
def build_state_action_samples(
    # messages: list[dict[str, Any]], 
    sample: dict[str, Any],
    sample_idx: int,
    all_tools_store: dict[str, Any],
    discount_factor: float = 0.9,
):
    new_samples: list[dict[str, Any]] = []   # state, action, done, reward, return_, timestep, discount_factor
    running_chat: list[dict[str, Any]] = []  # (no next_obs)

    messages: list[dict[str, Any]] = (
        json.loads(sample["trace"]) if isinstance(sample["trace"], str) else sample["trace"]
    )
    done = False
    timestep = 0
    reward = 2 * int(sample["correctness"]) - 1  # 1 if correct, -1 if incorrect
    max_timestep = (len(messages) // 2) - 1      # 0-indexed
    
    for msg_idx, message in enumerate(messages):
        new_msg, new_tools, old_texts = convert_msg(message)
        if len(new_tools) > 0:
            for tool in new_tools:
                name = tool["name"]
                arguments = tool["arguments"]
                argnames = list(arguments.keys())
                tool_hash = f"{name}_{'_'.join(argnames)}"
                if tool_hash not in all_tools_store:
                    all_tools_store[tool_hash] = {
                        "name": name,
                        "property_names": argnames,
                        "examples": [],
                        "count": 0
                    }
                # Add examples of tool usage
                if len(all_tools_store[tool_hash]["examples"]) < 3:
                    all_tools_store[tool_hash]["examples"].append(new_msg["content"])
                all_tools_store[tool_hash]["count"] += 1
        running_chat.append(new_msg)

        # Store new samples
        return_ = reward * (discount_factor ** (max_timestep - timestep))
        if msg_idx + 1 == len(messages):
            done = True

        if new_msg["role"] == "assistant":
            state = running_chat[:-1]
            action = running_chat[-1]

            new_samples.append({
                "sample_idx": sample_idx,
                "state": state,
                "action": action,
                # "next_obs": None,
                "done": done,
                "reward": reward,
                "return_": return_,
                "timestep": timestep,
                "max_timestep": max_timestep,
                "discount_factor": discount_factor,
                "system_prompt": sample["system_prompt"],
                "answer": sample["answer"]
            })
            timestep += 1
    
    return new_samples
        

        
        



NameError: name 'Any' is not defined

### Try building new samples

In [136]:
from datasets import Dataset

In [151]:
env_config = """
name: "action_lm"
is_async: false
dataset_config:
  path: "snorkelai/agent-finance-reasoning"
  cache_dir: "/scr/mzhang/data/"
  split: "train"
"""

env_config = OmegaConf.create(env_config)

ds = load_dataset(**env_config.dataset_config)
# ds = ds.filter(lambda x: x["correctness"])

In [None]:
ALL_TOOLS = {}
all_samples = []

for sample_idx, sample in enumerate(ds):
    new_samples = build_state_action_samples(sample, sample_idx, ALL_TOOLS)
    all_samples.extend(new_samples)


In [153]:
ds = Dataset.from_list(all_samples)
ds

Dataset({
    features: ['sample_idx', 'state', 'action', 'done', 'reward', 'return_', 'timestep', 'max_timestep', 'discount_factor', 'system_prompt'],
    num_rows: 4610
})

In [154]:
df = ds.to_pandas()
df.head()


Unnamed: 0,sample_idx,state,action,done,reward,return_,timestep,max_timestep,discount_factor,system_prompt
0,0,[{'content': 'Here is the question : How signi...,{'content': 'I'll help you analyze AT&T's post...,False,1,0.135085,0,19,0.9,Only execute one tool call at a time. Think an...
1,0,[{'content': 'Here is the question : How signi...,{'content': 'Let me look for tables related to...,False,1,0.150095,1,19,0.9,Only execute one tool call at a time. Think an...
2,0,[{'content': 'Here is the question : How signi...,{'content': 'Let me also check for plan obliga...,False,1,0.166772,2,19,0.9,Only execute one tool call at a time. Think an...
3,0,[{'content': 'Here is the question : How signi...,{'content': 'This is exactly what I need! Let ...,False,1,0.185302,3,19,0.9,Only execute one tool call at a time. Think an...
4,0,[{'content': 'Here is the question : How signi...,"{'content': '<tool_call> {""name"": ""sql_query"",...",False,1,0.205891,4,19,0.9,Only execute one tool call at a time. Think an...


In [155]:
df[df["done"]]

Unnamed: 0,sample_idx,state,action,done,reward,return_,timestep,max_timestep,discount_factor,system_prompt
19,0,[{'content': 'Here is the question : How signi...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,19,19,0.9,Only execute one tool call at a time. Think an...
38,1,[{'content': 'Here is the question : What is t...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,18,18,0.9,Only execute one tool call at a time. Think an...
54,2,[{'content': 'Here is the question : How has t...,"{'content': '<tool_call> {""name"": ""respond_use...",True,-1,-1.0,15,15,0.9,Only execute one tool call at a time. Think an...
63,3,[{'content': 'Here is the question : How does ...,"{'content': '<tool_call> {""name"": ""respond_use...",True,1,1.0,8,8,0.9,Only execute one tool call at a time. Think an...
76,4,[{'content': 'Here is the question : What is t...,"{'content': '<tool_call> {""name"": ""respond_use...",True,-1,-1.0,12,12,0.9,Only execute one tool call at a time. Think an...
...,...,...,...,...,...,...,...,...,...,...
4581,352,[{'content': 'Here is the question : What is t...,"{'content': '<tool_call> {""name"": ""respond_use...",True,-1,-1.0,4,4,0.9,Only execute one tool call at a time. Think an...
4585,353,[{'content': 'Here is the question : How has t...,"{'content': '<tool_call> {""name"": ""respond_use...",True,-1,-1.0,3,3,0.9,Only execute one tool call at a time. Think an...
4589,354,[{'content': 'Here is the question : How has t...,"{'content': '<tool_call> {""name"": ""respond_use...",True,-1,-1.0,3,3,0.9,Only execute one tool call at a time. Think an...
4605,355,[{'content': 'Here is the question : How has t...,"{'content': '<tool_call> {""name"": ""respond_use...",True,-1,-1.0,15,15,0.9,Only execute one tool call at a time. Think an...


In [160]:
df[df["timestep"] == 19].reset_index(drop=True).iloc[0]["action"]

{'content': '<tool_call>\n{"name": "respond_user", "arguments": {"text": "Based on my analysis of AT&T\'s postretirement benefit obligations and their interest burden, I can provide you with a comprehensive assessment:\\n\\n**Analysis Summary:**\\n\\nAT&T\'s postretirement benefit obligations represent a relatively modest portion of the company\'s overall interest burden, indicating effective long-term liability management in 2024. The postretirement benefit interest cost of $310 million accounts for approximately 4.6% of AT&T\'s total interest expense of $6,759 million, while the total postretirement benefit obligation of $6,339 million represents only 4.7% of the company\'s total long-term debt of approximately $136 billion. The company has successfully reduced its postretirement obligations from $6,693 million to $6,339 million year-over-year, demonstrating proactive liability management, while benefiting from higher discount rates (5.60% in 2024 vs 5.00% in 2023) that help reduce t

In [157]:
ALL_TOOLS

{'get_descriptions_company_name': {'name': 'get_descriptions',
  'property_names': ['company_name'],
  'examples': ['I\'ll help you analyze AT&T\'s postretirement benefit obligations and their interest burden to understand the company\'s long-term liability management in 2024. Let me start by exploring the available data.\n<tool_call>\n{"name": "get_descriptions", "arguments": {"company_name": "at_t"}}\n</tool_call>',
   'I\'ll help you analyze Meta\'s lease financing strategy and its impact on cash flow obligations. Let me start by exploring the available data.\n<tool_call>\n{"name": "get_descriptions", "arguments": {"company_name": "meta"}}\n</tool_call>',
   'I\'ll help you analyze Disney\'s lease portfolio composition between operating and finance leases from 2022 to 2024. Let me start by exploring the available tables related to leases.\n<tool_call>\n{"name": "get_descriptions", "arguments": {"company_name": "disney"}}\n</tool_call>'],
  'count': 161},
 'get_table_info_company_nam

In [158]:
len(ALL_TOOLS)

5

### Use some frontier LLM magic to convert the tool representation above to Tool descriptions

In [159]:
TOOL_DESCRIPTIONS = {
    "get_descriptions": {
        "type": "function",
        "name": "get_descriptions",
        "description": (
            "Retrieve the available table/field descriptions for a given company dataset. "
            "Use this to discover what tables exist and what they contain before querying."
        ),
        "parameters": {
            "type": "object",
            "properties": {
                "company_name": {
                    "type": "string",
                    "description": (
                        "Company identifier used by the dataset (typically a normalized slug such as "
                        "'at_t', 'meta', 'disney')."
                    ),
                }
            },
            "required": ["company_name"],
        },
    },
    "get_table_info": {
        "type": "function",
        "name": "get_table_info",
        "description": (
            "Retrieve schema/metadata for a specific table in a given company's dataset. "
            "Use this to understand the table's columns/fields before running SQL queries."
        ),
        "parameters": {
            "type": "object",
            "properties": {
                "company_name": {
                    "type": "string",
                    "description": (
                        "Company identifier used by the dataset (typically a normalized slug such as "
                        "'at_t', 'meta', 'disney')."
                    ),
                },
                "table_name": {
                    "type": "string",
                    "description": (
                        "Exact table name to inspect (e.g., an XBRL/US-GAAP text block or schedule table name)."
                    ),
                },
            },
            "required": ["company_name", "table_name"],
        },
    },
    "sql_query": {
        "type": "function",
        "name": "sql_query",
        "description": (
            "Execute a SQL query against a specific table in a given company's dataset and return the results."
        ),
        "parameters": {
            "type": "object",
            "properties": {
                "company_name": {
                    "type": "string",
                    "description": (
                        "Company identifier used by the dataset (typically a normalized slug such as "
                        "'at_t', 'meta', 'disney')."
                    ),
                },
                "table_name": {
                    "type": "string",
                    "description": "Exact table name to query.",
                },
                "query": {
                    "type": "string",
                    "description": (
                        "A SQL query string to run against the specified table. Must have filters "
                        "(e.g., 'SELECT ... FROM <table> WHERE ...')."
                    ),
                },
            },
            "required": ["company_name", "table_name", "query"],
        },
    },
    "calculator": {
        "type": "function",
        "name": "calculator",
        "description": "Evaluate a mathematical expression and return the result.",
        "parameters": {
            "type": "object",
            "properties": {
                "expression": {
                    "type": "string",
                    "description": (
                        "Mathematical expression to evaluate (e.g., '5399 + 8652', '(310 / 6759) * 100')."
                    ),
                }
            },
            "required": ["expression"],
        },
    },
    "respond_user": {
        "type": "function",
        "name": "respond_user",
        "description": "Respond or message the user.",
        "parameters": {
            "type": "object",
            "properties": {
                "text": {
                    "type": "string",
                    "description": "The text content to respond or message the user.",
                }
            },
            "required": ["text"],
        },
    },
}

In [161]:
tools = list(TOOL_DESCRIPTIONS.values())
tools

[{'type': 'function',
  'name': 'get_descriptions',
  'description': 'Retrieve the available table/field descriptions for a given company dataset. Use this to discover what tables exist and what they contain before querying.',
  'parameters': {'type': 'object',
   'properties': {'company_name': {'type': 'string',
     'description': "Company identifier used by the dataset (typically a normalized slug such as 'at_t', 'meta', 'disney')."}},
   'required': ['company_name']}},
 {'type': 'function',
  'name': 'get_table_info',
  'description': "Retrieve schema/metadata for a specific table in a given company's dataset. Use this to understand the table's columns/fields before running SQL queries.",
  'parameters': {'type': 'object',
   'properties': {'company_name': {'type': 'string',
     'description': "Company identifier used by the dataset (typically a normalized slug such as 'at_t', 'meta', 'disney')."},
    'table_name': {'type': 'string',
     'description': 'Exact table name to inspe

### Create final dataset and push to hub

In [164]:
env_config = """
name: "action_lm"
is_async: false
dataset_config:
  path: "snorkelai/agent-finance-reasoning"
  cache_dir: "/scr/mzhang/data/"
  split: "train"
"""

env_config = OmegaConf.create(env_config)

ds = load_dataset(**env_config.dataset_config)
# ds = ds.filter(lambda x: x["correctness"])


ALL_TOOLS = {}
all_samples = []

for sample_idx, sample in enumerate(ds):
    new_samples = build_state_action_samples(sample, sample_idx, ALL_TOOLS)
    all_samples.extend(new_samples)


ds = Dataset.from_list(all_samples)
df = ds.to_pandas()

df["tools"] = df["timestep"].apply(lambda x: tools)

df.head()


Unnamed: 0,sample_idx,state,action,done,reward,return_,timestep,max_timestep,discount_factor,system_prompt,tools
0,0,[{'content': 'Here is the question : How signi...,{'content': 'I'll help you analyze AT&T's post...,False,1,0.135085,0,19,0.9,Only execute one tool call at a time. Think an...,"[{'type': 'function', 'name': 'get_description..."
1,0,[{'content': 'Here is the question : How signi...,{'content': 'Let me look for tables related to...,False,1,0.150095,1,19,0.9,Only execute one tool call at a time. Think an...,"[{'type': 'function', 'name': 'get_description..."
2,0,[{'content': 'Here is the question : How signi...,{'content': 'Let me also check for plan obliga...,False,1,0.166772,2,19,0.9,Only execute one tool call at a time. Think an...,"[{'type': 'function', 'name': 'get_description..."
3,0,[{'content': 'Here is the question : How signi...,{'content': 'This is exactly what I need! Let ...,False,1,0.185302,3,19,0.9,Only execute one tool call at a time. Think an...,"[{'type': 'function', 'name': 'get_description..."
4,0,[{'content': 'Here is the question : How signi...,"{'content': '<tool_call> {""name"": ""sql_query"",...",False,1,0.205891,4,19,0.9,Only execute one tool call at a time. Think an...,"[{'type': 'function', 'name': 'get_description..."


In [170]:
hf_repo_path = "snorkelai/agent-finance-reasoning"
hf_repo_path = hf_repo_path.replace("/", "_").replace("-", "_")
hf_repo_path = f"mzio/aprm-{hf_repo_path}"

ds.push_to_hub(hf_repo_path)
print(hf_repo_path)

Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 10.21ba/s]
Processing Files (1 / 1): 100%|██████████| 2.39MB / 2.39MB,  0.00B/s  
New Data Upload: |          |  0.00B /  0.00B,  0.00B/s  
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.20 shards/s]


mzio/aprm-snorkelai_agent_finance_reasoning


In [171]:
hf_repo_path

'mzio/aprm-snorkelai_agent_finance_reasoning'